diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index e59b2f0..0a72551 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -98,6 +98,7 @@ class BrainToTextDecoderTrainerTF: # This ensures we're in the correct replica context print("🔧 Creating optimizer slot variables within TPU replica context...") + @tf.function def init_optimizer_slots(): dummy_gradients = [tf.zeros_like(var) for var in self.model.trainable_variables] self.optimizer.apply_gradients(zip(dummy_gradients, self.model.trainable_variables))