451 lines
		
	
	
		
			19 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			451 lines
		
	
	
		
			19 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| # -*- coding: utf-8 -*-
 | |
| 
 | |
| # Copyright 2019 Mobvoi Inc. All Rights Reserved.
 | |
| # Author: di.wu@mobvoi.com (DI WU)
 | |
| """Encoder definition."""
 | |
| from typing import Tuple, List, Optional
 | |
| 
 | |
| import torch
 | |
| from typeguard import check_argument_types
 | |
| 
 | |
| from wenet.transformer.attention import MultiHeadedAttention
 | |
| from wenet.transformer.attention import RelPositionMultiHeadedAttention
 | |
| from wenet.transformer.convolution import ConvolutionModule
 | |
| from wenet.transformer.embedding import PositionalEncoding
 | |
| from wenet.transformer.embedding import RelPositionalEncoding
 | |
| from wenet.transformer.embedding import NoPositionalEncoding
 | |
| from wenet.transformer.encoder_layer import TransformerEncoderLayer
 | |
| from wenet.transformer.encoder_layer import ConformerEncoderLayer
 | |
| from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward
 | |
| from wenet.transformer.subsampling import Conv2dSubsampling4
 | |
| from wenet.transformer.subsampling import Conv2dSubsampling6
 | |
| from wenet.transformer.subsampling import Conv2dSubsampling8
 | |
| from wenet.transformer.subsampling import LinearNoSubsampling
 | |
| from wenet.utils.common import get_activation
 | |
| from wenet.utils.mask import make_pad_mask
 | |
| from wenet.utils.mask import add_optional_chunk_mask
 | |
| 
 | |
| 
 | |
| class BaseEncoder(torch.nn.Module):
 | |
|     def __init__(
 | |
|         self,
 | |
|         input_size: int,
 | |
|         output_size: int = 256,
 | |
|         attention_heads: int = 4,
 | |
|         linear_units: int = 2048,
 | |
|         num_blocks: int = 6,
 | |
|         dropout_rate: float = 0.1,
 | |
|         positional_dropout_rate: float = 0.1,
 | |
|         attention_dropout_rate: float = 0.0,
 | |
|         input_layer: str = "conv2d",
 | |
|         pos_enc_layer_type: str = "abs_pos",
 | |
|         normalize_before: bool = True,
 | |
|         concat_after: bool = False,
 | |
|         static_chunk_size: int = 0,
 | |
|         use_dynamic_chunk: bool = False,
 | |
|         global_cmvn: torch.nn.Module = None,
 | |
|         use_dynamic_left_chunk: bool = False,
 | |
|     ):
 | |
|         """
 | |
|         Args:
 | |
|             input_size (int): input dim
 | |
|             output_size (int): dimension of attention
 | |
|             attention_heads (int): the number of heads of multi head attention
 | |
|             linear_units (int): the hidden units number of position-wise feed
 | |
|                 forward
 | |
|             num_blocks (int): the number of decoder blocks
 | |
|             dropout_rate (float): dropout rate
 | |
|             attention_dropout_rate (float): dropout rate in attention
 | |
|             positional_dropout_rate (float): dropout rate after adding
 | |
|                 positional encoding
 | |
|             input_layer (str): input layer type.
 | |
|                 optional [linear, conv2d, conv2d6, conv2d8]
 | |
|             pos_enc_layer_type (str): Encoder positional encoding layer type.
 | |
|                 opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
 | |
|             normalize_before (bool):
 | |
|                 True: use layer_norm before each sub-block of a layer.
 | |
|                 False: use layer_norm after each sub-block of a layer.
 | |
|             concat_after (bool): whether to concat attention layer's input
 | |
|                 and output.
 | |
|                 True: x -> x + linear(concat(x, att(x)))
 | |
|                 False: x -> x + att(x)
 | |
|             static_chunk_size (int): chunk size for static chunk training and
 | |
|                 decoding
 | |
|             use_dynamic_chunk (bool): whether use dynamic chunk size for
 | |
|                 training or not, You can only use fixed chunk(chunk_size > 0)
 | |
|                 or dyanmic chunk size(use_dynamic_chunk = True)
 | |
|             global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
 | |
|             use_dynamic_left_chunk (bool): whether use dynamic left chunk in
 | |
|                 dynamic chunk training
 | |
|         """
 | |
|         assert check_argument_types()
 | |
|         super().__init__()
 | |
|         self._output_size = output_size
 | |
| 
 | |
|         if pos_enc_layer_type == "abs_pos":
 | |
|             pos_enc_class = PositionalEncoding
 | |
|         elif pos_enc_layer_type == "rel_pos":
 | |
