File size: 3,001 Bytes
8b79d57 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
from torch import nn
from torchvision.models.mobilenetv3 import MobileNetV3, InvertedResidualConfig
from torchvision.transforms.functional import normalize
class MobileNetV3LargeEncoder(MobileNetV3):
def __init__(self, pretrained: bool = False):
super().__init__(
inverted_residual_setting=[
InvertedResidualConfig( 16, 3, 16, 16, False, "RE", 1, 1, 1),
InvertedResidualConfig( 16, 3, 64, 24, False, "RE", 2, 1, 1), # C1
InvertedResidualConfig( 24, 3, 72, 24, False, "RE", 1, 1, 1),
InvertedResidualConfig( 24, 5, 72, 40, True, "RE", 2, 1, 1), # C2
InvertedResidualConfig( 40, 5, 120, 40, True, "RE", 1, 1, 1),
InvertedResidualConfig( 40, 5, 120, 40, True, "RE", 1, 1, 1),
InvertedResidualConfig( 40, 3, 240, 80, False, "HS", 2, 1, 1), # C3
InvertedResidualConfig( 80, 3, 200, 80, False, "HS", 1, 1, 1),
InvertedResidualConfig( 80, 3, 184, 80, False, "HS", 1, 1, 1),
InvertedResidualConfig( 80, 3, 184, 80, False, "HS", 1, 1, 1),
InvertedResidualConfig( 80, 3, 480, 112, True, "HS", 1, 1, 1),
InvertedResidualConfig(112, 3, 672, 112, True, "HS", 1, 1, 1),
InvertedResidualConfig(112, 5, 672, 160, True, "HS", 2, 2, 1), # C4
InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1, 2, 1),
InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1, 2, 1),
],
last_channel=1280
)
if pretrained:
self.load_state_dict(torch.hub.load_state_dict_from_url(
'https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth'))
del self.avgpool
del self.classifier
def forward_single_frame(self, x):
x = normalize(x, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
x = self.features[0](x)
x = self.features[1](x)
f1 = x
x = self.features[2](x)
x = self.features[3](x)
f2 = x
x = self.features[4](x)
x = self.features[5](x)
x = self.features[6](x)
f3 = x
x = self.features[7](x)
x = self.features[8](x)
x = self.features[9](x)
x = self.features[10](x)
x = self.features[11](x)
x = self.features[12](x)
x = self.features[13](x)
x = self.features[14](x)
x = self.features[15](x)
x = self.features[16](x)
f4 = x
return [f1, f2, f3, f4]
def forward_time_series(self, x):
B, T = x.shape[:2]
features = self.forward_single_frame(x.flatten(0, 1))
features = [f.unflatten(0, (B, T)) for f in features]
return features
def forward(self, x):
if x.ndim == 5:
return self.forward_time_series(x)
else:
return self.forward_single_frame(x)
|