534 lines
20 KiB
Python
534 lines
20 KiB
Python
![]() |
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Chao Yang)
|
||
|
# Copyright (c) 2021 Jinsong Pan
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
import argparse
|
||
|
import codecs
|
||
|
import copy
|
||
|
import logging
|
||
|
import random
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torchaudio
|
||
|
import torchaudio.compliance.kaldi as kaldi
|
||
|
import torchaudio.sox_effects as sox_effects
|
||
|
import yaml
|
||
|
from PIL import Image
|
||
|
from PIL.Image import BICUBIC
|
||
|
from torch.nn.utils.rnn import pad_sequence
|
||
|
from torch.utils.data import Dataset, DataLoader
|
||
|
|
||
|
import wenet.dataset.kaldi_io as kaldi_io
|
||
|
from wenet.dataset.wav_distortion import distort_wav_conf
|
||
|
from wenet.utils.common import IGNORE_ID
|
||
|
|
||
|
torchaudio.set_audio_backend("sox_io")
|
||
|
|
||
|
|
||
|
def _spec_augmentation(x,
|
||
|
warp_for_time=False,
|
||
|
num_t_mask=2,
|
||
|
num_f_mask=2,
|
||
|
max_t=50,
|
||
|
max_f=10,
|
||
|
max_w=80):
|
||
|
""" Deep copy x and do spec augmentation then return it
|
||
|
|
||
|
Args:
|
||
|
x: input feature, T * F 2D
|
||
|
num_t_mask: number of time mask to apply
|
||
|
num_f_mask: number of freq mask to apply
|
||
|
max_t: max width of time mask
|
||
|
max_f: max width of freq mask
|
||
|
max_w: max width of time warp
|
||
|
|
||
|
Returns:
|
||
|
augmented feature
|
||
|
"""
|
||
|
y = np.copy(x)
|
||
|
max_frames = y.shape[0]
|
||
|
max_freq = y.shape[1]
|
||
|
|
||
|
# time warp
|
||
|
if warp_for_time and max_frames > max_w * 2:
|
||
|
center = random.randrange(max_w, max_frames - max_w)
|
||
|
warped = random.randrange(center - max_w, center + max_w) + 1
|
||
|
|
||
|
left = Image.fromarray(x[:center]).resize((max_freq, warped), BICUBIC)
|
||
|
right = Image.fromarray(x[center:]).resize(
|
||
|
(max_freq, max_frames - warped), BICUBIC)
|
||
|
y = np.concatenate((left, right), 0)
|
||
|
# time mask
|
||
|
for i in range(num_t_mask):
|
||
|
start = random.randint(0, max_frames - 1)
|
||
|
length = random.randint(1, max_t)
|
||
|
end = min(max_frames, start + length)
|
||
|
y[start:end, :] = 0
|
||
|
# freq mask
|
||
|
for i in range(num_f_mask):
|
||
|
start = random.randint(0, max_freq - 1)
|
||
|
length = random.randint(1, max_f)
|
||
|
end = min(max_freq, start + length)
|
||
|
y[:, start:end] = 0
|
||
|
return y
|
||
|
|
||
|
|
||
|
def _spec_substitute(x, max_t=20, num_t_sub=3):
|
||
|
""" Deep copy x and do spec substitute then return it
|
||
|
|
||
|
Args:
|
||
|
x: input feature, T * F 2D
|
||
|
max_t: max width of time substitute
|
||
|
num_t_sub: number of time substitute to apply
|
||
|
|
||
|
Returns:
|
||
|
augmented feature
|
||
|
"""
|
||
|
y = np.copy(x)
|
||
|
max_frames = y.shape[0]
|
||
|
for i in range(num_t_sub):
|
||
|
start = random.randint(0, max_frames - 1)
|
||
|
length = random.randint(1, max_t)
|
||
|
end = min(max_frames, start + length)
|
||
|
# only substitute the earlier time chosen randomly for current time
|
||
|
pos = random.randint(0, start)
|
||
|
y[start:end, :] = y[start - pos:end - pos, :]
|
||
|
return y
|
||
|
|
||
|
|
||
|
def _waveform_distortion(waveform, distortion_methods_conf):
|
||
|
""" Apply distortion on waveform
|
||
|
|
||
|
This distortion will not change the length of the waveform.
|
||
|
|
||
|
Args:
|
||
|
waveform: numpy float tensor, (length,)
|
||
|
distortion_methods_conf: a list of config for ditortion method.
|
||
|
a method will be randomly selected by 'method_rate' and
|
||
|
apply on the waveform.
|
||
|
|
||
|
Returns:
|
||
|
distorted waveform.
|
||
|
"""
|
||
|
r = random.uniform(0, 1)
|
||
|
acc = 0.0
|
||
|
for distortion_method in distortion_methods_conf:
|
||
|
method_rate = distortion_method['method_rate']
|
||
|
acc += method_rate
|
||
|
if r < acc:
|
||
|
distortion_type = distortion_method['name']
|
||
|
distortion_conf = distortion_method['params']
|
||
|
point_rate = distortion_method['point_rate']
|
||
|
return distort_wav_conf(waveform, distortion_type, distortion_conf,
|
||
|
point_rate)
|
||
|
return waveform
|
||
|
|
||
|
|
||
|
# add speed perturb when loading wav
|
||
|
# return augmented, sr
|
||
|
def _load_wav_with_speed(wav_file, speed):
|
||
|
""" Load the wave from file and apply speed perpturbation
|
||
|
|
||
|
Args:
|
||
|
wav_file: input feature, T * F 2D
|
||
|
|
||
|
Returns:
|
||
|
augmented feature
|
||
|
"""
|
||
|
if speed == 1.0:
|
||
|
wav, sr = torchaudio.load(wav_file)
|
||
|
else:
|
||
|
sample_rate = torchaudio.backend.sox_io_backend.info(
|
||
|
wav_file).sample_rate
|
||
|
# get torchaudio version
|
||
|
ta_no = torchaudio.__version__.split(".")
|
||
|
ta_version = 100 * int(ta_no[0]) + 10 * int(ta_no[1])
|
||
|
|
||
|
if ta_version < 80:
|
||
|
# Note: deprecated in torchaudio>=0.8.0
|
||
|
E = sox_effects.SoxEffectsChain()
|
||
|
E.append_effect_to_chain('speed', speed)
|
||
|
E.append_effect_to_chain("rate", sample_rate)
|
||
|
E.set_input_file(wav_file)
|
||
|
wav, sr = E.sox_build_flow_effects()
|
||
|
else:
|
||
|
# Note: enable in torchaudio>=0.8.0
|
||
|
wav, sr = sox_effects.apply_effects_file(
|
||
|
wav_file,
|
||
|
[['speed', str(speed)], ['rate', str(sample_rate)]])
|
||
|
|
||
|
return wav, sr
|
||
|
|
||
|
|
||
|
def _extract_feature(batch, speed_perturb, wav_distortion_conf,
|
||
|
feature_extraction_conf):
|
||
|
""" Extract acoustic fbank feature from origin waveform.
|
||
|
|
||
|
Speed perturbation and wave amplitude distortion is optional.
|
||
|
|
||
|
Args:
|
||
|
batch: a list of tuple (wav id , wave path).
|
||
|
speed_perturb: bool, whether or not to use speed pertubation.
|
||
|
wav_distortion_conf: a dict , the config of wave amplitude distortion.
|
||
|
feature_extraction_conf:a dict , the config of fbank extraction.
|
||
|
|
||
|
Returns:
|
||
|
(keys, feats, labels)
|
||
|
"""
|
||
|
keys = []
|
||
|
feats = []
|
||
|
lengths = []
|
||
|
wav_dither = wav_distortion_conf['wav_dither']
|
||
|
wav_distortion_rate = wav_distortion_conf['wav_distortion_rate']
|
||
|
distortion_methods_conf = wav_distortion_conf['distortion_methods']
|
||
|
if speed_perturb:
|
||
|
speeds = [1.0, 1.1, 0.9]
|
||
|
weights = [1, 1, 1]
|
||
|
speed = random.choices(speeds, weights, k=1)[0]
|
||
|
# speed = random.choice(speeds)
|
||
|
for i, x in enumerate(batch):
|
||
|
try:
|
||
|
wav = x[1]
|
||
|
value = wav.strip().split(",")
|
||
|
# 1 for general wav.scp, 3 for segmented wav.scp
|
||
|
assert len(value) == 1 or len(value) == 3
|
||
|
wav_path = value[0]
|
||
|
sample_rate = torchaudio.backend.sox_io_backend.info(
|
||
|
wav_path).sample_rate
|
||
|
if 'resample' in feature_extraction_conf:
|
||
|
resample_rate = feature_extraction_conf['resample']
|
||
|
else:
|
||
|
resample_rate = sample_rate
|
||
|
if speed_perturb:
|
||
|
if len(value) == 3:
|
||
|
logging.error(
|
||
|
"speed perturb does not support segmented wav.scp now")
|
||
|
assert len(value) == 1
|
||
|
waveform, sample_rate = _load_wav_with_speed(wav_path, speed)
|
||
|
else:
|
||
|
# value length 3 means using segmented wav.scp
|
||
|
# incluede .wav, start time, end time
|
||
|
if len(value) == 3:
|
||
|
start_frame = int(float(value[1]) * sample_rate)
|
||
|
end_frame = int(float(value[2]) * sample_rate)
|
||
|
waveform, sample_rate = torchaudio.backend.sox_io_backend.load(
|
||
|
filepath=wav_path,
|
||
|
num_frames=end_frame - start_frame,
|
||
|
offset=start_frame)
|
||
|
else:
|
||
|
waveform, sample_rate = torchaudio.load(wav_path)
|
||
|
waveform = waveform * (1 << 15)
|
||
|
if resample_rate != sample_rate:
|
||
|
waveform = torchaudio.transforms.Resample(
|
||
|
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
||
|
|
||
|
if wav_distortion_rate > 0.0:
|
||
|
r = random.uniform(0, 1)
|
||
|
if r < wav_distortion_rate:
|
||
|
waveform = waveform.detach().numpy()
|
||
|
waveform = _waveform_distortion(waveform,
|
||
|
distortion_methods_conf)
|
||
|
waveform = torch.from_numpy(waveform)
|
||
|
mat = kaldi.fbank(
|
||
|
waveform,
|
||
|
num_mel_bins=feature_extraction_conf['mel_bins'],
|
||
|
frame_length=feature_extraction_conf['frame_length'],
|
||
|
frame_shift=feature_extraction_conf['frame_shift'],
|
||
|
dither=wav_dither,
|
||
|
energy_floor=0.0,
|
||
|
sample_frequency=resample_rate)
|
||
|
mat = mat.detach().numpy()
|
||
|
feats.append(mat)
|
||
|
keys.append(x[0])
|
||
|
lengths.append(mat.shape[0])
|
||
|
except (Exception) as e:
|
||
|
print(e)
|
||
|
logging.warn('read utterance {} error'.format(x[0]))
|
||
|
pass
|
||
|
# Sort it because sorting is required in pack/pad operation
|
||
|
order = np.argsort(lengths)[::-1]
|
||
|
sorted_keys = [keys[i] for i in order]
|
||
|
sorted_feats = [feats[i] for i in order]
|
||
|
labels = [x[2].split() for x in batch]
|
||
|
labels = [np.fromiter(map(int, x), dtype=np.int32) for x in labels]
|
||
|
sorted_labels = [labels[i] for i in order]
|
||
|
return sorted_keys, sorted_feats, sorted_labels
|
||
|
|
||
|
|
||
|
def _load_feature(batch):
|
||
|
""" Load acoustic feature from files.
|
||
|
|
||
|
The features have been prepared in previous step, usualy by Kaldi.
|
||
|
|
||
|
Args:
|
||
|
batch: a list of tuple (wav id , feature ark path).
|
||
|
|
||
|
Returns:
|
||
|
(keys, feats, labels)
|
||
|
"""
|
||
|
keys = []
|
||
|
feats = []
|
||
|
lengths = []
|
||
|
for i, x in enumerate(batch):
|
||
|
try:
|
||
|
mat = kaldi_io.read_mat(x[1])
|
||
|
feats.append(mat)
|
||
|
keys.append(x[0])
|
||
|
lengths.append(mat.shape[0])
|
||
|
except (Exception):
|
||
|
# logging.warn('read utterance {} error'.format(x[0]))
|
||
|
pass
|
||
|
# Sort it because sorting is required in pack/pad operation
|
||
|
order = np.argsort(lengths)[::-1]
|
||
|
sorted_keys = [keys[i] for i in order]
|
||
|
sorted_feats = [feats[i] for i in order]
|
||
|
labels = [x[2].split() for x in batch]
|
||
|
labels = [np.fromiter(map(int, x), dtype=np.int32) for x in labels]
|
||
|
sorted_labels = [labels[i] for i in order]
|
||
|
return sorted_keys, sorted_feats, sorted_labels
|
||
|
|
||
|
|
||
|
class CollateFunc(object):
|
||
|
""" Collate function for AudioDataset
|
||
|
"""
|
||
|
def __init__(
|
||
|
self,
|
||
|
feature_dither=0.0,
|
||
|
speed_perturb=False,
|
||
|
spec_aug=False,
|
||
|
spec_aug_conf=None,
|
||
|
spec_sub=False,
|
||
|
spec_sub_conf=None,
|
||
|
raw_wav=True,
|
||
|
feature_extraction_conf=None,
|
||
|
wav_distortion_conf=None,
|
||
|
):
|
||
|
"""
|
||
|
Args:
|
||
|
raw_wav:
|
||
|
True if input is raw wav and feature extraction is needed.
|
||
|
False if input is extracted feature
|
||
|
"""
|
||
|
self.wav_distortion_conf = wav_distortion_conf
|
||
|
self.feature_extraction_conf = feature_extraction_conf
|
||
|
self.spec_aug = spec_aug
|
||
|
self.feature_dither = feature_dither
|
||
|
self.speed_perturb = speed_perturb
|
||
|
self.raw_wav = raw_wav
|
||
|
self.spec_aug_conf = spec_aug_conf
|
||
|
self.spec_sub = spec_sub
|
||
|
self.spec_sub_conf = spec_sub_conf
|
||
|
|
||
|
def __call__(self, batch):
|
||
|
assert (len(batch) == 1)
|
||
|
if self.raw_wav:
|
||
|
keys, xs, ys = _extract_feature(batch[0], self.speed_perturb,
|
||
|
self.wav_distortion_conf,
|
||
|
self.feature_extraction_conf)
|
||
|
|
||
|
else:
|
||
|
keys, xs, ys = _load_feature(batch[0])
|
||
|
|
||
|
train_flag = True
|
||
|
if ys is None:
|
||
|
train_flag = False
|
||
|
|
||
|
# optional feature dither d ~ (-a, a) on fbank feature
|
||
|
# a ~ (0, 0.5)
|
||
|
if self.feature_dither != 0.0:
|
||
|
a = random.uniform(0, self.feature_dither)
|
||
|
xs = [x + (np.random.random_sample(x.shape) - 0.5) * a for x in xs]
|
||
|
|
||
|
# optinoal spec substitute
|
||
|
if self.spec_sub:
|
||
|
xs = [_spec_substitute(x, **self.spec_sub_conf) for x in xs]
|
||
|
|
||
|
# optinoal spec augmentation
|
||
|
if self.spec_aug:
|
||
|
xs = [_spec_augmentation(x, **self.spec_aug_conf) for x in xs]
|
||
|
|
||
|
# padding
|
||
|
xs_lengths = torch.from_numpy(
|
||
|
np.array([x.shape[0] for x in xs], dtype=np.int32))
|
||
|
|
||
|
# pad_sequence will FAIL in case xs is empty
|
||
|
if len(xs) > 0:
|
||
|
xs_pad = pad_sequence([torch.from_numpy(x).float() for x in xs],
|
||
|
True, 0)
|
||
|
else:
|
||
|
xs_pad = torch.Tensor(xs)
|
||
|
if train_flag:
|
||
|
ys_lengths = torch.from_numpy(
|
||
|
np.array([y.shape[0] for y in ys], dtype=np.int32))
|
||
|
if len(ys) > 0:
|
||
|
ys_pad = pad_sequence([torch.from_numpy(y).int() for y in ys],
|
||
|
True, IGNORE_ID)
|
||
|
else:
|
||
|
ys_pad = torch.Tensor(ys)
|
||
|
else:
|
||
|
ys_pad = None
|
||
|
ys_lengths = None
|
||
|
return keys, xs_pad, ys_pad, xs_lengths, ys_lengths
|
||
|
|
||
|
|
||
|
class AudioDataset(Dataset):
|
||
|
def __init__(self,
|
||
|
data_file,
|
||
|
max_length=10240,
|
||
|
min_length=0,
|
||
|
token_max_length=200,
|
||
|
token_min_length=1,
|
||
|
batch_type='static',
|
||
|
batch_size=1,
|
||
|
max_frames_in_batch=0,
|
||
|
sort=True,
|
||
|
raw_wav=True):
|
||
|
"""Dataset for loading audio data.
|
||
|
|
||
|
Attributes::
|
||
|
data_file: input data file
|
||
|
Plain text data file, each line contains following 7 fields,
|
||
|
which is split by '\t':
|
||
|
utt:utt1
|
||
|
feat:tmp/data/file1.wav or feat:tmp/data/fbank.ark:30
|
||
|
feat_shape: 4.95(in seconds) or feat_shape:495,80(495 is in frames)
|
||
|
text:i love you
|
||
|
token: i <space> l o v e <space> y o u
|
||
|
tokenid: int id of this token
|
||
|
token_shape: M,N # M is the number of token, N is vocab size
|
||
|
max_length: drop utterance which is greater than max_length(10ms)
|
||
|
min_length: drop utterance which is less than min_length(10ms)
|
||
|
token_max_length: drop utterance which is greater than token_max_length,
|
||
|
especially when use char unit for english modeling
|
||
|
token_min_length: drop utterance which is less than token_max_length
|
||
|
batch_type: static or dynamic, see max_frames_in_batch(dynamic)
|
||
|
batch_size: number of utterances in a batch,
|
||
|
it's for static batch size.
|
||
|
max_frames_in_batch: max feature frames in a batch,
|
||
|
when batch_type is dynamic, it's for dynamic batch size.
|
||
|
Then batch_size is ignored, we will keep filling the
|
||
|
batch until the total frames in batch up to max_frames_in_batch.
|
||
|
sort: whether to sort all data, so the utterance with the same
|
||
|
length could be filled in a same batch.
|
||
|
raw_wav: use raw wave or extracted featute.
|
||
|
if raw wave is used, dynamic waveform-level augmentation could be used
|
||
|
and the feature is extracted by torchaudio.
|
||
|
if extracted featute(e.g. by kaldi) is used, only feature-level
|
||
|
augmentation such as specaug could be used.
|
||
|
"""
|
||
|
assert batch_type in ['static', 'dynamic']
|
||
|
data = []
|
||
|
|
||
|
# Open in utf8 mode since meet encoding problem
|
||
|
with codecs.open(data_file, 'r', encoding='utf-8') as f:
|
||
|
for line in f:
|
||
|
arr = line.strip().split('\t')
|
||
|
if len(arr) != 7:
|
||
|
continue
|
||
|
key = arr[0].split(':')[1]
|
||
|
tokenid = arr[5].split(':')[1]
|
||
|
output_dim = int(arr[6].split(':')[1].split(',')[1])
|
||
|
if raw_wav:
|
||
|
wav_path = ':'.join(arr[1].split(':')[1:])
|
||
|
duration = int(float(arr[2].split(':')[1]) * 1000 / 10)
|
||
|
data.append((key, wav_path, duration, tokenid))
|
||
|
else:
|
||
|
feat_ark = ':'.join(arr[1].split(':')[1:])
|
||
|
feat_info = arr[2].split(':')[1].split(',')
|
||
|
feat_dim = int(feat_info[1].strip())
|
||
|
num_frames = int(feat_info[0].strip())
|
||
|
data.append((key, feat_ark, num_frames, tokenid))
|
||
|
self.input_dim = feat_dim
|
||
|
self.output_dim = output_dim
|
||
|
if sort:
|
||
|
data = sorted(data, key=lambda x: x[2])
|
||
|
valid_data = []
|
||
|
for i in range(len(data)):
|
||
|
length = data[i][2]
|
||
|
token_length = len(data[i][3].split())
|
||
|
# remove too lang or too short utt for both input and output
|
||
|
# to prevent from out of memory
|
||
|
if length > max_length or length < min_length:
|
||
|
# logging.warn('ignore utterance {} feature {}'.format(
|
||
|
# data[i][0], length))
|
||
|
pass
|
||
|
elif token_length > token_max_length or token_length < token_min_length:
|
||
|
pass
|
||
|
else:
|
||
|
valid_data.append(data[i])
|
||
|
data = valid_data
|
||
|
self.minibatch = []
|
||
|
num_data = len(data)
|
||
|
# Dynamic batch size
|
||
|
if batch_type == 'dynamic':
|
||
|
assert (max_frames_in_batch > 0)
|
||
|
self.minibatch.append([])
|
||
|
num_frames_in_batch = 0
|
||
|
for i in range(num_data):
|
||
|
length = data[i][2]
|
||
|
num_frames_in_batch += length
|
||
|
if num_frames_in_batch > max_frames_in_batch:
|
||
|
self.minibatch.append([])
|
||
|
num_frames_in_batch = length
|
||
|
self.minibatch[-1].append((data[i][0], data[i][1], data[i][3]))
|
||
|
# Static batch size
|
||
|
else:
|
||
|
cur = 0
|
||
|
while cur < num_data:
|
||
|
end = min(cur + batch_size, num_data)
|
||
|
item = []
|
||
|
for i in range(cur, end):
|
||
|
item.append((data[i][0], data[i][1], data[i][3]))
|
||
|
self.minibatch.append(item)
|
||
|
cur = end
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.minibatch)
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
return self.minibatch[idx]
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('type', help='config file')
|
||
|
parser.add_argument('config_file', help='config file')
|
||
|
parser.add_argument('data_file', help='input data file')
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
with open(args.config_file, 'r') as fin:
|
||
|
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||
|
|
||
|
# Init dataset and data loader
|
||
|
collate_conf = copy.copy(configs['collate_conf'])
|
||
|
if args.type == 'raw_wav':
|
||
|
raw_wav = True
|
||
|
else:
|
||
|
raw_wav = False
|
||
|
collate_func = CollateFunc(**collate_conf, raw_wav=raw_wav)
|
||
|
dataset_conf = configs.get('dataset_conf', {})
|
||
|
dataset = AudioDataset(args.data_file, **dataset_conf, raw_wav=raw_wav)
|
||
|
|
||
|
data_loader = DataLoader(dataset,
|
||
|
batch_size=1,
|
||
|
shuffle=True,
|
||
|
sampler=None,
|
||
|
num_workers=0,
|
||
|
collate_fn=collate_func)
|
||
|
|
||
|
for i, batch in enumerate(data_loader):
|
||
|
print(i)
|
||
|
# print(batch[1].shape)
|