87 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			87 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| # -*- coding: utf-8 -*-
 | |
| 
 | |
| # Copyright 2019 Shigeki Karita
 | |
| #  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
 | |
| """Label smoothing module."""
 | |
| 
 | |
| import torch
 | |
| from torch import nn
 | |
| 
 | |
| 
 | |
| class LabelSmoothingLoss(nn.Module):
 | |
|     """Label-smoothing loss.
 | |
| 
 | |
|     In a standard CE loss, the label's data distribution is:
 | |
|     [0,1,2] ->
 | |
|     [
 | |
|         [1.0, 0.0, 0.0],
 | |
|         [0.0, 1.0, 0.0],
 | |
|         [0.0, 0.0, 1.0],
 | |
|     ]
 | |
| 
 | |
|     In the smoothing version CE Loss,some probabilities
 | |
|     are taken from the true label prob (1.0) and are divided
 | |
|     among other labels.
 | |
| 
 | |
|     e.g.
 | |
|     smoothing=0.1
 | |
|     [0,1,2] ->
 | |
|     [
 | |
|         [0.9, 0.05, 0.05],
 | |
|         [0.05, 0.9, 0.05],
 | |
|         [0.05, 0.05, 0.9],
 | |
|     ]
 | |
| 
 | |
|     Args:
 | |
|         size (int): the number of class
 | |
|         padding_idx (int): padding class id which will be ignored for loss
 | |
|         smoothing (float): smoothing rate (0.0 means the conventional CE)
 | |
|         normalize_length (bool):
 | |
|             normalize loss by sequence length if True
 | |
|             normalize loss by batch size if False
 | |
|     """
 | |
|     def __init__(self,
 | |
|                  size: int,
 | |
|                  padding_idx: int,
 | |
|                  smoothing: float,
 | |
|                  normalize_length: bool = False):
 | |
|         """Construct an LabelSmoothingLoss object."""
 | |
|         super(LabelSmoothingLoss, self).__init__()
 | |
|         self.criterion = nn.KLDivLoss(reduction="none")
 | |
|         self.padding_idx = padding_idx
 | |
|         self.confidence = 1.0 - smoothing
 | |
|         self.smoothing = smoothing
 | |
|         self.size = size
 | |
|         self.normalize_length = normalize_length
 | |
| 
 | |
|     def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
 | |
|         """Compute loss between x and target.
 | |
| 
 | |
|         The model outputs and data labels tensors are flatten to
 | |
|         (batch*seqlen, class) shape and a mask is applied to the
 | |
|         padding part which should not be calculated for loss.
 | |
| 
 | |
|         Args:
 | |
|             x (torch.Tensor): prediction (batch, seqlen, class)
 | |
|             target (torch.Tensor):
 | |
|                 target signal masked with self.padding_id (batch, seqlen)
 | |
|         Returns:
 | |
|             loss (torch.Tensor) : The KL loss, scalar float value
 | |
|         """
 | |
|         assert x.size(2) == self.size
 | |
|         batch_size = x.size(0)
 | |
|         x = x.view(-1, self.size)
 | |
|         target = target.view(-1)
 | |
|         # use zeros_like instead of torch.no_grad() for true_dist,
 | |
|         # since no_grad() can not be exported by JIT
 | |
|         true_dist = torch.zeros_like(x)
 | |
|         true_dist.fill_(self.smoothing / (self.size - 1))
 | |
|         ignore = target == self.padding_idx  # (B,)
 | |
|         total = len(target) - ignore.sum().item()
 | |
|         target = target.masked_fill(ignore, 0)  # avoid -1 index
 | |
|         true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
 | |
|         kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
 | |
|         denom = total if self.normalize_length else batch_size
 | |
|         return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
 | 
