diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 0af1b12..5087c43 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -148,11 +148,24 @@ class BrainToTextDecoderTrainerTF: self.adv_noise_l2_weight = float(adv_cfg.get('noise_l2_weight', 0.0)) self.adv_warmup_steps = int(adv_cfg.get('warmup_steps', 0)) + # TPU-specific weight decay handling + self.manual_weight_decay = False + if isinstance(self.strategy, tf.distribute.TPUStrategy) and self.args.get('weight_decay', 0.0) > 0: + self.manual_weight_decay = True + self.weight_decay_rate = self.args['weight_decay'] + print(f"🔧 Manual L2 regularization enabled: {self.weight_decay_rate}") + if self.adv_enabled: - self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, " - f"noisy_loss_weight={self.adv_noisy_loss_weight}, " - f"noise_l2_weight={self.adv_noise_l2_weight}, " - f"warmup_steps={self.adv_warmup_steps}") + if self.logger: + self.logger.info(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, " + f"noisy_loss_weight={self.adv_noisy_loss_weight}, " + f"noise_l2_weight={self.adv_noise_l2_weight}, " + f"warmup_steps={self.adv_warmup_steps}") + else: + print(f"Adversarial training ENABLED | grl_lambda={self.adv_grl_lambda}, " + f"noisy_loss_weight={self.adv_noisy_loss_weight}, " + f"noise_l2_weight={self.adv_noise_l2_weight}, " + f"warmup_steps={self.adv_warmup_steps}") def _setup_logging(self): """Setup logging configuration""" @@ -436,15 +449,19 @@ class BrainToTextDecoderTrainerTF: if isinstance(self.strategy, tf.distribute.TPUStrategy): print("Using TPU-optimized optimizer configuration") # TPU-specific optimizer configuration + # IMPORTANT: Disable weight_decay for TPU due to distributed training compatibility issues + # We'll implement manual L2 regularization instead optimizer = tf.keras.optimizers.AdamW( learning_rate=self.args['lr_max'], beta_1=self.args['beta0'], beta_2=self.args['beta1'], epsilon=self.args['epsilon'], - weight_decay=self.args['weight_decay'], + weight_decay=0.0, # Disabled for TPU compatibility # TPU-specific settings global_clipnorm=self.args.get('grad_norm_clip_value', 0.0) if self.args.get('grad_norm_clip_value', 0.0) > 0 else None ) + print(f"⚠️ Weight decay disabled for TPU compatibility (was {self.args['weight_decay']})") + print("💡 Consider implementing manual L2 regularization if needed") else: print("Using standard optimizer configuration") optimizer = tf.keras.optimizers.AdamW( @@ -574,6 +591,13 @@ class BrainToTextDecoderTrainerTF: loss = self.ctc_loss(loss_input, clean_logits) loss = tf.reduce_mean(loss) + # Add manual L2 regularization for TPU (since weight_decay is disabled) + if self.manual_weight_decay: + l2_loss = tf.constant(0.0, dtype=loss.dtype) + for var in self.model.trainable_variables: + l2_loss += tf.nn.l2_loss(var) + loss += self.weight_decay_rate * l2_loss + # TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理 # TPU v5e-8使用bfloat16,不需要loss scaling