f
This commit is contained in:
@@ -93,25 +93,8 @@ class BrainToTextDecoderTrainerTF:
|
||||
|
||||
print("🔧 Initializing optimizer for TPU training...")
|
||||
print(f"Optimizer type: {type(self.optimizer).__name__}")
|
||||
|
||||
# Initialize optimizer slot variables using strategy.run
|
||||
# This ensures we're in the correct replica context
|
||||
print("🔧 Creating optimizer slot variables within TPU replica context...")
|
||||
|
||||
@tf.function
|
||||
def init_optimizer_slots():
|
||||
# Use ALL trainable variables for slot initialization, not just filtered ones
|
||||
# This ensures slot variables are created for all variables that might need gradients
|
||||
all_variables = self.model.trainable_variables
|
||||
dummy_gradients = [tf.zeros_like(var) for var in all_variables]
|
||||
|
||||
# Apply gradients for all variables to ensure all slots are created
|
||||
self.optimizer.apply_gradients(zip(dummy_gradients, all_variables))
|
||||
return tf.constant(True) # Return something to satisfy strategy.run
|
||||
|
||||
# Run the slot initialization in replica context
|
||||
self.strategy.run(init_optimizer_slots)
|
||||
print("✅ Optimizer ready for TPU training")
|
||||
print("📝 Note: Optimizer slot variables will be created automatically during first training step")
|
||||
|
||||
self.lr_scheduler = self._create_lr_scheduler()
|
||||
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
||||
@@ -503,7 +486,6 @@ class BrainToTextDecoderTrainerTF:
|
||||
else:
|
||||
print(f"Model has {total_params:,} trainable parameters")
|
||||
|
||||
@tf.function
|
||||
def _train_step(self, batch, step):
|
||||
"""Single training step with gradient tape"""
|
||||
features = batch['input_features']
|
||||
|
Reference in New Issue
Block a user