87 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			87 lines
		
	
	
		
			2.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | ||
|  | # -*- coding: utf-8 -*- | ||
|  | 
 | ||
|  | # Copyright 2019 Shigeki Karita | ||
|  | #  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0) | ||
|  | """Label smoothing module.""" | ||
|  | 
 | ||
|  | import torch | ||
|  | from torch import nn | ||
|  | 
 | ||
|  | 
 | ||
|  | class LabelSmoothingLoss(nn.Module): | ||
|  |     """Label-smoothing loss.
 | ||
|  | 
 | ||
|  |     In a standard CE loss, the label's data distribution is: | ||
|  |     [0,1,2] -> | ||
|  |     [ | ||
|  |         [1.0, 0.0, 0.0], | ||
|  |         [0.0, 1.0, 0.0], | ||
|  |         [0.0, 0.0, 1.0], | ||
|  |     ] | ||
|  | 
 | ||
|  |     In the smoothing version CE Loss,some probabilities | ||
|  |     are taken from the true label prob (1.0) and are divided | ||
|  |     among other labels. | ||
|  | 
 | ||
|  |     e.g. | ||
|  |     smoothing=0.1 | ||
|  |     [0,1,2] -> | ||
|  |     [ | ||
|  |         [0.9, 0.05, 0.05], | ||
|  |         [0.05, 0.9, 0.05], | ||
|  |         [0.05, 0.05, 0.9], | ||
|  |     ] | ||
|  | 
 | ||
|  |     Args: | ||
|  |         size (int): the number of class | ||
|  |         padding_idx (int): padding class id which will be ignored for loss | ||
|  |         smoothing (float): smoothing rate (0.0 means the conventional CE) | ||
|  |         normalize_length (bool): | ||
|  |             normalize loss by sequence length if True | ||
|  |             normalize loss by batch size if False | ||
|  |     """
 | ||
|  |     def __init__(self, | ||
|  |                  size: int, | ||
|  |                  padding_idx: int, | ||
|  |                  smoothing: float, | ||
|  |                  normalize_length: bool = False): | ||
|  |         """Construct an LabelSmoothingLoss object.""" | ||
|  |         super(LabelSmoothingLoss, self).__init__() | ||
|  |         self.criterion = nn.KLDivLoss(reduction="none") | ||
|  |         self.padding_idx = padding_idx | ||
|  |         self.confidence = 1.0 - smoothing | ||
|  |         self.smoothing = smoothing | ||
|  |         self.size = size | ||
|  |         self.normalize_length = normalize_length | ||
|  | 
 | ||
|  |     def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | ||
|  |         """Compute loss between x and target.
 | ||
|  | 
 | ||
|  |         The model outputs and data labels tensors are flatten to | ||
|  |         (batch*seqlen, class) shape and a mask is applied to the | ||
|  |         padding part which should not be calculated for loss. | ||
|  | 
 | ||
|  |         Args: | ||
|  |             x (torch.Tensor): prediction (batch, seqlen, class) | ||
|  |             target (torch.Tensor): | ||
|  |                 target signal masked with self.padding_id (batch, seqlen) | ||
|  |         Returns: | ||
|  |             loss (torch.Tensor) : The KL loss, scalar float value | ||
|  |         """
 | ||
|  |         assert x.size(2) == self.size | ||
|  |         batch_size = x.size(0) | ||
|  |         x = x.view(-1, self.size) | ||
|  |         target = target.view(-1) | ||
|  |         # use zeros_like instead of torch.no_grad() for true_dist, | ||
|  |         # since no_grad() can not be exported by JIT | ||
|  |         true_dist = torch.zeros_like(x) | ||
|  |         true_dist.fill_(self.smoothing / (self.size - 1)) | ||
|  |         ignore = target == self.padding_idx  # (B,) | ||
|  |         total = len(target) - ignore.sum().item() | ||
|  |         target = target.masked_fill(ignore, 0)  # avoid -1 index | ||
|  |         true_dist.scatter_(1, target.unsqueeze(1), self.confidence) | ||
|  |         kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) | ||
|  |         denom = total if self.normalize_length else batch_size | ||
|  |         return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom |