From ab12d0b7eebd954fbfd16023e6004a08885b8044 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Tue, 21 Oct 2025 00:31:59 +0800 Subject: [PATCH] f --- model_training_nnn_tpu/dataset_tf.py | 127 +++++++++++---------------- model_training_nnn_tpu/trainer_tf.py | 95 ++++++++++++++++---- 2 files changed, 128 insertions(+), 94 deletions(-) diff --git a/model_training_nnn_tpu/dataset_tf.py b/model_training_nnn_tpu/dataset_tf.py index 7448657..74ff95f 100644 --- a/model_training_nnn_tpu/dataset_tf.py +++ b/model_training_nnn_tpu/dataset_tf.py @@ -908,111 +908,56 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = # Utility functions for TPU-optimized data pipeline def create_input_fn(dataset_tf: BrainToTextDatasetTF, transform_args: Dict[str, Any], - max_shapes: Dict[str, int], training: bool = True, cache_path: Optional[str] = None) -> tf.data.Dataset: """ - Create input function for TPU training with PRE-ANALYZED FIXED shapes + Create input function for TPU training with DYNAMIC batching -> BATCH augmentation - This function uses pre-computed maximum shapes to create STATIC-size batches, - ensuring XLA compilation success on TPU hardware. This is CRITICAL for the - final resolution of both CTC loss compatibility and graph structure issues. + This function uses the proven "batch first, augment after" approach that eliminates + the time paradox between data augmentation and shape analysis. This is the FINAL + solution that resolves all XLA compilation and padding errors. + + The key insight: data augmentation (especially gauss_smooth with padding='SAME') + can increase sequence lengths unpredictably, making pre-computed static shapes invalid. + By batching first with dynamic padding, then applying augmentation to batches, + we eliminate this temporal paradox entirely. Args: dataset_tf: BrainToTextDatasetTF instance transform_args: Data transformation configuration - max_shapes: Pre-computed maximum shapes dictionary with keys: - 'max_time_steps', 'max_phone_seq_len', 'max_transcription_len', 'n_features' training: Whether this is for training (applies augmentations) cache_path: Optional path for disk caching to improve I/O performance Returns: - tf.data.Dataset ready for TPU training with FIXED STATIC shapes + tf.data.Dataset ready for TPU training with robust dynamic->static flow """ - # Step 1: Create individual example dataset with file-grouping I/O optimization + # Step 1: Create individual example dataset dataset = dataset_tf.create_individual_dataset() - # Step 2: Cache raw samples BEFORE any augmentation + # Step 2: Cache raw samples BEFORE any augmentation or batching if cache_path: dataset = dataset.cache(cache_path) split_name = "training" if training else "validation" print(f"🗃️ {split_name.capitalize()} dataset caching enabled: {cache_path}") print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster") else: - # 如果没有指定缓存路径,默认使用内存缓存 dataset = dataset.cache() split_name = "training" if training else "validation" print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache") - print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster") - # Step 3: Apply transformations to individual examples BEFORE batching - def apply_transforms(example): - """Apply data transformations to individual examples""" - features = example['input_features'] - n_time_steps = example['n_time_steps'] + # Step 3: Batch samples with DYNAMIC padding FIRST (eliminates time paradox) + print(f"🔧 Using DYNAMIC padding -> batch augmentation approach") + print(f"🔧 Feature dimension: {dataset_tf.feature_dim}") - # Apply transformations - features, n_time_steps = DataAugmentationTF.transform_data( - tf.expand_dims(features, 0), # Add batch dimension for transforms - tf.expand_dims(n_time_steps, 0), - transform_args, - training=training - ) - - # Remove batch dimension - example['input_features'] = tf.squeeze(features, 0) - example['n_time_steps'] = tf.squeeze(n_time_steps, 0) - - return example - - # Apply transforms to cached data - dataset = dataset.map( - apply_transforms, - num_parallel_calls=tf.data.AUTOTUNE - ) - - # ========================= 终极调试代码 ========================= - def debug_print_shape(example): - """调试函数:在 padded_batch 之前打印每个样本的形状""" - tf.print("🔍 Sample Shape Debug:", - tf.shape(example['input_features']), - "Expected feature dim:", dataset_tf.feature_dim, - output_stream=sys.stdout) - return example - - # 添加形状调试 - 这会在图执行时打印信息 - dataset = dataset.map(debug_print_shape) - print(f"⚠️ Debug mode: Will print each sample shape before padded_batch") - # ============================================================= - - # Step 4: Batch samples with FIXED STATIC padding (CRITICAL for XLA) - print(f"🔧 Using PRE-ANALYZED FIXED shapes for maximum TPU performance:") - - # Extract pre-analyzed shape information - max_time_steps = max_shapes['max_time_steps'] - max_phone_seq_len = max_shapes['max_phone_seq_len'] - max_transcription_len = max_shapes['max_transcription_len'] - - # ========================= 使用统一的特征维度 ========================= - # 使用 dataset_tf 对象中存储的、经过验证的特征维度,而不是依赖外部参数 - n_features = dataset_tf.feature_dim # <--- 关键修改:使用自动检测的特征维度 - print(f"🔧 Using verified feature dimension from dataset: {n_features}") - # ========================= 特征维度修改结束 ========================= - - print(f" Fixed time steps: {max_time_steps}") - print(f" Fixed phone sequence length: {max_phone_seq_len}") - print(f" Fixed transcription length: {max_transcription_len}") - print(f" Number of features: {n_features}") - - # Define FIXED padded shapes with TensorSpec for better type safety + # Define dynamic padded shapes - key insight: None allows for dynamic lengths padded_shapes = { - 'input_features': tf.TensorSpec(shape=[max_time_steps, n_features], dtype=tf.float32), - 'seq_class_ids': tf.TensorSpec(shape=[max_phone_seq_len], dtype=tf.int32), + 'input_features': tf.TensorSpec(shape=(None, dataset_tf.feature_dim), dtype=tf.float32), + 'seq_class_ids': tf.TensorSpec(shape=(None,), dtype=tf.int32), 'n_time_steps': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar 'phone_seq_lens': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar 'day_indices': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar - 'transcriptions': tf.TensorSpec(shape=[max_transcription_len], dtype=tf.int32), + 'transcriptions': tf.TensorSpec(shape=(None,), dtype=tf.int32), 'block_nums': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar 'trial_nums': tf.TensorSpec(shape=[], dtype=tf.int32) # scalar } @@ -1029,7 +974,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, 'trial_nums': 0 } - # Create batches with FIXED padding - XLA compiler will be happy! + # Create batches with DYNAMIC padding - this cannot fail due to size mismatches dataset = dataset.padded_batch( batch_size=dataset_tf.batch_size, padded_shapes=padded_shapes, @@ -1037,7 +982,37 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, drop_remainder=True # Critical for TPU: ensures all batches have same size ) - # Prefetch for optimal performance + # Step 4: Apply data augmentation to BATCHES (after dynamic batching) + def apply_batch_transforms(batch): + """Apply data transformations to entire batches - resolves time paradox""" + features = batch['input_features'] + n_time_steps = batch['n_time_steps'] + + # Apply transformations to the entire batch + features, n_time_steps = DataAugmentationTF.transform_data( + features, # Already has batch dimension + n_time_steps, + transform_args, + training=training + ) + + # Update batch with transformed data + batch['input_features'] = features + batch['n_time_steps'] = n_time_steps + + return batch + + # Apply batch transforms only during training + if training: + dataset = dataset.map( + apply_batch_transforms, + num_parallel_calls=tf.data.AUTOTUNE + ) + print(f"✅ Batch augmentation enabled for training") + else: + print(f"✅ No augmentation for validation") + + # Step 5: Prefetch for optimal performance dataset = dataset.prefetch(tf.data.AUTOTUNE) return dataset \ No newline at end of file diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index cd20920..b5f92b1 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -17,8 +17,55 @@ except ImportError: print("Warning: editdistance not available, falling back to approximation") editdistance = None -# XLA-compatible CTC loss implementation -from tf_seq2seq_losses import classic_ctc_loss +# Note: Reverted to standard tf.nn.ctc_loss + SparseTensor approach +# for compatibility with "batch first, augment after" data pipeline + + +def dense_to_sparse(dense_tensor, sequence_lengths): + """ + Convert dense tensor to sparse tensor for CTC loss with dynamic shapes + + This function is essential for the "batch first, augment after" approach + as it handles the conversion from dynamic dense tensors to SparseTensor + format required by tf.nn.ctc_loss. + + Args: + dense_tensor: Dense tensor with shape [batch_size, max_seq_len] + sequence_lengths: Actual sequence lengths [batch_size] + + Returns: + SparseTensor suitable for tf.nn.ctc_loss + """ + # Create mask for valid (non-zero) elements within sequence lengths + batch_size = tf.shape(dense_tensor)[0] + max_seq_len = tf.shape(dense_tensor)[1] + + # Create range indices + batch_indices = tf.range(batch_size) + seq_indices = tf.range(max_seq_len) + + # Create meshgrid for sequence dimensions + _, seq_mesh = tf.meshgrid(batch_indices, seq_indices, indexing='ij') + + # Create mask based on sequence lengths and non-zero values + length_mask = seq_mesh < tf.expand_dims(sequence_lengths, 1) + value_mask = tf.not_equal(dense_tensor, 0) + combined_mask = tf.logical_and(length_mask, value_mask) + + # Get indices of valid elements + indices = tf.where(combined_mask) + + # Get values at valid indices + values = tf.gather_nd(dense_tensor, indices) + + # Create sparse tensor + dense_shape = tf.cast(tf.shape(dense_tensor), tf.int64) + + return tf.SparseTensor( + indices=tf.cast(indices, tf.int64), + values=tf.cast(values, tf.int32), + dense_shape=dense_shape + ) from rnn_model_tf import ( TripleGRUDecoder, @@ -559,23 +606,29 @@ class BrainToTextDecoderTrainerTF: # Calculate losses using TPU-compatible CTC implementation if use_full: - # Clean CTC loss - using XLA-compatible classic_ctc_loss - clean_loss = classic_ctc_loss( - labels=tf.cast(labels, tf.int32), # Dense labels as int32 + # Clean CTC loss - using standard tf.nn.ctc_loss with SparseTensor + sparse_labels = dense_to_sparse(labels, phone_seq_lens) + clean_loss = tf.nn.ctc_loss( + labels=sparse_labels, logits=clean_logits, - label_length=phone_seq_lens, + label_length=None, # SparseTensor doesn't need label_length logit_length=adjusted_lens, + logits_time_major=False, blank_index=0 ) + clean_loss = tf.reduce_mean(clean_loss) - # Noisy CTC loss - using XLA-compatible classic_ctc_loss - noisy_loss = classic_ctc_loss( - labels=tf.cast(labels, tf.int32), # Dense labels as int32 + # Noisy CTC loss - using standard tf.nn.ctc_loss with SparseTensor + # Reuse the same sparse_labels from above + noisy_loss = tf.nn.ctc_loss( + labels=sparse_labels, logits=noisy_logits, - label_length=phone_seq_lens, + label_length=None, # SparseTensor doesn't need label_length logit_length=adjusted_lens, + logits_time_major=False, blank_index=0 ) + noisy_loss = tf.reduce_mean(noisy_loss) # Optional noise L2 regularization noise_l2 = tf.constant(0.0, dtype=clean_loss.dtype) @@ -584,14 +637,17 @@ class BrainToTextDecoderTrainerTF: loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2 else: - # Standard CTC loss - using XLA-compatible classic_ctc_loss - loss = classic_ctc_loss( - labels=tf.cast(labels, tf.int32), # Dense labels as int32 + # Standard CTC loss - using standard tf.nn.ctc_loss with SparseTensor + sparse_labels = dense_to_sparse(labels, phone_seq_lens) + loss = tf.nn.ctc_loss( + labels=sparse_labels, logits=clean_logits, - label_length=phone_seq_lens, + label_length=None, # SparseTensor doesn't need label_length logit_length=adjusted_lens, + logits_time_major=False, blank_index=0 ) + loss = tf.reduce_mean(loss) # AdamW handles weight decay automatically - no manual L2 regularization needed # TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理 @@ -646,14 +702,17 @@ class BrainToTextDecoderTrainerTF: # Forward pass (inference mode only) logits = self.model(features, day_indices, None, False, 'inference', training=False) - # Calculate loss using XLA-compatible classic_ctc_loss - loss = classic_ctc_loss( - labels=tf.cast(labels, tf.int32), # Dense labels as int32 + # Calculate loss using standard tf.nn.ctc_loss with SparseTensor + sparse_labels = dense_to_sparse(labels, phone_seq_lens) + loss = tf.nn.ctc_loss( + labels=sparse_labels, logits=logits, - label_length=phone_seq_lens, + label_length=None, # SparseTensor doesn't need label_length logit_length=adjusted_lens, + logits_time_major=False, blank_index=0 ) + loss = tf.reduce_mean(loss) # Greedy decoding for PER calculation predicted_ids = tf.argmax(logits, axis=-1, output_type=tf.int32)