fix twice gradient cut
This commit is contained in:
@@ -88,9 +88,7 @@ class BrainToTextDecoderTrainerTF:
|
||||
|
||||
# Build model within strategy scope
|
||||
with self.strategy.scope():
|
||||
print("🔨 Building model within TPU strategy scope...")
|
||||
self.model = self._build_model()
|
||||
print("⚙️ Creating optimizer...")
|
||||
self.optimizer = self._create_optimizer()
|
||||
print("🔧 Pre-building optimizer state for TPU...")
|
||||
# For TPU, we must ensure optimizer is completely ready before training
|
||||
@@ -125,13 +123,8 @@ class BrainToTextDecoderTrainerTF:
|
||||
print(f"Full traceback: {traceback.format_exc()}")
|
||||
raise RuntimeError(f"Optimizer pre-build failed: {e}") from e
|
||||
|
||||
print("📅 Setting up learning rate scheduler...")
|
||||
self.lr_scheduler = self._create_lr_scheduler()
|
||||
print("✅ LR scheduler ready")
|
||||
|
||||
print("🎯 Initializing CTC loss...")
|
||||
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
||||
print("✅ CTC loss initialized")
|
||||
|
||||
# Log model information
|
||||
self._log_model_info()
|
||||
@@ -452,12 +445,9 @@ class BrainToTextDecoderTrainerTF:
|
||||
beta_1=self.args['beta0'],
|
||||
beta_2=self.args['beta1'],
|
||||
epsilon=self.args['epsilon'],
|
||||
weight_decay=0.0, # Disabled for TPU compatibility
|
||||
# TPU-specific settings
|
||||
global_clipnorm=self.args.get('grad_norm_clip_value', 0.0) if self.args.get('grad_norm_clip_value', 0.0) > 0 else None
|
||||
weight_decay=0.0 # Disabled for TPU compatibility
|
||||
# REMOVE global_clipnorm to avoid double clipping with manual tf.clip_by_global_norm
|
||||
)
|
||||
print(f"⚠️ Weight decay disabled for TPU compatibility (was {self.args['weight_decay']})")
|
||||
print("💡 Consider implementing manual L2 regularization if needed")
|
||||
else:
|
||||
print("Using standard optimizer configuration")
|
||||
optimizer = tf.keras.optimizers.AdamW(
|
||||
|
Reference in New Issue
Block a user