This commit is contained in:
Zchen
2025-10-20 13:37:11 +08:00
parent 7358ff3d79
commit e399cf262a
2 changed files with 150 additions and 107 deletions

View File

@@ -889,35 +889,32 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
# Utility functions for TPU-optimized data pipeline # Utility functions for TPU-optimized data pipeline
def create_input_fn(dataset_tf: BrainToTextDatasetTF, def create_input_fn(dataset_tf: BrainToTextDatasetTF,
transform_args: Dict[str, Any], transform_args: Dict[str, Any],
max_shapes: Dict[str, int],
training: bool = True, training: bool = True,
cache_path: Optional[str] = None) -> tf.data.Dataset: cache_path: Optional[str] = None) -> tf.data.Dataset:
""" """
Create input function for TPU training with BATCH-FIRST approach Create input function for TPU training with PRE-ANALYZED FIXED shapes
This function implements the correct TPU data pipeline: This function uses pre-computed maximum shapes to create STATIC-size batches,
1. Load individual samples ensuring XLA compilation success on TPU hardware. This is CRITICAL for the
2. Cache raw samples final resolution of both CTC loss compatibility and graph structure issues.
3. Batch samples with dynamic padding
4. Apply data augmentation to entire batches (AFTER batching)
This approach prevents shape conflicts from augmentation operations
like random_cut that would otherwise make tensor shapes dynamic before batching.
Args: Args:
dataset_tf: BrainToTextDatasetTF instance dataset_tf: BrainToTextDatasetTF instance
transform_args: Data transformation configuration transform_args: Data transformation configuration
max_shapes: Pre-computed maximum shapes dictionary with keys:
'max_time_steps', 'max_phone_seq_len', 'max_transcription_len', 'n_features'
training: Whether this is for training (applies augmentations) training: Whether this is for training (applies augmentations)
cache_path: Optional path for disk caching to improve I/O performance cache_path: Optional path for disk caching to improve I/O performance
Returns: Returns:
tf.data.Dataset ready for TPU training with XLA-compatible operations tf.data.Dataset ready for TPU training with FIXED STATIC shapes
""" """
# Step 1: Create individual example dataset with file-grouping I/O optimization # Step 1: Create individual example dataset with file-grouping I/O optimization
dataset = dataset_tf.create_individual_dataset() dataset = dataset_tf.create_individual_dataset()
# Step 2: Cache raw samples BEFORE any augmentation # Step 2: Cache raw samples BEFORE any augmentation
# ========================= I/O OPTIMIZATION SOLUTION =========================
if cache_path: if cache_path:
dataset = dataset.cache(cache_path) dataset = dataset.cache(cache_path)
split_name = "training" if training else "validation" split_name = "training" if training else "validation"
@@ -929,19 +926,55 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
split_name = "training" if training else "validation" split_name = "training" if training else "validation"
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache") print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache")
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster") print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
# ================================================================
# Step 3: Batch samples with DYNAMIC padding (XLA-friendly for variable input sizes) # Step 3: Apply transformations to individual examples BEFORE batching
print(f"🔧 Using DYNAMIC padding for XLA compatibility:") def apply_transforms(example):
"""Apply data transformations to individual examples"""
features = example['input_features']
n_time_steps = example['n_time_steps']
# Define padded shapes with None for dynamic dimensions # Apply transformations
features, n_time_steps = DataAugmentationTF.transform_data(
tf.expand_dims(features, 0), # Add batch dimension for transforms
tf.expand_dims(n_time_steps, 0),
transform_args,
training=training
)
# Remove batch dimension
example['input_features'] = tf.squeeze(features, 0)
example['n_time_steps'] = tf.squeeze(n_time_steps, 0)
return example
# Apply transforms to cached data
dataset = dataset.map(
apply_transforms,
num_parallel_calls=tf.data.AUTOTUNE
)
# Step 4: Batch samples with FIXED STATIC padding (CRITICAL for XLA)
print(f"🔧 Using PRE-ANALYZED FIXED shapes for maximum TPU performance:")
# Extract pre-analyzed shape information
max_time_steps = max_shapes['max_time_steps']
max_phone_seq_len = max_shapes['max_phone_seq_len']
max_transcription_len = max_shapes['max_transcription_len']
n_features = max_shapes['n_features']
print(f" Fixed time steps: {max_time_steps}")
print(f" Fixed phone sequence length: {max_phone_seq_len}")
print(f" Fixed transcription length: {max_transcription_len}")
print(f" Number of features: {n_features}")
# Define FIXED padded shapes - NO None values for XLA compatibility
padded_shapes = { padded_shapes = {
'input_features': tf.TensorShape([None, None]), # [time_steps, features] - dynamic 'input_features': tf.TensorShape([max_time_steps, n_features]),
'seq_class_ids': tf.TensorShape([None]), # [phone_seq_len] - dynamic 'seq_class_ids': tf.TensorShape([max_phone_seq_len]),
'n_time_steps': tf.TensorShape([]), # scalar 'n_time_steps': tf.TensorShape([]), # scalar
'phone_seq_lens': tf.TensorShape([]), # scalar 'phone_seq_lens': tf.TensorShape([]), # scalar
'day_indices': tf.TensorShape([]), # scalar 'day_indices': tf.TensorShape([]), # scalar
'transcriptions': tf.TensorShape([None]), # [transcription_len] - dynamic 'transcriptions': tf.TensorShape([max_transcription_len]),
'block_nums': tf.TensorShape([]), # scalar 'block_nums': tf.TensorShape([]), # scalar
'trial_nums': tf.TensorShape([]) # scalar 'trial_nums': tf.TensorShape([]) # scalar
} }
@@ -958,7 +991,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
'trial_nums': 0 'trial_nums': 0
} }
# Create batches with dynamic padding # Create batches with FIXED padding - XLA compiler will be happy!
dataset = dataset.padded_batch( dataset = dataset.padded_batch(
batch_size=dataset_tf.batch_size, batch_size=dataset_tf.batch_size,
padded_shapes=padded_shapes, padded_shapes=padded_shapes,
@@ -966,36 +999,6 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
drop_remainder=True # Critical for TPU: ensures all batches have same size drop_remainder=True # Critical for TPU: ensures all batches have same size
) )
# Step 4: Apply data augmentation to ENTIRE BATCHES (after batching)
def apply_batch_transforms(batch):
"""Apply data transformations to entire batches - CRITICAL for XLA compatibility"""
features = batch['input_features']
n_time_steps = batch['n_time_steps']
# Apply transformations to the entire batch
features, n_time_steps = DataAugmentationTF.transform_data(
features, # Already batched: [batch_size, time_steps, features]
n_time_steps, # Already batched: [batch_size]
transform_args,
training=training
)
# Update the batch with transformed data
batch['input_features'] = features
batch['n_time_steps'] = n_time_steps
return batch
# Apply batch-level transforms (only if training)
if training:
print(f"✅ Applying batch-level data augmentation (post-batching for XLA compatibility)")
dataset = dataset.map(
apply_batch_transforms,
num_parallel_calls=tf.data.AUTOTUNE
)
else:
print(f"✅ Validation mode: no data augmentation applied")
# Prefetch for optimal performance # Prefetch for optimal performance
dataset = dataset.prefetch(tf.data.AUTOTUNE) dataset = dataset.prefetch(tf.data.AUTOTUNE)

