|
''' |
|
Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu. |
|
BSD License. All rights reserved. |
|
|
|
Redistribution and use in source and binary forms, with or without |
|
modification, are permitted provided that the following conditions are met: |
|
|
|
* Redistributions of source code must retain the above copyright notice, this |
|
list of conditions and the following disclaimer. |
|
|
|
* Redistributions in binary form must reproduce the above copyright notice, |
|
this list of conditions and the following disclaimer in the documentation |
|
and/or other materials provided with the distribution. |
|
|
|
THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL |
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE. |
|
IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL |
|
DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, |
|
WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING |
|
OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. |
|
''' |
|
import torch |
|
import torch.nn as nn |
|
import functools |
|
from torch.autograd import Variable |
|
import numpy as np |
|
|
|
|
|
|
|
|
|
def weights_init(m): |
|
classname = m.__class__.__name__ |
|
if classname.find('Conv') != -1: |
|
m.weight.data.normal_(0.0, 0.02) |
|
elif classname.find('BatchNorm2d') != -1: |
|
m.weight.data.normal_(1.0, 0.02) |
|
m.bias.data.fill_(0) |
|
|
|
def get_norm_layer(norm_type='instance'): |
|
if norm_type == 'batch': |
|
norm_layer = functools.partial(nn.BatchNorm2d, affine=True) |
|
elif norm_type == 'instance': |
|
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) |
|
else: |
|
raise NotImplementedError('normalization layer [%s] is not found' % norm_type) |
|
return norm_layer |
|
|
|
def define_G(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1, |
|
n_blocks_local=3, norm='instance', gpu_ids=[], last_op=nn.Tanh()): |
|
norm_layer = get_norm_layer(norm_type=norm) |
|
if netG == 'global': |
|
netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer, last_op=last_op) |
|
elif netG == 'local': |
|
netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, |
|
n_local_enhancers, n_blocks_local, norm_layer) |
|
elif netG == 'encoder': |
|
netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, norm_layer) |
|
else: |
|
raise('generator not implemented!') |
|
|
|
if len(gpu_ids) > 0: |
|
assert(torch.cuda.is_available()) |
|
netG.cuda(gpu_ids[0]) |
|
netG.apply(weights_init) |
|
return netG |
|
|
|
def print_network(net): |
|
if isinstance(net, list): |
|
net = net[0] |
|
num_params = 0 |
|
for param in net.parameters(): |
|
num_params += param.numel() |
|
print(net) |
|
print('Total number of parameters: %d' % num_params) |
|
|
|
|
|
|
|
|
|
class LocalEnhancer(nn.Module): |
|
def __init__(self, input_nc, output_nc, ngf=32, n_downsample_global=3, n_blocks_global=9, |
|
n_local_enhancers=1, n_blocks_local=3, norm_layer=nn.BatchNorm2d, padding_type='reflect'): |
|
super(LocalEnhancer, self).__init__() |
|
self.n_local_enhancers = n_local_enhancers |
|
|
|
|
|
ngf_global = ngf * (2**n_local_enhancers) |
|
model_global = GlobalGenerator(input_nc, output_nc, ngf_global, n_downsample_global, n_blocks_global, norm_layer).model |
|
model_global = [model_global[i] for i in range(len(model_global)-3)] |
|
self.model = nn.Sequential(*model_global) |
|
|
|
|
|
for n in range(1, n_local_enhancers+1): |
|
|
|
ngf_global = ngf * (2**(n_local_enhancers-n)) |
|
model_downsample = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0), |
|
norm_layer(ngf_global), nn.ReLU(True), |
|
nn.Conv2d(ngf_global, ngf_global * 2, kernel_size=3, stride=2, padding=1), |
|
norm_layer(ngf_global * 2), nn.ReLU(True)] |
|
|
|
model_upsample = [] |
|
for i in range(n_blocks_local): |
|
model_upsample += [ResnetBlock(ngf_global * 2, padding_type=padding_type, norm_layer=norm_layer)] |
|
|
|
|
|
model_upsample += [nn.ConvTranspose2d(ngf_global * 2, ngf_global, kernel_size=3, stride=2, padding=1, output_padding=1), |
|
norm_layer(ngf_global), nn.ReLU(True)] |
|
|
|
|
|
if n == n_local_enhancers: |
|
model_upsample += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] |
|
|
|
setattr(self, 'model'+str(n)+'_1', nn.Sequential(*model_downsample)) |
|
setattr(self, 'model'+str(n)+'_2', nn.Sequential(*model_upsample)) |
|
|
|
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) |
|
|
|
def forward(self, input): |
|
|
|
input_downsampled = [input] |
|
for i in range(self.n_local_enhancers): |
|
input_downsampled.append(self.downsample(input_downsampled[-1])) |
|
|
|
|
|
output_prev = self.model(input_downsampled[-1]) |
|
|
|
for n_local_enhancers in range(1, self.n_local_enhancers+1): |
|
model_downsample = getattr(self, 'model'+str(n_local_enhancers)+'_1') |
|
model_upsample = getattr(self, 'model'+str(n_local_enhancers)+'_2') |
|
input_i = input_downsampled[self.n_local_enhancers-n_local_enhancers] |
|
output_prev = model_upsample(model_downsample(input_i) + output_prev) |
|
return output_prev |
|
|
|
class GlobalGenerator(nn.Module): |
|
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, |
|
padding_type='reflect', last_op=nn.Tanh()): |
|
assert(n_blocks >= 0) |
|
super(GlobalGenerator, self).__init__() |
|
activation = nn.ReLU(True) |
|
|
|
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation] |
|
|
|
for i in range(n_downsampling): |
|
mult = 2**i |
|
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), |
|
norm_layer(ngf * mult * 2), activation] |
|
|
|
|
|
mult = 2**n_downsampling |
|
for i in range(n_blocks): |
|
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)] |
|
|
|
|
|
for i in range(n_downsampling): |
|
mult = 2**(n_downsampling - i) |
|
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), |
|
norm_layer(int(ngf * mult / 2)), activation] |
|
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] |
|
if last_op is not None: |
|
model += [last_op] |
|
self.model = nn.Sequential(*model) |
|
|
|
def forward(self, input): |
|
return self.model(input) |
|
|
|
|
|
class ResnetBlock(nn.Module): |
|
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False): |
|
super(ResnetBlock, self).__init__() |
|
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout) |
|
|
|
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout): |
|
conv_block = [] |
|
p = 0 |
|
if padding_type == 'reflect': |
|
conv_block += [nn.ReflectionPad2d(1)] |
|
elif padding_type == 'replicate': |
|
conv_block += [nn.ReplicationPad2d(1)] |
|
elif padding_type == 'zero': |
|
p = 1 |
|
else: |
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type) |
|
|
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), |
|
norm_layer(dim), |
|
activation] |
|
if use_dropout: |
|
conv_block += [nn.Dropout(0.5)] |
|
|
|
p = 0 |
|
if padding_type == 'reflect': |
|
conv_block += [nn.ReflectionPad2d(1)] |
|
elif padding_type == 'replicate': |
|
conv_block += [nn.ReplicationPad2d(1)] |
|
elif padding_type == 'zero': |
|
p = 1 |
|
else: |
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type) |
|
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), |
|
norm_layer(dim)] |
|
|
|
return nn.Sequential(*conv_block) |
|
|
|
def forward(self, x): |
|
out = x + self.conv_block(x) |
|
return out |
|
|
|
class Encoder(nn.Module): |
|
def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d): |
|
super(Encoder, self).__init__() |
|
self.output_nc = output_nc |
|
|
|
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), |
|
norm_layer(ngf), nn.ReLU(True)] |
|
|
|
for i in range(n_downsampling): |
|
mult = 2**i |
|
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), |
|
norm_layer(ngf * mult * 2), nn.ReLU(True)] |
|
|
|
|
|
for i in range(n_downsampling): |
|
mult = 2**(n_downsampling - i) |
|
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), |
|
norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] |
|
|
|
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] |
|
self.model = nn.Sequential(*model) |
|
|
|
def forward(self, input, inst): |
|
outputs = self.model(input) |
|
|
|
|
|
outputs_mean = outputs.clone() |
|
inst_list = np.unique(inst.cpu().numpy().astype(int)) |
|
for i in inst_list: |
|
for b in range(input.size()[0]): |
|
indices = (inst[b:b+1] == int(i)).nonzero() |
|
for j in range(self.output_nc): |
|
output_ins = outputs[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]] |
|
mean_feat = torch.mean(output_ins).expand_as(output_ins) |
|
outputs_mean[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]] = mean_feat |
|
return outputs_mean |
|
|