This commit is contained in:
Zchen
2025-10-20 13:37:11 +08:00
parent 7358ff3d79
commit e399cf262a
2 changed files with 150 additions and 107 deletions

View File

@@ -889,35 +889,32 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
# Utility functions for TPU-optimized data pipeline
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
transform_args: Dict[str, Any],
max_shapes: Dict[str, int],
training: bool = True,
cache_path: Optional[str] = None) -> tf.data.Dataset:
"""
Create input function for TPU training with BATCH-FIRST approach
Create input function for TPU training with PRE-ANALYZED FIXED shapes
This function implements the correct TPU data pipeline:
1. Load individual samples
2. Cache raw samples
3. Batch samples with dynamic padding
4. Apply data augmentation to entire batches (AFTER batching)
This approach prevents shape conflicts from augmentation operations
like random_cut that would otherwise make tensor shapes dynamic before batching.
This function uses pre-computed maximum shapes to create STATIC-size batches,
ensuring XLA compilation success on TPU hardware. This is CRITICAL for the
final resolution of both CTC loss compatibility and graph structure issues.
Args:
dataset_tf: BrainToTextDatasetTF instance
transform_args: Data transformation configuration
max_shapes: Pre-computed maximum shapes dictionary with keys:
'max_time_steps', 'max_phone_seq_len', 'max_transcription_len', 'n_features'
training: Whether this is for training (applies augmentations)
cache_path: Optional path for disk caching to improve I/O performance
Returns:
tf.data.Dataset ready for TPU training with XLA-compatible operations
tf.data.Dataset ready for TPU training with FIXED STATIC shapes
"""
# Step 1: Create individual example dataset with file-grouping I/O optimization
dataset = dataset_tf.create_individual_dataset()
# Step 2: Cache raw samples BEFORE any augmentation
# ========================= I/O OPTIMIZATION SOLUTION =========================
if cache_path:
dataset = dataset.cache(cache_path)
split_name = "training" if training else "validation"
@@ -929,19 +926,55 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
split_name = "training" if training else "validation"
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache")
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
# ================================================================
# Step 3: Batch samples with DYNAMIC padding (XLA-friendly for variable input sizes)
print(f"🔧 Using DYNAMIC padding for XLA compatibility:")
# Step 3: Apply transformations to individual examples BEFORE batching
def apply_transforms(example):
"""Apply data transformations to individual examples"""
features = example['input_features']
n_time_steps = example['n_time_steps']
# Define padded shapes with None for dynamic dimensions
# Apply transformations
features, n_time_steps = DataAugmentationTF.transform_data(
tf.expand_dims(features, 0), # Add batch dimension for transforms
tf.expand_dims(n_time_steps, 0),
transform_args,
training=training
)
# Remove batch dimension
example['input_features'] = tf.squeeze(features, 0)
example['n_time_steps'] = tf.squeeze(n_time_steps, 0)
return example
# Apply transforms to cached data
dataset = dataset.map(
apply_transforms,
num_parallel_calls=tf.data.AUTOTUNE
)
# Step 4: Batch samples with FIXED STATIC padding (CRITICAL for XLA)
print(f"🔧 Using PRE-ANALYZED FIXED shapes for maximum TPU performance:")
# Extract pre-analyzed shape information
max_time_steps = max_shapes['max_time_steps']
max_phone_seq_len = max_shapes['max_phone_seq_len']
max_transcription_len = max_shapes['max_transcription_len']
n_features = max_shapes['n_features']
print(f" Fixed time steps: {max_time_steps}")
print(f" Fixed phone sequence length: {max_phone_seq_len}")
print(f" Fixed transcription length: {max_transcription_len}")
print(f" Number of features: {n_features}")
# Define FIXED padded shapes - NO None values for XLA compatibility
padded_shapes = {
'input_features': tf.TensorShape([None, None]), # [time_steps, features] - dynamic
'seq_class_ids': tf.TensorShape([None]), # [phone_seq_len] - dynamic
'input_features': tf.TensorShape([max_time_steps, n_features]),
'seq_class_ids': tf.TensorShape([max_phone_seq_len]),
'n_time_steps': tf.TensorShape([]), # scalar
'phone_seq_lens': tf.TensorShape([]), # scalar
'day_indices': tf.TensorShape([]), # scalar
'transcriptions': tf.TensorShape([None]), # [transcription_len] - dynamic
'transcriptions': tf.TensorShape([max_transcription_len]),
'block_nums': tf.TensorShape([]), # scalar
'trial_nums': tf.TensorShape([]) # scalar
}
@@ -958,7 +991,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
'trial_nums': 0
}
# Create batches with dynamic padding
# Create batches with FIXED padding - XLA compiler will be happy!
dataset = dataset.padded_batch(
batch_size=dataset_tf.batch_size,
padded_shapes=padded_shapes,
@@ -966,36 +999,6 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
drop_remainder=True # Critical for TPU: ensures all batches have same size
)
# Step 4: Apply data augmentation to ENTIRE BATCHES (after batching)
def apply_batch_transforms(batch):
"""Apply data transformations to entire batches - CRITICAL for XLA compatibility"""
features = batch['input_features']
n_time_steps = batch['n_time_steps']
# Apply transformations to the entire batch
features, n_time_steps = DataAugmentationTF.transform_data(
features, # Already batched: [batch_size, time_steps, features]
n_time_steps, # Already batched: [batch_size]
transform_args,
training=training
)
# Update the batch with transformed data
batch['input_features'] = features
batch['n_time_steps'] = n_time_steps
return batch
# Apply batch-level transforms (only if training)
if training:
print(f"✅ Applying batch-level data augmentation (post-batching for XLA compatibility)")
dataset = dataset.map(
apply_batch_transforms,
num_parallel_calls=tf.data.AUTOTUNE
)
else:
print(f"✅ Validation mode: no data augmentation applied")
# Prefetch for optimal performance
dataset = dataset.prefetch(tf.data.AUTOTUNE)

