adamw to adam
This commit is contained in:
		| @@ -137,12 +137,14 @@ class BrainToTextDecoderTrainerTF: | |||||||
|         self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) |         self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) | ||||||
|         self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) |         self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) | ||||||
|  |  | ||||||
|         # TPU-specific weight decay handling |         # Manual weight decay handling for all environments (since we use Adam) | ||||||
|         self.manual_weight_decay = False |         self.manual_weight_decay = False | ||||||
|         if isinstance(self.strategy, tf.distribute.TPUStrategy) and self.args.get('weight_decay', 0.0) > 0: |         if self.args.get('weight_decay', 0.0) > 0: | ||||||
|             self.manual_weight_decay = True |             self.manual_weight_decay = True | ||||||
|             self.weight_decay_rate = self.args['weight_decay'] |             self.weight_decay_rate = self.args['weight_decay'] | ||||||
|             print(f"🔧 Manual L2 regularization enabled: {self.weight_decay_rate}") |             print(f"🔧 Manual L2 regularization enabled: {self.weight_decay_rate}") | ||||||
|  |         else: | ||||||
|  |             print("💡 No weight decay configured") | ||||||
|  |  | ||||||
|         if self.adv_enabled: |         if self.adv_enabled: | ||||||
|             if self.logger: |             if self.logger: | ||||||
| @@ -435,28 +437,19 @@ class BrainToTextDecoderTrainerTF: | |||||||
|  |  | ||||||
|         # For TPU training, we need to be more explicit about optimizer configuration |         # For TPU training, we need to be more explicit about optimizer configuration | ||||||
|         # to avoid strategy context issues |         # to avoid strategy context issues | ||||||
|         if isinstance(self.strategy, tf.distribute.TPUStrategy): |         # IMPORTANT: Use Adam instead of AdamW to avoid TPU distributed training bugs | ||||||
|             print("Using TPU-optimized optimizer configuration") |         # AdamW has known issues with _apply_weight_decay in TPU environments even when weight_decay=0.0 | ||||||
|             # TPU-specific optimizer configuration |         # We implement manual L2 regularization (weight decay) in the training step instead | ||||||
|             # IMPORTANT: Disable weight_decay for TPU due to distributed training compatibility issues |         print("Using TPU-compatible Adam optimizer (avoiding AdamW distributed training bugs)") | ||||||
|             # We'll implement manual L2 regularization instead |         print("💡 Manual L2 regularization will be applied in training step") | ||||||
|             optimizer = tf.keras.optimizers.AdamW( |  | ||||||
|                 learning_rate=self.args['lr_max'], |         optimizer = tf.keras.optimizers.Adam( | ||||||
|                 beta_1=self.args['beta0'], |             learning_rate=self.args['lr_max'], | ||||||
|                 beta_2=self.args['beta1'], |             beta_1=self.args['beta0'], | ||||||
|                 epsilon=self.args['epsilon'], |             beta_2=self.args['beta1'], | ||||||
|                 weight_decay=0.0  # Disabled for TPU compatibility |             epsilon=self.args['epsilon'] | ||||||
|                 # REMOVE global_clipnorm to avoid double clipping with manual tf.clip_by_global_norm |             # No weight_decay parameter in Adam - handled manually | ||||||
|             ) |         ) | ||||||
|         else: |  | ||||||
|             print("Using standard optimizer configuration") |  | ||||||
|             optimizer = tf.keras.optimizers.AdamW( |  | ||||||
|                 learning_rate=self.args['lr_max'], |  | ||||||
|                 beta_1=self.args['beta0'], |  | ||||||
|                 beta_2=self.args['beta1'], |  | ||||||
|                 epsilon=self.args['epsilon'], |  | ||||||
|                 weight_decay=self.args['weight_decay'] |  | ||||||
|             ) |  | ||||||
|  |  | ||||||
|         return optimizer |         return optimizer | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Zchen
					Zchen