xuehongyang
ser
83d8d3c
from typing import Optional
import cv2
import numpy as np
import torch
from configs.train_config import TrainConfig
from models.model import HifiFace
def inference(source_face: str, target_face: str, model_path: str, model_idx: Optional[int], swapped_face: str):
opt = TrainConfig()
opt.use_ddp = False
device = "cpu"
checkpoint = (model_path, model_idx)
model = HifiFace(opt.identity_extractor_config, is_training=False, device=device, load_checkpoint=checkpoint)
model.eval()
src = cv2.cvtColor(cv2.imread(source_face), cv2.COLOR_BGR2RGB)
src = cv2.resize(src, (256, 256))
src = src.transpose(2, 0, 1)
src = torch.from_numpy(src).unsqueeze(0).to(device).float()
src = src / 255.0
tgt = cv2.cvtColor(cv2.imread(target_face), cv2.COLOR_BGR2RGB)
tgt = cv2.resize(tgt, (256, 256))
tgt = tgt.transpose(2, 0, 1)
tgt = torch.from_numpy(tgt).unsqueeze(0).to(device).float()
tgt = tgt / 255.0
with torch.no_grad():
result_face = model.forward(src, tgt).cpu()
result_face = torch.clamp(result_face, 0, 1) * 255
result_face = result_face.numpy()[0].astype(np.uint8)
result_face = result_face.transpose(1, 2, 0)
result_face = cv2.cvtColor(result_face, cv2.COLOR_BGR2RGB)
cv2.imwrite(swapped_face, result_face)
if __name__ == "__main__":
source_face = "/home/xuehongyang/data/female_1.jpg"
target_face = "/home/xuehongyang/data/female_2.jpg"
model_path = "/data/checkpoints/hififace/baseline_1k_ddp_with_cyc_1681278017147"
model_idx = 80000
swapped_face = "/home/xuehongyang/data/male_1_to_male_2.jpg"
inference(source_face, target_face, model_path, model_idx, swapped_face)