File size: 1,699 Bytes
83d8d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from typing import Optional

import cv2
import numpy as np
import torch

from configs.train_config import TrainConfig
from models.model import HifiFace


def inference(source_face: str, target_face: str, model_path: str, model_idx: Optional[int], swapped_face: str):
    opt = TrainConfig()
    opt.use_ddp = False

    device = "cpu"
    checkpoint = (model_path, model_idx)
    model = HifiFace(opt.identity_extractor_config, is_training=False, device=device, load_checkpoint=checkpoint)
    model.eval()

    src = cv2.cvtColor(cv2.imread(source_face), cv2.COLOR_BGR2RGB)
    src = cv2.resize(src, (256, 256))
    src = src.transpose(2, 0, 1)
    src = torch.from_numpy(src).unsqueeze(0).to(device).float()
    src = src / 255.0

    tgt = cv2.cvtColor(cv2.imread(target_face), cv2.COLOR_BGR2RGB)
    tgt = cv2.resize(tgt, (256, 256))
    tgt = tgt.transpose(2, 0, 1)
    tgt = torch.from_numpy(tgt).unsqueeze(0).to(device).float()
    tgt = tgt / 255.0

    with torch.no_grad():
        result_face = model.forward(src, tgt).cpu()
        result_face = torch.clamp(result_face, 0, 1) * 255
    result_face = result_face.numpy()[0].astype(np.uint8)
    result_face = result_face.transpose(1, 2, 0)

    result_face = cv2.cvtColor(result_face, cv2.COLOR_BGR2RGB)
    cv2.imwrite(swapped_face, result_face)


if __name__ == "__main__":
    source_face = "/home/xuehongyang/data/female_1.jpg"
    target_face = "/home/xuehongyang/data/female_2.jpg"
    model_path = "/data/checkpoints/hififace/baseline_1k_ddp_with_cyc_1681278017147"
    model_idx = 80000
    swapped_face = "/home/xuehongyang/data/male_1_to_male_2.jpg"
    inference(source_face, target_face, model_path, model_idx, swapped_face)