f
This commit is contained in:
@@ -148,11 +148,24 @@ class BrainToTextDecoderTrainerTF:
|
||||
self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0))
|
||||
self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0))
|
||||
|
||||
# TPU-specific weight decay handling
|
||||
self.manual_weight_decay = False
|
||||
if isinstance(self.strategy, tf.distribute.TPUStrategy) and self.args.get('weight_decay', 0.0) > 0:
|
||||
self.manual_weight_decay = True
|
||||
self.weight_decay_rate = self.args['weight_decay']
|
||||
print(f"🔧 Manual L2 regularization enabled: {self.weight_decay_rate}")
|
||||
|
||||
if self.adv_enabled:
|
||||
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, "
|
||||
f"noisy_loss_weight={self.adv_noisy_loss_weight}, "
|
||||
f"noise_l2_weight={self.adv_noise_l2_weight}, "
|
||||
f"warmup_steps={self.adv_warmup_steps}")
|
||||
if self.logger:
|
||||
self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, "
|
||||
f"noisy_loss_weight={self.adv_noisy_loss_weight}, "
|
||||
f"noise_l2_weight={self.adv_noise_l2_weight}, "
|
||||
f"warmup_steps={self.adv_warmup_steps}")
|
||||
else:
|
||||
print(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, "
|
||||
f"noisy_loss_weight={self.adv_noisy_loss_weight}, "
|
||||
f"noise_l2_weight={self.adv_noise_l2_weight}, "
|
||||
f"warmup_steps={self.adv_warmup_steps}")
|
||||
|
||||
def _setup_logging(self):
|
||||
"""Setup logging configuration"""
|
||||
@@ -436,15 +449,19 @@ class BrainToTextDecoderTrainerTF:
|
||||
if isinstance(self.strategy, tf.distribute.TPUStrategy):
|
||||
print("Using TPU-optimized optimizer configuration")
|
||||
# TPU-specific optimizer configuration
|
||||
# IMPORTANT: Disable weight_decay for TPU due to distributed training compatibility issues
|
||||
# We'll implement manual L2 regularization instead
|
||||
optimizer = tf.keras.optimizers.AdamW(
|
||||
learning_rate=self.args['lr_max'],
|
||||
beta_1=self.args['beta0'],
|
||||
beta_2=self.args['beta1'],
|
||||
epsilon=self.args['epsilon'],
|
||||
weight_decay=self.args['weight_decay'],
|
||||
weight_decay=0.0, # Disabled for TPU compatibility
|
||||
# TPU-specific settings
|
||||
global_clipnorm=self.args.get('grad_norm_clip_value', 0.0) if self.args.get('grad_norm_clip_value', 0.0) > 0 else None
|
||||
)
|
||||
print(f"⚠️ Weight decay disabled for TPU compatibility (was {self.args['weight_decay']})")
|
||||
print("💡 Consider implementing manual L2 regularization if needed")
|
||||
else:
|
||||
print("Using standard optimizer configuration")
|
||||
optimizer = tf.keras.optimizers.AdamW(
|
||||
@@ -574,6 +591,13 @@ class BrainToTextDecoderTrainerTF:
|
||||
loss = self.ctc_loss(loss_input, clean_logits)
|
||||
loss = tf.reduce_mean(loss)
|
||||
|
||||
# Add manual L2 regularization for TPU (since weight_decay is disabled)
|
||||
if self.manual_weight_decay:
|
||||
l2_loss = tf.constant(0.0, dtype=loss.dtype)
|
||||
for var in self.model.trainable_variables:
|
||||
l2_loss += tf.nn.l2_loss(var)
|
||||
loss += self.weight_decay_rate * l2_loss
|
||||
|
||||
# TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理
|
||||
# TPU v5e-8使用bfloat16,不需要loss scaling
|
||||
|
||||
|
Reference in New Issue
Block a user