132 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			132 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| # -*- coding: utf-8 -*-
 | |
| 
 | |
| # Copyright 2019 Mobvoi Inc. All Rights Reserved.
 | |
| # Author: di.wu@mobvoi.com (DI WU)
 | |
| """Positonal Encoding Module."""
 | |
| 
 | |
| import math
 | |
| from typing import Tuple
 | |
| 
 | |
| import torch
 | |
| 
 | |
| 
 | |
| class PositionalEncoding(torch.nn.Module):
 | |
|     """Positional encoding.
 | |
| 
 | |
|     :param int d_model: embedding dim
 | |
|     :param float dropout_rate: dropout rate
 | |
|     :param int max_len: maximum input length
 | |
| 
 | |
|     PE(pos, 2i)   = sin(pos/(10000^(2i/dmodel)))
 | |
|     PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
 | |
|     """
 | |
|     def __init__(self,
 | |
|                  d_model: int,
 | |
|                  dropout_rate: float,
 | |
|                  max_len: int = 5000,
 | |
|                  reverse: bool = False):
 | |
|         """Construct an PositionalEncoding object."""
 | |
|         super().__init__()
 | |
|         self.d_model = d_model
 | |
|         self.xscale = math.sqrt(self.d_model)
 | |
|         self.dropout = torch.nn.Dropout(p=dropout_rate)
 | |
|         self.max_len = max_len
 | |
| 
 | |
|         self.pe = torch.zeros(self.max_len, self.d_model)
 | |
|         position = torch.arange(0, self.max_len,
 | |
|                                 dtype=torch.float32).unsqueeze(1)
 | |
|         div_term = torch.exp(
 | |
|             torch.arange(0, self.d_model, 2, dtype=torch.float32) *
 | |
|             -(math.log(10000.0) / self.d_model))
 | |
|         self.pe[:, 0::2] = torch.sin(position * div_term)
 | |
|         self.pe[:, 1::2] = torch.cos(position * div_term)
 | |
|         self.pe = self.pe.unsqueeze(0)
 | |
| 
 | |
|     def forward(self,
 | |
|                 x: torch.Tensor,
 | |
|                 offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
 | |
|         """Add positional encoding.
 | |
| 
 | |
|         Args:
 | |
|             x (torch.Tensor): Input. Its shape is (batch, time, ...)
 | |
|             offset (int): position offset
 | |
| 
 | |
|         Returns:
 | |
|             torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
 | |
|             torch.Tensor: for compatibility to RelPositionalEncoding
 | |
|         """
 | |
|         assert offset + x.size(1) < self.max_len
 | |
|         self.pe = self.pe.to(x.device)
 | |
|         pos_emb = self.pe[:, offset:offset + x.size(1)]
 | |
|         x = x * self.xscale + pos_emb
 | |
|         return self.dropout(x), self.dropout(pos_emb)
 | |
| 
 | |
|     def position_encoding(self, offset: int, size: int) -> torch.Tensor:
 | |
|         """ For getting encoding in a streaming fashion
 | |
| 
 | |
|         Attention!!!!!
 | |
|         we apply dropout only once at the whole utterance level in a none
 | |
|         streaming way, but will call this function several times with
 | |
|         increasing input size in a streaming scenario, so the dropout will
 | |
|         be applied several times.
 | |
| 
 | |
|         Args:
 | |
|             offset (int): start offset
 | |
|             size (int): requried size of position encoding
 | |
| 
 | |
|         Returns:
 | |
|             torch.Tensor: Corresponding encoding
 | |
|         """
 | |
|         assert offset + size < self.max_len
 | |
|         return self.dropout(self.pe[:, offset:offset + size])
 | |
| 
 | |
| 
 | |
| class RelPositionalEncoding(PositionalEncoding):
 | |
|     """Relative positional encoding module.
 | |
|     See : Appendix B in https://arxiv.org/abs/1901.02860
 | |
|     Args:
 | |
|         d_model (int): Embedding dimension.
 | |
|         dropout_rate (float): Dropout rate.
 | |
|         max_len (int): Maximum input length.
 | |
|     """
 | |
|     def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
 | |
|         """Initialize class."""
 | |
|         super().__init__(d_model, dropout_rate, max_len, reverse=True)
 | |
| 
 | |
|     def forward(self,
 | |
|                 x: torch.Tensor,
 | |
|                 offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
 | |
|         """Compute positional encoding.
 | |
|         Args:
 | |
|             x (torch.Tensor): Input tensor (batch, time, `*`).
 | |
|         Returns:
 | |
|             torch.Tensor: Encoded tensor (batch, time, `*`).
 | |
|             torch.Tensor: Positional embedding tensor (1, time, `*`).
 | |
|         """
 | |
|         assert offset + x.size(1) < self.max_len
 | |
|         self.pe = self.pe.to(x.device)
 | |
|         x = x * self.xscale
 | |
|         pos_emb = self.pe[:, offset:offset + x.size(1)]
 | |
|         return self.dropout(x), self.dropout(pos_emb)
 | |
| 
 | |
| 
 | |
| class NoPositionalEncoding(torch.nn.Module):
 | |
|     """ No position encoding
 | |
|     """
 | |
|     def __init__(self, d_model: int, dropout_rate: float):
 | |
|         super().__init__()
 | |
|         self.d_model = d_model
 | |
|         self.dropout = torch.nn.Dropout(p=dropout_rate)
 | |
| 
 | |
|     def forward(self,
 | |
|                 x: torch.Tensor,
 | |
|                 offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
 | |
|         """ Just return zero vector for interface compatibility
 | |
|         """
 | |
|         pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
 | |
|         return self.dropout(x), pos_emb
 | |
| 
 | |
|     def position_encoding(self, offset: int, size: int) -> torch.Tensor:
 | |
|         return torch.zeros(1, size, self.d_model)
 | 
