187 lines
		
	
	
		
			5.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			187 lines
		
	
	
		
			5.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """Unility functions for Transformer."""
 | |
| 
 | |
| import math
 | |
| from typing import Tuple, List
 | |
| 
 | |
| import torch
 | |
| from torch.nn.utils.rnn import pad_sequence
 | |
| 
 | |
| IGNORE_ID = -1
 | |
| 
 | |
| 
 | |
| def pad_list(xs: List[torch.Tensor], pad_value: int):
 | |
|     """Perform padding for the list of tensors.
 | |
| 
 | |
|     Args:
 | |
|         xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
 | |
|         pad_value (float): Value for padding.
 | |
| 
 | |
|     Returns:
 | |
|         Tensor: Padded tensor (B, Tmax, `*`).
 | |
| 
 | |
|     Examples:
 | |
|         >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
 | |
|         >>> x
 | |
|         [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
 | |
|         >>> pad_list(x, 0)
 | |
|         tensor([[1., 1., 1., 1.],
 | |
|                 [1., 1., 0., 0.],
 | |
|                 [1., 0., 0., 0.]])
 | |
| 
 | |
|     """
 | |
|     n_batch = len(xs)
 | |
|     max_len = max([x.size(0) for x in xs])
 | |
|     pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device)
 | |
|     pad = pad.fill_(pad_value)
 | |
|     for i in range(n_batch):
 | |
|         pad[i, :xs[i].size(0)] = xs[i]
 | |
| 
 | |
|     return pad
 | |
| 
 | |
| 
 | |
| def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int,
 | |
|                 ignore_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
 | |
|     """Add <sos> and <eos> labels.
 | |
| 
 | |
|     Args:
 | |
|         ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)
 | |
|         sos (int): index of <sos>
 | |
|         eos (int): index of <eeos>
 | |
|         ignore_id (int): index of padding
 | |
| 
 | |
|     Returns:
 | |
|         ys_in (torch.Tensor) : (B, Lmax + 1)
 | |
|         ys_out (torch.Tensor) : (B, Lmax + 1)
 | |
| 
 | |
|     Examples:
 | |
|         >>> sos_id = 10
 | |
|         >>> eos_id = 11
 | |
|         >>> ignore_id = -1
 | |
|         >>> ys_pad
 | |
|         tensor([[ 1,  2,  3,  4,  5],
 | |
|                 [ 4,  5,  6, -1, -1],
 | |
|                 [ 7,  8,  9, -1, -1]], dtype=torch.int32)
 | |
|         >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
 | |
|         >>> ys_in
 | |
|         tensor([[10,  1,  2,  3,  4,  5],
 | |
|                 [10,  4,  5,  6, 11, 11],
 | |
|                 [10,  7,  8,  9, 11, 11]])
 | |
|         >>> ys_out
 | |
|         tensor([[ 1,  2,  3,  4,  5, 11],
 | |
|                 [ 4,  5,  6, 11, -1, -1],
 | |
|                 [ 7,  8,  9, 11, -1, -1]])
 | |
|     """
 | |
|     _sos = torch.tensor([sos],
 | |
|                         dtype=torch.long,
 | |
|                         requires_grad=False,
 | |
|                         device=ys_pad.device)
 | |
|     _eos = torch.tensor([eos],
 | |
|                         dtype=torch.long,
 | |
|                         requires_grad=False,
 | |
|                         device=ys_pad.device)
 | |
|     ys = [y[y != ignore_id] for y in ys_pad]  # parse padded ys
 | |
|     ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
 | |
|     ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
 | |
|     return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
 | |
| 
 | |
| 
 | |
| def reverse_pad_list(ys_pad: torch.Tensor,
 | |
|                      ys_lens: torch.Tensor,
 | |
|                      pad_value: float = -1.0) -> torch.Tensor:
 | |
|     """Reverse padding for the list of tensors.
 | |
| 
 | |
|     Args:
 | |
|         ys_pad (tensor): The padded tensor (B, Tokenmax).
 | |
|         ys_lens (tensor): The lens of token seqs (B)
 | |
|         pad_value (int): Value for padding.
 | |
| 
 | |
|     Returns:
 | |
|         Tensor: Padded tensor (B, Tokenmax).
 | |
| 
 | |
|     Examples:
 | |
|         >>> x
 | |
|         tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
 | |
|         >>> pad_list(x, 0)
 | |
|         tensor([[4, 3, 2, 1],
 | |
|                 [7, 6, 5, 0],
 | |
|                 [9, 8, 0, 0]])
 | |
| 
 | |
|     """
 | |
|     r_ys_pad = pad_sequence([(torch.flip(y.int()[:i], [0]))
 | |
|                              for y, i in zip(ys_pad, ys_lens)], True,
 | |
|                             pad_value)
 | |
|     return r_ys_pad
 | |
| 
 | |
| 
 | |
| def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
 | |
|                 ignore_label: int) -> float:
 | |
|     """Calculate accuracy.
 | |
| 
 | |
|     Args:
 | |
|         pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
 | |
|         pad_targets (LongTensor): Target label tensors (B, Lmax, D).
 | |
|         ignore_label (int): Ignore label id.
 | |
| 
 | |
|     Returns:
 | |
|         float: Accuracy value (0.0 - 1.0).
 | |
| 
 | |
|     """
 | |
|     pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
 | |
|                                 pad_outputs.size(1)).argmax(2)
 | |
|     mask = pad_targets != ignore_label
 | |
|     numerator = torch.sum(
 | |
|         pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
 | |
|     denominator = torch.sum(mask)
 | |
|     return float(numerator) / float(denominator)
 | |
| 
 | |
| 
 | |
| def get_activation(act):
 | |
|     """Return activation function."""
 | |
|     # Lazy load to avoid unused import
 | |
|     from wenet.transformer.swish import Swish
 | |
| 
 | |
|     activation_funcs = {
 | |
|         "hardtanh": torch.nn.Hardtanh,
 | |
|         "tanh": torch.nn.Tanh,
 | |
|         "relu": torch.nn.ReLU,
 | |
|         "selu": torch.nn.SELU,
 | |
|         "swish": Swish,
 | |
|         "gelu": torch.nn.GELU
 | |
|     }
 | |
| 
 | |
|     return activation_funcs[act]()
 | |
| 
 | |
| 
 | |
| def get_subsample(config):
 | |
|     input_layer = config["encoder_conf"]["input_layer"]
 | |
|     assert input_layer in ["conv2d", "conv2d6", "conv2d8"]
 | |
|     if input_layer == "conv2d":
 | |
|         return 4
 | |
|     elif input_layer == "conv2d6":
 | |
|         return 6
 | |
|     elif input_layer == "conv2d8":
 | |
|         return 8
 | |
| 
 | |
| 
 | |
| def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
 | |
|     new_hyp: List[int] = []
 | |
|     cur = 0
 | |
|     while cur < len(hyp):
 | |
|         if hyp[cur] != 0:
 | |
|             new_hyp.append(hyp[cur])
 | |
|         prev = cur
 | |
|         while cur < len(hyp) and hyp[cur] == hyp[prev]:
 | |
|             cur += 1
 | |
|     return new_hyp
 | |
| 
 | |
| 
 | |
| def log_add(args: List[int]) -> float:
 | |
|     """
 | |
|     Stable log add
 | |
|     """
 | |
|     if all(a == -float('inf') for a in args):
 | |
|         return -float('inf')
 | |
|     a_max = max(args)
 | |
|     lsp = math.log(sum(math.exp(a - a_max) for a in args))
 | |
|     return a_max + lsp
 | 
