import torch import torch.nn as nn from models.init_weight import init_net from models.model_blocks import AdaInResBlock from models.model_blocks import ResBlock from models.semantic_face_fusion_model import SemanticFaceFusionModule from models.shape_aware_identity_model import ShapeAwareIdentityExtractor class Encoder(nn.Module): """ Hififace encoder part """ def __init__(self): super(Encoder, self).__init__() self.conv_first = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) self.channel_list = [64, 128, 256, 512, 512, 512, 512, 512] self.down_sample = [True, True, True, True, True, False, False] self.block_list = nn.ModuleList() for i in range(7): self.block_list.append( ResBlock(self.channel_list[i], self.channel_list[i + 1], down_sample=self.down_sample[i]) ) def forward(self, x): x = self.conv_first(x) z_enc = None for i in range(7): x = self.block_list[i](x) if i == 1: z_enc = x return z_enc, x class Decoder(nn.Module): """ Hififace decoder part """ def __init__(self): super(Decoder, self).__init__() self.block_list = nn.ModuleList() self.channel_list = [512, 512, 512, 512, 512, 256] self.up_sample = [False, False, True, True, True] for i in range(5): self.block_list.append( AdaInResBlock(self.channel_list[i], self.channel_list[i + 1], up_sample=self.up_sample[i]) ) def forward(self, x, id_vector): """ Parameters: ----------- x: encoder encoded feature map id_vector: 3d shape aware identity vector Returns: -------- z_dec """ for i in range(5): x = self.block_list[i](x, id_vector) return x class Generator(nn.Module): """ Hififace Generator """ def __init__(self, identity_extractor_config): super(Generator, self).__init__() self.id_extractor = ShapeAwareIdentityExtractor(identity_extractor_config) self.id_extractor.requires_grad_(False) self.encoder = init_net(Encoder()) self.decoder = init_net(Decoder()) self.sff_module = init_net(SemanticFaceFusionModule()) @torch.no_grad() def interp(self, i_source, i_target, shape_rate=1.0, id_rate=1.0): shape_aware_id_vector = self.id_extractor.interp(i_source, i_target, shape_rate, id_rate) z_enc, x = self.encoder(i_target) z_dec = self.decoder(x, shape_aware_id_vector) i_r, i_low, m_r, m_low = self.sff_module(i_target, z_enc, z_dec, shape_aware_id_vector) return i_r, i_low, m_r, m_low def forward(self, i_source, i_target, need_id_grad=False): """ Parameters: ----------- i_source: torch.Tensor, shape (B, 3, H, W), in range [0, 1], source face image i_target: torch.Tensor, shape (B, 3, H, W), in range [0, 1], target face image need_id_grad: bool, whether to calculate id extractor module's gradient Returns: -------- i_r: torch.Tensor i_low: torch.Tensor m_r: torch.Tensor m_low: torch.Tensor """ if need_id_grad: shape_aware_id_vector = self.id_extractor(i_source, i_target) else: with torch.no_grad(): shape_aware_id_vector = self.id_extractor(i_source, i_target) z_enc, x = self.encoder(i_target) z_dec = self.decoder(x, shape_aware_id_vector) i_r, i_low, m_r, m_low = self.sff_module(i_target, z_enc, z_dec, shape_aware_id_vector) return i_r, i_low, m_r, m_low