From 8ee09b6b5e268ed90e9022f31fc4eab1b8b2ede4 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Fri, 17 Oct 2025 01:54:32 +0800 Subject: [PATCH] f --- model_training_nnn_tpu/trainer_tf.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index bb44d44..e59b2f0 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -94,11 +94,17 @@ class BrainToTextDecoderTrainerTF: print("🔧 Initializing optimizer for TPU training...") print(f"Optimizer type: {type(self.optimizer).__name__}") - # Initialize optimizer slot variables within strategy scope - # This prevents the "different scope" error - print("🔧 Creating optimizer slot variables within TPU strategy scope...") - dummy_gradients = [tf.zeros_like(var) for var in self.model.trainable_variables] - self.optimizer.apply_gradients(zip(dummy_gradients, self.model.trainable_variables)) + # Initialize optimizer slot variables using strategy.run + # This ensures we're in the correct replica context + print("🔧 Creating optimizer slot variables within TPU replica context...") + + def init_optimizer_slots(): + dummy_gradients = [tf.zeros_like(var) for var in self.model.trainable_variables] + self.optimizer.apply_gradients(zip(dummy_gradients, self.model.trainable_variables)) + return tf.constant(True) # Return something to satisfy strategy.run + + # Run the slot initialization in replica context + self.strategy.run(init_optimizer_slots) print("✅ Optimizer ready for TPU training") self.lr_scheduler = self._create_lr_scheduler()