f
This commit is contained in:
@@ -93,25 +93,8 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
print("🔧 Initializing optimizer for TPU training...")
|
print("🔧 Initializing optimizer for TPU training...")
|
||||||
print(f"Optimizer type: {type(self.optimizer).__name__}")
|
print(f"Optimizer type: {type(self.optimizer).__name__}")
|
||||||
|
|
||||||
# Initialize optimizer slot variables using strategy.run
|
|
||||||
# This ensures we're in the correct replica context
|
|
||||||
print("🔧 Creating optimizer slot variables within TPU replica context...")
|
|
||||||
|
|
||||||
@tf.function
|
|
||||||
def init_optimizer_slots():
|
|
||||||
# Use ALL trainable variables for slot initialization, not just filtered ones
|
|
||||||
# This ensures slot variables are created for all variables that might need gradients
|
|
||||||
all_variables = self.model.trainable_variables
|
|
||||||
dummy_gradients = [tf.zeros_like(var) for var in all_variables]
|
|
||||||
|
|
||||||
# Apply gradients for all variables to ensure all slots are created
|
|
||||||
self.optimizer.apply_gradients(zip(dummy_gradients, all_variables))
|
|
||||||
return tf.constant(True) # Return something to satisfy strategy.run
|
|
||||||
|
|
||||||
# Run the slot initialization in replica context
|
|
||||||
self.strategy.run(init_optimizer_slots)
|
|
||||||
print("✅ Optimizer ready for TPU training")
|
print("✅ Optimizer ready for TPU training")
|
||||||
|
print("📝 Note: Optimizer slot variables will be created automatically during first training step")
|
||||||
|
|
||||||
self.lr_scheduler = self._create_lr_scheduler()
|
self.lr_scheduler = self._create_lr_scheduler()
|
||||||
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
||||||
@@ -503,7 +486,6 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
else:
|
else:
|
||||||
print(f"Model has {total_params:,} trainable parameters")
|
print(f"Model has {total_params:,} trainable parameters")
|
||||||
|
|
||||||
@tf.function
|
|
||||||
def _train_step(self, batch, step):
|
def _train_step(self, batch, step):
|
||||||
"""Single training step with gradient tape"""
|
"""Single training step with gradient tape"""
|
||||||
features = batch['input_features']
|
features = batch['input_features']
|
||||||
|
Reference in New Issue
Block a user