|
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
from blocks import (InitLayer, UpBlockBig, UpBlockBigCond, UpBlockSmall, UpBlockSmallCond, SEBlock, conv2d) |
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
def normalize_second_moment(x, dim=1, eps=1e-8): |
|
return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() |
|
|
|
|
|
class DummyMapping(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, z, c, **kwargs): |
|
return z.unsqueeze(1) |
|
|
|
|
|
class FastganSynthesis(nn.Module): |
|
def __init__(self, ngf=128, z_dim=256, nc=3, img_resolution=256, lite=False): |
|
super().__init__() |
|
self.img_resolution = img_resolution |
|
self.z_dim = z_dim |
|
|
|
|
|
nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5, |
|
512:0.25, 1024:0.125} |
|
nfc = {} |
|
for k, v in nfc_multi.items(): |
|
nfc[k] = int(v*ngf) |
|
|
|
|
|
self.init = InitLayer(z_dim, channel=nfc[2], sz=4) |
|
|
|
UpBlock = UpBlockSmall if lite else UpBlockBig |
|
|
|
self.feat_8 = UpBlock(nfc[4], nfc[8]) |
|
self.feat_16 = UpBlock(nfc[8], nfc[16]) |
|
self.feat_32 = UpBlock(nfc[16], nfc[32]) |
|
self.feat_64 = UpBlock(nfc[32], nfc[64]) |
|
self.feat_128 = UpBlock(nfc[64], nfc[128]) |
|
self.feat_256 = UpBlock(nfc[128], nfc[256]) |
|
|
|
self.se_64 = SEBlock(nfc[4], nfc[64]) |
|
self.se_128 = SEBlock(nfc[8], nfc[128]) |
|
self.se_256 = SEBlock(nfc[16], nfc[256]) |
|
|
|
self.to_big = conv2d(nfc[img_resolution], nc, 3, 1, 1, bias=True) |
|
|
|
if img_resolution > 256: |
|
self.feat_512 = UpBlock(nfc[256], nfc[512]) |
|
self.se_512 = SEBlock(nfc[32], nfc[512]) |
|
if img_resolution > 512: |
|
self.feat_1024 = UpBlock(nfc[512], nfc[1024]) |
|
|
|
def forward(self, input, c, **kwargs): |
|
|
|
input = normalize_second_moment(input[:, 0]) |
|
|
|
feat_4 = self.init(input) |
|
feat_8 = self.feat_8(feat_4) |
|
feat_16 = self.feat_16(feat_8) |
|
feat_32 = self.feat_32(feat_16) |
|
feat_64 = self.se_64(feat_4, self.feat_64(feat_32)) |
|
feat_128 = self.se_128(feat_8, self.feat_128(feat_64)) |
|
|
|
if self.img_resolution >= 128: |
|
feat_last = feat_128 |
|
|
|
if self.img_resolution >= 256: |
|
feat_last = self.se_256(feat_16, self.feat_256(feat_last)) |
|
|
|
if self.img_resolution >= 512: |
|
feat_last = self.se_512(feat_32, self.feat_512(feat_last)) |
|
|
|
if self.img_resolution >= 1024: |
|
feat_last = self.feat_1024(feat_last) |
|
|
|
return self.to_big(feat_last) |
|
|
|
|
|
class FastganSynthesisCond(nn.Module): |
|
def __init__(self, ngf=64, z_dim=256, nc=3, img_resolution=256, num_classes=1000, lite=False): |
|
super().__init__() |
|
|
|
self.z_dim = z_dim |
|
nfc_multi = {2: 16, 4:16, 8:8, 16:4, 32:2, 64:2, 128:1, 256:0.5, |
|
512:0.25, 1024:0.125, 2048:0.125} |
|
nfc = {} |
|
for k, v in nfc_multi.items(): |
|
nfc[k] = int(v*ngf) |
|
|
|
self.img_resolution = img_resolution |
|
|
|
self.init = InitLayer(z_dim, channel=nfc[2], sz=4) |
|
|
|
UpBlock = UpBlockSmallCond if lite else UpBlockBigCond |
|
|
|
self.feat_8 = UpBlock(nfc[4], nfc[8], z_dim) |
|
self.feat_16 = UpBlock(nfc[8], nfc[16], z_dim) |
|
self.feat_32 = UpBlock(nfc[16], nfc[32], z_dim) |
|
self.feat_64 = UpBlock(nfc[32], nfc[64], z_dim) |
|
self.feat_128 = UpBlock(nfc[64], nfc[128], z_dim) |
|
self.feat_256 = UpBlock(nfc[128], nfc[256], z_dim) |
|
|
|
self.se_64 = SEBlock(nfc[4], nfc[64]) |
|
self.se_128 = SEBlock(nfc[8], nfc[128]) |
|
self.se_256 = SEBlock(nfc[16], nfc[256]) |
|
|
|
self.to_big = conv2d(nfc[img_resolution], nc, 3, 1, 1, bias=True) |
|
|
|
if img_resolution > 256: |
|
self.feat_512 = UpBlock(nfc[256], nfc[512]) |
|
self.se_512 = SEBlock(nfc[32], nfc[512]) |
|
if img_resolution > 512: |
|
self.feat_1024 = UpBlock(nfc[512], nfc[1024]) |
|
|
|
self.embed = nn.Embedding(num_classes, z_dim) |
|
|
|
def forward(self, input, c, update_emas=False): |
|
c = self.embed(c.argmax(1)) |
|
|
|
|
|
input = normalize_second_moment(input[:, 0]) |
|
|
|
feat_4 = self.init(input) |
|
feat_8 = self.feat_8(feat_4, c) |
|
feat_16 = self.feat_16(feat_8, c) |
|
feat_32 = self.feat_32(feat_16, c) |
|
feat_64 = self.se_64(feat_4, self.feat_64(feat_32, c)) |
|
feat_128 = self.se_128(feat_8, self.feat_128(feat_64, c)) |
|
|
|
if self.img_resolution >= 128: |
|
feat_last = feat_128 |
|
|
|
if self.img_resolution >= 256: |
|
feat_last = self.se_256(feat_16, self.feat_256(feat_last, c)) |
|
|
|
if self.img_resolution >= 512: |
|
feat_last = self.se_512(feat_32, self.feat_512(feat_last, c)) |
|
|
|
if self.img_resolution >= 1024: |
|
feat_last = self.feat_1024(feat_last, c) |
|
|
|
return self.to_big(feat_last) |
|
|
|
|
|
class MyGenerator(nn.Module, PyTorchModelHubMixin): |
|
def __init__( |
|
self, |
|
z_dim=256, |
|
c_dim=0, |
|
w_dim=0, |
|
img_resolution=256, |
|
img_channels=3, |
|
ngf=128, |
|
cond=0, |
|
mapping_kwargs={}, |
|
synthesis_kwargs={} |
|
): |
|
super().__init__() |
|
|
|
self.z_dim = z_dim |
|
self.c_dim = c_dim |
|
self.w_dim = w_dim |
|
self.img_resolution = img_resolution |
|
self.img_channels = img_channels |
|
|
|
|
|
self.mapping = DummyMapping() |
|
Synthesis = FastganSynthesisCond if cond else FastganSynthesis |
|
self.synthesis = Synthesis(ngf=ngf, z_dim=z_dim, nc=img_channels, img_resolution=img_resolution, **synthesis_kwargs) |
|
|
|
def forward(self, z, c, **kwargs): |
|
w = self.mapping(z, c) |
|
img = self.synthesis(w, c) |
|
return img |
|
|