competition update
This commit is contained in:
46
language_model/wenet/utils/checkpoint.py
Normal file
46
language_model/wenet/utils/checkpoint.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright 2019 Mobvoi Inc. All Rights Reserved.
|
||||
# Author: binbinzhang@mobvoi.com (Binbin Zhang)
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
import yaml
|
||||
import torch
|
||||
|
||||
|
||||
def load_checkpoint(model: torch.nn.Module, path: str) -> dict:
|
||||
if torch.cuda.is_available():
|
||||
logging.info('Checkpoint: loading from checkpoint %s for GPU' % path)
|
||||
checkpoint = torch.load(path)
|
||||
else:
|
||||
logging.info('Checkpoint: loading from checkpoint %s for CPU' % path)
|
||||
checkpoint = torch.load(path, map_location='cpu')
|
||||
model.load_state_dict(checkpoint)
|
||||
info_path = re.sub('.pt$', '.yaml', path)
|
||||
configs = {}
|
||||
if os.path.exists(info_path):
|
||||
with open(info_path, 'r') as fin:
|
||||
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
||||
return configs
|
||||
|
||||
|
||||
def save_checkpoint(model: torch.nn.Module, path: str, infos=None):
|
||||
'''
|
||||
Args:
|
||||
infos (dict or None): any info you want to save.
|
||||
'''
|
||||
logging.info('Checkpoint: save to checkpoint %s' % path)
|
||||
if isinstance(model, torch.nn.DataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
state_dict = model.module.state_dict()
|
||||
else:
|
||||
state_dict = model.state_dict()
|
||||
torch.save(state_dict, path)
|
||||
info_path = re.sub('.pt$', '.yaml', path)
|
||||
if infos is None:
|
||||
infos = {}
|
||||
with open(info_path, 'w') as fout:
|
||||
data = yaml.dump(infos)
|
||||
fout.write(data)
|
94
language_model/wenet/utils/cmvn.py
Normal file
94
language_model/wenet/utils/cmvn.py
Normal file
@@ -0,0 +1,94 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _load_json_cmvn(json_cmvn_file):
|
||||
""" Load the json format cmvn stats file and calculate cmvn
|
||||
|
||||
Args:
|
||||
json_cmvn_file: cmvn stats file in json format
|
||||
|
||||
Returns:
|
||||
a numpy array of [means, vars]
|
||||
"""
|
||||
with open(json_cmvn_file) as f:
|
||||
cmvn_stats = json.load(f)
|
||||
|
||||
means = cmvn_stats['mean_stat']
|
||||
variance = cmvn_stats['var_stat']
|
||||
count = cmvn_stats['frame_num']
|
||||
for i in range(len(means)):
|
||||
means[i] /= count
|
||||
variance[i] = variance[i] / count - means[i] * means[i]
|
||||
if variance[i] < 1.0e-20:
|
||||
variance[i] = 1.0e-20
|
||||
variance[i] = 1.0 / math.sqrt(variance[i])
|
||||
cmvn = np.array([means, variance])
|
||||
return cmvn
|
||||
|
||||
|
||||
def _load_kaldi_cmvn(kaldi_cmvn_file):
|
||||
""" Load the kaldi format cmvn stats file and calculate cmvn
|
||||
|
||||
Args:
|
||||
kaldi_cmvn_file: kaldi text style global cmvn file, which
|
||||
is generated by:
|
||||
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
|
||||
|
||||
Returns:
|
||||
a numpy array of [means, vars]
|
||||
"""
|
||||
means = []
|
||||
variance = []
|
||||
with open(kaldi_cmvn_file, 'r') as fid:
|
||||
# kaldi binary file start with '\0B'
|
||||
if fid.read(2) == '\0B':
|
||||
logging.error('kaldi cmvn binary file is not supported, please '
|
||||
'recompute it by: compute-cmvn-stats --binary=false '
|
||||
' scp:feats.scp global_cmvn')
|
||||
sys.exit(1)
|
||||
fid.seek(0)
|
||||
arr = fid.read().split()
|
||||
assert (arr[0] == '[')
|
||||
assert (arr[-2] == '0')
|
||||
assert (arr[-1] == ']')
|
||||
feat_dim = int((len(arr) - 2 - 2) / 2)
|
||||
for i in range(1, feat_dim + 1):
|
||||
means.append(float(arr[i]))
|
||||
count = float(arr[feat_dim + 1])
|
||||
for i in range(feat_dim + 2, 2 * feat_dim + 2):
|
||||
variance.append(float(arr[i]))
|
||||
|
||||
for i in range(len(means)):
|
||||
means[i] /= count
|
||||
variance[i] = variance[i] / count - means[i] * means[i]
|
||||
if variance[i] < 1.0e-20:
|
||||
variance[i] = 1.0e-20
|
||||
variance[i] = 1.0 / math.sqrt(variance[i])
|
||||
cmvn = np.array([means, variance])
|
||||
return cmvn
|
||||
|
||||
|
||||
def load_cmvn(cmvn_file, is_json):
|
||||
if is_json:
|
||||
cmvn = _load_json_cmvn(cmvn_file)
|
||||
else:
|
||||
cmvn = _load_kaldi_cmvn(cmvn_file)
|
||||
return cmvn[0], cmvn[1]
|
186
language_model/wenet/utils/common.py
Normal file
186
language_model/wenet/utils/common.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""Unility functions for Transformer."""
|
||||
|
||||
import math
|
||||
from typing import Tuple, List
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
IGNORE_ID = -1
|
||||
|
||||
|
||||
def pad_list(xs: List[torch.Tensor], pad_value: int):
|
||||
"""Perform padding for the list of tensors.
|
||||
|
||||
Args:
|
||||
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
||||
pad_value (float): Value for padding.
|
||||
|
||||
Returns:
|
||||
Tensor: Padded tensor (B, Tmax, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
||||
>>> x
|
||||
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
||||
>>> pad_list(x, 0)
|
||||
tensor([[1., 1., 1., 1.],
|
||||
[1., 1., 0., 0.],
|
||||
[1., 0., 0., 0.]])
|
||||
|
||||
"""
|
||||
n_batch = len(xs)
|
||||
max_len = max([x.size(0) for x in xs])
|
||||
pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device)
|
||||
pad = pad.fill_(pad_value)
|
||||
for i in range(n_batch):
|
||||
pad[i, :xs[i].size(0)] = xs[i]
|
||||
|
||||
return pad
|
||||
|
||||
|
||||
def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int,
|
||||
ignore_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Add <sos> and <eos> labels.
|
||||
|
||||
Args:
|
||||
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)
|
||||
sos (int): index of <sos>
|
||||
eos (int): index of <eeos>
|
||||
ignore_id (int): index of padding
|
||||
|
||||
Returns:
|
||||
ys_in (torch.Tensor) : (B, Lmax + 1)
|
||||
ys_out (torch.Tensor) : (B, Lmax + 1)
|
||||
|
||||
Examples:
|
||||
>>> sos_id = 10
|
||||
>>> eos_id = 11
|
||||
>>> ignore_id = -1
|
||||
>>> ys_pad
|
||||
tensor([[ 1, 2, 3, 4, 5],
|
||||
[ 4, 5, 6, -1, -1],
|
||||
[ 7, 8, 9, -1, -1]], dtype=torch.int32)
|
||||
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
|
||||
>>> ys_in
|
||||
tensor([[10, 1, 2, 3, 4, 5],
|
||||
[10, 4, 5, 6, 11, 11],
|
||||
[10, 7, 8, 9, 11, 11]])
|
||||
>>> ys_out
|
||||
tensor([[ 1, 2, 3, 4, 5, 11],
|
||||
[ 4, 5, 6, 11, -1, -1],
|
||||
[ 7, 8, 9, 11, -1, -1]])
|
||||
"""
|
||||
_sos = torch.tensor([sos],
|
||||
dtype=torch.long,
|
||||
requires_grad=False,
|
||||
device=ys_pad.device)
|
||||
_eos = torch.tensor([eos],
|
||||
dtype=torch.long,
|
||||
requires_grad=False,
|
||||
device=ys_pad.device)
|
||||
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
|
||||
ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
|
||||
ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
|
||||
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
|
||||
|
||||
|
||||
def reverse_pad_list(ys_pad: torch.Tensor,
|
||||
ys_lens: torch.Tensor,
|
||||
pad_value: float = -1.0) -> torch.Tensor:
|
||||
"""Reverse padding for the list of tensors.
|
||||
|
||||
Args:
|
||||
ys_pad (tensor): The padded tensor (B, Tokenmax).
|
||||
ys_lens (tensor): The lens of token seqs (B)
|
||||
pad_value (int): Value for padding.
|
||||
|
||||
Returns:
|
||||
Tensor: Padded tensor (B, Tokenmax).
|
||||
|
||||
Examples:
|
||||
>>> x
|
||||
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
|
||||
>>> pad_list(x, 0)
|
||||
tensor([[4, 3, 2, 1],
|
||||
[7, 6, 5, 0],
|
||||
[9, 8, 0, 0]])
|
||||
|
||||
"""
|
||||
r_ys_pad = pad_sequence([(torch.flip(y.int()[:i], [0]))
|
||||
for y, i in zip(ys_pad, ys_lens)], True,
|
||||
pad_value)
|
||||
return r_ys_pad
|
||||
|
||||
|
||||
def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
|
||||
ignore_label: int) -> float:
|
||||
"""Calculate accuracy.
|
||||
|
||||
Args:
|
||||
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
||||
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
||||
ignore_label (int): Ignore label id.
|
||||
|
||||
Returns:
|
||||
float: Accuracy value (0.0 - 1.0).
|
||||
|
||||
"""
|
||||
pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
|
||||
pad_outputs.size(1)).argmax(2)
|
||||
mask = pad_targets != ignore_label
|
||||
numerator = torch.sum(
|
||||
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
||||
denominator = torch.sum(mask)
|
||||
return float(numerator) / float(denominator)
|
||||
|
||||
|
||||
def get_activation(act):
|
||||
"""Return activation function."""
|
||||
# Lazy load to avoid unused import
|
||||
from wenet.transformer.swish import Swish
|
||||
|
||||
activation_funcs = {
|
||||
"hardtanh": torch.nn.Hardtanh,
|
||||
"tanh": torch.nn.Tanh,
|
||||
"relu": torch.nn.ReLU,
|
||||
"selu": torch.nn.SELU,
|
||||
"swish": Swish,
|
||||
"gelu": torch.nn.GELU
|
||||
}
|
||||
|
||||
return activation_funcs[act]()
|
||||
|
||||
|
||||
def get_subsample(config):
|
||||
input_layer = config["encoder_conf"]["input_layer"]
|
||||
assert input_layer in ["conv2d", "conv2d6", "conv2d8"]
|
||||
if input_layer == "conv2d":
|
||||
return 4
|
||||
elif input_layer == "conv2d6":
|
||||
return 6
|
||||
elif input_layer == "conv2d8":
|
||||
return 8
|
||||
|
||||
|
||||
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
|
||||
new_hyp: List[int] = []
|
||||
cur = 0
|
||||
while cur < len(hyp):
|
||||
if hyp[cur] != 0:
|
||||
new_hyp.append(hyp[cur])
|
||||
prev = cur
|
||||
while cur < len(hyp) and hyp[cur] == hyp[prev]:
|
||||
cur += 1
|
||||
return new_hyp
|
||||
|
||||
|
||||
def log_add(args: List[int]) -> float:
|
||||
"""
|
||||
Stable log add
|
||||
"""
|
||||
if all(a == -float('inf') for a in args):
|
||||
return -float('inf')
|
||||
a_max = max(args)
|
||||
lsp = math.log(sum(math.exp(a - a_max) for a in args))
|
||||
return a_max + lsp
|
72
language_model/wenet/utils/ctc_util.py
Normal file
72
language_model/wenet/utils/ctc_util.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Copyright 2021 Mobvoi Inc. All Rights Reserved.
|
||||
# Author: binbinzhang@mobvoi.com (Di Wu)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
def insert_blank(label, blank_id=0):
|
||||
"""Insert blank token between every two label token."""
|
||||
label = np.expand_dims(label, 1)
|
||||
blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
|
||||
label = np.concatenate([blanks, label], axis=1)
|
||||
label = label.reshape(-1)
|
||||
label = np.append(label, label[0])
|
||||
return label
|
||||
|
||||
def forced_align(ctc_probs: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
blank_id=0) -> list:
|
||||
"""ctc forced alignment.
|
||||
|
||||
Args:
|
||||
torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D)
|
||||
torch.Tensor y: id sequence tensor 1d tensor (L)
|
||||
int blank_id: blank symbol index
|
||||
Returns:
|
||||
torch.Tensor: alignment result
|
||||
"""
|
||||
y_insert_blank = insert_blank(y, blank_id)
|
||||
|
||||
log_alpha = torch.zeros((ctc_probs.size(0), len(y_insert_blank)))
|
||||
log_alpha = log_alpha - float('inf') # log of zero
|
||||
state_path = (torch.zeros(
|
||||
(ctc_probs.size(0), len(y_insert_blank)), dtype=torch.int16) - 1
|
||||
) # state path
|
||||
|
||||
# init start state
|
||||
log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]]
|
||||
log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]]
|
||||
|
||||
for t in range(1, ctc_probs.size(0)):
|
||||
for s in range(len(y_insert_blank)):
|
||||
if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
|
||||
s] == y_insert_blank[s - 2]:
|
||||
candidates = torch.tensor(
|
||||
[log_alpha[t - 1, s], log_alpha[t - 1, s - 1]])
|
||||
prev_state = [s, s - 1]
|
||||
else:
|
||||
candidates = torch.tensor([
|
||||
log_alpha[t - 1, s],
|
||||
log_alpha[t - 1, s - 1],
|
||||
log_alpha[t - 1, s - 2],
|
||||
])
|
||||
prev_state = [s, s - 1, s - 2]
|
||||
log_alpha[t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]]
|
||||
state_path[t, s] = prev_state[torch.argmax(candidates)]
|
||||
|
||||
state_seq = -1 * torch.ones((ctc_probs.size(0), 1), dtype=torch.int16)
|
||||
|
||||
candidates = torch.tensor([
|
||||
log_alpha[-1, len(y_insert_blank) - 1],
|
||||
log_alpha[-1, len(y_insert_blank) - 2]
|
||||
])
|
||||
prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2]
|
||||
state_seq[-1] = prev_state[torch.argmax(candidates)]
|
||||
for t in range(ctc_probs.size(0) - 2, -1, -1):
|
||||
state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]
|
||||
|
||||
output_alignment = []
|
||||
for t in range(0, ctc_probs.size(0)):
|
||||
output_alignment.append(y_insert_blank[state_seq[t, 0]])
|
||||
|
||||
return output_alignment
|
135
language_model/wenet/utils/executor.py
Normal file
135
language_model/wenet/utils/executor.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# Copyright 2019 Mobvoi Inc. All Rights Reserved.
|
||||
# Author: binbinzhang@mobvoi.com (Binbin Zhang)
|
||||
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
# if your python version < 3.7 use the below one
|
||||
# from contextlib import suppress as nullcontext
|
||||
import torch
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
|
||||
class Executor:
|
||||
def __init__(self):
|
||||
self.step = 0
|
||||
|
||||
def train(self, model, optimizer, scheduler, data_loader, device, writer,
|
||||
args, scaler):
|
||||
''' Train one epoch
|
||||
'''
|
||||
model.train()
|
||||
clip = args.get('grad_clip', 50.0)
|
||||
log_interval = args.get('log_interval', 10)
|
||||
rank = args.get('rank', 0)
|
||||
accum_grad = args.get('accum_grad', 1)
|
||||
is_distributed = args.get('is_distributed', True)
|
||||
use_amp = args.get('use_amp', False)
|
||||
logging.info('using accumulate grad, new batch size is {} times'
|
||||
'larger than before'.format(accum_grad))
|
||||
if use_amp:
|
||||
assert scaler is not None
|
||||
num_seen_utts = 0
|
||||
num_total_batch = len(data_loader)
|
||||
for batch_idx, batch in enumerate(data_loader):
|
||||
key, feats, target, feats_lengths, target_lengths = batch
|
||||
feats = feats.to(device)
|
||||
target = target.to(device)
|
||||
feats_lengths = feats_lengths.to(device)
|
||||
target_lengths = target_lengths.to(device)
|
||||
num_utts = target_lengths.size(0)
|
||||
if num_utts == 0:
|
||||
continue
|
||||
context = None
|
||||
# Disable gradient synchronizations across DDP processes.
|
||||
# Within this context, gradients will be accumulated on module
|
||||
# variables, which will later be synchronized.
|
||||
if is_distributed and batch_idx % accum_grad != 0:
|
||||
context = model.no_sync
|
||||
# Used for single gpu training and DDP gradient synchronization
|
||||
# processes.
|
||||
else:
|
||||
context = nullcontext
|
||||
with context():
|
||||
# autocast context
|
||||
# The more details about amp can be found in
|
||||
# https://pytorch.org/docs/stable/notes/amp_examples.html
|
||||
with torch.cuda.amp.autocast(scaler is not None):
|
||||
loss, loss_att, loss_ctc = model(feats, feats_lengths,
|
||||
target, target_lengths)
|
||||
loss = loss / accum_grad
|
||||
if use_amp:
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
num_seen_utts += num_utts
|
||||
if batch_idx % accum_grad == 0:
|
||||
if rank == 0 and writer is not None:
|
||||
writer.add_scalar('train_loss', loss, self.step)
|
||||
# Use mixed precision training
|
||||
if use_amp:
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
||||
# Must invoke scaler.update() if unscale_() is used in the
|
||||
# iteration to avoid the following error:
|
||||
# RuntimeError: unscale_() has already been called
|
||||
# on this optimizer since the last update().
|
||||
# We don't check grad here since that if the gradient has
|
||||
# inf/nan values, scaler.step will skip optimizer.step().
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
||||
if torch.isfinite(grad_norm):
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
scheduler.step()
|
||||
self.step += 1
|
||||
if batch_idx % log_interval == 0:
|
||||
lr = optimizer.param_groups[0]['lr']
|
||||
log_str = 'TRAIN Batch {}/{} loss {:.6f} '.format(
|
||||
batch_idx, num_total_batch,
|
||||
loss.item() * accum_grad)
|
||||
if loss_att is not None:
|
||||
log_str += 'loss_att {:.6f} '.format(loss_att.item())
|
||||
if loss_ctc is not None:
|
||||
log_str += 'loss_ctc {:.6f} '.format(loss_ctc.item())
|
||||
log_str += 'lr {:.8f} rank {}'.format(lr, rank)
|
||||
logging.debug(log_str)
|
||||
|
||||
def cv(self, model, data_loader, device, args):
|
||||
''' Cross validation on
|
||||
'''
|
||||
model.eval()
|
||||
log_interval = args.get('log_interval', 10)
|
||||
# in order to avoid division by 0
|
||||
num_seen_utts = 1
|
||||
total_loss = 0.0
|
||||
num_total_batch = len(data_loader)
|
||||
with torch.no_grad():
|
||||
for batch_idx, batch in enumerate(data_loader):
|
||||
key, feats, target, feats_lengths, target_lengths = batch
|
||||
feats = feats.to(device)
|
||||
target = target.to(device)
|
||||
feats_lengths = feats_lengths.to(device)
|
||||
target_lengths = target_lengths.to(device)
|
||||
num_utts = target_lengths.size(0)
|
||||
if num_utts == 0:
|
||||
continue
|
||||
loss, loss_att, loss_ctc = model(feats, feats_lengths, target,
|
||||
target_lengths)
|
||||
if torch.isfinite(loss):
|
||||
num_seen_utts += num_utts
|
||||
total_loss += loss.item() * num_utts
|
||||
if batch_idx % log_interval == 0:
|
||||
log_str = 'CV Batch {}/{} loss {:.6f} '.format(
|
||||
batch_idx, num_total_batch, loss.item())
|
||||
if loss_att is not None:
|
||||
log_str += 'loss_att {:.6f} '.format(loss_att.item())
|
||||
if loss_ctc is not None:
|
||||
log_str += 'loss_ctc {:.6f} '.format(loss_ctc.item())
|
||||
log_str += 'history loss {:.6f}'.format(total_loss /
|
||||
num_seen_utts)
|
||||
logging.debug(log_str)
|
||||
|
||||
return total_loss, num_seen_utts
|
251
language_model/wenet/utils/mask.py
Normal file
251
language_model/wenet/utils/mask.py
Normal file
@@ -0,0 +1,251 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def subsequent_mask(
|
||||
size: int,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> torch.Tensor:
|
||||
"""Create mask for subsequent steps (size, size).
|
||||
|
||||
This mask is used only in decoder which works in an auto-regressive mode.
|
||||
This means the current step could only do attention with its left steps.
|
||||
|
||||
In encoder, fully attention is used when streaming is not necessary and
|
||||
the sequence is not long. In this case, no attention mask is needed.
|
||||
|
||||
When streaming is need, chunk-based attention is used in encoder. See
|
||||
subsequent_chunk_mask for the chunk-based attention mask.
|
||||
|
||||
Args:
|
||||
size (int): size of mask
|
||||
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
||||
dtype (torch.device): result dtype
|
||||
|
||||
Returns:
|
||||
torch.Tensor: mask
|
||||
|
||||
Examples:
|
||||
>>> subsequent_mask(3)
|
||||
[[1, 0, 0],
|
||||
[1, 1, 0],
|
||||
[1, 1, 1]]
|
||||
"""
|
||||
ret = torch.ones(size, size, device=device, dtype=torch.bool)
|
||||
return torch.tril(ret, out=ret)
|
||||
|
||||
|
||||
def subsequent_chunk_mask(
|
||||
size: int,
|
||||
chunk_size: int,
|
||||
num_left_chunks: int = -1,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> torch.Tensor:
|
||||
"""Create mask for subsequent steps (size, size) with chunk size,
|
||||
this is for streaming encoder
|
||||
|
||||
Args:
|
||||
size (int): size of mask
|
||||
chunk_size (int): size of chunk
|
||||
num_left_chunks (int): number of left chunks
|
||||
<0: use full chunk
|
||||
>=0: use num_left_chunks
|
||||
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
||||
|
||||
Returns:
|
||||
torch.Tensor: mask
|
||||
|
||||
Examples:
|
||||
>>> subsequent_chunk_mask(4, 2)
|
||||
[[1, 1, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[1, 1, 1, 1],
|
||||
[1, 1, 1, 1]]
|
||||
"""
|
||||
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
||||
for i in range(size):
|
||||
if num_left_chunks < 0:
|
||||
start = 0
|
||||
else:
|
||||
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
||||
ending = min((i // chunk_size + 1) * chunk_size, size)
|
||||
ret[i, start:ending] = True
|
||||
return ret
|
||||
|
||||
|
||||
def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
|
||||
use_dynamic_chunk: bool,
|
||||
use_dynamic_left_chunk: bool,
|
||||
decoding_chunk_size: int, static_chunk_size: int,
|
||||
num_decoding_left_chunks: int):
|
||||
""" Apply optional mask for encoder.
|
||||
|
||||
Args:
|
||||
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
||||
mask (torch.Tensor): mask for xs, (B, 1, L)
|
||||
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
||||
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
||||
training.
|
||||
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
||||
0: default for training, use random dynamic chunk.
|
||||
<0: for decoding, use full chunk.
|
||||
>0: for decoding, use fixed chunk size as set.
|
||||
static_chunk_size (int): chunk size for static chunk training/decoding
|
||||
if it's greater than 0, if use_dynamic_chunk is true,
|
||||
this parameter will be ignored
|
||||
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
||||
the chunk size is decoding_chunk_size.
|
||||
>=0: use num_decoding_left_chunks
|
||||
<0: use all left chunks
|
||||
|
||||
Returns:
|
||||
torch.Tensor: chunk mask of the input xs.
|
||||
"""
|
||||
# Whether to use chunk mask or not
|
||||
if use_dynamic_chunk:
|
||||
max_len = xs.size(1)
|
||||
if decoding_chunk_size < 0:
|
||||
chunk_size = max_len
|
||||
num_left_chunks = -1
|
||||
elif decoding_chunk_size > 0:
|
||||
chunk_size = decoding_chunk_size
|
||||
num_left_chunks = num_decoding_left_chunks
|
||||
else:
|
||||
# chunk size is either [1, 25] or full context(max_len).
|
||||
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
||||
# delay, the maximum frame is 100 / 4 = 25.
|
||||
chunk_size = torch.randint(1, max_len, (1, )).item()
|
||||
num_left_chunks = -1
|
||||
if chunk_size > max_len // 2:
|
||||
chunk_size = max_len
|
||||
else:
|
||||
chunk_size = chunk_size % 25 + 1
|
||||
if use_dynamic_left_chunk:
|
||||
max_left_chunks = (max_len - 1) // chunk_size
|
||||
num_left_chunks = torch.randint(0, max_left_chunks,
|
||||
(1, )).item()
|
||||
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
||||
num_left_chunks,
|
||||
xs.device) # (L, L)
|
||||
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
||||
chunk_masks = masks & chunk_masks # (B, L, L)
|
||||
elif static_chunk_size > 0:
|
||||
num_left_chunks = num_decoding_left_chunks
|
||||
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
||||
num_left_chunks,
|
||||
xs.device) # (L, L)
|
||||
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
||||
chunk_masks = masks & chunk_masks # (B, L, L)
|
||||
else:
|
||||
chunk_masks = masks
|
||||
return chunk_masks
|
||||
|
||||
|
||||
def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""Make mask tensor containing indices of padded part.
|
||||
|
||||
See description of make_non_pad_mask.
|
||||
|
||||
Args:
|
||||
lengths (torch.Tensor): Batch of lengths (B,).
|
||||
Returns:
|
||||
torch.Tensor: Mask tensor containing indices of padded part.
|
||||
|
||||
Examples:
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_pad_mask(lengths)
|
||||
masks = [[0, 0, 0, 0 ,0],
|
||||
[0, 0, 0, 1, 1],
|
||||
[0, 0, 1, 1, 1]]
|
||||
"""
|
||||
batch_size = int(lengths.size(0))
|
||||
max_len = int(lengths.max().item())
|
||||
seq_range = torch.arange(0,
|
||||
max_len,
|
||||
dtype=torch.int64,
|
||||
device=lengths.device)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
seq_length_expand = lengths.unsqueeze(-1)
|
||||
mask = seq_range_expand >= seq_length_expand
|
||||
return mask
|
||||
|
||||
|
||||
def make_non_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
|
||||
"""Make mask tensor containing indices of non-padded part.
|
||||
|
||||
The sequences in a batch may have different lengths. To enable
|
||||
batch computing, padding is need to make all sequence in same
|
||||
size. To avoid the padding part pass value to context dependent
|
||||
block such as attention or convolution , this padding part is
|
||||
masked.
|
||||
|
||||
This pad_mask is used in both encoder and decoder.
|
||||
|
||||
1 for non-padded part and 0 for padded part.
|
||||
|
||||
Args:
|
||||
lengths (torch.Tensor): Batch of lengths (B,).
|
||||
Returns:
|
||||
torch.Tensor: mask tensor containing indices of padded part.
|
||||
|
||||
Examples:
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_non_pad_mask(lengths)
|
||||
masks = [[1, 1, 1, 1 ,1],
|
||||
[1, 1, 1, 0, 0],
|
||||
[1, 1, 0, 0, 0]]
|
||||
"""
|
||||
return ~make_pad_mask(lengths)
|
||||
|
||||
|
||||
def mask_finished_scores(score: torch.Tensor,
|
||||
flag: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
If a sequence is finished, we only allow one alive branch. This function
|
||||
aims to give one branch a zero score and the rest -inf score.
|
||||
|
||||
Args:
|
||||
score (torch.Tensor): A real value array with shape
|
||||
(batch_size * beam_size, beam_size).
|
||||
flag (torch.Tensor): A bool array with shape
|
||||
(batch_size * beam_size, 1).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: (batch_size * beam_size, beam_size).
|
||||
"""
|
||||
beam_size = score.size(-1)
|
||||
zero_mask = torch.zeros_like(flag, dtype=torch.bool)
|
||||
if beam_size > 1:
|
||||
unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])),
|
||||
dim=1)
|
||||
finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])),
|
||||
dim=1)
|
||||
else:
|
||||
unfinished = zero_mask
|
||||
finished = flag
|
||||
score.masked_fill_(unfinished, -float('inf'))
|
||||
score.masked_fill_(finished, 0)
|
||||
return score
|
||||
|
||||
|
||||
def mask_finished_preds(pred: torch.Tensor, flag: torch.Tensor,
|
||||
eos: int) -> torch.Tensor:
|
||||
"""
|
||||
If a sequence is finished, all of its branch should be <eos>
|
||||
|
||||
Args:
|
||||
pred (torch.Tensor): A int array with shape
|
||||
(batch_size * beam_size, beam_size).
|
||||
flag (torch.Tensor): A bool array with shape
|
||||
(batch_size * beam_size, 1).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: (batch_size * beam_size).
|
||||
"""
|
||||
beam_size = pred.size(-1)
|
||||
finished = flag.repeat([1, beam_size])
|
||||
return pred.masked_fill_(finished, eos)
|
52
language_model/wenet/utils/scheduler.py
Normal file
52
language_model/wenet/utils/scheduler.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
from typeguard import check_argument_types
|
||||
|
||||
|
||||
class WarmupLR(_LRScheduler):
|
||||
"""The WarmupLR scheduler
|
||||
|
||||
This scheduler is almost same as NoamLR Scheduler except for following
|
||||
difference:
|
||||
|
||||
NoamLR:
|
||||
lr = optimizer.lr * model_size ** -0.5
|
||||
* min(step ** -0.5, step * warmup_step ** -1.5)
|
||||
WarmupLR:
|
||||
lr = optimizer.lr * warmup_step ** 0.5
|
||||
* min(step ** -0.5, step * warmup_step ** -1.5)
|
||||
|
||||
Note that the maximum lr equals to optimizer.lr in this scheduler.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
warmup_steps: Union[int, float] = 25000,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
assert check_argument_types()
|
||||
self.warmup_steps = warmup_steps
|
||||
|
||||
# __init__() must be invoked before setting field
|
||||
# because step() is also invoked in __init__()
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
|
||||
|
||||
def get_lr(self):
|
||||
step_num = self.last_epoch + 1
|
||||
return [
|
||||
lr
|
||||
* self.warmup_steps ** 0.5
|
||||
* min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5)
|
||||
for lr in self.base_lrs
|
||||
]
|
||||
|
||||
def set_step(self, step: int):
|
||||
self.last_epoch = step
|
Reference in New Issue
Block a user