Spaces:
Paused
Paused
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ..utils import xavier_init | |
from .registry import UPSAMPLE_LAYERS | |
UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample) | |
UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample) | |
class PixelShufflePack(nn.Module): | |
"""Pixel Shuffle upsample layer. | |
This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to | |
achieve a simple upsampling with pixel shuffle. | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
scale_factor (int): Upsample ratio. | |
upsample_kernel (int): Kernel size of the conv layer to expand the | |
channels. | |
""" | |
def __init__(self, in_channels, out_channels, scale_factor, | |
upsample_kernel): | |
super(PixelShufflePack, self).__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.scale_factor = scale_factor | |
self.upsample_kernel = upsample_kernel | |
self.upsample_conv = nn.Conv2d( | |
self.in_channels, | |
self.out_channels * scale_factor * scale_factor, | |
self.upsample_kernel, | |
padding=(self.upsample_kernel - 1) // 2) | |
self.init_weights() | |
def init_weights(self): | |
xavier_init(self.upsample_conv, distribution='uniform') | |
def forward(self, x): | |
x = self.upsample_conv(x) | |
x = F.pixel_shuffle(x, self.scale_factor) | |
return x | |
def build_upsample_layer(cfg, *args, **kwargs): | |
"""Build upsample layer. | |
Args: | |
cfg (dict): The upsample layer config, which should contain: | |
- type (str): Layer type. | |
- scale_factor (int): Upsample ratio, which is not applicable to | |
deconv. | |
- layer args: Args needed to instantiate a upsample layer. | |
args (argument list): Arguments passed to the ``__init__`` | |
method of the corresponding conv layer. | |
kwargs (keyword arguments): Keyword arguments passed to the | |
``__init__`` method of the corresponding conv layer. | |
Returns: | |
nn.Module: Created upsample layer. | |
""" | |
if not isinstance(cfg, dict): | |
raise TypeError(f'cfg must be a dict, but got {type(cfg)}') | |
if 'type' not in cfg: | |
raise KeyError( | |
f'the cfg dict must contain the key "type", but got {cfg}') | |
cfg_ = cfg.copy() | |
layer_type = cfg_.pop('type') | |
if layer_type not in UPSAMPLE_LAYERS: | |
raise KeyError(f'Unrecognized upsample type {layer_type}') | |
else: | |
upsample = UPSAMPLE_LAYERS.get(layer_type) | |
if upsample is nn.Upsample: | |
cfg_['mode'] = layer_type | |
layer = upsample(*args, **kwargs, **cfg_) | |
return layer | |