guesstimatelocation / models /huggingface.py
yunusserhat's picture
Upload 40 files
94f372a verified
raw
history blame contribute delete
854 Bytes
import torch
from torch import nn
from hydra.utils import instantiate
from omegaconf import OmegaConf
from huggingface_hub import PyTorchModelHubMixin
class Geolocalizer(nn.Module, PyTorchModelHubMixin):
def __init__(self, config):
super().__init__()
self.config = OmegaConf.create(config)
self.transform = instantiate(self.config.transform)
self.model = instantiate(self.config.model)
self.head = self.model.head
self.mid = self.model.mid
self.backbone = self.model.backbone
def forward(self, img: torch.Tensor):
output = self.head(self.mid(self.backbone({"img": img})), None)
return output["gps"]
def forward_tensor(self, img: torch.Tensor):
output = self.head(self.mid(self.backbone(img)), None)
return output["gps"]