Refactor input function to implement batch-first approach with dynamic padding and apply data augmentation post-batching for TPU compatibility

This commit is contained in:
Zchen
2025-10-20 00:58:29 +08:00
parent fabf70cfa9
commit 06ddbc6ac2
2 changed files with 136 additions and 132 deletions

View File

@@ -889,34 +889,35 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
# Utility functions for TPU-optimized data pipeline
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
transform_args: Dict[str, Any],
max_shapes: Dict[str, int],
training: bool = True,
cache_path: Optional[str] = None) -> tf.data.Dataset:
"""
Create input function for TPU training with PRE-ANALYZED FIXED shapes
Create input function for TPU training with BATCH-FIRST approach
This function uses pre-computed maximum shapes to create fixed-size batches,
ensuring XLA compilation success on TPU hardware.
This function implements the correct TPU data pipeline:
1. Load individual samples
2. Cache raw samples
3. Batch samples with dynamic padding
4. Apply data augmentation to entire batches (AFTER batching)
This approach prevents shape conflicts from augmentation operations
like random_cut that would otherwise make tensor shapes dynamic before batching.
Args:
dataset_tf: BrainToTextDatasetTF instance
transform_args: Data transformation configuration
max_shapes: Pre-computed maximum shapes dictionary with keys:
'max_time_steps', 'max_phone_seq_len', 'max_transcription_len', 'n_features'
training: Whether this is for training (applies augmentations)
cache_path: Optional path for disk caching to improve I/O performance
Returns:
tf.data.Dataset ready for TPU training with fixed shapes
tf.data.Dataset ready for TPU training with XLA-compatible operations
"""
# Create individual example dataset with file-grouping I/O optimization
# Step 1: Create individual example dataset with file-grouping I/O optimization
dataset = dataset_tf.create_individual_dataset()
# Step 2: Cache raw samples BEFORE any augmentation
# ========================= I/O OPTIMIZATION SOLUTION =========================
# 对训练集和验证集都进行缓存,因为:
# 1. 训练集每个epoch都要完整遍历
# 2. 验证集每200轮验证一次 + 早停检查,会被频繁使用
if cache_path:
dataset = dataset.cache(cache_path)
split_name = "training" if training else "validation"
@@ -924,63 +925,25 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
else:
# 如果没有指定缓存路径,默认使用内存缓存
# 对于大型数据集,建议在调用时显式指定磁盘缓存路径
dataset = dataset.cache()
split_name = "training" if training else "validation"
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache")
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
# ================================================================
def apply_transforms(example):
"""Apply data transformations to individual examples"""
features = example['input_features']
n_time_steps = example['n_time_steps']
# Step 3: Batch samples with DYNAMIC padding (XLA-friendly for variable input sizes)
print(f"🔧 Using DYNAMIC padding for XLA compatibility:")
# Apply transformations
features, n_time_steps = DataAugmentationTF.transform_data(
tf.expand_dims(features, 0), # Add batch dimension for transforms
tf.expand_dims(n_time_steps, 0),
transform_args,
training=training
)
# Remove batch dimension
example['input_features'] = tf.squeeze(features, 0)
example['n_time_steps'] = tf.squeeze(n_time_steps, 0)
return example
# 在缓存之后应用随机的数据增强确保每个epoch的增强都不同
dataset = dataset.map(
apply_transforms,
num_parallel_calls=tf.data.AUTOTUNE
)
# ========================= FIXED SHAPES SOLUTION =========================
# 使用预分析的固定形状确保 XLA 编译成功
print(f"🔧 Using PRE-ANALYZED FIXED shapes for maximum TPU performance:")
# 从传入的参数中获取形状信息
max_time_steps = max_shapes['max_time_steps']
max_phone_seq_len = max_shapes['max_phone_seq_len']
max_transcription_len = max_shapes['max_transcription_len']
n_features = max_shapes['n_features']
print(f" Fixed time steps: {max_time_steps}")
print(f" Fixed phone sequence length: {max_phone_seq_len}")
print(f" Fixed transcription length: {max_transcription_len}")
print(f" Number of features: {n_features}")
# Define fixed padded shapes - NO None values for XLA compatibility
# Define padded shapes with None for dynamic dimensions
padded_shapes = {
'input_features': tf.TensorShape([max_time_steps, n_features]),
'seq_class_ids': tf.TensorShape([max_phone_seq_len]),
'n_time_steps': tf.TensorShape([]), # 标量
'phone_seq_lens': tf.TensorShape([]), # 标量
'day_indices': tf.TensorShape([]), # 标量
'transcriptions': tf.TensorShape([max_transcription_len]),
'block_nums': tf.TensorShape([]), # 标量
'trial_nums': tf.TensorShape([]) # 标量
'input_features': tf.TensorShape([None, None]), # [time_steps, features] - dynamic
'seq_class_ids': tf.TensorShape([None]), # [phone_seq_len] - dynamic
'n_time_steps': tf.TensorShape([]), # scalar
'phone_seq_lens': tf.TensorShape([]), # scalar
'day_indices': tf.TensorShape([]), # scalar
'transcriptions': tf.TensorShape([None]), # [transcription_len] - dynamic
'block_nums': tf.TensorShape([]), # scalar
'trial_nums': tf.TensorShape([]) # scalar
}
# Define padding values for each field
@@ -995,7 +958,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
'trial_nums': 0
}
# Create batches with FIXED padding - XLA compiler will be happy!
# Create batches with dynamic padding
dataset = dataset.padded_batch(
batch_size=dataset_tf.batch_size,
padded_shapes=padded_shapes,
@@ -1003,6 +966,36 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
drop_remainder=True # Critical for TPU: ensures all batches have same size
)
# Step 4: Apply data augmentation to ENTIRE BATCHES (after batching)
def apply_batch_transforms(batch):
"""Apply data transformations to entire batches - CRITICAL for XLA compatibility"""
features = batch['input_features']
n_time_steps = batch['n_time_steps']
# Apply transformations to the entire batch
features, n_time_steps = DataAugmentationTF.transform_data(
features, # Already batched: [batch_size, time_steps, features]
n_time_steps, # Already batched: [batch_size]
transform_args,
training=training
)
# Update the batch with transformed data
batch['input_features'] = features
batch['n_time_steps'] = n_time_steps
return batch
# Apply batch-level transforms (only if training)
if training:
print(f"✅ Applying batch-level data augmentation (post-batching for XLA compatibility)")
dataset = dataset.map(
apply_batch_transforms,
num_parallel_calls=tf.data.AUTOTUNE
)
else:
print(f"✅ Validation mode: no data augmentation applied")
# Prefetch for optimal performance
dataset = dataset.prefetch(tf.data.AUTOTUNE)

