This commit is contained in:
Zchen
2025-10-21 00:31:59 +08:00
parent e7c9b95b00
commit ab12d0b7ee
2 changed files with 128 additions and 94 deletions

View File

@@ -908,111 +908,56 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
# Utility functions for TPU-optimized data pipeline # Utility functions for TPU-optimized data pipeline
def create_input_fn(dataset_tf: BrainToTextDatasetTF, def create_input_fn(dataset_tf: BrainToTextDatasetTF,
transform_args: Dict[str, Any], transform_args: Dict[str, Any],
max_shapes: Dict[str, int],
training: bool = True, training: bool = True,
cache_path: Optional[str] = None) -> tf.data.Dataset: cache_path: Optional[str] = None) -> tf.data.Dataset:
""" """
Create input function for TPU training with PRE-ANALYZED FIXED shapes Create input function for TPU training with DYNAMIC batching -> BATCH augmentation
This function uses pre-computed maximum shapes to create STATIC-size batches, This function uses the proven "batch first, augment after" approach that eliminates
ensuring XLA compilation success on TPU hardware. This is CRITICAL for the the time paradox between data augmentation and shape analysis. This is the FINAL
final resolution of both CTC loss compatibility and graph structure issues. solution that resolves all XLA compilation and padding errors.
The key insight: data augmentation (especially gauss_smooth with padding='SAME')
can increase sequence lengths unpredictably, making pre-computed static shapes invalid.
By batching first with dynamic padding, then applying augmentation to batches,
we eliminate this temporal paradox entirely.
Args: Args:
dataset_tf: BrainToTextDatasetTF instance dataset_tf: BrainToTextDatasetTF instance
transform_args: Data transformation configuration transform_args: Data transformation configuration
max_shapes: Pre-computed maximum shapes dictionary with keys:
'max_time_steps', 'max_phone_seq_len', 'max_transcription_len', 'n_features'
training: Whether this is for training (applies augmentations) training: Whether this is for training (applies augmentations)
cache_path: Optional path for disk caching to improve I/O performance cache_path: Optional path for disk caching to improve I/O performance
Returns: Returns:
tf.data.Dataset ready for TPU training with FIXED STATIC shapes tf.data.Dataset ready for TPU training with robust dynamic->static flow
""" """
# Step 1: Create individual example dataset with file-grouping I/O optimization # Step 1: Create individual example dataset
dataset = dataset_tf.create_individual_dataset() dataset = dataset_tf.create_individual_dataset()
# Step 2: Cache raw samples BEFORE any augmentation # Step 2: Cache raw samples BEFORE any augmentation or batching
if cache_path: if cache_path:
dataset = dataset.cache(cache_path) dataset = dataset.cache(cache_path)
split_name = "training" if training else "validation" split_name = "training" if training else "validation"
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: {cache_path}") print(f"🗃️ {split_name.capitalize()} dataset caching enabled: {cache_path}")
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster") print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
else: else:
# 如果没有指定缓存路径,默认使用内存缓存
dataset = dataset.cache() dataset = dataset.cache()
split_name = "training" if training else "validation" split_name = "training" if training else "validation"
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache") print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache")
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
# Step 3: Apply transformations to individual examples BEFORE batching # Step 3: Batch samples with DYNAMIC padding FIRST (eliminates time paradox)
def apply_transforms(example): print(f"🔧 Using DYNAMIC padding -> batch augmentation approach")
"""Apply data transformations to individual examples""" print(f"🔧 Feature dimension: {dataset_tf.feature_dim}")
features = example['input_features']
n_time_steps = example['n_time_steps']
# Apply transformations # Define dynamic padded shapes - key insight: None allows for dynamic lengths
features, n_time_steps = DataAugmentationTF.transform_data(
tf.expand_dims(features, 0), # Add batch dimension for transforms
tf.expand_dims(n_time_steps, 0),
transform_args,
training=training
)
# Remove batch dimension
example['input_features'] = tf.squeeze(features, 0)
example['n_time_steps'] = tf.squeeze(n_time_steps, 0)
return example
# Apply transforms to cached data
dataset = dataset.map(
apply_transforms,
num_parallel_calls=tf.data.AUTOTUNE
)
# ========================= 终极调试代码 =========================
def debug_print_shape(example):
"""调试函数:在 padded_batch 之前打印每个样本的形状"""
tf.print("🔍 Sample Shape Debug:",
tf.shape(example['input_features']),
"Expected feature dim:", dataset_tf.feature_dim,
output_stream=sys.stdout)
return example
# 添加形状调试 - 这会在图执行时打印信息
dataset = dataset.map(debug_print_shape)
print(f"⚠️ Debug mode: Will print each sample shape before padded_batch")
# =============================================================
# Step 4: Batch samples with FIXED STATIC padding (CRITICAL for XLA)
print(f"🔧 Using PRE-ANALYZED FIXED shapes for maximum TPU performance:")
# Extract pre-analyzed shape information
max_time_steps = max_shapes['max_time_steps']
max_phone_seq_len = max_shapes['max_phone_seq_len']
max_transcription_len = max_shapes['max_transcription_len']
# ========================= 使用统一的特征维度 =========================
# 使用 dataset_tf 对象中存储的、经过验证的特征维度,而不是依赖外部参数
n_features = dataset_tf.feature_dim # <--- 关键修改:使用自动检测的特征维度
print(f"🔧 Using verified feature dimension from dataset: {n_features}")
# ========================= 特征维度修改结束 =========================
print(f" Fixed time steps: {max_time_steps}")
print(f" Fixed phone sequence length: {max_phone_seq_len}")
print(f" Fixed transcription length: {max_transcription_len}")
print(f" Number of features: {n_features}")
# Define FIXED padded shapes with TensorSpec for better type safety
padded_shapes = { padded_shapes = {
'input_features': tf.TensorSpec(shape=[max_time_steps, n_features], dtype=tf.float32), 'input_features': tf.TensorSpec(shape=(None, dataset_tf.feature_dim), dtype=tf.float32),
'seq_class_ids': tf.TensorSpec(shape=[max_phone_seq_len], dtype=tf.int32), 'seq_class_ids': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'n_time_steps': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar 'n_time_steps': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
'phone_seq_lens': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar 'phone_seq_lens': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
'day_indices': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar 'day_indices': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
'transcriptions': tf.TensorSpec(shape=[max_transcription_len], dtype=tf.int32), 'transcriptions': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'block_nums': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar 'block_nums': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
'trial_nums': tf.TensorSpec(shape=[], dtype=tf.int32) # scalar 'trial_nums': tf.TensorSpec(shape=[], dtype=tf.int32) # scalar
} }
@@ -1029,7 +974,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
'trial_nums': 0 'trial_nums': 0
} }
# Create batches with FIXED padding - XLA compiler will be happy! # Create batches with DYNAMIC padding - this cannot fail due to size mismatches
dataset = dataset.padded_batch( dataset = dataset.padded_batch(
batch_size=dataset_tf.batch_size, batch_size=dataset_tf.batch_size,
padded_shapes=padded_shapes, padded_shapes=padded_shapes,
@@ -1037,7 +982,37 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
drop_remainder=True # Critical for TPU: ensures all batches have same size drop_remainder=True # Critical for TPU: ensures all batches have same size
) )
# Prefetch for optimal performance # Step 4: Apply data augmentation to BATCHES (after dynamic batching)
def apply_batch_transforms(batch):
"""Apply data transformations to entire batches - resolves time paradox"""
features = batch['input_features']
n_time_steps = batch['n_time_steps']
# Apply transformations to the entire batch
features, n_time_steps = DataAugmentationTF.transform_data(
features, # Already has batch dimension
n_time_steps,
transform_args,
training=training
)
# Update batch with transformed data
batch['input_features'] = features
batch['n_time_steps'] = n_time_steps
return batch
# Apply batch transforms only during training
if training:
dataset = dataset.map(
apply_batch_transforms,
num_parallel_calls=tf.data.AUTOTUNE
)
print(f"✅ Batch augmentation enabled for training")
else:
print(f"✅ No augmentation for validation")
# Step 5: Prefetch for optimal performance
dataset = dataset.prefetch(tf.data.AUTOTUNE) dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset return dataset