|             pos_enc_class = RelPositionalEncoding
 | |
|         elif pos_enc_layer_type == "no_pos":
 | |
|             pos_enc_class = NoPositionalEncoding
 | |
|         else:
 | |
|             raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
 | |
| 
 | |
|         if input_layer == "linear":
 | |
|             subsampling_class = LinearNoSubsampling
 | |
|         elif input_layer == "conv2d":
 | |
|             subsampling_class = Conv2dSubsampling4
 | |
|         elif input_layer == "conv2d6":
 | |
|             subsampling_class = Conv2dSubsampling6
 | |
|         elif input_layer == "conv2d8":
 | |
|             subsampling_class = Conv2dSubsampling8
 | |
|         else:
 | |
|             raise ValueError("unknown input_layer: " + input_layer)
 | |
| 
 | |
|         self.global_cmvn = global_cmvn
 | |
|         self.embed = subsampling_class(
 | |
|             input_size,
 | |
|             output_size,
 | |
|             dropout_rate,
 | |
|             pos_enc_class(output_size, positional_dropout_rate),
 | |
|         )
 | |
| 
 | |
|         self.normalize_before = normalize_before
 | |
|         self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-12)
 | |
|         self.static_chunk_size = static_chunk_size
 | |
|         self.use_dynamic_chunk = use_dynamic_chunk
 | |
|         self.use_dynamic_left_chunk = use_dynamic_left_chunk
 | |
| 
 | |
|     def output_size(self) -> int:
 | |
|         return self._output_size
 | |
| 
 | |
|     def forward(
 | |
|         self,
 | |
|         xs: torch.Tensor,
 | |
|         xs_lens: torch.Tensor,
 | |
|         decoding_chunk_size: int = 0,
 | |
|         num_decoding_left_chunks: int = -1,
 | |
|     ) -> Tuple[torch.Tensor, torch.Tensor]:
 | |
|         """Embed positions in tensor.
 | |
| 
 | |
|         Args:
 | |
|             xs: padded input tensor (B, T, D)
 | |
|             xs_lens: input length (B)
 | |
|             decoding_chunk_size: decoding chunk size for dynamic chunk
 | |
|                 0: default for training, use random dynamic chunk.
 | |
|                 <0: for decoding, use full chunk.
 | |
|                 >0: for decoding, use fixed chunk size as set.
 | |
|             num_decoding_left_chunks: number of left chunks, this is for decoding,
 | |
|             the chunk size is decoding_chunk_size.
 | |
|                 >=0: use num_decoding_left_chunks
 | |
|                 <0: use all left chunks
 | |
|         Returns:
 | |
|             encoder output tensor xs, and subsampled masks
 | |
|             xs: padded output tensor (B, T' ~= T/subsample_rate, D)
 | |
|             masks: torch.Tensor batch padding mask after subsample
 | |
|                 (B, 1, T' ~= T/subsample_rate)
 | |
|         """
 | |
|         masks = ~make_pad_mask(xs_lens).unsqueeze(1)  # (B, 1, T)
 | |
|         if self.global_cmvn is not None:
 | |
|             xs = self.global_cmvn(xs)
 | |
|         xs, pos_emb, masks = self.embed(xs, masks)
 | |
|         mask_pad = masks  # (B, 1, T/subsample_rate)
 | |
|         chunk_masks = add_optional_chunk_mask(xs, masks,
 | |
|                                               self.use_dynamic_chunk,
 | |
|                                               self.use_dynamic_left_chunk,
 | |
|                                               decoding_chunk_size,
 | |
|                                               self.static_chunk_size,
 | |
|                                               num_decoding_left_chunks)
 | |
|         for layer in self.encoders:
 | |
|             xs, chunk_masks, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
 | |
|         if self.normalize_before:
 | |
|             xs = self.after_norm(xs)
 | |
|         # Here we assume the mask is not changed in encoder layers, so just
 | |
|         # return the masks before encoder layers, and the masks will be used
 | |
|         # for cross attention with decoder later
 | |
|         return xs, masks
 | |
| 
 | |
|     def forward_chunk(
 | |
|         self,
 | |
|         xs: torch.Tensor,
 | |
|         offset: int,
 | |
|         required_cache_size: int,
 | |
|         subsampling_cache: Optional[torch.Tensor] = None,
 | |
|         elayers_output_cache: Optional[List[torch.Tensor]] = None,
 | |
|         conformer_cnn_cache: Optional[List[torch.Tensor]] = None,
 | |
|     ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor],
 | |
|                List[torch.Tensor]]:
 | |
