44 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			44 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | ||
|  | # -*- coding: utf-8 -*- | ||
|  | 
 | ||
|  | # Copyright 2019 Shigeki Karita | ||
|  | #  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0) | ||
|  | """Positionwise feed forward layer definition.""" | ||
|  | 
 | ||
|  | import torch | ||
|  | 
 | ||
|  | 
 | ||
|  | class PositionwiseFeedForward(torch.nn.Module): | ||
|  |     """Positionwise feed forward layer.
 | ||
|  | 
 | ||
|  |     FeedForward are appied on each position of the sequence. | ||
|  |     The output dim is same with the input dim. | ||
|  | 
 | ||
|  |     Args: | ||
|  |         idim (int): Input dimenstion. | ||
|  |         hidden_units (int): The number of hidden units. | ||
|  |         dropout_rate (float): Dropout rate. | ||
|  |         activation (torch.nn.Module): Activation function | ||
|  |     """
 | ||
|  |     def __init__(self, | ||
|  |                  idim: int, | ||
|  |                  hidden_units: int, | ||
|  |                  dropout_rate: float, | ||
|  |                  activation: torch.nn.Module = torch.nn.ReLU()): | ||
|  |         """Construct a PositionwiseFeedForward object.""" | ||
|  |         super(PositionwiseFeedForward, self).__init__() | ||
|  |         self.w_1 = torch.nn.Linear(idim, hidden_units) | ||
|  |         self.activation = activation | ||
|  |         self.dropout = torch.nn.Dropout(dropout_rate) | ||
|  |         self.w_2 = torch.nn.Linear(hidden_units, idim) | ||
|  | 
 | ||
|  |     def forward(self, xs: torch.Tensor) -> torch.Tensor: | ||
|  |         """Forward function.
 | ||
|  | 
 | ||
|  |         Args: | ||
|  |             xs: input tensor (B, L, D) | ||
|  |         Returns: | ||
|  |             output tensor, (B, L, D) | ||
|  |         """
 | ||
|  |         return self.w_2(self.dropout(self.activation(self.w_1(xs)))) |