import torch from torch.utils.data import DataLoader from torch.optim.lr_scheduler import LambdaLR import random import time import os import numpy as np import math import pathlib import logging import sys import json import pickle from dataset import BrainToTextDataset, train_test_split_indicies from data_augmentations import gauss_smooth import torchaudio.functional as F # for edit distance from omegaconf import OmegaConf # Import Accelerate for TPU support from accelerate import Accelerator, DataLoaderConfiguration from accelerate.utils import set_seed # XLA multi-threading optimization for faster compilation import torch_xla.core.xla_model as xm if xm.get_xla_supported_devices(): # Enable XLA multi-threading for compilation speedup os.environ.setdefault('XLA_FLAGS', '--xla_cpu_multi_thread_eigen=true ' + '--xla_cpu_enable_fast_math=true ' + f'--xla_force_host_platform_device_count={os.cpu_count()}' ) # Set PyTorch XLA threading os.environ.setdefault('PYTORCH_XLA_COMPILATION_THREADS', str(os.cpu_count())) torch.set_float32_matmul_precision('high') # makes float32 matmuls faster on some GPUs torch.backends.cudnn.deterministic = True # makes training more reproducible torch._dynamo.config.cache_size_limit = 64 from rnn_model import TripleGRUDecoder class BrainToTextDecoder_Trainer: """ This class will initialize and train a brain-to-text phoneme decoder Written by Nick Card and Zachery Fogg with reference to Stanford NPTL's decoding function """ def __init__(self, args): ''' args : dictionary of training arguments ''' # Configure DataLoader behavior for TPU compatibility dataloader_config = DataLoaderConfiguration( even_batches=False # Required for batch_size=None DataLoaders on TPU ) # Initialize Accelerator for TPU/multi-device support self.accelerator = Accelerator( mixed_precision='bf16' if args.get('use_amp', True) else 'no', gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1), log_with=None, # We'll use our own logging project_dir=args.get('output_dir', './output'), dataloader_config=dataloader_config, ) # Trainer fields self.args = args self.logger = None self.device = self.accelerator.device # Use accelerator device instead of manual device selection self.model = None self.optimizer = None self.learning_rate_scheduler = None self.ctc_loss = None self.best_val_PER = torch.inf # track best PER for checkpointing self.best_val_loss = torch.inf # track best loss for checkpointing self.train_dataset = None self.val_dataset = None self.train_loader = None self.val_loader = None self.transform_args = self.args['dataset']['data_transforms'] # Adversarial training config (safe defaults if not provided) adv_cfg = self.args.get('adversarial', {}) self.adv_enabled = adv_cfg.get('enabled', False) self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5)) # GRL strength self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2)) # weight for noisy branch CTC self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) # optional L2 on noise output self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) # delay enabling adversarial after N steps # Create output directory if args['mode'] == 'train': os.makedirs(self.args['output_dir'], exist_ok=True) # Create checkpoint directory if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']: os.makedirs(self.args['checkpoint_dir'], exist_ok=True) # Set up logging self.logger = logging.getLogger(__name__) for handler in self.logger.handlers[:]: # make a copy of the list self.logger.removeHandler(handler) self.logger.setLevel(logging.INFO) formatter = logging.Formatter(fmt='%(asctime)s: %(message)s') if args['mode']=='train': # During training, save logs to file in output directory fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'],'training_log'))) fh.setFormatter(formatter) self.logger.addHandler(fh) # Always print logs to stdout sh = logging.StreamHandler(sys.stdout) sh.setFormatter(formatter) self.logger.addHandler(sh) # Log device information (managed by Accelerator) self.logger.info(f'Using device: {self.device}') self.logger.info(f'Accelerator state: {self.accelerator.state}') if self.accelerator.num_processes > 1: self.logger.info(f'Distributed training on {self.accelerator.num_processes} processes') # Set seed if provided (using Accelerator's set_seed for proper distributed seeding) if self.args['seed'] != -1: set_seed(self.args['seed']) # Initialize the model self.model = TripleGRUDecoder( neural_dim = self.args['model']['n_input_features'], n_units = self.args['model']['n_units'], n_days = len(self.args['dataset']['sessions']), n_classes = self.args['dataset']['n_classes'], rnn_dropout = self.args['model']['rnn_dropout'], input_dropout = self.args['model']['input_network']['input_layer_dropout'], patch_size = self.args['model']['patch_size'], patch_stride = self.args['model']['patch_stride'], ) # Temporarily disable torch.compile for compatibility with new model architecture # TODO: Re-enable torch.compile once model is stable # self.logger.info("Using torch.compile") # self.model = torch.compile(self.model) self.logger.info("torch.compile disabled for new TripleGRUDecoder compatibility") self.logger.info(f"Initialized RNN decoding model") self.logger.info(self.model) # Log how many parameters are in the model total_params = sum(p.numel() for p in self.model.parameters()) self.logger.info(f"Model has {total_params:,} parameters") # Determine how many day-specific parameters are in the model day_params = 0 for name, param in self.model.named_parameters(): if 'day' in name: day_params += param.numel() self.logger.info(f"Model has {day_params:,} day-specific parameters | {((day_params / total_params) * 100):.2f}% of total parameters") # Create datasets and dataloaders train_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_train.hdf5') for s in self.args['dataset']['sessions']] val_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_val.hdf5') for s in self.args['dataset']['sessions']] # Ensure that there are no duplicate days if len(set(train_file_paths)) != len(train_file_paths): raise ValueError("There are duplicate sessions listed in the train dataset") if len(set(val_file_paths)) != len(val_file_paths): raise ValueError("There are duplicate sessions listed in the val dataset") # Split trials into train and test sets train_trials, _ = train_test_split_indicies( file_paths = train_file_paths, test_percentage = 0, seed = self.args['dataset']['seed'], bad_trials_dict = None, ) _, val_trials = train_test_split_indicies( file_paths = val_file_paths, test_percentage = 1, seed = self.args['dataset']['seed'], bad_trials_dict = None, ) # Save dictionaries to output directory to know which trials were train vs val with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f: json.dump({'train' : train_trials, 'val': val_trials}, f) # Determine if a only a subset of neural features should be used feature_subset = None if ('feature_subset' in self.args['dataset']) and self.args['dataset']['feature_subset'] != None: feature_subset = self.args['dataset']['feature_subset'] self.logger.info(f'Using only a subset of features: {feature_subset}') # train dataset and dataloader self.train_dataset = BrainToTextDataset( trial_indicies = train_trials, split = 'train', days_per_batch = self.args['dataset']['days_per_batch'], n_batches = self.args['num_training_batches'], batch_size = self.args['dataset']['batch_size'], must_include_days = None, random_seed = self.args['dataset']['seed'], feature_subset = feature_subset ) # Custom collate function that handles pre-batched data from our dataset def collate_fn(batch): # Our dataset returns full batches, so batch will be a list of single batch dict # Extract the first (and only) element since our dataset.__getitem__() returns a full batch if len(batch) == 1 and isinstance(batch[0], dict): return batch[0] else: # Fallback for unexpected batch structure return batch # DataLoader configuration compatible with Accelerate self.train_loader = DataLoader( self.train_dataset, batch_size = 1, # Use batch_size=1 since dataset returns full batches shuffle = self.args['dataset']['loader_shuffle'], num_workers = self.args['dataset']['num_dataloader_workers'], pin_memory = True, collate_fn = collate_fn ) # val dataset and dataloader self.val_dataset = BrainToTextDataset( trial_indicies = val_trials, split = 'test', days_per_batch = None, n_batches = None, batch_size = self.args['dataset']['batch_size'], must_include_days = None, random_seed = self.args['dataset']['seed'], feature_subset = feature_subset ) # Validation DataLoader with same collate function self.val_loader = DataLoader( self.val_dataset, batch_size = 1, # Use batch_size=1 since dataset returns full batches shuffle = False, num_workers = 0, # Keep validation dataloader single-threaded for consistency pin_memory = True, collate_fn = collate_fn # Use same collate function ) self.logger.info("Successfully initialized datasets") # Create optimizer, learning rate scheduler, and loss self.optimizer = self.create_optimizer() if self.args['lr_scheduler_type'] == 'linear': self.learning_rate_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer = self.optimizer, start_factor = 1.0, end_factor = self.args['lr_min'] / self.args['lr_max'], total_iters = self.args['lr_decay_steps'], ) elif self.args['lr_scheduler_type'] == 'cosine': self.learning_rate_scheduler = self.create_cosine_lr_scheduler(self.optimizer) else: raise ValueError(f"Invalid learning rate scheduler type: {self.args['lr_scheduler_type']}") self.ctc_loss = torch.nn.CTCLoss(blank = 0, reduction = 'none', zero_infinity = False) # If a checkpoint is provided, then load from checkpoint if self.args['init_from_checkpoint']: self.load_model_checkpoint(self.args['init_checkpoint_path']) # Set rnn and/or input layers to not trainable if specified for name, param in self.model.named_parameters(): if not self.args['model']['rnn_trainable'] and 'gru' in name: param.requires_grad = False elif not self.args['model']['input_network']['input_trainable'] and 'day' in name: param.requires_grad = False # Prepare model, optimizer, scheduler, and dataloaders for distributed training # Let Accelerator handle everything automatically for both GPU and TPU ( self.model, self.optimizer, self.learning_rate_scheduler, self.train_loader, self.val_loader, ) = self.accelerator.prepare( self.model, self.optimizer, self.learning_rate_scheduler, self.train_loader, self.val_loader, ) self.logger.info("Prepared model and dataloaders with Accelerator") if self.adv_enabled: self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, noisy_loss_weight={self.adv_noisy_loss_weight}, noise_l2_weight={self.adv_noise_l2_weight}, warmup_steps={self.adv_warmup_steps}") def create_optimizer(self): ''' Create the optimizer with special param groups Biases and day weights should not be decayed Day weights should have a separate learning rate ''' bias_params = [p for name, p in self.model.named_parameters() if 'gru.bias' in name or 'out.bias' in name] day_params = [p for name, p in self.model.named_parameters() if 'day_' in name] other_params = [p for name, p in self.model.named_parameters() if 'day_' not in name and 'gru.bias' not in name and 'out.bias' not in name] if len(day_params) != 0: param_groups = [ {'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'}, {'params' : day_params, 'lr' : self.args['lr_max_day'], 'weight_decay' : self.args['weight_decay_day'], 'group_type' : 'day_layer'}, {'params' : other_params, 'group_type' : 'other'} ] else: param_groups = [ {'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'}, {'params' : other_params, 'group_type' : 'other'} ] optim = torch.optim.AdamW( param_groups, lr = self.args['lr_max'], betas = (self.args['beta0'], self.args['beta1']), eps = self.args['epsilon'], weight_decay = self.args['weight_decay'], fused = True ) return optim def create_cosine_lr_scheduler(self, optim): lr_max = self.args['lr_max'] lr_min = self.args['lr_min'] lr_decay_steps = self.args['lr_decay_steps'] lr_max_day = self.args['lr_max_day'] lr_min_day = self.args['lr_min_day'] lr_decay_steps_day = self.args['lr_decay_steps_day'] lr_warmup_steps = self.args['lr_warmup_steps'] lr_warmup_steps_day = self.args['lr_warmup_steps_day'] def lr_lambda(current_step, min_lr_ratio, decay_steps, warmup_steps): ''' Create lr lambdas for each param group that implement cosine decay Different lr lambda decaying for day params vs rest of the model ''' # Warmup phase if current_step < warmup_steps: return float(current_step) / float(max(1, warmup_steps)) # Cosine decay phase if current_step < decay_steps: progress = float(current_step - warmup_steps) / float( max(1, decay_steps - warmup_steps) ) cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) # Scale from 1.0 to min_lr_ratio return max(min_lr_ratio, min_lr_ratio + (1 - min_lr_ratio) * cosine_decay) # After cosine decay is complete, maintain min_lr_ratio return min_lr_ratio if len(optim.param_groups) == 3: lr_lambdas = [ lambda step: lr_lambda( step, lr_min / lr_max, lr_decay_steps, lr_warmup_steps), # biases lambda step: lr_lambda( step, lr_min_day / lr_max_day, lr_decay_steps_day, lr_warmup_steps_day, ), # day params lambda step: lr_lambda( step, lr_min / lr_max, lr_decay_steps, lr_warmup_steps), # rest of model weights ] elif len(optim.param_groups) == 2: lr_lambdas = [ lambda step: lr_lambda( step, lr_min / lr_max, lr_decay_steps, lr_warmup_steps), # biases lambda step: lr_lambda( step, lr_min / lr_max, lr_decay_steps, lr_warmup_steps), # rest of model weights ] else: raise ValueError(f"Invalid number of param groups in optimizer: {len(optim.param_groups)}") return LambdaLR(optim, lr_lambdas, -1) def load_model_checkpoint(self, load_path): ''' Load a training checkpoint for distributed training ''' # Load checkpoint on CPU first to avoid OOM issues checkpoint = torch.load(load_path, map_location='cpu', weights_only = False) # checkpoint is just a dict # Get unwrapped model for loading state dict unwrapped_model = self.accelerator.unwrap_model(self.model) unwrapped_model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.best_val_PER = checkpoint['val_PER'] # best phoneme error rate self.best_val_loss = checkpoint['val_loss'] if 'val_loss' in checkpoint.keys() else torch.inf # Device handling is managed by Accelerator, no need to manually move to device self.logger.info("Loaded model from checkpoint: " + load_path) def save_model_checkpoint(self, save_path, PER, loss): ''' Save a training checkpoint using Accelerator for distributed training ''' # Only save on main process to avoid conflicts if self.accelerator.is_main_process: # Unwrap model to get base model for saving unwrapped_model = self.accelerator.unwrap_model(self.model) checkpoint = { 'model_state_dict' : unwrapped_model.state_dict(), 'optimizer_state_dict' : self.optimizer.state_dict(), 'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(), 'val_PER' : PER, 'val_loss' : loss } torch.save(checkpoint, save_path) self.logger.info("Saved model to checkpoint: " + save_path) # Save the args file alongside the checkpoint with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f: OmegaConf.save(config=self.args, f=f) # Wait for all processes to complete checkpoint saving self.accelerator.wait_for_everyone() def create_attention_mask(self, sequence_lengths): max_length = torch.max(sequence_lengths).item() batch_size = sequence_lengths.size(0) # Create a mask for valid key positions (columns) # Shape: [batch_size, max_length] key_mask = torch.arange(max_length, device=sequence_lengths.device).expand(batch_size, max_length) key_mask = key_mask < sequence_lengths.unsqueeze(1) # Expand key_mask to [batch_size, 1, 1, max_length] # This will be broadcast across all query positions key_mask = key_mask.unsqueeze(1).unsqueeze(1) # Create the attention mask of shape [batch_size, 1, max_length, max_length] # by broadcasting key_mask across all query positions attention_mask = key_mask.expand(batch_size, 1, max_length, max_length) # Convert boolean mask to float mask: # - True (valid key positions) -> 0.0 (no change to attention scores) # - False (padding key positions) -> -inf (will become 0 after softmax) attention_mask_float = torch.where(attention_mask, True, False) return attention_mask_float def transform_data(self, features, n_time_steps, mode = 'train'): ''' Apply various augmentations and smoothing to data Performing augmentations is much faster on GPU than CPU ''' # TPU and GPU should now handle data consistently with our improved DataLoader configuration data_shape = features.shape batch_size = data_shape[0] channels = data_shape[-1] # We only apply these augmentations in training if mode == 'train': # add static gain noise if self.transform_args['static_gain_std'] > 0: warp_mat = torch.tile(torch.unsqueeze(torch.eye(channels), dim = 0), (batch_size, 1, 1)) warp_mat += torch.randn_like(warp_mat, device=self.device) * self.transform_args['static_gain_std'] features = torch.matmul(features, warp_mat) # add white noise if self.transform_args['white_noise_std'] > 0: features += torch.randn(data_shape, device=self.device) * self.transform_args['white_noise_std'] # add constant offset noise if self.transform_args['constant_offset_std'] > 0: features += torch.randn((batch_size, 1, channels), device=self.device) * self.transform_args['constant_offset_std'] # add random walk noise if self.transform_args['random_walk_std'] > 0: features += torch.cumsum(torch.randn(data_shape, device=self.device) * self.transform_args['random_walk_std'], dim =self.transform_args['random_walk_axis']) # randomly cutoff part of the data timecourse if self.transform_args['random_cut'] > 0: cut = np.random.randint(0, self.transform_args['random_cut']) features = features[:, cut:, :] n_time_steps = n_time_steps - cut # Apply Gaussian smoothing to data # This is done in both training and validation if self.transform_args['smooth_data']: features = gauss_smooth( inputs = features, device = self.device, smooth_kernel_std = self.transform_args['smooth_kernel_std'], smooth_kernel_size= self.transform_args['smooth_kernel_size'], ) return features, n_time_steps def train(self): ''' Train the model ''' # Set model to train mode (specificially to make sure dropout layers are engaged) self.model.train() # create vars to track performance train_losses = [] val_losses = [] val_PERs = [] val_results = [] val_steps_since_improvement = 0 # training params save_best_checkpoint = self.args.get('save_best_checkpoint', True) early_stopping = self.args.get('early_stopping', True) early_stopping_val_steps = self.args['early_stopping_val_steps'] train_start_time = time.time() # train for specified number of batches self.logger.info("Starting training loop - loading first batch (TPU compilation may take 5-15 minutes)...") for i, batch in enumerate(self.train_loader): self.model.train() self.optimizer.zero_grad() # Train step start_time = time.time() # Data is automatically moved to device by Accelerator features = batch['input_features'] labels = batch['seq_class_ids'] n_time_steps = batch['n_time_steps'] phone_seq_lens = batch['phone_seq_lens'] day_indicies = batch['day_indicies'] # Use Accelerator's autocast (mixed precision handled by Accelerator init) with self.accelerator.autocast(): # Apply augmentations to the data features, n_time_steps = self.transform_data(features, n_time_steps, 'train') # Ensure proper dtype handling for TPU mixed precision adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) # Get phoneme predictions using inference mode during training # (We use inference mode for simplicity - only clean logits are used for CTC loss) # Ensure features tensor matches model parameter dtype for TPU compatibility if self.accelerator.mixed_precision == 'bf16': # In mixed precision mode, ensure features match the expected precision features = features.to(torch.float32) # Forward pass: enable full adversarial mode if configured and past warmup use_full = self.adv_enabled and (i >= self.adv_warmup_steps) if use_full: clean_logits, noisy_logits, noise_output = self.model(features, day_indicies, None, False, 'full', grl_lambda=self.adv_grl_lambda) else: logits = self.model(features, day_indicies, None, False, 'inference') # Calculate CTC Loss if use_full: # Clean CTC loss clean_loss = self.ctc_loss( torch.permute(clean_logits.log_softmax(2), [1, 0, 2]), labels, adjusted_lens, phone_seq_lens ) clean_loss = torch.mean(clean_loss) # Noisy branch CTC loss(让 Noisy 更可识别,但经 GRL 对 NoiseModel 变成对抗) noisy_loss = self.ctc_loss( torch.permute(noisy_logits.log_softmax(2), [1, 0, 2]), labels, adjusted_lens, phone_seq_lens ) noisy_loss = torch.mean(noisy_loss) # Optional noise energy regularization noise_l2 = torch.tensor(0.0, device=self.device) if self.adv_noise_l2_weight > 0.0: noise_l2 = torch.mean(noise_output.pow(2)) loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2 else: loss = self.ctc_loss( log_probs = torch.permute(logits.log_softmax(2), [1, 0, 2]), targets = labels, input_lengths = adjusted_lens, target_lengths = phone_seq_lens ) loss = torch.mean(loss) # take mean loss over batches # Use Accelerator's backward for distributed training self.accelerator.backward(loss) # Clip gradient using Accelerator's clip_grad_norm_ if self.args['grad_norm_clip_value'] > 0: grad_norm = self.accelerator.clip_grad_norm_(self.model.parameters(), max_norm = self.args['grad_norm_clip_value']) self.optimizer.step() self.learning_rate_scheduler.step() # Save training metrics train_step_duration = time.time() - start_time train_losses.append(loss.detach().item()) # Incrementally log training progress if i % self.args['batches_per_train_log'] == 0: self.logger.info(f'Train batch {i}: ' + f'loss: {(loss.detach().item()):.2f} ' + f'grad norm: {grad_norm:.2f} ' f'time: {train_step_duration:.3f}') # Incrementally run a test step if i % self.args['batches_per_val_step'] == 0 or i == ((self.args['num_training_batches'] - 1)): self.logger.info(f"Running test after training batch: {i}") # Calculate metrics on val data start_time = time.time() val_metrics = self.validation(loader = self.val_loader, return_logits = self.args['save_val_logits'], return_data = self.args['save_val_data']) val_step_duration = time.time() - start_time # Log info self.logger.info(f'Val batch {i}: ' + f'PER (avg): {val_metrics["avg_PER"]:.4f} ' + f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} ' + f'time: {val_step_duration:.3f}') if self.args['log_individual_day_val_PER']: for day in val_metrics['day_PERs'].keys(): self.logger.info(f"{self.args['dataset']['sessions'][day]} val PER: {val_metrics['day_PERs'][day]['total_edit_distance'] / val_metrics['day_PERs'][day]['total_seq_length']:0.4f}") # Save metrics val_PERs.append(val_metrics['avg_PER']) val_losses.append(val_metrics['avg_loss']) val_results.append(val_metrics) # Determine if new best day. Based on if PER is lower, or in the case of a PER tie, if loss is lower new_best = False if val_metrics['avg_PER'] < self.best_val_PER: self.logger.info(f"New best test PER {self.best_val_PER:.4f} --> {val_metrics['avg_PER']:.4f}") self.best_val_PER = val_metrics['avg_PER'] self.best_val_loss = val_metrics['avg_loss'] new_best = True elif val_metrics['avg_PER'] == self.best_val_PER and (val_metrics['avg_loss'] < self.best_val_loss): self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}") self.best_val_loss = val_metrics['avg_loss'] new_best = True if new_best: # Checkpoint if metrics have improved if save_best_checkpoint: self.logger.info(f"Checkpointing model") self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/best_checkpoint', self.best_val_PER, self.best_val_loss) # save validation metrics to pickle file if self.args['save_val_metrics']: with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f: pickle.dump(val_metrics, f) val_steps_since_improvement = 0 else: val_steps_since_improvement +=1 # Optionally save this validation checkpoint, regardless of performance if self.args['save_all_val_steps']: self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/checkpoint_batch_{i}', val_metrics['avg_PER'], val_metrics['avg_loss']) # Early stopping if early_stopping and (val_steps_since_improvement >= early_stopping_val_steps): self.logger.info(f'Overall validation PER has not improved in {early_stopping_val_steps} validation steps. Stopping training early at batch: {i}') break # Log final training steps training_duration = time.time() - train_start_time self.logger.info(f'Best avg val PER achieved: {self.best_val_PER:.5f}') self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes') # Save final model if self.args['save_final_model']: last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf') self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/final_checkpoint_batch_{i}', val_PERs[-1], last_loss) train_stats = {} train_stats['train_losses'] = train_losses train_stats['val_losses'] = val_losses train_stats['val_PERs'] = val_PERs train_stats['val_metrics'] = val_results return train_stats def validation(self, loader, return_logits = False, return_data = False): ''' Calculate metrics on the validation dataset ''' self.model.eval() metrics = {} # Record metrics if return_logits: metrics['logits'] = [] metrics['n_time_steps'] = [] if return_data: metrics['input_features'] = [] metrics['decoded_seqs'] = [] metrics['true_seq'] = [] metrics['phone_seq_lens'] = [] metrics['transcription'] = [] metrics['losses'] = [] metrics['block_nums'] = [] metrics['trial_nums'] = [] metrics['day_indicies'] = [] total_edit_distance = 0 total_seq_length = 0 # Calculate PER for each specific day day_per = {} for d in range(len(self.args['dataset']['sessions'])): if self.args['dataset']['dataset_probability_val'][d] == 1: day_per[d] = {'total_edit_distance' : 0, 'total_seq_length' : 0} for i, batch in enumerate(loader): # Data is automatically moved to device by Accelerator features = batch['input_features'] labels = batch['seq_class_ids'] n_time_steps = batch['n_time_steps'] phone_seq_lens = batch['phone_seq_lens'] day_indicies = batch['day_indicies'] # Determine if we should perform validation on this batch day = day_indicies[0].item() if self.args['dataset']['dataset_probability_val'][day] == 0: if self.args['log_val_skip_logs']: self.logger.info(f"Skipping validation on day {day}") continue with torch.no_grad(): with self.accelerator.autocast(): features, n_time_steps = self.transform_data(features, n_time_steps, 'val') # Ensure proper dtype handling for TPU mixed precision adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) # Ensure features tensor matches model parameter dtype for TPU compatibility if self.accelerator.mixed_precision == 'bf16': # In mixed precision mode, ensure features match the expected precision features = features.to(torch.float32) logits = self.model(features, day_indicies, None, False, 'inference') loss = self.ctc_loss( torch.permute(logits.log_softmax(2), [1, 0, 2]), labels, adjusted_lens, phone_seq_lens, ) loss = torch.mean(loss) metrics['losses'].append(loss.cpu().detach().numpy()) # Calculate PER per day and also avg over entire validation set batch_edit_distance = 0 decoded_seqs = [] for iterIdx in range(logits.shape[0]): decoded_seq = torch.argmax(logits[iterIdx, 0 : adjusted_lens[iterIdx], :].clone().detach(),dim=-1) decoded_seq = torch.unique_consecutive(decoded_seq, dim=-1) decoded_seq = decoded_seq.cpu().detach().numpy() decoded_seq = np.array([i for i in decoded_seq if i != 0]) trueSeq = np.array( labels[iterIdx][0 : phone_seq_lens[iterIdx]].cpu().detach() ) batch_edit_distance += F.edit_distance(decoded_seq, trueSeq) decoded_seqs.append(decoded_seq) day = batch['day_indicies'][0].item() day_per[day]['total_edit_distance'] += batch_edit_distance day_per[day]['total_seq_length'] += torch.sum(phone_seq_lens).item() total_edit_distance += batch_edit_distance total_seq_length += torch.sum(phone_seq_lens) # Record metrics if return_logits: metrics['logits'].append(logits.cpu().float().numpy()) # Will be in bfloat16 if AMP is enabled, so need to set back to float32 metrics['n_time_steps'].append(adjusted_lens.cpu().numpy()) if return_data: metrics['input_features'].append(batch['input_features'].cpu().numpy()) metrics['decoded_seqs'].append(decoded_seqs) metrics['true_seq'].append(batch['seq_class_ids'].cpu().numpy()) metrics['phone_seq_lens'].append(batch['phone_seq_lens'].cpu().numpy()) metrics['transcription'].append(batch['transcriptions'].cpu().numpy()) metrics['losses'].append(loss.detach().item()) metrics['block_nums'].append(batch['block_nums'].numpy()) metrics['trial_nums'].append(batch['trial_nums'].numpy()) metrics['day_indicies'].append(batch['day_indicies'].cpu().numpy()) avg_PER = total_edit_distance / total_seq_length metrics['day_PERs'] = day_per metrics['avg_PER'] = avg_PER.item() metrics['avg_loss'] = np.mean(metrics['losses']) return metrics def inference(self, features, day_indicies, n_time_steps, mode='inference'): ''' TPU-compatible inference method for generating phoneme logits ''' self.model.eval() with torch.no_grad(): with self.accelerator.autocast(): # Apply data transformations (no augmentation for inference) features, n_time_steps = self.transform_data(features, n_time_steps, 'val') # Ensure features tensor matches model parameter dtype for TPU compatibility if self.accelerator.mixed_precision == 'bf16': # In mixed precision mode, ensure features match the expected precision features = features.to(torch.float32) # Get phoneme predictions logits = self.model(features, day_indicies, None, False, mode) return logits def inference_batch(self, batch, mode='inference'): ''' Inference method for processing a full batch ''' self.model.eval() # Data is automatically moved to device by Accelerator features = batch['input_features'] day_indicies = batch['day_indicies'] n_time_steps = batch['n_time_steps'] with torch.no_grad(): with self.accelerator.autocast(): # Apply data transformations (no augmentation for inference) features, n_time_steps = self.transform_data(features, n_time_steps, 'val') # Calculate adjusted sequence lengths for CTC with proper dtype handling adjusted_lens = ((n_time_steps.float() - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) # Ensure features tensor matches model parameter dtype for TPU compatibility if self.accelerator.mixed_precision == 'bf16': # In mixed precision mode, ensure features match the expected precision features = features.to(torch.float32) # Get phoneme predictions logits = self.model(features, day_indicies, None, False, mode) return logits, adjusted_lens