competition update

This commit is contained in:
nckcard
2025-07-02 12:18:09 -07:00
parent 9e17716a4a
commit 77dbcf868f
2615 changed files with 1648116 additions and 125 deletions

View File

@@ -0,0 +1,46 @@
# Copyright 2019 Mobvoi Inc. All Rights Reserved.
# Author: binbinzhang@mobvoi.com (Binbin Zhang)
import logging
import os
import re
import yaml
import torch
def load_checkpoint(model: torch.nn.Module, path: str) -> dict:
if torch.cuda.is_available():
logging.info('Checkpoint: loading from checkpoint %s for GPU' % path)
checkpoint = torch.load(path)
else:
logging.info('Checkpoint: loading from checkpoint %s for CPU' % path)
checkpoint = torch.load(path, map_location='cpu')
model.load_state_dict(checkpoint)
info_path = re.sub('.pt$', '.yaml', path)
configs = {}
if os.path.exists(info_path):
with open(info_path, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
return configs
def save_checkpoint(model: torch.nn.Module, path: str, infos=None):
'''
Args:
infos (dict or None): any info you want to save.
'''
logging.info('Checkpoint: save to checkpoint %s' % path)
if isinstance(model, torch.nn.DataParallel):
state_dict = model.module.state_dict()
elif isinstance(model, torch.nn.parallel.DistributedDataParallel):
state_dict = model.module.state_dict()
else:
state_dict = model.state_dict()
torch.save(state_dict, path)
info_path = re.sub('.pt$', '.yaml', path)
if infos is None:
infos = {}
with open(info_path, 'w') as fout:
data = yaml.dump(infos)
fout.write(data)

View File

@@ -0,0 +1,94 @@
#!/usr/bin/env python3
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
import numpy as np
def _load_json_cmvn(json_cmvn_file):
""" Load the json format cmvn stats file and calculate cmvn
Args:
json_cmvn_file: cmvn stats file in json format
Returns:
a numpy array of [means, vars]
"""
with open(json_cmvn_file) as f:
cmvn_stats = json.load(f)
means = cmvn_stats['mean_stat']
variance = cmvn_stats['var_stat']
count = cmvn_stats['frame_num']
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def _load_kaldi_cmvn(kaldi_cmvn_file):
""" Load the kaldi format cmvn stats file and calculate cmvn
Args:
kaldi_cmvn_file: kaldi text style global cmvn file, which
is generated by:
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
Returns:
a numpy array of [means, vars]
"""
means = []
variance = []
with open(kaldi_cmvn_file, 'r') as fid:
# kaldi binary file start with '\0B'
if fid.read(2) == '\0B':
logging.error('kaldi cmvn binary file is not supported, please '
'recompute it by: compute-cmvn-stats --binary=false '
' scp:feats.scp global_cmvn')
sys.exit(1)
fid.seek(0)
arr = fid.read().split()
assert (arr[0] == '[')
assert (arr[-2] == '0')
assert (arr[-1] == ']')
feat_dim = int((len(arr) - 2 - 2) / 2)
for i in range(1, feat_dim + 1):
means.append(float(arr[i]))
count = float(arr[feat_dim + 1])
for i in range(feat_dim + 2, 2 * feat_dim + 2):
variance.append(float(arr[i]))
for i in range(len(means)):
means[i] /= count
variance[i] = variance[i] / count - means[i] * means[i]
if variance[i] < 1.0e-20:
variance[i] = 1.0e-20
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn
def load_cmvn(cmvn_file, is_json):
if is_json:
cmvn = _load_json_cmvn(cmvn_file)
else:
cmvn = _load_kaldi_cmvn(cmvn_file)
return cmvn[0], cmvn[1]

View File

@@ -0,0 +1,186 @@
"""Unility functions for Transformer."""
import math
from typing import Tuple, List
import torch
from torch.nn.utils.rnn import pad_sequence
IGNORE_ID = -1
def pad_list(xs: List[torch.Tensor], pad_value: int):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
n_batch = len(xs)
max_len = max([x.size(0) for x in xs])
pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device)
pad = pad.fill_(pad_value)
for i in range(n_batch):
pad[i, :xs[i].size(0)] = xs[i]
return pad
def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int,
ignore_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Add <sos> and <eos> labels.
Args:
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)
sos (int): index of <sos>
eos (int): index of <eeos>
ignore_id (int): index of padding
Returns:
ys_in (torch.Tensor) : (B, Lmax + 1)
ys_out (torch.Tensor) : (B, Lmax + 1)
Examples:
>>> sos_id = 10
>>> eos_id = 11
>>> ignore_id = -1
>>> ys_pad
tensor([[ 1, 2, 3, 4, 5],
[ 4, 5, 6, -1, -1],
[ 7, 8, 9, -1, -1]], dtype=torch.int32)
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
>>> ys_in
tensor([[10, 1, 2, 3, 4, 5],
[10, 4, 5, 6, 11, 11],
[10, 7, 8, 9, 11, 11]])
>>> ys_out
tensor([[ 1, 2, 3, 4, 5, 11],
[ 4, 5, 6, 11, -1, -1],
[ 7, 8, 9, 11, -1, -1]])
"""
_sos = torch.tensor([sos],
dtype=torch.long,
requires_grad=False,
device=ys_pad.device)
_eos = torch.tensor([eos],
dtype=torch.long,
requires_grad=False,
device=ys_pad.device)
ys = [y[y != ignore_id] for y in ys_pad] # parse padded ys
ys_in = [torch.cat([_sos, y], dim=0) for y in ys]
ys_out = [torch.cat([y, _eos], dim=0) for y in ys]
return pad_list(ys_in, eos), pad_list(ys_out, ignore_id)
def reverse_pad_list(ys_pad: torch.Tensor,
ys_lens: torch.Tensor,
pad_value: float = -1.0) -> torch.Tensor:
"""Reverse padding for the list of tensors.
Args:
ys_pad (tensor): The padded tensor (B, Tokenmax).
ys_lens (tensor): The lens of token seqs (B)
pad_value (int): Value for padding.
Returns:
Tensor: Padded tensor (B, Tokenmax).
Examples:
>>> x
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
>>> pad_list(x, 0)
tensor([[4, 3, 2, 1],
[7, 6, 5, 0],
[9, 8, 0, 0]])
"""
r_ys_pad = pad_sequence([(torch.flip(y.int()[:i], [0]))
for y, i in zip(ys_pad, ys_lens)], True,
pad_value)
return r_ys_pad
def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
ignore_label: int) -> float:
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
pad_outputs.size(1)).argmax(2)
mask = pad_targets != ignore_label
numerator = torch.sum(
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
denominator = torch.sum(mask)
return float(numerator) / float(denominator)
def get_activation(act):
"""Return activation function."""
# Lazy load to avoid unused import
from wenet.transformer.swish import Swish
activation_funcs = {
"hardtanh": torch.nn.Hardtanh,
"tanh": torch.nn.Tanh,
"relu": torch.nn.ReLU,
"selu": torch.nn.SELU,
"swish": Swish,
"gelu": torch.nn.GELU
}
return activation_funcs[act]()
def get_subsample(config):
input_layer = config["encoder_conf"]["input_layer"]
assert input_layer in ["conv2d", "conv2d6", "conv2d8"]
if input_layer == "conv2d":
return 4
elif input_layer == "conv2d6":
return 6
elif input_layer == "conv2d8":
return 8
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != 0:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp
def log_add(args: List[int]) -> float:
"""
Stable log add
"""
if all(a == -float('inf') for a in args):
return -float('inf')
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max) for a in args))
return a_max + lsp

View File

@@ -0,0 +1,72 @@
# Copyright 2021 Mobvoi Inc. All Rights Reserved.
# Author: binbinzhang@mobvoi.com (Di Wu)
import numpy as np
import torch
def insert_blank(label, blank_id=0):
"""Insert blank token between every two label token."""
label = np.expand_dims(label, 1)
blanks = np.zeros((label.shape[0], 1), dtype=np.int64) + blank_id
label = np.concatenate([blanks, label], axis=1)
label = label.reshape(-1)
label = np.append(label, label[0])
return label
def forced_align(ctc_probs: torch.Tensor,
y: torch.Tensor,
blank_id=0) -> list:
"""ctc forced alignment.
Args:
torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D)
torch.Tensor y: id sequence tensor 1d tensor (L)
int blank_id: blank symbol index
Returns:
torch.Tensor: alignment result
"""
y_insert_blank = insert_blank(y, blank_id)
log_alpha = torch.zeros((ctc_probs.size(0), len(y_insert_blank)))
log_alpha = log_alpha - float('inf') # log of zero
state_path = (torch.zeros(
(ctc_probs.size(0), len(y_insert_blank)), dtype=torch.int16) - 1
) # state path
# init start state
log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]]
log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]]
for t in range(1, ctc_probs.size(0)):
for s in range(len(y_insert_blank)):
if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[
s] == y_insert_blank[s - 2]:
candidates = torch.tensor(
[log_alpha[t - 1, s], log_alpha[t - 1, s - 1]])
prev_state = [s, s - 1]
else:
candidates = torch.tensor([
log_alpha[t - 1, s],
log_alpha[t - 1, s - 1],
log_alpha[t - 1, s - 2],
])
prev_state = [s, s - 1, s - 2]
log_alpha[t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]]
state_path[t, s] = prev_state[torch.argmax(candidates)]
state_seq = -1 * torch.ones((ctc_probs.size(0), 1), dtype=torch.int16)
candidates = torch.tensor([
log_alpha[-1, len(y_insert_blank) - 1],
log_alpha[-1, len(y_insert_blank) - 2]
])
prev_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2]
state_seq[-1] = prev_state[torch.argmax(candidates)]
for t in range(ctc_probs.size(0) - 2, -1, -1):
state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]]
output_alignment = []
for t in range(0, ctc_probs.size(0)):
output_alignment.append(y_insert_blank[state_seq[t, 0]])
return output_alignment

View File

@@ -0,0 +1,135 @@
# Copyright 2019 Mobvoi Inc. All Rights Reserved.
# Author: binbinzhang@mobvoi.com (Binbin Zhang)
import logging
from contextlib import nullcontext
# if your python version < 3.7 use the below one
# from contextlib import suppress as nullcontext
import torch
from torch.nn.utils import clip_grad_norm_
class Executor:
def __init__(self):
self.step = 0
def train(self, model, optimizer, scheduler, data_loader, device, writer,
args, scaler):
''' Train one epoch
'''
model.train()
clip = args.get('grad_clip', 50.0)
log_interval = args.get('log_interval', 10)
rank = args.get('rank', 0)
accum_grad = args.get('accum_grad', 1)
is_distributed = args.get('is_distributed', True)
use_amp = args.get('use_amp', False)
logging.info('using accumulate grad, new batch size is {} times'
'larger than before'.format(accum_grad))
if use_amp:
assert scaler is not None
num_seen_utts = 0
num_total_batch = len(data_loader)
for batch_idx, batch in enumerate(data_loader):
key, feats, target, feats_lengths, target_lengths = batch
feats = feats.to(device)
target = target.to(device)
feats_lengths = feats_lengths.to(device)
target_lengths = target_lengths.to(device)
num_utts = target_lengths.size(0)
if num_utts == 0:
continue
context = None
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
if is_distributed and batch_idx % accum_grad != 0:
context = model.no_sync
# Used for single gpu training and DDP gradient synchronization
# processes.
else:
context = nullcontext
with context():
# autocast context
# The more details about amp can be found in
# https://pytorch.org/docs/stable/notes/amp_examples.html
with torch.cuda.amp.autocast(scaler is not None):
loss, loss_att, loss_ctc = model(feats, feats_lengths,
target, target_lengths)
loss = loss / accum_grad
if use_amp:
scaler.scale(loss).backward()
else:
loss.backward()
num_seen_utts += num_utts
if batch_idx % accum_grad == 0:
if rank == 0 and writer is not None:
writer.add_scalar('train_loss', loss, self.step)
# Use mixed precision training
if use_amp:
scaler.unscale_(optimizer)
grad_norm = clip_grad_norm_(model.parameters(), clip)
# Must invoke scaler.update() if unscale_() is used in the
# iteration to avoid the following error:
# RuntimeError: unscale_() has already been called
# on this optimizer since the last update().
# We don't check grad here since that if the gradient has
# inf/nan values, scaler.step will skip optimizer.step().
scaler.step(optimizer)
scaler.update()
else:
grad_norm = clip_grad_norm_(model.parameters(), clip)
if torch.isfinite(grad_norm):
optimizer.step()
optimizer.zero_grad()
scheduler.step()
self.step += 1
if batch_idx % log_interval == 0:
lr = optimizer.param_groups[0]['lr']
log_str = 'TRAIN Batch {}/{} loss {:.6f} '.format(
batch_idx, num_total_batch,
loss.item() * accum_grad)
if loss_att is not None:
log_str += 'loss_att {:.6f} '.format(loss_att.item())
if loss_ctc is not None:
log_str += 'loss_ctc {:.6f} '.format(loss_ctc.item())
log_str += 'lr {:.8f} rank {}'.format(lr, rank)
logging.debug(log_str)
def cv(self, model, data_loader, device, args):
''' Cross validation on
'''
model.eval()
log_interval = args.get('log_interval', 10)
# in order to avoid division by 0
num_seen_utts = 1
total_loss = 0.0
num_total_batch = len(data_loader)
with torch.no_grad():
for batch_idx, batch in enumerate(data_loader):
key, feats, target, feats_lengths, target_lengths = batch
feats = feats.to(device)
target = target.to(device)
feats_lengths = feats_lengths.to(device)
target_lengths = target_lengths.to(device)
num_utts = target_lengths.size(0)
if num_utts == 0:
continue
loss, loss_att, loss_ctc = model(feats, feats_lengths, target,
target_lengths)
if torch.isfinite(loss):
num_seen_utts += num_utts
total_loss += loss.item() * num_utts
if batch_idx % log_interval == 0:
log_str = 'CV Batch {}/{} loss {:.6f} '.format(
batch_idx, num_total_batch, loss.item())
if loss_att is not None:
log_str += 'loss_att {:.6f} '.format(loss_att.item())
if loss_ctc is not None:
log_str += 'loss_ctc {:.6f} '.format(loss_ctc.item())
log_str += 'history loss {:.6f}'.format(total_loss /
num_seen_utts)
logging.debug(log_str)
return total_loss, num_seen_utts

View File

@@ -0,0 +1,251 @@
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import torch
def subsequent_mask(
size: int,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size).
This mask is used only in decoder which works in an auto-regressive mode.
This means the current step could only do attention with its left steps.
In encoder, fully attention is used when streaming is not necessary and
the sequence is not long. In this case, no attention mask is needed.
When streaming is need, chunk-based attention is used in encoder. See
subsequent_chunk_mask for the chunk-based attention mask.
Args:
size (int): size of mask
str device (str): "cpu" or "cuda" or torch.Tensor.device
dtype (torch.device): result dtype
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_mask(3)
[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
"""
ret = torch.ones(size, size, device=device, dtype=torch.bool)
return torch.tril(ret, out=ret)
def subsequent_chunk_mask(
size: int,
chunk_size: int,
num_left_chunks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
num_left_chunks (int): number of left chunks
<0: use full chunk
>=0: use num_left_chunks
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_chunk_mask(4, 2)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1]]
"""
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if num_left_chunks < 0:
start = 0
else:
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
ending = min((i // chunk_size + 1) * chunk_size, size)
ret[i, start:ending] = True
return ret
def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
use_dynamic_chunk: bool,
use_dynamic_left_chunk: bool,
decoding_chunk_size: int, static_chunk_size: int,
num_decoding_left_chunks: int):
""" Apply optional mask for encoder.
Args:
xs (torch.Tensor): padded input, (B, L, D), L for max length
mask (torch.Tensor): mask for xs, (B, 1, L)
use_dynamic_chunk (bool): whether to use dynamic chunk or not
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
training.
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
static_chunk_size (int): chunk size for static chunk training/decoding
if it's greater than 0, if use_dynamic_chunk is true,
this parameter will be ignored
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
torch.Tensor: chunk mask of the input xs.
"""
# Whether to use chunk mask or not
if use_dynamic_chunk:
max_len = xs.size(1)
if decoding_chunk_size < 0:
chunk_size = max_len
num_left_chunks = -1
elif decoding_chunk_size > 0:
chunk_size = decoding_chunk_size
num_left_chunks = num_decoding_left_chunks
else:
# chunk size is either [1, 25] or full context(max_len).
# Since we use 4 times subsampling and allow up to 1s(100 frames)
# delay, the maximum frame is 100 / 4 = 25.
chunk_size = torch.randint(1, max_len, (1, )).item()
num_left_chunks = -1
if chunk_size > max_len // 2:
chunk_size = max_len
else:
chunk_size = chunk_size % 25 + 1
if use_dynamic_left_chunk:
max_left_chunks = (max_len - 1) // chunk_size
num_left_chunks = torch.randint(0, max_left_chunks,
(1, )).item()
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
elif static_chunk_size > 0:
num_left_chunks = num_decoding_left_chunks
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
else:
chunk_masks = masks
return chunk_masks
def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size = int(lengths.size(0))
max_len = int(lengths.max().item())
seq_range = torch.arange(0,
max_len,
dtype=torch.int64,
device=lengths.device)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
def make_non_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Make mask tensor containing indices of non-padded part.
The sequences in a batch may have different lengths. To enable
batch computing, padding is need to make all sequence in same
size. To avoid the padding part pass value to context dependent
block such as attention or convolution , this padding part is
masked.
This pad_mask is used in both encoder and decoder.
1 for non-padded part and 0 for padded part.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
"""
return ~make_pad_mask(lengths)
def mask_finished_scores(score: torch.Tensor,
flag: torch.Tensor) -> torch.Tensor:
"""
If a sequence is finished, we only allow one alive branch. This function
aims to give one branch a zero score and the rest -inf score.
Args:
score (torch.Tensor): A real value array with shape
(batch_size * beam_size, beam_size).
flag (torch.Tensor): A bool array with shape
(batch_size * beam_size, 1).
Returns:
torch.Tensor: (batch_size * beam_size, beam_size).
"""
beam_size = score.size(-1)
zero_mask = torch.zeros_like(flag, dtype=torch.bool)
if beam_size > 1:
unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])),
dim=1)
finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])),
dim=1)
else:
unfinished = zero_mask
finished = flag
score.masked_fill_(unfinished, -float('inf'))
score.masked_fill_(finished, 0)
return score
def mask_finished_preds(pred: torch.Tensor, flag: torch.Tensor,
eos: int) -> torch.Tensor:
"""
If a sequence is finished, all of its branch should be <eos>
Args:
pred (torch.Tensor): A int array with shape
(batch_size * beam_size, beam_size).
flag (torch.Tensor): A bool array with shape
(batch_size * beam_size, 1).
Returns:
torch.Tensor: (batch_size * beam_size).
"""
beam_size = pred.size(-1)
finished = flag.repeat([1, beam_size])
return pred.masked_fill_(finished, eos)

View File

@@ -0,0 +1,52 @@
from typing import Union
import torch
from torch.optim.lr_scheduler import _LRScheduler
from typeguard import check_argument_types
class WarmupLR(_LRScheduler):
"""The WarmupLR scheduler
This scheduler is almost same as NoamLR Scheduler except for following
difference:
NoamLR:
lr = optimizer.lr * model_size ** -0.5
* min(step ** -0.5, step * warmup_step ** -1.5)
WarmupLR:
lr = optimizer.lr * warmup_step ** 0.5
* min(step ** -0.5, step * warmup_step ** -1.5)
Note that the maximum lr equals to optimizer.lr in this scheduler.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_steps: Union[int, float] = 25000,
last_epoch: int = -1,
):
assert check_argument_types()
self.warmup_steps = warmup_steps
# __init__() must be invoked before setting field
# because step() is also invoked in __init__()
super().__init__(optimizer, last_epoch)
def __repr__(self):
return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
def get_lr(self):
step_num = self.last_epoch + 1
return [
lr
* self.warmup_steps ** 0.5
* min(step_num ** -0.5, step_num * self.warmup_steps ** -1.5)
for lr in self.base_lrs
]
def set_step(self, step: int):
self.last_epoch = step