from easydict import EasyDict # Base default config CONFIG = EasyDict({}) # to indicate this is a default setting, should not be changed by user CONFIG.is_default = True CONFIG.version = "baseline" CONFIG.phase = "train" # distributed training CONFIG.dist = False CONFIG.wandb = False # global variables which will be assigned in the runtime CONFIG.local_rank = 0 CONFIG.gpu = 0 CONFIG.world_size = 1 # Model config CONFIG.model = EasyDict({}) # use pretrained checkpoint as encoder CONFIG.model.freeze_seg = True CONFIG.model.multi_scale = False CONFIG.model.imagenet_pretrain = True CONFIG.model.imagenet_pretrain_path = "/home/liyaoyi/Source/python/attentionMatting/pretrain/model_best_resnet34_En_nomixup.pth" CONFIG.model.batch_size = 16 # one-hot or class, choice: [3, 1] CONFIG.model.mask_channel = 1 CONFIG.model.trimap_channel = 3 # hyper-parameter for refinement CONFIG.model.self_refine_width1 = 30 CONFIG.model.self_refine_width2 = 15 CONFIG.model.self_mask_width = 10 # Model -> Architecture config CONFIG.model.arch = EasyDict({}) # definition in networks/encoders/__init__.py and networks/encoders/__init__.py CONFIG.model.arch.encoder = "res_shortcut_encoder_29" CONFIG.model.arch.decoder = "res_shortcut_decoder_22" CONFIG.model.arch.m2m = "conv_baseline" CONFIG.model.arch.seg = "maskrcnn" # predefined for GAN structure CONFIG.model.arch.discriminator = None # Dataloader config CONFIG.data = EasyDict({}) CONFIG.data.cutmask_prob = 0 CONFIG.data.workers = 0 CONFIG.data.pha_ratio = 0.5 # data path for training and validation in training phase CONFIG.data.train_fg = None CONFIG.data.train_alpha = None CONFIG.data.train_bg = None CONFIG.data.test_merged = None CONFIG.data.test_alpha = None CONFIG.data.test_trimap = None CONFIG.data.imagematte_fg = None CONFIG.data.imagematte_pha = None CONFIG.data.d646_fg = None CONFIG.data.d646_pha = None CONFIG.data.aim_fg = None CONFIG.data.aim_pha = None CONFIG.data.human2k_fg = None CONFIG.data.human2k_pha = None CONFIG.data.am2k_fg = None CONFIG.data.am2k_pha = None CONFIG.data.coco_bg = None CONFIG.data.bg20k_bg = None CONFIG.data.rim_pha = None CONFIG.data.rim_img = None CONFIG.data.spd_pha = None CONFIG.data.spd_img = None # feed forward image size (untested) CONFIG.data.crop_size = 1024 # composition of two foregrounds, affine transform, crop and HSV jitter CONFIG.data.real_world_aug = False CONFIG.data.augmentation = True CONFIG.data.random_interp = True ### Benchmark config CONFIG.benchmark = EasyDict({}) CONFIG.benchmark.him2k_img = '/home/jiachen.li/data/HIM2K/images/natural' CONFIG.benchmark.him2k_alpha = '/home/jiachen.li/data/HIM2K/alphas/natural' CONFIG.benchmark.him2k_comp_img = '/home/jiachen.li/data/HIM2K/images/comp' CONFIG.benchmark.him2k_comp_alpha = '/home/jiachen.li/data/HIM2K/alphas/comp' CONFIG.benchmark.rwp636_img = '/home/jiachen.li/data/RealWorldPortrait-636/image' CONFIG.benchmark.rwp636_alpha = '/home/jiachen.li/data/RealWorldPortrait-636/alpha' CONFIG.benchmark.ppm100_img = '/home/jiachen.li/data/PPM-100/image' CONFIG.benchmark.ppm100_alpha = '/home/jiachen.li/data/PPM-100/matte' CONFIG.benchmark.am2k_img = '/home/jiachen.li/data/AM2k/validation/original' CONFIG.benchmark.am2k_alpha = '/home/jiachen.li/data/AM2k/validation/mask' CONFIG.benchmark.rw100_img = '/home/jiachen.li/data/RefMatte_RW_100/image_all' CONFIG.benchmark.rw100_alpha = '/home/jiachen.li/data/RefMatte_RW_100/mask' CONFIG.benchmark.rw100_text = '/home/jiachen.li/data/RefMatte_RW_100/refmatte_rw100_label.json' CONFIG.benchmark.rw100_index = '/home/jiachen.li/data/RefMatte_RW_100/eval_index_expression.json' CONFIG.benchmark.vm_img = '/home/jiachen.li/data/videomatte_512x288' # Training config CONFIG.train = EasyDict({}) CONFIG.train.total_step = 100000 CONFIG.train.warmup_step = 5000 CONFIG.train.val_step = 1000 # basic learning rate of optimizer CONFIG.train.G_lr = 1e-3 # beta1 and beta2 for Adam CONFIG.train.beta1 = 0.5 CONFIG.train.beta2 = 0.999 # weight of different losses CONFIG.train.rec_weight = 1 CONFIG.train.comp_weight = 1 CONFIG.train.lap_weight = 1 # clip large gradient CONFIG.train.clip_grad = True # resume the training (checkpoint file name) CONFIG.train.resume_checkpoint = None # reset the learning rate (this option will reset the optimizer and learning rate scheduler and ignore warmup) CONFIG.train.reset_lr = False # Logging config CONFIG.log = EasyDict({}) CONFIG.log.tensorboard_path = "./logs/tensorboard" CONFIG.log.tensorboard_step = 100 # save less images to save disk space CONFIG.log.tensorboard_image_step = 500 CONFIG.log.logging_path = "./logs/stdout" CONFIG.log.logging_step = 10 CONFIG.log.logging_level = "DEBUG" CONFIG.log.checkpoint_path = "./checkpoints" CONFIG.log.checkpoint_step = 10000 def load_config(custom_config, default_config=CONFIG, prefix="CONFIG"): """ This function will recursively overwrite the default config by a custom config :param default_config: :param custom_config: parsed from config/config.toml :param prefix: prefix for config key :return: None """ if "is_default" in default_config: default_config.is_default = False for key in custom_config.keys(): full_key = ".".join([prefix, key]) if key not in default_config: raise NotImplementedError("Unknown config key: {}".format(full_key)) elif isinstance(custom_config[key], dict): if isinstance(default_config[key], dict): load_config(default_config=default_config[key], custom_config=custom_config[key], prefix=full_key) else: raise ValueError("{}: Expected {}, got dict instead.".format(full_key, type(custom_config[key]))) else: if isinstance(default_config[key], dict): raise ValueError("{}: Expected dict, got {} instead.".format(full_key, type(custom_config[key]))) else: default_config[key] = custom_config[key]