guesstimatelocation / metrics /distance_based.py
yunusserhat's picture
Upload 40 files
94f372a verified
import torch
from metrics.utils import haversine, reverse
from torchmetrics import Metric
class HaversineMetrics(Metric):
"""
Computes the average haversine distance between the predicted and ground truth points.
Compute the accuracy given some radiuses.
Compute the Geoguessr score given some radiuses.
Args:
acc_radiuses (list): list of radiuses to compute the accuracy from
acc_area (list): list of areas to compute the accuracy from.
acc_data (list): list of auxilliary data to compute the accuracy from.
"""
def __init__(
self,
acc_radiuses=[],
acc_area=["country", "region", "sub-region", "city"],
aux_data=[],
):
super().__init__()
self.add_state("haversine_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
self.add_state("geoguessr_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
for acc in acc_radiuses:
self.add_state(
f"close_enough_points_{acc}",
default=torch.tensor(0.0),
dist_reduce_fx="sum",
)
for acc in acc_area:
self.add_state(
f"close_enough_points_{acc}",
default=torch.tensor(0.0),
dist_reduce_fx="sum",
)
self.add_state(
f"count_{acc}", default=torch.tensor(0), dist_reduce_fx="sum"
)
self.acc_radius = acc_radiuses
self.acc_area = acc_area
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
self.aux = len(aux_data) > 0
self.aux_list = aux_data
if self.aux:
self.aux_count = {}
for col in self.aux_list:
self.add_state(
f"aux_{col}",
default=torch.tensor(0.0),
dist_reduce_fx="sum",
)
def update(self, pred, gt):
haversine_distance = haversine(pred["gps"], gt["gps"])
for acc in self.acc_radius:
self.__dict__[f"close_enough_points_{acc}"] += (
haversine_distance < acc
).sum()
if len(self.acc_area) > 0:
area_pred, area_gt = reverse(pred["gps"], gt, self.acc_area)
for acc in self.acc_area:
self.__dict__[f"close_enough_points_{acc}"] += (
area_pred[acc] == area_gt["_".join(["unique", acc])]
).sum()
self.__dict__[f"count_{acc}"] += len(area_gt["_".join(["unique", acc])])
self.haversine_sum += haversine_distance.sum()
self.geoguessr_sum += 5000 * torch.exp(-haversine_distance / 1492.7).sum()
if self.aux:
if "land_cover" in self.aux_list:
col = "land_cover"
self.__dict__[f"aux_{col}"] += (
pred[col].argmax(dim=1) == gt[col].argmax(dim=1)
).sum()
if "road_index" in self.aux_list:
col = "road_index"
self.__dict__[f"aux_{col}"] += (
pred[col].argmax(dim=1) == gt[col].argmax(dim=1)
).sum()
if "drive_side" in self.aux_list:
col = "drive_side"
self.__dict__[f"aux_{col}"] += (
(pred[col] > 0.5).float() == gt[col]
).sum()
if "climate" in self.aux_list:
col = "climate"
self.__dict__[f"aux_{col}"] += (
pred[col].argmax(dim=1) == gt[col].argmax(dim=1)
).sum()
if "soil" in self.aux_list:
col = "soil"
self.__dict__[f"aux_{col}"] += (
pred[col].argmax(dim=1) == gt[col].argmax(dim=1)
).sum()
if "dist_sea" in self.aux_list:
col = "dist_sea"
self.__dict__[f"aux_{col}"] += (
(pred[col] - gt[col]).pow(2).sum(dim=1).sum()
)
self.count += pred["gps"].shape[0]
def compute(self):
output = {
"Haversine": self.haversine_sum / self.count,
"Geoguessr": self.geoguessr_sum / self.count,
}
for acc in self.acc_radius:
output[f"Accuracy_{acc}_km_radius"] = (
self.__dict__[f"close_enough_points_{acc}"] / self.count
)
for acc in self.acc_area:
output[f"Accuracy_{acc}"] = (
self.__dict__[f"close_enough_points_{acc}"]
/ self.__dict__[f"count_{acc}"]
)
if self.aux:
for col in self.aux_list:
output["_".join(["Accuracy", col])] = (
self.__dict__[f"aux_{col}"] / self.count
)
return output