269 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			269 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | ||
| # -*- coding: utf-8 -*-
 | ||
| 
 | ||
| # Copyright 2019 Mobvoi Inc. All Rights Reserved.
 | ||
| # Author: di.wu@mobvoi.com (DI WU)
 | ||
| """Encoder self-attention layer definition."""
 | ||
| 
 | ||
| from typing import Optional, Tuple
 | ||
| 
 | ||
| import torch
 | ||
| from torch import nn
 | ||
| 
 | ||
| 
 | ||
| class TransformerEncoderLayer(nn.Module):
 | ||
|     """Encoder layer module.
 | ||
| 
 | ||
|     Args:
 | ||
|         size (int): Input dimension.
 | ||
|         self_attn (torch.nn.Module): Self-attention module instance.
 | ||
|             `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
 | ||
|             instance can be used as the argument.
 | ||
|         feed_forward (torch.nn.Module): Feed-forward module instance.
 | ||
|             `PositionwiseFeedForward`, instance can be used as the argument.
 | ||
|         dropout_rate (float): Dropout rate.
 | ||
|         normalize_before (bool):
 | ||
|             True: use layer_norm before each sub-block.
 | ||
|             False: to use layer_norm after each sub-block.
 | ||
|         concat_after (bool): Whether to concat attention layer's input and
 | ||
|             output.
 | ||
|             True: x -> x + linear(concat(x, att(x)))
 | ||
|             False: x -> x + att(x)
 | ||
| 
 | ||
|     """
 | ||
|     def __init__(
 | ||
|         self,
 | ||
|         size: int,
 | ||
|         self_attn: torch.nn.Module,
 | ||
|         feed_forward: torch.nn.Module,
 | ||
|         dropout_rate: float,
 | ||
|         normalize_before: bool = True,
 | ||
|         concat_after: bool = False,
 | ||
|     ):
 | ||
|         """Construct an EncoderLayer object."""
 | ||
|         super().__init__()
 | ||
|         self.self_attn = self_attn
 | ||
|         self.feed_forward = feed_forward
 | ||
|         self.norm1 = nn.LayerNorm(size, eps=1e-12)
 | ||
|         self.norm2 = nn.LayerNorm(size, eps=1e-12)
 | ||
|         self.dropout = nn.Dropout(dropout_rate)
 | ||
|         self.size = size
 | ||
|         self.normalize_before = normalize_before
 | ||
|         self.concat_after = concat_after
 | ||
|         # concat_linear may be not used in forward fuction,
 | ||
|         # but will be saved in the *.pt
 | ||
|         self.concat_linear = nn.Linear(size + size, size)
 | ||
| 
 | ||
|     def forward(
 | ||
|         self,
 | ||
|         x: torch.Tensor,
 | ||
|         mask: torch.Tensor,
 | ||
|         pos_emb: torch.Tensor,
 | ||
|         mask_pad: Optional[torch.Tensor] = None,
 | ||
|         output_cache: Optional[torch.Tensor] = None,
 | ||
|         cnn_cache: Optional[torch.Tensor] = None,
 | ||
|     ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 | ||
|         """Compute encoded features.
 | ||
| 
 | ||
|         Args:
 | ||
|             x (torch.Tensor): Input tensor (#batch, time, size).
 | ||
|             mask (torch.Tensor): Mask tensor for the input (#batch, time).
 | ||
|             pos_emb (torch.Tensor): just for interface compatibility
 | ||
|                 to ConformerEncoderLayer
 | ||
|             mask_pad (torch.Tensor): does not used in transformer layer,
 | ||
|                 just for unified api with conformer.
 | ||
|             output_cache (torch.Tensor): Cache tensor of the output
 | ||
|                 (#batch, time2, size), time2 < time in x.
 | ||
|             cnn_cache (torch.Tensor): not used here, it's for interface
 | ||
|                 compatibility to ConformerEncoderLayer
 | ||
|         Returns:
 | ||
|             torch.Tensor: Output tensor (#batch, time, size).
 | ||
|             torch.Tensor: Mask tensor (#batch, time).
 | ||
| 
 | ||
|         """
 | ||
|         residual = x
 | ||
|         if self.normalize_before:
 | ||
|             x = self.norm1(x)
 | ||
| 
 | ||
|         if output_cache is None:
 | ||
|             x_q = x
 | ||
|         else:
 | ||
|             assert output_cache.size(0) == x.size(0)
 | ||
|             assert output_cache.size(2) == self.size
 | ||
|             assert output_cache.size(1) < x.size(1)
 | ||
|             chunk = x.size(1) - output_cache.size(1)
 | ||
|             x_q = x[:, -chunk:, :]
 | ||
|             residual = residual[:, -chunk:, :]
 | ||
|             mask = mask[:, -chunk:, :]
 | ||
| 
 | ||
|         if self.concat_after:
 | ||
|             x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
 | ||
|             x = residual + self.concat_linear(x_concat)
 | ||
|         else:
 | ||
|             x = residual + self.dropout(self.self_attn(x_q, x, x, mask))
 | ||
|         if not self.normalize_before:
 | ||
|             x = self.norm1(x)
 | ||
| 
 | ||
|         residual = x
 | ||
|         if self.normalize_before:
 | ||
|             x = self.norm2(x)
 | ||
|         x = residual + self.dropout(self.feed_forward(x))
 | ||
|         if not self.normalize_before:
 | ||
|             x = self.norm2(x)
 | ||
| 
 | ||
|         if output_cache is not None:
 | ||
|             x = torch.cat([output_cache, x], dim=1)
 | ||
| 
 | ||
|         fake_cnn_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device)
 | ||
