651 lines
		
	
	
		
			25 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			651 lines
		
	
	
		
			25 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| import tensorflow as tf
 | |
| import numpy as np
 | |
| import time
 | |
| import json
 | |
| import pickle
 | |
| import logging
 | |
| import pathlib
 | |
| import sys
 | |
| from typing import Dict, Any, Tuple, Optional, List
 | |
| from omegaconf import OmegaConf
 | |
| 
 | |
| from rnn_model_tf import (
 | |
|     TripleGRUDecoder,
 | |
|     CTCLoss,
 | |
|     create_tpu_strategy,
 | |
|     build_model_for_tpu,
 | |
|     configure_mixed_precision
 | |
| )
 | |
| from dataset_tf import (
 | |
|     BrainToTextDatasetTF,
 | |
|     DataAugmentationTF,
 | |
|     train_test_split_indices,
 | |
|     create_input_fn
 | |
| )
 | |
| 
 | |
| 
 | |
| class BrainToTextDecoderTrainerTF:
 | |
|     """
 | |
|     TensorFlow/Keras trainer for brain-to-text phoneme decoder optimized for TPU v5e-8
 | |
| 
 | |
|     This trainer implements the same training logic as the PyTorch version but uses
 | |
|     TensorFlow operations optimized for TPU hardware.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, args: Dict[str, Any]):
 | |
|         """
 | |
|         Initialize the TensorFlow trainer
 | |
| 
 | |
|         Args:
 | |
|             args: Configuration dictionary containing all training parameters
 | |
