|
|
|
|
|
|
|
""" |
|
@Author : Peike Li |
|
@Contact : [email protected] |
|
@File : train.py |
|
@Time : 8/4/19 3:36 PM |
|
@Desc : |
|
@License : This source code is licensed under the license found in the |
|
LICENSE file in the root directory of this source tree. |
|
""" |
|
|
|
import os |
|
import json |
|
import timeit |
|
import argparse |
|
|
|
import torch |
|
import torch.optim as optim |
|
import torchvision.transforms as transforms |
|
import torch.backends.cudnn as cudnn |
|
from torch.utils import data |
|
|
|
import networks |
|
import utils.schp as schp |
|
from datasets.datasets import LIPDataSet |
|
from datasets.target_generation import generate_edge_tensor |
|
from utils.transforms import BGR2RGB_transform |
|
from utils.criterion import CriterionAll |
|
from utils.encoding import DataParallelModel, DataParallelCriterion |
|
from utils.warmup_scheduler import SGDRScheduler |
|
|
|
|
|
def get_arguments(): |
|
"""Parse all the arguments provided from the CLI. |
|
Returns: |
|
A list of parsed arguments. |
|
""" |
|
parser = argparse.ArgumentParser(description="Self Correction for Human Parsing") |
|
|
|
|
|
parser.add_argument("--arch", type=str, default='resnet101') |
|
|
|
parser.add_argument("--data-dir", type=str, default='./data/LIP') |
|
parser.add_argument("--batch-size", type=int, default=16) |
|
parser.add_argument("--input-size", type=str, default='473,473') |
|
parser.add_argument("--num-classes", type=int, default=20) |
|
parser.add_argument("--ignore-label", type=int, default=255) |
|
parser.add_argument("--random-mirror", action="store_true") |
|
parser.add_argument("--random-scale", action="store_true") |
|
|
|
parser.add_argument("--learning-rate", type=float, default=7e-3) |
|
parser.add_argument("--momentum", type=float, default=0.9) |
|
parser.add_argument("--weight-decay", type=float, default=5e-4) |
|
parser.add_argument("--gpu", type=str, default='0,1,2') |
|
parser.add_argument("--start-epoch", type=int, default=0) |
|
parser.add_argument("--epochs", type=int, default=150) |
|
parser.add_argument("--eval-epochs", type=int, default=10) |
|
parser.add_argument("--imagenet-pretrain", type=str, default='./pretrain_model/resnet101-imagenet.pth') |
|
parser.add_argument("--log-dir", type=str, default='./log') |
|
parser.add_argument("--model-restore", type=str, default='./log/checkpoint.pth.tar') |
|
parser.add_argument("--schp-start", type=int, default=100, help='schp start epoch') |
|
parser.add_argument("--cycle-epochs", type=int, default=10, help='schp cyclical epoch') |
|
parser.add_argument("--schp-restore", type=str, default='./log/schp_checkpoint.pth.tar') |
|
parser.add_argument("--lambda-s", type=float, default=1, help='segmentation loss weight') |
|
parser.add_argument("--lambda-e", type=float, default=1, help='edge loss weight') |
|
parser.add_argument("--lambda-c", type=float, default=0.1, help='segmentation-edge consistency loss weight') |
|
return parser.parse_args() |
|
|
|
|
|
def main(): |
|
args = get_arguments() |
|
print(args) |
|
|
|
start_epoch = 0 |
|
cycle_n = 0 |
|
|
|
if not os.path.exists(args.log_dir): |
|
os.makedirs(args.log_dir) |
|
with open(os.path.join(args.log_dir, 'args.json'), 'w') as opt_file: |
|
json.dump(vars(args), opt_file) |
|
|
|
gpus = [int(i) for i in args.gpu.split(',')] |
|
if not args.gpu == 'None': |
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu |
|
|
|
input_size = list(map(int, args.input_size.split(','))) |
|
|
|
cudnn.enabled = True |
|
cudnn.benchmark = True |
|
|
|
|
|
AugmentCE2P = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=args.imagenet_pretrain) |
|
model = DataParallelModel(AugmentCE2P) |
|
model.cuda() |
|
|
|
IMAGE_MEAN = AugmentCE2P.mean |
|
IMAGE_STD = AugmentCE2P.std |
|
INPUT_SPACE = AugmentCE2P.input_space |
|
print('image mean: {}'.format(IMAGE_MEAN)) |
|
print('image std: {}'.format(IMAGE_STD)) |
|
print('input space:{}'.format(INPUT_SPACE)) |
|
|
|
restore_from = args.model_restore |
|
if os.path.exists(restore_from): |
|
print('Resume training from {}'.format(restore_from)) |
|
checkpoint = torch.load(restore_from) |
|
model.load_state_dict(checkpoint['state_dict']) |
|
start_epoch = checkpoint['epoch'] |
|
|
|
SCHP_AugmentCE2P = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=args.imagenet_pretrain) |
|
schp_model = DataParallelModel(SCHP_AugmentCE2P) |
|
schp_model.cuda() |
|
|
|
if os.path.exists(args.schp_restore): |
|
print('Resuming schp checkpoint from {}'.format(args.schp_restore)) |
|
schp_checkpoint = torch.load(args.schp_restore) |
|
schp_model_state_dict = schp_checkpoint['state_dict'] |
|
cycle_n = schp_checkpoint['cycle_n'] |
|
schp_model.load_state_dict(schp_model_state_dict) |
|
|
|
|
|
criterion = CriterionAll(lambda_1=args.lambda_s, lambda_2=args.lambda_e, lambda_3=args.lambda_c, |
|
num_classes=args.num_classes) |
|
criterion = DataParallelCriterion(criterion) |
|
criterion.cuda() |
|
|
|
|
|
if INPUT_SPACE == 'BGR': |
|
print('BGR Transformation') |
|
transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=IMAGE_MEAN, |
|
std=IMAGE_STD), |
|
]) |
|
|
|
elif INPUT_SPACE == 'RGB': |
|
print('RGB Transformation') |
|
transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
BGR2RGB_transform(), |
|
transforms.Normalize(mean=IMAGE_MEAN, |
|
std=IMAGE_STD), |
|
]) |
|
|
|
train_dataset = LIPDataSet(args.data_dir, 'train', crop_size=input_size, transform=transform) |
|
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size * len(gpus), |
|
num_workers=16, shuffle=True, pin_memory=True, drop_last=True) |
|
print('Total training samples: {}'.format(len(train_dataset))) |
|
|
|
|
|
optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, |
|
weight_decay=args.weight_decay) |
|
|
|
lr_scheduler = SGDRScheduler(optimizer, total_epoch=args.epochs, |
|
eta_min=args.learning_rate / 100, warmup_epoch=10, |
|
start_cyclical=args.schp_start, cyclical_base_lr=args.learning_rate / 2, |
|
cyclical_epoch=args.cycle_epochs) |
|
|
|
total_iters = args.epochs * len(train_loader) |
|
start = timeit.default_timer() |
|
for epoch in range(start_epoch, args.epochs): |
|
lr_scheduler.step(epoch=epoch) |
|
lr = lr_scheduler.get_lr()[0] |
|
|
|
model.train() |
|
for i_iter, batch in enumerate(train_loader): |
|
i_iter += len(train_loader) * epoch |
|
|
|
images, labels, _ = batch |
|
labels = labels.cuda(non_blocking=True) |
|
|
|
edges = generate_edge_tensor(labels) |
|
labels = labels.type(torch.cuda.LongTensor) |
|
edges = edges.type(torch.cuda.LongTensor) |
|
|
|
preds = model(images) |
|
|
|
|
|
if cycle_n >= 1: |
|
with torch.no_grad(): |
|
soft_preds = schp_model(images) |
|
soft_parsing = [] |
|
soft_edge = [] |
|
for soft_pred in soft_preds: |
|
soft_parsing.append(soft_pred[0][-1]) |
|
soft_edge.append(soft_pred[1][-1]) |
|
soft_preds = torch.cat(soft_parsing, dim=0) |
|
soft_edges = torch.cat(soft_edge, dim=0) |
|
else: |
|
soft_preds = None |
|
soft_edges = None |
|
|
|
loss = criterion(preds, [labels, edges, soft_preds, soft_edges], cycle_n) |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
if i_iter % 100 == 0: |
|
print('iter = {} of {} completed, lr = {}, loss = {}'.format(i_iter, total_iters, lr, |
|
loss.data.cpu().numpy())) |
|
if (epoch + 1) % (args.eval_epochs) == 0: |
|
schp.save_schp_checkpoint({ |
|
'epoch': epoch + 1, |
|
'state_dict': model.state_dict(), |
|
}, False, args.log_dir, filename='checkpoint_{}.pth.tar'.format(epoch + 1)) |
|
|
|
|
|
if (epoch + 1) >= args.schp_start and (epoch + 1 - args.schp_start) % args.cycle_epochs == 0: |
|
print('Self-correction cycle number {}'.format(cycle_n)) |
|
schp.moving_average(schp_model, model, 1.0 / (cycle_n + 1)) |
|
cycle_n += 1 |
|
schp.bn_re_estimate(train_loader, schp_model) |
|
schp.save_schp_checkpoint({ |
|
'state_dict': schp_model.state_dict(), |
|
'cycle_n': cycle_n, |
|
}, False, args.log_dir, filename='schp_{}_checkpoint.pth.tar'.format(cycle_n)) |
|
|
|
torch.cuda.empty_cache() |
|
end = timeit.default_timer() |
|
print('epoch = {} of {} completed using {} s'.format(epoch, args.epochs, |
|
(end - start) / (epoch - start_epoch + 1))) |
|
|
|
end = timeit.default_timer() |
|
print('Training Finished in {} seconds'.format(end - start)) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|