Spaces:
Sleeping
Sleeping
import torch | |
from metrics.utils import haversine, reverse | |
from torchmetrics import Metric | |
class HaversineMetrics(Metric): | |
""" | |
Computes the average haversine distance between the predicted and ground truth points. | |
Compute the accuracy given some radiuses. | |
Compute the Geoguessr score given some radiuses. | |
Args: | |
acc_radiuses (list): list of radiuses to compute the accuracy from | |
acc_area (list): list of areas to compute the accuracy from. | |
acc_data (list): list of auxilliary data to compute the accuracy from. | |
""" | |
def __init__( | |
self, | |
acc_radiuses=[], | |
acc_area=["country", "region", "sub-region", "city"], | |
aux_data=[], | |
): | |
super().__init__() | |
self.add_state("haversine_sum", default=torch.tensor(0.0), dist_reduce_fx="sum") | |
self.add_state("geoguessr_sum", default=torch.tensor(0.0), dist_reduce_fx="sum") | |
for acc in acc_radiuses: | |
self.add_state( | |
f"close_enough_points_{acc}", | |
default=torch.tensor(0.0), | |
dist_reduce_fx="sum", | |
) | |
for acc in acc_area: | |
self.add_state( | |
f"close_enough_points_{acc}", | |
default=torch.tensor(0.0), | |
dist_reduce_fx="sum", | |
) | |
self.add_state( | |
f"count_{acc}", default=torch.tensor(0), dist_reduce_fx="sum" | |
) | |
self.acc_radius = acc_radiuses | |
self.acc_area = acc_area | |
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") | |
self.aux = len(aux_data) > 0 | |
self.aux_list = aux_data | |
if self.aux: | |
self.aux_count = {} | |
for col in self.aux_list: | |
self.add_state( | |
f"aux_{col}", | |
default=torch.tensor(0.0), | |
dist_reduce_fx="sum", | |
) | |
def update(self, pred, gt): | |
haversine_distance = haversine(pred["gps"], gt["gps"]) | |
for acc in self.acc_radius: | |
self.__dict__[f"close_enough_points_{acc}"] += ( | |
haversine_distance < acc | |
).sum() | |
if len(self.acc_area) > 0: | |
area_pred, area_gt = reverse(pred["gps"], gt, self.acc_area) | |
for acc in self.acc_area: | |
self.__dict__[f"close_enough_points_{acc}"] += ( | |
area_pred[acc] == area_gt["_".join(["unique", acc])] | |
).sum() | |
self.__dict__[f"count_{acc}"] += len(area_gt["_".join(["unique", acc])]) | |
self.haversine_sum += haversine_distance.sum() | |
self.geoguessr_sum += 5000 * torch.exp(-haversine_distance / 1492.7).sum() | |
if self.aux: | |
if "land_cover" in self.aux_list: | |
col = "land_cover" | |
self.__dict__[f"aux_{col}"] += ( | |
pred[col].argmax(dim=1) == gt[col].argmax(dim=1) | |
).sum() | |
if "road_index" in self.aux_list: | |
col = "road_index" | |
self.__dict__[f"aux_{col}"] += ( | |
pred[col].argmax(dim=1) == gt[col].argmax(dim=1) | |
).sum() | |
if "drive_side" in self.aux_list: | |
col = "drive_side" | |
self.__dict__[f"aux_{col}"] += ( | |
(pred[col] > 0.5).float() == gt[col] | |
).sum() | |
if "climate" in self.aux_list: | |
col = "climate" | |
self.__dict__[f"aux_{col}"] += ( | |
pred[col].argmax(dim=1) == gt[col].argmax(dim=1) | |
).sum() | |
if "soil" in self.aux_list: | |
col = "soil" | |
self.__dict__[f"aux_{col}"] += ( | |
pred[col].argmax(dim=1) == gt[col].argmax(dim=1) | |
).sum() | |
if "dist_sea" in self.aux_list: | |
col = "dist_sea" | |
self.__dict__[f"aux_{col}"] += ( | |
(pred[col] - gt[col]).pow(2).sum(dim=1).sum() | |
) | |
self.count += pred["gps"].shape[0] | |
def compute(self): | |
output = { | |
"Haversine": self.haversine_sum / self.count, | |
"Geoguessr": self.geoguessr_sum / self.count, | |
} | |
for acc in self.acc_radius: | |
output[f"Accuracy_{acc}_km_radius"] = ( | |
self.__dict__[f"close_enough_points_{acc}"] / self.count | |
) | |
for acc in self.acc_area: | |
output[f"Accuracy_{acc}"] = ( | |
self.__dict__[f"close_enough_points_{acc}"] | |
/ self.__dict__[f"count_{acc}"] | |
) | |
if self.aux: | |
for col in self.aux_list: | |
output["_".join(["Accuracy", col])] = ( | |
self.__dict__[f"aux_{col}"] / self.count | |
) | |
return output | |