Spaces:
No application file
No application file
import os | |
import torch | |
import torch.optim as optim | |
from tqdm import tqdm | |
from torch.autograd import Variable | |
from network_v0.model import PointModel | |
from loss_function import KeypointLoss | |
class Trainer(object): | |
def __init__(self, config, train_loader=None): | |
self.config = config | |
# data parameters | |
self.train_loader = train_loader | |
self.num_train = len(self.train_loader) | |
# training parameters | |
self.max_epoch = config.max_epoch | |
self.start_epoch = config.start_epoch | |
self.momentum = config.momentum | |
self.lr = config.init_lr | |
self.lr_factor = config.lr_factor | |
self.display = config.display | |
# misc params | |
self.use_gpu = config.use_gpu | |
self.random_seed = config.seed | |
self.gpu = config.gpu | |
self.ckpt_dir = config.ckpt_dir | |
self.ckpt_name = "{}-{}".format(config.ckpt_name, config.seed) | |
# build model | |
self.model = PointModel(is_test=False) | |
# training on GPU | |
if self.use_gpu: | |
torch.cuda.set_device(self.gpu) | |
self.model.cuda() | |
print( | |
"Number of model parameters: {:,}".format( | |
sum([p.data.nelement() for p in self.model.parameters()]) | |
) | |
) | |
# build loss functional | |
self.loss_func = KeypointLoss(config) | |
# build optimizer and scheduler | |
self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) | |
self.lr_scheduler = optim.lr_scheduler.MultiStepLR( | |
self.optimizer, milestones=[4, 8], gamma=self.lr_factor | |
) | |
# resume | |
if int(self.config.start_epoch) > 0: | |
( | |
self.config.start_epoch, | |
self.model, | |
self.optimizer, | |
self.lr_scheduler, | |
) = self.load_checkpoint( | |
int(self.config.start_epoch), | |
self.model, | |
self.optimizer, | |
self.lr_scheduler, | |
) | |
def train(self): | |
print("\nTrain on {} samples".format(self.num_train)) | |
self.save_checkpoint(0, self.model, self.optimizer, self.lr_scheduler) | |
for epoch in range(self.start_epoch, self.max_epoch): | |
print( | |
"\nEpoch: {}/{} --lr: {:.6f}".format(epoch + 1, self.max_epoch, self.lr) | |
) | |
# train for one epoch | |
self.train_one_epoch(epoch) | |
if self.lr_scheduler: | |
self.lr_scheduler.step() | |
self.save_checkpoint( | |
epoch + 1, self.model, self.optimizer, self.lr_scheduler | |
) | |
def train_one_epoch(self, epoch): | |
self.model.train() | |
for (i, data) in enumerate(tqdm(self.train_loader)): | |
if self.use_gpu: | |
source_img = data["image_aug"].cuda() | |
target_img = data["image"].cuda() | |
homography = data["homography"].cuda() | |
source_img = Variable(source_img) | |
target_img = Variable(target_img) | |
homography = Variable(homography) | |
# forward propogation | |
output = self.model(source_img, target_img, homography) | |
# compute loss | |
loss, loc_loss, desc_loss, score_loss, corres_loss = self.loss_func(output) | |
# compute gradients and update | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
# print training info | |
msg_batch = ( | |
"Epoch:{} Iter:{} lr:{:.4f} " | |
"loc_loss={:.4f} desc_loss={:.4f} score_loss={:.4f} corres_loss={:.4f} " | |
"loss={:.4f} ".format( | |
(epoch + 1), | |
i, | |
self.lr, | |
loc_loss.data, | |
desc_loss.data, | |
score_loss.data, | |
corres_loss.data, | |
loss.data, | |
) | |
) | |
if (i % self.display) == 0: | |
print(msg_batch) | |
return | |
def save_checkpoint(self, epoch, model, optimizer, lr_scheduler): | |
filename = self.ckpt_name + "_" + str(epoch) + ".pth" | |
torch.save( | |
{ | |
"epoch": epoch, | |
"model_state": model.state_dict(), | |
"optimizer_state": optimizer.state_dict(), | |
"lr_scheduler": lr_scheduler.state_dict(), | |
}, | |
os.path.join(self.ckpt_dir, filename), | |
) | |
def load_checkpoint(self, epoch, model, optimizer, lr_scheduler): | |
filename = self.ckpt_name + "_" + str(epoch) + ".pth" | |
ckpt = torch.load(os.path.join(self.ckpt_dir, filename)) | |
epoch = ckpt["epoch"] | |
model.load_state_dict(ckpt["model_state"]) | |
optimizer.load_state_dict(ckpt["optimizer_state"]) | |
lr_scheduler.load_state_dict(ckpt["lr_scheduler"]) | |
print("[*] Loaded {} checkpoint @ epoch {}".format(filename, ckpt["epoch"])) | |
return epoch, model, optimizer, lr_scheduler | |