#!/usr/bin/env python3 # encoding: utf-8 import sys import argparse import json import codecs import yaml import torch import torchaudio import torchaudio.compliance.kaldi as kaldi from torch.utils.data import Dataset, DataLoader torchaudio.set_audio_backend("sox_io") class CollateFunc(object): ''' Collate function for AudioDataset ''' def __init__(self, feat_dim, resample_rate): self.feat_dim = feat_dim self.resample_rate = resample_rate pass def __call__(self, batch): mean_stat = torch.zeros(self.feat_dim) var_stat = torch.zeros(self.feat_dim) number = 0 for item in batch: value = item[1].strip().split(",") assert len(value) == 3 or len(value) == 1 wav_path = value[0] sample_rate = torchaudio.backend.sox_io_backend.info(wav_path).sample_rate resample_rate = sample_rate # len(value) == 3 means segmented wav.scp, # len(value) == 1 means original wav.scp if len(value) == 3: start_frame = int(float(value[1]) * sample_rate) end_frame = int(float(value[2]) * sample_rate) waveform, sample_rate = torchaudio.backend.sox_io_backend.load( filepath=wav_path, num_frames=end_frame - start_frame, frame_offset=start_frame) else: waveform, sample_rate = torchaudio.load(item[1]) waveform = waveform * (1 << 15) if self.resample_rate != 0 and self.resample_rate != sample_rate: resample_rate = self.resample_rate waveform = torchaudio.transforms.Resample( orig_freq=sample_rate, new_freq=resample_rate)(waveform) mat = kaldi.fbank(waveform, num_mel_bins=self.feat_dim, dither=0.0, energy_floor=0.0, sample_frequency=resample_rate) mean_stat += torch.sum(mat, axis=0) var_stat += torch.sum(torch.square(mat), axis=0) number += mat.shape[0] return number, mean_stat, var_stat class AudioDataset(Dataset): def __init__(self, data_file): self.items = [] with codecs.open(data_file, 'r', encoding='utf-8') as f: for line in f: arr = line.strip().split() self.items.append((arr[0], arr[1])) def __len__(self): return len(self.items) def __getitem__(self, idx): return self.items[idx] if __name__ == '__main__': parser = argparse.ArgumentParser(description='extract CMVN stats') parser.add_argument('--num_workers', default=0, type=int, help='num of subprocess workers for processing') parser.add_argument('--train_config', default='', help='training yaml conf') parser.add_argument('--in_scp', default=None, help='wav scp file') parser.add_argument('--out_cmvn', default='global_cmvn', help='global cmvn file') args = parser.parse_args() with open(args.train_config, 'r') as fin: configs = yaml.load(fin, Loader=yaml.FullLoader) feat_dim = configs['collate_conf']['feature_extraction_conf']['mel_bins'] resample_rate = 0 if 'resample' in configs['collate_conf']['feature_extraction_conf']: resample_rate = configs['collate_conf']['feature_extraction_conf']['resample'] print('using resample and new sample rate is {}'.format(resample_rate)) collate_func = CollateFunc(feat_dim, resample_rate) dataset = AudioDataset(args.in_scp) batch_size = 20 data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, sampler=None, num_workers=args.num_workers, collate_fn=collate_func) with torch.no_grad(): all_number = 0 all_mean_stat = torch.zeros(feat_dim) all_var_stat = torch.zeros(feat_dim) wav_number = 0 for i, batch in enumerate(data_loader): number, mean_stat, var_stat = batch all_mean_stat += mean_stat all_var_stat += var_stat all_number += number wav_number += batch_size if wav_number % 1000 == 0: print(f'processed {wav_number} wavs, {all_number} frames', file=sys.stderr, flush=True) cmvn_info = { 'mean_stat': list(all_mean_stat.tolist()), 'var_stat': list(all_var_stat.tolist()), 'frame_num': all_number } with open(args.out_cmvn, 'w') as fout: fout.write(json.dumps(cmvn_info))