Spaces:
Runtime error
Runtime error
# Copyright (c) 2019 Shigeki Karita | |
# 2020 Mobvoi Inc (Binbin Zhang) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Label smoothing module.""" | |
import torch | |
from torch import nn | |
class LabelSmoothingLoss(nn.Module): | |
"""Label-smoothing loss. | |
In a standard CE loss, the label's data distribution is: | |
[0,1,2] -> | |
[ | |
[1.0, 0.0, 0.0], | |
[0.0, 1.0, 0.0], | |
[0.0, 0.0, 1.0], | |
] | |
In the smoothing version CE Loss,some probabilities | |
are taken from the true label prob (1.0) and are divided | |
among other labels. | |
e.g. | |
smoothing=0.1 | |
[0,1,2] -> | |
[ | |
[0.9, 0.05, 0.05], | |
[0.05, 0.9, 0.05], | |
[0.05, 0.05, 0.9], | |
] | |
Args: | |
size (int): the number of class | |
padding_idx (int): padding class id which will be ignored for loss | |
smoothing (float): smoothing rate (0.0 means the conventional CE) | |
normalize_length (bool): | |
normalize loss by sequence length if True | |
normalize loss by batch size if False | |
""" | |
def __init__(self, | |
size: int, | |
padding_idx: int, | |
smoothing: float, | |
normalize_length: bool = False): | |
"""Construct an LabelSmoothingLoss object.""" | |
super(LabelSmoothingLoss, self).__init__() | |
self.criterion = nn.KLDivLoss(reduction="none") | |
self.padding_idx = padding_idx | |
self.confidence = 1.0 - smoothing | |
self.smoothing = smoothing | |
self.size = size | |
self.normalize_length = normalize_length | |
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
"""Compute loss between x and target. | |
The model outputs and data labels tensors are flatten to | |
(batch*seqlen, class) shape and a mask is applied to the | |
padding part which should not be calculated for loss. | |
Args: | |
x (torch.Tensor): prediction (batch, seqlen, class) | |
target (torch.Tensor): | |
target signal masked with self.padding_id (batch, seqlen) | |
Returns: | |
loss (torch.Tensor) : The KL loss, scalar float value | |
""" | |
assert x.size(2) == self.size | |
batch_size = x.size(0) | |
x = x.view(-1, self.size) | |
target = target.view(-1) | |
# use zeros_like instead of torch.no_grad() for true_dist, | |
# since no_grad() can not be exported by JIT | |
true_dist = torch.zeros_like(x) | |
true_dist.fill_(self.smoothing / (self.size - 1)) | |
ignore = target == self.padding_idx # (B,) | |
total = len(target) - ignore.sum().item() | |
target = target.masked_fill(ignore, 0) # avoid -1 index | |
true_dist.scatter_(1, target.unsqueeze(1), self.confidence) | |
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) | |
denom = total if self.normalize_length else batch_size | |
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom | |