Spaces:
Runtime error
Runtime error
File size: 3,778 Bytes
83d8d3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import torch
import torch.nn as nn
from models.init_weight import init_net
from models.model_blocks import AdaInResBlock
from models.model_blocks import ResBlock
from models.semantic_face_fusion_model import SemanticFaceFusionModule
from models.shape_aware_identity_model import ShapeAwareIdentityExtractor
class Encoder(nn.Module):
"""
Hififace encoder part
"""
def __init__(self):
super(Encoder, self).__init__()
self.conv_first = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.channel_list = [64, 128, 256, 512, 512, 512, 512, 512]
self.down_sample = [True, True, True, True, True, False, False]
self.block_list = nn.ModuleList()
for i in range(7):
self.block_list.append(
ResBlock(self.channel_list[i], self.channel_list[i + 1], down_sample=self.down_sample[i])
)
def forward(self, x):
x = self.conv_first(x)
z_enc = None
for i in range(7):
x = self.block_list[i](x)
if i == 1:
z_enc = x
return z_enc, x
class Decoder(nn.Module):
"""
Hififace decoder part
"""
def __init__(self):
super(Decoder, self).__init__()
self.block_list = nn.ModuleList()
self.channel_list = [512, 512, 512, 512, 512, 256]
self.up_sample = [False, False, True, True, True]
for i in range(5):
self.block_list.append(
AdaInResBlock(self.channel_list[i], self.channel_list[i + 1], up_sample=self.up_sample[i])
)
def forward(self, x, id_vector):
"""
Parameters:
-----------
x: encoder encoded feature map
id_vector: 3d shape aware identity vector
Returns:
--------
z_dec
"""
for i in range(5):
x = self.block_list[i](x, id_vector)
return x
class Generator(nn.Module):
"""
Hififace Generator
"""
def __init__(self, identity_extractor_config):
super(Generator, self).__init__()
self.id_extractor = ShapeAwareIdentityExtractor(identity_extractor_config)
self.id_extractor.requires_grad_(False)
self.encoder = init_net(Encoder())
self.decoder = init_net(Decoder())
self.sff_module = init_net(SemanticFaceFusionModule())
@torch.no_grad()
def interp(self, i_source, i_target, shape_rate=1.0, id_rate=1.0):
shape_aware_id_vector = self.id_extractor.interp(i_source, i_target, shape_rate, id_rate)
z_enc, x = self.encoder(i_target)
z_dec = self.decoder(x, shape_aware_id_vector)
i_r, i_low, m_r, m_low = self.sff_module(i_target, z_enc, z_dec, shape_aware_id_vector)
return i_r, i_low, m_r, m_low
def forward(self, i_source, i_target, need_id_grad=False):
"""
Parameters:
-----------
i_source: torch.Tensor, shape (B, 3, H, W), in range [0, 1], source face image
i_target: torch.Tensor, shape (B, 3, H, W), in range [0, 1], target face image
need_id_grad: bool, whether to calculate id extractor module's gradient
Returns:
--------
i_r: torch.Tensor
i_low: torch.Tensor
m_r: torch.Tensor
m_low: torch.Tensor
"""
if need_id_grad:
shape_aware_id_vector = self.id_extractor(i_source, i_target)
else:
with torch.no_grad():
shape_aware_id_vector = self.id_extractor(i_source, i_target)
z_enc, x = self.encoder(i_target)
z_dec = self.decoder(x, shape_aware_id_vector)
i_r, i_low, m_r, m_low = self.sff_module(i_target, z_enc, z_dec, shape_aware_id_vector)
return i_r, i_low, m_r, m_low
|