|
import torch |
|
from torch import nn |
|
from torch.nn import KLDivLoss |
|
from torch.nn import LogSoftmax |
|
|
|
|
|
class LabelSmoothingLoss(nn.Module): |
|
def __init__(self, label_smoothing=0.0, unreliable_label=None, ignore_index=-100): |
|
""" |
|
If label_smoothing == 0.0, it is equivalent to xentropy |
|
""" |
|
assert 0.0 <= label_smoothing <= 1.0 |
|
super(LabelSmoothingLoss, self).__init__() |
|
|
|
self.ignore_index = ignore_index |
|
self.label_smoothing = label_smoothing |
|
|
|
self.loss_fn = KLDivLoss(reduction='batchmean') |
|
self.unreliable_label = unreliable_label |
|
self.max_gap = 100. |
|
self.log_softmax = LogSoftmax(1) |
|
|
|
def forward(self, output, target): |
|
""" |
|
output: logits |
|
target: labels |
|
""" |
|
vocab_size = output.shape[1] |
|
mask = (target != self.ignore_index) |
|
output, target = output[mask], target[mask] |
|
output = self.log_softmax(output) |
|
|
|
def get_smooth_prob(ls): |
|
smoothing_value = ls / (vocab_size - 1) |
|
prob = output.new_full((target.size(0), vocab_size), smoothing_value) |
|
prob.scatter_(1, target.unsqueeze(1), 1 - ls) |
|
return prob |
|
|
|
if self.unreliable_label is not None: |
|
smoothed_prob = get_smooth_prob(self.label_smoothing) |
|
hard_prob = get_smooth_prob(0.0) |
|
unreliable_mask = (target == self.unreliable_label).to(torch.float) |
|
model_prob = ((smoothed_prob.T * unreliable_mask) + (hard_prob.T * (1 - unreliable_mask))).T |
|
else: |
|
model_prob = get_smooth_prob(self.label_smoothing) |
|
|
|
loss = self.loss_fn(output, model_prob) |
|
return loss |
|
|