Files
b2txt25/model_training_nnn_tpu/rnn_trainer.py
Zchen 56fa336af0 tpu
2025-10-15 14:26:11 +08:00

952 lines
41 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
# XLA multi-threading optimization - MUST be set before importing torch_xla
# Set these environment variables early to ensure they take effect
if 'TPU_CORES' in os.environ or 'COLAB_TPU_ADDR' in os.environ:
# Enable XLA multi-threading for compilation speedup
os.environ.setdefault('XLA_FLAGS',
'--xla_cpu_multi_thread_eigen=true ' +
'--xla_cpu_enable_fast_math=true ' +
f'--xla_force_host_platform_device_count={os.cpu_count()}'
)
# Set PyTorch XLA threading
os.environ.setdefault('PYTORCH_XLA_COMPILATION_THREADS', str(os.cpu_count()))
print(f"Set XLA compilation threads to {os.cpu_count()}")
import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR
import random
import time
import numpy as np
import math
import pathlib
import logging
import sys
import json
import pickle
from contextlib import nullcontext
from dataset import BrainToTextDataset, train_test_split_indicies
from data_augmentations import gauss_smooth
import torchaudio.functional as F # for edit distance
from omegaconf import OmegaConf
# Import Accelerate for TPU support
from accelerate import Accelerator, DataLoaderConfiguration
from accelerate.utils import set_seed
# Import XLA after setting environment variables
import torch_xla.core.xla_model as xm
torch.set_float32_matmul_precision('high') # makes float32 matmuls faster on some GPUs
torch.backends.cudnn.deterministic = True # makes training more reproducible
torch._dynamo.config.cache_size_limit = 64
from rnn_model import TripleGRUDecoder
class BrainToTextDecoder_Trainer:
"""
This class will initialize and train a brain-to-text phoneme decoder
Written by Nick Card and Zachery Fogg with reference to Stanford NPTL's decoding function
"""
def __init__(self, args):
'''
args : dictionary of training arguments
'''
# Configure DataLoader behavior for TPU compatibility
dataloader_config = DataLoaderConfiguration(
even_batches=False # Required for batch_size=None DataLoaders on TPU
)
# Initialize Accelerator for TPU/multi-device support
self.use_xla = bool(xm.get_xla_supported_devices())
self.amp_requested = args.get('use_amp', True)
mixed_precision_mode = 'bf16' if self.amp_requested else 'no'
self.accelerator = Accelerator(
mixed_precision=mixed_precision_mode,
gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1),
log_with=None, # We'll use our own logging
project_dir=args.get('output_dir', './output'),
dataloader_config=dataloader_config,
)
# Trainer fields
self.args = args
self.logger = None
self.device = self.accelerator.device # Use accelerator device instead of manual device selection
self.model = None
self.optimizer = None
self.learning_rate_scheduler = None
self.ctc_loss = None
self.best_val_PER = torch.inf # track best PER for checkpointing
self.best_val_loss = torch.inf # track best loss for checkpointing
self.train_dataset = None
self.val_dataset = None
self.train_loader = None
self.val_loader = None
self.transform_args = self.args['dataset']['data_transforms']
# Adversarial training config (safe defaults if not provided)
adv_cfg = self.args.get('adversarial', {})
self.adv_enabled = adv_cfg.get('enabled', False)
self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5)) # GRL strength
self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2)) # weight for noisy branch CTC
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) # optional L2 on noise output
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) # delay enabling adversarial after N steps
# Create output directory
if args['mode'] == 'train':
os.makedirs(self.args['output_dir'], exist_ok=True)
# Create checkpoint directory
if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']:
os.makedirs(self.args['checkpoint_dir'], exist_ok=True)
# Set up logging
self.logger = logging.getLogger(__name__)
for handler in self.logger.handlers[:]: # make a copy of the list
self.logger.removeHandler(handler)
self.logger.setLevel(logging.INFO)
formatter = logging.Formatter(fmt='%(asctime)s: %(message)s')
if args['mode']=='train':
# During training, save logs to file in output directory
fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'],'training_log')))
fh.setFormatter(formatter)
self.logger.addHandler(fh)
# Always print logs to stdout
sh = logging.StreamHandler(sys.stdout)
sh.setFormatter(formatter)
self.logger.addHandler(sh)
# Log device information (managed by Accelerator)
self.logger.info(f'Using device: {self.device}')
self.logger.info(f'Accelerator state: {self.accelerator.state}')
if self.accelerator.num_processes > 1:
self.logger.info(f'Distributed training on {self.accelerator.num_processes} processes')
if self.use_xla and self.amp_requested:
self.logger.info('AMP requested on TPU; converting model weights to bfloat16 for memory efficiency.')
# Set seed if provided (using Accelerator's set_seed for proper distributed seeding)
if self.args['seed'] != -1:
set_seed(self.args['seed'])
# Initialize the model
self.model = TripleGRUDecoder(
neural_dim = self.args['model']['n_input_features'],
n_units = self.args['model']['n_units'],
n_days = len(self.args['dataset']['sessions']),
n_classes = self.args['dataset']['n_classes'],
rnn_dropout = self.args['model']['rnn_dropout'],
input_dropout = self.args['model']['input_network']['input_layer_dropout'],
patch_size = self.args['model']['patch_size'],
patch_stride = self.args['model']['patch_stride'],
)
if self.use_xla and self.amp_requested:
self.model = self.model.to(torch.bfloat16)
self.logger.info('Converted model parameters to bfloat16 for TPU training.')
self.model_dtype = next(self.model.parameters()).dtype
# Temporarily disable torch.compile for compatibility with new model architecture
# TODO: Re-enable torch.compile once model is stable
# self.logger.info("Using torch.compile")
# self.model = torch.compile(self.model)
self.logger.info("torch.compile disabled for new TripleGRUDecoder compatibility")
self.logger.info(f"Initialized RNN decoding model")
self.logger.info(self.model)
# Log how many parameters are in the model
total_params = sum(p.numel() for p in self.model.parameters())
self.logger.info(f"Model has {total_params:,} parameters")
# Determine how many day-specific parameters are in the model
day_params = 0
for name, param in self.model.named_parameters():
if 'day' in name:
day_params += param.numel()
self.logger.info(f"Model has {day_params:,} day-specific parameters | {((day_params / total_params) * 100):.2f}% of total parameters")
# Create datasets and dataloaders
train_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_train.hdf5') for s in self.args['dataset']['sessions']]
val_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_val.hdf5') for s in self.args['dataset']['sessions']]
# Ensure that there are no duplicate days
if len(set(train_file_paths)) != len(train_file_paths):
raise ValueError("There are duplicate sessions listed in the train dataset")
if len(set(val_file_paths)) != len(val_file_paths):
raise ValueError("There are duplicate sessions listed in the val dataset")
# Split trials into train and test sets
train_trials, _ = train_test_split_indicies(
file_paths = train_file_paths,
test_percentage = 0,
seed = self.args['dataset']['seed'],
bad_trials_dict = None,
)
_, val_trials = train_test_split_indicies(
file_paths = val_file_paths,
test_percentage = 1,
seed = self.args['dataset']['seed'],
bad_trials_dict = None,
)
# Save dictionaries to output directory to know which trials were train vs val
with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f:
json.dump({'train' : train_trials, 'val': val_trials}, f)
# Determine if a only a subset of neural features should be used
feature_subset = None
if ('feature_subset' in self.args['dataset']) and self.args['dataset']['feature_subset'] != None:
feature_subset = self.args['dataset']['feature_subset']
self.logger.info(f'Using only a subset of features: {feature_subset}')
# train dataset and dataloader
self.train_dataset = BrainToTextDataset(
trial_indicies = train_trials,
split = 'train',
days_per_batch = self.args['dataset']['days_per_batch'],
n_batches = self.args['num_training_batches'],
batch_size = self.args['dataset']['batch_size'],
must_include_days = None,
random_seed = self.args['dataset']['seed'],
feature_subset = feature_subset
)
# Custom collate function that handles pre-batched data from our dataset
def collate_fn(batch):
# Our dataset returns full batches, so batch will be a list of single batch dict
# Extract the first (and only) element since our dataset.__getitem__() returns a full batch
if len(batch) == 1 and isinstance(batch[0], dict):
return batch[0]
else:
# Fallback for unexpected batch structure
return batch
# DataLoader configuration compatible with Accelerate
self.train_loader = DataLoader(
self.train_dataset,
batch_size = 1, # Use batch_size=1 since dataset returns full batches
shuffle = self.args['dataset']['loader_shuffle'],
num_workers = self.args['dataset']['num_dataloader_workers'],
pin_memory = True,
collate_fn = collate_fn
)
# val dataset and dataloader
self.val_dataset = BrainToTextDataset(
trial_indicies = val_trials,
split = 'test',
days_per_batch = None,
n_batches = None,
batch_size = self.args['dataset']['batch_size'],
must_include_days = None,
random_seed = self.args['dataset']['seed'],
feature_subset = feature_subset
)
# Validation DataLoader with same collate function
self.val_loader = DataLoader(
self.val_dataset,
batch_size = 1, # Use batch_size=1 since dataset returns full batches
shuffle = False,
num_workers = 0, # Keep validation dataloader single-threaded for consistency
pin_memory = True,
collate_fn = collate_fn # Use same collate function
)
self.logger.info("Successfully initialized datasets")
# Create optimizer, learning rate scheduler, and loss
self.optimizer = self.create_optimizer()
if self.args['lr_scheduler_type'] == 'linear':
self.learning_rate_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer = self.optimizer,
start_factor = 1.0,
end_factor = self.args['lr_min'] / self.args['lr_max'],
total_iters = self.args['lr_decay_steps'],
)
elif self.args['lr_scheduler_type'] == 'cosine':
self.learning_rate_scheduler = self.create_cosine_lr_scheduler(self.optimizer)
else:
raise ValueError(f"Invalid learning rate scheduler type: {self.args['lr_scheduler_type']}")
self.ctc_loss = torch.nn.CTCLoss(blank = 0, reduction = 'none', zero_infinity = False)
# If a checkpoint is provided, then load from checkpoint
if self.args['init_from_checkpoint']:
self.load_model_checkpoint(self.args['init_checkpoint_path'])
# Set rnn and/or input layers to not trainable if specified
for name, param in self.model.named_parameters():
if not self.args['model']['rnn_trainable'] and 'gru' in name:
param.requires_grad = False
elif not self.args['model']['input_network']['input_trainable'] and 'day' in name:
param.requires_grad = False
# Prepare model, optimizer, scheduler, and dataloaders for distributed training
# Let Accelerator handle everything automatically for both GPU and TPU
(
self.model,
self.optimizer,
self.learning_rate_scheduler,
self.train_loader,
self.val_loader,
) = self.accelerator.prepare(
self.model,
self.optimizer,
self.learning_rate_scheduler,
self.train_loader,
self.val_loader,
)
self.model_dtype = next(self.model.parameters()).dtype
self.logger.info("Prepared model and dataloaders with Accelerator")
if self.adv_enabled:
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}")
def autocast_context(self):
"""Return appropriate autocast context; disable on XLA to avoid dtype mismatches."""
if self.device.type == 'xla':
return nullcontext()
return self.accelerator.autocast()
def create_optimizer(self):
'''
Create the optimizer with special param groups
Biases and day weights should not be decayed
Day weights should have a separate learning rate
'''
bias_params = [p for name, p in self.model.named_parameters() if 'gru.bias' in name or 'out.bias' in name]
day_params = [p for name, p in self.model.named_parameters() if 'day_' in name]
other_params = [p for name, p in self.model.named_parameters() if 'day_' not in name and 'gru.bias' not in name and 'out.bias' not in name]
if len(day_params) != 0:
param_groups = [
{'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'},
{'params' : day_params, 'lr' : self.args['lr_max_day'], 'weight_decay' : self.args['weight_decay_day'], 'group_type' : 'day_layer'},
{'params' : other_params, 'group_type' : 'other'}
]
else:
param_groups = [
{'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'},
{'params' : other_params, 'group_type' : 'other'}
]
optim = torch.optim.AdamW(
param_groups,
lr = self.args['lr_max'],
betas = (self.args['beta0'], self.args['beta1']),
eps = self.args['epsilon'],
weight_decay = self.args['weight_decay'],
fused = True
)
return optim
def create_cosine_lr_scheduler(self, optim):
lr_max = self.args['lr_max']
lr_min = self.args['lr_min']
lr_decay_steps = self.args['lr_decay_steps']
lr_max_day = self.args['lr_max_day']
lr_min_day = self.args['lr_min_day']
lr_decay_steps_day = self.args['lr_decay_steps_day']
lr_warmup_steps = self.args['lr_warmup_steps']
lr_warmup_steps_day = self.args['lr_warmup_steps_day']
def lr_lambda(current_step, min_lr_ratio, decay_steps, warmup_steps):
'''
Create lr lambdas for each param group that implement cosine decay
Different lr lambda decaying for day params vs rest of the model
'''
# Warmup phase
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
# Cosine decay phase
if current_step < decay_steps:
progress = float(current_step - warmup_steps) / float(
max(1, decay_steps - warmup_steps)
)
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
# Scale from 1.0 to min_lr_ratio
return max(min_lr_ratio, min_lr_ratio + (1 - min_lr_ratio) * cosine_decay)
# After cosine decay is complete, maintain min_lr_ratio
return min_lr_ratio
if len(optim.param_groups) == 3:
lr_lambdas = [
lambda step: lr_lambda(
step,
lr_min / lr_max,
lr_decay_steps,
lr_warmup_steps), # biases
lambda step: lr_lambda(
step,
lr_min_day / lr_max_day,
lr_decay_steps_day,
lr_warmup_steps_day,
), # day params
lambda step: lr_lambda(
step,
lr_min / lr_max,
lr_decay_steps,
lr_warmup_steps), # rest of model weights
]
elif len(optim.param_groups) == 2:
lr_lambdas = [
lambda step: lr_lambda(
step,
lr_min / lr_max,
lr_decay_steps,
lr_warmup_steps), # biases
lambda step: lr_lambda(
step,
lr_min / lr_max,
lr_decay_steps,
lr_warmup_steps), # rest of model weights
]
else:
raise ValueError(f"Invalid number of param groups in optimizer: {len(optim.param_groups)}")
return LambdaLR(optim, lr_lambdas, -1)
def load_model_checkpoint(self, load_path):
'''
Load a training checkpoint for distributed training
'''
# Load checkpoint on CPU first to avoid OOM issues
checkpoint = torch.load(load_path, map_location='cpu', weights_only = False) # checkpoint is just a dict
# Get unwrapped model for loading state dict
unwrapped_model = self.accelerator.unwrap_model(self.model)
unwrapped_model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.best_val_PER = checkpoint['val_PER'] # best phoneme error rate
self.best_val_loss = checkpoint['val_loss'] if 'val_loss' in checkpoint.keys() else torch.inf
# Device handling is managed by Accelerator, no need to manually move to device
self.logger.info("Loaded model from checkpoint: " + load_path)
def save_model_checkpoint(self, save_path, PER, loss):
'''
Save a training checkpoint using Accelerator for distributed training
'''
# Only save on main process to avoid conflicts
if self.accelerator.is_main_process:
# Unwrap model to get base model for saving
unwrapped_model = self.accelerator.unwrap_model(self.model)
checkpoint = {
'model_state_dict' : unwrapped_model.state_dict(),
'optimizer_state_dict' : self.optimizer.state_dict(),
'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(),
'val_PER' : PER,
'val_loss' : loss
}
torch.save(checkpoint, save_path)
self.logger.info("Saved model to checkpoint: " + save_path)
# Save the args file alongside the checkpoint
with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
OmegaConf.save(config=self.args, f=f)
# Wait for all processes to complete checkpoint saving
self.accelerator.wait_for_everyone()
def create_attention_mask(self, sequence_lengths):
max_length = torch.max(sequence_lengths).item()
batch_size = sequence_lengths.size(0)
# Create a mask for valid key positions (columns)
# Shape: [batch_size, max_length]
key_mask = torch.arange(max_length, device=sequence_lengths.device).expand(batch_size, max_length)
key_mask = key_mask < sequence_lengths.unsqueeze(1)
# Expand key_mask to [batch_size, 1, 1, max_length]
# This will be broadcast across all query positions
key_mask = key_mask.unsqueeze(1).unsqueeze(1)
# Create the attention mask of shape [batch_size, 1, max_length, max_length]
# by broadcasting key_mask across all query positions
attention_mask = key_mask.expand(batch_size, 1, max_length, max_length)
# Convert boolean mask to float mask:
# - True (valid key positions) -> 0.0 (no change to attention scores)
# - False (padding key positions) -> -inf (will become 0 after softmax)
attention_mask_float = torch.where(attention_mask,
True,
False)
return attention_mask_float
def transform_data(self, features, n_time_steps, mode = 'train'):
'''
Apply various augmentations and smoothing to data
Performing augmentations is much faster on GPU than CPU
'''
# TPU and GPU should now handle data consistently with our improved DataLoader configuration
data_shape = features.shape
batch_size = data_shape[0]
channels = data_shape[-1]
# We only apply these augmentations in training
if mode == 'train':
# add static gain noise
if self.transform_args['static_gain_std'] > 0:
warp_mat = torch.tile(torch.unsqueeze(torch.eye(channels), dim = 0), (batch_size, 1, 1))
warp_mat += torch.randn_like(warp_mat, device=self.device) * self.transform_args['static_gain_std']
features = torch.matmul(features, warp_mat)
# add white noise
if self.transform_args['white_noise_std'] > 0:
features += torch.randn(data_shape, device=self.device) * self.transform_args['white_noise_std']
# add constant offset noise
if self.transform_args['constant_offset_std'] > 0:
features += torch.randn((batch_size, 1, channels), device=self.device) * self.transform_args['constant_offset_std']
# add random walk noise
if self.transform_args['random_walk_std'] > 0:
features += torch.cumsum(torch.randn(data_shape, device=self.device) * self.transform_args['random_walk_std'], dim =self.transform_args['random_walk_axis'])
# randomly cutoff part of the data timecourse
if self.transform_args['random_cut'] > 0:
cut = np.random.randint(0, self.transform_args['random_cut'])
features = features[:, cut:, :]
n_time_steps = n_time_steps - cut
# Apply Gaussian smoothing to data
# This is done in both training and validation
if self.transform_args['smooth_data']:
features = gauss_smooth(
inputs = features,
device = self.device,
smooth_kernel_std = self.transform_args['smooth_kernel_std'],
smooth_kernel_size= self.transform_args['smooth_kernel_size'],
)
if hasattr(self, 'model_dtype'):
features = features.to(self.model_dtype)
return features, n_time_steps
def train(self):
'''
Train the model
'''
# Set model to train mode (specificially to make sure dropout layers are engaged)
self.model.train()
# create vars to track performance
train_losses = []
val_losses = []
val_PERs = []
val_results = []
val_steps_since_improvement = 0
# training params
save_best_checkpoint = self.args.get('save_best_checkpoint', True)
early_stopping = self.args.get('early_stopping', True)
early_stopping_val_steps = self.args['early_stopping_val_steps']
train_start_time = time.time()
# train for specified number of batches
self.logger.info("Starting training loop - loading first batch (TPU compilation may take 5-15 minutes)...")
for i, batch in enumerate(self.train_loader):
self.model.train()
self.optimizer.zero_grad()
# Train step
start_time = time.time()
# Data is automatically moved to device by Accelerator
features = batch['input_features']
labels = batch['seq_class_ids']
n_time_steps = batch['n_time_steps']
phone_seq_lens = batch['phone_seq_lens']
day_indicies = batch['day_indicies']
# Use Accelerator's autocast (mixed precision handled by Accelerator init)
with self.autocast_context():
# Apply augmentations to the data
features, n_time_steps = self.transform_data(features, n_time_steps, 'train')
# Ensure proper dtype handling for TPU mixed precision
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
# Get phoneme predictions using inference mode during training
# (We use inference mode for simplicity - only clean logits are used for CTC loss)
# Ensure features tensor matches model parameter dtype for TPU compatibility
if features.dtype != self.model_dtype:
features = features.to(self.model_dtype)
# Forward pass: enable full adversarial mode if configured and past warmup
use_full = self.adv_enabled and (i >= self.adv_warmup_steps)
if use_full:
clean_logits, noisy_logits, noise_output = self.model(features, day_indicies, None, False, 'full', grl_lambda=self.adv_grl_lambda)
else:
logits = self.model(features, day_indicies, None, False, 'inference')
# Calculate CTC Loss
if use_full:
# Clean CTC loss
clean_log_probs = torch.permute(clean_logits, [1, 0, 2]).float().log_softmax(2)
clean_loss = self.ctc_loss(
clean_log_probs,
labels,
adjusted_lens,
phone_seq_lens
)
clean_loss = torch.mean(clean_loss)
# Noisy branch CTC loss让 Noisy 更可识别,但经 GRL 对 NoiseModel 变成对抗)
noisy_log_probs = torch.permute(noisy_logits, [1, 0, 2]).float().log_softmax(2)
noisy_loss = self.ctc_loss(
noisy_log_probs,
labels,
adjusted_lens,
phone_seq_lens
)
noisy_loss = torch.mean(noisy_loss)
# Optional noise energy regularization
noise_l2 = torch.tensor(0.0, device=self.device, dtype=clean_loss.dtype)
if self.adv_noise_l2_weight > 0.0:
noise_l2 = torch.mean(noise_output.float().pow(2)).to(clean_loss.dtype)
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
else:
log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2)
loss = self.ctc_loss(
log_probs=log_probs,
targets=labels,
input_lengths=adjusted_lens,
target_lengths=phone_seq_lens
)
loss = torch.mean(loss) # take mean loss over batches
# Use Accelerator's backward for distributed training
self.accelerator.backward(loss)
# Clip gradient using Accelerator's clip_grad_norm_
if self.args['grad_norm_clip_value'] > 0:
grad_norm = self.accelerator.clip_grad_norm_(self.model.parameters(),
max_norm = self.args['grad_norm_clip_value'])
self.optimizer.step()
self.learning_rate_scheduler.step()
# Save training metrics
train_step_duration = time.time() - start_time
train_losses.append(loss.detach().item())
# Incrementally log training progress
if i % self.args['batches_per_train_log'] == 0:
self.logger.info(f'Train batch {i}: ' +
f'loss: {(loss.detach().item()):.2f} ' +
f'grad norm: {grad_norm:.2f} '
f'time: {train_step_duration:.3f}')
# Incrementally run a test step
if i % self.args['batches_per_val_step'] == 0 or i == ((self.args['num_training_batches'] - 1)):
self.logger.info(f"Running test after training batch: {i}")
# Calculate metrics on val data
start_time = time.time()
val_metrics = self.validation(loader = self.val_loader, return_logits = self.args['save_val_logits'], return_data = self.args['save_val_data'])
val_step_duration = time.time() - start_time
# Log info
self.logger.info(f'Val batch {i}: ' +
f'PER (avg): {val_metrics["avg_PER"]:.4f} ' +
f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} ' +
f'time: {val_step_duration:.3f}')
if self.args['log_individual_day_val_PER']:
for day in val_metrics['day_PERs'].keys():
self.logger.info(f"{self.args['dataset']['sessions'][day]} val PER: {val_metrics['day_PERs'][day]['total_edit_distance'] / val_metrics['day_PERs'][day]['total_seq_length']:0.4f}")
# Save metrics
val_PERs.append(val_metrics['avg_PER'])
val_losses.append(val_metrics['avg_loss'])
val_results.append(val_metrics)
# Determine if new best day. Based on if PER is lower, or in the case of a PER tie, if loss is lower
new_best = False
if val_metrics['avg_PER'] < self.best_val_PER:
self.logger.info(f"New best test PER {self.best_val_PER:.4f} --> {val_metrics['avg_PER']:.4f}")
self.best_val_PER = val_metrics['avg_PER']
self.best_val_loss = val_metrics['avg_loss']
new_best = True
elif val_metrics['avg_PER'] == self.best_val_PER and (val_metrics['avg_loss'] < self.best_val_loss):
self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}")
self.best_val_loss = val_metrics['avg_loss']
new_best = True
if new_best:
# Checkpoint if metrics have improved
if save_best_checkpoint:
self.logger.info(f"Checkpointing model")
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/best_checkpoint', self.best_val_PER, self.best_val_loss)
# save validation metrics to pickle file
if self.args['save_val_metrics']:
with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f:
pickle.dump(val_metrics, f)
val_steps_since_improvement = 0
else:
val_steps_since_improvement +=1
# Optionally save this validation checkpoint, regardless of performance
if self.args['save_all_val_steps']:
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/checkpoint_batch_{i}', val_metrics['avg_PER'], val_metrics['avg_loss'])
# Early stopping
if early_stopping and (val_steps_since_improvement >= early_stopping_val_steps):
self.logger.info(f'Overall validation PER has not improved in {early_stopping_val_steps} validation steps. Stopping training early at batch: {i}')
break
# Log final training steps
training_duration = time.time() - train_start_time
self.logger.info(f'Best avg val PER achieved: {self.best_val_PER:.5f}')
self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes')
# Save final model
if self.args['save_final_model']:
last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf')
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/final_checkpoint_batch_{i}', val_PERs[-1], last_loss)
train_stats = {}
train_stats['train_losses'] = train_losses
train_stats['val_losses'] = val_losses
train_stats['val_PERs'] = val_PERs
train_stats['val_metrics'] = val_results
return train_stats
def validation(self, loader, return_logits = False, return_data = False):
'''
Calculate metrics on the validation dataset
'''
self.model.eval()
metrics = {}
# Record metrics
if return_logits:
metrics['logits'] = []
metrics['n_time_steps'] = []
if return_data:
metrics['input_features'] = []
metrics['decoded_seqs'] = []
metrics['true_seq'] = []
metrics['phone_seq_lens'] = []
metrics['transcription'] = []
metrics['losses'] = []
metrics['block_nums'] = []
metrics['trial_nums'] = []
metrics['day_indicies'] = []
total_edit_distance = 0
total_seq_length = 0
# Calculate PER for each specific day
day_per = {}
for d in range(len(self.args['dataset']['sessions'])):
if self.args['dataset']['dataset_probability_val'][d] == 1:
day_per[d] = {'total_edit_distance' : 0, 'total_seq_length' : 0}
for i, batch in enumerate(loader):
# Data is automatically moved to device by Accelerator
features = batch['input_features']
labels = batch['seq_class_ids']
n_time_steps = batch['n_time_steps']
phone_seq_lens = batch['phone_seq_lens']
day_indicies = batch['day_indicies']
# Determine if we should perform validation on this batch
day = day_indicies[0].item()
if self.args['dataset']['dataset_probability_val'][day] == 0:
if self.args['log_val_skip_logs']:
self.logger.info(f"Skipping validation on day {day}")
continue
with torch.no_grad():
with self.autocast_context():
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
# Ensure proper dtype handling for TPU mixed precision
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
# Ensure features tensor matches model parameter dtype for TPU compatibility
model_param = next(self.model.parameters()) if self.model is not None else None
if model_param is not None and features.dtype != model_param.dtype:
features = features.to(model_param.dtype)
logits = self.model(features, day_indicies, None, False, 'inference')
val_log_probs = torch.permute(logits, [1, 0, 2]).float().log_softmax(2)
loss = self.ctc_loss(
val_log_probs,
labels,
adjusted_lens,
phone_seq_lens,
)
loss = torch.mean(loss)
metrics['losses'].append(loss.cpu().detach().numpy())
# Calculate PER per day and also avg over entire validation set
batch_edit_distance = 0
decoded_seqs = []
for iterIdx in range(logits.shape[0]):
decoded_seq = torch.argmax(logits[iterIdx, 0 : adjusted_lens[iterIdx], :].clone().detach(),dim=-1)
decoded_seq = torch.unique_consecutive(decoded_seq, dim=-1)
decoded_seq = decoded_seq.cpu().detach().numpy()
decoded_seq = np.array([i for i in decoded_seq if i != 0])
trueSeq = np.array(
labels[iterIdx][0 : phone_seq_lens[iterIdx]].cpu().detach()
)
batch_edit_distance += F.edit_distance(decoded_seq, trueSeq)
decoded_seqs.append(decoded_seq)
day = batch['day_indicies'][0].item()
day_per[day]['total_edit_distance'] += batch_edit_distance
day_per[day]['total_seq_length'] += torch.sum(phone_seq_lens).item()
total_edit_distance += batch_edit_distance
total_seq_length += torch.sum(phone_seq_lens)
# Record metrics
if return_logits:
metrics['logits'].append(logits.cpu().float().numpy()) # Will be in bfloat16 if AMP is enabled, so need to set back to float32
metrics['n_time_steps'].append(adjusted_lens.cpu().numpy())
if return_data:
metrics['input_features'].append(batch['input_features'].cpu().numpy())
metrics['decoded_seqs'].append(decoded_seqs)
metrics['true_seq'].append(batch['seq_class_ids'].cpu().numpy())
metrics['phone_seq_lens'].append(batch['phone_seq_lens'].cpu().numpy())
metrics['transcription'].append(batch['transcriptions'].cpu().numpy())
metrics['losses'].append(loss.detach().item())
metrics['block_nums'].append(batch['block_nums'].numpy())
metrics['trial_nums'].append(batch['trial_nums'].numpy())
metrics['day_indicies'].append(batch['day_indicies'].cpu().numpy())
if isinstance(total_seq_length, torch.Tensor):
total_length_value = float(total_seq_length.item())
else:
total_length_value = float(total_seq_length)
avg_PER = total_edit_distance / max(total_length_value, 1e-6)
metrics['day_PERs'] = day_per
metrics['avg_PER'] = avg_PER
metrics['avg_loss'] = float(np.mean(metrics['losses']))
return metrics
def inference(self, features, day_indicies, n_time_steps, mode='inference'):
'''
TPU-compatible inference method for generating phoneme logits
'''
self.model.eval()
with torch.no_grad():
with self.autocast_context():
# Apply data transformations (no augmentation for inference)
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
# Ensure features tensor matches model parameter dtype for TPU compatibility
if features.dtype != self.model_dtype:
features = features.to(self.model_dtype)
# Get phoneme predictions
logits = self.model(features, day_indicies, None, False, mode)
return logits
def inference_batch(self, batch, mode='inference'):
'''
Inference method for processing a full batch
'''
self.model.eval()
# Data is automatically moved to device by Accelerator
features = batch['input_features']
day_indicies = batch['day_indicies']
n_time_steps = batch['n_time_steps']
with torch.no_grad():
with self.autocast_context():
# Apply data transformations (no augmentation for inference)
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
# Calculate adjusted sequence lengths for CTC with proper dtype handling
adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
# Ensure features tensor matches model parameter dtype for TPU compatibility
if features.dtype != self.model_dtype:
features = features.to(self.model_dtype)
# Get phoneme predictions
logits = self.model(features, day_indicies, None, False, mode)
return logits, adjusted_lens