48 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			48 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | ||
|  | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) | ||
|  | # | ||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||
|  | # you may not use this file except in compliance with the License. | ||
|  | # You may obtain a copy of the License at | ||
|  | # | ||
|  | #   http://www.apache.org/licenses/LICENSE-2.0 | ||
|  | # | ||
|  | # Unless required by applicable law or agreed to in writing, software | ||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
|  | # See the License for the specific language governing permissions and | ||
|  | # limitations under the License. | ||
|  | 
 | ||
|  | import torch | ||
|  | 
 | ||
|  | 
 | ||
|  | class GlobalCMVN(torch.nn.Module): | ||
|  |     def __init__(self, | ||
|  |                  mean: torch.Tensor, | ||
|  |                  istd: torch.Tensor, | ||
|  |                  norm_var: bool = True): | ||
|  |         """
 | ||
|  |         Args: | ||
|  |             mean (torch.Tensor): mean stats | ||
|  |             istd (torch.Tensor): inverse std, std which is 1.0 / std | ||
|  |         """
 | ||
|  |         super().__init__() | ||
|  |         assert mean.shape == istd.shape | ||
|  |         self.norm_var = norm_var | ||
|  |         # The buffer can be accessed from this module using self.mean | ||
|  |         self.register_buffer("mean", mean) | ||
|  |         self.register_buffer("istd", istd) | ||
|  | 
 | ||
|  |     def forward(self, x: torch.Tensor): | ||
|  |         """
 | ||
|  |         Args: | ||
|  |             x (torch.Tensor): (batch, max_len, feat_dim) | ||
|  | 
 | ||
|  |         Returns: | ||
|  |             (torch.Tensor): normalized feature | ||
|  |         """
 | ||
|  |         x = x - self.mean | ||
|  |         if self.norm_var: | ||
|  |             x = x * self.istd | ||
|  |         return x |