diff --git a/model_training_nnn_tpu/dataset_tf.py b/model_training_nnn_tpu/dataset_tf.py index 2a42f24..c03d8ce 100644 --- a/model_training_nnn_tpu/dataset_tf.py +++ b/model_training_nnn_tpu/dataset_tf.py @@ -889,34 +889,35 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = # Utility functions for TPU-optimized data pipeline def create_input_fn(dataset_tf: BrainToTextDatasetTF, transform_args: Dict[str, Any], - max_shapes: Dict[str, int], training: bool = True, cache_path: Optional[str] = None) -> tf.data.Dataset: """ - Create input function for TPU training with PRE-ANALYZED FIXED shapes + Create input function for TPU training with BATCH-FIRST approach - This function uses pre-computed maximum shapes to create fixed-size batches, - ensuring XLA compilation success on TPU hardware. + This function implements the correct TPU data pipeline: + 1. Load individual samples + 2. Cache raw samples + 3. Batch samples with dynamic padding + 4. Apply data augmentation to entire batches (AFTER batching) + + This approach prevents shape conflicts from augmentation operations + like random_cut that would otherwise make tensor shapes dynamic before batching. Args: dataset_tf: BrainToTextDatasetTF instance transform_args: Data transformation configuration - max_shapes: Pre-computed maximum shapes dictionary with keys: - 'max_time_steps', 'max_phone_seq_len', 'max_transcription_len', 'n_features' training: Whether this is for training (applies augmentations) cache_path: Optional path for disk caching to improve I/O performance Returns: - tf.data.Dataset ready for TPU training with fixed shapes + tf.data.Dataset ready for TPU training with XLA-compatible operations """ - # Create individual example dataset with file-grouping I/O optimization + # Step 1: Create individual example dataset with file-grouping I/O optimization dataset = dataset_tf.create_individual_dataset() + # Step 2: Cache raw samples BEFORE any augmentation # ========================= I/O OPTIMIZATION SOLUTION ========================= - # 对训练集和验证集都进行缓存,因为: - # 1. 训练集:每个epoch都要完整遍历 - # 2. 验证集:每200轮验证一次 + 早停检查,会被频繁使用 if cache_path: dataset = dataset.cache(cache_path) split_name = "training" if training else "validation" @@ -924,63 +925,25 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster") else: # 如果没有指定缓存路径,默认使用内存缓存 - # 对于大型数据集,建议在调用时显式指定磁盘缓存路径 dataset = dataset.cache() split_name = "training" if training else "validation" print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache") print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster") # ================================================================ - def apply_transforms(example): - """Apply data transformations to individual examples""" - features = example['input_features'] - n_time_steps = example['n_time_steps'] + # Step 3: Batch samples with DYNAMIC padding (XLA-friendly for variable input sizes) + print(f"🔧 Using DYNAMIC padding for XLA compatibility:") - # Apply transformations - features, n_time_steps = DataAugmentationTF.transform_data( - tf.expand_dims(features, 0), # Add batch dimension for transforms - tf.expand_dims(n_time_steps, 0), - transform_args, - training=training - ) - - # Remove batch dimension - example['input_features'] = tf.squeeze(features, 0) - example['n_time_steps'] = tf.squeeze(n_time_steps, 0) - - return example - - # 在缓存之后应用随机的数据增强,确保每个epoch的增强都不同 - dataset = dataset.map( - apply_transforms, - num_parallel_calls=tf.data.AUTOTUNE - ) - - # ========================= FIXED SHAPES SOLUTION ========================= - # 使用预分析的固定形状确保 XLA 编译成功 - print(f"🔧 Using PRE-ANALYZED FIXED shapes for maximum TPU performance:") - - # 从传入的参数中获取形状信息 - max_time_steps = max_shapes['max_time_steps'] - max_phone_seq_len = max_shapes['max_phone_seq_len'] - max_transcription_len = max_shapes['max_transcription_len'] - n_features = max_shapes['n_features'] - - print(f" Fixed time steps: {max_time_steps}") - print(f" Fixed phone sequence length: {max_phone_seq_len}") - print(f" Fixed transcription length: {max_transcription_len}") - print(f" Number of features: {n_features}") - - # Define fixed padded shapes - NO None values for XLA compatibility + # Define padded shapes with None for dynamic dimensions padded_shapes = { - 'input_features': tf.TensorShape([max_time_steps, n_features]), - 'seq_class_ids': tf.TensorShape([max_phone_seq_len]), - 'n_time_steps': tf.TensorShape([]), # 标量 - 'phone_seq_lens': tf.TensorShape([]), # 标量 - 'day_indices': tf.TensorShape([]), # 标量 - 'transcriptions': tf.TensorShape([max_transcription_len]), - 'block_nums': tf.TensorShape([]), # 标量 - 'trial_nums': tf.TensorShape([]) # 标量 + 'input_features': tf.TensorShape([None, None]), # [time_steps, features] - dynamic + 'seq_class_ids': tf.TensorShape([None]), # [phone_seq_len] - dynamic + 'n_time_steps': tf.TensorShape([]), # scalar + 'phone_seq_lens': tf.TensorShape([]), # scalar + 'day_indices': tf.TensorShape([]), # scalar + 'transcriptions': tf.TensorShape([None]), # [transcription_len] - dynamic + 'block_nums': tf.TensorShape([]), # scalar + 'trial_nums': tf.TensorShape([]) # scalar } # Define padding values for each field @@ -995,7 +958,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, 'trial_nums': 0 } - # Create batches with FIXED padding - XLA compiler will be happy! + # Create batches with dynamic padding dataset = dataset.padded_batch( batch_size=dataset_tf.batch_size, padded_shapes=padded_shapes, @@ -1003,6 +966,36 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, drop_remainder=True # Critical for TPU: ensures all batches have same size ) + # Step 4: Apply data augmentation to ENTIRE BATCHES (after batching) + def apply_batch_transforms(batch): + """Apply data transformations to entire batches - CRITICAL for XLA compatibility""" + features = batch['input_features'] + n_time_steps = batch['n_time_steps'] + + # Apply transformations to the entire batch + features, n_time_steps = DataAugmentationTF.transform_data( + features, # Already batched: [batch_size, time_steps, features] + n_time_steps, # Already batched: [batch_size] + transform_args, + training=training + ) + + # Update the batch with transformed data + batch['input_features'] = features + batch['n_time_steps'] = n_time_steps + + return batch + + # Apply batch-level transforms (only if training) + if training: + print(f"✅ Applying batch-level data augmentation (post-batching for XLA compatibility)") + dataset = dataset.map( + apply_batch_transforms, + num_parallel_calls=tf.data.AUTOTUNE + ) + else: + print(f"✅ Validation mode: no data augmentation applied") + # Prefetch for optimal performance dataset = dataset.prefetch(tf.data.AUTOTUNE) diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 7b7d8c7..77243ca 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -27,11 +27,53 @@ from dataset_tf import ( BrainToTextDatasetTF, DataAugmentationTF, train_test_split_indices, - create_input_fn, - analyze_dataset_shapes + create_input_fn ) +def dense_to_sparse(dense_tensor, sequence_lengths): + """ + Convert dense tensor to sparse tensor for CTC loss with dynamic shapes + + Args: + dense_tensor: Dense tensor with shape [batch_size, max_seq_len] + sequence_lengths: Actual sequence lengths [batch_size] + + Returns: + SparseTensor suitable for tf.nn.ctc_loss + """ + # Create mask for valid (non-zero) elements within sequence lengths + batch_size = tf.shape(dense_tensor)[0] + max_seq_len = tf.shape(dense_tensor)[1] + + # Create range indices + batch_indices = tf.range(batch_size) + seq_indices = tf.range(max_seq_len) + + # Create meshgrid for batch and sequence dimensions + batch_mesh, seq_mesh = tf.meshgrid(batch_indices, seq_indices, indexing='ij') + + # Create mask based on sequence lengths and non-zero values + length_mask = seq_mesh < tf.expand_dims(sequence_lengths, 1) + value_mask = tf.not_equal(dense_tensor, 0) + combined_mask = tf.logical_and(length_mask, value_mask) + + # Get indices of valid elements + indices = tf.where(combined_mask) + + # Get values at valid indices + values = tf.gather_nd(dense_tensor, indices) + + # Create sparse tensor + dense_shape = tf.cast(tf.shape(dense_tensor), tf.int64) + + return tf.SparseTensor( + indices=tf.cast(indices, tf.int64), + values=tf.cast(values, tf.int32), + dense_shape=dense_shape + ) + + class BrainToTextDecoderTrainerTF: """ TensorFlow/Keras trainer for brain-to-text phoneme decoder optimized for TPU v5e-8 @@ -392,7 +434,8 @@ class BrainToTextDecoderTrainerTF: import psutil initial_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024 - print("🔄 Initializing training dataset with GPU-style memory management...") + print("🔄 Initializing training dataset with TPU-optimized memory management...") + print(" 🚀 Preloading all data to RAM for maximum parallel analysis speed...") init_start_time = time.time() self.train_dataset_tf = BrainToTextDatasetTF( trial_indices=train_trials, @@ -403,8 +446,8 @@ class BrainToTextDecoderTrainerTF: random_seed=self.args['dataset']['seed'], must_include_days=self.args['dataset'].get('must_include_days'), feature_subset=self.args['dataset'].get('feature_subset'), - cache_data=True, # 启用智能缓存(像GPU版本一样) - preload_all_data=False # 🚨 采用GPU版本策略:按需加载,避免内存溢出 + cache_data=True, # 启用智能缓存 + preload_all_data=True # 🚀 TPU优化:预加载全部数据,解锁并行分析 ) # Log training dataset initialization performance @@ -413,7 +456,8 @@ class BrainToTextDecoderTrainerTF: train_memory_used = train_memory_mb - initial_memory_mb print(f"✅ Training dataset initialized in {train_init_time:.2f}s, using {train_memory_used:.1f} MB RAM") - print("🔄 Initializing validation dataset with GPU-style memory management...") + print("🔄 Initializing validation dataset with TPU-optimized memory management...") + print(" 🚀 Preloading all validation data to RAM for maximum parallel analysis speed...") val_init_start_time = time.time() self.val_dataset_tf = BrainToTextDatasetTF( trial_indices=val_trials, @@ -423,8 +467,8 @@ class BrainToTextDecoderTrainerTF: days_per_batch=1, # One day per validation batch random_seed=self.args['dataset']['seed'], feature_subset=self.args['dataset'].get('feature_subset'), - cache_data=True, # 启用智能缓存(像GPU版本一样) - preload_all_data=False # 🚨 采用GPU版本策略:按需加载,避免内存溢出 + cache_data=True, # 启用智能缓存 + preload_all_data=True # 🚀 TPU优化:预加载全部数据,解锁并行分析 ) # Log validation dataset initialization performance @@ -525,12 +569,7 @@ class BrainToTextDecoderTrainerTF: day_indices = batch['day_indices'] with tf.GradientTape() as tape: - # Apply data transformations - features, n_time_steps = DataAugmentationTF.transform_data( - features, n_time_steps, self.args['dataset']['data_transforms'], training=True - ) - - # Calculate adjusted lengths for CTC + # Calculate adjusted lengths for CTC (data augmentation now handled in dataset pipeline) adjusted_lens = tf.cast( (tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1, @@ -551,25 +590,28 @@ class BrainToTextDecoderTrainerTF: # Calculate losses if use_full: - # Clean CTC loss - use tf.nn.ctc_loss with dense labels (fixed shapes) + # Convert dense labels to sparse for dynamic shapes + sparse_labels = dense_to_sparse(labels, phone_seq_lens) + + # Clean CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes) # tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes] clean_logits_time_major = tf.transpose(clean_logits, [1, 0, 2]) clean_loss = tf.nn.ctc_loss( - labels=tf.cast(labels, tf.int32), # Use dense labels with fixed shapes + labels=sparse_labels, logits=clean_logits_time_major, - label_length=tf.cast(phone_seq_lens, tf.int32), # Re-enable label_length + label_length=None, # Not needed with sparse labels logit_length=tf.cast(adjusted_lens, tf.int32), blank_index=0, logits_time_major=True ) clean_loss = tf.reduce_mean(clean_loss) - # Noisy CTC loss - use tf.nn.ctc_loss with dense labels (fixed shapes) + # Noisy CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes) noisy_logits_time_major = tf.transpose(noisy_logits, [1, 0, 2]) noisy_loss = tf.nn.ctc_loss( - labels=tf.cast(labels, tf.int32), # Use dense labels with fixed shapes + labels=sparse_labels, # Reuse same sparse labels logits=noisy_logits_time_major, - label_length=tf.cast(phone_seq_lens, tf.int32), # Re-enable label_length + label_length=None, # Not needed with sparse labels logit_length=tf.cast(adjusted_lens, tf.int32), blank_index=0, logits_time_major=True @@ -583,12 +625,15 @@ class BrainToTextDecoderTrainerTF: loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2 else: - # Standard CTC loss - use tf.nn.ctc_loss with dense labels (fixed shapes) + # Convert dense labels to sparse for dynamic shapes + sparse_labels = dense_to_sparse(labels, phone_seq_lens) + + # Standard CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes) logits_time_major = tf.transpose(clean_logits, [1, 0, 2]) loss = tf.nn.ctc_loss( - labels=tf.cast(labels, tf.int32), # Use dense labels with fixed shapes + labels=sparse_labels, logits=logits_time_major, - label_length=tf.cast(phone_seq_lens, tf.int32), # Re-enable label_length + label_length=None, # Not needed with sparse labels logit_length=tf.cast(adjusted_lens, tf.int32), blank_index=0, logits_time_major=True @@ -638,12 +683,7 @@ class BrainToTextDecoderTrainerTF: phone_seq_lens = batch['phone_seq_lens'] day_indices = batch['day_indices'] - # Apply data transformations (no augmentation for validation) - features, n_time_steps = DataAugmentationTF.transform_data( - features, n_time_steps, self.args['dataset']['data_transforms'], training=False - ) - - # Calculate adjusted lengths + # Calculate adjusted lengths (no augmentation for validation, handled in dataset pipeline) adjusted_lens = tf.cast( (tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1, @@ -653,13 +693,16 @@ class BrainToTextDecoderTrainerTF: # Forward pass (inference mode only) logits = self.model(features, day_indices, None, False, 'inference', training=False) - # Calculate loss - use tf.nn.ctc_loss with dense labels (fixed shapes) + # Convert dense labels to sparse for dynamic shapes + sparse_labels = dense_to_sparse(labels, phone_seq_lens) + + # Calculate loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes) # tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes] logits_time_major = tf.transpose(logits, [1, 0, 2]) loss = tf.nn.ctc_loss( - labels=tf.cast(labels, tf.int32), # Use dense labels with fixed shapes + labels=sparse_labels, logits=logits_time_major, - label_length=tf.cast(phone_seq_lens, tf.int32), # Re-enable label_length + label_length=None, # Not needed with sparse labels logit_length=tf.cast(adjusted_lens, tf.int32), blank_index=0, logits_time_major=True @@ -680,60 +723,28 @@ class BrainToTextDecoderTrainerTF: initial_tpu_status = self._get_detailed_tpu_status() self.logger.info(f"Initial TPU Status: {initial_tpu_status}") - # ========================= DATASET SHAPE ANALYSIS ========================= - # Perform one-time full dataset analysis for fixed shapes (TPU requirement) - self.logger.info("🚀 Performing one-time full dataset analysis for fixed shapes...") - - # Analyze training dataset (all data for accurate max shapes) - train_analysis_start = time.time() - train_max_shapes = analyze_dataset_shapes(self.train_dataset_tf, sample_size=-1) - train_analysis_time = time.time() - train_analysis_start - self.logger.info(f"✅ Training dataset analysis completed in {train_analysis_time:.2f}s") - - # Analyze validation dataset (all data for accurate max shapes) - val_analysis_start = time.time() - val_max_shapes = analyze_dataset_shapes(self.val_dataset_tf, sample_size=-1) - val_analysis_time = time.time() - val_analysis_start - self.logger.info(f"✅ Validation dataset analysis completed in {val_analysis_time:.2f}s") - - # Use maximum shapes across both datasets for consistent padding - final_max_shapes = { - 'max_time_steps': max(train_max_shapes['max_time_steps'], val_max_shapes['max_time_steps']), - 'max_phone_seq_len': max(train_max_shapes['max_phone_seq_len'], val_max_shapes['max_phone_seq_len']), - 'max_transcription_len': max(train_max_shapes['max_transcription_len'], val_max_shapes['max_transcription_len']), - 'n_features': train_max_shapes['n_features'] - } - - self.logger.info(f"📊 Final fixed shapes for TPU training:") - self.logger.info(f" Time steps: {final_max_shapes['max_time_steps']}") - self.logger.info(f" Phone sequence length: {final_max_shapes['max_phone_seq_len']}") - self.logger.info(f" Transcription length: {final_max_shapes['max_transcription_len']}") - self.logger.info(f" Features: {final_max_shapes['n_features']}") - # ===================================================================== - - # Create datasets using modern distribution API with fixed shapes - def create_dist_dataset_fn(input_dataset_tf, training, max_shapes): - """Create distributed dataset function for modern TPU strategy with fixed shapes""" + # Create datasets using modern distribution API with dynamic padding + def create_dist_dataset_fn(input_dataset_tf, training): + """Create distributed dataset function for modern TPU strategy with batch-first approach""" def dataset_fn(input_context): - # create_input_fn now requires max_shapes parameter for fixed shapes + # create_input_fn now uses batch-first approach with dynamic padding return create_input_fn( input_dataset_tf, self.args['dataset']['data_transforms'], - max_shapes=max_shapes, # Pass pre-analyzed shapes training=training ) return self.strategy.distribute_datasets_from_function(dataset_fn) - # Distribute datasets using modern API with fixed shapes + # Distribute datasets using modern API with batch-first approach self.logger.info("🔄 Distributing training dataset across TPU cores...") dist_start_time = time.time() - train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True, max_shapes=final_max_shapes) + train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True) train_dist_time = time.time() - dist_start_time self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s") self.logger.info("🔄 Distributing validation dataset across TPU cores...") val_start_time = time.time() - val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False, max_shapes=final_max_shapes) + val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False) val_dist_time = time.time() - val_start_time self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")