48 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			48 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #   http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| import torch
 | |
| 
 | |
| 
 | |
| class GlobalCMVN(torch.nn.Module):
 | |
|     def __init__(self,
 | |
|                  mean: torch.Tensor,
 | |
|                  istd: torch.Tensor,
 | |
|                  norm_var: bool = True):
 | |
|         """
 | |
|         Args:
 | |
|             mean (torch.Tensor): mean stats
 | |
|             istd (torch.Tensor): inverse std, std which is 1.0 / std
 | |
|         """
 | |
|         super().__init__()
 | |
|         assert mean.shape == istd.shape
 | |
|         self.norm_var = norm_var
 | |
|         # The buffer can be accessed from this module using self.mean
 | |
|         self.register_buffer("mean", mean)
 | |
|         self.register_buffer("istd", istd)
 | |
| 
 | |
|     def forward(self, x: torch.Tensor):
 | |
|         """
 | |
|         Args:
 | |
|             x (torch.Tensor): (batch, max_len, feat_dim)
 | |
| 
 | |
|         Returns:
 | |
|             (torch.Tensor): normalized feature
 | |
|         """
 | |
|         x = x - self.mean
 | |
|         if self.norm_var:
 | |
|             x = x * self.istd
 | |
|         return x
 | 