|         """ Forward just one chunk
 | |
| 
 | |
|         Args:
 | |
|             xs (torch.Tensor): chunk input
 | |
|             offset (int): current offset in encoder output time stamp
 | |
|             required_cache_size (int): cache size required for next chunk
 | |
|                 compuation
 | |
|                 >=0: actual cache size
 | |
|                 <0: means all history cache is required
 | |
|             subsampling_cache (Optional[torch.Tensor]): subsampling cache
 | |
|             elayers_output_cache (Optional[List[torch.Tensor]]):
 | |
|                 transformer/conformer encoder layers output cache
 | |
|             conformer_cnn_cache (Optional[List[torch.Tensor]]): conformer
 | |
|                 cnn cache
 | |
| 
 | |
|         Returns:
 | |
|             torch.Tensor: output of current input xs
 | |
|             torch.Tensor: subsampling cache required for next chunk computation
 | |
|             List[torch.Tensor]: encoder layers output cache required for next
 | |
|                 chunk computation
 | |
|             List[torch.Tensor]: conformer cnn cache
 | |
| 
 | |
|         """
 | |
|         assert xs.size(0) == 1
 | |
|         # tmp_masks is just for interface compatibility
 | |
|         tmp_masks = torch.ones(1,
 | |
|                                xs.size(1),
 | |
|                                device=xs.device,
 | |
|                                dtype=torch.bool)
 | |
|         tmp_masks = tmp_masks.unsqueeze(1)
 | |
|         if self.global_cmvn is not None:
 | |
|             xs = self.global_cmvn(xs)
 | |
|         xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
 | |
|         if subsampling_cache is not None:
 | |
|             cache_size = subsampling_cache.size(1)
 | |
|             xs = torch.cat((subsampling_cache, xs), dim=1)
 | |
|         else:
 | |
|             cache_size = 0
 | |
|         pos_emb = self.embed.position_encoding(offset - cache_size, xs.size(1))
 | |
|         if required_cache_size < 0:
 | |
|             next_cache_start = 0
 | |
|         elif required_cache_size == 0:
 | |
|             next_cache_start = xs.size(1)
 | |
|         else:
 | |
|             next_cache_start = max(xs.size(1) - required_cache_size, 0)
 | |
|         r_subsampling_cache = xs[:, next_cache_start:, :]
 | |
|         # Real mask for transformer/conformer layers
 | |
|         masks = torch.ones(1, xs.size(1), device=xs.device, dtype=torch.bool)
 | |
|         masks = masks.unsqueeze(1)
 | |
|         r_elayers_output_cache = []
 | |
|         r_conformer_cnn_cache = []
 | |
|         for i, layer in enumerate(self.encoders):
 | |
|             if elayers_output_cache is None:
 | |
|                 attn_cache = None
 | |
|             else:
 | |
|                 attn_cache = elayers_output_cache[i]
 | |
|             if conformer_cnn_cache is None:
 | |
|                 cnn_cache = None
 | |
|             else:
 | |
|                 cnn_cache = conformer_cnn_cache[i]
 | |
|             xs, _, new_cnn_cache = layer(xs,
 | |
|                                          masks,
 | |
|                                          pos_emb,
 | |
|                                          output_cache=attn_cache,
 | |
|                                          cnn_cache=cnn_cache)
 | |
|             r_elayers_output_cache.append(xs[:, next_cache_start:, :])
 | |
|             r_conformer_cnn_cache.append(new_cnn_cache)
 | |
|         if self.normalize_before:
 | |
|             xs = self.after_norm(xs)
 | |
| 
 | |
|         return (xs[:, cache_size:, :], r_subsampling_cache,
 | |
|                 r_elayers_output_cache, r_conformer_cnn_cache)
 | |
| 
 | |
|     def forward_chunk_by_chunk(
 | |
|         self,
 | |
|         xs: torch.Tensor,
 | |
|         decoding_chunk_size: int,
 | |
|         num_decoding_left_chunks: int = -1,
 | |
|     ) -> Tuple[torch.Tensor, torch.Tensor]:
 | |
