import argparse import os from typing import List from typing import Optional import cv2 import numpy as np import torch from configs.train_config import TrainConfig from models.model import HifiFace def test( data_root: str, result_path: str, source_face: List[str], target_face: List[str], model_path: str, model_idx: Optional[int], ): opt = TrainConfig() opt.use_ddp = False device = "cpu" checkpoint = (model_path, model_idx) model = HifiFace(opt.identity_extractor_config, is_training=False, device=device, load_checkpoint=checkpoint) model.eval() results = [] for source, target in zip(source_face, target_face): source = os.path.join(data_root, source) target = os.path.join(data_root, target) src_img = cv2.imread(source) src_img = cv2.resize(src_img, (256, 256)) src = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB) src = src.transpose(2, 0, 1) src = torch.from_numpy(src).unsqueeze(0).to(device).float() src = src / 255.0 tgt_img = cv2.imread(target) tgt_img = cv2.resize(tgt_img, (256, 256)) tgt = cv2.cvtColor(tgt_img, cv2.COLOR_BGR2RGB) tgt = tgt.transpose(2, 0, 1) tgt = torch.from_numpy(tgt).unsqueeze(0).to(device).float() tgt = tgt / 255.0 with torch.no_grad(): result_face = model.forward(src, tgt).cpu() result_face = torch.clamp(result_face, 0, 1) * 255 result_face = result_face.numpy()[0].astype(np.uint8) result_face = result_face.transpose(1, 2, 0) result_face = cv2.cvtColor(result_face, cv2.COLOR_BGR2RGB) one_result = np.concatenate((src_img, tgt_img, result_face), axis=0) results.append(one_result) result = np.concatenate(results, axis=1) swapped_face = os.path.join(data_root, result_path) cv2.imwrite(swapped_face, result) if __name__ == "__main__": parser = argparse.ArgumentParser( prog="benchmark", description="What the program does", epilog="Text at the bottom of help" ) parser.add_argument("-m", "--model_name") parser.add_argument("-i", "--model_index") args = parser.parse_args() data_root = "/home/xuehongyang/data/face_swap_test" model_path = os.path.join("/data/checkpoints/hififace/", args.model_name) model_idx = int(args.model_index) name = f"{args.model_name}_{args.model_index}" source = [ "male_1.jpg", "male_2.jpg", "female_1.jpg", "female_2.jpg", "male_1.jpg", "male_2.jpg", "female_1.jpg", "female_2.jpg", "female_1.jpg", "female_2.jpg", "test1.jpg", "test1.jpg", "test1.jpg", ] target = [ "male_2.jpg", "male_1.jpg", "female_2.jpg", "female_1.jpg", "female_1.jpg", "female_2.jpg", "male_2.jpg", "male_1.jpg", "male_1.jpg", "male_2.jpg", "female_1.jpg", "female_2.jpg", "male_1.jpg", ] target_src = os.path.join(data_root, f"../{name}_1.jpg") test(data_root, target_src, source, target, model_path, model_idx) source = [ "male_2.jpg", "male_1.jpg", "male_1.jpg", "male_2.jpg", "male_1.jpg", "male_2.jpg", "male_1.jpg", "male_2.jpg", "male_1.jpg", "male_2.jpg", "female_2.jpg", "female_1.jpg", "female_2.jpg", "female_1.jpg", "female_2.jpg", "female_1.jpg", "female_2.jpg", "female_1.jpg", "female_2.jpg", "female_1.jpg", "female_2.jpg", "female_1.jpg", "female_2.jpg", "female_1.jpg", ] target = [ "male_1.jpg", "male_2.jpg", "minlu_1.jpg", "minlu_2.jpg", "shizong_1.jpg", "shizong_2.jpg", "tianxin_1.jpg", "tianxin_2.jpg", "xiaohui_1.jpg", "xiaohui_2.jpg", "female_1.jpg", "female_2.jpg", "female_3.jpg", "female_4.jpg", "female_5.jpg", "female_6.jpg", "lixia_1.jpg", "lixia_2.jpg", "qq_1.jpg", "qq_2.jpg", "pink_1.jpg", "pink_2.jpg", "xulie_1.jpg", "xulie_2.jpg", ] target_src = os.path.join(data_root, f"../{name}_2.jpg") test(data_root, target_src, source, target, model_path, model_idx) source = [ "male_2.jpg", "male_1.jpg", "shizong_1.jpg", "shizong_2.jpg", "minlu_1.jpg", "minlu_2.jpg", "xiaohui_1.jpg", "xiaohui_2.jpg", "tianxin_1.jpg", "tianxin_2.jpg", "female_2.jpg", "female_1.jpg", "female_5.jpg", "female_6.jpg", "female_3.jpg", "female_4.jpg", "qq_1.jpg", "qq_2.jpg", "pink_1.jpg", "pink_2.jpg", "xulie_1.jpg", "xulie_2.jpg", "lixia_1.jpg", "lixia_2.jpg", ] target = [ "male_2.jpg", "male_1.jpg", "minlu_1.jpg", "minlu_2.jpg", "shizong_1.jpg", "shizong_2.jpg", "tianxin_1.jpg", "tianxin_2.jpg", "xiaohui_1.jpg", "xiaohui_2.jpg", "female_1.jpg", "female_2.jpg", "female_3.jpg", "female_4.jpg", "female_5.jpg", "female_6.jpg", "lixia_1.jpg", "lixia_2.jpg", "qq_1.jpg", "qq_2.jpg", "pink_1.jpg", "pink_2.jpg", "xulie_1.jpg", "xulie_2.jpg", ] target_src = os.path.join(data_root, f"../{name}_3.jpg") test(data_root, target_src, source, target, model_path, model_idx)