File size: 1,349 Bytes
8b79d57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from torch import nn
from torchvision.models.resnet import ResNet, Bottleneck

class ResNet50Encoder(ResNet):
    def __init__(self, pretrained: bool = False):
        super().__init__(
            block=Bottleneck,
            layers=[3, 4, 6, 3],
            replace_stride_with_dilation=[False, False, True],
            norm_layer=None)
        
        if pretrained:
            self.load_state_dict(torch.hub.load_state_dict_from_url(
                'https://download.pytorch.org/models/resnet50-0676ba61.pth'))
        
        del self.avgpool
        del self.fc
        
    def forward_single_frame(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        f1 = x  # 1/2
        x = self.maxpool(x)
        x = self.layer1(x)
        f2 = x  # 1/4
        x = self.layer2(x)
        f3 = x  # 1/8
        x = self.layer3(x)
        x = self.layer4(x)
        f4 = x  # 1/16
        return [f1, f2, f3, f4]
    
    def forward_time_series(self, x):
        B, T = x.shape[:2]
        features = self.forward_single_frame(x.flatten(0, 1))
        features = [f.unflatten(0, (B, T)) for f in features]
        return features
    
    def forward(self, x):
        if x.ndim == 5:
            return self.forward_time_series(x)
        else:
            return self.forward_single_frame(x)