File size: 3,778 Bytes
83d8d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import torch.nn as nn

from models.init_weight import init_net
from models.model_blocks import AdaInResBlock
from models.model_blocks import ResBlock
from models.semantic_face_fusion_model import SemanticFaceFusionModule
from models.shape_aware_identity_model import ShapeAwareIdentityExtractor


class Encoder(nn.Module):
    """
    Hififace encoder part
    """

    def __init__(self):
        super(Encoder, self).__init__()
        self.conv_first = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)

        self.channel_list = [64, 128, 256, 512, 512, 512, 512, 512]
        self.down_sample = [True, True, True, True, True, False, False]

        self.block_list = nn.ModuleList()

        for i in range(7):
            self.block_list.append(
                ResBlock(self.channel_list[i], self.channel_list[i + 1], down_sample=self.down_sample[i])
            )

    def forward(self, x):
        x = self.conv_first(x)
        z_enc = None

        for i in range(7):
            x = self.block_list[i](x)
            if i == 1:
                z_enc = x
        return z_enc, x


class Decoder(nn.Module):
    """
    Hififace decoder part
    """

    def __init__(self):
        super(Decoder, self).__init__()
        self.block_list = nn.ModuleList()
        self.channel_list = [512, 512, 512, 512, 512, 256]
        self.up_sample = [False, False, True, True, True]

        for i in range(5):
            self.block_list.append(
                AdaInResBlock(self.channel_list[i], self.channel_list[i + 1], up_sample=self.up_sample[i])
            )

    def forward(self, x, id_vector):
        """
        Parameters:
        -----------
        x: encoder encoded feature map
        id_vector: 3d shape aware identity vector

        Returns:
        --------
        z_dec
        """
        for i in range(5):
            x = self.block_list[i](x, id_vector)
        return x


class Generator(nn.Module):
    """
    Hififace Generator
    """

    def __init__(self, identity_extractor_config):
        super(Generator, self).__init__()
        self.id_extractor = ShapeAwareIdentityExtractor(identity_extractor_config)
        self.id_extractor.requires_grad_(False)
        self.encoder = init_net(Encoder())
        self.decoder = init_net(Decoder())
        self.sff_module = init_net(SemanticFaceFusionModule())

    @torch.no_grad()
    def interp(self, i_source, i_target, shape_rate=1.0, id_rate=1.0):
        shape_aware_id_vector = self.id_extractor.interp(i_source, i_target, shape_rate, id_rate)
        z_enc, x = self.encoder(i_target)
        z_dec = self.decoder(x, shape_aware_id_vector)

        i_r, i_low, m_r, m_low = self.sff_module(i_target, z_enc, z_dec, shape_aware_id_vector)

        return i_r, i_low, m_r, m_low

    def forward(self, i_source, i_target, need_id_grad=False):
        """
        Parameters:
        -----------
        i_source: torch.Tensor, shape (B, 3, H, W), in range [0, 1], source face image
        i_target: torch.Tensor, shape (B, 3, H, W), in range [0, 1], target face image
        need_id_grad: bool, whether to calculate id extractor module's gradient

        Returns:
        --------
        i_r:    torch.Tensor
        i_low:  torch.Tensor
        m_r:    torch.Tensor
        m_low:  torch.Tensor
        """
        if need_id_grad:
            shape_aware_id_vector = self.id_extractor(i_source, i_target)
        else:
            with torch.no_grad():
                shape_aware_id_vector = self.id_extractor(i_source, i_target)
        z_enc, x = self.encoder(i_target)
        z_dec = self.decoder(x, shape_aware_id_vector)

        i_r, i_low, m_r, m_low = self.sff_module(i_target, z_enc, z_dec, shape_aware_id_vector)

        return i_r, i_low, m_r, m_low