import time import itertools from dataset import ImageFolder from torchvision import transforms from torch.utils.data import DataLoader from .networks import * from utils import * from glob import glob from .face_features import FaceFeatures class UgatitSadalinHourglass(object): def __init__(self, args): self.light = args.light if self.light: self.model_name = 'UGATIT_light' else: self.model_name = 'UGATIT' self.result_dir = args.result_dir self.dataset = args.dataset self.iteration = args.iteration self.decay_flag = args.decay_flag self.batch_size = args.batch_size self.print_freq = args.print_freq self.save_freq = args.save_freq self.lr = args.lr self.ch = args.ch """ Weight """ self.adv_weight = args.adv_weight self.cycle_weight = args.cycle_weight self.identity_weight = args.identity_weight self.cam_weight = args.cam_weight self.faceid_weight = args.faceid_weight """ Discriminator """ self.n_dis = args.n_dis self.img_size = args.img_size self.img_ch = args.img_ch self.device = f'cuda:{args.gpu_ids[0]}' self.gpu_ids = args.gpu_ids self.benchmark_flag = args.benchmark_flag self.resume = args.resume self.rho_clipper = args.rho_clipper self.w_clipper = args.w_clipper self.pretrained_weights = args.pretrained_weights if torch.backends.cudnn.enabled and self.benchmark_flag: print('set benchmark !') torch.backends.cudnn.benchmark = True print("##### Information #####") print("# light : ", self.light) print("# dataset : ", self.dataset) print("# batch_size : ", self.batch_size) print("# iteration per epoch : ", self.iteration) print("##### Discriminator #####") print("# discriminator layer : ", self.n_dis) print() print("##### Weight #####") print("# adv_weight : ", self.adv_weight) print("# cycle_weight : ", self.cycle_weight) print("# faceid_weight : ", self.faceid_weight) print("# identity_weight : ", self.identity_weight) print("# cam_weight : ", self.cam_weight) print("# rho_clipper: ", self.rho_clipper) print("# w_clipper: ", self.w_clipper) ################################################################################## # Model ################################################################################## def build_model(self): """ DataLoader """ train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.Resize((self.img_size + 30, self.img_size+30)), transforms.RandomCrop(self.img_size), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) test_transform = transforms.Compose([ transforms.Resize((self.img_size, self.img_size)), transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) ]) self.trainA = ImageFolder(os.path.join('dataset', self.dataset, 'trainA'), train_transform) self.trainB = ImageFolder(os.path.join('dataset', self.dataset, 'trainB'), train_transform) self.testA = ImageFolder(os.path.join('dataset', self.dataset, 'testA'), test_transform) self.testB = ImageFolder(os.path.join('dataset', self.dataset, 'testB'), test_transform) self.trainA_loader = DataLoader(self.trainA, batch_size=self.batch_size, shuffle=True) self.trainB_loader = DataLoader(self.trainB, batch_size=self.batch_size, shuffle=True) self.testA_loader = DataLoader(self.testA, batch_size=1, shuffle=False) self.testB_loader = DataLoader(self.testB, batch_size=1, shuffle=False) """ Define Generator, Discriminator """ self.genA2B = ResnetGenerator(ngf=self.ch, img_size=self.img_size, light=self.light).to(self.device) self.genB2A = ResnetGenerator(ngf=self.ch, img_size=self.img_size, light=self.light).to(self.device) self.disGA = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device) self.disGB = Discriminator(input_nc=3, ndf=self.ch, n_layers=7).to(self.device) self.disLA = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device) self.disLB = Discriminator(input_nc=3, ndf=self.ch, n_layers=5).to(self.device) self.facenet = FaceFeatures('models/model_mobilefacenet.pth', self.device) """ Define Loss """ self.L1_loss = nn.L1Loss().to(self.device) self.MSE_loss = nn.MSELoss().to(self.device) self.BCE_loss = nn.BCEWithLogitsLoss().to(self.device) """ Trainer """ self.G_optim = torch.optim.Adam(itertools.chain(self.genA2B.parameters(), self.genB2A.parameters()), lr=self.lr, betas=(0.5, 0.999), weight_decay=0.0001) self.D_optim = torch.optim.Adam( itertools.chain(self.disGA.parameters(), self.disGB.parameters(), self.disLA.parameters(), self.disLB.parameters()), lr=self.lr, betas=(0.5, 0.999), weight_decay=0.0001 ) """ Define Rho clipper to constraint the value of rho in AdaLIN and LIN""" self.Rho_clipper = RhoClipper(0, self.rho_clipper) self.W_Clipper = WClipper(0, self.w_clipper) def train(self): self.genA2B.train(), self.genB2A.train(), self.disGA.train(), self.disGB.train(), self.disLA.train(), self.disLB.train() start_iter = 1 if self.resume: model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt')) if not len(model_list) == 0: model_list.sort() start_iter = int(model_list[-1].split('_')[-1].split('.')[0]) self.load(os.path.join(self.result_dir, self.dataset, 'model'), start_iter) print(" [*] Load SUCCESS") if self.decay_flag and start_iter > (self.iteration // 2): self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2) self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) * (start_iter - self.iteration // 2) if self.pretrained_weights: params = torch.load(self.pretrained_weights, map_location=self.device) self.genA2B.load_state_dict(params['genA2B']) self.genB2A.load_state_dict(params['genB2A']) self.disGA.load_state_dict(params['disGA']) self.disGB.load_state_dict(params['disGB']) self.disLA.load_state_dict(params['disLA']) self.disLB.load_state_dict(params['disLB']) print(" [*] Load {} Success".format(self.pretrained_weights)) if len(self.gpu_ids) > 1: self.genA2B = nn.DataParallel(self.genA2B, device_ids=self.gpu_ids) self.genB2A = nn.DataParallel(self.genB2A, device_ids=self.gpu_ids) self.disGA = nn.DataParallel(self.disGA, device_ids=self.gpu_ids) self.disGB = nn.DataParallel(self.disGB, device_ids=self.gpu_ids) self.disLA = nn.DataParallel(self.disLA, device_ids=self.gpu_ids) self.disLB = nn.DataParallel(self.disLB, device_ids=self.gpu_ids) # training loop print('training start !') start_time = time.time() for step in range(start_iter, self.iteration + 1): if self.decay_flag and step > (self.iteration // 2): self.G_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) self.D_optim.param_groups[0]['lr'] -= (self.lr / (self.iteration // 2)) try: real_A, _ = trainA_iter.next() except: trainA_iter = iter(self.trainA_loader) real_A, _ = trainA_iter.next() try: real_B, _ = trainB_iter.next() except: trainB_iter = iter(self.trainB_loader) real_B, _ = trainB_iter.next() real_A, real_B = real_A.to(self.device), real_B.to(self.device) # Update D self.D_optim.zero_grad() fake_A2B, _, _ = self.genA2B(real_A) fake_B2A, _, _ = self.genB2A(real_B) real_GA_logit, real_GA_cam_logit, _ = self.disGA(real_A) real_LA_logit, real_LA_cam_logit, _ = self.disLA(real_A) real_GB_logit, real_GB_cam_logit, _ = self.disGB(real_B) real_LB_logit, real_LB_cam_logit, _ = self.disLB(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) D_ad_loss_GA = self.MSE_loss(real_GA_logit, torch.ones_like(real_GA_logit).to(self.device)) + \ self.MSE_loss(fake_GA_logit, torch.zeros_like(fake_GA_logit).to(self.device)) D_ad_cam_loss_GA = self.MSE_loss(real_GA_cam_logit, torch.ones_like(real_GA_cam_logit).to(self.device)) + \ self.MSE_loss(fake_GA_cam_logit, torch.zeros_like(fake_GA_cam_logit).to(self.device)) D_ad_loss_LA = self.MSE_loss(real_LA_logit, torch.ones_like(real_LA_logit).to(self.device)) + \ self.MSE_loss(fake_LA_logit, torch.zeros_like(fake_LA_logit).to(self.device)) D_ad_cam_loss_LA = self.MSE_loss(real_LA_cam_logit, torch.ones_like(real_LA_cam_logit).to(self.device)) +\ self.MSE_loss(fake_LA_cam_logit, torch.zeros_like(fake_LA_cam_logit).to(self.device)) D_ad_loss_GB = self.MSE_loss(real_GB_logit, torch.ones_like(real_GB_logit).to(self.device)) + \ self.MSE_loss(fake_GB_logit, torch.zeros_like(fake_GB_logit).to(self.device)) D_ad_cam_loss_GB = self.MSE_loss(real_GB_cam_logit, torch.ones_like(real_GB_cam_logit).to(self.device)) + \ self.MSE_loss(fake_GB_cam_logit, torch.zeros_like(fake_GB_cam_logit).to(self.device)) D_ad_loss_LB = self.MSE_loss(real_LB_logit, torch.ones_like(real_LB_logit).to(self.device)) + \ self.MSE_loss(fake_LB_logit, torch.zeros_like(fake_LB_logit).to(self.device)) D_ad_cam_loss_LB = self.MSE_loss(real_LB_cam_logit, torch.ones_like(real_LB_cam_logit).to(self.device)) +\ self.MSE_loss(fake_LB_cam_logit, torch.zeros_like(fake_LB_cam_logit).to(self.device)) D_loss_A = self.adv_weight * (D_ad_loss_GA + D_ad_cam_loss_GA + D_ad_loss_LA + D_ad_cam_loss_LA) D_loss_B = self.adv_weight * (D_ad_loss_GB + D_ad_cam_loss_GB + D_ad_loss_LB + D_ad_cam_loss_LB) Discriminator_loss = D_loss_A + D_loss_B Discriminator_loss.backward() self.D_optim.step() # Update G self.G_optim.zero_grad() fake_A2B, fake_A2B_cam_logit, _ = self.genA2B(real_A) fake_B2A, fake_B2A_cam_logit, _ = self.genB2A(real_B) fake_A2B2A, _, _ = self.genB2A(fake_A2B) fake_B2A2B, _, _ = self.genA2B(fake_B2A) fake_A2A, fake_A2A_cam_logit, _ = self.genB2A(real_A) fake_B2B, fake_B2B_cam_logit, _ = self.genA2B(real_B) fake_GA_logit, fake_GA_cam_logit, _ = self.disGA(fake_B2A) fake_LA_logit, fake_LA_cam_logit, _ = self.disLA(fake_B2A) fake_GB_logit, fake_GB_cam_logit, _ = self.disGB(fake_A2B) fake_LB_logit, fake_LB_cam_logit, _ = self.disLB(fake_A2B) G_ad_loss_GA = self.MSE_loss(fake_GA_logit, torch.ones_like(fake_GA_logit).to(self.device)) G_ad_cam_loss_GA = self.MSE_loss(fake_GA_cam_logit, torch.ones_like(fake_GA_cam_logit).to(self.device)) G_ad_loss_LA = self.MSE_loss(fake_LA_logit, torch.ones_like(fake_LA_logit).to(self.device)) G_ad_cam_loss_LA = self.MSE_loss(fake_LA_cam_logit, torch.ones_like(fake_LA_cam_logit).to(self.device)) G_ad_loss_GB = self.MSE_loss(fake_GB_logit, torch.ones_like(fake_GB_logit).to(self.device)) G_ad_cam_loss_GB = self.MSE_loss(fake_GB_cam_logit, torch.ones_like(fake_GB_cam_logit).to(self.device)) G_ad_loss_LB = self.MSE_loss(fake_LB_logit, torch.ones_like(fake_LB_logit).to(self.device)) G_ad_cam_loss_LB = self.MSE_loss(fake_LB_cam_logit, torch.ones_like(fake_LB_cam_logit).to(self.device)) G_recon_loss_A = self.L1_loss(fake_A2B2A, real_A) G_recon_loss_B = self.L1_loss(fake_B2A2B, real_B) G_identity_loss_A = self.L1_loss(fake_A2A, real_A) G_identity_loss_B = self.L1_loss(fake_B2B, real_B) G_id_loss_A = self.facenet.cosine_distance(real_A, fake_A2B) G_id_loss_B = self.facenet.cosine_distance(real_B, fake_B2A) if len(self.gpu_ids) > 1: G_id_loss_A = torch.mean(G_id_loss_A) G_id_loss_B = torch.mean(G_id_loss_B) G_cam_loss_A = self.BCE_loss(fake_B2A_cam_logit, torch.ones_like(fake_B2A_cam_logit).to(self.device)) + \ self.BCE_loss(fake_A2A_cam_logit, torch.zeros_like(fake_A2A_cam_logit).to(self.device)) G_cam_loss_B = self.BCE_loss(fake_A2B_cam_logit, torch.ones_like(fake_A2B_cam_logit).to(self.device)) + \ self.BCE_loss(fake_B2B_cam_logit, torch.zeros_like(fake_B2B_cam_logit).to(self.device)) G_loss_A = self.adv_weight * (G_ad_loss_GA + G_ad_cam_loss_GA + G_ad_loss_LA + G_ad_cam_loss_LA) + \ self.cycle_weight * G_recon_loss_A + self.identity_weight * G_identity_loss_A + \ self.cam_weight * G_cam_loss_A + self.faceid_weight * G_id_loss_A G_loss_B = self.adv_weight * (G_ad_loss_GB + G_ad_cam_loss_GB + G_ad_loss_LB + G_ad_cam_loss_LB) + \ self.cycle_weight * G_recon_loss_B + self.identity_weight * G_identity_loss_B + \ self.cam_weight * G_cam_loss_B + self.faceid_weight * G_id_loss_B Generator_loss = G_loss_A + G_loss_B Generator_loss.backward() self.G_optim.step() # clip parameter of Soft-AdaLIN and LIN, applied after optimizer step self.genA2B.apply(self.Rho_clipper) self.genB2A.apply(self.Rho_clipper) self.genA2B.apply(self.W_Clipper) self.genB2A.apply(self.W_Clipper) if step % 10 == 0: print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" % (step, self.iteration, time.time() - start_time, Discriminator_loss, Generator_loss)) if step % self.print_freq == 0: train_sample_num = 5 test_sample_num = 5 A2B = np.zeros((self.img_size * 7, 0, 3)) B2A = np.zeros((self.img_size * 7, 0, 3)) self.genA2B.eval(), self.genB2A.eval(), self.disGA.eval(), self.disGB.eval(), self.disLA.eval(), self.disLB.eval() with torch.no_grad(): for _ in range(train_sample_num): try: real_A, _ = trainA_iter.next() except: trainA_iter = iter(self.trainA_loader) real_A, _ = trainA_iter.next() try: real_B, _ = trainB_iter.next() except: trainB_iter = iter(self.trainB_loader) real_B, _ = trainB_iter.next() real_A, real_B = real_A.to(self.device), real_B.to(self.device) fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) for _ in range(test_sample_num): try: real_A, _ = testA_iter.next() except: testA_iter = iter(self.testA_loader) real_A, _ = testA_iter.next() try: real_B, _ = testB_iter.next() except: testB_iter = iter(self.testB_loader) real_B, _ = testB_iter.next() real_A, real_B = real_A.to(self.device), real_B.to(self.device) fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) A2B = np.concatenate((A2B, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0)), 1) B2A = np.concatenate((B2A, np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0)), 1) cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'A2B_%07d.png' % step), A2B * 255.0) cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'img', 'B2A_%07d.png' % step), B2A * 255.0) self.genA2B.train(), self.genB2A.train(), self.disGA.train(), self.disGB.train(), self.disLA.train(), self.disLB.train() if step % self.save_freq == 0: self.save(os.path.join(self.result_dir, self.dataset, 'model'), step) if step % 1000 == 0: params = {} if len(self.gpu_ids) > 1: params['genA2B'] = self.genA2B.module.state_dict() params['genB2A'] = self.genB2A.module.state_dict() params['disGA'] = self.disGA.module.state_dict() params['disGB'] = self.disGB.module.state_dict() params['disLA'] = self.disLA.module.state_dict() params['disLB'] = self.disLB.module.state_dict() else: params['genA2B'] = self.genA2B.state_dict() params['genB2A'] = self.genB2A.state_dict() params['disGA'] = self.disGA.state_dict() params['disGB'] = self.disGB.state_dict() params['disLA'] = self.disLA.state_dict() params['disLB'] = self.disLB.state_dict() torch.save(params, os.path.join(self.result_dir, self.dataset + '_params_latest.pt')) def save(self, dir, step): params = {} if len(self.gpu_ids) > 1: params['genA2B'] = self.genA2B.module.state_dict() params['genB2A'] = self.genB2A.module.state_dict() params['disGA'] = self.disGA.module.state_dict() params['disGB'] = self.disGB.module.state_dict() params['disLA'] = self.disLA.module.state_dict() params['disLB'] = self.disLB.module.state_dict() else: params['genA2B'] = self.genA2B.state_dict() params['genB2A'] = self.genB2A.state_dict() params['disGA'] = self.disGA.state_dict() params['disGB'] = self.disGB.state_dict() params['disLA'] = self.disLA.state_dict() params['disLB'] = self.disLB.state_dict() torch.save(params, os.path.join(dir, self.dataset + '_params_%07d.pt' % step)) def load(self, dir, step): params = torch.load(os.path.join(dir, self.dataset + '_params_%07d.pt' % step)) self.genA2B.load_state_dict(params['genA2B']) self.genB2A.load_state_dict(params['genB2A']) self.disGA.load_state_dict(params['disGA']) self.disGB.load_state_dict(params['disGB']) self.disLA.load_state_dict(params['disLA']) self.disLB.load_state_dict(params['disLB']) def test(self): model_list = glob(os.path.join(self.result_dir, self.dataset, 'model', '*.pt')) if not len(model_list) == 0: model_list.sort() iter = int(model_list[-1].split('_')[-1].split('.')[0]) self.load(os.path.join(self.result_dir, self.dataset, 'model'), iter) print(" [*] Load SUCCESS") else: print(" [*] Load FAILURE") return self.genA2B.eval(), self.genB2A.eval() with torch.no_grad(): for n, (real_A, _) in enumerate(self.testA_loader): real_A = real_A.to(self.device) fake_A2B, _, fake_A2B_heatmap = self.genA2B(real_A) fake_A2B2A, _, fake_A2B2A_heatmap = self.genB2A(fake_A2B) fake_A2A, _, fake_A2A_heatmap = self.genB2A(real_A) A2B = np.concatenate((RGB2BGR(tensor2numpy(denorm(real_A[0]))), cam(tensor2numpy(fake_A2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2A[0]))), cam(tensor2numpy(fake_A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B[0]))), cam(tensor2numpy(fake_A2B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_A2B2A[0])))), 0) cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test', 'A2B_%d.png' % (n + 1)), A2B * 255.0) for n, (real_B, _) in enumerate(self.testB_loader): real_B = real_B.to(self.device) fake_B2A, _, fake_B2A_heatmap = self.genB2A(real_B) fake_B2A2B, _, fake_B2A2B_heatmap = self.genA2B(fake_B2A) fake_B2B, _, fake_B2B_heatmap = self.genA2B(real_B) B2A = np.concatenate((RGB2BGR(tensor2numpy(denorm(real_B[0]))), cam(tensor2numpy(fake_B2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2B[0]))), cam(tensor2numpy(fake_B2A_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A[0]))), cam(tensor2numpy(fake_B2A2B_heatmap[0]), self.img_size), RGB2BGR(tensor2numpy(denorm(fake_B2A2B[0])))), 0) cv2.imwrite(os.path.join(self.result_dir, self.dataset, 'test', 'B2A_%d.png' % (n + 1)), B2A * 255.0)