70 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			70 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | import torch | ||
|  | import torch.nn.functional as F | ||
|  | from typeguard import check_argument_types | ||
|  | 
 | ||
|  | 
 | ||
|  | class CTC(torch.nn.Module): | ||
|  |     """CTC module""" | ||
|  |     def __init__( | ||
|  |         self, | ||
|  |         odim: int, | ||
|  |         encoder_output_size: int, | ||
|  |         dropout_rate: float = 0.0, | ||
|  |         reduce: bool = True, | ||
|  |     ): | ||
|  |         """ Construct CTC module
 | ||
|  |         Args: | ||
|  |             odim: dimension of outputs | ||
|  |             encoder_output_size: number of encoder projection units | ||
|  |             dropout_rate: dropout rate (0.0 ~ 1.0) | ||
|  |             reduce: reduce the CTC loss into a scalar | ||
|  |         """
 | ||
|  |         assert check_argument_types() | ||
|  |         super().__init__() | ||
|  |         eprojs = encoder_output_size | ||
|  |         self.dropout_rate = dropout_rate | ||
|  |         self.ctc_lo = torch.nn.Linear(eprojs, odim) | ||
|  | 
 | ||
|  |         reduction_type = "sum" if reduce else "none" | ||
|  |         self.ctc_loss = torch.nn.CTCLoss(reduction=reduction_type) | ||
|  | 
 | ||
|  |     def forward(self, hs_pad: torch.Tensor, hlens: torch.Tensor, | ||
|  |                 ys_pad: torch.Tensor, ys_lens: torch.Tensor) -> torch.Tensor: | ||
|  |         """Calculate CTC loss.
 | ||
|  | 
 | ||
|  |         Args: | ||
|  |             hs_pad: batch of padded hidden state sequences (B, Tmax, D) | ||
|  |             hlens: batch of lengths of hidden state sequences (B) | ||
|  |             ys_pad: batch of padded character id sequence tensor (B, Lmax) | ||
|  |             ys_lens: batch of lengths of character sequence (B) | ||
|  |         """
 | ||
|  |         # hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab) | ||
|  |         ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate)) | ||
|  |         # ys_hat: (B, L, D) -> (L, B, D) | ||
|  |         ys_hat = ys_hat.transpose(0, 1) | ||
|  |         ys_hat = ys_hat.log_softmax(2) | ||
|  |         loss = self.ctc_loss(ys_hat, ys_pad, hlens, ys_lens) | ||
|  |         # Batch-size average | ||
|  |         loss = loss / ys_hat.size(1) | ||
|  |         return loss | ||
|  | 
 | ||
|  |     def log_softmax(self, hs_pad: torch.Tensor) -> torch.Tensor: | ||
|  |         """log_softmax of frame activations
 | ||
|  | 
 | ||
|  |         Args: | ||
|  |             Tensor hs_pad: 3d tensor (B, Tmax, eprojs) | ||
|  |         Returns: | ||
|  |             torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim) | ||
|  |         """
 | ||
|  |         return F.log_softmax(self.ctc_lo(hs_pad), dim=2) | ||
|  | 
 | ||
|  |     def argmax(self, hs_pad: torch.Tensor) -> torch.Tensor: | ||
|  |         """argmax of frame activations
 | ||
|  | 
 | ||
|  |         Args: | ||
|  |             torch.Tensor hs_pad: 3d tensor (B, Tmax, eprojs) | ||
|  |         Returns: | ||
|  |             torch.Tensor: argmax applied 2d tensor (B, Tmax) | ||
|  |         """
 | ||
|  |         return torch.argmax(self.ctc_lo(hs_pad), dim=2) |