fix twice gradient cut

This commit is contained in:
Zchen
2025-10-17 00:51:53 +08:00
parent 7a43ebfb71
commit a96e272f7b
2 changed files with 38 additions and 12 deletions

View File

@@ -88,9 +88,7 @@ class BrainToTextDecoderTrainerTF:
# Build model within strategy scope
with self.strategy.scope():
print("🔨 Building model within TPU strategy scope...")
self.model = self._build_model()
print("⚙️ Creating optimizer...")
self.optimizer = self._create_optimizer()
print("🔧 Pre-building optimizer state for TPU...")
# For TPU, we must ensure optimizer is completely ready before training
@@ -125,13 +123,8 @@ class BrainToTextDecoderTrainerTF:
print(f"Full traceback: {traceback.format_exc()}")
raise RuntimeError(f"Optimizer pre-build failed: {e}") from e
print("📅 Setting up learning rate scheduler...")
self.lr_scheduler = self._create_lr_scheduler()
print("✅ LR scheduler ready")
print("🎯 Initializing CTC loss...")
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
print("✅ CTC loss initialized")
# Log model information
self._log_model_info()
@@ -452,12 +445,9 @@ class BrainToTextDecoderTrainerTF:
beta_1=self.args['beta0'],
beta_2=self.args['beta1'],
epsilon=self.args['epsilon'],
weight_decay=0.0, # Disabled for TPU compatibility
# TPU-specific settings
global_clipnorm=self.args.get('grad_norm_clip_value', 0.0) if self.args.get('grad_norm_clip_value', 0.0) > 0 else None
weight_decay=0.0 # Disabled for TPU compatibility
# REMOVE global_clipnorm to avoid double clipping with manual tf.clip_by_global_norm
)
print(f"⚠️ Weight decay disabled for TPU compatibility (was {self.args['weight_decay']})")
print("💡 Consider implementing manual L2 regularization if needed")
else:
print("Using standard optimizer configuration")
optimizer = tf.keras.optimizers.AdamW(