Spaces:
Running
on
T4
Running
on
T4
import yaml | |
from easydict import EasyDict | |
import os | |
from .logger import print_log | |
def log_args_to_file(args, pre='args', logger=None): | |
for key, val in args.__dict__.items(): | |
print_log(f'{pre}.{key} : {val}', logger = logger) | |
def log_config_to_file(cfg, pre='cfg', logger=None): | |
for key, val in cfg.items(): | |
if isinstance(cfg[key], EasyDict): | |
print_log(f'{pre}.{key} = edict()', logger = logger) | |
log_config_to_file(cfg[key], pre=pre + '.' + key, logger=logger) | |
continue | |
print_log(f'{pre}.{key} : {val}', logger = logger) | |
def merge_new_config(config, new_config): | |
for key, val in new_config.items(): | |
if not isinstance(val, dict): | |
if key == '_base_': | |
with open(new_config['_base_'], 'r') as f: | |
try: | |
val = yaml.load(f, Loader=yaml.FullLoader) | |
except: | |
val = yaml.load(f) | |
config[key] = EasyDict() | |
merge_new_config(config[key], val) | |
else: | |
config[key] = val | |
continue | |
if key not in config: | |
config[key] = EasyDict() | |
merge_new_config(config[key], val) | |
return config | |
def cfg_from_yaml_file(cfg_file): | |
config = EasyDict() | |
with open(cfg_file, 'r') as f: | |
try: | |
new_config = yaml.load(f, Loader=yaml.FullLoader) | |
except: | |
new_config = yaml.load(f) | |
merge_new_config(config=config, new_config=new_config) | |
return config | |
def get_config(args, logger=None): | |
if args.resume: | |
cfg_path = os.path.join(args.experiment_path, 'config.yaml') | |
if not os.path.exists(cfg_path): | |
print_log("Failed to resume", logger = logger) | |
raise FileNotFoundError() | |
print_log(f'Resume yaml from {cfg_path}', logger = logger) | |
args.config = cfg_path | |
config = cfg_from_yaml_file(args.config) | |
if not args.resume and args.local_rank == 0: | |
save_experiment_config(args, config, logger) | |
return config | |
def save_experiment_config(args, config, logger = None): | |
config_path = os.path.join(args.experiment_path, 'config.yaml') | |
os.system('cp %s %s' % (args.config, config_path)) | |
print_log(f'Copy the Config file from {args.config} to {config_path}',logger = logger ) |