136 lines
6.0 KiB
Python
136 lines
6.0 KiB
Python
# Copyright 2019 Mobvoi Inc. All Rights Reserved.
|
|
# Author: binbinzhang@mobvoi.com (Binbin Zhang)
|
|
|
|
import logging
|
|
from contextlib import nullcontext
|
|
# if your python version < 3.7 use the below one
|
|
# from contextlib import suppress as nullcontext
|
|
import torch
|
|
from torch.nn.utils import clip_grad_norm_
|
|
|
|
|
|
class Executor:
|
|
def __init__(self):
|
|
self.step = 0
|
|
|
|
def train(self, model, optimizer, scheduler, data_loader, device, writer,
|
|
args, scaler):
|
|
''' Train one epoch
|
|
'''
|
|
model.train()
|
|
clip = args.get('grad_clip', 50.0)
|
|
log_interval = args.get('log_interval', 10)
|
|
rank = args.get('rank', 0)
|
|
accum_grad = args.get('accum_grad', 1)
|
|
is_distributed = args.get('is_distributed', True)
|
|
use_amp = args.get('use_amp', False)
|
|
logging.info('using accumulate grad, new batch size is {} times'
|
|
'larger than before'.format(accum_grad))
|
|
if use_amp:
|
|
assert scaler is not None
|
|
num_seen_utts = 0
|
|
num_total_batch = len(data_loader)
|
|
for batch_idx, batch in enumerate(data_loader):
|
|
key, feats, target, feats_lengths, target_lengths = batch
|
|
feats = feats.to(device)
|
|
target = target.to(device)
|
|
feats_lengths = feats_lengths.to(device)
|
|
target_lengths = target_lengths.to(device)
|
|
num_utts = target_lengths.size(0)
|
|
if num_utts == 0:
|
|
continue
|
|
context = None
|
|
# Disable gradient synchronizations across DDP processes.
|
|
# Within this context, gradients will be accumulated on module
|
|
# variables, which will later be synchronized.
|
|
if is_distributed and batch_idx % accum_grad != 0:
|
|
context = model.no_sync
|
|
# Used for single gpu training and DDP gradient synchronization
|
|
# processes.
|
|
else:
|
|
context = nullcontext
|
|
with context():
|
|
# autocast context
|
|
# The more details about amp can be found in
|
|
# https://pytorch.org/docs/stable/notes/amp_examples.html
|
|
with torch.cuda.amp.autocast(scaler is not None):
|
|
loss, loss_att, loss_ctc = model(feats, feats_lengths,
|
|
target, target_lengths)
|
|
loss = loss / accum_grad
|
|
if use_amp:
|
|
scaler.scale(loss).backward()
|
|
else:
|
|
loss.backward()
|
|
|
|
num_seen_utts += num_utts
|
|
if batch_idx % accum_grad == 0:
|
|
if rank == 0 and writer is not None:
|
|
writer.add_scalar('train_loss', loss, self.step)
|
|
# Use mixed precision training
|
|
if use_amp:
|
|
scaler.unscale_(optimizer)
|
|
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
|
# Must invoke scaler.update() if unscale_() is used in the
|
|
# iteration to avoid the following error:
|
|
# RuntimeError: unscale_() has already been called
|
|
# on this optimizer since the last update().
|
|
# We don't check grad here since that if the gradient has
|
|
# inf/nan values, scaler.step will skip optimizer.step().
|
|
scaler.step(optimizer)
|
|
scaler.update()
|
|
else:
|
|
grad_norm = clip_grad_norm_(model.parameters(), clip)
|
|
if torch.isfinite(grad_norm):
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
scheduler.step()
|
|
self.step += 1
|
|
if batch_idx % log_interval == 0:
|
|
lr = optimizer.param_groups[0]['lr']
|
|
log_str = 'TRAIN Batch {}/{} loss {:.6f} '.format(
|
|
batch_idx, num_total_batch,
|
|
loss.item() * accum_grad)
|
|
if loss_att is not None:
|
|
log_str += 'loss_att {:.6f} '.format(loss_att.item())
|
|
if loss_ctc is not None:
|
|
log_str += 'loss_ctc {:.6f} '.format(loss_ctc.item())
|
|
log_str += 'lr {:.8f} rank {}'.format(lr, rank)
|
|
logging.debug(log_str)
|
|
|
|
def cv(self, model, data_loader, device, args):
|
|
''' Cross validation on
|
|
'''
|
|
model.eval()
|
|
log_interval = args.get('log_interval', 10)
|
|
# in order to avoid division by 0
|
|
num_seen_utts = 1
|
|
total_loss = 0.0
|
|
num_total_batch = len(data_loader)
|
|
with torch.no_grad():
|
|
for batch_idx, batch in enumerate(data_loader):
|
|
key, feats, target, feats_lengths, target_lengths = batch
|
|
feats = feats.to(device)
|
|
target = target.to(device)
|
|
feats_lengths = feats_lengths.to(device)
|
|
target_lengths = target_lengths.to(device)
|
|
num_utts = target_lengths.size(0)
|
|
if num_utts == 0:
|
|
continue
|
|
loss, loss_att, loss_ctc = model(feats, feats_lengths, target,
|
|
target_lengths)
|
|
if torch.isfinite(loss):
|
|
num_seen_utts += num_utts
|
|
total_loss += loss.item() * num_utts
|
|
if batch_idx % log_interval == 0:
|
|
log_str = 'CV Batch {}/{} loss {:.6f} '.format(
|
|
batch_idx, num_total_batch, loss.item())
|
|
if loss_att is not None:
|
|
log_str += 'loss_att {:.6f} '.format(loss_att.item())
|
|
if loss_ctc is not None:
|
|
log_str += 'loss_ctc {:.6f} '.format(loss_ctc.item())
|
|
log_str += 'history loss {:.6f}'.format(total_loss /
|
|
num_seen_utts)
|
|
logging.debug(log_str)
|
|
|
|
return total_loss, num_seen_utts
|