From e399cf262a8c06dc4e48f7edd99c6f4e583223ee Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Mon, 20 Oct 2025 13:37:11 +0800 Subject: [PATCH] ff --- model_training_nnn_tpu/dataset_tf.py | 101 ++++++++--------- model_training_nnn_tpu/trainer_tf.py | 156 +++++++++++++++++---------- 2 files changed, 150 insertions(+), 107 deletions(-) diff --git a/model_training_nnn_tpu/dataset_tf.py b/model_training_nnn_tpu/dataset_tf.py index c03d8ce..e16f13f 100644 --- a/model_training_nnn_tpu/dataset_tf.py +++ b/model_training_nnn_tpu/dataset_tf.py @@ -889,35 +889,32 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = # Utility functions for TPU-optimized data pipeline def create_input_fn(dataset_tf: BrainToTextDatasetTF, transform_args: Dict[str, Any], + max_shapes: Dict[str, int], training: bool = True, cache_path: Optional[str] = None) -> tf.data.Dataset: """ - Create input function for TPU training with BATCH-FIRST approach + Create input function for TPU training with PRE-ANALYZED FIXED shapes - This function implements the correct TPU data pipeline: - 1. Load individual samples - 2. Cache raw samples - 3. Batch samples with dynamic padding - 4. Apply data augmentation to entire batches (AFTER batching) - - This approach prevents shape conflicts from augmentation operations - like random_cut that would otherwise make tensor shapes dynamic before batching. + This function uses pre-computed maximum shapes to create STATIC-size batches, + ensuring XLA compilation success on TPU hardware. This is CRITICAL for the + final resolution of both CTC loss compatibility and graph structure issues. Args: dataset_tf: BrainToTextDatasetTF instance transform_args: Data transformation configuration + max_shapes: Pre-computed maximum shapes dictionary with keys: + 'max_time_steps', 'max_phone_seq_len', 'max_transcription_len', 'n_features' training: Whether this is for training (applies augmentations) cache_path: Optional path for disk caching to improve I/O performance Returns: - tf.data.Dataset ready for TPU training with XLA-compatible operations + tf.data.Dataset ready for TPU training with FIXED STATIC shapes """ # Step 1: Create individual example dataset with file-grouping I/O optimization dataset = dataset_tf.create_individual_dataset() # Step 2: Cache raw samples BEFORE any augmentation - # ========================= I/O OPTIMIZATION SOLUTION ========================= if cache_path: dataset = dataset.cache(cache_path) split_name = "training" if training else "validation" @@ -929,19 +926,55 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, split_name = "training" if training else "validation" print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache") print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster") - # ================================================================ - # Step 3: Batch samples with DYNAMIC padding (XLA-friendly for variable input sizes) - print(f"🔧 Using DYNAMIC padding for XLA compatibility:") + # Step 3: Apply transformations to individual examples BEFORE batching + def apply_transforms(example): + """Apply data transformations to individual examples""" + features = example['input_features'] + n_time_steps = example['n_time_steps'] - # Define padded shapes with None for dynamic dimensions + # Apply transformations + features, n_time_steps = DataAugmentationTF.transform_data( + tf.expand_dims(features, 0), # Add batch dimension for transforms + tf.expand_dims(n_time_steps, 0), + transform_args, + training=training + ) + + # Remove batch dimension + example['input_features'] = tf.squeeze(features, 0) + example['n_time_steps'] = tf.squeeze(n_time_steps, 0) + + return example + + # Apply transforms to cached data + dataset = dataset.map( + apply_transforms, + num_parallel_calls=tf.data.AUTOTUNE + ) + + # Step 4: Batch samples with FIXED STATIC padding (CRITICAL for XLA) + print(f"🔧 Using PRE-ANALYZED FIXED shapes for maximum TPU performance:") + + # Extract pre-analyzed shape information + max_time_steps = max_shapes['max_time_steps'] + max_phone_seq_len = max_shapes['max_phone_seq_len'] + max_transcription_len = max_shapes['max_transcription_len'] + n_features = max_shapes['n_features'] + + print(f" Fixed time steps: {max_time_steps}") + print(f" Fixed phone sequence length: {max_phone_seq_len}") + print(f" Fixed transcription length: {max_transcription_len}") + print(f" Number of features: {n_features}") + + # Define FIXED padded shapes - NO None values for XLA compatibility padded_shapes = { - 'input_features': tf.TensorShape([None, None]), # [time_steps, features] - dynamic - 'seq_class_ids': tf.TensorShape([None]), # [phone_seq_len] - dynamic + 'input_features': tf.TensorShape([max_time_steps, n_features]), + 'seq_class_ids': tf.TensorShape([max_phone_seq_len]), 'n_time_steps': tf.TensorShape([]), # scalar 'phone_seq_lens': tf.TensorShape([]), # scalar 'day_indices': tf.TensorShape([]), # scalar - 'transcriptions': tf.TensorShape([None]), # [transcription_len] - dynamic + 'transcriptions': tf.TensorShape([max_transcription_len]), 'block_nums': tf.TensorShape([]), # scalar 'trial_nums': tf.TensorShape([]) # scalar } @@ -958,7 +991,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, 'trial_nums': 0 } - # Create batches with dynamic padding + # Create batches with FIXED padding - XLA compiler will be happy! dataset = dataset.padded_batch( batch_size=dataset_tf.batch_size, padded_shapes=padded_shapes, @@ -966,36 +999,6 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, drop_remainder=True # Critical for TPU: ensures all batches have same size ) - # Step 4: Apply data augmentation to ENTIRE BATCHES (after batching) - def apply_batch_transforms(batch): - """Apply data transformations to entire batches - CRITICAL for XLA compatibility""" - features = batch['input_features'] - n_time_steps = batch['n_time_steps'] - - # Apply transformations to the entire batch - features, n_time_steps = DataAugmentationTF.transform_data( - features, # Already batched: [batch_size, time_steps, features] - n_time_steps, # Already batched: [batch_size] - transform_args, - training=training - ) - - # Update the batch with transformed data - batch['input_features'] = features - batch['n_time_steps'] = n_time_steps - - return batch - - # Apply batch-level transforms (only if training) - if training: - print(f"✅ Applying batch-level data augmentation (post-batching for XLA compatibility)") - dataset = dataset.map( - apply_batch_transforms, - num_parallel_calls=tf.data.AUTOTUNE - ) - else: - print(f"✅ Validation mode: no data augmentation applied") - # Prefetch for optimal performance dataset = dataset.prefetch(tf.data.AUTOTUNE) diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 154b557..249ae47 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -1,5 +1,6 @@ import os import tensorflow as tf +import tensorflow.keras.backend as K import numpy as np import time import json @@ -27,10 +28,41 @@ from dataset_tf import ( BrainToTextDatasetTF, DataAugmentationTF, train_test_split_indices, - create_input_fn + create_input_fn, + analyze_dataset_shapes ) +def ctc_loss_for_tpu(y_true, y_pred, input_length, label_length): + """ + TPU-compatible CTC loss function using Keras backend + + This implementation uses K.ctc_batch_cost which is often more robust + for XLA compilation than tf.nn.ctc_loss, especially in complex model graphs. + + Args: + y_true: Dense labels [batch_size, max_label_len] + y_pred: Logits [batch_size, time_steps, num_classes] + input_length: Logit sequence lengths [batch_size] + label_length: True label sequence lengths [batch_size] + + Returns: + Scalar CTC loss value + """ + # K.ctc_batch_cost requires logits to be time-major [time_steps, batch_size, num_classes] + y_pred_time_major = tf.transpose(y_pred, [1, 0, 2]) + + # Ensure correct data types for Keras backend + y_true = tf.cast(y_true, tf.float32) # K.ctc_batch_cost expects float32 labels + input_length = tf.cast(input_length, tf.int32) + label_length = tf.cast(label_length, tf.int32) + + # Calculate CTC loss using Keras backend (more XLA-friendly) + loss = K.ctc_batch_cost(y_true, y_pred_time_major, input_length, label_length) + + return tf.reduce_mean(loss) + + def dense_to_sparse(dense_tensor, sequence_lengths): """ Convert dense tensor to sparse tensor for CTC loss with dynamic shapes @@ -592,34 +624,23 @@ class BrainToTextDecoderTrainerTF: features, day_indices, None, False, 'inference', training=True ) - # Calculate losses + # Calculate losses using TPU-compatible CTC implementation if use_full: - # Convert dense labels to sparse for dynamic shapes - sparse_labels = dense_to_sparse(labels, phone_seq_lens) - - # Clean CTC loss - will auto-fallback to CPU with soft device placement - clean_logits_time_major = tf.transpose(clean_logits, [1, 0, 2]) - clean_loss = tf.nn.ctc_loss( - labels=sparse_labels, - logits=clean_logits_time_major, - label_length=None, # Not needed with sparse labels - logit_length=tf.cast(adjusted_lens, tf.int32), - blank_index=0, - logits_time_major=True + # Clean CTC loss - using Keras backend for XLA compatibility + clean_loss = ctc_loss_for_tpu( + y_true=tf.cast(labels, tf.float32), # Dense labels as float32 + y_pred=clean_logits, + input_length=adjusted_lens, + label_length=phone_seq_lens ) - clean_loss = tf.reduce_mean(clean_loss) - # Noisy CTC loss - will auto-fallback to CPU with soft device placement - noisy_logits_time_major = tf.transpose(noisy_logits, [1, 0, 2]) - noisy_loss = tf.nn.ctc_loss( - labels=sparse_labels, # Reuse same sparse labels - logits=noisy_logits_time_major, - label_length=None, # Not needed with sparse labels - logit_length=tf.cast(adjusted_lens, tf.int32), - blank_index=0, - logits_time_major=True + # Noisy CTC loss - using Keras backend for XLA compatibility + noisy_loss = ctc_loss_for_tpu( + y_true=tf.cast(labels, tf.float32), # Reuse same dense labels + y_pred=noisy_logits, + input_length=adjusted_lens, + label_length=phone_seq_lens ) - noisy_loss = tf.reduce_mean(noisy_loss) # Optional noise L2 regularization noise_l2 = tf.constant(0.0, dtype=clean_loss.dtype) @@ -628,20 +649,13 @@ class BrainToTextDecoderTrainerTF: loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2 else: - # Convert dense labels to sparse for dynamic shapes - sparse_labels = dense_to_sparse(labels, phone_seq_lens) - - # Standard CTC loss - will auto-fallback to CPU with soft device placement - logits_time_major = tf.transpose(clean_logits, [1, 0, 2]) - loss = tf.nn.ctc_loss( - labels=sparse_labels, - logits=logits_time_major, - label_length=None, # Not needed with sparse labels - logit_length=tf.cast(adjusted_lens, tf.int32), - blank_index=0, - logits_time_major=True + # Standard CTC loss - using Keras backend for XLA compatibility + loss = ctc_loss_for_tpu( + y_true=tf.cast(labels, tf.float32), # Dense labels as float32 + y_pred=clean_logits, + input_length=adjusted_lens, + label_length=phone_seq_lens ) - loss = tf.reduce_mean(loss) # AdamW handles weight decay automatically - no manual L2 regularization needed # TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理 @@ -696,20 +710,13 @@ class BrainToTextDecoderTrainerTF: # Forward pass (inference mode only) logits = self.model(features, day_indices, None, False, 'inference', training=False) - # Convert dense labels to sparse for dynamic shapes - sparse_labels = dense_to_sparse(labels, phone_seq_lens) - - # Calculate loss - will auto-fallback to CPU with soft device placement - logits_time_major = tf.transpose(logits, [1, 0, 2]) - loss = tf.nn.ctc_loss( - labels=sparse_labels, - logits=logits_time_major, - label_length=None, # Not needed with sparse labels - logit_length=tf.cast(adjusted_lens, tf.int32), - blank_index=0, - logits_time_major=True + # Calculate loss using TPU-compatible CTC implementation + loss = ctc_loss_for_tpu( + y_true=tf.cast(labels, tf.float32), # Dense labels as float32 + y_pred=logits, + input_length=adjusted_lens, + label_length=phone_seq_lens ) - loss = tf.reduce_mean(loss) # Greedy decoding for PER calculation predicted_ids = tf.argmax(logits, axis=-1, output_type=tf.int32) @@ -725,28 +732,61 @@ class BrainToTextDecoderTrainerTF: initial_tpu_status = self._get_detailed_tpu_status() self.logger.info(f"Initial TPU Status: {initial_tpu_status}") - # Create datasets using modern distribution API with dynamic padding - def create_dist_dataset_fn(input_dataset_tf, training): - """Create distributed dataset function for modern TPU strategy with batch-first approach""" + # ========================= DATASET SHAPE ANALYSIS ========================= + # Perform one-time full dataset analysis for FIXED shapes (critical for XLA) + self.logger.info("🚀 Performing one-time full dataset analysis for FIXED shapes...") + self.logger.info(" This is CRITICAL for resolving both CTC compatibility and graph structure issues") + + # Analyze training dataset (all data for accurate max shapes) + train_analysis_start = time.time() + train_max_shapes = analyze_dataset_shapes(self.train_dataset_tf, sample_size=-1) + train_analysis_time = time.time() - train_analysis_start + self.logger.info(f"✅ Training dataset analysis completed in {train_analysis_time:.2f}s") + + # Analyze validation dataset (all data for accurate max shapes) + val_analysis_start = time.time() + val_max_shapes = analyze_dataset_shapes(self.val_dataset_tf, sample_size=-1) + val_analysis_time = time.time() - val_analysis_start + self.logger.info(f"✅ Validation dataset analysis completed in {val_analysis_time:.2f}s") + + # Use maximum shapes across both datasets for consistent padding + final_max_shapes = { + 'max_time_steps': max(train_max_shapes['max_time_steps'], val_max_shapes['max_time_steps']), + 'max_phone_seq_len': max(train_max_shapes['max_phone_seq_len'], val_max_shapes['max_phone_seq_len']), + 'max_transcription_len': max(train_max_shapes['max_transcription_len'], val_max_shapes['max_transcription_len']), + 'n_features': train_max_shapes['n_features'] + } + + self.logger.info(f"📊 Final FIXED shapes for TPU training (eliminates XLA dynamic shape issues):") + self.logger.info(f" Time steps: {final_max_shapes['max_time_steps']}") + self.logger.info(f" Phone sequence length: {final_max_shapes['max_phone_seq_len']}") + self.logger.info(f" Transcription length: {final_max_shapes['max_transcription_len']}") + self.logger.info(f" Features: {final_max_shapes['n_features']}") + # ===================================================================== + + # Create datasets using modern distribution API with FIXED shapes + def create_dist_dataset_fn(input_dataset_tf, training, max_shapes): + """Create distributed dataset function for modern TPU strategy with FIXED shapes""" def dataset_fn(input_context): - # create_input_fn now uses batch-first approach with dynamic padding + # create_input_fn now requires max_shapes parameter for FIXED shapes return create_input_fn( input_dataset_tf, self.args['dataset']['data_transforms'], + max_shapes=max_shapes, # Pass pre-analyzed FIXED shapes training=training ) return self.strategy.distribute_datasets_from_function(dataset_fn) - # Distribute datasets using modern API with batch-first approach + # Distribute datasets using modern API with FIXED shapes self.logger.info("🔄 Distributing training dataset across TPU cores...") dist_start_time = time.time() - train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True) + train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True, max_shapes=final_max_shapes) train_dist_time = time.time() - dist_start_time self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s") self.logger.info("🔄 Distributing validation dataset across TPU cores...") val_start_time = time.time() - val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False) + val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False, max_shapes=final_max_shapes) val_dist_time = time.time() - val_start_time self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")