138 lines
4.9 KiB
Python
Executable File
138 lines
4.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# encoding: utf-8
|
|
|
|
import sys
|
|
import argparse
|
|
import json
|
|
import codecs
|
|
import yaml
|
|
|
|
import torch
|
|
import torchaudio
|
|
import torchaudio.compliance.kaldi as kaldi
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
|
torchaudio.set_audio_backend("sox_io")
|
|
|
|
|
|
class CollateFunc(object):
|
|
''' Collate function for AudioDataset
|
|
'''
|
|
def __init__(self, feat_dim, resample_rate):
|
|
self.feat_dim = feat_dim
|
|
self.resample_rate = resample_rate
|
|
pass
|
|
|
|
def __call__(self, batch):
|
|
mean_stat = torch.zeros(self.feat_dim)
|
|
var_stat = torch.zeros(self.feat_dim)
|
|
number = 0
|
|
for item in batch:
|
|
value = item[1].strip().split(",")
|
|
assert len(value) == 3 or len(value) == 1
|
|
wav_path = value[0]
|
|
sample_rate = torchaudio.backend.sox_io_backend.info(wav_path).sample_rate
|
|
resample_rate = sample_rate
|
|
# len(value) == 3 means segmented wav.scp,
|
|
# len(value) == 1 means original wav.scp
|
|
if len(value) == 3:
|
|
start_frame = int(float(value[1]) * sample_rate)
|
|
end_frame = int(float(value[2]) * sample_rate)
|
|
waveform, sample_rate = torchaudio.backend.sox_io_backend.load(
|
|
filepath=wav_path,
|
|
num_frames=end_frame - start_frame,
|
|
frame_offset=start_frame)
|
|
else:
|
|
waveform, sample_rate = torchaudio.load(item[1])
|
|
|
|
waveform = waveform * (1 << 15)
|
|
if self.resample_rate != 0 and self.resample_rate != sample_rate:
|
|
resample_rate = self.resample_rate
|
|
waveform = torchaudio.transforms.Resample(
|
|
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
|
|
|
mat = kaldi.fbank(waveform,
|
|
num_mel_bins=self.feat_dim,
|
|
dither=0.0,
|
|
energy_floor=0.0,
|
|
sample_frequency=resample_rate)
|
|
mean_stat += torch.sum(mat, axis=0)
|
|
var_stat += torch.sum(torch.square(mat), axis=0)
|
|
number += mat.shape[0]
|
|
return number, mean_stat, var_stat
|
|
|
|
|
|
class AudioDataset(Dataset):
|
|
def __init__(self, data_file):
|
|
self.items = []
|
|
with codecs.open(data_file, 'r', encoding='utf-8') as f:
|
|
for line in f:
|
|
arr = line.strip().split()
|
|
self.items.append((arr[0], arr[1]))
|
|
|
|
def __len__(self):
|
|
return len(self.items)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.items[idx]
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='extract CMVN stats')
|
|
parser.add_argument('--num_workers',
|
|
default=0,
|
|
type=int,
|
|
help='num of subprocess workers for processing')
|
|
parser.add_argument('--train_config',
|
|
default='',
|
|
help='training yaml conf')
|
|
parser.add_argument('--in_scp', default=None, help='wav scp file')
|
|
parser.add_argument('--out_cmvn',
|
|
default='global_cmvn',
|
|
help='global cmvn file')
|
|
|
|
args = parser.parse_args()
|
|
|
|
with open(args.train_config, 'r') as fin:
|
|
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
|
feat_dim = configs['collate_conf']['feature_extraction_conf']['mel_bins']
|
|
resample_rate = 0
|
|
if 'resample' in configs['collate_conf']['feature_extraction_conf']:
|
|
resample_rate = configs['collate_conf']['feature_extraction_conf']['resample']
|
|
print('using resample and new sample rate is {}'.format(resample_rate))
|
|
|
|
collate_func = CollateFunc(feat_dim, resample_rate)
|
|
dataset = AudioDataset(args.in_scp)
|
|
batch_size = 20
|
|
data_loader = DataLoader(dataset,
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
sampler=None,
|
|
num_workers=args.num_workers,
|
|
collate_fn=collate_func)
|
|
|
|
with torch.no_grad():
|
|
all_number = 0
|
|
all_mean_stat = torch.zeros(feat_dim)
|
|
all_var_stat = torch.zeros(feat_dim)
|
|
wav_number = 0
|
|
for i, batch in enumerate(data_loader):
|
|
number, mean_stat, var_stat = batch
|
|
all_mean_stat += mean_stat
|
|
all_var_stat += var_stat
|
|
all_number += number
|
|
wav_number += batch_size
|
|
if wav_number % 1000 == 0:
|
|
print(f'processed {wav_number} wavs, {all_number} frames',
|
|
file=sys.stderr,
|
|
flush=True)
|
|
|
|
cmvn_info = {
|
|
'mean_stat': list(all_mean_stat.tolist()),
|
|
'var_stat': list(all_var_stat.tolist()),
|
|
'frame_num': all_number
|
|
}
|
|
|
|
with open(args.out_cmvn, 'w') as fout:
|
|
fout.write(json.dumps(cmvn_info))
|