diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 970d858..175a2b3 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -97,25 +97,16 @@ class BrainToTextDecoderTrainerTF: print("✅ Optimizer created") print("🔧 Pre-building optimizer state for TPU...") - # Force optimizer to build its internal state within strategy scope - # This prevents the 'NoneType' strategy error during first apply_gradients + # Build optimizer within strategy scope but don't apply gradients yet + # The actual gradient application will happen in distributed training context try: - print("✅ Building optimizer with complete state initialization...") + print("✅ Building optimizer with model variables...") - # First, explicitly build the optimizer with model variables + # Explicitly build the optimizer with model variables print(f"Building optimizer with {len(self.model.trainable_variables)} variables") self.optimizer.build(self.model.trainable_variables) print("✅ Optimizer built with model variables") - # Create dummy gradients and variables for full state initialization - dummy_grads = [tf.zeros_like(var) for var in self.model.trainable_variables] - print(f"Created {len(dummy_grads)} dummy gradients") - - # Apply dummy gradients to fully initialize optimizer state - # This ensures all optimizer variables are created within the strategy scope - self.optimizer.apply_gradients(zip(dummy_grads, self.model.trainable_variables)) - print("✅ Optimizer state fully initialized with dummy gradients") - # Verify optimizer is properly built print(f"Optimizer iterations: {self.optimizer.iterations}") print(f"Optimizer built: {self.optimizer.built}") @@ -127,6 +118,7 @@ class BrainToTextDecoderTrainerTF: print("⚠️ Optimizer has no internal variables - this might cause issues") print("✅ Optimizer pre-build completed successfully") + print("📝 Note: Optimizer state will be fully initialized on first training step") except Exception as e: print(f"❌ CRITICAL: Could not pre-build optimizer state: {e}") @@ -603,36 +595,32 @@ class BrainToTextDecoderTrainerTF: # Apply gradients (only for variables that have gradients) if len(filtered_gradients) > 0: - # Apply gradients with comprehensive error handling - # The optimizer should already be built and have all necessary variables try: - # Check if optimizer is properly built before applying gradients - if not self.optimizer.built: - print("WARNING: Optimizer not built, building now...") - # This should not happen if pre-build worked correctly - self.optimizer.build(filtered_variables) - - # Apply gradients - this should work since optimizer is pre-built + # Apply gradients - optimizer should be built and ready + # This will work correctly in distributed training context self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables)) except AttributeError as e: - print("CRITICAL ERROR in gradient application:") - print(f"Error: {e}") - print("This indicates the optimizer lost its strategy context") - print(f"Optimizer built: {self.optimizer.built}") - print(f"Number of gradients: {len(filtered_gradients)}") - print(f"Number of variables: {len(filtered_variables)}") + if "merge_call" in str(e) or "replica_context" in str(e): + print("CRITICAL ERROR: Distributed training context issue") + print(f"Error: {e}") + print("This indicates TPU strategy context is not properly set up") - # Check current strategy - current_strategy = tf.distribute.get_strategy() - print(f"Current strategy: {type(current_strategy).__name__}") - print(f"Training strategy: {type(self.strategy).__name__}") + # Try to get current strategy and replica context info + try: + current_strategy = tf.distribute.get_strategy() + replica_context = tf.distribute.get_replica_context() + print(f"Current strategy: {type(current_strategy).__name__}") + print(f"Replica context: {replica_context}") + except: + print("Could not get strategy/context information") - # Re-raise with more context - raise RuntimeError(f"Gradient application failed - optimizer strategy context lost: {e}") + raise RuntimeError(f"TPU distributed training context error: {e}") + else: + print(f"Optimizer AttributeError: {e}") + raise except Exception as e: - # Catch any other errors during gradient application print("Unexpected error during gradient application:") print(f"Error type: {type(e).__name__}") print(f"Error message: {e}")