f
This commit is contained in:
@@ -97,25 +97,16 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
print("✅ Optimizer created")
|
print("✅ Optimizer created")
|
||||||
|
|
||||||
print("🔧 Pre-building optimizer state for TPU...")
|
print("🔧 Pre-building optimizer state for TPU...")
|
||||||
# Force optimizer to build its internal state within strategy scope
|
# Build optimizer within strategy scope but don't apply gradients yet
|
||||||
# This prevents the 'NoneType' strategy error during first apply_gradients
|
# The actual gradient application will happen in distributed training context
|
||||||
try:
|
try:
|
||||||
print("✅ Building optimizer with complete state initialization...")
|
print("✅ Building optimizer with model variables...")
|
||||||
|
|
||||||
# First, explicitly build the optimizer with model variables
|
# Explicitly build the optimizer with model variables
|
||||||
print(f"Building optimizer with {len(self.model.trainable_variables)} variables")
|
print(f"Building optimizer with {len(self.model.trainable_variables)} variables")
|
||||||
self.optimizer.build(self.model.trainable_variables)
|
self.optimizer.build(self.model.trainable_variables)
|
||||||
print("✅ Optimizer built with model variables")
|
print("✅ Optimizer built with model variables")
|
||||||
|
|
||||||
# Create dummy gradients and variables for full state initialization
|
|
||||||
dummy_grads = [tf.zeros_like(var) for var in self.model.trainable_variables]
|
|
||||||
print(f"Created {len(dummy_grads)} dummy gradients")
|
|
||||||
|
|
||||||
# Apply dummy gradients to fully initialize optimizer state
|
|
||||||
# This ensures all optimizer variables are created within the strategy scope
|
|
||||||
self.optimizer.apply_gradients(zip(dummy_grads, self.model.trainable_variables))
|
|
||||||
print("✅ Optimizer state fully initialized with dummy gradients")
|
|
||||||
|
|
||||||
# Verify optimizer is properly built
|
# Verify optimizer is properly built
|
||||||
print(f"Optimizer iterations: {self.optimizer.iterations}")
|
print(f"Optimizer iterations: {self.optimizer.iterations}")
|
||||||
print(f"Optimizer built: {self.optimizer.built}")
|
print(f"Optimizer built: {self.optimizer.built}")
|
||||||
@@ -127,6 +118,7 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
print("⚠️ Optimizer has no internal variables - this might cause issues")
|
print("⚠️ Optimizer has no internal variables - this might cause issues")
|
||||||
|
|
||||||
print("✅ Optimizer pre-build completed successfully")
|
print("✅ Optimizer pre-build completed successfully")
|
||||||
|
print("📝 Note: Optimizer state will be fully initialized on first training step")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ CRITICAL: Could not pre-build optimizer state: {e}")
|
print(f"❌ CRITICAL: Could not pre-build optimizer state: {e}")
|
||||||
@@ -603,36 +595,32 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
# Apply gradients (only for variables that have gradients)
|
# Apply gradients (only for variables that have gradients)
|
||||||
if len(filtered_gradients) > 0:
|
if len(filtered_gradients) > 0:
|
||||||
# Apply gradients with comprehensive error handling
|
|
||||||
# The optimizer should already be built and have all necessary variables
|
|
||||||
try:
|
try:
|
||||||
# Check if optimizer is properly built before applying gradients
|
# Apply gradients - optimizer should be built and ready
|
||||||
if not self.optimizer.built:
|
# This will work correctly in distributed training context
|
||||||
print("WARNING: Optimizer not built, building now...")
|
|
||||||
# This should not happen if pre-build worked correctly
|
|
||||||
self.optimizer.build(filtered_variables)
|
|
||||||
|
|
||||||
# Apply gradients - this should work since optimizer is pre-built
|
|
||||||
self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables))
|
self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables))
|
||||||
|
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
print("CRITICAL ERROR in gradient application:")
|
if "merge_call" in str(e) or "replica_context" in str(e):
|
||||||
print(f"Error: {e}")
|
print("CRITICAL ERROR: Distributed training context issue")
|
||||||
print("This indicates the optimizer lost its strategy context")
|
print(f"Error: {e}")
|
||||||
print(f"Optimizer built: {self.optimizer.built}")
|
print("This indicates TPU strategy context is not properly set up")
|
||||||
print(f"Number of gradients: {len(filtered_gradients)}")
|
|
||||||
print(f"Number of variables: {len(filtered_variables)}")
|
|
||||||
|
|
||||||
# Check current strategy
|
# Try to get current strategy and replica context info
|
||||||
current_strategy = tf.distribute.get_strategy()
|
try:
|
||||||
print(f"Current strategy: {type(current_strategy).__name__}")
|
current_strategy = tf.distribute.get_strategy()
|
||||||
print(f"Training strategy: {type(self.strategy).__name__}")
|
replica_context = tf.distribute.get_replica_context()
|
||||||
|
print(f"Current strategy: {type(current_strategy).__name__}")
|
||||||
|
print(f"Replica context: {replica_context}")
|
||||||
|
except:
|
||||||
|
print("Could not get strategy/context information")
|
||||||
|
|
||||||
# Re-raise with more context
|
raise RuntimeError(f"TPU distributed training context error: {e}")
|
||||||
raise RuntimeError(f"Gradient application failed - optimizer strategy context lost: {e}")
|
else:
|
||||||
|
print(f"Optimizer AttributeError: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Catch any other errors during gradient application
|
|
||||||
print("Unexpected error during gradient application:")
|
print("Unexpected error during gradient application:")
|
||||||
print(f"Error type: {type(e).__name__}")
|
print(f"Error type: {type(e).__name__}")
|
||||||
print(f"Error message: {e}")
|
print(f"Error message: {e}")
|
||||||
|
Reference in New Issue
Block a user