From a96e272f7bc41caeaa353ece3c25919d5439c6f4 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Fri, 17 Oct 2025 00:51:53 +0800 Subject: [PATCH] fix twice gradient cut --- model_training_nnn_tpu/ISSUES.md | 36 ++++++++++++++++++++++++++++ model_training_nnn_tpu/trainer_tf.py | 14 ++--------- 2 files changed, 38 insertions(+), 12 deletions(-) create mode 100644 model_training_nnn_tpu/ISSUES.md diff --git a/model_training_nnn_tpu/ISSUES.md b/model_training_nnn_tpu/ISSUES.md new file mode 100644 index 0000000..0b013cc --- /dev/null +++ b/model_training_nnn_tpu/ISSUES.md @@ -0,0 +1,36 @@ +# ISSUES + +## 双重梯度裁剪 +优化器级别:global_clipnorm=self.args.get('grad_norm_clip_value', 0.0)(第283行) +手动级别:tf.clip_by_global_norm(第447-449行) +这导致梯度被裁剪两次,并且在TPU的分布式训练中可能引发内部状态冲突。 + +修复总结 +问题根源:双重梯度裁剪导致AdamW内部状态冲突 修复内容: +移除了优化器级别的梯度裁剪:删除了 global_clipnorm 参数 +保留手动梯度裁剪:在 _train_step 中继续使用 tf.clip_by_global_norm +为什么会出错: +```python +# 之前:双重裁剪 +optimizer = tf.keras.optimizers.AdamW( + global_clipnorm=clip_value # 第一次裁剪 +) +``` +```python +# 在 _train_step 中: +tf.clip_by_global_norm(gradients, clip_value) # 第二次裁剪 +optimizer.apply_gradients(...) # 内部再次处理,导致冲突 +现在的修复: +``` +```python +# 现在:只有一次裁剪 +optimizer = tf.keras.optimizers.AdamW( + # 没有 global_clipnorm +) + +``` +```python +# 在 _train_step 中: +tf.clip_by_global_norm(gradients, clip_value) # 唯一的裁剪 +optimizer.apply_gradients(...) # 正常工作 +``` diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 519fa5a..88f6df7 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -88,9 +88,7 @@ class BrainToTextDecoderTrainerTF: # Build model within strategy scope with self.strategy.scope(): - print("🔨 Building model within TPU strategy scope...") self.model = self._build_model() - print("⚙️ Creating optimizer...") self.optimizer = self._create_optimizer() print("🔧 Pre-building optimizer state for TPU...") # For TPU, we must ensure optimizer is completely ready before training @@ -125,13 +123,8 @@ class BrainToTextDecoderTrainerTF: print(f"Full traceback: {traceback.format_exc()}") raise RuntimeError(f"Optimizer pre-build failed: {e}") from e - print("📅 Setting up learning rate scheduler...") self.lr_scheduler = self._create_lr_scheduler() - print("✅ LR scheduler ready") - - print("🎯 Initializing CTC loss...") self.ctc_loss = CTCLoss(blank_index=0, reduction='none') - print("✅ CTC loss initialized") # Log model information self._log_model_info() @@ -452,12 +445,9 @@ class BrainToTextDecoderTrainerTF: beta_1=self.args['beta0'], beta_2=self.args['beta1'], epsilon=self.args['epsilon'], - weight_decay=0.0, # Disabled for TPU compatibility - # TPU-specific settings - global_clipnorm=self.args.get('grad_norm_clip_value', 0.0) if self.args.get('grad_norm_clip_value', 0.0) > 0 else None + weight_decay=0.0 # Disabled for TPU compatibility + # REMOVE global_clipnorm to avoid double clipping with manual tf.clip_by_global_norm ) - print(f"⚠️ Weight decay disabled for TPU compatibility (was {self.args['weight_decay']})") - print("💡 Consider implementing manual L2 regularization if needed") else: print("Using standard optimizer configuration") optimizer = tf.keras.optimizers.AdamW(