diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index bb44d44..e59b2f0 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -94,11 +94,17 @@ class BrainToTextDecoderTrainerTF: print("🔧 Initializing optimizer for TPU training...") print(f"Optimizer type: {type(self.optimizer).__name__}") - # Initialize optimizer slot variables within strategy scope - # This prevents the "different scope" error - print("🔧 Creating optimizer slot variables within TPU strategy scope...") - dummy_gradients = [tf.zeros_like(var) for var in self.model.trainable_variables] - self.optimizer.apply_gradients(zip(dummy_gradients, self.model.trainable_variables)) + # Initialize optimizer slot variables using strategy.run + # This ensures we're in the correct replica context + print("🔧 Creating optimizer slot variables within TPU replica context...") + + def init_optimizer_slots(): + dummy_gradients = [tf.zeros_like(var) for var in self.model.trainable_variables] + self.optimizer.apply_gradients(zip(dummy_gradients, self.model.trainable_variables)) + return tf.constant(True) # Return something to satisfy strategy.run + + # Run the slot initialization in replica context + self.strategy.run(init_optimizer_slots) print("✅ Optimizer ready for TPU training") self.lr_scheduler = self._create_lr_scheduler()