|         return x, mask, fake_cnn_cache
 | ||
| 
 | ||
| 
 | ||
| class ConformerEncoderLayer(nn.Module):
 | ||
|     """Encoder layer module.
 | ||
|     Args:
 | ||
|         size (int): Input dimension.
 | ||
|         self_attn (torch.nn.Module): Self-attention module instance.
 | ||
|             `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
 | ||
|             instance can be used as the argument.
 | ||
|         feed_forward (torch.nn.Module): Feed-forward module instance.
 | ||
|             `PositionwiseFeedForward` instance can be used as the argument.
 | ||
|         feed_forward_macaron (torch.nn.Module): Additional feed-forward module
 | ||
|              instance.
 | ||
|             `PositionwiseFeedForward` instance can be used as the argument.
 | ||
|         conv_module (torch.nn.Module): Convolution module instance.
 | ||
|             `ConvlutionModule` instance can be used as the argument.
 | ||
|         dropout_rate (float): Dropout rate.
 | ||
|         normalize_before (bool):
 | ||
|             True: use layer_norm before each sub-block.
 | ||
|             False: use layer_norm after each sub-block.
 | ||
|         concat_after (bool): Whether to concat attention layer's input and
 | ||
|             output.
 | ||
|             True: x -> x + linear(concat(x, att(x)))
 | ||
|             False: x -> x + att(x)
 | ||
|     """
 | ||
|     def __init__(
 | ||
|         self,
 | ||
|         size: int,
 | ||
|         self_attn: torch.nn.Module,
 | ||
|         feed_forward: Optional[nn.Module] = None,
 | ||
|         feed_forward_macaron: Optional[nn.Module] = None,
 | ||
|         conv_module: Optional[nn.Module] = None,
 | ||
|         dropout_rate: float = 0.1,
 | ||
|         normalize_before: bool = True,
 | ||
|         concat_after: bool = False,
 | ||
|     ):
 | ||
|         """Construct an EncoderLayer object."""
 | ||
|         super().__init__()
 | ||
|         self.self_attn = self_attn
 | ||
|         self.feed_forward = feed_forward
 | ||
|         self.feed_forward_macaron = feed_forward_macaron
 | ||
|         self.conv_module = conv_module
 | ||
|         self.norm_ff = nn.LayerNorm(size, eps=1e-12)  # for the FNN module
 | ||
|         self.norm_mha = nn.LayerNorm(size, eps=1e-12)  # for the MHA module
 | ||
|         if feed_forward_macaron is not None:
 | ||
|             self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
 | ||
|             self.ff_scale = 0.5
 | ||
|         else:
 | ||
|             self.ff_scale = 1.0
 | ||
|         if self.conv_module is not None:
 | ||
|             self.norm_conv = nn.LayerNorm(size,
 | ||
|                                           eps=1e-12)  # for the CNN module
 | ||
|             self.norm_final = nn.LayerNorm(
 | ||
|                 size, eps=1e-12)  # for the final output of the block
 | ||
|         self.dropout = nn.Dropout(dropout_rate)
 | ||
|         self.size = size
 | ||
|         self.normalize_before = normalize_before
 | ||
|         self.concat_after = concat_after
 | ||