View File

@@ -1,5 +1,6 @@
import os
import tensorflow as tf
import tensorflow.keras.backend as K
import numpy as np
import time
import json
@@ -27,10 +28,41 @@ from dataset_tf import (
BrainToTextDatasetTF,
DataAugmentationTF,
train_test_split_indices,
create_input_fn
create_input_fn,
analyze_dataset_shapes
)
def ctc_loss_for_tpu(y_true, y_pred, input_length, label_length):
"""
TPU-compatible CTC loss function using Keras backend
This implementation uses K.ctc_batch_cost which is often more robust
for XLA compilation than tf.nn.ctc_loss, especially in complex model graphs.
Args:
y_true: Dense labels [batch_size, max_label_len]
y_pred: Logits [batch_size, time_steps, num_classes]
input_length: Logit sequence lengths [batch_size]
label_length: True label sequence lengths [batch_size]
Returns:
Scalar CTC loss value
"""
# K.ctc_batch_cost requires logits to be time-major [time_steps, batch_size, num_classes]
y_pred_time_major = tf.transpose(y_pred, [1, 0, 2])
# Ensure correct data types for Keras backend
y_true = tf.cast(y_true, tf.float32) # K.ctc_batch_cost expects float32 labels
input_length = tf.cast(input_length, tf.int32)
label_length = tf.cast(label_length, tf.int32)
# Calculate CTC loss using Keras backend (more XLA-friendly)
loss = K.ctc_batch_cost(y_true, y_pred_time_major, input_length, label_length)
return tf.reduce_mean(loss)
def dense_to_sparse(dense_tensor, sequence_lengths):
"""
Convert dense tensor to sparse tensor for CTC loss with dynamic shapes
@@ -592,34 +624,23 @@ class BrainToTextDecoderTrainerTF:
features, day_indices, None, False, 'inference', training=True
)
# Calculate losses
# Calculate losses using TPU-compatible CTC implementation
if use_full:
# Convert dense labels to sparse for dynamic shapes
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
# Clean CTC loss - will auto-fallback to CPU with soft device placement
clean_logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
clean_loss = tf.nn.ctc_loss(
labels=sparse_labels,
logits=clean_logits_time_major,
label_length=None, # Not needed with sparse labels
logit_length=tf.cast(adjusted_lens, tf.int32),
blank_index=0,
logits_time_major=True
# Clean CTC loss - using Keras backend for XLA compatibility
clean_loss = ctc_loss_for_tpu(
y_true=tf.cast(labels, tf.float32), # Dense labels as float32
y_pred=clean_logits,
input_length=adjusted_lens,
label_length=phone_seq_lens
)
clean_loss = tf.reduce_mean(clean_loss)
# Noisy CTC loss - will auto-fallback to CPU with soft device placement
noisy_logits_time_major = tf.transpose(noisy_logits, [1, 0, 2])
noisy_loss = tf.nn.ctc_loss(
labels=sparse_labels, # Reuse same sparse labels
logits=noisy_logits_time_major,
label_length=None, # Not needed with sparse labels
logit_length=tf.cast(adjusted_lens, tf.int32),
blank_index=0,
logits_time_major=True
# Noisy CTC loss - using Keras backend for XLA compatibility
noisy_loss = ctc_loss_for_tpu(
y_true=tf.cast(labels, tf.float32), # Reuse same dense labels
y_pred=noisy_logits,
input_length=adjusted_lens,
label_length=phone_seq_lens
)
noisy_loss = tf.reduce_mean(noisy_loss)
# Optional noise L2 regularization
noise_l2 = tf.constant(0.0, dtype=clean_loss.dtype)
@@ -628,20 +649,13 @@ class BrainToTextDecoderTrainerTF:
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
else:
# Convert dense labels to sparse for dynamic shapes
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
# Standard CTC loss - will auto-fallback to CPU with soft device placement
logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
loss = tf.nn.ctc_loss(
labels=sparse_labels,
logits=logits_time_major,
label_length=None, # Not needed with sparse labels
logit_length=tf.cast(adjusted_lens, tf.int32),
blank_index=0,
logits_time_major=True
# Standard CTC loss - using Keras backend for XLA compatibility
loss = ctc_loss_for_tpu(
y_true=tf.cast(labels, tf.float32), # Dense labels as float32
y_pred=clean_logits,
input_length=adjusted_lens,
label_length=phone_seq_lens
)
loss = tf.reduce_mean(loss)
# AdamW handles weight decay automatically - no manual L2 regularization needed
# TensorFlow混合精度处理不需要手动scalingKeras policy自动处理
@@ -696,20 +710,13 @@ class BrainToTextDecoderTrainerTF:
# Forward pass (inference mode only)
logits = self.model(features, day_indices, None, False, 'inference', training=False)
# Convert dense labels to sparse for dynamic shapes
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
# Calculate loss - will auto-fallback to CPU with soft device placement
logits_time_major = tf.transpose(logits, [1, 0, 2])
loss = tf.nn.ctc_loss(
labels=sparse_labels,
logits=logits_time_major,
label_length=None, # Not needed with sparse labels
logit_length=tf.cast(adjusted_lens, tf.int32),
blank_index=0,
logits_time_major=True
# Calculate loss using TPU-compatible CTC implementation
loss = ctc_loss_for_tpu(
y_true=tf.cast(labels, tf.float32), # Dense labels as float32
y_pred=logits,
input_length=adjusted_lens,
label_length=phone_seq_lens
)
loss = tf.reduce_mean(loss)
# Greedy decoding for PER calculation
predicted_ids = tf.argmax(logits, axis=-1, output_type=tf.int32)
@@ -725,28 +732,61 @@ class BrainToTextDecoderTrainerTF:
initial_tpu_status = self._get_detailed_tpu_status()
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
# Create datasets using modern distribution API with dynamic padding
def create_dist_dataset_fn(input_dataset_tf, training):
"""Create distributed dataset function for modern TPU strategy with batch-first approach"""
# ========================= DATASET SHAPE ANALYSIS =========================
# Perform one-time full dataset analysis for FIXED shapes (critical for XLA)
self.logger.info("🚀 Performing one-time full dataset analysis for FIXED shapes...")
self.logger.info(" This is CRITICAL for resolving both CTC compatibility and graph structure issues")
# Analyze training dataset (all data for accurate max shapes)
train_analysis_start = time.time()
train_max_shapes = analyze_dataset_shapes(self.train_dataset_tf, sample_size=-1)
train_analysis_time = time.time() - train_analysis_start
self.logger.info(f"✅ Training dataset analysis completed in {train_analysis_time:.2f}s")
# Analyze validation dataset (all data for accurate max shapes)
val_analysis_start = time.time()
val_max_shapes = analyze_dataset_shapes(self.val_dataset_tf, sample_size=-1)
val_analysis_time = time.time() - val_analysis_start
self.logger.info(f"✅ Validation dataset analysis completed in {val_analysis_time:.2f}s")
# Use maximum shapes across both datasets for consistent padding
final_max_shapes = {
'max_time_steps': max(train_max_shapes['max_time_steps'], val_max_shapes['max_time_steps']),
'max_phone_seq_len': max(train_max_shapes['max_phone_seq_len'], val_max_shapes['max_phone_seq_len']),
'max_transcription_len': max(train_max_shapes['max_transcription_len'], val_max_shapes['max_transcription_len']),
'n_features': train_max_shapes['n_features']
}
self.logger.info(f"📊 Final FIXED shapes for TPU training (eliminates XLA dynamic shape issues):")
self.logger.info(f" Time steps: {final_max_shapes['max_time_steps']}")
self.logger.info(f" Phone sequence length: {final_max_shapes['max_phone_seq_len']}")
self.logger.info(f" Transcription length: {final_max_shapes['max_transcription_len']}")
self.logger.info(f" Features: {final_max_shapes['n_features']}")
# =====================================================================
# Create datasets using modern distribution API with FIXED shapes
def create_dist_dataset_fn(input_dataset_tf, training, max_shapes):
"""Create distributed dataset function for modern TPU strategy with FIXED shapes"""
def dataset_fn(input_context):
# create_input_fn now uses batch-first approach with dynamic padding
# create_input_fn now requires max_shapes parameter for FIXED shapes
return create_input_fn(
input_dataset_tf,
self.args['dataset']['data_transforms'],
max_shapes=max_shapes, # Pass pre-analyzed FIXED shapes
training=training
)
return self.strategy.distribute_datasets_from_function(dataset_fn)
# Distribute datasets using modern API with batch-first approach
# Distribute datasets using modern API with FIXED shapes
self.logger.info("🔄 Distributing training dataset across TPU cores...")
dist_start_time = time.time()
train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True)
train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True, max_shapes=final_max_shapes)
train_dist_time = time.time() - dist_start_time
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
self.logger.info("🔄 Distributing validation dataset across TPU cores...")
val_start_time = time.time()
val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False)
val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False, max_shapes=final_max_shapes)
val_dist_time = time.time() - val_start_time
self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")