yunusserhat's picture
Upload 40 files
94f372a verified
import torch
from torch import nn
class MLP(nn.Module):
def __init__(
self,
initial_dim=512,
hidden_dim=[128, 32, 2],
final_dim=2,
norm=nn.InstanceNorm1d,
activation=nn.ReLU,
aux_data=[],
):
"""
Initializes an MLP Classification Head
Args:
hidden_dim (list): list of hidden dimensions for the MLP
norm (nn.Module): normalization layer
activation (nn.Module): activation layer
"""
super().__init__()
self.aux_data = aux_data
self.aux = len(self.aux_data) > 0
if self.aux:
hidden_dim_aux = hidden_dim
hidden_dim_aux[-1] = 128
final_dim_aux_dict = {
"land_cover": 12,
"climate": 30,
"soil": 14,
"road_index": 1,
"drive_side": 1,
"dist_sea": 1,
}
self.idx = {}
final_dim_aux = 0
for col in self.aux_data:
self.idx[col] = [
final_dim_aux + i for i in range(final_dim_aux_dict[col])
]
final_dim_aux += final_dim_aux_dict[col]
dim = [initial_dim] + hidden_dim_aux + [final_dim_aux]
args = self.init_layers(dim, norm, activation)
self.mlp_aux = nn.Sequential(*args)
dim = [initial_dim] + hidden_dim + [final_dim]
args = self.init_layers(dim, norm, activation)
self.mlp = nn.Sequential(*args)
def init_layers(self, dim, norm, activation):
"""Initializes the MLP layers."""
args = [nn.LayerNorm(dim[0])]
for i in range(len(dim) - 1):
args.append(nn.Linear(dim[i], dim[i + 1]))
if i < len(dim) - 2:
# args.append(norm(dim[i + 1]))
args.append(norm(4, dim[i + 1]))
args.append(activation())
return args
def forward(self, x):
"""Predicts GPS coordinates from an image.
Args:
x: torch.Tensor with features
"""
if self.aux:
out = {"gps": self.mlp(x[:, 0, :])}
x = self.mlp_aux(x[:, 0, :])
for col in list(self.idx.keys()):
out[col] = x[:, self.idx[col]]
return out
return self.mlp(x[:, 0, :])
class MLPResNet(nn.Module):
def __init__(
self,
initial_dim=512,
hidden_dim=[128, 32, 2],
final_dim=2,
norm=nn.InstanceNorm1d,
activation=nn.ReLU,
aux_data=[],
):
"""
Initializes an MLP Classification Head
Args:
hidden_dim (list): list of hidden dimensions for the MLP
norm (nn.Module): normalization layer
activation (nn.Module): activation layer
"""
super().__init__()
self.aux_data = aux_data
self.aux = len(self.aux_data) > 0
if self.aux:
hidden_dim_aux = hidden_dim
hidden_dim_aux[-1] = 128
final_dim_aux_dict = {
"land_cover": 12,
"climate": 30,
"soil": 14,
"road_index": 1,
"drive_side": 1,
"dist_sea": 1,
}
self.idx = {}
final_dim_aux = 0
for col in self.aux_data:
self.idx[col] = [
final_dim_aux + i for i in range(final_dim_aux_dict[col])
]
final_dim_aux += final_dim_aux_dict[col]
dim = [initial_dim] + hidden_dim_aux + [final_dim_aux]
args = self.init_layers(dim, norm, activation)
self.mlp_aux = nn.Sequential(*args)
dim = [initial_dim] + hidden_dim + [final_dim]
args = self.init_layers(dim, norm, activation)
self.mlp = nn.Sequential(*args)
def init_layers(self, dim, norm, activation):
"""Initializes the MLP layers."""
args = [nn.LayerNorm(dim[0])]
for i in range(len(dim) - 1):
args.append(nn.Linear(dim[i], dim[i + 1]))
if i < len(dim) - 2:
# args.append(norm(dim[i + 1]))
args.append(norm(4, dim[i + 1]))
args.append(activation())
return args
def forward(self, x):
"""Predicts GPS coordinates from an image.
Args:
x: torch.Tensor with features
"""
if self.aux:
out = {"gps": self.mlp(x[:, 0, :])}
x = self.mlp_aux(x[:, 0, :])
for col in list(self.idx.keys()):
out[col] = x[:, self.idx[col]]
return out
return self.mlp(x)
class MLPCentroid(nn.Module):
def __init__(
self,
initial_dim=512,
hidden_dim=[128, 32, 2],
final_dim=2,
norm=nn.InstanceNorm1d,
activation=nn.ReLU,
aux_data=[],
):
"""
Initializes an MLP Classification Head
Args:
hidden_dim (list): list of hidden dimensions for the MLP
norm (nn.Module): normalization layer
activation (nn.Module): activation layer
"""
super().__init__()
self.aux_data = aux_data
self.aux = len(self.aux_data) > 0
dim = [initial_dim] + hidden_dim + [final_dim // 3]
args = self.init_layers(dim, norm, activation)
self.classif = nn.Sequential(*args)
dim = [initial_dim] + hidden_dim + [2 * final_dim // 3]
args = self.init_layers(dim, norm, activation)
self.reg = nn.Sequential(*args)
# torch.nn.init.normal_(self.reg.weight, mean=0.0, std=0.01)
if self.aux:
self.dim = [initial_dim] + hidden_dim
self.predictors = {"gps": self.mlp}
self.init_aux(dim, norm, activation)
def init_layers(self, dim, norm, activation):
"""Initializes the MLP layers."""
args = [nn.LayerNorm(dim[0])]
for i in range(len(dim) - 1):
args.append(nn.Linear(dim[i], dim[i + 1]))
if i < len(dim) - 2:
# args.append(norm(dim[i + 1]))
args.append(norm(4, dim[i + 1]))
args.append(activation())
return args
def init_aux(self, dim, norm, activation):
final_dim_aux = {
"land_cover": 12,
"climate": 30,
"soil": 14,
"road_index": 1,
"drive_side": 1,
"dist_sea": 1,
}
if "land_cover" in self.aux_data:
args = self.init_layers(
self.dim + [final_dim_aux["land_cover"]], norm, activation
)
self.land_cover = nn.Sequential(*args)
self.predictors["land_cover"] = self.land_cover
if "road_index" in self.aux_data:
args = self.init_layers(
self.dim + [final_dim_aux["road_index"]], norm, activation
)
self.road_index = nn.Sequential(*args)
self.predictors["road_index"] = self.road_index
if "drive_side" in self.aux_data:
args = self.init_layers(
self.dim + [final_dim_aux["drive_side"]], norm, activation
)
self.drive_side = nn.Sequential(*args)
self.predictors["drive_side"] = self.drive_side
if "climate" in self.aux_data:
args = self.init_layers(
self.dim + [final_dim_aux["climate"]], norm, activation
)
self.climate = nn.Sequential(*args)
self.predictors["climate"] = self.climate
if "soil" in self.aux_data:
args = self.init_layers(
self.dim + [final_dim_aux["soil"]], norm, activation
)
self.soil = nn.Sequential(*args)
self.predictors["soil"] = self.soil
if "dist_sea" in self.aux_data:
args = self.init_layers(
self.dim + [final_dim_aux["dist_sea"]], norm, activation
)
self.dist_sea = nn.Sequential(*args)
self.predictors["dist_sea"] = self.dist_sea
def forward(self, x):
"""Predicts GPS coordinates from an image.
Args:
x: torch.Tensor with features
"""
if self.aux:
return {
col: self.predictors[col](x[:, 0, :]) for col in self.predictors.keys()
}
return torch.cat([self.classif(x[:, 0, :]), self.reg(x[:, 0, :])], dim=1)
class Identity(nn.Module):
def __init__(
self
):
"""
Initializes an Identity module
"""
super().__init__()
def forward(self, x):
"""
Return same as input
"""
return x