From 7a43ebfb71a7b4eb638a50b961b4c487268bedb1 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Thu, 16 Oct 2025 23:06:09 +0800 Subject: [PATCH] refactor: streamline model building and ensure dtype consistency in L2 loss calculation --- model_training_nnn_tpu/trainer_tf.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 5087c43..519fa5a 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -90,12 +90,8 @@ class BrainToTextDecoderTrainerTF: with self.strategy.scope(): print("🔨 Building model within TPU strategy scope...") self.model = self._build_model() - print("✅ Model built successfully") - print("⚙️ Creating optimizer...") self.optimizer = self._create_optimizer() - print("✅ Optimizer created") - print("🔧 Pre-building optimizer state for TPU...") # For TPU, we must ensure optimizer is completely ready before training # since @tf.function doesn't allow dynamic building @@ -595,7 +591,10 @@ class BrainToTextDecoderTrainerTF: if self.manual_weight_decay: l2_loss = tf.constant(0.0, dtype=loss.dtype) for var in self.model.trainable_variables: - l2_loss += tf.nn.l2_loss(var) + # Ensure dtype consistency for mixed precision training + var_l2 = tf.nn.l2_loss(var) + var_l2 = tf.cast(var_l2, dtype=loss.dtype) # Cast to match loss dtype + l2_loss += var_l2 loss += self.weight_decay_rate * l2_loss # TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理