Refactor create_input_fn to support static shape handling for XLA compatibility

This commit is contained in:
Zchen
2025-10-22 01:29:31 +08:00
parent c03441d8f3
commit 6fb5907c72
2 changed files with 56 additions and 39 deletions

View File

@@ -962,27 +962,20 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
transform_args: Dict[str, Any],
training: bool = True,
cache_path: Optional[str] = None) -> tf.data.Dataset:
cache_path: Optional[str] = None,
use_static_shapes: bool = True) -> tf.data.Dataset:
"""
Create input function for TPU training with DYNAMIC batching -> BATCH augmentation
This function uses the proven "batch first, augment after" approach that eliminates
the time paradox between data augmentation and shape analysis. This is the FINAL
solution that resolves all XLA compilation and padding errors.
The key insight: data augmentation (especially gauss_smooth with padding='SAME')
can increase sequence lengths unpredictably, making pre-computed static shapes invalid.
By batching first with dynamic padding, then applying augmentation to batches,
we eliminate this temporal paradox entirely.
Create input function for TPU training with configurable shape handling
Args:
dataset_tf: BrainToTextDatasetTF instance
transform_args: Data transformation configuration
training: Whether this is for training (applies augmentations)
cache_path: Optional path for disk caching to improve I/O performance
use_static_shapes: If True, use pre-computed static shapes for XLA compatibility
Returns:
tf.data.Dataset ready for TPU training with robust dynamic->static flow
tf.data.Dataset ready for TPU training
"""
# Step 1: Create individual example dataset
@@ -999,21 +992,44 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
split_name = "training" if training else "validation"
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache")
# Step 3: Batch samples with DYNAMIC padding FIRST (eliminates time paradox)
print(f"🔧 Using DYNAMIC padding -> batch augmentation approach")
print(f"🔧 Feature dimension: {dataset_tf.feature_dim}")
# Step 3: Batch samples with shape handling optimized for TPU
if use_static_shapes:
print(f"🔧 Using STATIC shapes for XLA compatibility")
# Define dynamic padded shapes - use simple shapes instead of TensorSpec for padded_batch
padded_shapes = {
'input_features': (None, dataset_tf.feature_dim),
'seq_class_ids': (None,),
'n_time_steps': (), # scalar
'phone_seq_lens': (), # scalar
'day_indices': (), # scalar
'transcriptions': (None,),
'block_nums': (), # scalar
'trial_nums': () # scalar
}
# Analyze dataset to get maximum shapes
print("📊 Analyzing dataset for maximum shapes...")
max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=100)
# Use static shapes based on analysis
padded_shapes = {
'input_features': (max_shapes['max_time_steps'], dataset_tf.feature_dim),
'seq_class_ids': (max_shapes['max_phone_seq_len'],),
'n_time_steps': (),
'phone_seq_lens': (),
'day_indices': (),
'transcriptions': (max_shapes['max_transcription_len'],),
'block_nums': (),
'trial_nums': ()
}
print(f"📏 Using static shapes: time_steps={max_shapes['max_time_steps']}, "
f"phone_len={max_shapes['max_phone_seq_len']}, "
f"transcription_len={max_shapes['max_transcription_len']}")
else:
print(f"🔧 Using DYNAMIC shapes (may cause XLA compilation issues)")
# Use dynamic shapes - may cause XLA compilation issues
padded_shapes = {
'input_features': (None, dataset_tf.feature_dim),
'seq_class_ids': (None,),
'n_time_steps': (),
'phone_seq_lens': (),
'day_indices': (),
'transcriptions': (None,),
'block_nums': (),
'trial_nums': ()
}
print(f"🔧 Feature dimension: {dataset_tf.feature_dim}")
# Define padding values for each field
padding_values = {

View File

@@ -679,31 +679,32 @@ class BrainToTextDecoderTrainerTF:
initial_tpu_status = self._get_detailed_tpu_status()
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
# ========================= 终极解决方案:批处理优先 =========================
# 使用经过验证的"先批处理,后增强"方法,消除数据增强与形状分析的时间悖论
self.logger.info("🚀 Using FINAL 'batch-first, augment-after' approach")
self.logger.info(" This eliminates the time paradox between data augmentation and shape analysis")
# ========================= XLA静态形状解决方案 =========================
# 使用静态形状解决XLA编译错误确保所有张量形状在编译时已知
self.logger.info("🚀 Using STATIC shapes approach for XLA compatibility")
self.logger.info(" This resolves 'Range must be a compile-time constant' errors")
# 简化的数据集创建函数,不再需要 max_shapes
# 简化的数据集创建函数,使用静态形状
def create_dist_dataset_fn(input_dataset_tf, training):
"""Create distributed dataset function for the final 'batch-first' approach."""
"""Create distributed dataset function with static shapes for XLA compatibility."""
def dataset_fn(input_context):
# 调用新版的 create_input_fn它不需要 max_shapes
# Use static shapes for XLA compatibility on TPU
return create_input_fn(
input_dataset_tf,
self.args['dataset']['data_transforms'],
training=training
training=training,
use_static_shapes=True # Enable static shapes for XLA compatibility
)
return self.strategy.distribute_datasets_from_function(dataset_fn)
# 使用新的、简化的函数签名创建数据集
self.logger.info("🔄 Distributing training dataset (batch-first approach)...")
# 使用静态形状创建数据集
self.logger.info("🔄 Distributing training dataset (static shapes approach)...")
dist_start_time = time.time()
train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True)
train_dist_time = time.time() - dist_start_time
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
self.logger.info("🔄 Distributing validation dataset (batch-first approach)...")
self.logger.info("🔄 Distributing validation dataset (static shapes approach)...")
val_start_time = time.time()
val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False)
val_dist_time = time.time() - val_start_time
@@ -724,8 +725,8 @@ class BrainToTextDecoderTrainerTF:
step = 0
self.logger.info("🔄 Starting training loop...")
self.logger.info("📋 Note: If you see 'TPU has inputs with dynamic shapes' warnings,")
self.logger.info(" consider using padded_batch with fixed shapes in create_input_fn")
self.logger.info("📋 Using static shapes - should resolve XLA compilation errors")
self.logger.info(" If successful, you should not see 'Range must be a compile-time constant' errors")
for batch in train_dist_dataset:
if step >= self.args['num_training_batches']: