260 lines
10 KiB
Python
260 lines
10 KiB
Python
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Xiaoyu Chen)
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import print_function
|
|
|
|
import argparse
|
|
import copy
|
|
import logging
|
|
import os
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.optim as optim
|
|
import yaml
|
|
from tensorboardX import SummaryWriter
|
|
from torch.utils.data import DataLoader
|
|
|
|
from wenet.dataset.dataset import AudioDataset, CollateFunc
|
|
from wenet.transformer.asr_model import init_asr_model
|
|
from wenet.utils.checkpoint import load_checkpoint, save_checkpoint
|
|
from wenet.utils.executor import Executor
|
|
from wenet.utils.scheduler import WarmupLR
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(description='training your network')
|
|
parser.add_argument('--config', required=True, help='config file')
|
|
parser.add_argument('--train_data', required=True, help='train data file')
|
|
parser.add_argument('--cv_data', required=True, help='cv data file')
|
|
parser.add_argument('--gpu',
|
|
type=int,
|
|
default=-1,
|
|
help='gpu id for this local rank, -1 for cpu')
|
|
parser.add_argument('--model_dir', required=True, help='save model dir')
|
|
parser.add_argument('--checkpoint', help='checkpoint model')
|
|
parser.add_argument('--tensorboard_dir',
|
|
default='tensorboard',
|
|
help='tensorboard log dir')
|
|
parser.add_argument('--ddp.rank',
|
|
dest='rank',
|
|
default=0,
|
|
type=int,
|
|
help='global rank for distributed training')
|
|
parser.add_argument('--ddp.world_size',
|
|
dest='world_size',
|
|
default=-1,
|
|
type=int,
|
|
help='''number of total processes/gpus for
|
|
distributed training''')
|
|
parser.add_argument('--ddp.dist_backend',
|
|
dest='dist_backend',
|
|
default='nccl',
|
|
choices=['nccl', 'gloo'],
|
|
help='distributed backend')
|
|
parser.add_argument('--ddp.init_method',
|
|
dest='init_method',
|
|
default=None,
|
|
help='ddp init method')
|
|
parser.add_argument('--num_workers',
|
|
default=0,
|
|
type=int,
|
|
help='num of subprocess workers for reading')
|
|
parser.add_argument('--pin_memory',
|
|
action='store_true',
|
|
default=False,
|
|
help='Use pinned memory buffers used for reading')
|
|
parser.add_argument('--use_amp',
|
|
action='store_true',
|
|
default=False,
|
|
help='Use automatic mixed precision training')
|
|
parser.add_argument('--cmvn', default=None, help='global cmvn file')
|
|
|
|
args = parser.parse_args()
|
|
|
|
logging.basicConfig(level=logging.DEBUG,
|
|
format='%(asctime)s %(levelname)s %(message)s')
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
|
|
# Set random seed
|
|
torch.manual_seed(777)
|
|
print(args)
|
|
with open(args.config, 'r') as fin:
|
|
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
|
|
|
distributed = args.world_size > 1
|
|
|
|
raw_wav = configs['raw_wav']
|
|
|
|
train_collate_func = CollateFunc(**configs['collate_conf'],
|
|
raw_wav=raw_wav)
|
|
|
|
cv_collate_conf = copy.deepcopy(configs['collate_conf'])
|
|
# no augmenation on cv set
|
|
cv_collate_conf['spec_aug'] = False
|
|
cv_collate_conf['spec_sub'] = False
|
|
if raw_wav:
|
|
cv_collate_conf['feature_dither'] = 0.0
|
|
cv_collate_conf['speed_perturb'] = False
|
|
cv_collate_conf['wav_distortion_conf']['wav_distortion_rate'] = 0
|
|
cv_collate_func = CollateFunc(**cv_collate_conf, raw_wav=raw_wav)
|
|
|
|
dataset_conf = configs.get('dataset_conf', {})
|
|
train_dataset = AudioDataset(args.train_data,
|
|
**dataset_conf,
|
|
raw_wav=raw_wav)
|
|
cv_dataset = AudioDataset(args.cv_data, **dataset_conf, raw_wav=raw_wav)
|
|
|
|
if distributed:
|
|
logging.info('training on multiple gpus, this gpu {}'.format(args.gpu))
|
|
dist.init_process_group(args.dist_backend,
|
|
init_method=args.init_method,
|
|
world_size=args.world_size,
|
|
rank=args.rank)
|
|
train_sampler = torch.utils.data.distributed.DistributedSampler(
|
|
train_dataset, shuffle=True)
|
|
cv_sampler = torch.utils.data.distributed.DistributedSampler(
|
|
cv_dataset, shuffle=False)
|
|
else:
|
|
train_sampler = None
|
|
cv_sampler = None
|
|
|
|
train_data_loader = DataLoader(train_dataset,
|
|
collate_fn=train_collate_func,
|
|
sampler=train_sampler,
|
|
shuffle=(train_sampler is None),
|
|
pin_memory=args.pin_memory,
|
|
batch_size=1,
|
|
num_workers=args.num_workers)
|
|
cv_data_loader = DataLoader(cv_dataset,
|
|
collate_fn=cv_collate_func,
|
|
sampler=cv_sampler,
|
|
shuffle=False,
|
|
batch_size=1,
|
|
pin_memory=args.pin_memory,
|
|
num_workers=args.num_workers)
|
|
|
|
if raw_wav:
|
|
input_dim = configs['collate_conf']['feature_extraction_conf'][
|
|
'mel_bins']
|
|
else:
|
|
input_dim = train_dataset.input_dim
|
|
vocab_size = train_dataset.output_dim
|
|
|
|
# Save configs to model_dir/train.yaml for inference and export
|
|
configs['input_dim'] = input_dim
|
|
configs['output_dim'] = vocab_size
|
|
configs['cmvn_file'] = args.cmvn
|
|
configs['is_json_cmvn'] = raw_wav
|
|
if args.rank == 0:
|
|
saved_config_path = os.path.join(args.model_dir, 'train.yaml')
|
|
with open(saved_config_path, 'w') as fout:
|
|
data = yaml.dump(configs)
|
|
fout.write(data)
|
|
|
|
# Init asr model from configs
|
|
model = init_asr_model(configs)
|
|
print(model)
|
|
num_params = sum(p.numel() for p in model.parameters())
|
|
print('the number of model params: {}'.format(num_params))
|
|
|
|
# !!!IMPORTANT!!!
|
|
# Try to export the model by script, if fails, we should refine
|
|
# the code to satisfy the script export requirements
|
|
if args.rank == 0:
|
|
script_model = torch.jit.script(model)
|
|
script_model.save(os.path.join(args.model_dir, 'init.zip'))
|
|
executor = Executor()
|
|
# If specify checkpoint, load some info from checkpoint
|
|
if args.checkpoint is not None:
|
|
infos = load_checkpoint(model, args.checkpoint)
|
|
else:
|
|
infos = {}
|
|
start_epoch = infos.get('epoch', -1) + 1
|
|
cv_loss = infos.get('cv_loss', 0.0)
|
|
step = infos.get('step', -1)
|
|
|
|
num_epochs = configs.get('max_epoch', 100)
|
|
model_dir = args.model_dir
|
|
writer = None
|
|
if args.rank == 0:
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
exp_id = os.path.basename(model_dir)
|
|
writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id))
|
|
|
|
if distributed:
|
|
assert (torch.cuda.is_available())
|
|
# cuda model is required for nn.parallel.DistributedDataParallel
|
|
model.cuda()
|
|
model = torch.nn.parallel.DistributedDataParallel(
|
|
model, find_unused_parameters=True)
|
|
device = torch.device("cuda")
|
|
else:
|
|
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
|
|
device = torch.device('cuda' if use_cuda else 'cpu')
|
|
model = model.to(device)
|
|
|
|
optimizer = optim.Adam(model.parameters(), **configs['optim_conf'])
|
|
scheduler = WarmupLR(optimizer, **configs['scheduler_conf'])
|
|
final_epoch = None
|
|
configs['rank'] = args.rank
|
|
configs['is_distributed'] = distributed
|
|
configs['use_amp'] = args.use_amp
|
|
if start_epoch == 0 and args.rank == 0:
|
|
save_model_path = os.path.join(model_dir, 'init.pt')
|
|
save_checkpoint(model, save_model_path)
|
|
|
|
# Start training loop
|
|
executor.step = step
|
|
scheduler.set_step(step)
|
|
# used for pytorch amp mixed precision training
|
|
scaler = None
|
|
if args.use_amp:
|
|
scaler = torch.cuda.amp.GradScaler()
|
|
for epoch in range(start_epoch, num_epochs):
|
|
if distributed:
|
|
train_sampler.set_epoch(epoch)
|
|
lr = optimizer.param_groups[0]['lr']
|
|
logging.info('Epoch {} TRAIN info lr {}'.format(epoch, lr))
|
|
executor.train(model, optimizer, scheduler, train_data_loader, device,
|
|
writer, configs, scaler)
|
|
total_loss, num_seen_utts = executor.cv(model, cv_data_loader, device,
|
|
configs)
|
|
if args.world_size > 1:
|
|
# all_reduce expected a sequence parameter, so we use [num_seen_utts].
|
|
num_seen_utts = torch.Tensor([num_seen_utts]).to(device)
|
|
# the default operator in all_reduce function is sum.
|
|
dist.all_reduce(num_seen_utts)
|
|
total_loss = torch.Tensor([total_loss]).to(device)
|
|
dist.all_reduce(total_loss)
|
|
cv_loss = total_loss[0] / num_seen_utts[0]
|
|
cv_loss = cv_loss.item()
|
|
else:
|
|
cv_loss = total_loss / num_seen_utts
|
|
|
|
logging.info('Epoch {} CV info cv_loss {}'.format(epoch, cv_loss))
|
|
if args.rank == 0:
|
|
save_model_path = os.path.join(model_dir, '{}.pt'.format(epoch))
|
|
save_checkpoint(
|
|
model, save_model_path, {
|
|
'epoch': epoch,
|
|
'lr': lr,
|
|
'cv_loss': cv_loss,
|
|
'step': executor.step
|
|
})
|
|
writer.add_scalars('epoch', {'cv_loss': cv_loss, 'lr': lr}, epoch)
|
|
final_epoch = epoch
|
|
|
|
if final_epoch is not None and args.rank == 0:
|
|
final_model_path = os.path.join(model_dir, 'final.pt')
|
|
os.symlink('{}.pt'.format(final_epoch), final_model_path)
|