This commit is contained in:
Zchen
2025-10-20 13:37:11 +08:00
parent 7358ff3d79
commit e399cf262a
2 changed files with 150 additions and 107 deletions

View File

@@ -889,35 +889,32 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
# Utility functions for TPU-optimized data pipeline
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
transform_args: Dict[str, Any],
max_shapes: Dict[str, int],
training: bool = True,
cache_path: Optional[str] = None) -> tf.data.Dataset:
"""
Create input function for TPU training with BATCH-FIRST approach
Create input function for TPU training with PRE-ANALYZED FIXED shapes
This function implements the correct TPU data pipeline:
1. Load individual samples
2. Cache raw samples
3. Batch samples with dynamic padding
4. Apply data augmentation to entire batches (AFTER batching)
This approach prevents shape conflicts from augmentation operations
like random_cut that would otherwise make tensor shapes dynamic before batching.
This function uses pre-computed maximum shapes to create STATIC-size batches,
ensuring XLA compilation success on TPU hardware. This is CRITICAL for the
final resolution of both CTC loss compatibility and graph structure issues.
Args:
dataset_tf: BrainToTextDatasetTF instance
transform_args: Data transformation configuration
max_shapes: Pre-computed maximum shapes dictionary with keys:
'max_time_steps', 'max_phone_seq_len', 'max_transcription_len', 'n_features'
training: Whether this is for training (applies augmentations)
cache_path: Optional path for disk caching to improve I/O performance
Returns:
tf.data.Dataset ready for TPU training with XLA-compatible operations
tf.data.Dataset ready for TPU training with FIXED STATIC shapes
"""
# Step 1: Create individual example dataset with file-grouping I/O optimization
dataset = dataset_tf.create_individual_dataset()
# Step 2: Cache raw samples BEFORE any augmentation
# ========================= I/O OPTIMIZATION SOLUTION =========================
if cache_path:
dataset = dataset.cache(cache_path)
split_name = "training" if training else "validation"
@@ -929,19 +926,55 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
split_name = "training" if training else "validation"
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache")
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
# ================================================================
# Step 3: Batch samples with DYNAMIC padding (XLA-friendly for variable input sizes)
print(f"🔧 Using DYNAMIC padding for XLA compatibility:")
# Step 3: Apply transformations to individual examples BEFORE batching
def apply_transforms(example):
"""Apply data transformations to individual examples"""
features = example['input_features']
n_time_steps = example['n_time_steps']
# Define padded shapes with None for dynamic dimensions
# Apply transformations
features, n_time_steps = DataAugmentationTF.transform_data(
tf.expand_dims(features, 0), # Add batch dimension for transforms
tf.expand_dims(n_time_steps, 0),
transform_args,
training=training
)
# Remove batch dimension
example['input_features'] = tf.squeeze(features, 0)
example['n_time_steps'] = tf.squeeze(n_time_steps, 0)
return example
# Apply transforms to cached data
dataset = dataset.map(
apply_transforms,
num_parallel_calls=tf.data.AUTOTUNE
)
# Step 4: Batch samples with FIXED STATIC padding (CRITICAL for XLA)
print(f"🔧 Using PRE-ANALYZED FIXED shapes for maximum TPU performance:")
# Extract pre-analyzed shape information
max_time_steps = max_shapes['max_time_steps']
max_phone_seq_len = max_shapes['max_phone_seq_len']
max_transcription_len = max_shapes['max_transcription_len']
n_features = max_shapes['n_features']
print(f" Fixed time steps: {max_time_steps}")
print(f" Fixed phone sequence length: {max_phone_seq_len}")
print(f" Fixed transcription length: {max_transcription_len}")
print(f" Number of features: {n_features}")
# Define FIXED padded shapes - NO None values for XLA compatibility
padded_shapes = {
'input_features': tf.TensorShape([None, None]), # [time_steps, features] - dynamic
'seq_class_ids': tf.TensorShape([None]), # [phone_seq_len] - dynamic
'input_features': tf.TensorShape([max_time_steps, n_features]),
'seq_class_ids': tf.TensorShape([max_phone_seq_len]),
'n_time_steps': tf.TensorShape([]), # scalar
'phone_seq_lens': tf.TensorShape([]), # scalar
'day_indices': tf.TensorShape([]), # scalar
'transcriptions': tf.TensorShape([None]), # [transcription_len] - dynamic
'transcriptions': tf.TensorShape([max_transcription_len]),
'block_nums': tf.TensorShape([]), # scalar
'trial_nums': tf.TensorShape([]) # scalar
}
@@ -958,7 +991,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
'trial_nums': 0
}
# Create batches with dynamic padding
# Create batches with FIXED padding - XLA compiler will be happy!
dataset = dataset.padded_batch(
batch_size=dataset_tf.batch_size,
padded_shapes=padded_shapes,
@@ -966,36 +999,6 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
drop_remainder=True # Critical for TPU: ensures all batches have same size
)
# Step 4: Apply data augmentation to ENTIRE BATCHES (after batching)
def apply_batch_transforms(batch):
"""Apply data transformations to entire batches - CRITICAL for XLA compatibility"""
features = batch['input_features']
n_time_steps = batch['n_time_steps']
# Apply transformations to the entire batch
features, n_time_steps = DataAugmentationTF.transform_data(
features, # Already batched: [batch_size, time_steps, features]
n_time_steps, # Already batched: [batch_size]
transform_args,
training=training
)
# Update the batch with transformed data
batch['input_features'] = features
batch['n_time_steps'] = n_time_steps
return batch
# Apply batch-level transforms (only if training)
if training:
print(f"✅ Applying batch-level data augmentation (post-batching for XLA compatibility)")
dataset = dataset.map(
apply_batch_transforms,
num_parallel_calls=tf.data.AUTOTUNE
)
else:
print(f"✅ Validation mode: no data augmentation applied")
# Prefetch for optimal performance
dataset = dataset.prefetch(tf.data.AUTOTUNE)