View File

@@ -1,5 +1,6 @@
import os import os
import tensorflow as tf import tensorflow as tf
import tensorflow.keras.backend as K
import numpy as np import numpy as np
import time import time
import json import json
@@ -27,10 +28,41 @@ from dataset_tf import (
BrainToTextDatasetTF, BrainToTextDatasetTF,
DataAugmentationTF, DataAugmentationTF,
train_test_split_indices, train_test_split_indices,
create_input_fn create_input_fn,
analyze_dataset_shapes
) )
def ctc_loss_for_tpu(y_true, y_pred, input_length, label_length):
"""
TPU-compatible CTC loss function using Keras backend
This implementation uses K.ctc_batch_cost which is often more robust
for XLA compilation than tf.nn.ctc_loss, especially in complex model graphs.
Args:
y_true: Dense labels [batch_size, max_label_len]
y_pred: Logits [batch_size, time_steps, num_classes]
input_length: Logit sequence lengths [batch_size]
label_length: True label sequence lengths [batch_size]
Returns:
Scalar CTC loss value
"""
# K.ctc_batch_cost requires logits to be time-major [time_steps, batch_size, num_classes]
y_pred_time_major = tf.transpose(y_pred, [1, 0, 2])
# Ensure correct data types for Keras backend
y_true = tf.cast(y_true, tf.float32) # K.ctc_batch_cost expects float32 labels
input_length = tf.cast(input_length, tf.int32)
label_length = tf.cast(label_length, tf.int32)
# Calculate CTC loss using Keras backend (more XLA-friendly)
loss = K.ctc_batch_cost(y_true, y_pred_time_major, input_length, label_length)
return tf.reduce_mean(loss)
def dense_to_sparse(dense_tensor, sequence_lengths): def dense_to_sparse(dense_tensor, sequence_lengths):
""" """
Convert dense tensor to sparse tensor for CTC loss with dynamic shapes Convert dense tensor to sparse tensor for CTC loss with dynamic shapes
@@ -592,34 +624,23 @@ class BrainToTextDecoderTrainerTF:
features, day_indices, None, False, 'inference', training=True features, day_indices, None, False, 'inference', training=True
) )
# Calculate losses # Calculate losses using TPU-compatible CTC implementation
if use_full: if use_full:
# Convert dense labels to sparse for dynamic shapes # Clean CTC loss - using Keras backend for XLA compatibility
sparse_labels = dense_to_sparse(labels, phone_seq_lens) clean_loss = ctc_loss_for_tpu(
y_true=tf.cast(labels, tf.float32), # Dense labels as float32
# Clean CTC loss - will auto-fallback to CPU with soft device placement y_pred=clean_logits,
clean_logits_time_major = tf.transpose(clean_logits, [1, 0, 2]) input_length=adjusted_lens,
clean_loss = tf.nn.ctc_loss( label_length=phone_seq_lens
labels=sparse_labels,
logits=clean_logits_time_major,
label_length=None, # Not needed with sparse labels
logit_length=tf.cast(adjusted_lens, tf.int32),
blank_index=0,
logits_time_major=True
) )
clean_loss = tf.reduce_mean(clean_loss)
# Noisy CTC loss - will auto-fallback to CPU with soft device placement # Noisy CTC loss - using Keras backend for XLA compatibility
noisy_logits_time_major = tf.transpose(noisy_logits, [1, 0, 2]) noisy_loss = ctc_loss_for_tpu(
noisy_loss = tf.nn.ctc_loss( y_true=tf.cast(labels, tf.float32), # Reuse same dense labels
labels=sparse_labels, # Reuse same sparse labels y_pred=noisy_logits,
logits=noisy_logits_time_major, input_length=adjusted_lens,
label_length=None, # Not needed with sparse labels label_length=phone_seq_lens
logit_length=tf.cast(adjusted_lens, tf.int32),
blank_index=0,
logits_time_major=True
) )
noisy_loss = tf.reduce_mean(noisy_loss)
# Optional noise L2 regularization # Optional noise L2 regularization
noise_l2 = tf.constant(0.0, dtype=clean_loss.dtype) noise_l2 = tf.constant(0.0, dtype=clean_loss.dtype)
@@ -628,20 +649,13 @@ class BrainToTextDecoderTrainerTF:
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2 loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
else: else:
# Convert dense labels to sparse for dynamic shapes # Standard CTC loss - using Keras backend for XLA compatibility
sparse_labels = dense_to_sparse(labels, phone_seq_lens) loss = ctc_loss_for_tpu(
y_true=tf.cast(labels, tf.float32), # Dense labels as float32
# Standard CTC loss - will auto-fallback to CPU with soft device placement y_pred=clean_logits,
logits_time_major = tf.transpose(clean_logits, [1, 0, 2]) input_length=adjusted_lens,
loss = tf.nn.ctc_loss( label_length=phone_seq_lens
labels=sparse_labels,
logits=logits_time_major,
label_length=None, # Not needed with sparse labels
logit_length=tf.cast(adjusted_lens, tf.int32),
blank_index=0,
logits_time_major=True
) )
loss = tf.reduce_mean(loss)
# AdamW handles weight decay automatically - no manual L2 regularization needed # AdamW handles weight decay automatically - no manual L2 regularization needed
# TensorFlow混合精度处理不需要手动scalingKeras policy自动处理 # TensorFlow混合精度处理不需要手动scalingKeras policy自动处理
@@ -696,20 +710,13 @@ class BrainToTextDecoderTrainerTF:
# Forward pass (inference mode only) # Forward pass (inference mode only)
logits = self.model(features, day_indices, None, False, 'inference', training=False) logits = self.model(features, day_indices, None, False, 'inference', training=False)
# Convert dense labels to sparse for dynamic shapes # Calculate loss using TPU-compatible CTC implementation
sparse_labels = dense_to_sparse(labels, phone_seq_lens) loss = ctc_loss_for_tpu(
y_true=tf.cast(labels, tf.float32), # Dense labels as float32
# Calculate loss - will auto-fallback to CPU with soft device placement y_pred=logits,
logits_time_major = tf.transpose(logits, [1, 0, 2]) input_length=adjusted_lens,
loss = tf.nn.ctc_loss( label_length=phone_seq_lens
labels=sparse_labels,
logits=logits_time_major,
label_length=None, # Not needed with sparse labels
logit_length=tf.cast(adjusted_lens, tf.int32),
blank_index=0,
logits_time_major=True
) )
loss = tf.reduce_mean(loss)
# Greedy decoding for PER calculation # Greedy decoding for PER calculation
predicted_ids = tf.argmax(logits, axis=-1, output_type=tf.int32) predicted_ids = tf.argmax(logits, axis=-1, output_type=tf.int32)
@@ -725,28 +732,61 @@ class BrainToTextDecoderTrainerTF:
initial_tpu_status = self._get_detailed_tpu_status() initial_tpu_status = self._get_detailed_tpu_status()
self.logger.info(f"Initial TPU Status: {initial_tpu_status}") self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
# Create datasets using modern distribution API with dynamic padding # ========================= DATASET SHAPE ANALYSIS =========================
def create_dist_dataset_fn(input_dataset_tf, training): # Perform one-time full dataset analysis for FIXED shapes (critical for XLA)
"""Create distributed dataset function for modern TPU strategy with batch-first approach""" self.logger.info("🚀 Performing one-time full dataset analysis for FIXED shapes...")
self.logger.info(" This is CRITICAL for resolving both CTC compatibility and graph structure issues")
# Analyze training dataset (all data for accurate max shapes)
train_analysis_start = time.time()
train_max_shapes = analyze_dataset_shapes(self.train_dataset_tf, sample_size=-1)
train_analysis_time = time.time() - train_analysis_start
self.logger.info(f"✅ Training dataset analysis completed in {train_analysis_time:.2f}s")
# Analyze validation dataset (all data for accurate max shapes)
val_analysis_start = time.time()
val_max_shapes = analyze_dataset_shapes(self.val_dataset_tf, sample_size=-1)
val_analysis_time = time.time() - val_analysis_start
self.logger.info(f"✅ Validation dataset analysis completed in {val_analysis_time:.2f}s")
# Use maximum shapes across both datasets for consistent padding
final_max_shapes = {
'max_time_steps': max(train_max_shapes['max_time_steps'], val_max_shapes['max_time_steps']),
'max_phone_seq_len': max(train_max_shapes['max_phone_seq_len'], val_max_shapes['max_phone_seq_len']),
'max_transcription_len': max(train_max_shapes['max_transcription_len'], val_max_shapes['max_transcription_len']),
'n_features': train_max_shapes['n_features']
}
self.logger.info(f"📊 Final FIXED shapes for TPU training (eliminates XLA dynamic shape issues):")
self.logger.info(f" Time steps: {final_max_shapes['max_time_steps']}")
self.logger.info(f" Phone sequence length: {final_max_shapes['max_phone_seq_len']}")
self.logger.info(f" Transcription length: {final_max_shapes['max_transcription_len']}")
self.logger.info(f" Features: {final_max_shapes['n_features']}")
# =====================================================================
# Create datasets using modern distribution API with FIXED shapes
def create_dist_dataset_fn(input_dataset_tf, training, max_shapes):
"""Create distributed dataset function for modern TPU strategy with FIXED shapes"""
def dataset_fn(input_context): def dataset_fn(input_context):
# create_input_fn now uses batch-first approach with dynamic padding # create_input_fn now requires max_shapes parameter for FIXED shapes
return create_input_fn( return create_input_fn(
input_dataset_tf, input_dataset_tf,
self.args['dataset']['data_transforms'], self.args['dataset']['data_transforms'],
max_shapes=max_shapes, # Pass pre-analyzed FIXED shapes
training=training training=training
) )
return self.strategy.distribute_datasets_from_function(dataset_fn) return self.strategy.distribute_datasets_from_function(dataset_fn)
# Distribute datasets using modern API with batch-first approach # Distribute datasets using modern API with FIXED shapes
self.logger.info("🔄 Distributing training dataset across TPU cores...") self.logger.info("🔄 Distributing training dataset across TPU cores...")
dist_start_time = time.time() dist_start_time = time.time()
train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True) train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True, max_shapes=final_max_shapes)
train_dist_time = time.time() - dist_start_time train_dist_time = time.time() - dist_start_time
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s") self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
self.logger.info("🔄 Distributing validation dataset across TPU cores...") self.logger.info("🔄 Distributing validation dataset across TPU cores...")
val_start_time = time.time() val_start_time = time.time()
val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False) val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False, max_shapes=final_max_shapes)
val_dist_time = time.time() - val_start_time val_dist_time = time.time() - val_start_time
self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s") self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")