import math import torch class CombinedMarginLoss(torch.nn.Module): def __init__(self, s, m1, m2, m3, interclass_filtering_threshold=0): super().__init__() self.s = s self.m1 = m1 self.m2 = m2 self.m3 = m3 self.interclass_filtering_threshold = interclass_filtering_threshold # For ArcFace self.cos_m = math.cos(self.m2) self.sin_m = math.sin(self.m2) self.theta = math.cos(math.pi - self.m2) self.sinmm = math.sin(math.pi - self.m2) * self.m2 self.easy_margin = False def forward(self, logits, labels): index_positive = torch.where(labels != -1)[0] if self.interclass_filtering_threshold > 0: with torch.no_grad(): dirty = logits > self.interclass_filtering_threshold dirty = dirty.float() mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device) mask.scatter_(1, labels[index_positive], 0) dirty[index_positive] *= mask tensor_mul = 1 - dirty logits = tensor_mul * logits target_logit = logits[index_positive, labels[index_positive].view(-1)] if self.m1 == 1.0 and self.m3 == 0.0: with torch.no_grad(): target_logit.arccos_() logits.arccos_() final_target_logit = target_logit + self.m2 logits[index_positive, labels[index_positive].view(-1)] = final_target_logit logits.cos_() logits = logits * self.s elif self.m3 > 0: final_target_logit = target_logit - self.m3 logits[index_positive, labels[index_positive].view(-1)] = final_target_logit logits = logits * self.s else: raise return logits class ArcFace(torch.nn.Module): """ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):""" def __init__(self, s=64.0, margin=0.5): super(ArcFace, self).__init__() self.scale = s self.margin = margin self.cos_m = math.cos(margin) self.sin_m = math.sin(margin) self.theta = math.cos(math.pi - margin) self.sinmm = math.sin(math.pi - margin) * margin self.easy_margin = False def forward(self, logits: torch.Tensor, labels: torch.Tensor): index = torch.where(labels != -1)[0] target_logit = logits[index, labels[index].view(-1)] with torch.no_grad(): target_logit.arccos_() logits.arccos_() final_target_logit = target_logit + self.margin logits[index, labels[index].view(-1)] = final_target_logit logits.cos_() logits = logits * self.s return logits class CosFace(torch.nn.Module): def __init__(self, s=64.0, m=0.40): super(CosFace, self).__init__() self.s = s self.m = m def forward(self, logits: torch.Tensor, labels: torch.Tensor): index = torch.where(labels != -1)[0] target_logit = logits[index, labels[index].view(-1)] final_target_logit = target_logit - self.m logits[index, labels[index].view(-1)] = final_target_logit logits = logits * self.s return logits