This commit is contained in:
Zchen
2025-10-17 01:49:03 +08:00
parent 59fb73ee9f
commit a5a3179ca6

View File

@@ -90,31 +90,16 @@ class BrainToTextDecoderTrainerTF:
with self.strategy.scope():
self.model = self._build_model()
self.optimizer = self._create_optimizer()
print("🔧 Initializing optimizer for TPU training...")
# For TPU, we initialize the optimizer by accessing its basic properties
# The optimizer will be properly built when first used in training
try:
print("✅ Checking optimizer initialization...")
print(f"Optimizer type: {type(self.optimizer).__name__}")
# Access optimizer properties to ensure it's properly initialized
# This is safe and works with all TensorFlow/Keras optimizer versions
print(f"Optimizer type: {type(self.optimizer).__name__}")
print(f"Learning rate: {self.optimizer.learning_rate}")
# Access iterations to ensure optimizer state tracking is ready
# This creates the iterations variable without building the full state
iterations = self.optimizer.iterations
print(f"Optimizer iterations initialized: {iterations}")
print("✅ Optimizer ready for TPU training")
print("📝 Note: Optimizer state will be built automatically during first training step")
except Exception as e:
print(f"❌ CRITICAL: Could not initialize optimizer: {e}")
print(f"Error type: {type(e).__name__}")
import traceback
print(f"Full traceback: {traceback.format_exc()}")
raise RuntimeError(f"Optimizer initialization failed: {e}") from e
# Initialize optimizer slot variables within strategy scope
# This prevents the "different scope" error
print("🔧 Creating optimizer slot variables within TPU strategy scope...")
dummy_gradients = [tf.zeros_like(var) for var in self.model.trainable_variables]
self.optimizer.apply_gradients(zip(dummy_gradients, self.model.trainable_variables))
print("✅ Optimizer ready for TPU training")
self.lr_scheduler = self._create_lr_scheduler()
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')