Spaces:
Runtime error
Runtime error
import numpy as np | |
import os | |
import torch | |
import random | |
import string | |
import yaml | |
from easydict import EasyDict as edict | |
import utils | |
from utils import log | |
def parse_arguments(args): | |
""" | |
Parse arguments from command line. | |
Syntax: --key1.key2.key3=value --> value | |
--key1.key2.key3= --> None | |
--key1.key2.key3 --> True | |
--key1.key2.key3! --> False | |
""" | |
opt_cmd = {} | |
for arg in args: | |
assert(arg.startswith("--")) | |
if "=" not in arg[2:]: | |
key_str,value = (arg[2:-1],"false") if arg[-1]=="!" else (arg[2:],"true") | |
else: | |
key_str,value = arg[2:].split("=") | |
keys_sub = key_str.split(".") | |
opt_sub = opt_cmd | |
for k in keys_sub[:-1]: | |
if k not in opt_sub: opt_sub[k] = {} | |
opt_sub = opt_sub[k] | |
assert keys_sub[-1] not in opt_sub,keys_sub[-1] | |
opt_sub[keys_sub[-1]] = yaml.safe_load(value) | |
opt_cmd = edict(opt_cmd) | |
return opt_cmd | |
def set(opt_cmd={}): | |
log.info("setting configurations...") | |
# load config from yaml file | |
assert("yaml" in opt_cmd) | |
fname = "options/{}.yaml".format(opt_cmd.yaml) | |
opt_base = load_options(fname) | |
# override with command line arguments | |
opt = override_options(opt_base,opt_cmd,key_stack=[],safe_check=True) | |
process_options(opt) | |
log.options(opt) | |
return opt | |
def load_options(fname): | |
with open(fname) as file: | |
opt = edict(yaml.safe_load(file)) | |
if "_parent_" in opt: | |
# load parent yaml file(s) as base options | |
parent_fnames = opt.pop("_parent_") | |
if type(parent_fnames) is str: | |
parent_fnames = [parent_fnames] | |
for parent_fname in parent_fnames: | |
opt_parent = load_options(parent_fname) | |
opt_parent = override_options(opt_parent,opt,key_stack=[]) | |
opt = opt_parent | |
print("loading {}...".format(fname)) | |
return opt | |
def override_options(opt,opt_over,key_stack=None,safe_check=False): | |
for key,value in opt_over.items(): | |
print(key,value) | |
if isinstance(value,dict): | |
# parse child options (until leaf nodes are reached) | |
opt[key] = override_options(opt.get(key,dict()),value,key_stack=key_stack+[key],safe_check=safe_check) | |
else: | |
# ensure command line argument to override is also in yaml file | |
if safe_check and key not in opt: | |
add_new = None | |
while add_new not in ["y","n"]: | |
key_str = ".".join(key_stack+[key]) | |
add_new = input("\"{}\" not found in original opt, add? (y/n) ".format(key_str)) | |
if add_new=="n": | |
print("safe exiting...") | |
exit() | |
opt[key] = value | |
return opt | |
def process_options(opt): | |
# set seed | |
if opt.seed is not None: | |
random.seed(opt.seed) | |
np.random.seed(opt.seed) | |
torch.manual_seed(opt.seed) | |
torch.cuda.manual_seed_all(opt.seed) | |
else: | |
# create random string as run ID | |
randkey = "".join(random.choice(string.ascii_uppercase) for _ in range(4)) | |
opt.name = str(opt.name)+"_{}".format(randkey) | |
assert(isinstance(opt.gpu,int)) # disable multi-GPU support for now, single is enough | |
opt.device = "cpu" if opt.cpu or not torch.cuda.is_available() else "cuda:{}".format(opt.gpu) | |
def save_options_file(opt,output_path): | |
opt_fname = "{}/options.yaml".format(output_path) | |
if os.path.isfile(opt_fname): | |
with open(opt_fname) as file: | |
opt_old = yaml.safe_load(file) | |
if opt!=opt_old: | |
# prompt if options are not identical | |
opt_new_fname = "{}/options_temp.yaml".format(output_path) | |
with open(opt_new_fname,"w") as file: | |
yaml.safe_dump(utils.to_dict(opt),file,default_flow_style=False,indent=4) | |
print("existing options file found (different from current one)...") | |
os.system("diff {} {}".format(opt_fname,opt_new_fname)) | |
os.system("rm {}".format(opt_new_fname)) | |
override = None | |
while override not in ["y","n"]: | |
override = input("override? (y/n) ") | |
if override=="n": | |
print("safe exiting...") | |
exit() | |
else: print("existing options file found (identical)") | |
else: print("(creating new options file...)") | |
with open(opt_fname,"w") as file: | |
yaml.safe_dump(utils.to_dict(opt),file,default_flow_style=False,indent=4) | |