|         """
 | |
|         self.args = args
 | |
|         self.logger = None
 | |
| 
 | |
|         # Initialize TPU strategy
 | |
|         self.strategy = create_tpu_strategy()
 | |
|         print(f"Training on {self.strategy.num_replicas_in_sync} TPU cores")
 | |
| 
 | |
|         # Configure mixed precision for TPU v5e-8
 | |
|         if args.get('use_amp', True):
 | |
|             configure_mixed_precision()
 | |
|             self.mixed_precision = True
 | |
|         else:
 | |
|             self.mixed_precision = False
 | |
| 
 | |
|         # Initialize tracking variables
 | |
|         self.best_val_per = float('inf')
 | |
|         self.best_val_loss = float('inf')
 | |
| 
 | |
|         # Setup directories
 | |
|         if args['mode'] == 'train':
 | |
|             os.makedirs(self.args['output_dir'], exist_ok=True)
 | |
| 
 | |
|         if (args.get('save_best_checkpoint', True) or
 | |
|             args.get('save_all_val_steps', False) or
 | |
|             args.get('save_final_model', False)):
 | |
|             os.makedirs(self.args['checkpoint_dir'], exist_ok=True)
 | |
| 
 | |
|         # Setup logging
 | |
|         self._setup_logging()
 | |
| 
 | |
|         # Set random seeds
 | |
|         if self.args['seed'] != -1:
 | |
|             tf.random.set_seed(self.args['seed'])
 | |
|             np.random.seed(self.args['seed'])
 | |
| 
 | |
|         # Initialize datasets
 | |
|         self._initialize_datasets()
 | |
| 
 | |
|         # Build model within strategy scope
 | |
|         with self.strategy.scope():
 | |
|             self.model = self._build_model()
 | |
|             self.optimizer = self._create_optimizer()
 | |
|             self.lr_scheduler = self._create_lr_scheduler()
 | |
|             self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
 | |
| 
 | |
|         # Log model information
 | |
|         self._log_model_info()
 | |
| 
 | |
|         # Adversarial training configuration
 | |
|         adv_cfg = self.args.get('adversarial', {})
 | |
|         self.adv_enabled = adv_cfg.get('enabled', False)
 | |
|         self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5))
 | |
|         self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2))
 | |
|         self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0))
 | |
|         self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0))
 | |
| 
 | |
|         if self.adv_enabled:
 | |
|             self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, "
 | |
|                            f"noisy_loss_weight={self.adv_noisy_loss_weight}, "
 | |
|                            f"noise_l2_weight={self.adv_noise_l2_weight}, "
 | |
|                            f"warmup_steps={self.adv_warmup_steps}")
 | |
| 
 | |
|     def _setup_logging(self):
 | |
|         """Setup logging configuration"""
 | |
|         self.logger = logging.getLogger(__name__)
 | |
|         for handler in self.logger.handlers[:]:
 | |
|             self.logger.removeHandler(handler)
 | |
|         self.logger.setLevel(logging.INFO)
 | |
|         formatter = logging.Formatter(fmt='%(asctime)s: %(message)s')
 | |
| 
 | |
|         if self.args['mode'] == 'train':
 | |
|             fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'], 'training_log')))
 | |
|             fh.setFormatter(formatter)
 | |
|             self.logger.addHandler(fh)
 | |
| 
 | |
|         sh = logging.StreamHandler(sys.stdout)
 | |
|         sh.setFormatter(formatter)
 | |
|         self.logger.addHandler(sh)
 | |
| 
 | |
|         self.logger.info(f'Using TPU strategy with {self.strategy.num_replicas_in_sync} replicas')
 | |
|         if self.mixed_precision:
 | |
|             self.logger.info('Mixed precision (bfloat16) enabled for TPU training')
 | |
| 
 | |
|     def _initialize_datasets(self):
 | |
|         """Initialize training and validation datasets"""
 | |
|         # Create file paths
 | |
|         train_file_paths = [
 | |
|             os.path.join(self.args["dataset"]["dataset_dir"], s, 'data_train.hdf5')
 | |
|             for s in self.args['dataset']['sessions']
 | |
|         ]
 | |
|         val_file_paths = [
 | |
|             os.path.join(self.args["dataset"]["dataset_dir"], s, 'data_val.hdf5')
 | |
|             for s in self.args['dataset']['sessions']
 | |
|         ]
 | |
| 
 | |
|         # Validate no duplicates
 | |
|         if len(set(train_file_paths)) != len(train_file_paths):
 | |
|             raise ValueError("Duplicate sessions in train dataset")
 | |
|         if len(set(val_file_paths)) != len(val_file_paths):
 | |
|             raise ValueError("Duplicate sessions in val dataset")
 | |
| 
 | |
|         # Split trials
 | |
|         train_trials, _ = train_test_split_indices(
 | |
|             file_paths=train_file_paths,
 | |
|             test_percentage=0,
 | |
|             seed=self.args['dataset']['seed'],
 | |
|             bad_trials_dict=self.args['dataset'].get('bad_trials_dict')
 | |
|         )
 | |
| 
 | |
|         _, val_trials = train_test_split_indices(
 | |
|             file_paths=val_file_paths,
 | |
|             test_percentage=1,
 | |
|             seed=self.args['dataset']['seed'],
 | |
|             bad_trials_dict=self.args['dataset'].get('bad_trials_dict')
 | |
|         )
 | |
| 
 | |
|         # Save trial splits
 | |
|         with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f:
 | |
|             json.dump({'train': train_trials, 'val': val_trials}, f)
 | |
| 
 | |
|         # Create TensorFlow datasets
 | |
|         self.train_dataset_tf = BrainToTextDatasetTF(
 | |
|             trial_indices=train_trials,
 | |
|             n_batches=self.args['num_training_batches'],
 | |
|             split='train',
 | |
|             batch_size=self.args['dataset']['batch_size'],
 | |
|             days_per_batch=self.args['dataset']['days_per_batch'],
 | |
|             random_seed=self.args['dataset']['seed'],
 | |
|             must_include_days=self.args['dataset'].get('must_include_days'),
 | |
|             feature_subset=self.args['dataset'].get('feature_subset')
 | |
|         )
 | |
| 
 | |
|         self.val_dataset_tf = BrainToTextDatasetTF(
 | |
|             trial_indices=val_trials,
 | |
|             n_batches=None,  # Use all validation data
 | |
|             split='test',
 | |
|             batch_size=self.args['dataset']['batch_size'],
 | |
|             days_per_batch=1,  # One day per validation batch
 | |
|             random_seed=self.args['dataset']['seed'],
 | |
|             feature_subset=self.args['dataset'].get('feature_subset')
 | |
|         )
 | |
| 
 | |
|         self.logger.info("Successfully initialized TensorFlow datasets")
 | |
| 
 | |
|     def _build_model(self) -> TripleGRUDecoder:
 | |
|         """Build the TripleGRUDecoder model"""
 | |
|         model = TripleGRUDecoder(
 | |
|             neural_dim=self.args['model']['n_input_features'],
 | |
|             n_units=self.args['model']['n_units'],
 | |
|             n_days=len(self.args['dataset']['sessions']),
 | |
|             n_classes=self.args['dataset']['n_classes'],
 | |
|             rnn_dropout=self.args['model']['rnn_dropout'],
 | |
|             input_dropout=self.args['model']['input_network']['input_layer_dropout'],
 | |
|             patch_size=self.args['model']['patch_size'],
 | |
|             patch_stride=self.args['model']['patch_stride']
 | |
|         )
 | |
|         return model
 | |
| 
 | |
|     def _create_optimizer(self) -> tf.keras.optimizers.Optimizer:
 | |
|         """Create AdamW optimizer with parameter groups"""
 | |
|         # Note: TensorFlow doesn't have the same parameter group functionality as PyTorch
 | |
|         # We'll use a single optimizer and handle different learning rates in the scheduler
 | |
|         optimizer = tf.keras.optimizers.AdamW(
 | |
|             learning_rate=self.args['lr_max'],
 | |
|             beta_1=self.args['beta0'],
 | |
|             beta_2=self.args['beta1'],
 | |
|             epsilon=self.args['epsilon'],
 | |
|             weight_decay=self.args['weight_decay']
 | |
|         )
 | |
| 
 | |
|         return optimizer
 | |
| 
 | |
|     def _create_lr_scheduler(self):
 | |
|         """Create learning rate scheduler"""
 | |
|         if self.args['lr_scheduler_type'] == 'cosine':
 | |
|             return self._create_cosine_scheduler()
 | |
|         elif self.args['lr_scheduler_type'] == 'linear':
 | |
|             return tf.keras.optimizers.schedules.PolynomialDecay(
 | |
|                 initial_learning_rate=self.args['lr_max'],
 | |
|                 decay_steps=self.args['lr_decay_steps'],
 | |
|                 end_learning_rate=self.args['lr_min'],
 | |
|                 power=1.0  # Linear decay
 | |
|             )
 | |
|         else:
 | |
|             raise ValueError(f"Unknown scheduler type: {self.args['lr_scheduler_type']}")
 | |
| 
 | |
|     def _create_cosine_scheduler(self):
 | |
|         """Create cosine learning rate scheduler"""
 | |
|         return tf.keras.optimizers.schedules.CosineDecayRestarts(
 | |
|             initial_learning_rate=self.args['lr_max'],
 | |
|             first_decay_steps=self.args['lr_decay_steps'],
 | |
|             t_mul=1.0,
 | |
|             m_mul=1.0,
 | |
|             alpha=self.args['lr_min'] / self.args['lr_max']
 | |
|         )
 | |
| 
 | |
|     def _log_model_info(self):
 | |
|         """Log model architecture and parameter information"""
 | |
|         self.logger.info("Initialized TripleGRUDecoder model")
 | |
| 
 | |
|         # Build the model by calling it once with dummy data
 | |
|         dummy_batch_size = 2
 | |
|         dummy_time_steps = 100
 | |
|         dummy_features = tf.zeros((dummy_batch_size, dummy_time_steps, self.args['model']['n_input_features']))
 | |
|         dummy_day_idx = tf.zeros((dummy_batch_size,), dtype=tf.int32)
 | |
| 
 | |
|         # Call the model to build it
 | |
|         _ = self.model(dummy_features, dummy_day_idx, training=False)
 | |
| 
 | |
|         # Count parameters
 | |
|         total_params = sum([tf.size(w).numpy() for w in self.model.trainable_weights])
 | |
|         self.logger.info(f"Model has {total_params:,} trainable parameters")
 | |
| 
 | |
|     @tf.function
 | |
|     def _train_step(self, batch, step):
 | |
|         """Single training step with gradient tape"""
 | |
|         features = batch['input_features']
 | |
|         labels = batch['seq_class_ids']
 | |
|         n_time_steps = batch['n_time_steps']
 | |
|         phone_seq_lens = batch['phone_seq_lens']
 | |
|         day_indices = batch['day_indices']
 | |
| 
 | |
|         with tf.GradientTape() as tape:
 | |
|             # Apply data transformations
 | |
|             features, n_time_steps = DataAugmentationTF.transform_data(
 | |
|                 features, n_time_steps, self.args['dataset']['data_transforms'], training=True
 | |
|             )
 | |
| 
 | |
|             # Calculate adjusted lengths for CTC
 | |
|             adjusted_lens = tf.cast(
 | |
|                 (tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) /
 | |
|                 self.args['model']['patch_stride'] + 1,
 | |
|                 tf.int32
 | |
|             )
 | |
| 
 | |
|             # Forward pass
 | |
|             use_full = self.adv_enabled and (step >= self.adv_warmup_steps)
 | |
|             if use_full:
 | |
|                 clean_logits, noisy_logits, noise_output = self.model(
 | |
|                     features, day_indices, None, False, 'full',
 | |
|                     grl_lambda=self.adv_grl_lambda, training=True
 | |
|                 )
 | |
|             else:
 | |
|                 clean_logits = self.model(
 | |
|                     features, day_indices, None, False, 'inference', training=True
 | |
|                 )
 | |
| 
 | |
|             # Calculate losses
 | |
|             if use_full:
 | |
|                 # Clean CTC loss
 | |
|                 clean_loss_input = {
 | |
|                     'labels': labels,
 | |
|                     'input_lengths': adjusted_lens,
 | |
|                     'label_lengths': phone_seq_lens
 | |
|                 }
 | |
|                 clean_loss = self.ctc_loss(clean_loss_input, clean_logits)
 | |
|                 clean_loss = tf.reduce_mean(clean_loss)
 | |
| 
 | |
|                 # Noisy CTC loss
 | |
|                 noisy_loss_input = {
 | |
|                     'labels': labels,
 | |
|                     'input_lengths': adjusted_lens,
 | |
|                     'label_lengths': phone_seq_lens
 | |
|                 }
 | |
|                 noisy_loss = self.ctc_loss(noisy_loss_input, noisy_logits)
 | |
|                 noisy_loss = tf.reduce_mean(noisy_loss)
 | |
| 
 | |
|                 # Optional noise L2 regularization
 | |
|                 noise_l2 = tf.constant(0.0, dtype=clean_loss.dtype)
 | |
|                 if self.adv_noise_l2_weight > 0.0:
 | |
|                     noise_l2 = tf.reduce_mean(tf.square(noise_output))
 | |
| 
 | |
|                 loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
 | |
|             else:
 | |
|                 loss_input = {
 | |
|                     'labels': labels,
 | |
|                     'input_lengths': adjusted_lens,
 | |
|                     'label_lengths': phone_seq_lens
 | |
|                 }
 | |
|                 loss = self.ctc_loss(loss_input, clean_logits)
 | |
|                 loss = tf.reduce_mean(loss)
 | |
| 
 | |
|             # Scale loss for mixed precision
 | |
|             if self.mixed_precision:
 | |
|                 scaled_loss = self.optimizer.get_scaled_loss(loss)
 | |
|             else:
 | |
|                 scaled_loss = loss
 | |
| 
 | |
|         # Calculate gradients
 | |
|         if self.mixed_precision:
 | |
|             scaled_gradients = tape.gradient(scaled_loss, self.model.trainable_variables)
 | |
|             gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
 | |
|         else:
 | |
|             gradients = tape.gradient(scaled_loss, self.model.trainable_variables)
 | |
| 
 | |
|         # Clip gradients
 | |
|         if self.args['grad_norm_clip_value'] > 0:
 | |
|             gradients, grad_norm = tf.clip_by_global_norm(
 | |
|                 gradients, self.args['grad_norm_clip_value']
 | |
|             )
 | |
|         else:
 | |
|             grad_norm = tf.global_norm(gradients)
 | |
| 
 | |
|         # Apply gradients
 | |
|         self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
 | |
| 
 | |
|         return loss, grad_norm
 | |
| 
 | |
|     @tf.function
 | |
|     def _validation_step(self, batch):
 | |
|         """Single validation step"""
 | |
|         features = batch['input_features']
 | |
|         labels = batch['seq_class_ids']
 | |
|         n_time_steps = batch['n_time_steps']
 | |
|         phone_seq_lens = batch['phone_seq_lens']
 | |
|         day_indices = batch['day_indices']
 | |
| 
 | |
|         # Apply data transformations (no augmentation for validation)
 | |
|         features, n_time_steps = DataAugmentationTF.transform_data(
 | |
|             features, n_time_steps, self.args['dataset']['data_transforms'], training=False
 | |
|         )
 | |
| 
 | |
|         # Calculate adjusted lengths
 | |
|         adjusted_lens = tf.cast(
 | |
|             (tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) /
 | |
|             self.args['model']['patch_stride'] + 1,
 | |
|             tf.int32
 | |
|         )
 | |
| 
 | |
|         # Forward pass (inference mode only)
 | |
|         logits = self.model(features, day_indices, None, False, 'inference', training=False)
 | |
| 
 | |
|         # Calculate loss
 | |
|         loss_input = {
 | |
|             'labels': labels,
 | |
|             'input_lengths': adjusted_lens,
 | |
|             'label_lengths': phone_seq_lens
 | |
|         }
 | |
|         loss = self.ctc_loss(loss_input, logits)
 | |
|         loss = tf.reduce_mean(loss)
 | |
| 
 | |
|         # Calculate PER (Phoneme Error Rate)
 | |
|         # Greedy decoding
 | |
|         predicted_ids = tf.argmax(logits, axis=-1)
 | |
| 
 | |
|         # Remove blanks and consecutive duplicates
 | |
|         batch_edit_distance = 0
 | |
|         for i in range(tf.shape(logits)[0]):
 | |
|             pred_seq = predicted_ids[i, :adjusted_lens[i]]
 | |
|             # Remove consecutive duplicates
 | |
|             pred_seq = tf.py_function(
 | |
|                 func=lambda x: tf.constant([x[0]] + [x[j] for j in range(1, len(x)) if x[j] != x[j-1]]),
 | |
|                 inp=[pred_seq],
 | |
|                 Tout=tf.int64
 | |
|             )
 | |
|             # Remove blanks (assuming blank_index=0)
 | |
|             pred_seq = tf.boolean_mask(pred_seq, pred_seq != 0)
 | |
| 
 | |
|             true_seq = labels[i, :phone_seq_lens[i]]
 | |
| 
 | |
|             # Calculate edit distance
 | |
|             edit_dist = tf.edit_distance(
 | |
|                 tf.SparseTensor(
 | |
|                     indices=tf.expand_dims(tf.range(tf.size(pred_seq)), 1),
 | |
|                     values=tf.cast(pred_seq, tf.int64),
 | |
|                     dense_shape=[tf.size(pred_seq)]
 | |
|                 ),
 | |
|                 tf.SparseTensor(
 | |
|                     indices=tf.expand_dims(tf.range(tf.size(true_seq)), 1),
 | |
|                     values=tf.cast(true_seq, tf.int64),
 | |
|                     dense_shape=[tf.size(true_seq)]
 | |
|                 ),
 | |
|                 normalize=False
 | |
|             )
 | |
| 
 | |
|             batch_edit_distance += edit_dist
 | |
| 
 | |
|         return loss, batch_edit_distance, tf.reduce_sum(phone_seq_lens)
 | |
| 
 | |
|     def train(self) -> Dict[str, Any]:
 | |
|         """Main training loop"""
 | |
|         self.logger.info("Starting training loop...")
 | |
| 
 | |
|         # Create distributed datasets
 | |
|         train_dataset = create_input_fn(
 | |
|             self.train_dataset_tf,
 | |
|             self.args['dataset']['data_transforms'],
 | |
|             training=True
 | |
|         )
 | |
|         val_dataset = create_input_fn(
 | |
|             self.val_dataset_tf,
 | |
|             self.args['dataset']['data_transforms'],
 | |
|             training=False
 | |
|         )
 | |
| 
 | |
|         # Distribute datasets
 | |
|         train_dist_dataset = self.strategy.experimental_distribute_dataset(train_dataset)
 | |
|         val_dist_dataset = self.strategy.experimental_distribute_dataset(val_dataset)
 | |
| 
 | |
|         # Training metrics
 | |
|         train_losses = []
 | |
|         val_losses = []
 | |
|         val_pers = []
 | |
|         val_results = []
 | |
|         val_steps_since_improvement = 0
 | |
| 
 | |
|         train_start_time = time.time()
 | |
| 
 | |
|         # Training loop
 | |
|         step = 0
 | |
|         for batch in train_dist_dataset:
 | |
|             if step >= self.args['num_training_batches']:
 | |
|                 break
 | |
| 
 | |
|             start_time = time.time()
 | |
| 
 | |
|             # Distributed training step
 | |
|             per_replica_losses, per_replica_grad_norms = self.strategy.run(
 | |
|                 self._train_step, args=(batch, step)
 | |
|             )
 | |
| 
 | |
|             # Reduce across replicas
 | |
|             loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
 | |
|             grad_norm = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_grad_norms, axis=None)
 | |
| 
 | |
|             train_step_duration = time.time() - start_time
 | |
|             train_losses.append(float(loss.numpy()))
 | |
| 
 | |
|             # Log training progress
 | |
|             if step % self.args['batches_per_train_log'] == 0:
 | |
|                 self.logger.info(f'Train batch {step}: '
 | |
|                                f'loss: {float(loss.numpy()):.2f} '
 | |
|                                f'grad norm: {float(grad_norm.numpy()):.2f} '
 | |
|                                f'time: {train_step_duration:.3f}')
 | |
| 
 | |
|             # Validation step
 | |
|             if step % self.args['batches_per_val_step'] == 0 or step == (self.args['num_training_batches'] - 1):
 | |
|                 self.logger.info(f"Running validation after training batch: {step}")
 | |
| 
 | |
|                 val_start_time = time.time()
 | |
|                 val_metrics = self._validate(val_dist_dataset)
 | |
|                 val_step_duration = time.time() - val_start_time
 | |
| 
 | |
|                 self.logger.info(f'Val batch {step}: '
 | |
|                                f'PER (avg): {val_metrics["avg_per"]:.4f} '
 | |
|                                f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} '
 | |
|                                f'time: {val_step_duration:.3f}')
 | |
| 
 | |
|                 val_pers.append(val_metrics['avg_per'])
 | |
|                 val_losses.append(val_metrics['avg_loss'])
 | |
|                 val_results.append(val_metrics)
 | |
| 
 | |
|                 # Check for improvement
 | |
|                 new_best = False
 | |
|                 if val_metrics['avg_per'] < self.best_val_per:
 | |
|                     self.logger.info(f"New best test PER {self.best_val_per:.4f} --> {val_metrics['avg_per']:.4f}")
 | |
|                     self.best_val_per = val_metrics['avg_per']
 | |
|                     self.best_val_loss = val_metrics['avg_loss']
 | |
|                     new_best = True
 | |
|                 elif (val_metrics['avg_per'] == self.best_val_per and
 | |
|                       val_metrics['avg_loss'] < self.best_val_loss):
 | |
|                     self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}")
 | |
|                     self.best_val_loss = val_metrics['avg_loss']
 | |
|                     new_best = True
 | |
| 
 | |
|                 if new_best:
 | |
|                     if self.args.get('save_best_checkpoint', True):
 | |
|                         self.logger.info("Checkpointing model")
 | |
|                         self._save_checkpoint('best_checkpoint', step)
 | |
| 
 | |
|                     if self.args.get('save_val_metrics', True):
 | |
|                         with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f:
 | |
|                             pickle.dump(val_metrics, f)
 | |
| 
 | |
|                     val_steps_since_improvement = 0
 | |
|                 else:
 | |
|                     val_steps_since_improvement += 1
 | |
| 
 | |
|                 # Optional save all validation checkpoints
 | |
|                 if self.args.get('save_all_val_steps', False):
 | |
|                     self._save_checkpoint(f'checkpoint_batch_{step}', step)
 | |
| 
 | |
|                 # Early stopping
 | |
|                 if (self.args.get('early_stopping', False) and
 | |
|                     val_steps_since_improvement >= self.args.get('early_stopping_val_steps', 20)):
 | |
|                     self.logger.info(f'Validation PER has not improved in {self.args["early_stopping_val_steps"]} '
 | |
|                                    f'validation steps. Stopping training early at batch: {step}')
 | |
|                     break
 | |
| 
 | |
|             step += 1
 | |
| 
 | |
|         # Training completed
 | |
|         training_duration = time.time() - train_start_time
 | |
|         self.logger.info(f'Best avg val PER achieved: {self.best_val_per:.5f}')
 | |
|         self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes')
 | |
| 
 | |
|         # Save final model
 | |
|         if self.args.get('save_final_model', False):
 | |
|             last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf')
 | |
|             self._save_checkpoint(f'final_checkpoint_batch_{step-1}', step-1)
 | |
| 
 | |
|         return {
 | |
|             'train_losses': train_losses,
 | |
|             'val_losses': val_losses,
 | |
|             'val_pers': val_pers,
 | |
|             'val_metrics': val_results
 | |
|         }
 | |
| 
 | |
|     def _validate(self, val_dataset) -> Dict[str, Any]:
 | |
|         """Run validation on entire validation dataset"""
 | |
|         total_loss = 0.0
 | |
|         total_edit_distance = 0
 | |
|         total_seq_length = 0
 | |
|         num_batches = 0
 | |
| 
 | |
|         for batch in val_dataset:
 | |
|             per_replica_losses, per_replica_edit_distances, per_replica_seq_lengths = (
 | |
|                 self.strategy.run(self._validation_step, args=(batch,))
 | |
|             )
 | |
| 
 | |
|             # Reduce across replicas
 | |
|             batch_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
 | |
|             batch_edit_distance = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_edit_distances, axis=None)
 | |
|             batch_seq_length = self.strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_seq_lengths, axis=None)
 | |
| 
 | |
|             total_loss += float(batch_loss.numpy())
 | |
|             total_edit_distance += float(batch_edit_distance.numpy())
 | |
|             total_seq_length += float(batch_seq_length.numpy())
 | |
|             num_batches += 1
 | |
| 
 | |
|         avg_loss = total_loss / max(num_batches, 1)
 | |
|         avg_per = total_edit_distance / max(total_seq_length, 1e-6)
 | |
| 
 | |
|         return {
 | |
|             'avg_loss': avg_loss,
 | |
|             'avg_per': avg_per,
 | |
|             'total_edit_distance': total_edit_distance,
 | |
|             'total_seq_length': total_seq_length,
 | |
|             'num_batches': num_batches
 | |
|         }
 | |
| 
 | |
|     def _save_checkpoint(self, name: str, step: int):
 | |
|         """Save model checkpoint"""
 | |
|         checkpoint_path = os.path.join(self.args['checkpoint_dir'], name)
 | |
| 
 | |
|         # Save model weights
 | |
|         self.model.save_weights(checkpoint_path + '.weights.h5')
 | |
| 
 | |
|         # Save optimizer state
 | |
|         optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer)
 | |
|         optimizer_checkpoint.save(checkpoint_path + '.optimizer')
 | |
| 
 | |
|         # Save training state
 | |
|         state = {
 | |
|             'step': step,
 | |
|             'best_val_per': float(self.best_val_per),
 | |
|             'best_val_loss': float(self.best_val_loss)
 | |
|         }
 | |
| 
 | |
|         with open(checkpoint_path + '.state.json', 'w') as f:
 | |
|             json.dump(state, f)
 | |
| 
 | |
|         # Save config
 | |
|         with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
 | |
|             OmegaConf.save(config=self.args, f=f)
 | |
| 
 | |
|         self.logger.info(f"Saved checkpoint: {checkpoint_path}")
 | |
| 
 | |
|     def load_checkpoint(self, checkpoint_path: str):
 | |
|         """Load model checkpoint"""
 | |
|         # Load model weights
 | |
|         self.model.load_weights(checkpoint_path + '.weights.h5')
 | |
| 
 | |
|         # Load optimizer state
 | |
|         optimizer_checkpoint = tf.train.Checkpoint(optimizer=self.optimizer)
 | |
|         optimizer_checkpoint.restore(checkpoint_path + '.optimizer-1')
 | |
| 
 | |
|         # Load training state
 | |
|         with open(checkpoint_path + '.state.json', 'r') as f:
 | |
|             state = json.load(f)
 | |
| 
 | |
|         self.best_val_per = state['best_val_per']
 | |
|         self.best_val_loss = state['best_val_loss']
 | |
| 
 | |
|         self.logger.info(f"Loaded checkpoint: {checkpoint_path}")
 | |
| 
 | |
|     def inference(self, features: tf.Tensor, day_indices: tf.Tensor,
 | |
|                  n_time_steps: tf.Tensor, mode: str = 'inference') -> tf.Tensor:
 | |
|         """
 | |
|         Run inference on input features
 | |
| 
 | |
|         Args:
 | |
|             features: Input neural features [batch_size, time_steps, features]
 | |
|             day_indices: Day indices [batch_size]
 | |
|             n_time_steps: Number of valid time steps [batch_size]
 | |
|             mode: 'inference' or 'full'
 | |
| 
 | |
|         Returns:
 | |
|             Phoneme logits [batch_size, time_steps, n_classes]
 | |
|         """
 | |
|         # Apply data transformations (no augmentation)
 | |
|         features, n_time_steps = DataAugmentationTF.transform_data(
 | |
|             features, n_time_steps, self.args['dataset']['data_transforms'], training=False
 | |
|         )
 | |
| 
 | |
|         # Run model inference
 | |
|         logits = self.model(features, day_indices, None, False, mode, training=False)
 | |
| 
 | |
|         return logits | 
