This commit is contained in:
Zchen
2025-10-17 10:53:58 +08:00
parent 7ede7b5f12
commit 6c7abfcca8

View File

@@ -93,25 +93,8 @@ class BrainToTextDecoderTrainerTF:
print("🔧 Initializing optimizer for TPU training...") print("🔧 Initializing optimizer for TPU training...")
print(f"Optimizer type: {type(self.optimizer).__name__}") print(f"Optimizer type: {type(self.optimizer).__name__}")
# Initialize optimizer slot variables using strategy.run
# This ensures we're in the correct replica context
print("🔧 Creating optimizer slot variables within TPU replica context...")
@tf.function
def init_optimizer_slots():
# Use ALL trainable variables for slot initialization, not just filtered ones
# This ensures slot variables are created for all variables that might need gradients
all_variables = self.model.trainable_variables
dummy_gradients = [tf.zeros_like(var) for var in all_variables]
# Apply gradients for all variables to ensure all slots are created
self.optimizer.apply_gradients(zip(dummy_gradients, all_variables))
return tf.constant(True) # Return something to satisfy strategy.run
# Run the slot initialization in replica context
self.strategy.run(init_optimizer_slots)
print("✅ Optimizer ready for TPU training") print("✅ Optimizer ready for TPU training")
print("📝 Note: Optimizer slot variables will be created automatically during first training step")
self.lr_scheduler = self._create_lr_scheduler() self.lr_scheduler = self._create_lr_scheduler()
self.ctc_loss = CTCLoss(blank_index=0, reduction='none') self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
@@ -503,7 +486,6 @@ class BrainToTextDecoderTrainerTF:
else: else:
print(f"Model has {total_params:,} trainable parameters") print(f"Model has {total_params:,} trainable parameters")
@tf.function
def _train_step(self, batch, step): def _train_step(self, batch, step):
"""Single training step with gradient tape""" """Single training step with gradient tape"""
features = batch['input_features'] features = batch['input_features']