|         """ Forward input chunk by chunk with chunk_size like a streaming
 | |
|             fashion
 | |
| 
 | |
|         Here we should pay special attention to computation cache in the
 | |
|         streaming style forward chunk by chunk. Three things should be taken
 | |
|         into account for computation in the current network:
 | |
|             1. transformer/conformer encoder layers output cache
 | |
|             2. convolution in conformer
 | |
|             3. convolution in subsampling
 | |
| 
 | |
|         However, we don't implement subsampling cache for:
 | |
|             1. We can control subsampling module to output the right result by
 | |
|                overlapping input instead of cache left context, even though it
 | |
|                wastes some computation, but subsampling only takes a very
 | |
|                small fraction of computation in the whole model.
 | |
|             2. Typically, there are several covolution layers with subsampling
 | |
|                in subsampling module, it is tricky and complicated to do cache
 | |
|                with different convolution layers with different subsampling
 | |
|                rate.
 | |
|             3. Currently, nn.Sequential is used to stack all the convolution
 | |
|                layers in subsampling, we need to rewrite it to make it work
 | |
|                with cache, which is not prefered.
 | |
|         Args:
 | |
|             xs (torch.Tensor): (1, max_len, dim)
 | |
|             chunk_size (int): decoding chunk size
 | |
|         """
 | |
|         assert decoding_chunk_size > 0
 | |
|         # The model is trained by static or dynamic chunk
 | |
|         assert self.static_chunk_size > 0 or self.use_dynamic_chunk
 | |
|         subsampling = self.embed.subsampling_rate
 | |
|         context = self.embed.right_context + 1  # Add current frame
 | |
|         stride = subsampling * decoding_chunk_size
 | |
|         decoding_window = (decoding_chunk_size - 1) * subsampling + context
 | |
|         num_frames = xs.size(1)
 | |
|         subsampling_cache: Optional[torch.Tensor] = None
 | |
|         elayers_output_cache: Optional[List[torch.Tensor]] = None
 | |
|         conformer_cnn_cache: Optional[List[torch.Tensor]] = None
 | |
|         outputs = []
 | |
|         offset = 0
 | |
|         required_cache_size = decoding_chunk_size * num_decoding_left_chunks
 | |
| 
 | |
|         # Feed forward overlap input step by step
 | |
|         for cur in range(0, num_frames - context + 1, stride):
 | |
|             end = min(cur + decoding_window, num_frames)
 | |
|             chunk_xs = xs[:, cur:end, :]
 | |
|             (y, subsampling_cache, elayers_output_cache,
 | |
|              conformer_cnn_cache) = self.forward_chunk(chunk_xs, offset,
 | |
|                                                        required_cache_size,
 | |
|                                                        subsampling_cache,
 | |
|                                                        elayers_output_cache,
 | |
|                                                        conformer_cnn_cache)
 | |
|             outputs.append(y)
 | |
|             offset += y.size(1)
 | |
|         ys = torch.cat(outputs, 1)
 | |
|         masks = torch.ones(1, ys.size(1), device=ys.device, dtype=torch.bool)
 | |
|         masks = masks.unsqueeze(1)
 | |
|         return ys, masks
 | |
| 
 | |
| 
 | |
| class TransformerEncoder(BaseEncoder):
 | |
|     """Transformer encoder module."""
 | |
|     def __init__(
 | |
|         self,
 | |
|         input_size: int,
 | |
|         output_size: int = 256,
 | |
|         attention_heads: int = 4,
 | |
|         linear_units: int = 2048,
 | |
|         num_blocks: int = 6,
 | |
|         dropout_rate: float = 0.1,
 | |
|         positional_dropout_rate: float = 0.1,
 | |
|         attention_dropout_rate: float = 0.0,
 | |
|         input_layer: str = "conv2d",
 | |
|         pos_enc_layer_type: str = "abs_pos",
 | |
|         normalize_before: bool = True,
 | |
|         concat_after: bool = False,
 | |
|         static_chunk_size: int = 0,
 | |
|         use_dynamic_chunk: bool = False,
 | |
|         global_cmvn: torch.nn.Module = None,
 | |
|         use_dynamic_left_chunk: bool = False,
 | |
|     ):
 | |
|         """ Construct TransformerEncoder
 | |
| 
 | |
|         See Encoder for the meaning of each parameter.
 | |
|         """
 | |
|         assert check_argument_types()
 | |
|         super().__init__(input_size, output_size, attention_heads,
 | |
|                          linear_units, num_blocks, dropout_rate,
 | |
|                          positional_dropout_rate, attention_dropout_rate,
 | |
|                          input_layer, pos_enc_layer_type, normalize_before,
 | |
|                          concat_after, static_chunk_size, use_dynamic_chunk,
 | |
|                          global_cmvn, use_dynamic_left_chunk)
 | |
|         self.encoders = torch.nn.ModuleList([
 | |
|             TransformerEncoderLayer(
 | |
|                 output_size,
 | |
|                 MultiHeadedAttention(attention_heads, output_size,
 | |
|                                      attention_dropout_rate),
 | |
|                 PositionwiseFeedForward(output_size, linear_units,
 | |
|                                         dropout_rate), dropout_rate,
 | |
|                 normalize_before, concat_after) for _ in range(num_blocks)
 | |
|         ])
 | |