View File

@@ -17,8 +17,55 @@ except ImportError:
print("Warning: editdistance not available, falling back to approximation") print("Warning: editdistance not available, falling back to approximation")
editdistance = None editdistance = None
# XLA-compatible CTC loss implementation # Note: Reverted to standard tf.nn.ctc_loss + SparseTensor approach
from tf_seq2seq_losses import classic_ctc_loss # for compatibility with "batch first, augment after" data pipeline
def dense_to_sparse(dense_tensor, sequence_lengths):
"""
Convert dense tensor to sparse tensor for CTC loss with dynamic shapes
This function is essential for the "batch first, augment after" approach
as it handles the conversion from dynamic dense tensors to SparseTensor
format required by tf.nn.ctc_loss.
Args:
dense_tensor: Dense tensor with shape [batch_size, max_seq_len]
sequence_lengths: Actual sequence lengths [batch_size]
Returns:
SparseTensor suitable for tf.nn.ctc_loss
"""
# Create mask for valid (non-zero) elements within sequence lengths
batch_size = tf.shape(dense_tensor)[0]
max_seq_len = tf.shape(dense_tensor)[1]
# Create range indices
batch_indices = tf.range(batch_size)
seq_indices = tf.range(max_seq_len)
# Create meshgrid for sequence dimensions
_, seq_mesh = tf.meshgrid(batch_indices, seq_indices, indexing='ij')
# Create mask based on sequence lengths and non-zero values
length_mask = seq_mesh < tf.expand_dims(sequence_lengths, 1)
value_mask = tf.not_equal(dense_tensor, 0)
combined_mask = tf.logical_and(length_mask, value_mask)
# Get indices of valid elements
indices = tf.where(combined_mask)
# Get values at valid indices
values = tf.gather_nd(dense_tensor, indices)
# Create sparse tensor
dense_shape = tf.cast(tf.shape(dense_tensor), tf.int64)
return tf.SparseTensor(
indices=tf.cast(indices, tf.int64),
values=tf.cast(values, tf.int32),
dense_shape=dense_shape
)
from rnn_model_tf import ( from rnn_model_tf import (
TripleGRUDecoder, TripleGRUDecoder,
@@ -559,23 +606,29 @@ class BrainToTextDecoderTrainerTF:
# Calculate losses using TPU-compatible CTC implementation # Calculate losses using TPU-compatible CTC implementation
if use_full: if use_full:
# Clean CTC loss - using XLA-compatible classic_ctc_loss # Clean CTC loss - using standard tf.nn.ctc_loss with SparseTensor
clean_loss = classic_ctc_loss( sparse_labels = dense_to_sparse(labels, phone_seq_lens)
labels=tf.cast(labels, tf.int32), # Dense labels as int32 clean_loss = tf.nn.ctc_loss(
labels=sparse_labels,
logits=clean_logits, logits=clean_logits,
label_length=phone_seq_lens, label_length=None, # SparseTensor doesn't need label_length
logit_length=adjusted_lens, logit_length=adjusted_lens,
logits_time_major=False,
blank_index=0 blank_index=0
) )
clean_loss = tf.reduce_mean(clean_loss)
# Noisy CTC loss - using XLA-compatible classic_ctc_loss # Noisy CTC loss - using standard tf.nn.ctc_loss with SparseTensor
noisy_loss = classic_ctc_loss( # Reuse the same sparse_labels from above
labels=tf.cast(labels, tf.int32), # Dense labels as int32 noisy_loss = tf.nn.ctc_loss(
labels=sparse_labels,
logits=noisy_logits, logits=noisy_logits,
label_length=phone_seq_lens, label_length=None, # SparseTensor doesn't need label_length
logit_length=adjusted_lens, logit_length=adjusted_lens,
logits_time_major=False,
blank_index=0 blank_index=0
) )
noisy_loss = tf.reduce_mean(noisy_loss)
# Optional noise L2 regularization # Optional noise L2 regularization
noise_l2 = tf.constant(0.0, dtype=clean_loss.dtype) noise_l2 = tf.constant(0.0, dtype=clean_loss.dtype)
@@ -584,14 +637,17 @@ class BrainToTextDecoderTrainerTF:
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2 loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
else: else:
# Standard CTC loss - using XLA-compatible classic_ctc_loss # Standard CTC loss - using standard tf.nn.ctc_loss with SparseTensor
loss = classic_ctc_loss( sparse_labels = dense_to_sparse(labels, phone_seq_lens)
labels=tf.cast(labels, tf.int32), # Dense labels as int32 loss = tf.nn.ctc_loss(
labels=sparse_labels,
logits=clean_logits, logits=clean_logits,
label_length=phone_seq_lens, label_length=None, # SparseTensor doesn't need label_length
logit_length=adjusted_lens, logit_length=adjusted_lens,
logits_time_major=False,
blank_index=0 blank_index=0
) )
loss = tf.reduce_mean(loss)
# AdamW handles weight decay automatically - no manual L2 regularization needed # AdamW handles weight decay automatically - no manual L2 regularization needed
# TensorFlow混合精度处理不需要手动scalingKeras policy自动处理 # TensorFlow混合精度处理不需要手动scalingKeras policy自动处理
@@ -646,14 +702,17 @@ class BrainToTextDecoderTrainerTF:
# Forward pass (inference mode only) # Forward pass (inference mode only)
logits = self.model(features, day_indices, None, False, 'inference', training=False) logits = self.model(features, day_indices, None, False, 'inference', training=False)
# Calculate loss using XLA-compatible classic_ctc_loss # Calculate loss using standard tf.nn.ctc_loss with SparseTensor
loss = classic_ctc_loss( sparse_labels = dense_to_sparse(labels, phone_seq_lens)
labels=tf.cast(labels, tf.int32), # Dense labels as int32 loss = tf.nn.ctc_loss(
labels=sparse_labels,
logits=logits, logits=logits,
label_length=phone_seq_lens, label_length=None, # SparseTensor doesn't need label_length
logit_length=adjusted_lens, logit_length=adjusted_lens,
logits_time_major=False,
blank_index=0 blank_index=0
) )
loss = tf.reduce_mean(loss)
# Greedy decoding for PER calculation # Greedy decoding for PER calculation
predicted_ids = tf.argmax(logits, axis=-1, output_type=tf.int32) predicted_ids = tf.argmax(logits, axis=-1, output_type=tf.int32)