diff --git a/model_training_nnn_tpu/dataset_tf.py b/model_training_nnn_tpu/dataset_tf.py index 2c8bdea..ba9afb8 100644 --- a/model_training_nnn_tpu/dataset_tf.py +++ b/model_training_nnn_tpu/dataset_tf.py @@ -962,27 +962,20 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = def create_input_fn(dataset_tf: BrainToTextDatasetTF, transform_args: Dict[str, Any], training: bool = True, - cache_path: Optional[str] = None) -> tf.data.Dataset: + cache_path: Optional[str] = None, + use_static_shapes: bool = True) -> tf.data.Dataset: """ - Create input function for TPU training with DYNAMIC batching -> BATCH augmentation - - This function uses the proven "batch first, augment after" approach that eliminates - the time paradox between data augmentation and shape analysis. This is the FINAL - solution that resolves all XLA compilation and padding errors. - - The key insight: data augmentation (especially gauss_smooth with padding='SAME') - can increase sequence lengths unpredictably, making pre-computed static shapes invalid. - By batching first with dynamic padding, then applying augmentation to batches, - we eliminate this temporal paradox entirely. + Create input function for TPU training with configurable shape handling Args: dataset_tf: BrainToTextDatasetTF instance transform_args: Data transformation configuration training: Whether this is for training (applies augmentations) cache_path: Optional path for disk caching to improve I/O performance + use_static_shapes: If True, use pre-computed static shapes for XLA compatibility Returns: - tf.data.Dataset ready for TPU training with robust dynamic->static flow + tf.data.Dataset ready for TPU training """ # Step 1: Create individual example dataset @@ -999,21 +992,44 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, split_name = "training" if training else "validation" print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache") - # Step 3: Batch samples with DYNAMIC padding FIRST (eliminates time paradox) - print(f"🔧 Using DYNAMIC padding -> batch augmentation approach") - print(f"🔧 Feature dimension: {dataset_tf.feature_dim}") + # Step 3: Batch samples with shape handling optimized for TPU + if use_static_shapes: + print(f"🔧 Using STATIC shapes for XLA compatibility") - # Define dynamic padded shapes - use simple shapes instead of TensorSpec for padded_batch - padded_shapes = { - 'input_features': (None, dataset_tf.feature_dim), - 'seq_class_ids': (None,), - 'n_time_steps': (), # scalar - 'phone_seq_lens': (), # scalar - 'day_indices': (), # scalar - 'transcriptions': (None,), - 'block_nums': (), # scalar - 'trial_nums': () # scalar - } + # Analyze dataset to get maximum shapes + print("📊 Analyzing dataset for maximum shapes...") + max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=100) + + # Use static shapes based on analysis + padded_shapes = { + 'input_features': (max_shapes['max_time_steps'], dataset_tf.feature_dim), + 'seq_class_ids': (max_shapes['max_phone_seq_len'],), + 'n_time_steps': (), + 'phone_seq_lens': (), + 'day_indices': (), + 'transcriptions': (max_shapes['max_transcription_len'],), + 'block_nums': (), + 'trial_nums': () + } + print(f"📏 Using static shapes: time_steps={max_shapes['max_time_steps']}, " + f"phone_len={max_shapes['max_phone_seq_len']}, " + f"transcription_len={max_shapes['max_transcription_len']}") + else: + print(f"🔧 Using DYNAMIC shapes (may cause XLA compilation issues)") + + # Use dynamic shapes - may cause XLA compilation issues + padded_shapes = { + 'input_features': (None, dataset_tf.feature_dim), + 'seq_class_ids': (None,), + 'n_time_steps': (), + 'phone_seq_lens': (), + 'day_indices': (), + 'transcriptions': (None,), + 'block_nums': (), + 'trial_nums': () + } + + print(f"🔧 Feature dimension: {dataset_tf.feature_dim}") # Define padding values for each field padding_values = { diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 42b261a..6ff2d7b 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -679,31 +679,32 @@ class BrainToTextDecoderTrainerTF: initial_tpu_status = self._get_detailed_tpu_status() self.logger.info(f"Initial TPU Status: {initial_tpu_status}") - # ========================= 终极解决方案:批处理优先 ========================= - # 使用经过验证的"先批处理,后增强"方法,消除数据增强与形状分析的时间悖论 - self.logger.info("🚀 Using FINAL 'batch-first, augment-after' approach") - self.logger.info(" This eliminates the time paradox between data augmentation and shape analysis") + # ========================= XLA静态形状解决方案 ========================= + # 使用静态形状解决XLA编译错误,确保所有张量形状在编译时已知 + self.logger.info("🚀 Using STATIC shapes approach for XLA compatibility") + self.logger.info(" This resolves 'Range must be a compile-time constant' errors") - # 简化的数据集创建函数,不再需要 max_shapes + # 简化的数据集创建函数,使用静态形状 def create_dist_dataset_fn(input_dataset_tf, training): - """Create distributed dataset function for the final 'batch-first' approach.""" + """Create distributed dataset function with static shapes for XLA compatibility.""" def dataset_fn(input_context): - # 调用新版的 create_input_fn,它不需要 max_shapes + # Use static shapes for XLA compatibility on TPU return create_input_fn( input_dataset_tf, self.args['dataset']['data_transforms'], - training=training + training=training, + use_static_shapes=True # Enable static shapes for XLA compatibility ) return self.strategy.distribute_datasets_from_function(dataset_fn) - # 使用新的、简化的函数签名创建数据集 - self.logger.info("🔄 Distributing training dataset (batch-first approach)...") + # 使用静态形状创建数据集 + self.logger.info("🔄 Distributing training dataset (static shapes approach)...") dist_start_time = time.time() train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True) train_dist_time = time.time() - dist_start_time self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s") - self.logger.info("🔄 Distributing validation dataset (batch-first approach)...") + self.logger.info("🔄 Distributing validation dataset (static shapes approach)...") val_start_time = time.time() val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False) val_dist_time = time.time() - val_start_time @@ -724,8 +725,8 @@ class BrainToTextDecoderTrainerTF: step = 0 self.logger.info("🔄 Starting training loop...") - self.logger.info("📋 Note: If you see 'TPU has inputs with dynamic shapes' warnings,") - self.logger.info(" consider using padded_batch with fixed shapes in create_input_fn") + self.logger.info("📋 Using static shapes - should resolve XLA compilation errors") + self.logger.info(" If successful, you should not see 'Range must be a compile-time constant' errors") for batch in train_dist_dataset: if step >= self.args['num_training_batches']: