adamw to adam

This commit is contained in:
Zchen
2025-10-17 01:07:01 +08:00
parent a96e272f7b
commit 7df78244e6

View File

@@ -137,12 +137,14 @@ class BrainToTextDecoderTrainerTF:
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0))
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0))
# TPU-specific weight decay handling # Manual weight decay handling for all environments (since we use Adam)
self.manual_weight_decay = False self.manual_weight_decay = False
if isinstance(self.strategy, tf.distribute.TPUStrategy) and self.args.get('weight_decay', 0.0) > 0: if self.args.get('weight_decay', 0.0) > 0:
self.manual_weight_decay = True self.manual_weight_decay = True
self.weight_decay_rate = self.args['weight_decay'] self.weight_decay_rate = self.args['weight_decay']
print(f"🔧 Manual L2 regularization enabled: {self.weight_decay_rate}") print(f"🔧 Manual L2 regularization enabled: {self.weight_decay_rate}")
else:
print("💡 No weight decay configured")
if self.adv_enabled: if self.adv_enabled:
if self.logger: if self.logger:
@@ -435,28 +437,19 @@ class BrainToTextDecoderTrainerTF:
# For TPU training, we need to be more explicit about optimizer configuration # For TPU training, we need to be more explicit about optimizer configuration
# to avoid strategy context issues # to avoid strategy context issues
if isinstance(self.strategy, tf.distribute.TPUStrategy): # IMPORTANT: Use Adam instead of AdamW to avoid TPU distributed training bugs
print("Using TPU-optimized optimizer configuration") # AdamW has known issues with _apply_weight_decay in TPU environments even when weight_decay=0.0
# TPU-specific optimizer configuration # We implement manual L2 regularization (weight decay) in the training step instead
# IMPORTANT: Disable weight_decay for TPU due to distributed training compatibility issues print("Using TPU-compatible Adam optimizer (avoiding AdamW distributed training bugs)")
# We'll implement manual L2 regularization instead print("💡 Manual L2 regularization will be applied in training step")
optimizer = tf.keras.optimizers.AdamW(
learning_rate=self.args['lr_max'], optimizer = tf.keras.optimizers.Adam(
beta_1=self.args['beta0'], learning_rate=self.args['lr_max'],
beta_2=self.args['beta1'], beta_1=self.args['beta0'],
epsilon=self.args['epsilon'], beta_2=self.args['beta1'],
weight_decay=0.0 # Disabled for TPU compatibility epsilon=self.args['epsilon']
# REMOVE global_clipnorm to avoid double clipping with manual tf.clip_by_global_norm # No weight_decay parameter in Adam - handled manually
) )
else:
print("Using standard optimizer configuration")
optimizer = tf.keras.optimizers.AdamW(
learning_rate=self.args['lr_max'],
beta_1=self.args['beta0'],
beta_2=self.args['beta1'],
epsilon=self.args['epsilon'],
weight_decay=self.args['weight_decay']
)
return optimizer return optimizer