File size: 6,001 Bytes
05ff3be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from easydict import EasyDict

# Base default config
CONFIG = EasyDict({})
# to indicate this is a default setting, should not be changed by user
CONFIG.is_default = True
CONFIG.version = "baseline"
CONFIG.phase = "train"
# distributed training
CONFIG.dist = False
CONFIG.wandb = False
# global variables which will be assigned in the runtime
CONFIG.local_rank = 0
CONFIG.gpu = 0
CONFIG.world_size = 1

# Model config
CONFIG.model = EasyDict({})
# use pretrained checkpoint as encoder
CONFIG.model.freeze_seg = True
CONFIG.model.multi_scale = False
CONFIG.model.imagenet_pretrain = True
CONFIG.model.imagenet_pretrain_path = "/home/liyaoyi/Source/python/attentionMatting/pretrain/model_best_resnet34_En_nomixup.pth"
CONFIG.model.batch_size = 16
# one-hot or class, choice: [3, 1]
CONFIG.model.mask_channel = 1
CONFIG.model.trimap_channel = 3

# hyper-parameter for refinement
CONFIG.model.self_refine_width1 = 30
CONFIG.model.self_refine_width2 = 15
CONFIG.model.self_mask_width = 10

# Model -> Architecture config
CONFIG.model.arch = EasyDict({})
# definition in networks/encoders/__init__.py and networks/encoders/__init__.py
CONFIG.model.arch.encoder = "res_shortcut_encoder_29"
CONFIG.model.arch.decoder = "res_shortcut_decoder_22"
CONFIG.model.arch.m2m = "conv_baseline"
CONFIG.model.arch.seg = "maskrcnn"
# predefined for GAN structure
CONFIG.model.arch.discriminator = None


# Dataloader config
CONFIG.data = EasyDict({})
CONFIG.data.cutmask_prob = 0
CONFIG.data.workers = 0
CONFIG.data.pha_ratio = 0.5
# data path for training and validation in training phase
CONFIG.data.train_fg = None
CONFIG.data.train_alpha = None
CONFIG.data.train_bg = None
CONFIG.data.test_merged = None
CONFIG.data.test_alpha = None
CONFIG.data.test_trimap = None
CONFIG.data.imagematte_fg = None
CONFIG.data.imagematte_pha = None
CONFIG.data.d646_fg = None
CONFIG.data.d646_pha = None
CONFIG.data.aim_fg = None
CONFIG.data.aim_pha = None
CONFIG.data.human2k_fg = None
CONFIG.data.human2k_pha = None
CONFIG.data.am2k_fg = None
CONFIG.data.am2k_pha = None
CONFIG.data.coco_bg = None
CONFIG.data.bg20k_bg = None
CONFIG.data.rim_pha = None
CONFIG.data.rim_img = None
CONFIG.data.spd_pha = None
CONFIG.data.spd_img = None
# feed forward image size (untested)
CONFIG.data.crop_size = 1024
# composition of two foregrounds, affine transform, crop and HSV jitter
CONFIG.data.real_world_aug = False
CONFIG.data.augmentation = True
CONFIG.data.random_interp = True

### Benchmark config
CONFIG.benchmark = EasyDict({})
CONFIG.benchmark.him2k_img = '/home/jiachen.li/data/HIM2K/images/natural'
CONFIG.benchmark.him2k_alpha = '/home/jiachen.li/data/HIM2K/alphas/natural'
CONFIG.benchmark.him2k_comp_img = '/home/jiachen.li/data/HIM2K/images/comp'
CONFIG.benchmark.him2k_comp_alpha = '/home/jiachen.li/data/HIM2K/alphas/comp'
CONFIG.benchmark.rwp636_img = '/home/jiachen.li/data/RealWorldPortrait-636/image'
CONFIG.benchmark.rwp636_alpha = '/home/jiachen.li/data/RealWorldPortrait-636/alpha'
CONFIG.benchmark.ppm100_img = '/home/jiachen.li/data/PPM-100/image'
CONFIG.benchmark.ppm100_alpha = '/home/jiachen.li/data/PPM-100/matte'
CONFIG.benchmark.am2k_img = '/home/jiachen.li/data/AM2k/validation/original'
CONFIG.benchmark.am2k_alpha = '/home/jiachen.li/data/AM2k/validation/mask'
CONFIG.benchmark.rw100_img = '/home/jiachen.li/data/RefMatte_RW_100/image_all'
CONFIG.benchmark.rw100_alpha = '/home/jiachen.li/data/RefMatte_RW_100/mask'
CONFIG.benchmark.rw100_text = '/home/jiachen.li/data/RefMatte_RW_100/refmatte_rw100_label.json'
CONFIG.benchmark.rw100_index = '/home/jiachen.li/data/RefMatte_RW_100/eval_index_expression.json'
CONFIG.benchmark.vm_img = '/home/jiachen.li/data/videomatte_512x288'

# Training config
CONFIG.train = EasyDict({})
CONFIG.train.total_step = 100000
CONFIG.train.warmup_step = 5000
CONFIG.train.val_step = 1000
# basic learning rate of optimizer
CONFIG.train.G_lr = 1e-3
# beta1 and beta2 for Adam
CONFIG.train.beta1 = 0.5
CONFIG.train.beta2 = 0.999
# weight of different losses
CONFIG.train.rec_weight = 1
CONFIG.train.comp_weight = 1
CONFIG.train.lap_weight = 1
# clip large gradient
CONFIG.train.clip_grad = True
# resume the training (checkpoint file name)
CONFIG.train.resume_checkpoint = None
# reset the learning rate (this option will reset the optimizer and learning rate scheduler and ignore warmup)
CONFIG.train.reset_lr = False


# Logging config
CONFIG.log = EasyDict({})
CONFIG.log.tensorboard_path = "./logs/tensorboard"
CONFIG.log.tensorboard_step = 100
# save less images to save disk space
CONFIG.log.tensorboard_image_step = 500
CONFIG.log.logging_path = "./logs/stdout"
CONFIG.log.logging_step = 10
CONFIG.log.logging_level = "DEBUG"
CONFIG.log.checkpoint_path = "./checkpoints"
CONFIG.log.checkpoint_step = 10000


def load_config(custom_config, default_config=CONFIG, prefix="CONFIG"):
    """
    This function will recursively overwrite the default config by a custom config
    :param default_config:
    :param custom_config: parsed from config/config.toml
    :param prefix: prefix for config key
    :return: None
    """
    if "is_default" in default_config:
        default_config.is_default = False

    for key in custom_config.keys():
        full_key = ".".join([prefix, key])
        if key not in default_config:
            raise NotImplementedError("Unknown config key: {}".format(full_key))
        elif isinstance(custom_config[key], dict):
            if isinstance(default_config[key], dict):
                load_config(default_config=default_config[key],
                            custom_config=custom_config[key],
                            prefix=full_key)
            else:
                raise ValueError("{}: Expected {}, got dict instead.".format(full_key, type(custom_config[key])))
        else:
            if isinstance(default_config[key], dict):
                raise ValueError("{}: Expected dict, got {} instead.".format(full_key, type(custom_config[key])))
            else:
                default_config[key] = custom_config[key]