legacy adam

This commit is contained in:
Zchen
2025-10-17 01:26:02 +08:00
parent 7df78244e6
commit 0a72143513

View File

@@ -443,13 +443,25 @@ class BrainToTextDecoderTrainerTF:
print("Using TPU-compatible Adam optimizer (avoiding AdamW distributed training bugs)") print("Using TPU-compatible Adam optimizer (avoiding AdamW distributed training bugs)")
print("💡 Manual L2 regularization will be applied in training step") print("💡 Manual L2 regularization will be applied in training step")
optimizer = tf.keras.optimizers.Adam( # Use legacy Adam optimizer for better TPU distributed training compatibility
learning_rate=self.args['lr_max'], # Legacy optimizers have more stable distributed training implementations
beta_1=self.args['beta0'], try:
beta_2=self.args['beta1'], optimizer = tf.keras.optimizers.legacy.Adam(
epsilon=self.args['epsilon'] learning_rate=self.args['lr_max'],
# No weight_decay parameter in Adam - handled manually beta_1=self.args['beta0'],
) beta_2=self.args['beta1'],
epsilon=self.args['epsilon']
)
print("✅ Using legacy Adam optimizer for better TPU compatibility")
except AttributeError:
# Fallback to standard Adam if legacy is not available
optimizer = tf.keras.optimizers.Adam(
learning_rate=self.args['lr_max'],
beta_1=self.args['beta0'],
beta_2=self.args['beta1'],
epsilon=self.args['epsilon']
)
print("⚠️ Using standard Adam optimizer (legacy not available)")
return optimizer return optimizer