Spaces:
Runtime error
Runtime error
File size: 2,141 Bytes
58f667f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
# Copyright (c) OpenMMLab. All rights reserved.
# Code reference from "Temporal Interlacing Network"
# https://github.com/deepcs233/TIN/blob/master/cuda_shift/rtc_wrap.py
# Hao Shao, Shengju Qian, Yu Liu
# [email protected], [email protected], [email protected]
import torch
import torch.nn as nn
from torch.autograd import Function
from ..utils import ext_loader
ext_module = ext_loader.load_ext('_ext',
['tin_shift_forward', 'tin_shift_backward'])
class TINShiftFunction(Function):
@staticmethod
def forward(ctx, input, shift):
C = input.size(2)
num_segments = shift.size(1)
if C // num_segments <= 0 or C % num_segments != 0:
raise ValueError('C should be a multiple of num_segments, '
f'but got C={C} and num_segments={num_segments}.')
ctx.save_for_backward(shift)
out = torch.zeros_like(input)
ext_module.tin_shift_forward(input, shift, out)
return out
@staticmethod
def backward(ctx, grad_output):
shift = ctx.saved_tensors[0]
data_grad_input = grad_output.new(*grad_output.size()).zero_()
shift_grad_input = shift.new(*shift.size()).zero_()
ext_module.tin_shift_backward(grad_output, shift, data_grad_input)
return data_grad_input, shift_grad_input
tin_shift = TINShiftFunction.apply
class TINShift(nn.Module):
"""Temporal Interlace Shift.
Temporal Interlace shift is a differentiable temporal-wise frame shifting
which is proposed in "Temporal Interlacing Network"
Please refer to https://arxiv.org/abs/2001.06499 for more details.
Code is modified from https://github.com/mit-han-lab/temporal-shift-module
"""
def forward(self, input, shift):
"""Perform temporal interlace shift.
Args:
input (Tensor): Feature map with shape [N, num_segments, C, H * W].
shift (Tensor): Shift tensor with shape [N, num_segments].
Returns:
Feature map after temporal interlace shift.
"""
return tin_shift(input, shift)
|