File size: 9,398 Bytes
47162d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
@Author : Peike Li
@Contact : [email protected]
@File : train.py
@Time : 8/4/19 3:36 PM
@Desc :
@License : This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import os
import json
import timeit
import argparse
import torch
import torch.optim as optim
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
from torch.utils import data
import networks
import utils.schp as schp
from datasets.datasets import LIPDataSet
from datasets.target_generation import generate_edge_tensor
from utils.transforms import BGR2RGB_transform
from utils.criterion import CriterionAll
from utils.encoding import DataParallelModel, DataParallelCriterion
from utils.warmup_scheduler import SGDRScheduler
def get_arguments():
"""Parse all the arguments provided from the CLI.
Returns:
A list of parsed arguments.
"""
parser = argparse.ArgumentParser(description="Self Correction for Human Parsing")
# Network Structure
parser.add_argument("--arch", type=str, default='resnet101')
# Data Preference
parser.add_argument("--data-dir", type=str, default='./data/LIP')
parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--input-size", type=str, default='473,473')
parser.add_argument("--num-classes", type=int, default=20)
parser.add_argument("--ignore-label", type=int, default=255)
parser.add_argument("--random-mirror", action="store_true")
parser.add_argument("--random-scale", action="store_true")
# Training Strategy
parser.add_argument("--learning-rate", type=float, default=7e-3)
parser.add_argument("--momentum", type=float, default=0.9)
parser.add_argument("--weight-decay", type=float, default=5e-4)
parser.add_argument("--gpu", type=str, default='0,1,2')
parser.add_argument("--start-epoch", type=int, default=0)
parser.add_argument("--epochs", type=int, default=150)
parser.add_argument("--eval-epochs", type=int, default=10)
parser.add_argument("--imagenet-pretrain", type=str, default='./pretrain_model/resnet101-imagenet.pth')
parser.add_argument("--log-dir", type=str, default='./log')
parser.add_argument("--model-restore", type=str, default='./log/checkpoint.pth.tar')
parser.add_argument("--schp-start", type=int, default=100, help='schp start epoch')
parser.add_argument("--cycle-epochs", type=int, default=10, help='schp cyclical epoch')
parser.add_argument("--schp-restore", type=str, default='./log/schp_checkpoint.pth.tar')
parser.add_argument("--lambda-s", type=float, default=1, help='segmentation loss weight')
parser.add_argument("--lambda-e", type=float, default=1, help='edge loss weight')
parser.add_argument("--lambda-c", type=float, default=0.1, help='segmentation-edge consistency loss weight')
return parser.parse_args()
def main():
args = get_arguments()
print(args)
start_epoch = 0
cycle_n = 0
if not os.path.exists(args.log_dir):
os.makedirs(args.log_dir)
with open(os.path.join(args.log_dir, 'args.json'), 'w') as opt_file:
json.dump(vars(args), opt_file)
gpus = [int(i) for i in args.gpu.split(',')]
if not args.gpu == 'None':
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
input_size = list(map(int, args.input_size.split(',')))
cudnn.enabled = True
cudnn.benchmark = True
# Model Initialization
AugmentCE2P = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=args.imagenet_pretrain)
model = DataParallelModel(AugmentCE2P)
model.cuda()
IMAGE_MEAN = AugmentCE2P.mean
IMAGE_STD = AugmentCE2P.std
INPUT_SPACE = AugmentCE2P.input_space
print('image mean: {}'.format(IMAGE_MEAN))
print('image std: {}'.format(IMAGE_STD))
print('input space:{}'.format(INPUT_SPACE))
restore_from = args.model_restore
if os.path.exists(restore_from):
print('Resume training from {}'.format(restore_from))
checkpoint = torch.load(restore_from)
model.load_state_dict(checkpoint['state_dict'])
start_epoch = checkpoint['epoch']
SCHP_AugmentCE2P = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=args.imagenet_pretrain)
schp_model = DataParallelModel(SCHP_AugmentCE2P)
schp_model.cuda()
if os.path.exists(args.schp_restore):
print('Resuming schp checkpoint from {}'.format(args.schp_restore))
schp_checkpoint = torch.load(args.schp_restore)
schp_model_state_dict = schp_checkpoint['state_dict']
cycle_n = schp_checkpoint['cycle_n']
schp_model.load_state_dict(schp_model_state_dict)
# Loss Function
criterion = CriterionAll(lambda_1=args.lambda_s, lambda_2=args.lambda_e, lambda_3=args.lambda_c,
num_classes=args.num_classes)
criterion = DataParallelCriterion(criterion)
criterion.cuda()
# Data Loader
if INPUT_SPACE == 'BGR':
print('BGR Transformation')
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=IMAGE_MEAN,
std=IMAGE_STD),
])
elif INPUT_SPACE == 'RGB':
print('RGB Transformation')
transform = transforms.Compose([
transforms.ToTensor(),
BGR2RGB_transform(),
transforms.Normalize(mean=IMAGE_MEAN,
std=IMAGE_STD),
])
train_dataset = LIPDataSet(args.data_dir, 'train', crop_size=input_size, transform=transform)
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size * len(gpus),
num_workers=16, shuffle=True, pin_memory=True, drop_last=True)
print('Total training samples: {}'.format(len(train_dataset)))
# Optimizer Initialization
optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
weight_decay=args.weight_decay)
lr_scheduler = SGDRScheduler(optimizer, total_epoch=args.epochs,
eta_min=args.learning_rate / 100, warmup_epoch=10,
start_cyclical=args.schp_start, cyclical_base_lr=args.learning_rate / 2,
cyclical_epoch=args.cycle_epochs)
total_iters = args.epochs * len(train_loader)
start = timeit.default_timer()
for epoch in range(start_epoch, args.epochs):
lr_scheduler.step(epoch=epoch)
lr = lr_scheduler.get_lr()[0]
model.train()
for i_iter, batch in enumerate(train_loader):
i_iter += len(train_loader) * epoch
images, labels, _ = batch
labels = labels.cuda(non_blocking=True)
edges = generate_edge_tensor(labels)
labels = labels.type(torch.cuda.LongTensor)
edges = edges.type(torch.cuda.LongTensor)
preds = model(images)
# Online Self Correction Cycle with Label Refinement
if cycle_n >= 1:
with torch.no_grad():
soft_preds = schp_model(images)
soft_parsing = []
soft_edge = []
for soft_pred in soft_preds:
soft_parsing.append(soft_pred[0][-1])
soft_edge.append(soft_pred[1][-1])
soft_preds = torch.cat(soft_parsing, dim=0)
soft_edges = torch.cat(soft_edge, dim=0)
else:
soft_preds = None
soft_edges = None
loss = criterion(preds, [labels, edges, soft_preds, soft_edges], cycle_n)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i_iter % 100 == 0:
print('iter = {} of {} completed, lr = {}, loss = {}'.format(i_iter, total_iters, lr,
loss.data.cpu().numpy()))
if (epoch + 1) % (args.eval_epochs) == 0:
schp.save_schp_checkpoint({
'epoch': epoch + 1,
'state_dict': model.state_dict(),
}, False, args.log_dir, filename='checkpoint_{}.pth.tar'.format(epoch + 1))
# Self Correction Cycle with Model Aggregation
if (epoch + 1) >= args.schp_start and (epoch + 1 - args.schp_start) % args.cycle_epochs == 0:
print('Self-correction cycle number {}'.format(cycle_n))
schp.moving_average(schp_model, model, 1.0 / (cycle_n + 1))
cycle_n += 1
schp.bn_re_estimate(train_loader, schp_model)
schp.save_schp_checkpoint({
'state_dict': schp_model.state_dict(),
'cycle_n': cycle_n,
}, False, args.log_dir, filename='schp_{}_checkpoint.pth.tar'.format(cycle_n))
torch.cuda.empty_cache()
end = timeit.default_timer()
print('epoch = {} of {} completed using {} s'.format(epoch, args.epochs,
(end - start) / (epoch - start_epoch + 1)))
end = timeit.default_timer()
print('Training Finished in {} seconds'.format(end - start))
if __name__ == '__main__':
main()
|