fix twice gradient cut

This commit is contained in:
Zchen
2025-10-17 00:51:53 +08:00
parent 7a43ebfb71
commit a96e272f7b
2 changed files with 38 additions and 12 deletions

View File

@@ -0,0 +1,36 @@
# ISSUES
## 双重梯度裁剪
优化器级别global_clipnorm=self.args.get('grad_norm_clip_value', 0.0)第283行
手动级别tf.clip_by_global_norm第447-449行
这导致梯度被裁剪两次并且在TPU的分布式训练中可能引发内部状态冲突。
修复总结
问题根源双重梯度裁剪导致AdamW内部状态冲突 修复内容:
移除了优化器级别的梯度裁剪:删除了 global_clipnorm 参数
保留手动梯度裁剪:在 _train_step 中继续使用 tf.clip_by_global_norm
为什么会出错:
```python
# 之前:双重裁剪
optimizer = tf.keras.optimizers.AdamW(
global_clipnorm=clip_value # 第一次裁剪
)
```
```python
# 在 _train_step 中:
tf.clip_by_global_norm(gradients, clip_value) # 第二次裁剪
optimizer.apply_gradients(...) # 内部再次处理,导致冲突
现在的修复
```
```python
# 现在:只有一次裁剪
optimizer = tf.keras.optimizers.AdamW(
# 没有 global_clipnorm
)
```
```python
# 在 _train_step 中:
tf.clip_by_global_norm(gradients, clip_value) # 唯一的裁剪
optimizer.apply_gradients(...) # 正常工作
```

View File

@@ -88,9 +88,7 @@ class BrainToTextDecoderTrainerTF:
# Build model within strategy scope # Build model within strategy scope
with self.strategy.scope(): with self.strategy.scope():
print("🔨 Building model within TPU strategy scope...")
self.model = self._build_model() self.model = self._build_model()
print("⚙️ Creating optimizer...")
self.optimizer = self._create_optimizer() self.optimizer = self._create_optimizer()
print("🔧 Pre-building optimizer state for TPU...") print("🔧 Pre-building optimizer state for TPU...")
# For TPU, we must ensure optimizer is completely ready before training # For TPU, we must ensure optimizer is completely ready before training
@@ -125,13 +123,8 @@ class BrainToTextDecoderTrainerTF:
print(f"Full traceback: {traceback.format_exc()}") print(f"Full traceback: {traceback.format_exc()}")
raise RuntimeError(f"Optimizer pre-build failed: {e}") from e raise RuntimeError(f"Optimizer pre-build failed: {e}") from e
print("📅 Setting up learning rate scheduler...")
self.lr_scheduler = self._create_lr_scheduler() self.lr_scheduler = self._create_lr_scheduler()
print("✅ LR scheduler ready")
print("🎯 Initializing CTC loss...")
self.ctc_loss = CTCLoss(blank_index=0, reduction='none') self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
print("✅ CTC loss initialized")
# Log model information # Log model information
self._log_model_info() self._log_model_info()
@@ -452,12 +445,9 @@ class BrainToTextDecoderTrainerTF:
beta_1=self.args['beta0'], beta_1=self.args['beta0'],
beta_2=self.args['beta1'], beta_2=self.args['beta1'],
epsilon=self.args['epsilon'], epsilon=self.args['epsilon'],
weight_decay=0.0, # Disabled for TPU compatibility weight_decay=0.0 # Disabled for TPU compatibility
# TPU-specific settings # REMOVE global_clipnorm to avoid double clipping with manual tf.clip_by_global_norm
global_clipnorm=self.args.get('grad_norm_clip_value', 0.0) if self.args.get('grad_norm_clip_value', 0.0) > 0 else None
) )
print(f"⚠️ Weight decay disabled for TPU compatibility (was {self.args['weight_decay']})")
print("💡 Consider implementing manual L2 regularization if needed")
else: else:
print("Using standard optimizer configuration") print("Using standard optimizer configuration")
optimizer = tf.keras.optimizers.AdamW( optimizer = tf.keras.optimizers.AdamW(