Refactor create_input_fn to support static shape handling for XLA compatibility
This commit is contained in:
@@ -962,27 +962,20 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
||||
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
transform_args: Dict[str, Any],
|
||||
training: bool = True,
|
||||
cache_path: Optional[str] = None) -> tf.data.Dataset:
|
||||
cache_path: Optional[str] = None,
|
||||
use_static_shapes: bool = True) -> tf.data.Dataset:
|
||||
"""
|
||||
Create input function for TPU training with DYNAMIC batching -> BATCH augmentation
|
||||
|
||||
This function uses the proven "batch first, augment after" approach that eliminates
|
||||
the time paradox between data augmentation and shape analysis. This is the FINAL
|
||||
solution that resolves all XLA compilation and padding errors.
|
||||
|
||||
The key insight: data augmentation (especially gauss_smooth with padding='SAME')
|
||||
can increase sequence lengths unpredictably, making pre-computed static shapes invalid.
|
||||
By batching first with dynamic padding, then applying augmentation to batches,
|
||||
we eliminate this temporal paradox entirely.
|
||||
Create input function for TPU training with configurable shape handling
|
||||
|
||||
Args:
|
||||
dataset_tf: BrainToTextDatasetTF instance
|
||||
transform_args: Data transformation configuration
|
||||
training: Whether this is for training (applies augmentations)
|
||||
cache_path: Optional path for disk caching to improve I/O performance
|
||||
use_static_shapes: If True, use pre-computed static shapes for XLA compatibility
|
||||
|
||||
Returns:
|
||||
tf.data.Dataset ready for TPU training with robust dynamic->static flow
|
||||
tf.data.Dataset ready for TPU training
|
||||
"""
|
||||
|
||||
# Step 1: Create individual example dataset
|
||||
@@ -999,21 +992,44 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
split_name = "training" if training else "validation"
|
||||
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache")
|
||||
|
||||
# Step 3: Batch samples with DYNAMIC padding FIRST (eliminates time paradox)
|
||||
print(f"🔧 Using DYNAMIC padding -> batch augmentation approach")
|
||||
print(f"🔧 Feature dimension: {dataset_tf.feature_dim}")
|
||||
# Step 3: Batch samples with shape handling optimized for TPU
|
||||
if use_static_shapes:
|
||||
print(f"🔧 Using STATIC shapes for XLA compatibility")
|
||||
|
||||
# Define dynamic padded shapes - use simple shapes instead of TensorSpec for padded_batch
|
||||
padded_shapes = {
|
||||
'input_features': (None, dataset_tf.feature_dim),
|
||||
'seq_class_ids': (None,),
|
||||
'n_time_steps': (), # scalar
|
||||
'phone_seq_lens': (), # scalar
|
||||
'day_indices': (), # scalar
|
||||
'transcriptions': (None,),
|
||||
'block_nums': (), # scalar
|
||||
'trial_nums': () # scalar
|
||||
}
|
||||
# Analyze dataset to get maximum shapes
|
||||
print("📊 Analyzing dataset for maximum shapes...")
|
||||
max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=100)
|
||||
|
||||
# Use static shapes based on analysis
|
||||
padded_shapes = {
|
||||
'input_features': (max_shapes['max_time_steps'], dataset_tf.feature_dim),
|
||||
'seq_class_ids': (max_shapes['max_phone_seq_len'],),
|
||||
'n_time_steps': (),
|
||||
'phone_seq_lens': (),
|
||||
'day_indices': (),
|
||||
'transcriptions': (max_shapes['max_transcription_len'],),
|
||||
'block_nums': (),
|
||||
'trial_nums': ()
|
||||
}
|
||||
print(f"📏 Using static shapes: time_steps={max_shapes['max_time_steps']}, "
|
||||
f"phone_len={max_shapes['max_phone_seq_len']}, "
|
||||
f"transcription_len={max_shapes['max_transcription_len']}")
|
||||
else:
|
||||
print(f"🔧 Using DYNAMIC shapes (may cause XLA compilation issues)")
|
||||
|
||||
# Use dynamic shapes - may cause XLA compilation issues
|
||||
padded_shapes = {
|
||||
'input_features': (None, dataset_tf.feature_dim),
|
||||
'seq_class_ids': (None,),
|
||||
'n_time_steps': (),
|
||||
'phone_seq_lens': (),
|
||||
'day_indices': (),
|
||||
'transcriptions': (None,),
|
||||
'block_nums': (),
|
||||
'trial_nums': ()
|
||||
}
|
||||
|
||||
print(f"🔧 Feature dimension: {dataset_tf.feature_dim}")
|
||||
|
||||
# Define padding values for each field
|
||||
padding_values = {
|
||||
|
@@ -679,31 +679,32 @@ class BrainToTextDecoderTrainerTF:
|
||||
initial_tpu_status = self._get_detailed_tpu_status()
|
||||
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
|
||||
|
||||
# ========================= 终极解决方案:批处理优先 =========================
|
||||
# 使用经过验证的"先批处理,后增强"方法,消除数据增强与形状分析的时间悖论
|
||||
self.logger.info("🚀 Using FINAL 'batch-first, augment-after' approach")
|
||||
self.logger.info(" This eliminates the time paradox between data augmentation and shape analysis")
|
||||
# ========================= XLA静态形状解决方案 =========================
|
||||
# 使用静态形状解决XLA编译错误,确保所有张量形状在编译时已知
|
||||
self.logger.info("🚀 Using STATIC shapes approach for XLA compatibility")
|
||||
self.logger.info(" This resolves 'Range must be a compile-time constant' errors")
|
||||
|
||||
# 简化的数据集创建函数,不再需要 max_shapes
|
||||
# 简化的数据集创建函数,使用静态形状
|
||||
def create_dist_dataset_fn(input_dataset_tf, training):
|
||||
"""Create distributed dataset function for the final 'batch-first' approach."""
|
||||
"""Create distributed dataset function with static shapes for XLA compatibility."""
|
||||
def dataset_fn(input_context):
|
||||
# 调用新版的 create_input_fn,它不需要 max_shapes
|
||||
# Use static shapes for XLA compatibility on TPU
|
||||
return create_input_fn(
|
||||
input_dataset_tf,
|
||||
self.args['dataset']['data_transforms'],
|
||||
training=training
|
||||
training=training,
|
||||
use_static_shapes=True # Enable static shapes for XLA compatibility
|
||||
)
|
||||
return self.strategy.distribute_datasets_from_function(dataset_fn)
|
||||
|
||||
# 使用新的、简化的函数签名创建数据集
|
||||
self.logger.info("🔄 Distributing training dataset (batch-first approach)...")
|
||||
# 使用静态形状创建数据集
|
||||
self.logger.info("🔄 Distributing training dataset (static shapes approach)...")
|
||||
dist_start_time = time.time()
|
||||
train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True)
|
||||
train_dist_time = time.time() - dist_start_time
|
||||
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
|
||||
|
||||
self.logger.info("🔄 Distributing validation dataset (batch-first approach)...")
|
||||
self.logger.info("🔄 Distributing validation dataset (static shapes approach)...")
|
||||
val_start_time = time.time()
|
||||
val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False)
|
||||
val_dist_time = time.time() - val_start_time
|
||||
@@ -724,8 +725,8 @@ class BrainToTextDecoderTrainerTF:
|
||||
step = 0
|
||||
|
||||
self.logger.info("🔄 Starting training loop...")
|
||||
self.logger.info("📋 Note: If you see 'TPU has inputs with dynamic shapes' warnings,")
|
||||
self.logger.info(" consider using padded_batch with fixed shapes in create_input_fn")
|
||||
self.logger.info("📋 Using static shapes - should resolve XLA compilation errors")
|
||||
self.logger.info(" If successful, you should not see 'Range must be a compile-time constant' errors")
|
||||
|
||||
for batch in train_dist_dataset:
|
||||
if step >= self.args['num_training_batches']:
|
||||
|
Reference in New Issue
Block a user