From 3b242b908dc7145aa2c48f1961473f44f0a006eb Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Wed, 15 Oct 2025 19:04:42 +0800 Subject: [PATCH] trainer --- model_training_nnn_tpu/trainer_tf.py | 651 +++++++++++++++++++++++++++ 1 file changed, 651 insertions(+) create mode 100644 model_training_nnn_tpu/trainer_tf.py diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py new file mode 100644 index 0000000..253a0db --- /dev/null +++ b/model_training_nnn_tpu/trainer_tf.py @@ -0,0 +1,651 @@ +import os +import tensorflow as tf +import numpy as np +import time +import json +import pickle +import logging +import pathlib +import sys +from typing import Dict, Any, Tuple, Optional, List +from omegaconf import OmegaConf + +from rnn_model_tf import ( + TripleGRUDecoder, + CTCLoss, + create_tpu_strategy, + build_model_for_tpu, + configure_mixed_precision +) +from dataset_tf import ( + BrainToTextDatasetTF, + DataAugmentationTF, + train_test_split_indices, + create_input_fn +) + + +class BrainToTextDecoderTrainerTF: + """ + TensorFlow/Keras trainer for brain-to-text phoneme decoder optimized for TPU v5e-8 + + This trainer implements the same training logic as the PyTorch version but uses + TensorFlow operations optimized for TPU hardware. + """ + + def __init__(self, args: Dict[str, Any]): + """ + Initialize the TensorFlow trainer + + Args: + args: Configuration dictionary containing all training parameters + """ + self.args = args + self.logger = None + + # Initialize TPU strategy + self.strategy = create_tpu_strategy() + print(f"Training on {self.strategy.num_replicas_in_sync} TPU cores") + + # Configure mixed precision for TPU v5e-8 + if args.get('use_amp', True): + configure_mixed_precision() + self.mixed_precision = True + else: + self.mixed_precision = False + + # Initialize tracking variables + self.best_val_per = float('inf') + self.best_val_loss = float('inf') + + # Setup directories + if args['mode'] == 'train': + os.makedirs(self.args['output_dir'], exist_ok=True) + + if (args.get('save_best_checkpoint', True) or + args.get('save_all_val_steps', False) or + args.get('save_final_model', False)): + os.makedirs(self.args['checkpoint_dir'], exist_ok=True) + + # Setup logging + self._setup_logging() + + # Set random seeds + if self.args['seed'] != -1: + tf.random.set_seed(self.args['seed']) + np.random.seed(self.args['seed']) + + # Initialize datasets + self._initialize_datasets() + + # Build model within strategy scope + with self.strategy.scope(): + self.model = self._build_model() + self.optimizer = self._create_optimizer() + self.lr_scheduler = self._create_lr_scheduler() + self.ctc_loss = CTCLoss(blank_index=0, reduction='none') + + # Log model information + self._log_model_info() + + # Adversarial training configuration + adv_cfg = self.args.get('adversarial', {}) + self.adv_enabled = adv_cfg.get('enabled', False) + self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5)) + self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2)) + self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) + self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) + + if self.adv_enabled: + self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, " + f"noisy_loss_weight={self.adv_noisy_loss_weight}, " + f"noise_l2_weight={self.adv_noise_l2_weight}, " + f"warmup_steps={self.adv_warmup_steps}") + + def _setup_logging(self): + """Setup logging configuration""" + self.logger = logging.getLogger(__name__) + for handler in self.logger.handlers[:]: + self.logger.removeHandler(handler) + self.logger.setLevel(logging.INFO) + formatter = logging.Formatter(fmt='%(asctime)s: %(message)s') + + if self.args['mode'] == 'train': + fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'], 'training_log'))) + fh.setFormatter(formatter) + self.logger.addHandler(fh) + + sh = logging.StreamHandler(sys.stdout) + sh.setFormatter(formatter) + self.logger.addHandler(sh) + + self.logger.info(f'Using TPU strategy with {self.strategy.num_replicas_in_sync} replicas') + if self.mixed_precision: + self.logger.info('Mixed precision (bfloat16) enabled for TPU training') + + def _initialize_datasets(self): + """Initialize training and validation datasets""" + # Create file paths + train_file_paths = [ + os.path.join(self.args["dataset"]["dataset_dir"], s, 'data_train.hdf5') + for s in self.args['dataset']['sessions'] + ] + val_file_paths = [ + os.path.join(self.args["dataset"]["dataset_dir"], s, 'data_val.hdf5') + for s in self.args['dataset']['sessions'] + ] + + # Validate no duplicates + if len(set(train_file_paths)) != len(train_file_paths): + raise ValueError("Duplicate sessions in train dataset") + if len(set(val_file_paths)) != len(val_file_paths): + raise ValueError("Duplicate sessions in val dataset") + + # Split trials + train_trials, _ = train_test_split_indices( + file_paths=train_file_paths, + test_percentage=0, + seed=self.args['dataset']['seed'], + bad_trials_dict=self.args['dataset'].get('bad_trials_dict') + ) + + _, val_trials = train_test_split_indices( + file_paths=val_file_paths, + test_percentage=1, + seed=self.args['dataset']['seed'], + bad_trials_dict=self.args['dataset'].get('bad_trials_dict') + ) + + # Save trial splits + with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f: + json.dump({'train': train_trials, 'val': val_trials}, f) + + # Create TensorFlow datasets + self.train_dataset_tf = BrainToTextDatasetTF( + trial_indices=train_trials, + n_batches=self.args['num_training_batches'], + split='train', + batch_size=self.args['dataset']['batch_size'], + days_per_batch=self.args['dataset']['days_per_batch'], + random_seed=self.args['dataset']['seed'], + must_include_days=self.args['dataset'].get('must_include_days'), + feature_subset=self.args['dataset'].get('feature_subset') + ) + + self.val_dataset_tf = BrainToTextDatasetTF( + trial_indices=val_trials, + n_batches=None, # Use all validation data + split='test', + batch_size=self.args['dataset']['batch_size'], + days_per_batch=1, # One day per validation batch + random_seed=self.args['dataset']['seed'], + feature_subset=self.args['dataset'].get('feature_subset') + ) + + self.logger.info("Successfully initialized TensorFlow datasets") + + def _build_model(self) -> TripleGRUDecoder: + """Build the TripleGRUDecoder model""" + model = TripleGRUDecoder( + neural_dim=self.args['model']['n_input_features'], + n_units=self.args['model']['n_units'], + n_days=len(self.args['dataset']['sessions']), + n_classes=self.args['dataset']['n_classes'], + rnn_dropout=self.args['model']['rnn_dropout'], + input_dropout=self.args['model']['input_network']['input_layer_dropout'], + patch_size=self.args['model']['patch_size'], + patch_stride=self.args['model']['patch_stride'] + ) + return model + + def _create_optimizer(self) -> tf.keras.optimizers.Optimizer: + """Create AdamW optimizer with parameter groups""" + # Note: TensorFlow doesn't have the same parameter group functionality as PyTorch + # We'll use a single optimizer and handle different learning rates in the scheduler + optimizer = tf.keras.optimizers.AdamW( + learning_rate=self.args['lr_max'], + beta_1=self.args['beta0'], + beta_2=self.args['beta1'], + epsilon=self.args['epsilon'], + weight_decay=self.args['weight_decay'] + ) + + return optimizer + + def _create_lr_scheduler(self): + """Create learning rate scheduler""" + if self.args['lr_scheduler_type'] == 'cosine': + return self._create_cosine_scheduler() + elif self.args['lr_scheduler_type'] == 'linear': + return tf.keras.optimizers.schedules.PolynomialDecay( + initial_learning_rate=self.args['lr_max'], + decay_steps=self.args['lr_decay_steps'], + end_learning_rate=self.args['lr_min'], + power=1.0 # Linear decay + ) + else: + raise ValueError(f"Unknown scheduler type: {self.args['lr_scheduler_type']}") + + def _create_cosine_scheduler(self): + """Create cosine learning rate scheduler""" + return tf.keras.optimizers.schedules.CosineDecayRestarts( + initial_learning_rate=self.args['lr_max'], + first_decay_steps=self.args['lr_decay_steps'], + t_mul=1.0, + m_mul=1.0, + alpha=self.args['lr_min'] / self.args['lr_max'] + ) + + def _log_model_info(self): + """Log model architecture and parameter information""" + self.logger.info("Initialized TripleGRUDecoder model") + + # Build the model by calling it once with dummy data + dummy_batch_size = 2 + dummy_time_steps = 100 + dummy_features = tf.zeros((dummy_batch_size, dummy_time_steps, self.args['model']['n_input_features'])) + dummy_day_idx = tf.zeros((dummy_batch_size,), dtype=tf.int32) + + # Call the model to build it + _ = self.model(dummy_features, dummy_day_idx, training=False) + + # Count parameters + total_params = sum([tf.size(w).numpy() for w in self.model.trainable_weights]) + self.logger.info(f"Model has {total_params:,} trainable parameters") + + @tf.function + def _train_step(self, batch, step): + """Single training step with gradient tape""" + features = batch['input_features'] + labels = batch['seq_class_ids'] + n_time_steps = batch['n_time_steps'] + phone_seq_lens = batch['phone_seq_lens'] + day_indices = batch['day_indices'] + + with tf.GradientTape() as tape: + # Apply data transformations + features, n_time_steps = DataAugmentationTF.transform_data( + features, n_time_steps, self.args['dataset']['data_transforms'], training=True + ) + + # Calculate adjusted lengths for CTC + adjusted_lens = tf.cast( + (tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) / + self.args['model']['patch_stride'] + 1, + tf.int32 + ) + + # Forward pass + use_full = self.adv_enabled and (step >= self.adv_warmup_steps) + if use_full: + clean_logits, noisy_logits, noise_output = self.model( + features, day_indices, None, False, 'full', + grl_lambda=self.adv_grl_lambda, training=True + ) + else: + clean_logits = self.model( + features, day_indices, None, False, 'inference', training=True + ) + + # Calculate losses + if use_full: + # Clean CTC loss + clean_loss_input = { + 'labels': labels, + 'input_lengths': adjusted_lens, + 'label_lengths': phone_seq_lens + } + clean_loss = self.ctc_loss(clean_loss_input, clean_logits) + clean_loss = tf.reduce_mean(clean_loss) + + # Noisy CTC loss + noisy_loss_input = { + 'labels': labels, + 'input_lengths': adjusted_lens, + 'label_lengths': phone_seq_lens + } + noisy_loss = self.ctc_loss(noisy_loss_input, noisy_logits) + noisy_loss = tf.reduce_mean(noisy_loss) + + # Optional noise L2 regularization + noise_l2 = tf.constant(0.0, dtype=clean_loss.dtype) + if self.adv_noise_l2_weight > 0.0: + noise_l2 = tf.reduce_mean(tf.square(noise_output)) + + loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2 + else: + loss_input = { + 'labels': labels, + 'input_lengths': adjusted_lens, + 'label_lengths': phone_seq_lens + } + loss = self.ctc_loss(loss_input, clean_logits) + loss = tf.reduce_mean(loss) + + # Scale loss for mixed precision + if self.mixed_precision: + scaled_loss = self.optimizer.get_scaled_loss(loss) + else: + scaled_loss = loss + + # Calculate gradients + if self.mixed_precision: + scaled_gradients = tape.gradient(scaled_loss, self.model.trainable_variables) + gradients = self.optimizer.get_unscaled_gradients(scaled_gradients) + else: + gradients = tape.gradient(scaled_loss, self.model.trainable_variables) + + # Clip gradients + if self.args['grad_norm_clip_value'] > 0: + gradients, grad_norm = tf.clip_by_global_norm( + gradients, self.args['grad_norm_clip_value'] + ) + else: + grad_norm = tf.global_norm(gradients) + + # Apply gradients + self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables)) + + return loss, grad_norm + + @tf.function + def _validation_step(self, batch): + """Single validation step""" + features = batch['input_features'] + labels = batch['seq_class_ids'] + n_time_steps = batch['n_time_steps'] + phone_seq_lens = batch['phone_seq_lens'] + day_indices = batch['day_indices'] + + # Apply data transformations (no augmentation for validation) + features, n_time_steps = DataAugmentationTF.transform_data( + features, n_time_steps, self.args['dataset']['data_transforms'], training=False + ) + + # Calculate adjusted lengths + adjusted_lens = tf.cast( + (tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) / + self.args['model']['patch_stride'] + 1, + tf.int32 + ) + + # Forward pass (inference mode only) + logits = self.model(features, day_indices, None, False, 'inference', training=False) + + # Calculate loss + loss_input = { + 'labels': labels, + 'input_lengths': adjusted_lens, + 'label_lengths': phone_seq_lens + } + loss = self.ctc_loss(loss_input, logits) + loss = tf.reduce_mean(loss) + + # Calculate PER (Phoneme Error Rate) + # Greedy decoding + predicted_ids = tf.argmax(logits, axis=-1) + + # Remove blanks and consecutive duplicates + batch_edit_distance = 0 + for i in range(tf.shape(logits)[0]): + pred_seq = predicted_ids[i, :adjusted_lens[i]] + # Remove consecutive duplicates + pred_seq = tf.py_function( + func=lambda x: tf.constant([x[0]] + [x[j] for j in range(1, len(x)) if x[j] != x[j-1]]), + inp=[pred_seq], + Tout=tf.int64 + ) + # Remove blanks (assuming blank_index=0) + pred_seq = tf.boolean_mask(pred_seq, pred_seq != 0) + + true_seq = labels[i, :phone_seq_lens[i]] + + # Calculate edit distance + edit_dist = tf.edit_distance( + tf.SparseTensor( + indices=tf.expand_dims(tf.range(tf.size(pred_seq)), 1), + values=tf.cast(pred_seq, tf.int64), + dense_shape=[tf.size(pred_seq)] + ), + tf.SparseTensor( + indices=tf.expand_dims(tf.range(tf.size(true_seq)), 1), + values=tf.cast(true_seq, tf.int64), + dense_shape=[tf.size(true_seq)] + ), + normalize=False + ) + + batch_edit_distance += edit_dist + + return loss, batch_edit_distance, tf.reduce_sum(phone_seq_lens) + + def train(self) -> Dict[str, Any]: + """Main training loop""" + self.logger.info("Starting training loop...") + + # Create distributed datasets + train_dataset = create_input_fn( + self.train_dataset_tf, + self.args['dataset']['data_transforms'], + training=True + ) + val_dataset = create_input_fn( + self.val_dataset_tf, + self.args['dataset']['data_transforms'], + training=False + ) + + # Distribute datasets + train_dist_dataset = self.strategy.experimental_distribute_dataset(train_dataset) + val_dist_dataset = self.strategy.experimental_distribute_dataset(val_dataset) + + # Training metrics + train_losses = [] + val_losses = [] + val_pers = [] + val_results = [] + val_steps_since_improvement = 0 + + train_start_time = time.time() + + # Training loop + step = 0 + for batch in train_dist_dataset: + if step >= self.args['num_training_batches']: + break + + start_time = time.time() + + # Distributed training step + per_replica_losses, per_replica_grad_norms = self.strategy.run( + self._train_step, args=(batch, step) + ) + + # Reduce across replicas + loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None) + grad_norm = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_grad_norms, axis=None) + + train_step_duration = time.time() - start_time + train_losses.append(float(loss.numpy())) + + # Log training progress + if step % self.args['batches_per_train_log'] == 0: + self.logger.info(f'Train batch {step}: ' + f'loss: {float(loss.numpy()):.2f} ' + f'grad norm: {float(grad_norm.numpy()):.2f} ' + f'time: {train_step_duration:.3f}') + + # Validation step + if step % self.args['batches_per_val_step'] == 0 or step == (self.args['num_training_batches'] - 1): + self.logger.info(f"Running validation after training batch: {step}") + + val_start_time = time.time() + val_metrics = self._validate(val_dist_dataset) + val_step_duration = time.time() - val_start_time + + self.logger.info(f'Val batch {step}: ' + f'PER (avg): {val_metrics["avg_per"]:.4f} ' + f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} ' + f'time: {val_step_duration:.3f}') + + val_pers.append(val_metrics['avg_per']) + val_losses.append(val_metrics['avg_loss']) + val_results.append(val_metrics) + + # Check for improvement + new_best = False + if val_metrics['avg_per'] < self.best_val_per: + self.logger.info(f"New best test PER {self.best_val_per:.4f} --> {val_metrics['avg_per']:.4f}") + self.best_val_per = val_metrics['avg_per'] + self.best_val_loss = val_metrics['avg_loss'] + new_best = True + elif (val_metrics['avg_per'] == self.best_val_per and + val_metrics['avg_loss'] < self.best_val_loss): + self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}") + self.best_val_loss = val_metrics['avg_loss'] + new_best = True + + if new_best: + if self.args.get('save_best_checkpoint', True): + self.logger.info("Checkpointing model") + self._save_checkpoint('best_checkpoint', step) + + if self.args.get('save_val_metrics', True): + with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f: + pickle.dump(val_metrics, f) + + val_steps_since_improvement = 0 + else: + val_steps_since_improvement += 1 + + # Optional save all validation checkpoints + if self.args.get('save_all_val_steps', False): + self._save_checkpoint(f'checkpoint_batch_{step}', step) + + # Early stopping + if (self.args.get('early_stopping', False) and + val_steps_since_improvement >= self.args.get('early_stopping_val_steps', 20)): + self.logger.info(f'Validation PER has not improved in {self.args["early_stopping_val_steps"]} ' + f'validation steps. Stopping training early at batch: {step}') + break + + step += 1 + + # Training completed + training_duration = time.time() - train_start_time + self.logger.info(f'Best avg val PER achieved: {self.best_val_per:.5f}') + self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes') + + # Save final model + if self.args.get('save_final_model', False): + last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf') + self._save_checkpoint(f'final_checkpoint_batch_{step-1}', step-1) + + return { + 'train_losses': train_losses, + 'val_losses': val_losses, + 'val_pers': val_pers, + 'val_metrics': val_results + } + + def _validate(self, val_dataset) -> Dict[str, Any]: + """Run validation on entire validation dataset""" + total_loss = 0.0 + total_edit_distance = 0 + total_seq_length = 0 + num_batches = 0 + + for batch in val_dataset: + per_replica_losses, per_replica_edit_distances, per_replica_seq_lengths = ( + self.strategy.run(self._validation_step, args=(batch,)) + ) + + # Reduce across replicas + batch_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None) + batch_edit_distance = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_edit_distances, axis=None) + batch_seq_length = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_seq_lengths, axis=None) + + total_loss += float(batch_loss.numpy()) + total_edit_distance += float(batch_edit_distance.numpy()) + total_seq_length += float(batch_seq_length.numpy()) + num_batches += 1 + + avg_loss = total_loss / max(num_batches, 1) + avg_per = total_edit_distance / max(total_seq_length, 1e-6) + + return { + 'avg_loss': avg_loss, + 'avg_per': avg_per, + 'total_edit_distance': total_edit_distance, + 'total_seq_length': total_seq_length, + 'num_batches': num_batches + } + + def _save_checkpoint(self, name: str, step: int): + """Save model checkpoint""" + checkpoint_path = os.path.join(self.args['checkpoint_dir'], name) + + # Save model weights + self.model.save_weights(checkpoint_path + '.weights.h5') + + # Save optimizer state + optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer) + optimizer_checkpoint.save(checkpoint_path + '.optimizer') + + # Save training state + state = { + 'step': step, + 'best_val_per': float(self.best_val_per), + 'best_val_loss': float(self.best_val_loss) + } + + with open(checkpoint_path + '.state.json', 'w') as f: + json.dump(state, f) + + # Save config + with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f: + OmegaConf.save(config=self.args, f=f) + + self.logger.info(f"Saved checkpoint: {checkpoint_path}") + + def load_checkpoint(self, checkpoint_path: str): + """Load model checkpoint""" + # Load model weights + self.model.load_weights(checkpoint_path + '.weights.h5') + + # Load optimizer state + optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer) + optimizer_checkpoint.restore(checkpoint_path + '.optimizer-1') + + # Load training state + with open(checkpoint_path + '.state.json', 'r') as f: + state = json.load(f) + + self.best_val_per = state['best_val_per'] + self.best_val_loss = state['best_val_loss'] + + self.logger.info(f"Loaded checkpoint: {checkpoint_path}") + + def inference(self, features: tf.Tensor, day_indices: tf.Tensor, + n_time_steps: tf.Tensor, mode: str = 'inference') -> tf.Tensor: + """ + Run inference on input features + + Args: + features: Input neural features [batch_size, time_steps, features] + day_indices: Day indices [batch_size] + n_time_steps: Number of valid time steps [batch_size] + mode: 'inference' or 'full' + + Returns: + Phoneme logits [batch_size, time_steps, n_classes] + """ + # Apply data transformations (no augmentation) + features, n_time_steps = DataAugmentationTF.transform_data( + features, n_time_steps, self.args['dataset']['data_transforms'], training=False + ) + + # Run model inference + logits = self.model(features, day_indices, None, False, mode, training=False) + + return logits \ No newline at end of file