View File

@@ -27,11 +27,53 @@ from dataset_tf import (
BrainToTextDatasetTF,
DataAugmentationTF,
train_test_split_indices,
create_input_fn,
analyze_dataset_shapes
create_input_fn
)
def dense_to_sparse(dense_tensor, sequence_lengths):
"""
Convert dense tensor to sparse tensor for CTC loss with dynamic shapes
Args:
dense_tensor: Dense tensor with shape [batch_size, max_seq_len]
sequence_lengths: Actual sequence lengths [batch_size]
Returns:
SparseTensor suitable for tf.nn.ctc_loss
"""
# Create mask for valid (non-zero) elements within sequence lengths
batch_size = tf.shape(dense_tensor)[0]
max_seq_len = tf.shape(dense_tensor)[1]
# Create range indices
batch_indices = tf.range(batch_size)
seq_indices = tf.range(max_seq_len)
# Create meshgrid for batch and sequence dimensions
batch_mesh, seq_mesh = tf.meshgrid(batch_indices, seq_indices, indexing='ij')
# Create mask based on sequence lengths and non-zero values
length_mask = seq_mesh < tf.expand_dims(sequence_lengths, 1)
value_mask = tf.not_equal(dense_tensor, 0)
combined_mask = tf.logical_and(length_mask, value_mask)
# Get indices of valid elements
indices = tf.where(combined_mask)
# Get values at valid indices
values = tf.gather_nd(dense_tensor, indices)
# Create sparse tensor
dense_shape = tf.cast(tf.shape(dense_tensor), tf.int64)
return tf.SparseTensor(
indices=tf.cast(indices, tf.int64),
values=tf.cast(values, tf.int32),
dense_shape=dense_shape
)
class BrainToTextDecoderTrainerTF:
"""
TensorFlow/Keras trainer for brain-to-text phoneme decoder optimized for TPU v5e-8
@@ -392,7 +434,8 @@ class BrainToTextDecoderTrainerTF:
import psutil
initial_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
print("🔄 Initializing training dataset with GPU-style memory management...")
print("🔄 Initializing training dataset with TPU-optimized memory management...")
print(" 🚀 Preloading all data to RAM for maximum parallel analysis speed...")
init_start_time = time.time()
self.train_dataset_tf = BrainToTextDatasetTF(
trial_indices=train_trials,
@@ -403,8 +446,8 @@ class BrainToTextDecoderTrainerTF:
random_seed=self.args['dataset']['seed'],
must_include_days=self.args['dataset'].get('must_include_days'),
feature_subset=self.args['dataset'].get('feature_subset'),
cache_data=True, # 启用智能缓存像GPU版本一样
preload_all_data=False # 🚨 采用GPU版本策略按需加载避免内存溢出
cache_data=True, # 启用智能缓存
preload_all_data=True # 🚀 TPU优化预加载全部数据解锁并行分析
)
# Log training dataset initialization performance
@@ -413,7 +456,8 @@ class BrainToTextDecoderTrainerTF:
train_memory_used = train_memory_mb - initial_memory_mb
print(f"✅ Training dataset initialized in {train_init_time:.2f}s, using {train_memory_used:.1f} MB RAM")
print("🔄 Initializing validation dataset with GPU-style memory management...")
print("🔄 Initializing validation dataset with TPU-optimized memory management...")
print(" 🚀 Preloading all validation data to RAM for maximum parallel analysis speed...")
val_init_start_time = time.time()
self.val_dataset_tf = BrainToTextDatasetTF(
trial_indices=val_trials,
@@ -423,8 +467,8 @@ class BrainToTextDecoderTrainerTF:
days_per_batch=1, # One day per validation batch
random_seed=self.args['dataset']['seed'],
feature_subset=self.args['dataset'].get('feature_subset'),
cache_data=True, # 启用智能缓存像GPU版本一样
preload_all_data=False # 🚨 采用GPU版本策略按需加载避免内存溢出
cache_data=True, # 启用智能缓存
preload_all_data=True # 🚀 TPU优化预加载全部数据解锁并行分析
)
# Log validation dataset initialization performance
@@ -525,12 +569,7 @@ class BrainToTextDecoderTrainerTF:
day_indices = batch['day_indices']
with tf.GradientTape() as tape:
# Apply data transformations
features, n_time_steps = DataAugmentationTF.transform_data(
features, n_time_steps, self.args['dataset']['data_transforms'], training=True
)
# Calculate adjusted lengths for CTC
# Calculate adjusted lengths for CTC (data augmentation now handled in dataset pipeline)
adjusted_lens = tf.cast(
(tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) /
self.args['model']['patch_stride'] + 1,
@@ -551,25 +590,28 @@ class BrainToTextDecoderTrainerTF:
# Calculate losses
if use_full:
# Clean CTC loss - use tf.nn.ctc_loss with dense labels (fixed shapes)
# Convert dense labels to sparse for dynamic shapes
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
# Clean CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes)
# tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
clean_logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
clean_loss = tf.nn.ctc_loss(
labels=tf.cast(labels, tf.int32), # Use dense labels with fixed shapes
labels=sparse_labels,
logits=clean_logits_time_major,
label_length=tf.cast(phone_seq_lens, tf.int32), # Re-enable label_length
label_length=None, # Not needed with sparse labels
logit_length=tf.cast(adjusted_lens, tf.int32),
blank_index=0,
logits_time_major=True
)
clean_loss = tf.reduce_mean(clean_loss)
# Noisy CTC loss - use tf.nn.ctc_loss with dense labels (fixed shapes)
# Noisy CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes)
noisy_logits_time_major = tf.transpose(noisy_logits, [1, 0, 2])
noisy_loss = tf.nn.ctc_loss(
labels=tf.cast(labels, tf.int32), # Use dense labels with fixed shapes
labels=sparse_labels, # Reuse same sparse labels
logits=noisy_logits_time_major,
label_length=tf.cast(phone_seq_lens, tf.int32), # Re-enable label_length
label_length=None, # Not needed with sparse labels
logit_length=tf.cast(adjusted_lens, tf.int32),
blank_index=0,
logits_time_major=True
@@ -583,12 +625,15 @@ class BrainToTextDecoderTrainerTF:
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
else:
# Standard CTC loss - use tf.nn.ctc_loss with dense labels (fixed shapes)
# Convert dense labels to sparse for dynamic shapes
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
# Standard CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes)
logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
loss = tf.nn.ctc_loss(
labels=tf.cast(labels, tf.int32), # Use dense labels with fixed shapes
labels=sparse_labels,
logits=logits_time_major,
label_length=tf.cast(phone_seq_lens, tf.int32), # Re-enable label_length
label_length=None, # Not needed with sparse labels
logit_length=tf.cast(adjusted_lens, tf.int32),
blank_index=0,
logits_time_major=True
@@ -638,12 +683,7 @@ class BrainToTextDecoderTrainerTF:
phone_seq_lens = batch['phone_seq_lens']
day_indices = batch['day_indices']
# Apply data transformations (no augmentation for validation)
features, n_time_steps = DataAugmentationTF.transform_data(
features, n_time_steps, self.args['dataset']['data_transforms'], training=False
)
# Calculate adjusted lengths
# Calculate adjusted lengths (no augmentation for validation, handled in dataset pipeline)
adjusted_lens = tf.cast(
(tf.cast(n_time_steps, tf.float32) - self.args['model']['patch_size']) /
self.args['model']['patch_stride'] + 1,
@@ -653,13 +693,16 @@ class BrainToTextDecoderTrainerTF:
# Forward pass (inference mode only)
logits = self.model(features, day_indices, None, False, 'inference', training=False)
# Calculate loss - use tf.nn.ctc_loss with dense labels (fixed shapes)
# Convert dense labels to sparse for dynamic shapes
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
# Calculate loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes)
# tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
logits_time_major = tf.transpose(logits, [1, 0, 2])
loss = tf.nn.ctc_loss(
labels=tf.cast(labels, tf.int32), # Use dense labels with fixed shapes
labels=sparse_labels,
logits=logits_time_major,
label_length=tf.cast(phone_seq_lens, tf.int32), # Re-enable label_length
label_length=None, # Not needed with sparse labels
logit_length=tf.cast(adjusted_lens, tf.int32),
blank_index=0,
logits_time_major=True
@@ -680,60 +723,28 @@ class BrainToTextDecoderTrainerTF:
initial_tpu_status = self._get_detailed_tpu_status()
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
# ========================= DATASET SHAPE ANALYSIS =========================
# Perform one-time full dataset analysis for fixed shapes (TPU requirement)
self.logger.info("🚀 Performing one-time full dataset analysis for fixed shapes...")
# Analyze training dataset (all data for accurate max shapes)
train_analysis_start = time.time()
train_max_shapes = analyze_dataset_shapes(self.train_dataset_tf, sample_size=-1)
train_analysis_time = time.time() - train_analysis_start
self.logger.info(f"✅ Training dataset analysis completed in {train_analysis_time:.2f}s")
# Analyze validation dataset (all data for accurate max shapes)
val_analysis_start = time.time()
val_max_shapes = analyze_dataset_shapes(self.val_dataset_tf, sample_size=-1)
val_analysis_time = time.time() - val_analysis_start
self.logger.info(f"✅ Validation dataset analysis completed in {val_analysis_time:.2f}s")
# Use maximum shapes across both datasets for consistent padding
final_max_shapes = {
'max_time_steps': max(train_max_shapes['max_time_steps'], val_max_shapes['max_time_steps']),
'max_phone_seq_len': max(train_max_shapes['max_phone_seq_len'], val_max_shapes['max_phone_seq_len']),
'max_transcription_len': max(train_max_shapes['max_transcription_len'], val_max_shapes['max_transcription_len']),
'n_features': train_max_shapes['n_features']
}
self.logger.info(f"📊 Final fixed shapes for TPU training:")
self.logger.info(f" Time steps: {final_max_shapes['max_time_steps']}")
self.logger.info(f" Phone sequence length: {final_max_shapes['max_phone_seq_len']}")
self.logger.info(f" Transcription length: {final_max_shapes['max_transcription_len']}")
self.logger.info(f" Features: {final_max_shapes['n_features']}")
# =====================================================================
# Create datasets using modern distribution API with fixed shapes
def create_dist_dataset_fn(input_dataset_tf, training, max_shapes):
"""Create distributed dataset function for modern TPU strategy with fixed shapes"""
# Create datasets using modern distribution API with dynamic padding
def create_dist_dataset_fn(input_dataset_tf, training):
"""Create distributed dataset function for modern TPU strategy with batch-first approach"""
def dataset_fn(input_context):
# create_input_fn now requires max_shapes parameter for fixed shapes
# create_input_fn now uses batch-first approach with dynamic padding
return create_input_fn(
input_dataset_tf,
self.args['dataset']['data_transforms'],
max_shapes=max_shapes, # Pass pre-analyzed shapes
training=training
)
return self.strategy.distribute_datasets_from_function(dataset_fn)
# Distribute datasets using modern API with fixed shapes
# Distribute datasets using modern API with batch-first approach
self.logger.info("🔄 Distributing training dataset across TPU cores...")
dist_start_time = time.time()
train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True, max_shapes=final_max_shapes)
train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True)
train_dist_time = time.time() - dist_start_time
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
self.logger.info("🔄 Distributing validation dataset across TPU cores...")
val_start_time = time.time()
val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False, max_shapes=final_max_shapes)
val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False)
val_dist_time = time.time() - val_start_time
self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")