import os import tensorflow as tf import numpy as np import time import json import pickle import logging import pathlib import sys from typing import Dict, Any, Tuple, Optional, List from omegaconf import OmegaConf # For accurate PER calculation try: import editdistance except ImportError: print("Warning: editdistance not available, falling back to approximation") editdistance = None # Note: Reverted to standard tf.nn.ctc_loss + SparseTensor approach # for compatibility with "batch first, augment after" data pipeline def dense_to_sparse(dense_tensor, sequence_lengths): """ Convert dense tensor to sparse tensor for CTC loss with dynamic shapes This function is essential for the "batch first, augment after" approach as it handles the conversion from dynamic dense tensors to SparseTensor format required by tf.nn.ctc_loss. Args: dense_tensor: Dense tensor with shape [batch_size, max_seq_len] sequence_lengths: Actual sequence lengths [batch_size] Returns: SparseTensor suitable for tf.nn.ctc_loss """ # Create mask for valid (non-zero) elements within sequence lengths batch_size = tf.shape(dense_tensor)[0] max_seq_len = tf.shape(dense_tensor)[1] # Create range indices batch_indices = tf.range(batch_size) seq_indices = tf.range(max_seq_len) # Create meshgrid for sequence dimensions _, seq_mesh = tf.meshgrid(batch_indices, seq_indices, indexing='ij') # Create mask based on sequence lengths and non-zero values length_mask = seq_mesh < tf.expand_dims(sequence_lengths, 1) value_mask = tf.not_equal(dense_tensor, 0) combined_mask = tf.logical_and(length_mask, value_mask) # Get indices of valid elements indices = tf.where(combined_mask) # Get values at valid indices values = tf.gather_nd(dense_tensor, indices) # Create sparse tensor dense_shape = tf.cast(tf.shape(dense_tensor), tf.int64) return tf.SparseTensor( indices=tf.cast(indices, tf.int64), values=tf.cast(values, tf.int32), dense_shape=dense_shape ) from rnn_model_tf import ( TripleGRUDecoder, create_tpu_strategy, build_model_for_tpu, configure_mixed_precision ) from dataset_tf import ( BrainToTextDatasetTF, DataAugmentationTF, train_test_split_indices, create_input_fn ) class BrainToTextDecoderTrainerTF: """ TensorFlow/Keras trainer for brain-to-text phoneme decoder optimized for TPU v5e-8 This trainer implements the same training logic as the PyTorch version but uses TensorFlow operations optimized for TPU hardware. """ def __init__(self, args: Dict[str, Any]): """ Initialize the TensorFlow trainer Args: args: Configuration dictionary containing all training parameters """ self.args = args self.logger = None # Enable soft device placement for XLA unsupported ops (like CTC) tf.config.set_soft_device_placement(True) print("✅ Enabled soft device placement for CTC operations") # Initialize TPU strategy self.strategy = create_tpu_strategy() if self.strategy is None: raise RuntimeError("Failed to create TPU strategy - strategy is None") print(f"Training on {self.strategy.num_replicas_in_sync} TPU cores") print(f"Strategy type: {type(self.strategy).__name__}") print("💡 Using tf.data.AUTOTUNE for optimal data pipeline performance") print("📝 Ensure create_input_fn uses AUTOTUNE for .map() and .prefetch() operations") print("⚠️ CTC operations will automatically fall back to CPU (expected behavior)") print(" This has minimal performance impact as CTC is a small portion of computation") # Configure mixed precision for TPU v5e-8 if args.get('use_amp', True): configure_mixed_precision() self.mixed_precision = True else: self.mixed_precision = False # Initialize tracking variables self.best_val_per = float('inf') self.best_val_loss = float('inf') # Setup directories if args['mode'] == 'train': os.makedirs(self.args['output_dir'], exist_ok=True) if (args.get('save_best_checkpoint', True) or args.get('save_all_val_steps', False) or args.get('save_final_model', False)): os.makedirs(self.args['checkpoint_dir'], exist_ok=True) # Setup logging self._setup_logging() # Set random seeds if self.args['seed'] != -1: tf.random.set_seed(self.args['seed']) np.random.seed(self.args['seed']) # Initialize datasets self._initialize_datasets() # Build model within strategy scope with self.strategy.scope(): self.model = self._build_model() self.optimizer = self._create_optimizer() print("🔧 Initializing optimizer for TPU training...") print(f"Optimizer type: {type(self.optimizer).__name__}") # ========================= SOLUTION ========================= # Explicitly build optimizer within strategy scope before training. # This forces creation of all slot variables (e.g., AdamW momentum) # avoiding lazy initialization inside @tf.function which loses context. # Note: Model must be built first for .build() to work. # The _log_model_info method builds the model via forward pass. # Ensure model is built (will be called later in _log_model_info anyway) if not self.model.built: dummy_batch_size = 2 dummy_time_steps = 100 dummy_features = tf.zeros((dummy_batch_size, dummy_time_steps, self.args['model']['n_input_features'])) dummy_day_idx = tf.zeros((dummy_batch_size,), dtype=tf.int32) _ = self.model(dummy_features, dummy_day_idx, training=False) print("🔧 Building optimizer with model variables...") self.optimizer.build(self.model.trainable_variables) print("✅ Optimizer built successfully") # ============================================================ print("✅ Optimizer ready for TPU training") self.lr_scheduler = self._create_lr_scheduler() # CTC loss is now handled using tf.nn.ctc_loss (TPU-compatible) # Create unified checkpoint management self.ckpt = tf.train.Checkpoint( optimizer=self.optimizer, model=self.model ) self.ckpt_manager = tf.train.CheckpointManager( self.ckpt, directory=self.args['checkpoint_dir'], max_to_keep=5 # Keep only the 5 most recent checkpoints ) # Try to restore from latest checkpoint if self.ckpt_manager.latest_checkpoint: print(f"🔄 Restoring from {self.ckpt_manager.latest_checkpoint}") if self.logger: self.logger.info(f"Restoring from {self.ckpt_manager.latest_checkpoint}") # expect_partial() avoids failures due to model structure changes self.ckpt.restore(self.ckpt_manager.latest_checkpoint).expect_partial() print("✅ Checkpoint restored successfully") else: print("🆕 Initializing training from scratch") if self.logger: self.logger.info("Initializing training from scratch") # Log model information self._log_model_info() # Adversarial training configuration adv_cfg = self.args.get('adversarial', {}) self.adv_enabled = adv_cfg.get('enabled', False) self.adv_grl_lambda = float(adv_cfg.get('grl_lambda', 0.5)) self.adv_noisy_loss_weight = float(adv_cfg.get('noisy_loss_weight', 0.2)) self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) # Manual weight decay handling - disable since AdamW handles it self.manual_weight_decay = False if self.args.get('weight_decay', 0.0) > 0: print(f"🔧 Weight decay configured in AdamW: {self.args.get('weight_decay', 0.0)}") else: print("💡 No weight decay configured") if self.adv_enabled: if self.logger: self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, " f"noisy_loss_weight={self.adv_noisy_loss_weight}, " f"noise_l2_weight={self.adv_noise_l2_weight}, " f"warmup_steps={self.adv_warmup_steps}") else: print(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, " f"noisy_loss_weight={self.adv_noisy_loss_weight}, " f"noise_l2_weight={self.adv_noise_l2_weight}, " f"warmup_steps={self.adv_warmup_steps}") def _setup_logging(self): """Setup logging configuration""" self.logger = logging.getLogger(__name__) for handler in self.logger.handlers[:]: self.logger.removeHandler(handler) self.logger.setLevel(logging.INFO) formatter = logging.Formatter(fmt='%(asctime)s: %(message)s') if self.args['mode'] == 'train': fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'], 'training_log'))) fh.setFormatter(formatter) self.logger.addHandler(fh) sh = logging.StreamHandler(sys.stdout) sh.setFormatter(formatter) self.logger.addHandler(sh) self.logger.info(f'Using TPU strategy with {self.strategy.num_replicas_in_sync} replicas') if self.mixed_precision: self.logger.info('Mixed precision (bfloat16) enabled for TPU training') def _configure_cpu_optimization(self): """Configure CPU utilization to make use of 224 cores for data pipeline""" import multiprocessing # Get available CPU cores available_cores = multiprocessing.cpu_count() print(f"💻 Available CPU cores: {available_cores}") # Optimize for data pipeline parallelism # For 224 cores, use more threads for better data loading performance if available_cores >= 200: # High core count system inter_op_threads = min(64, available_cores // 3) # More aggressive for 224 cores intra_op_threads = min(32, available_cores // 6) else: # Use ~1/4 of cores for inter-op (between operations) # Use ~1/8 of cores for intra-op (within operations) inter_op_threads = min(32, available_cores // 4) intra_op_threads = min(16, available_cores // 8) tf.config.threading.set_inter_op_parallelism_threads(inter_op_threads) tf.config.threading.set_intra_op_parallelism_threads(intra_op_threads) print(f"🔧 CPU optimization configured:") print(f" Inter-op parallelism: {inter_op_threads} threads") print(f" Intra-op parallelism: {intra_op_threads} threads") print(f" This will accelerate data loading and preprocessing while TPU handles training") def _get_tpu_status(self) -> str: """Get current TPU status and HBM utilization info""" try: # Get TPU devices tpu_devices = tf.config.list_logical_devices('TPU') if not tpu_devices: return "TPU: No devices" # Get strategy info num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1 # Get TPU memory info using the working /device:TPU:X format try: # Check all TPU devices for memory usage active_cores = 0 total_current_mb = 0 max_peak_mb = 0 for device in tpu_devices: try: memory_info = tf.config.experimental.get_memory_info(device.name) if memory_info and 'current' in memory_info: current_mb = memory_info['current'] // (1024 * 1024) peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024) if current_mb > 1: # >1MB considered active active_cores += 1 total_current_mb += current_mb max_peak_mb = max(max_peak_mb, peak_mb) except: continue if active_cores > 0: if active_cores == 1: hbm_info = f"HBM:{total_current_mb}MB(peak:{max_peak_mb}MB)" else: hbm_info = f"HBM:{total_current_mb}MB/{active_cores}cores(peak:{max_peak_mb}MB)" else: hbm_info = "HBM:idle" except Exception: # Fallback: simple TPU activity check try: with tf.device('/TPU:0'): _ = tf.constant(1.0) hbm_info = "HBM:active" except Exception: hbm_info = "HBM:inactive" return (f"TPU: {len(tpu_devices)}dev {num_replicas}cores " f"{hbm_info}") except Exception as e: return f"TPU: status_error({str(e)[:20]})" def _get_detailed_tpu_status(self) -> str: """Get detailed TPU status for training start""" try: # Get TPU devices tpu_devices = tf.config.list_logical_devices('TPU') if not tpu_devices: return "❌ No TPU devices detected" # Get strategy info num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1 strategy_type = type(self.strategy).__name__ # Get TPU HBM memory info using working device format try: active_cores = 0 total_current_gb = 0 max_peak_gb = 0 memory_details = [] for i, device in enumerate(tpu_devices): try: memory_info = tf.config.experimental.get_memory_info(device.name) if memory_info and 'current' in memory_info: current_gb = memory_info['current'] // (1024 * 1024 * 1024) peak_gb = memory_info.get('peak', memory_info['current']) // (1024 * 1024 * 1024) if current_gb > 0 or memory_info['current'] > 1024*1024: # >1MB active_cores += 1 total_current_gb += current_gb max_peak_gb = max(max_peak_gb, peak_gb) if current_gb > 0: memory_details.append(f"Core{i}:{current_gb}GB") except: continue if active_cores > 0: # Based on your test: TPU:0 peaked at 14.5GB, suggesting ~16GB per core estimated_per_core = 16 # Conservative estimate estimated_total_gb = estimated_per_core * len(tpu_devices) hbm_usage = f"HBM: {total_current_gb}GB/{estimated_total_gb}GB (peak: {max_peak_gb}GB) active:{active_cores}cores" else: hbm_usage = "HBM: 0GB/256GB (idle)" except Exception: hbm_usage = "HBM: unavailable" # Simple TPU test try: with tf.device('/TPU:0'): test_result = tf.constant([1.0, 2.0]) _ = tf.reduce_sum(test_result) tpu_test = "✅ responsive" except Exception as e: tpu_test = f"❌ test_failed({str(e)[:15]})" return (f"TPU Devices: {len(tpu_devices)} | " f"Strategy: {strategy_type} | " f"Cores: {num_replicas} | " f"{hbm_usage} | " f"Test: {tpu_test}") except Exception as e: return f"❌ TPU status check failed: {str(e)[:50]}" def _initialize_datasets(self): """Initialize training and validation datasets""" # Create file paths train_file_paths = [ os.path.join(self.args["dataset"]["dataset_dir"], s, 'data_train.hdf5') for s in self.args['dataset']['sessions'] ] val_file_paths = [ os.path.join(self.args["dataset"]["dataset_dir"], s, 'data_val.hdf5') for s in self.args['dataset']['sessions'] ] # Validate no duplicates if len(set(train_file_paths)) != len(train_file_paths): raise ValueError("Duplicate sessions in train dataset") if len(set(val_file_paths)) != len(val_file_paths): raise ValueError("Duplicate sessions in val dataset") # Split trials train_trials, _ = train_test_split_indices( file_paths=train_file_paths, test_percentage=0, seed=self.args['dataset']['seed'], bad_trials_dict=self.args['dataset'].get('bad_trials_dict') ) _, val_trials = train_test_split_indices( file_paths=val_file_paths, test_percentage=1, seed=self.args['dataset']['seed'], bad_trials_dict=self.args['dataset'].get('bad_trials_dict') ) # Save trial splits with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f: json.dump({'train': train_trials, 'val': val_trials}, f) # Create TensorFlow datasets with aggressive data preloading for TPU optimization # Monitor memory usage during data preloading import psutil initial_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024 print("🔄 Initializing training dataset with TPU-optimized memory management...") print(" 🚀 Preloading all data to RAM for maximum parallel analysis speed...") init_start_time = time.time() self.train_dataset_tf = BrainToTextDatasetTF( trial_indices=train_trials, n_batches=self.args['num_training_batches'], split='train', batch_size=self.args['dataset']['batch_size'], days_per_batch=self.args['dataset']['days_per_batch'], random_seed=self.args['dataset']['seed'], must_include_days=self.args['dataset'].get('must_include_days'), feature_subset=self.args['dataset'].get('feature_subset'), cache_data=True, # 启用智能缓存 preload_all_data=True # 🚀 TPU优化:预加载全部数据,解锁并行分析 ) # Log training dataset initialization performance train_init_time = time.time() - init_start_time train_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024 train_memory_used = train_memory_mb - initial_memory_mb print(f"✅ Training dataset initialized in {train_init_time:.2f}s, using {train_memory_used:.1f} MB RAM") print("🔄 Initializing validation dataset with TPU-optimized memory management...") print(" 🚀 Preloading all validation data to RAM for maximum parallel analysis speed...") val_init_start_time = time.time() self.val_dataset_tf = BrainToTextDatasetTF( trial_indices=val_trials, n_batches=None, # Use all validation data split='test', batch_size=self.args['dataset']['batch_size'], days_per_batch=1, # One day per validation batch random_seed=self.args['dataset']['seed'], feature_subset=self.args['dataset'].get('feature_subset'), cache_data=True, # 启用智能缓存 preload_all_data=True # 🚀 TPU优化:预加载全部数据,解锁并行分析 ) # Log validation dataset initialization performance val_init_time = time.time() - val_init_start_time final_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024 total_memory_used = final_memory_mb - initial_memory_mb val_memory_used = final_memory_mb - train_memory_mb print(f"✅ Validation dataset initialized in {val_init_time:.2f}s, using {val_memory_used:.1f} MB RAM") print(f"📊 Total data cache: {total_memory_used:.1f} MB RAM used for all datasets") self.logger.info("Successfully initialized TensorFlow datasets") def _build_model(self) -> TripleGRUDecoder: """Build the TripleGRUDecoder model""" model = TripleGRUDecoder( neural_dim=self.args['model']['n_input_features'], n_units=self.args['model']['n_units'], n_days=len(self.args['dataset']['sessions']), n_classes=self.args['dataset']['n_classes'], rnn_dropout=self.args['model']['rnn_dropout'], input_dropout=self.args['model']['input_network']['input_layer_dropout'], patch_size=self.args['model']['patch_size'], patch_stride=self.args['model']['patch_stride'] ) return model def _create_optimizer(self) -> tf.keras.optimizers.Optimizer: """Create AdamW optimizer""" print(f"Creating optimizer with strategy: {type(self.strategy).__name__}") print("Using AdamW optimizer for TPU training") optimizer = tf.keras.optimizers.AdamW( learning_rate=self.args['lr_max'], beta_1=self.args['beta0'], beta_2=self.args['beta1'], epsilon=self.args['epsilon'], weight_decay=self.args.get('weight_decay', 0.0) ) print("✅ Using AdamW optimizer") return optimizer def _create_lr_scheduler(self): """Create learning rate scheduler""" if self.args['lr_scheduler_type'] == 'cosine': return self._create_cosine_scheduler() elif self.args['lr_scheduler_type'] == 'linear': return tf.keras.optimizers.schedules.PolynomialDecay( initial_learning_rate=self.args['lr_max'], decay_steps=self.args['lr_decay_steps'], end_learning_rate=self.args['lr_min'], power=1.0 # Linear decay ) else: raise ValueError(f"Unknown scheduler type: {self.args['lr_scheduler_type']}") def _create_cosine_scheduler(self): """Create cosine learning rate scheduler""" return tf.keras.optimizers.schedules.CosineDecayRestarts( initial_learning_rate=self.args['lr_max'], first_decay_steps=self.args['lr_decay_steps'], t_mul=1.0, m_mul=1.0, alpha=self.args['lr_min'] / self.args['lr_max'] ) def _log_model_info(self): """Log model architecture and parameter information""" if self.logger: self.logger.info("Initialized TripleGRUDecoder model") else: print("Initialized TripleGRUDecoder model") # Build the model by calling it once with dummy data dummy_batch_size = 2 dummy_time_steps = 100 dummy_features = tf.zeros((dummy_batch_size, dummy_time_steps, self.args['model']['n_input_features'])) dummy_day_idx = tf.zeros((dummy_batch_size,), dtype=tf.int32) # Call the model to build it _ = self.model(dummy_features, dummy_day_idx, training=False) # Count parameters total_params = sum([tf.size(w).numpy() for w in self.model.trainable_weights]) if self.logger: self.logger.info(f"Model has {total_params:,} trainable parameters") else: print(f"Model has {total_params:,} trainable parameters") @tf.function def _train_step(self, batch, step): """Single training step with gradient tape""" features = batch['input_features'] labels = batch['seq_class_ids'] n_time_steps = batch['n_time_steps'] phone_seq_lens = batch['phone_seq_lens'] day_indices = batch['day_indices'] with tf.GradientTape() as tape: # Calculate adjusted lengths for CTC (data augmentation now handled in dataset pipeline) adjusted_lens = tf.cast( (tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1, tf.int32 ) # Forward pass use_full = self.adv_enabled and (step >= self.adv_warmup_steps) if use_full: clean_logits, noisy_logits, noise_output = self.model( features, day_indices, None, False, 'full', grl_lambda=self.adv_grl_lambda, training=True ) else: clean_logits = self.model( features, day_indices, None, False, 'inference', training=True ) # Calculate losses using TPU-compatible CTC implementation if use_full: # Clean CTC loss - using standard tf.nn.ctc_loss with SparseTensor sparse_labels = dense_to_sparse(labels, phone_seq_lens) clean_loss = tf.nn.ctc_loss( labels=sparse_labels, logits=clean_logits, label_length=None, # SparseTensor doesn't need label_length logit_length=adjusted_lens, logits_time_major=False, blank_index=0 ) clean_loss = tf.reduce_mean(clean_loss) # Noisy CTC loss - using standard tf.nn.ctc_loss with SparseTensor # Reuse the same sparse_labels from above noisy_loss = tf.nn.ctc_loss( labels=sparse_labels, logits=noisy_logits, label_length=None, # SparseTensor doesn't need label_length logit_length=adjusted_lens, logits_time_major=False, blank_index=0 ) noisy_loss = tf.reduce_mean(noisy_loss) # Optional noise L2 regularization noise_l2 = tf.constant(0.0, dtype=clean_loss.dtype) if self.adv_noise_l2_weight > 0.0: noise_l2 = tf.reduce_mean(tf.square(noise_output)) loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2 else: # Standard CTC loss - using standard tf.nn.ctc_loss with SparseTensor sparse_labels = dense_to_sparse(labels, phone_seq_lens) loss = tf.nn.ctc_loss( labels=sparse_labels, logits=clean_logits, label_length=None, # SparseTensor doesn't need label_length logit_length=adjusted_lens, logits_time_major=False, blank_index=0 ) loss = tf.reduce_mean(loss) # AdamW handles weight decay automatically - no manual L2 regularization needed # TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理 # TPU v5e-8使用bfloat16,不需要loss scaling # Calculate gradients - TensorFlow自动处理混合精度 gradients = tape.gradient(loss, self.model.trainable_variables) # For TPU compatibility, use all variables (TensorFlow will handle None gradients automatically) # This ensures consistency with slot variable initialization all_variables = self.model.trainable_variables # Replace None gradients with zeros to maintain consistency safe_gradients = [] for grad, var in zip(gradients, all_variables): if grad is not None: safe_gradients.append(grad) else: # Create zero gradient for variables without gradients safe_gradients.append(tf.zeros_like(var)) # Clip gradients if self.args['grad_norm_clip_value'] > 0: safe_gradients, grad_norm = tf.clip_by_global_norm( safe_gradients, self.args['grad_norm_clip_value'] ) else: grad_norm = tf.global_norm(safe_gradients) # Apply gradients to ALL variables (consistent with initialization) # TensorFlow optimizer will handle zero gradients correctly self.optimizer.apply_gradients(zip(safe_gradients, all_variables)) return loss, grad_norm @tf.function def _validation_step(self, batch): """Single validation step - returns data for accurate PER calculation""" features = batch['input_features'] labels = batch['seq_class_ids'] n_time_steps = batch['n_time_steps'] phone_seq_lens = batch['phone_seq_lens'] day_indices = batch['day_indices'] # Calculate adjusted lengths (no augmentation for validation, handled in dataset pipeline) adjusted_lens = tf.cast( (tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1, tf.int32 ) # Forward pass (inference mode only) logits = self.model(features, day_indices, None, False, 'inference', training=False) # Calculate loss using standard tf.nn.ctc_loss with SparseTensor sparse_labels = dense_to_sparse(labels, phone_seq_lens) loss = tf.nn.ctc_loss( labels=sparse_labels, logits=logits, label_length=None, # SparseTensor doesn't need label_length logit_length=adjusted_lens, logits_time_major=False, blank_index=0 ) loss = tf.reduce_mean(loss) # Greedy decoding for PER calculation predicted_ids = tf.argmax(logits, axis=-1, output_type=tf.int32) # Return all necessary data for accurate PER calculation on CPU return loss, predicted_ids, labels, adjusted_lens, phone_seq_lens def train(self) -> Dict[str, Any]: """Main training loop""" self.logger.info("Starting training loop...") # Log initial TPU status initial_tpu_status = self._get_detailed_tpu_status() self.logger.info(f"Initial TPU Status: {initial_tpu_status}") # ========================= 终极解决方案:批处理优先 ========================= # 使用经过验证的"先批处理,后增强"方法,消除数据增强与形状分析的时间悖论 self.logger.info("🚀 Using FINAL 'batch-first, augment-after' approach") self.logger.info(" This eliminates the time paradox between data augmentation and shape analysis") # 简化的数据集创建函数,不再需要 max_shapes def create_dist_dataset_fn(input_dataset_tf, training): """Create distributed dataset function for the final 'batch-first' approach.""" def dataset_fn(input_context): # 调用新版的 create_input_fn,它不需要 max_shapes return create_input_fn( input_dataset_tf, self.args['dataset']['data_transforms'], training=training ) return self.strategy.distribute_datasets_from_function(dataset_fn) # 使用新的、简化的函数签名创建数据集 self.logger.info("🔄 Distributing training dataset (batch-first approach)...") dist_start_time = time.time() train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True) train_dist_time = time.time() - dist_start_time self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s") self.logger.info("🔄 Distributing validation dataset (batch-first approach)...") val_start_time = time.time() val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False) val_dist_time = time.time() - val_start_time self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s") # ===================================================================== self.logger.info("Created distributed training and validation datasets") # Training metrics train_losses = [] val_losses = [] val_pers = [] val_results = [] val_steps_since_improvement = 0 self.logger.info("Training time count beginning...") train_start_time = time.time() # Training loop step = 0 self.logger.info("🔄 Starting training loop...") self.logger.info("📋 Note: If you see 'TPU has inputs with dynamic shapes' warnings,") self.logger.info(" consider using padded_batch with fixed shapes in create_input_fn") for batch in train_dist_dataset: if step >= self.args['num_training_batches']: self.logger.info("Reached maximum training batches, stopping training") break start_time = time.time() # Distributed training step self.logger.info("Running distributed training step...") # Ensure we're in the correct TPU strategy scope try: with self.strategy.scope(): per_replica_losses, per_replica_grad_norms = self.strategy.run( self._train_step, args=(batch, step) ) except AttributeError as e: if "merge_call" in str(e): error_msg = f"Strategy merge_call error at step {step}: {e}" print(error_msg) if self.logger: self.logger.error(error_msg) self.logger.error("This indicates the strategy is not properly initialized") raise RuntimeError(f"TPU strategy failed during training step {step}: {e}") else: raise # Reduce across replicas self.logger.info("Reducing results across replicas...") loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None) grad_norm = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_grad_norms, axis=None) train_step_duration = time.time() - start_time train_losses.append(float(loss.numpy())) # Log training progress with TPU status if step % self.args['batches_per_train_log'] == 0: tpu_status = self._get_tpu_status() self.logger.info(f'Train batch {step}: ' f'loss: {float(loss.numpy()):.2f} ' f'grad norm: {float(grad_norm.numpy()):.2f} ' f'time: {train_step_duration:.3f}s ' f'| {tpu_status}') # Validation step if step % self.args['batches_per_val_step'] == 0 or step == (self.args['num_training_batches'] - 1): self.logger.info(f"Running validation after training batch: {step}") val_start_time = time.time() val_metrics = self._validate(val_dist_dataset) val_step_duration = time.time() - val_start_time tpu_status = self._get_tpu_status() self.logger.info(f'Val batch {step}: ' f'PER (avg): {val_metrics["avg_per"]:.4f} ' f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} ' f'time: {val_step_duration:.3f}s ' f'| {tpu_status}') val_pers.append(val_metrics['avg_per']) val_losses.append(val_metrics['avg_loss']) val_results.append(val_metrics) # Check for improvement new_best = False if val_metrics['avg_per'] < self.best_val_per: self.logger.info(f"New best test PER {self.best_val_per:.4f} --> {val_metrics['avg_per']:.4f}") self.best_val_per = val_metrics['avg_per'] self.best_val_loss = val_metrics['avg_loss'] new_best = True elif (val_metrics['avg_per'] == self.best_val_per and val_metrics['avg_loss'] < self.best_val_loss): self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}") self.best_val_loss = val_metrics['avg_loss'] new_best = True if new_best: if self.args.get('save_best_checkpoint', True): self.logger.info("Checkpointing model") self._save_checkpoint(step) if self.args.get('save_val_metrics', True): with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f: pickle.dump(val_metrics, f) val_steps_since_improvement = 0 else: val_steps_since_improvement += 1 # Optional save all validation checkpoints if self.args.get('save_all_val_steps', False): self._save_checkpoint(step) # Early stopping if (self.args.get('early_stopping', False) and val_steps_since_improvement >= self.args.get('early_stopping_val_steps', 20)): self.logger.info(f'Validation PER has not improved in {self.args["early_stopping_val_steps"]} ' f'validation steps. Stopping training early at batch: {step}') break step += 1 # Training completed training_duration = time.time() - train_start_time self.logger.info(f'Best avg val PER achieved: {self.best_val_per:.5f}') self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes') # Save final model if self.args.get('save_final_model', False): last_loss = val_losses[-1] if len(val_losses) > 0 else float('inf') self._save_checkpoint(step-1) return { 'train_losses': train_losses, 'val_losses': val_losses, 'val_pers': val_pers, 'val_metrics': val_results } def _validate(self, val_dataset) -> Dict[str, Any]: """Run validation on entire validation dataset with accurate PER calculation""" total_loss = 0.0 total_edit_distance = 0 total_seq_length = 0 num_batches = 0 for batch in val_dataset: # Get predictions and labels from all TPU cores per_replica_losses, per_replica_preds, per_replica_labels, per_replica_pred_lens, per_replica_label_lens = ( self.strategy.run(self._validation_step, args=(batch,)) ) # Reduce loss across replicas batch_loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None) total_loss += float(batch_loss.numpy()) # Gather all data from TPU cores to CPU for accurate PER calculation all_preds = self.strategy.gather(per_replica_preds, axis=0) all_labels = self.strategy.gather(per_replica_labels, axis=0) all_pred_lens = self.strategy.gather(per_replica_pred_lens, axis=0) all_label_lens = self.strategy.gather(per_replica_label_lens, axis=0) # Calculate accurate edit distance on CPU batch_size = all_preds.shape[0] for i in range(batch_size): pred_len = int(all_pred_lens[i].numpy()) label_len = int(all_label_lens[i].numpy()) # Extract sequences and remove CTC blanks (assuming blank_index=0) pred_seq = all_preds[i, :pred_len].numpy() pred_seq = [p for p in pred_seq if p != 0] # Remove blanks # Remove consecutive duplicates (CTC decoding) if len(pred_seq) > 0: deduped_pred = [pred_seq[0]] for j in range(1, len(pred_seq)): if pred_seq[j] != pred_seq[j-1]: deduped_pred.append(pred_seq[j]) pred_seq = deduped_pred true_seq = all_labels[i, :label_len].numpy().tolist() # Calculate edit distance using proper library if available if editdistance is not None: edit_dist = editdistance.eval(pred_seq, true_seq) else: # Fallback to simple approximation if editdistance not available edit_dist = self._simple_edit_distance(pred_seq, true_seq) total_edit_distance += edit_dist total_seq_length += label_len num_batches += 1 avg_loss = total_loss / max(num_batches, 1) avg_per = total_edit_distance / max(total_seq_length, 1e-6) return { 'avg_loss': avg_loss, 'avg_per': avg_per, 'total_edit_distance': total_edit_distance, 'total_seq_length': total_seq_length, 'num_batches': num_batches } def _simple_edit_distance(self, seq1, seq2): """Simple edit distance implementation as fallback""" # Dynamic programming implementation of edit distance m, n = len(seq1), len(seq2) dp = [[0] * (n + 1) for _ in range(m + 1)] # Initialize base cases for i in range(m + 1): dp[i][0] = i for j in range(n + 1): dp[0][j] = j # Fill the DP table for i in range(1, m + 1): for j in range(1, n + 1): if seq1[i-1] == seq2[j-1]: dp[i][j] = dp[i-1][j-1] else: dp[i][j] = 1 + min( dp[i-1][j], # deletion dp[i][j-1], # insertion dp[i-1][j-1] # substitution ) return dp[m][n] def _save_checkpoint(self, step: int, name: str = ""): """Save checkpoint using the unified CheckpointManager""" # CheckpointManager automatically handles naming and numbering # The 'name' parameter is kept for backward compatibility but not used save_path = self.ckpt_manager.save(checkpoint_number=step) if self.logger: self.logger.info(f"Saved checkpoint for step {step}: {save_path}") else: print(f"Saved checkpoint for step {step}: {save_path}") # Save non-TensorFlow Python state separately state = { 'step': step, 'best_val_per': float(self.best_val_per), 'best_val_loss': float(self.best_val_loss) } # Associate state file with checkpoint state_path = os.path.join(self.args['checkpoint_dir'], f'state-{step}.json') with open(state_path, 'w') as f: json.dump(state, f) # Save config file (only once) config_path = os.path.join(self.args['checkpoint_dir'], 'args.yaml') if not os.path.exists(config_path): with open(config_path, 'w') as f: OmegaConf.save(config=self.args, f=f) def load_checkpoint(self, checkpoint_path: str): """Load a specific checkpoint and its associated training state""" if self.logger: self.logger.info(f"Loading checkpoint from: {checkpoint_path}") else: print(f"Loading checkpoint from: {checkpoint_path}") # Restore TensorFlow objects (model and optimizer) self.ckpt.restore(checkpoint_path).expect_partial() # Restore non-TensorFlow training state try: # Extract step number from checkpoint path (e.g., ckpt-123 -> 123) step = int(checkpoint_path.split('-')[-1]) state_path = os.path.join(os.path.dirname(checkpoint_path), f'state-{step}.json') with open(state_path, 'r') as f: state = json.load(f) self.best_val_per = state['best_val_per'] self.best_val_loss = state['best_val_loss'] if self.logger: self.logger.info(f"Restored training state from: {state_path}") else: print(f"Restored training state from: {state_path}") except (IOError, ValueError, KeyError) as e: warning_msg = (f"Could not load or parse state file for checkpoint {checkpoint_path}. " f"Starting with fresh state. Error: {e}") if self.logger: self.logger.warning(warning_msg) else: print(f"⚠️ {warning_msg}") def inference(self, features: tf.Tensor, day_indices: tf.Tensor, n_time_steps: tf.Tensor, mode: str = 'inference') -> tf.Tensor: """ Run inference on input features Args: features: Input neural features [batch_size, time_steps, features] day_indices: Day indices [batch_size] n_time_steps: Number of valid time steps [batch_size] mode: 'inference' or 'full' Returns: Phoneme logits [batch_size, time_steps, n_classes] """ # Apply data transformations (no augmentation) features, n_time_steps = DataAugmentationTF.transform_data( features, n_time_steps, self.args['dataset']['data_transforms'], training=False ) # Run model inference logits = self.model(features, day_indices, None, False, mode, training=False) return logits