Spaces:
Runtime error
Runtime error
import math | |
import torch | |
class CombinedMarginLoss(torch.nn.Module): | |
def __init__(self, s, m1, m2, m3, interclass_filtering_threshold=0): | |
super().__init__() | |
self.s = s | |
self.m1 = m1 | |
self.m2 = m2 | |
self.m3 = m3 | |
self.interclass_filtering_threshold = interclass_filtering_threshold | |
# For ArcFace | |
self.cos_m = math.cos(self.m2) | |
self.sin_m = math.sin(self.m2) | |
self.theta = math.cos(math.pi - self.m2) | |
self.sinmm = math.sin(math.pi - self.m2) * self.m2 | |
self.easy_margin = False | |
def forward(self, logits, labels): | |
index_positive = torch.where(labels != -1)[0] | |
if self.interclass_filtering_threshold > 0: | |
with torch.no_grad(): | |
dirty = logits > self.interclass_filtering_threshold | |
dirty = dirty.float() | |
mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device) | |
mask.scatter_(1, labels[index_positive], 0) | |
dirty[index_positive] *= mask | |
tensor_mul = 1 - dirty | |
logits = tensor_mul * logits | |
target_logit = logits[index_positive, labels[index_positive].view(-1)] | |
if self.m1 == 1.0 and self.m3 == 0.0: | |
with torch.no_grad(): | |
target_logit.arccos_() | |
logits.arccos_() | |
final_target_logit = target_logit + self.m2 | |
logits[index_positive, labels[index_positive].view(-1)] = final_target_logit | |
logits.cos_() | |
logits = logits * self.s | |
elif self.m3 > 0: | |
final_target_logit = target_logit - self.m3 | |
logits[index_positive, labels[index_positive].view(-1)] = final_target_logit | |
logits = logits * self.s | |
else: | |
raise | |
return logits | |
class ArcFace(torch.nn.Module): | |
"""ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):""" | |
def __init__(self, s=64.0, margin=0.5): | |
super(ArcFace, self).__init__() | |
self.scale = s | |
self.margin = margin | |
self.cos_m = math.cos(margin) | |
self.sin_m = math.sin(margin) | |
self.theta = math.cos(math.pi - margin) | |
self.sinmm = math.sin(math.pi - margin) * margin | |
self.easy_margin = False | |
def forward(self, logits: torch.Tensor, labels: torch.Tensor): | |
index = torch.where(labels != -1)[0] | |
target_logit = logits[index, labels[index].view(-1)] | |
with torch.no_grad(): | |
target_logit.arccos_() | |
logits.arccos_() | |
final_target_logit = target_logit + self.margin | |
logits[index, labels[index].view(-1)] = final_target_logit | |
logits.cos_() | |
logits = logits * self.s | |
return logits | |
class CosFace(torch.nn.Module): | |
def __init__(self, s=64.0, m=0.40): | |
super(CosFace, self).__init__() | |
self.s = s | |
self.m = m | |
def forward(self, logits: torch.Tensor, labels: torch.Tensor): | |
index = torch.where(labels != -1)[0] | |
target_logit = logits[index, labels[index].view(-1)] | |
final_target_logit = target_logit - self.m | |
logits[index, labels[index].view(-1)] = final_target_logit | |
logits = logits * self.s | |
return logits | |