Spaces:
Runtime error
Runtime error
from easydict import EasyDict | |
# Base default config | |
CONFIG = EasyDict({}) | |
# to indicate this is a default setting, should not be changed by user | |
CONFIG.is_default = True | |
CONFIG.version = "baseline" | |
CONFIG.phase = "train" | |
# distributed training | |
CONFIG.dist = False | |
CONFIG.wandb = False | |
# global variables which will be assigned in the runtime | |
CONFIG.local_rank = 0 | |
CONFIG.gpu = 0 | |
CONFIG.world_size = 1 | |
# Model config | |
CONFIG.model = EasyDict({}) | |
# use pretrained checkpoint as encoder | |
CONFIG.model.freeze_seg = True | |
CONFIG.model.multi_scale = False | |
CONFIG.model.imagenet_pretrain = True | |
CONFIG.model.imagenet_pretrain_path = "/home/liyaoyi/Source/python/attentionMatting/pretrain/model_best_resnet34_En_nomixup.pth" | |
CONFIG.model.batch_size = 16 | |
# one-hot or class, choice: [3, 1] | |
CONFIG.model.mask_channel = 1 | |
CONFIG.model.trimap_channel = 3 | |
# hyper-parameter for refinement | |
CONFIG.model.self_refine_width1 = 30 | |
CONFIG.model.self_refine_width2 = 15 | |
CONFIG.model.self_mask_width = 10 | |
# Model -> Architecture config | |
CONFIG.model.arch = EasyDict({}) | |
# definition in networks/encoders/__init__.py and networks/encoders/__init__.py | |
CONFIG.model.arch.encoder = "res_shortcut_encoder_29" | |
CONFIG.model.arch.decoder = "res_shortcut_decoder_22" | |
CONFIG.model.arch.m2m = "conv_baseline" | |
CONFIG.model.arch.seg = "maskrcnn" | |
# predefined for GAN structure | |
CONFIG.model.arch.discriminator = None | |
# Dataloader config | |
CONFIG.data = EasyDict({}) | |
CONFIG.data.cutmask_prob = 0 | |
CONFIG.data.workers = 0 | |
CONFIG.data.pha_ratio = 0.5 | |
# data path for training and validation in training phase | |
CONFIG.data.train_fg = None | |
CONFIG.data.train_alpha = None | |
CONFIG.data.train_bg = None | |
CONFIG.data.test_merged = None | |
CONFIG.data.test_alpha = None | |
CONFIG.data.test_trimap = None | |
CONFIG.data.imagematte_fg = None | |
CONFIG.data.imagematte_pha = None | |
CONFIG.data.d646_fg = None | |
CONFIG.data.d646_pha = None | |
CONFIG.data.aim_fg = None | |
CONFIG.data.aim_pha = None | |
CONFIG.data.human2k_fg = None | |
CONFIG.data.human2k_pha = None | |
CONFIG.data.am2k_fg = None | |
CONFIG.data.am2k_pha = None | |
CONFIG.data.coco_bg = None | |
CONFIG.data.bg20k_bg = None | |
CONFIG.data.rim_pha = None | |
CONFIG.data.rim_img = None | |
CONFIG.data.spd_pha = None | |
CONFIG.data.spd_img = None | |
# feed forward image size (untested) | |
CONFIG.data.crop_size = 1024 | |
# composition of two foregrounds, affine transform, crop and HSV jitter | |
CONFIG.data.real_world_aug = False | |
CONFIG.data.augmentation = True | |
CONFIG.data.random_interp = True | |
### Benchmark config | |
CONFIG.benchmark = EasyDict({}) | |
CONFIG.benchmark.him2k_img = '/home/jiachen.li/data/HIM2K/images/natural' | |
CONFIG.benchmark.him2k_alpha = '/home/jiachen.li/data/HIM2K/alphas/natural' | |
CONFIG.benchmark.him2k_comp_img = '/home/jiachen.li/data/HIM2K/images/comp' | |
CONFIG.benchmark.him2k_comp_alpha = '/home/jiachen.li/data/HIM2K/alphas/comp' | |
CONFIG.benchmark.rwp636_img = '/home/jiachen.li/data/RealWorldPortrait-636/image' | |
CONFIG.benchmark.rwp636_alpha = '/home/jiachen.li/data/RealWorldPortrait-636/alpha' | |
CONFIG.benchmark.ppm100_img = '/home/jiachen.li/data/PPM-100/image' | |
CONFIG.benchmark.ppm100_alpha = '/home/jiachen.li/data/PPM-100/matte' | |
CONFIG.benchmark.am2k_img = '/home/jiachen.li/data/AM2k/validation/original' | |
CONFIG.benchmark.am2k_alpha = '/home/jiachen.li/data/AM2k/validation/mask' | |
CONFIG.benchmark.rw100_img = '/home/jiachen.li/data/RefMatte_RW_100/image_all' | |
CONFIG.benchmark.rw100_alpha = '/home/jiachen.li/data/RefMatte_RW_100/mask' | |
CONFIG.benchmark.rw100_text = '/home/jiachen.li/data/RefMatte_RW_100/refmatte_rw100_label.json' | |
CONFIG.benchmark.rw100_index = '/home/jiachen.li/data/RefMatte_RW_100/eval_index_expression.json' | |
CONFIG.benchmark.vm_img = '/home/jiachen.li/data/videomatte_512x288' | |
# Training config | |
CONFIG.train = EasyDict({}) | |
CONFIG.train.total_step = 100000 | |
CONFIG.train.warmup_step = 5000 | |
CONFIG.train.val_step = 1000 | |
# basic learning rate of optimizer | |
CONFIG.train.G_lr = 1e-3 | |
# beta1 and beta2 for Adam | |
CONFIG.train.beta1 = 0.5 | |
CONFIG.train.beta2 = 0.999 | |
# weight of different losses | |
CONFIG.train.rec_weight = 1 | |
CONFIG.train.comp_weight = 1 | |
CONFIG.train.lap_weight = 1 | |
# clip large gradient | |
CONFIG.train.clip_grad = True | |
# resume the training (checkpoint file name) | |
CONFIG.train.resume_checkpoint = None | |
# reset the learning rate (this option will reset the optimizer and learning rate scheduler and ignore warmup) | |
CONFIG.train.reset_lr = False | |
# Logging config | |
CONFIG.log = EasyDict({}) | |
CONFIG.log.tensorboard_path = "./logs/tensorboard" | |
CONFIG.log.tensorboard_step = 100 | |
# save less images to save disk space | |
CONFIG.log.tensorboard_image_step = 500 | |
CONFIG.log.logging_path = "./logs/stdout" | |
CONFIG.log.logging_step = 10 | |
CONFIG.log.logging_level = "DEBUG" | |
CONFIG.log.checkpoint_path = "./checkpoints" | |
CONFIG.log.checkpoint_step = 10000 | |
def load_config(custom_config, default_config=CONFIG, prefix="CONFIG"): | |
""" | |
This function will recursively overwrite the default config by a custom config | |
:param default_config: | |
:param custom_config: parsed from config/config.toml | |
:param prefix: prefix for config key | |
:return: None | |
""" | |
if "is_default" in default_config: | |
default_config.is_default = False | |
for key in custom_config.keys(): | |
full_key = ".".join([prefix, key]) | |
if key not in default_config: | |
raise NotImplementedError("Unknown config key: {}".format(full_key)) | |
elif isinstance(custom_config[key], dict): | |
if isinstance(default_config[key], dict): | |
load_config(default_config=default_config[key], | |
custom_config=custom_config[key], | |
prefix=full_key) | |
else: | |
raise ValueError("{}: Expected {}, got dict instead.".format(full_key, type(custom_config[key]))) | |
else: | |
if isinstance(default_config[key], dict): | |
raise ValueError("{}: Expected dict, got {} instead.".format(full_key, type(custom_config[key]))) | |
else: | |
default_config[key] = custom_config[key] | |