Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn | |
class MLP(nn.Module): | |
def __init__( | |
self, | |
initial_dim=512, | |
hidden_dim=[128, 32, 2], | |
final_dim=2, | |
norm=nn.InstanceNorm1d, | |
activation=nn.ReLU, | |
aux_data=[], | |
): | |
""" | |
Initializes an MLP Classification Head | |
Args: | |
hidden_dim (list): list of hidden dimensions for the MLP | |
norm (nn.Module): normalization layer | |
activation (nn.Module): activation layer | |
""" | |
super().__init__() | |
self.aux_data = aux_data | |
self.aux = len(self.aux_data) > 0 | |
if self.aux: | |
hidden_dim_aux = hidden_dim | |
hidden_dim_aux[-1] = 128 | |
final_dim_aux_dict = { | |
"land_cover": 12, | |
"climate": 30, | |
"soil": 14, | |
"road_index": 1, | |
"drive_side": 1, | |
"dist_sea": 1, | |
} | |
self.idx = {} | |
final_dim_aux = 0 | |
for col in self.aux_data: | |
self.idx[col] = [ | |
final_dim_aux + i for i in range(final_dim_aux_dict[col]) | |
] | |
final_dim_aux += final_dim_aux_dict[col] | |
dim = [initial_dim] + hidden_dim_aux + [final_dim_aux] | |
args = self.init_layers(dim, norm, activation) | |
self.mlp_aux = nn.Sequential(*args) | |
dim = [initial_dim] + hidden_dim + [final_dim] | |
args = self.init_layers(dim, norm, activation) | |
self.mlp = nn.Sequential(*args) | |
def init_layers(self, dim, norm, activation): | |
"""Initializes the MLP layers.""" | |
args = [nn.LayerNorm(dim[0])] | |
for i in range(len(dim) - 1): | |
args.append(nn.Linear(dim[i], dim[i + 1])) | |
if i < len(dim) - 2: | |
# args.append(norm(dim[i + 1])) | |
args.append(norm(4, dim[i + 1])) | |
args.append(activation()) | |
return args | |
def forward(self, x): | |
"""Predicts GPS coordinates from an image. | |
Args: | |
x: torch.Tensor with features | |
""" | |
if self.aux: | |
out = {"gps": self.mlp(x[:, 0, :])} | |
x = self.mlp_aux(x[:, 0, :]) | |
for col in list(self.idx.keys()): | |
out[col] = x[:, self.idx[col]] | |
return out | |
return self.mlp(x[:, 0, :]) | |
class MLPResNet(nn.Module): | |
def __init__( | |
self, | |
initial_dim=512, | |
hidden_dim=[128, 32, 2], | |
final_dim=2, | |
norm=nn.InstanceNorm1d, | |
activation=nn.ReLU, | |
aux_data=[], | |
): | |
""" | |
Initializes an MLP Classification Head | |
Args: | |
hidden_dim (list): list of hidden dimensions for the MLP | |
norm (nn.Module): normalization layer | |
activation (nn.Module): activation layer | |
""" | |
super().__init__() | |
self.aux_data = aux_data | |
self.aux = len(self.aux_data) > 0 | |
if self.aux: | |
hidden_dim_aux = hidden_dim | |
hidden_dim_aux[-1] = 128 | |
final_dim_aux_dict = { | |
"land_cover": 12, | |
"climate": 30, | |
"soil": 14, | |
"road_index": 1, | |
"drive_side": 1, | |
"dist_sea": 1, | |
} | |
self.idx = {} | |
final_dim_aux = 0 | |
for col in self.aux_data: | |
self.idx[col] = [ | |
final_dim_aux + i for i in range(final_dim_aux_dict[col]) | |
] | |
final_dim_aux += final_dim_aux_dict[col] | |
dim = [initial_dim] + hidden_dim_aux + [final_dim_aux] | |
args = self.init_layers(dim, norm, activation) | |
self.mlp_aux = nn.Sequential(*args) | |
dim = [initial_dim] + hidden_dim + [final_dim] | |
args = self.init_layers(dim, norm, activation) | |
self.mlp = nn.Sequential(*args) | |
def init_layers(self, dim, norm, activation): | |
"""Initializes the MLP layers.""" | |
args = [nn.LayerNorm(dim[0])] | |
for i in range(len(dim) - 1): | |
args.append(nn.Linear(dim[i], dim[i + 1])) | |
if i < len(dim) - 2: | |
# args.append(norm(dim[i + 1])) | |
args.append(norm(4, dim[i + 1])) | |
args.append(activation()) | |
return args | |
def forward(self, x): | |
"""Predicts GPS coordinates from an image. | |
Args: | |
x: torch.Tensor with features | |
""" | |
if self.aux: | |
out = {"gps": self.mlp(x[:, 0, :])} | |
x = self.mlp_aux(x[:, 0, :]) | |
for col in list(self.idx.keys()): | |
out[col] = x[:, self.idx[col]] | |
return out | |
return self.mlp(x) | |
class MLPCentroid(nn.Module): | |
def __init__( | |
self, | |
initial_dim=512, | |
hidden_dim=[128, 32, 2], | |
final_dim=2, | |
norm=nn.InstanceNorm1d, | |
activation=nn.ReLU, | |
aux_data=[], | |
): | |
""" | |
Initializes an MLP Classification Head | |
Args: | |
hidden_dim (list): list of hidden dimensions for the MLP | |
norm (nn.Module): normalization layer | |
activation (nn.Module): activation layer | |
""" | |
super().__init__() | |
self.aux_data = aux_data | |
self.aux = len(self.aux_data) > 0 | |
dim = [initial_dim] + hidden_dim + [final_dim // 3] | |
args = self.init_layers(dim, norm, activation) | |
self.classif = nn.Sequential(*args) | |
dim = [initial_dim] + hidden_dim + [2 * final_dim // 3] | |
args = self.init_layers(dim, norm, activation) | |
self.reg = nn.Sequential(*args) | |
# torch.nn.init.normal_(self.reg.weight, mean=0.0, std=0.01) | |
if self.aux: | |
self.dim = [initial_dim] + hidden_dim | |
self.predictors = {"gps": self.mlp} | |
self.init_aux(dim, norm, activation) | |
def init_layers(self, dim, norm, activation): | |
"""Initializes the MLP layers.""" | |
args = [nn.LayerNorm(dim[0])] | |
for i in range(len(dim) - 1): | |
args.append(nn.Linear(dim[i], dim[i + 1])) | |
if i < len(dim) - 2: | |
# args.append(norm(dim[i + 1])) | |
args.append(norm(4, dim[i + 1])) | |
args.append(activation()) | |
return args | |
def init_aux(self, dim, norm, activation): | |
final_dim_aux = { | |
"land_cover": 12, | |
"climate": 30, | |
"soil": 14, | |
"road_index": 1, | |
"drive_side": 1, | |
"dist_sea": 1, | |
} | |
if "land_cover" in self.aux_data: | |
args = self.init_layers( | |
self.dim + [final_dim_aux["land_cover"]], norm, activation | |
) | |
self.land_cover = nn.Sequential(*args) | |
self.predictors["land_cover"] = self.land_cover | |
if "road_index" in self.aux_data: | |
args = self.init_layers( | |
self.dim + [final_dim_aux["road_index"]], norm, activation | |
) | |
self.road_index = nn.Sequential(*args) | |
self.predictors["road_index"] = self.road_index | |
if "drive_side" in self.aux_data: | |
args = self.init_layers( | |
self.dim + [final_dim_aux["drive_side"]], norm, activation | |
) | |
self.drive_side = nn.Sequential(*args) | |
self.predictors["drive_side"] = self.drive_side | |
if "climate" in self.aux_data: | |
args = self.init_layers( | |
self.dim + [final_dim_aux["climate"]], norm, activation | |
) | |
self.climate = nn.Sequential(*args) | |
self.predictors["climate"] = self.climate | |
if "soil" in self.aux_data: | |
args = self.init_layers( | |
self.dim + [final_dim_aux["soil"]], norm, activation | |
) | |
self.soil = nn.Sequential(*args) | |
self.predictors["soil"] = self.soil | |
if "dist_sea" in self.aux_data: | |
args = self.init_layers( | |
self.dim + [final_dim_aux["dist_sea"]], norm, activation | |
) | |
self.dist_sea = nn.Sequential(*args) | |
self.predictors["dist_sea"] = self.dist_sea | |
def forward(self, x): | |
"""Predicts GPS coordinates from an image. | |
Args: | |
x: torch.Tensor with features | |
""" | |
if self.aux: | |
return { | |
col: self.predictors[col](x[:, 0, :]) for col in self.predictors.keys() | |
} | |
return torch.cat([self.classif(x[:, 0, :]), self.reg(x[:, 0, :])], dim=1) | |
class Identity(nn.Module): | |
def __init__( | |
self | |
): | |
""" | |
Initializes an Identity module | |
""" | |
super().__init__() | |
def forward(self, x): | |
""" | |
Return same as input | |
""" | |
return x |