competition update
This commit is contained in:
137
language_model/tools/compute_cmvn_stats.py
Executable file
137
language_model/tools/compute_cmvn_stats.py
Executable file
@@ -0,0 +1,137 @@
|
||||
#!/usr/bin/env python3
|
||||
# encoding: utf-8
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import json
|
||||
import codecs
|
||||
import yaml
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
torchaudio.set_audio_backend("sox_io")
|
||||
|
||||
|
||||
class CollateFunc(object):
|
||||
''' Collate function for AudioDataset
|
||||
'''
|
||||
def __init__(self, feat_dim, resample_rate):
|
||||
self.feat_dim = feat_dim
|
||||
self.resample_rate = resample_rate
|
||||
pass
|
||||
|
||||
def __call__(self, batch):
|
||||
mean_stat = torch.zeros(self.feat_dim)
|
||||
var_stat = torch.zeros(self.feat_dim)
|
||||
number = 0
|
||||
for item in batch:
|
||||
value = item[1].strip().split(",")
|
||||
assert len(value) == 3 or len(value) == 1
|
||||
wav_path = value[0]
|
||||
sample_rate = torchaudio.backend.sox_io_backend.info(wav_path).sample_rate
|
||||
resample_rate = sample_rate
|
||||
# len(value) == 3 means segmented wav.scp,
|
||||
# len(value) == 1 means original wav.scp
|
||||
if len(value) == 3:
|
||||
start_frame = int(float(value[1]) * sample_rate)
|
||||
end_frame = int(float(value[2]) * sample_rate)
|
||||
waveform, sample_rate = torchaudio.backend.sox_io_backend.load(
|
||||
filepath=wav_path,
|
||||
num_frames=end_frame - start_frame,
|
||||
frame_offset=start_frame)
|
||||
else:
|
||||
waveform, sample_rate = torchaudio.load(item[1])
|
||||
|
||||
waveform = waveform * (1 << 15)
|
||||
if self.resample_rate != 0 and self.resample_rate != sample_rate:
|
||||
resample_rate = self.resample_rate
|
||||
waveform = torchaudio.transforms.Resample(
|
||||
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
||||
|
||||
mat = kaldi.fbank(waveform,
|
||||
num_mel_bins=self.feat_dim,
|
||||
dither=0.0,
|
||||
energy_floor=0.0,
|
||||
sample_frequency=resample_rate)
|
||||
mean_stat += torch.sum(mat, axis=0)
|
||||
var_stat += torch.sum(torch.square(mat), axis=0)
|
||||
number += mat.shape[0]
|
||||
return number, mean_stat, var_stat
|
||||
|
||||
|
||||
class AudioDataset(Dataset):
|
||||
def __init__(self, data_file):
|
||||
self.items = []
|
||||
with codecs.open(data_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
arr = line.strip().split()
|
||||
self.items.append((arr[0], arr[1]))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.items[idx]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='extract CMVN stats')
|
||||
parser.add_argument('--num_workers',
|
||||
default=0,
|
||||
type=int,
|
||||
help='num of subprocess workers for processing')
|
||||
parser.add_argument('--train_config',
|
||||
default='',
|
||||
help='training yaml conf')
|
||||
parser.add_argument('--in_scp', default=None, help='wav scp file')
|
||||
parser.add_argument('--out_cmvn',
|
||||
default='global_cmvn',
|
||||
help='global cmvn file')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
with open(args.train_config, 'r') as fin:
|
||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||
feat_dim = configs['collate_conf']['feature_extraction_conf']['mel_bins']
|
||||
resample_rate = 0
|
||||
if 'resample' in configs['collate_conf']['feature_extraction_conf']:
|
||||
resample_rate = configs['collate_conf']['feature_extraction_conf']['resample']
|
||||
print('using resample and new sample rate is {}'.format(resample_rate))
|
||||
|
||||
collate_func = CollateFunc(feat_dim, resample_rate)
|
||||
dataset = AudioDataset(args.in_scp)
|
||||
batch_size = 20
|
||||
data_loader = DataLoader(dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
sampler=None,
|
||||
num_workers=args.num_workers,
|
||||
collate_fn=collate_func)
|
||||
|
||||
with torch.no_grad():
|
||||
all_number = 0
|
||||
all_mean_stat = torch.zeros(feat_dim)
|
||||
all_var_stat = torch.zeros(feat_dim)
|
||||
wav_number = 0
|
||||
for i, batch in enumerate(data_loader):
|
||||
number, mean_stat, var_stat = batch
|
||||
all_mean_stat += mean_stat
|
||||
all_var_stat += var_stat
|
||||
all_number += number
|
||||
wav_number += batch_size
|
||||
if wav_number % 1000 == 0:
|
||||
print(f'processed {wav_number} wavs, {all_number} frames',
|
||||
file=sys.stderr,
|
||||
flush=True)
|
||||
|
||||
cmvn_info = {
|
||||
'mean_stat': list(all_mean_stat.tolist()),
|
||||
'var_stat': list(all_var_stat.tolist()),
|
||||
'frame_num': all_number
|
||||
}
|
||||
|
||||
with open(args.out_cmvn, 'w') as fout:
|
||||
fout.write(json.dumps(cmvn_info))
|
Reference in New Issue
Block a user