LN3Diff / nsr /superresolution.py
NIRVANALAN
release file
87c126b
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
"""Superresolution network architectures from the paper
"Efficient Geometry-aware 3D Generative Adversarial Networks"."""
import torch
from nsr.networks_stylegan2 import Conv2dLayer, SynthesisLayer, ToRGBLayer
from utils.torch_utils.ops import upfirdn2d
from utils.torch_utils import persistence
from utils.torch_utils import misc
from nsr.networks_stylegan2 import SynthesisBlock
import numpy as np
from pdb import set_trace as st
@persistence.persistent_class
class SynthesisBlockNoUp(torch.nn.Module):
def __init__(
self,
in_channels, # Number of input channels, 0 = first block.
out_channels, # Number of output channels.
w_dim, # Intermediate latent (W) dimensionality.
resolution, # Resolution of this block.
img_channels, # Number of output color channels.
is_last, # Is this the last block?
architecture='skip', # Architecture: 'orig', 'skip', 'resnet'.
resample_filter=[
1, 3, 3, 1
], # Low-pass filter to apply when resampling activations.
conv_clamp=256, # Clamp the output of convolution layers to +-X, None = disable clamping.
use_fp16=False, # Use FP16 for this block?
fp16_channels_last=False, # Use channels-last memory format with FP16?
fused_modconv_default=True, # Default value of fused_modconv. 'inference_only' = True for inference, False for training.
**layer_kwargs, # Arguments for SynthesisLayer.
):
assert architecture in ['orig', 'skip', 'resnet']
super().__init__()
self.in_channels = in_channels
self.w_dim = w_dim
self.resolution = resolution
self.img_channels = img_channels
self.is_last = is_last
self.architecture = architecture
self.use_fp16 = use_fp16
self.channels_last = (use_fp16 and fp16_channels_last)
self.fused_modconv_default = fused_modconv_default
self.register_buffer('resample_filter',
upfirdn2d.setup_filter(resample_filter))
self.num_conv = 0
self.num_torgb = 0
if in_channels == 0:
self.const = torch.nn.Parameter(
torch.randn([out_channels, resolution, resolution]))
if in_channels != 0:
self.conv0 = SynthesisLayer(in_channels,
out_channels,
w_dim=w_dim,
resolution=resolution,
conv_clamp=conv_clamp,
channels_last=self.channels_last,
**layer_kwargs)
self.num_conv += 1
self.conv1 = SynthesisLayer(out_channels,
out_channels,
w_dim=w_dim,
resolution=resolution,
conv_clamp=conv_clamp,
channels_last=self.channels_last,
**layer_kwargs)
self.num_conv += 1
if is_last or architecture == 'skip':
self.torgb = ToRGBLayer(out_channels,
img_channels,
w_dim=w_dim,
conv_clamp=conv_clamp,
channels_last=self.channels_last)
self.num_torgb += 1
if in_channels != 0 and architecture == 'resnet':
self.skip = Conv2dLayer(in_channels,
out_channels,
kernel_size=1,
bias=False,
up=2,
resample_filter=resample_filter,
channels_last=self.channels_last)
def forward(self,
x,
img,
ws,
force_fp32=False,
fused_modconv=None,
update_emas=False,
**layer_kwargs):
_ = update_emas # unused
misc.assert_shape(ws,
[None, self.num_conv + self.num_torgb, self.w_dim])
w_iter = iter(ws.unbind(dim=1))
if ws.device.type != 'cuda':
force_fp32 = True
dtype = torch.float16 if self.use_fp16 and not force_fp32 else torch.float32
memory_format = torch.channels_last if self.channels_last and not force_fp32 else torch.contiguous_format
if fused_modconv is None:
fused_modconv = self.fused_modconv_default
if fused_modconv == 'inference_only':
fused_modconv = (not self.training)
# Input.
if self.in_channels == 0:
x = self.const.to(dtype=dtype, memory_format=memory_format)
x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1])
else:
misc.assert_shape(
x, [None, self.in_channels, self.resolution, self.resolution])
x = x.to(dtype=dtype, memory_format=memory_format)
# Main layers.
if self.in_channels == 0:
x = self.conv1(x,
next(w_iter),
fused_modconv=fused_modconv,
**layer_kwargs)
elif self.architecture == 'resnet':
y = self.skip(x, gain=np.sqrt(0.5))
x = self.conv0(x,
next(w_iter),
fused_modconv=fused_modconv,
**layer_kwargs)
x = self.conv1(x,
next(w_iter),
fused_modconv=fused_modconv,
gain=np.sqrt(0.5),
**layer_kwargs)
x = y.add_(x)
else:
x = self.conv0(x,
next(w_iter),
fused_modconv=fused_modconv,
**layer_kwargs)
x = self.conv1(x,
next(w_iter),
fused_modconv=fused_modconv,
**layer_kwargs)
# ToRGB.
# if img is not None:
# misc.assert_shape(img, [None, self.img_channels, self.resolution // 2, self.resolution // 2])
# img = upfirdn2d.upsample2d(img, self.resample_filter)
if self.is_last or self.architecture == 'skip':
y = self.torgb(x, next(w_iter), fused_modconv=fused_modconv)
y = y.to(dtype=torch.float32,
memory_format=torch.contiguous_format)
img = img.add_(y) if img is not None else y
# assert x.dtype == dtype # support AMP in this library
assert img is None or img.dtype == torch.float32
return x, img
def extra_repr(self):
return f'resolution={self.resolution:d}, architecture={self.architecture:s}'
#----------------------------------------------------------------------------
# for 512x512 generation
@persistence.persistent_class
class SuperresolutionHybrid8X(torch.nn.Module):
def __init__(
self,
channels,
img_resolution,
sr_num_fp16_res,
sr_antialias,
num_fp16_res=4,
conv_clamp=None,
channel_base=None,
channel_max=None, # IGNORE
**block_kwargs):
super().__init__()
# assert img_resolution == 512
use_fp16 = sr_num_fp16_res > 0
self.input_resolution = 128
self.sr_antialias = sr_antialias
self.block0 = SynthesisBlock(channels,
128,
w_dim=512,
resolution=256,
img_channels=3,
is_last=False,
use_fp16=use_fp16,
conv_clamp=(256 if use_fp16 else None),
**block_kwargs)
self.block1 = SynthesisBlock(128,
64,
w_dim=512,
resolution=512,
img_channels=3,
is_last=True,
use_fp16=use_fp16,
conv_clamp=(256 if use_fp16 else None),
**block_kwargs)
self.register_buffer('resample_filter',
upfirdn2d.setup_filter([1, 3, 3, 1]))
def forward(self, rgb, x, ws, **block_kwargs):
ws = ws[:, -1:, :].repeat(1, 3, 1)
if x.shape[-1] != self.input_resolution:
x = torch.nn.functional.interpolate(x,
size=(self.input_resolution,
self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=self.sr_antialias)
rgb = torch.nn.functional.interpolate(rgb,
size=(self.input_resolution,
self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=self.sr_antialias)
x, rgb = self.block0(x, rgb, ws, **block_kwargs) # block_kwargs: {'noise_mode': 'none'}
x, rgb = self.block1(x, rgb, ws, **block_kwargs)
return rgb
#----------------------------------------------------------------------------
# for 256x256 generation
@persistence.persistent_class
class SuperresolutionHybrid4X(torch.nn.Module):
def __init__(
self,
channels,
img_resolution,
sr_num_fp16_res,
sr_antialias,
num_fp16_res=4,
conv_clamp=None,
channel_base=None,
channel_max=None, # IGNORE
**block_kwargs):
super().__init__()
# assert img_resolution == 256
use_fp16 = sr_num_fp16_res > 0
self.sr_antialias = sr_antialias
self.input_resolution = 128
self.block0 = SynthesisBlockNoUp(
channels,
128,
w_dim=512,
resolution=128,
img_channels=3,
is_last=False,
use_fp16=use_fp16,
conv_clamp=(256 if use_fp16 else None),
**block_kwargs)
self.block1 = SynthesisBlock(128,
64,
w_dim=512,
resolution=256,
img_channels=3,
is_last=True,
use_fp16=use_fp16,
conv_clamp=(256 if use_fp16 else None),
**block_kwargs)
self.register_buffer('resample_filter',
upfirdn2d.setup_filter([1, 3, 3, 1]))
def forward(self, rgb, x, ws, **block_kwargs):
ws = ws[:, -1:, :].repeat(1, 3, 1)
if x.shape[-1] < self.input_resolution:
x = torch.nn.functional.interpolate(x,
size=(self.input_resolution,
self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=self.sr_antialias)
rgb = torch.nn.functional.interpolate(rgb,
size=(self.input_resolution,
self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=self.sr_antialias)
x, rgb = self.block0(x, rgb, ws, **block_kwargs)
x, rgb = self.block1(x, rgb, ws, **block_kwargs)
return rgb
#----------------------------------------------------------------------------
# for 128 x 128 generation
@persistence.persistent_class
class SuperresolutionHybrid2X(torch.nn.Module):
def __init__(
self,
channels,
img_resolution,
sr_num_fp16_res,
sr_antialias,
num_fp16_res=4,
conv_clamp=None,
channel_base=None,
channel_max=None, # IGNORE
**block_kwargs):
super().__init__()
assert img_resolution == 128
use_fp16 = sr_num_fp16_res > 0
self.input_resolution = 64
# self.input_resolution = 128
self.sr_antialias = sr_antialias
self.block0 = SynthesisBlockNoUp(
channels,
128,
w_dim=512,
resolution=64,
# resolution=128,
img_channels=3,
is_last=False,
use_fp16=use_fp16,
conv_clamp=(256 if use_fp16 else None),
**block_kwargs)
self.block1 = SynthesisBlock(128,
64,
w_dim=512,
resolution=128,
# resolution=256,
img_channels=3,
is_last=True,
use_fp16=use_fp16,
conv_clamp=(256 if use_fp16 else None),
**block_kwargs)
self.register_buffer('resample_filter',
upfirdn2d.setup_filter([1, 3, 3, 1]))
def forward(self, rgb, x, ws, **block_kwargs):
ws = ws[:, -1:, :].repeat(1, 3, 1)
if x.shape[-1] != self.input_resolution:
x = torch.nn.functional.interpolate(x,
size=(self.input_resolution,
self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=self.sr_antialias)
rgb = torch.nn.functional.interpolate(rgb,
size=(self.input_resolution,
self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=self.sr_antialias)
x, rgb = self.block0(x, rgb, ws, **block_kwargs)
x, rgb = self.block1(x, rgb, ws, **block_kwargs)
return rgb
#----------------------------------------------------------------------------
# for 512x512 generation
@persistence.persistent_class
class SuperresolutionHybrid8XDC(torch.nn.Module):
def __init__(
self,
channels,
img_resolution,
sr_num_fp16_res,
sr_antialias,
num_fp16_res=4,
conv_clamp=None,
channel_base=None,
channel_max=None, # IGNORE
**block_kwargs):
super().__init__()
# assert img_resolution == 512
use_fp16 = sr_num_fp16_res > 0
self.input_resolution = 128
self.sr_antialias = sr_antialias
self.block0 = SynthesisBlock(channels,
256,
w_dim=512,
resolution=256,
img_channels=3,
is_last=False,
use_fp16=use_fp16,
conv_clamp=(256 if use_fp16 else None),
**block_kwargs)
self.block1 = SynthesisBlock(256,
128,
w_dim=512,
resolution=512,
img_channels=3,
is_last=True,
use_fp16=use_fp16,
conv_clamp=(256 if use_fp16 else None),
**block_kwargs)
def forward(self, rgb, x, ws, base_x=None, **block_kwargs):
ws = ws[:, -1:, :].repeat(1, 3, 1) # BS 3 512
# st()
if x.shape[-1] != self.input_resolution: # resize 64 => 128
x = torch.nn.functional.interpolate(x,
size=(self.input_resolution,
self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=self.sr_antialias)
rgb = torch.nn.functional.interpolate(rgb,
size=(self.input_resolution,
self.input_resolution),
mode='bilinear',
align_corners=False,
antialias=self.sr_antialias)
x, rgb = self.block0(x, rgb, ws, **block_kwargs)
# print(f'device={self.block0.conv1.weight.device}')
x, rgb = self.block1(x, rgb, ws, **block_kwargs)
# print(f'device={self.block1.conv1.weight.device}')
return rgb
#----------------------------------------------------------------------------