File size: 1,918 Bytes
eb7d2bb
 
 
 
 
 
 
 
 
 
 
c47cf0c
eb7d2bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os
import cv2
import numpy as np
import onnxruntime
from utils import Preprocess


class Photo2Cartoon:
    def __init__(self):
        self.pre = Preprocess()
        curPath = os.path.abspath(os.path.dirname(__file__))
        print(os.path.join(curPath, 'models/photo2cartoon_weights.onnx'))
        # assert os.path.exists('./models/photo2cartoon_weights.onnx'), "[Step1: load weights] Can not find 'photo2cartoon_weights.onnx' in folder 'models!!!'"
        self.session = onnxruntime.InferenceSession(os.path.join(curPath, 'models/photo2cartoon_weights.onnx'))
        print('[Step1: load weights] success!')

    def inference(self, in_path):
        img = cv2.cvtColor(cv2.imread(in_path), cv2.COLOR_BGR2RGB)
        # face alignment and segmentation
        face_rgba = self.pre.process(img)
        if face_rgba is None:
            print('[Step2: face detect] can not detect face!!!')
            return None

        print('[Step2: face detect] success!')
        face_rgba = cv2.resize(face_rgba, (256, 256), interpolation=cv2.INTER_AREA)
        face = face_rgba[:, :, :3].copy()
        mask = face_rgba[:, :, 3][:, :, np.newaxis].copy() / 255.
        face = (face * mask + (1 - mask) * 255) / 127.5 - 1

        face = np.transpose(face[np.newaxis, :, :, :], (0, 3, 1, 2)).astype(np.float32)

        # inference
        cartoon = self.session.run(['output'], input_feed={'input': face})

        # post-process
        cartoon = np.transpose(cartoon[0][0], (1, 2, 0))
        cartoon = (cartoon + 1) * 127.5
        cartoon = (cartoon * mask + 255 * (1 - mask)).astype(np.uint8)
        #cartoon = cv2.cvtColor(cartoon, cv2.COLOR_RGB2BGR)

        print('[Step3: photo to cartoon] success!')
        return cartoon


if __name__ == '__main__':
    c2p = Photo2Cartoon()
    cartoon = c2p.inference('')
    if cartoon is not None:
        print('Cartoon portrait has been saved successfully!')