138 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			138 lines
		
	
	
		
			4.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | ||
|  | # encoding: utf-8 | ||
|  | 
 | ||
|  | import sys | ||
|  | import argparse | ||
|  | import json | ||
|  | import codecs | ||
|  | import yaml | ||
|  | 
 | ||
|  | import torch | ||
|  | import torchaudio | ||
|  | import torchaudio.compliance.kaldi as kaldi | ||
|  | from torch.utils.data import Dataset, DataLoader | ||
|  | 
 | ||
|  | torchaudio.set_audio_backend("sox_io") | ||
|  | 
 | ||
|  | 
 | ||
|  | class CollateFunc(object): | ||
|  |     ''' Collate function for AudioDataset
 | ||
|  |     '''
 | ||
|  |     def __init__(self, feat_dim, resample_rate): | ||
|  |         self.feat_dim = feat_dim | ||
|  |         self.resample_rate = resample_rate | ||
|  |         pass | ||
|  | 
 | ||
|  |     def __call__(self, batch): | ||
|  |         mean_stat = torch.zeros(self.feat_dim) | ||
|  |         var_stat = torch.zeros(self.feat_dim) | ||
|  |         number = 0 | ||
|  |         for item in batch: | ||
|  |             value = item[1].strip().split(",") | ||
|  |             assert len(value) == 3 or len(value) == 1 | ||
|  |             wav_path = value[0] | ||
|  |             sample_rate = torchaudio.backend.sox_io_backend.info(wav_path).sample_rate | ||
|  |             resample_rate = sample_rate | ||
|  |             # len(value) == 3 means segmented wav.scp, | ||
|  |             # len(value) == 1 means original wav.scp | ||
|  |             if len(value) == 3: | ||
|  |                 start_frame = int(float(value[1]) * sample_rate) | ||
|  |                 end_frame = int(float(value[2]) * sample_rate) | ||
|  |                 waveform, sample_rate = torchaudio.backend.sox_io_backend.load( | ||
|  |                     filepath=wav_path, | ||
|  |                     num_frames=end_frame - start_frame, | ||
|  |                     frame_offset=start_frame) | ||
|  |             else: | ||
|  |                 waveform, sample_rate = torchaudio.load(item[1]) | ||
|  | 
 | ||
|  |             waveform = waveform * (1 << 15) | ||
|  |             if self.resample_rate != 0 and self.resample_rate != sample_rate: | ||
|  |                 resample_rate = self.resample_rate | ||
|  |                 waveform = torchaudio.transforms.Resample( | ||
|  |                     orig_freq=sample_rate, new_freq=resample_rate)(waveform) | ||
|  | 
 | ||
|  |             mat = kaldi.fbank(waveform, | ||
|  |                               num_mel_bins=self.feat_dim, | ||
|  |                               dither=0.0, | ||
|  |                               energy_floor=0.0, | ||
|  |                               sample_frequency=resample_rate) | ||
|  |             mean_stat += torch.sum(mat, axis=0) | ||
|  |             var_stat += torch.sum(torch.square(mat), axis=0) | ||
|  |             number += mat.shape[0] | ||
|  |         return number, mean_stat, var_stat | ||
|  | 
 | ||
|  | 
 | ||
|  | class AudioDataset(Dataset): | ||
|  |     def __init__(self, data_file): | ||
|  |         self.items = [] | ||
|  |         with codecs.open(data_file, 'r', encoding='utf-8') as f: | ||
|  |             for line in f: | ||
|  |                 arr = line.strip().split() | ||
|  |                 self.items.append((arr[0], arr[1])) | ||
|  | 
 | ||
|  |     def __len__(self): | ||
|  |         return len(self.items) | ||
|  | 
 | ||
|  |     def __getitem__(self, idx): | ||
|  |         return self.items[idx] | ||
|  | 
 | ||
|  | 
 | ||
|  | if __name__ == '__main__': | ||
|  |     parser = argparse.ArgumentParser(description='extract CMVN stats') | ||
|  |     parser.add_argument('--num_workers', | ||
|  |                         default=0, | ||
|  |                         type=int, | ||
|  |                         help='num of subprocess workers for processing') | ||
|  |     parser.add_argument('--train_config', | ||
|  |                         default='', | ||
|  |                         help='training yaml conf') | ||
|  |     parser.add_argument('--in_scp', default=None, help='wav scp file') | ||
|  |     parser.add_argument('--out_cmvn', | ||
|  |                         default='global_cmvn', | ||
|  |                         help='global cmvn file') | ||
|  | 
 | ||
|  |     args = parser.parse_args() | ||
|  | 
 | ||
|  |     with open(args.train_config, 'r') as fin: | ||
|  |         configs = yaml.load(fin, Loader=yaml.FullLoader) | ||
|  |     feat_dim = configs['collate_conf']['feature_extraction_conf']['mel_bins'] | ||
|  |     resample_rate = 0 | ||
|  |     if 'resample' in configs['collate_conf']['feature_extraction_conf']: | ||
|  |         resample_rate = configs['collate_conf']['feature_extraction_conf']['resample'] | ||
|  |         print('using resample and new sample rate is {}'.format(resample_rate)) | ||
|  | 
 | ||
|  |     collate_func = CollateFunc(feat_dim, resample_rate) | ||
|  |     dataset = AudioDataset(args.in_scp) | ||
|  |     batch_size = 20 | ||
|  |     data_loader = DataLoader(dataset, | ||
|  |                              batch_size=batch_size, | ||
|  |                              shuffle=True, | ||
|  |                              sampler=None, | ||
|  |                              num_workers=args.num_workers, | ||
|  |                              collate_fn=collate_func) | ||
|  | 
 | ||
|  |     with torch.no_grad(): | ||
|  |         all_number = 0 | ||
|  |         all_mean_stat = torch.zeros(feat_dim) | ||
|  |         all_var_stat = torch.zeros(feat_dim) | ||
|  |         wav_number = 0 | ||
|  |         for i, batch in enumerate(data_loader): | ||
|  |             number, mean_stat, var_stat = batch | ||
|  |             all_mean_stat += mean_stat | ||
|  |             all_var_stat += var_stat | ||
|  |             all_number += number | ||
|  |             wav_number += batch_size | ||
|  |             if wav_number % 1000 == 0: | ||
|  |                 print(f'processed {wav_number} wavs, {all_number} frames', | ||
|  |                       file=sys.stderr, | ||
|  |                       flush=True) | ||
|  | 
 | ||
|  |     cmvn_info = { | ||
|  |         'mean_stat': list(all_mean_stat.tolist()), | ||
|  |         'var_stat': list(all_var_stat.tolist()), | ||
|  |         'frame_num': all_number | ||
|  |     } | ||
|  | 
 | ||
|  |     with open(args.out_cmvn, 'w') as fout: | ||
|  |         fout.write(json.dumps(cmvn_info)) |