File size: 3,265 Bytes
83d8d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import argparse
import os
from typing import List
from typing import Optional

import cv2
import numpy as np
import torch

from configs.train_config import TrainConfig
from models.model import HifiFace


def test(
    data_root: str,
    result_path: str,
    source_face: List[str],
    target_face: List[str],
    model_path: str,
    model_idx: Optional[int],
):
    opt = TrainConfig()
    opt.use_ddp = False

    device = "cpu"
    checkpoint = (model_path, model_idx)
    model = HifiFace(opt.identity_extractor_config, is_training=False, device=device, load_checkpoint=checkpoint)
    model.eval()

    results = []
    for source, target in zip(source_face, target_face):
        source = os.path.join(data_root, source)
        target = os.path.join(data_root, target)

        src_img = cv2.imread(source)
        src_img = cv2.resize(src_img, (256, 256))
        src = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
        src = src.transpose(2, 0, 1)
        src = torch.from_numpy(src).unsqueeze(0).to(device).float()
        src = src / 255.0

        tgt_img = cv2.imread(target)
        tgt_img = cv2.resize(tgt_img, (256, 256))
        tgt = cv2.cvtColor(tgt_img, cv2.COLOR_BGR2RGB)
        tgt = tgt.transpose(2, 0, 1)
        tgt = torch.from_numpy(tgt).unsqueeze(0).to(device).float()
        tgt = tgt / 255.0

        with torch.no_grad():
            result_face = model.forward(src, tgt).cpu()
            result_face = torch.clamp(result_face, 0, 1) * 255
        result_face = result_face.numpy()[0].astype(np.uint8)
        result_face = result_face.transpose(1, 2, 0)

        result_face = cv2.cvtColor(result_face, cv2.COLOR_BGR2RGB)
        one_result = np.concatenate((src_img, tgt_img, result_face), axis=0)
        results.append(one_result)
    result = np.concatenate(results, axis=1)
    swapped_face = os.path.join(data_root, result_path)
    cv2.imwrite(swapped_face, result)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="benchmark", description="What the program does", epilog="Text at the bottom of help"
    )
    parser.add_argument("-m", "--model_name")
    parser.add_argument("-i", "--model_index")
    parser.add_argument("-s", "--source_image")
    args = parser.parse_args()
    data_root = "/home/xuehongyang/data/face_swap_test"

    model_path = os.path.join("/data/checkpoints/hififace/", args.model_name)
    model_idx = int(args.model_index)

    name = f"{args.model_name}_{args.model_index}"

    target = [
        "male_1.jpg",
        "male_2.jpg",
        "minlu_1.jpg",
        "minlu_2.jpg",
        "shizong_1.jpg",
        "shizong_2.jpg",
        "tianxin_1.jpg",
        "tianxin_2.jpg",
        "xiaohui_1.jpg",
        "xiaohui_2.jpg",
        "female_1.jpg",
        "female_2.jpg",
        "female_3.jpg",
        "female_4.jpg",
        "female_5.jpg",
        "female_6.jpg",
        "lixia_1.jpg",
        "lixia_2.jpg",
        "qq_1.jpg",
        "qq_2.jpg",
        "pink_1.jpg",
        "pink_2.jpg",
        "xulie_1.jpg",
        "xulie_2.jpg",
    ]

    source = [args.source_image] * len(target)
    target_src = os.path.join(data_root, f"../{name}_1tom_{args.source_image}.jpg")
    test(data_root, target_src, source, target, model_path, model_idx)