Spaces:
Paused
Paused
# Copyright (c) OpenMMLab. All rights reserved. | |
import math | |
import torch | |
import torch.nn as nn | |
from torch.autograd import Function | |
from torch.autograd.function import once_differentiable | |
from torch.nn.modules.utils import _pair | |
from ..utils import ext_loader | |
ext_module = ext_loader.load_ext( | |
'_ext', ['masked_im2col_forward', 'masked_col2im_forward']) | |
class MaskedConv2dFunction(Function): | |
def symbolic(g, features, mask, weight, bias, padding, stride): | |
return g.op( | |
'mmcv::MMCVMaskedConv2d', | |
features, | |
mask, | |
weight, | |
bias, | |
padding_i=padding, | |
stride_i=stride) | |
def forward(ctx, features, mask, weight, bias, padding=0, stride=1): | |
assert mask.dim() == 3 and mask.size(0) == 1 | |
assert features.dim() == 4 and features.size(0) == 1 | |
assert features.size()[2:] == mask.size()[1:] | |
pad_h, pad_w = _pair(padding) | |
stride_h, stride_w = _pair(stride) | |
if stride_h != 1 or stride_w != 1: | |
raise ValueError( | |
'Stride could not only be 1 in masked_conv2d currently.') | |
out_channel, in_channel, kernel_h, kernel_w = weight.size() | |
batch_size = features.size(0) | |
out_h = int( | |
math.floor((features.size(2) + 2 * pad_h - | |
(kernel_h - 1) - 1) / stride_h + 1)) | |
out_w = int( | |
math.floor((features.size(3) + 2 * pad_w - | |
(kernel_h - 1) - 1) / stride_w + 1)) | |
mask_inds = torch.nonzero(mask[0] > 0, as_tuple=False) | |
output = features.new_zeros(batch_size, out_channel, out_h, out_w) | |
if mask_inds.numel() > 0: | |
mask_h_idx = mask_inds[:, 0].contiguous() | |
mask_w_idx = mask_inds[:, 1].contiguous() | |
data_col = features.new_zeros(in_channel * kernel_h * kernel_w, | |
mask_inds.size(0)) | |
ext_module.masked_im2col_forward( | |
features, | |
mask_h_idx, | |
mask_w_idx, | |
data_col, | |
kernel_h=kernel_h, | |
kernel_w=kernel_w, | |
pad_h=pad_h, | |
pad_w=pad_w) | |
masked_output = torch.addmm(1, bias[:, None], 1, | |
weight.view(out_channel, -1), data_col) | |
ext_module.masked_col2im_forward( | |
masked_output, | |
mask_h_idx, | |
mask_w_idx, | |
output, | |
height=out_h, | |
width=out_w, | |
channels=out_channel) | |
return output | |
def backward(ctx, grad_output): | |
return (None, ) * 5 | |
masked_conv2d = MaskedConv2dFunction.apply | |
class MaskedConv2d(nn.Conv2d): | |
"""A MaskedConv2d which inherits the official Conv2d. | |
The masked forward doesn't implement the backward function and only | |
supports the stride parameter to be 1 currently. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride=1, | |
padding=0, | |
dilation=1, | |
groups=1, | |
bias=True): | |
super(MaskedConv2d, | |
self).__init__(in_channels, out_channels, kernel_size, stride, | |
padding, dilation, groups, bias) | |
def forward(self, input, mask=None): | |
if mask is None: # fallback to the normal Conv2d | |
return super(MaskedConv2d, self).forward(input) | |
else: | |
return masked_conv2d(input, mask, self.weight, self.bias, | |
self.padding) | |