Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .registry import CONV_LAYERS | |
def conv_ws_2d(input, | |
weight, | |
bias=None, | |
stride=1, | |
padding=0, | |
dilation=1, | |
groups=1, | |
eps=1e-5): | |
c_in = weight.size(0) | |
weight_flat = weight.view(c_in, -1) | |
mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1) | |
std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1) | |
weight = (weight - mean) / (std + eps) | |
return F.conv2d(input, weight, bias, stride, padding, dilation, groups) | |
class ConvWS2d(nn.Conv2d): | |
def __init__(self, | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride=1, | |
padding=0, | |
dilation=1, | |
groups=1, | |
bias=True, | |
eps=1e-5): | |
super(ConvWS2d, self).__init__( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
bias=bias) | |
self.eps = eps | |
def forward(self, x): | |
return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding, | |
self.dilation, self.groups, self.eps) | |
class ConvAWS2d(nn.Conv2d): | |
"""AWS (Adaptive Weight Standardization) | |
This is a variant of Weight Standardization | |
(https://arxiv.org/pdf/1903.10520.pdf) | |
It is used in DetectoRS to avoid NaN | |
(https://arxiv.org/pdf/2006.02334.pdf) | |
Args: | |
in_channels (int): Number of channels in the input image | |
out_channels (int): Number of channels produced by the convolution | |
kernel_size (int or tuple): Size of the conv kernel | |
stride (int or tuple, optional): Stride of the convolution. Default: 1 | |
padding (int or tuple, optional): Zero-padding added to both sides of | |
the input. Default: 0 | |
dilation (int or tuple, optional): Spacing between kernel elements. | |
Default: 1 | |
groups (int, optional): Number of blocked connections from input | |
channels to output channels. Default: 1 | |
bias (bool, optional): If set True, adds a learnable bias to the | |
output. Default: True | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride=1, | |
padding=0, | |
dilation=1, | |
groups=1, | |
bias=True): | |
super().__init__( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
bias=bias) | |
self.register_buffer('weight_gamma', | |
torch.ones(self.out_channels, 1, 1, 1)) | |
self.register_buffer('weight_beta', | |
torch.zeros(self.out_channels, 1, 1, 1)) | |
def _get_weight(self, weight): | |
weight_flat = weight.view(weight.size(0), -1) | |
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1) | |
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1) | |
weight = (weight - mean) / std | |
weight = self.weight_gamma * weight + self.weight_beta | |
return weight | |
def forward(self, x): | |
weight = self._get_weight(self.weight) | |
return F.conv2d(x, weight, self.bias, self.stride, self.padding, | |
self.dilation, self.groups) | |
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, | |
missing_keys, unexpected_keys, error_msgs): | |
"""Override default load function. | |
AWS overrides the function _load_from_state_dict to recover | |
weight_gamma and weight_beta if they are missing. If weight_gamma and | |
weight_beta are found in the checkpoint, this function will return | |
after super()._load_from_state_dict. Otherwise, it will compute the | |
mean and std of the pretrained weights and store them in weight_beta | |
and weight_gamma. | |
""" | |
self.weight_gamma.data.fill_(-1) | |
local_missing_keys = [] | |
super()._load_from_state_dict(state_dict, prefix, local_metadata, | |
strict, local_missing_keys, | |
unexpected_keys, error_msgs) | |
if self.weight_gamma.data.mean() > 0: | |
for k in local_missing_keys: | |
missing_keys.append(k) | |
return | |
weight = self.weight.data | |
weight_flat = weight.view(weight.size(0), -1) | |
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1) | |
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1) | |
self.weight_beta.data.copy_(mean) | |
self.weight_gamma.data.copy_(std) | |
missing_gamma_beta = [ | |
k for k in local_missing_keys | |
if k.endswith('weight_gamma') or k.endswith('weight_beta') | |
] | |
for k in missing_gamma_beta: | |
local_missing_keys.remove(k) | |
for k in local_missing_keys: | |
missing_keys.append(k) | |