Spaces:
Running
Running
File size: 839 Bytes
37ee4a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from typing import Type
from torch import nn
# Lightly adapted from
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
class MLPBlock(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
act: Type[nn.Module],
) -> None:
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Sequential(nn.Linear(n, k), act())
for n, k in zip([input_dim] + h, [hidden_dim] * num_layers)
)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return self.fc(x)
|