File size: 4,933 Bytes
e6043d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os.path as osp
import glob
import logging
import insightface
from insightface.model_zoo.model_zoo import ModelRouter, PickableInferenceSession
from insightface.model_zoo.retinaface import RetinaFace
from insightface.model_zoo.landmark import Landmark
from insightface.model_zoo.attribute import Attribute
from insightface.model_zoo.inswapper import INSwapper
from insightface.model_zoo.arcface_onnx import ArcFaceONNX
from insightface.app import FaceAnalysis
from insightface.utils import DEFAULT_MP_NAME, ensure_available
from insightface.model_zoo import model_zoo
import onnxruntime
import onnx
from onnx import numpy_helper
from scripts.reactor_logger import logger


def patched_get_model(self, **kwargs):
    session = PickableInferenceSession(self.onnx_file, **kwargs)
    inputs = session.get_inputs()
    input_cfg = inputs[0]
    input_shape = input_cfg.shape
    outputs = session.get_outputs()

    if len(outputs) >= 5:
        return RetinaFace(model_file=self.onnx_file, session=session)
    elif input_shape[2] == 192 and input_shape[3] == 192:
        return Landmark(model_file=self.onnx_file, session=session)
    elif input_shape[2] == 96 and input_shape[3] == 96:
        return Attribute(model_file=self.onnx_file, session=session)
    elif len(inputs) == 2 and input_shape[2] == 128 and input_shape[3] == 128:
        return INSwapper(model_file=self.onnx_file, session=session)
    elif input_shape[2] == input_shape[3] and input_shape[2] >= 112 and input_shape[2] % 16 == 0:
        return ArcFaceONNX(model_file=self.onnx_file, session=session)
    else:
        return None


def patched_faceanalysis_init(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None, **kwargs):
    onnxruntime.set_default_logger_severity(3)
    self.models = {}
    self.model_dir = ensure_available('models', name, root=root)
    onnx_files = glob.glob(osp.join(self.model_dir, '*.onnx'))
    onnx_files = sorted(onnx_files)
    for onnx_file in onnx_files:
        model = model_zoo.get_model(onnx_file, **kwargs)
        if model is None:
            print('model not recognized:', onnx_file)
        elif allowed_modules is not None and model.taskname not in allowed_modules:
            print('model ignore:', onnx_file, model.taskname)
            del model
        elif model.taskname not in self.models and (allowed_modules is None or model.taskname in allowed_modules):
            self.models[model.taskname] = model
        else:
            print('duplicated model task type, ignore:', onnx_file, model.taskname)
            del model
    assert 'detection' in self.models
    self.det_model = self.models['detection']


def patched_faceanalysis_prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)):
    self.det_thresh = det_thresh
    assert det_size is not None
    self.det_size = det_size
    for taskname, model in self.models.items():
        if taskname == 'detection':
            model.prepare(ctx_id, input_size=det_size, det_thresh=det_thresh)
        else:
            model.prepare(ctx_id)


def patched_inswapper_init(self, model_file=None, session=None):
    self.model_file = model_file
    self.session = session
    model = onnx.load(self.model_file)
    graph = model.graph
    self.emap = numpy_helper.to_array(graph.initializer[-1])
    self.input_mean = 0.0
    self.input_std = 255.0
    if self.session is None:
        self.session = onnxruntime.InferenceSession(self.model_file, None)
    inputs = self.session.get_inputs()
    self.input_names = []
    for inp in inputs:
        self.input_names.append(inp.name)
    outputs = self.session.get_outputs()
    output_names = []
    for out in outputs:
        output_names.append(out.name)
    self.output_names = output_names
    assert len(self.output_names) == 1
    input_cfg = inputs[0]
    input_shape = input_cfg.shape
    self.input_shape = input_shape
    self.input_size = tuple(input_shape[2:4][::-1])


def patch_insightface(get_model, faceanalysis_init, faceanalysis_prepare, inswapper_init):
    insightface.model_zoo.model_zoo.ModelRouter.get_model = get_model
    insightface.app.FaceAnalysis.__init__ = faceanalysis_init
    insightface.app.FaceAnalysis.prepare = faceanalysis_prepare
    insightface.model_zoo.inswapper.INSwapper.__init__ = inswapper_init


original_functions = [ModelRouter.get_model, FaceAnalysis.__init__, FaceAnalysis.prepare, INSwapper.__init__]
patched_functions = [patched_get_model, patched_faceanalysis_init, patched_faceanalysis_prepare, patched_inswapper_init]


def apply_logging_patch(console_log_level):
    if console_log_level == 0:
        patch_insightface(*patched_functions)
        logger.setLevel(logging.WARNING)
    elif console_log_level == 1:
        patch_insightface(*patched_functions)
        logger.setLevel(logging.STATUS)
    elif console_log_level == 2:
        patch_insightface(*original_functions)
        logger.setLevel(logging.INFO)