|         self.concat_linear = nn.Linear(size + size, size)
 | ||
| 
 | ||
|     def forward(
 | ||
|         self,
 | ||
|         x: torch.Tensor,
 | ||
|         mask: torch.Tensor,
 | ||
|         pos_emb: torch.Tensor,
 | ||
|         mask_pad: Optional[torch.Tensor] = None,
 | ||
|         output_cache: Optional[torch.Tensor] = None,
 | ||
|         cnn_cache: Optional[torch.Tensor] = None,
 | ||
|     ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 | ||
|         """Compute encoded features.
 | ||
| 
 | ||
|         Args:
 | ||
|             x (torch.Tensor): (#batch, time, size)
 | ||
|             mask (torch.Tensor): Mask tensor for the input (#batch, time,time).
 | ||
|             pos_emb (torch.Tensor): positional encoding, must not be None
 | ||
|                 for ConformerEncoderLayer.
 | ||
|             mask_pad (torch.Tensor): batch padding mask used for conv module.
 | ||
|                 (#batch, 1,time)
 | ||
|             output_cache (torch.Tensor): Cache tensor of the output
 | ||
|                 (#batch, time2, size), time2 < time in x.
 | ||
|             cnn_cache (torch.Tensor): Convolution cache in conformer layer
 | ||
|         Returns:
 | ||
|             torch.Tensor: Output tensor (#batch, time, size).
 | ||
|             torch.Tensor: Mask tensor (#batch, time).
 | ||
|         """
 | ||
| 
 | ||
|         # whether to use macaron style
 | ||
|         if self.feed_forward_macaron is not None:
 | ||
|             residual = x
 | ||
|             if self.normalize_before:
 | ||
|                 x = self.norm_ff_macaron(x)
 | ||
|             x = residual + self.ff_scale * self.dropout(
 | ||
|                 self.feed_forward_macaron(x))
 | ||
|             if not self.normalize_before:
 | ||
|                 x = self.norm_ff_macaron(x)
 | ||
| 
 | ||
|         # multi-headed self-attention module
 | ||
|         residual = x
 | ||
|         if self.normalize_before:
 | ||
|             x = self.norm_mha(x)
 | ||
| 
 | ||
|         if output_cache is None:
 | ||
|             x_q = x
 | ||
|         else:
 | ||
|             assert output_cache.size(0) == x.size(0)
 | ||
|             assert output_cache.size(2) == self.size
 | ||
|             assert output_cache.size(1) < x.size(1)
 | ||
|             chunk = x.size(1) - output_cache.size(1)
 | ||
|             x_q = x[:, -chunk:, :]
 | ||
|             residual = residual[:, -chunk:, :]
 | ||
|             mask = mask[:, -chunk:, :]
 | ||
| 
 | ||
|         x_att = self.self_attn(x_q, x, x, mask, pos_emb)
 | ||
|         if self.concat_after:
 | ||
|             x_concat = torch.cat((x, x_att), dim=-1)
 | ||
|             x = residual + self.concat_linear(x_concat)
 | ||
|         else:
 | ||
|             x = residual + self.dropout(x_att)
 | ||
|         if not self.normalize_before:
 | ||
|             x = self.norm_mha(x)
 | ||
| 
 | ||
|         # convolution module
 | ||
|         # Fake new cnn cache here, and then change it in conv_module
 | ||
|         new_cnn_cache = torch.tensor([0.0], dtype=x.dtype, device=x.device)
 | ||
|         if self.conv_module is not None:
 | ||
|             residual = x
 | ||
|             if self.normalize_before:
 | ||
|                 x = self.norm_conv(x)
 | ||
|             x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
 | ||
|             x = residual + self.dropout(x)
 | ||
| 
 | ||
|             if not self.normalize_before:
 | ||
|                 x = self.norm_conv(x)
 | ||
| 
 | ||
|         # feed forward module
 | ||
|         residual = x
 | ||
|         if self.normalize_before:
 | ||
|             x = self.norm_ff(x)
 | ||
| 
 | ||
|         x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
 | ||
|         if not self.normalize_before:
 | ||
|             x = self.norm_ff(x)
 | ||
| 
 | ||
|         if self.conv_module is not None:
 | ||
|             x = self.norm_final(x)
 | ||
| 
 | ||
|         if output_cache is not None:
 | ||
|             x = torch.cat([output_cache, x], dim=1)
 | ||
| 
 | ||
|         return x, mask, new_cnn_cache
 | 
