From 0a721435136e4b95f0e0a597f526ff0f9d7e876c Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Fri, 17 Oct 2025 01:26:02 +0800 Subject: [PATCH] legacy adam --- model_training_nnn_tpu/trainer_tf.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index c4e97d1..75e7b1e 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -443,13 +443,25 @@ class BrainToTextDecoderTrainerTF: print("Using TPU-compatible Adam optimizer (avoiding AdamW distributed training bugs)") print("💡 Manual L2 regularization will be applied in training step") - optimizer = tf.keras.optimizers.Adam( - learning_rate=self.args['lr_max'], - beta_1=self.args['beta0'], - beta_2=self.args['beta1'], - epsilon=self.args['epsilon'] - # No weight_decay parameter in Adam - handled manually - ) + # Use legacy Adam optimizer for better TPU distributed training compatibility + # Legacy optimizers have more stable distributed training implementations + try: + optimizer = tf.keras.optimizers.legacy.Adam( + learning_rate=self.args['lr_max'], + beta_1=self.args['beta0'], + beta_2=self.args['beta1'], + epsilon=self.args['epsilon'] + ) + print("✅ Using legacy Adam optimizer for better TPU compatibility") + except AttributeError: + # Fallback to standard Adam if legacy is not available + optimizer = tf.keras.optimizers.Adam( + learning_rate=self.args['lr_max'], + beta_1=self.args['beta0'], + beta_2=self.args['beta1'], + epsilon=self.args['epsilon'] + ) + print("⚠️ Using standard Adam optimizer (legacy not available)") return optimizer