| 
 | |
| 
 | |
| class ConformerEncoder(BaseEncoder):
 | |
|     """Conformer encoder module."""
 | |
|     def __init__(
 | |
|         self,
 | |
|         input_size: int,
 | |
|         output_size: int = 256,
 | |
|         attention_heads: int = 4,
 | |
|         linear_units: int = 2048,
 | |
|         num_blocks: int = 6,
 | |
|         dropout_rate: float = 0.1,
 | |
|         positional_dropout_rate: float = 0.1,
 | |
|         attention_dropout_rate: float = 0.0,
 | |
|         input_layer: str = "conv2d",
 | |
|         pos_enc_layer_type: str = "rel_pos",
 | |
|         normalize_before: bool = True,
 | |
|         concat_after: bool = False,
 | |
|         static_chunk_size: int = 0,
 | |
|         use_dynamic_chunk: bool = False,
 | |
|         global_cmvn: torch.nn.Module = None,
 | |
|         use_dynamic_left_chunk: bool = False,
 | |
|         positionwise_conv_kernel_size: int = 1,
 | |
|         macaron_style: bool = True,
 | |
|         selfattention_layer_type: str = "rel_selfattn",
 | |
|         activation_type: str = "swish",
 | |
|         use_cnn_module: bool = True,
 | |
|         cnn_module_kernel: int = 15,
 | |
|         causal: bool = False,
 | |
|         cnn_module_norm: str = "batch_norm",
 | |
|     ):
 | |
|         """Construct ConformerEncoder
 | |
| 
 | |
|         Args:
 | |
|             input_size to use_dynamic_chunk, see in BaseEncoder
 | |
|             positionwise_conv_kernel_size (int): Kernel size of positionwise
 | |
|                 conv1d layer.
 | |
|             macaron_style (bool): Whether to use macaron style for
 | |
|                 positionwise layer.
 | |
|             selfattention_layer_type (str): Encoder attention layer type,
 | |
|                 the parameter has no effect now, it's just for configure
 | |
|                 compatibility.
 | |
|             activation_type (str): Encoder activation function type.
 | |
|             use_cnn_module (bool): Whether to use convolution module.
 | |
|             cnn_module_kernel (int): Kernel size of convolution module.
 | |
|             causal (bool): whether to use causal convolution or not.
 | |
|         """
 | |
|         assert check_argument_types()
 | |
|         super().__init__(input_size, output_size, attention_heads,
 | |
|                          linear_units, num_blocks, dropout_rate,
 | |
|                          positional_dropout_rate, attention_dropout_rate,
 | |
|                          input_layer, pos_enc_layer_type, normalize_before,
 | |
|                          concat_after, static_chunk_size, use_dynamic_chunk,
 | |
|                          global_cmvn, use_dynamic_left_chunk)
 | |
|         activation = get_activation(activation_type)
 | |
| 
 | |
|         # self-attention module definition
 | |
|         if pos_enc_layer_type == "no_pos":
 | |
|             encoder_selfattn_layer = MultiHeadedAttention
 | |
|         else:
 | |
|             encoder_selfattn_layer = RelPositionMultiHeadedAttention
 | |
|         encoder_selfattn_layer_args = (
 | |
|             attention_heads,
 | |
|             output_size,
 | |
|             attention_dropout_rate,
 | |
|         )
 | |
|         # feed-forward module definition
 | |
|         positionwise_layer = PositionwiseFeedForward
 | |
|         positionwise_layer_args = (
 | |
|             output_size,
 | |
|             linear_units,
 | |
|             dropout_rate,
 | |
|             activation,
 | |
|         )
 | |
|         # convolution module definition
 | |
|         convolution_layer = ConvolutionModule
 | |
|         convolution_layer_args = (output_size, cnn_module_kernel, activation,
 | |
|                                   cnn_module_norm, causal)
 | |
| 
 | |
|         self.encoders = torch.nn.ModuleList([
 | |
|             ConformerEncoderLayer(
 | |
|                 output_size,
 | |
|                 encoder_selfattn_layer(*encoder_selfattn_layer_args),
 | |
|                 positionwise_layer(*positionwise_layer_args),
 | |
|                 positionwise_layer(
 | |
|                     *positionwise_layer_args) if macaron_style else None,
 | |
|                 convolution_layer(
 | |
|                     *convolution_layer_args) if use_cnn_module else None,
 | |
|                 dropout_rate,
 | |
|                 normalize_before,
 | |
|                 concat_after,
 | |
|             ) for _ in range(num_blocks)
 | |
|         ])
 | 
