From 1e7077bba7c7eb7767170d285a036fe10f11b9e9 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Thu, 16 Oct 2025 20:44:55 +0800 Subject: [PATCH] =?UTF-8?q?adamw=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model_training_nnn_tpu/trainer_tf.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index d4598c6..1917585 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -511,18 +511,11 @@ class BrainToTextDecoderTrainerTF: loss = self.ctc_loss(loss_input, clean_logits) loss = tf.reduce_mean(loss) - # Scale loss for mixed precision - if self.mixed_precision: - scaled_loss = self.optimizer.get_scaled_loss(loss) - else: - scaled_loss = loss + # TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理 + # TPU v5e-8使用bfloat16,不需要loss scaling - # Calculate gradients - if self.mixed_precision: - scaled_gradients = tape.gradient(scaled_loss, self.model.trainable_variables) - gradients = self.optimizer.get_unscaled_gradients(scaled_gradients) - else: - gradients = tape.gradient(scaled_loss, self.model.trainable_variables) + # Calculate gradients - TensorFlow自动处理混合精度 + gradients = tape.gradient(loss, self.model.trainable_variables) # Clip gradients if self.args['grad_norm_clip_value'] > 0: