import torch.nn as nn import torch.nn.functional as F from models.model_blocks import AdaInResBlock from models.model_blocks import ResBlock from models.model_blocks import UpSamplingBlock class SemanticFaceFusionModule(nn.Module): def __init__(self): """ Semantic Face Fusion Module to preserve lighting and background """ super(SemanticFaceFusionModule, self).__init__() self.sigma = ResBlock(256, 256) self.low_mask_predict = nn.Sequential(nn.Conv2d(256, 1, 3, 1, 1), nn.Sigmoid()) self.z_fuse_block_1 = AdaInResBlock(256, 256) self.z_fuse_block_2 = AdaInResBlock(256, 256) self.i_low_block = nn.Sequential(nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 3, 3, 1, 1)) self.f_up = UpSamplingBlock() def forward(self, target_image, z_enc, z_dec, v_sid): """ Parameters: ---------- target_image: 目标脸图片 z_enc: 1/4原图大小的low-level encoder feature map z_dec: 1/4原图大小的low-level decoder feature map v_sid: the 3D shape aware identity vector Returns: -------- i_r: re-target image i_low: 1/4 size retarget image m_r: face mask m_low: 1/4 size face mask """ z_enc = self.sigma(z_enc) # 估算z_dec对应的人脸 low-level feature mask m_low = self.low_mask_predict(z_dec) # 计算融合的low-level feature map # mask区域使用decoder的low-level特征 + 非mask区域使用encoder的low-level特征 z_fuse = m_low * z_dec + (1 - m_low) * z_enc z_fuse = self.z_fuse_block_1(z_fuse, v_sid) z_fuse = self.z_fuse_block_2(z_fuse, v_sid) i_low = self.i_low_block(z_fuse) i_low = m_low * i_low + (1 - m_low) * F.interpolate(target_image, scale_factor=0.25) i_r, m_r = self.f_up(z_fuse) i_r = m_r * i_r + (1 - m_r) * target_image return i_r, i_low, m_r, m_low if __name__ == "__main__": import torch timg = torch.randn(1, 3, 256, 256) z_enc = torch.randn(1, 256, 64, 64) z_dec = torch.randn(1, 256, 64, 64) v_sid = torch.randn(1, 769) model = SemanticFaceFusionModule() i_r, i_low, m_r, m_low = model(timg, z_enc, z_dec, v_sid) print(i_r.shape, i_low.shape, m_r.shape, m_low.shape)