File size: 13,283 Bytes
355b5d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import argparse
import os
class BaseOptions():
def __init__(self):
self.initialized = False
self.parser = None
def initialize(self, parser):
# Datasets related
g_data = parser.add_argument_group('Data')
g_data.add_argument('--dataset', type=str, default='renderppl', help='dataset name')
g_data.add_argument('--dataroot', type=str, default='./data',
help='path to images (data folder)')
g_data.add_argument('--loadSize', type=int, default=512, help='load size of input image')
# Experiment related
g_exp = parser.add_argument_group('Experiment')
g_exp.add_argument('--name', type=str, default='',
help='name of the experiment. It decides where to store samples and models')
g_exp.add_argument('--debug', action='store_true', help='debug mode or not')
g_exp.add_argument('--mode', type=str, default='inout', help='inout || color')
# Training related
g_train = parser.add_argument_group('Training')
g_train.add_argument('--tmp_id', type=int, default=0, help='tmp_id')
g_train.add_argument('--gpu_id', type=int, default=0, help='gpu id for cuda')
g_train.add_argument('--batch_size', type=int, default=32, help='input batch size')
g_train.add_argument('--num_threads', default=1, type=int, help='# sthreads for loading data')
g_train.add_argument('--serial_batches', action='store_true',
help='if true, takes images in order to make batches, otherwise takes them randomly')
g_train.add_argument('--pin_memory', action='store_true', help='pin_memory')
g_train.add_argument('--learning_rate', type=float, default=1e-3, help='adam learning rate')
g_train.add_argument('--num_iter', type=int, default=30000, help='num iterations to train')
g_train.add_argument('--freq_plot', type=int, default=100, help='freqency of the error plot')
g_train.add_argument('--freq_mesh', type=int, default=20000, help='freqency of the save_checkpoints')
g_train.add_argument('--freq_eval', type=int, default=5000, help='freqency of the save_checkpoints')
g_train.add_argument('--freq_save_ply', type=int, default=5000, help='freqency of the save ply')
g_train.add_argument('--freq_save_image', type=int, default=100, help='freqency of the save input image')
g_train.add_argument('--resume_epoch', type=int, default=-1, help='epoch resuming the training')
g_train.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
g_train.add_argument('--finetune', action='store_true', help='fine tuning netG in training C')
# Testing related
g_test = parser.add_argument_group('Testing')
g_test.add_argument('--resolution', type=int, default=512, help='# of grid in mesh reconstruction')
g_test.add_argument('--no_numel_eval', action='store_true', help='no numerical evaluation')
g_test.add_argument('--no_mesh_recon', action='store_true', help='no mesh reconstruction')
# Sampling related
g_sample = parser.add_argument_group('Sampling')
g_sample.add_argument('--num_sample_inout', type=int, default=6000, help='# of sampling points')
g_sample.add_argument('--num_sample_surface', type=int, default=0, help='# of sampling points')
g_sample.add_argument('--num_sample_normal', type=int, default=0, help='# of sampling points')
g_sample.add_argument('--num_sample_color', type=int, default=0, help='# of sampling points')
g_sample.add_argument('--num_pts_dic', type=int, default=1, help='# of pts dic you load')
g_sample.add_argument('--crop_type', type=str, default='fullbody', help='Sampling file name.')
g_sample.add_argument('--uniform_ratio', type=float, default=0.1, help='maximum sigma for sampling')
g_sample.add_argument('--mask_ratio', type=float, default=0.5, help='maximum sigma for sampling')
g_sample.add_argument('--sampling_parts', action='store_true', help='Sampling on the fly')
g_sample.add_argument('--sampling_otf', action='store_true', help='Sampling on the fly')
g_sample.add_argument('--sampling_mode', type=str, default='sigma_uniform', help='Sampling file name.')
g_sample.add_argument('--linear_anneal_sigma', action='store_true', help='linear annealing of sigma')
g_sample.add_argument('--sigma_max', type=float, default=0.0, help='maximum sigma for sampling')
g_sample.add_argument('--sigma_min', type=float, default=0.0, help='minimum sigma for sampling')
g_sample.add_argument('--sigma', type=float, default=1.0, help='sigma for sampling')
g_sample.add_argument('--sigma_surface', type=float, default=1.0, help='sigma for sampling')
g_sample.add_argument('--z_size', type=float, default=200.0, help='z normalization factor')
# Model related
g_model = parser.add_argument_group('Model')
# General
g_model.add_argument('--norm', type=str, default='batch',
help='instance normalization or batch normalization or group normalization')
# Image filter General
g_model.add_argument('--netG', type=str, default='hgpifu', help='piximp | fanimp | hghpifu')
g_model.add_argument('--netC', type=str, default='resblkpifu', help='resblkpifu | resblkhpifu')
# hgimp specific
g_model.add_argument('--num_stack', type=int, default=4, help='# of hourglass')
g_model.add_argument('--hg_depth', type=int, default=2, help='# of stacked layer of hourglass')
g_model.add_argument('--hg_down', type=str, default='ave_pool', help='ave pool || conv64 || conv128')
g_model.add_argument('--hg_dim', type=int, default=256, help='256 | 512')
# Classification General
g_model.add_argument('--mlp_norm', type=str, default='group', help='normalization for volume branch')
g_model.add_argument('--mlp_dim', nargs='+', default=[257, 1024, 512, 256, 128, 1], type=int,
help='# of dimensions of mlp. no need to put the first channel')
g_model.add_argument('--mlp_dim_color', nargs='+', default=[1024, 512, 256, 128, 3], type=int,
help='# of dimensions of mlp. no need to put the first channel')
g_model.add_argument('--mlp_res_layers', nargs='+', default=[2,3,4], type=int,
help='leyers that has skip connection. use 0 for no residual pass')
g_model.add_argument('--merge_layer', type=int, default=-1)
# for train
parser.add_argument('--random_body_chop', action='store_true', help='if random flip')
parser.add_argument('--random_flip', action='store_true', help='if random flip')
parser.add_argument('--random_trans', action='store_true', help='if random flip')
parser.add_argument('--random_scale', action='store_true', help='if random flip')
parser.add_argument('--random_rotate', action='store_true', help='if random flip')
parser.add_argument('--random_bg', action='store_true', help='using random background')
parser.add_argument('--schedule', type=int, nargs='+', default=[10, 15],
help='Decrease learning rate at these epochs.')
parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.')
parser.add_argument('--lambda_nml', type=float, default=0.0, help='weight of normal loss')
parser.add_argument('--lambda_cmp_l1', type=float, default=0.0, help='weight of normal loss')
parser.add_argument('--occ_loss_type', type=str, default='mse', help='bce | brock_bce | mse')
parser.add_argument('--clr_loss_type', type=str, default='mse', help='mse | l1')
parser.add_argument('--nml_loss_type', type=str, default='mse', help='mse | l1')
parser.add_argument('--occ_gamma', type=float, default=None, help='weighting term')
parser.add_argument('--no_finetune', action='store_true', help='fine tuning netG in training C')
# for eval
parser.add_argument('--val_test_error', action='store_true', help='validate errors of test data')
parser.add_argument('--val_train_error', action='store_true', help='validate errors of train data')
parser.add_argument('--gen_test_mesh', action='store_true', help='generate test mesh')
parser.add_argument('--gen_train_mesh', action='store_true', help='generate train mesh')
parser.add_argument('--all_mesh', action='store_true', help='generate meshs from all hourglass output')
parser.add_argument('--num_gen_mesh_test', type=int, default=4,
help='how many meshes to generate during testing')
# path
parser.add_argument('--load_netG_checkpoint_path', type=str, help='path to save checkpoints')
parser.add_argument('--load_netC_checkpoint_path', type=str, help='path to save checkpoints')
parser.add_argument('--checkpoints_path', type=str, default='./checkpoints', help='path to save checkpoints')
parser.add_argument('--results_path', type=str, default='./results', help='path to save results ply')
parser.add_argument('--load_checkpoint_path', type=str, help='path to save results ply')
parser.add_argument('--single', type=str, default='', help='single data for training')
# for single image reconstruction
parser.add_argument('--mask_path', type=str, help='path for input mask')
parser.add_argument('--img_path', type=str, help='path for input image')
# for multi resolution
parser.add_argument('--load_netMR_checkpoint_path', type=str, help='path to save checkpoints')
parser.add_argument('--loadSizeBig', type=int, default=1024, help='load size of input image')
parser.add_argument('--loadSizeLocal', type=int, default=512, help='load size of input image')
parser.add_argument('--train_full_pifu', action='store_true', help='enable end-to-end training')
parser.add_argument('--num_local', type=int, default=1, help='number of local cropping')
# for normal condition
parser.add_argument('--load_netFB_checkpoint_path', type=str, help='path to save checkpoints')
parser.add_argument('--load_netF_checkpoint_path', type=str, help='path to save checkpoints')
parser.add_argument('--load_netB_checkpoint_path', type=str, help='path to save checkpoints')
parser.add_argument('--use_aio_normal', action='store_true')
parser.add_argument('--use_front_normal', action='store_true')
parser.add_argument('--use_back_normal', action='store_true')
parser.add_argument('--no_intermediate_loss', action='store_true')
# aug
group_aug = parser.add_argument_group('aug')
group_aug.add_argument('--aug_alstd', type=float, default=0.0, help='augmentation pca lighting alpha std')
group_aug.add_argument('--aug_bri', type=float, default=0.2, help='augmentation brightness')
group_aug.add_argument('--aug_con', type=float, default=0.2, help='augmentation contrast')
group_aug.add_argument('--aug_sat', type=float, default=0.05, help='augmentation saturation')
group_aug.add_argument('--aug_hue', type=float, default=0.05, help='augmentation hue')
group_aug.add_argument('--aug_gry', type=float, default=0.1, help='augmentation gray scale')
group_aug.add_argument('--aug_blur', type=float, default=0.0, help='augmentation blur')
# for reconstruction
parser.add_argument('--start_id', type=int, default=-1, help='load size of input image')
parser.add_argument('--end_id', type=int, default=-1, help='load size of input image')
# special tasks
self.initialized = True
return parser
def gather_options(self, args=None):
# initialize parser with basic options
if not self.initialized:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = self.initialize(parser)
self.parser = parser
if args is None:
return self.parser.parse_args()
else:
return self.parser.parse_args(args)
def print_options(self, opt):
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
def parse(self, args=None):
opt = self.gather_options(args)
opt.sigma = opt.sigma_max
if len(opt.mlp_res_layers) == 1 and opt.mlp_res_layers[0] < 1:
opt.mlp_res_layers = []
return opt
|