|
import torch |
|
from torch import nn |
|
from torchvision.models.resnet import ResNet, Bottleneck |
|
|
|
class ResNet50Encoder(ResNet): |
|
def __init__(self, pretrained: bool = False): |
|
super().__init__( |
|
block=Bottleneck, |
|
layers=[3, 4, 6, 3], |
|
replace_stride_with_dilation=[False, False, True], |
|
norm_layer=None) |
|
|
|
if pretrained: |
|
self.load_state_dict(torch.hub.load_state_dict_from_url( |
|
'https://download.pytorch.org/models/resnet50-0676ba61.pth')) |
|
|
|
del self.avgpool |
|
del self.fc |
|
|
|
def forward_single_frame(self, x): |
|
x = self.conv1(x) |
|
x = self.bn1(x) |
|
x = self.relu(x) |
|
f1 = x |
|
x = self.maxpool(x) |
|
x = self.layer1(x) |
|
f2 = x |
|
x = self.layer2(x) |
|
f3 = x |
|
x = self.layer3(x) |
|
x = self.layer4(x) |
|
f4 = x |
|
return [f1, f2, f3, f4] |
|
|
|
def forward_time_series(self, x): |
|
B, T = x.shape[:2] |
|
features = self.forward_single_frame(x.flatten(0, 1)) |
|
features = [f.unflatten(0, (B, T)) for f in features] |
|
return features |
|
|
|
def forward(self, x): |
|
if x.ndim == 5: |
|
return self.forward_time_series(x) |
|
else: |
|
return self.forward_single_frame(x) |
|
|