fix
This commit is contained in:
@@ -100,27 +100,40 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
# Force optimizer to build its internal state within strategy scope
|
# Force optimizer to build its internal state within strategy scope
|
||||||
# This prevents the 'NoneType' strategy error during first apply_gradients
|
# This prevents the 'NoneType' strategy error during first apply_gradients
|
||||||
try:
|
try:
|
||||||
# Check if strategy is properly initialized before applying gradients
|
print("✅ Building optimizer with complete state initialization...")
|
||||||
if hasattr(self.strategy, 'merge_call') and callable(getattr(self.strategy, 'merge_call')):
|
|
||||||
print("✅ Strategy has merge_call, building optimizer properly...")
|
|
||||||
|
|
||||||
# Build optimizer by explicitly calling build method
|
# First, explicitly build the optimizer with model variables
|
||||||
self.optimizer.build(self.model.trainable_variables)
|
print(f"Building optimizer with {len(self.model.trainable_variables)} variables")
|
||||||
print("✅ Optimizer built with model variables")
|
self.optimizer.build(self.model.trainable_variables)
|
||||||
|
print("✅ Optimizer built with model variables")
|
||||||
|
|
||||||
# Test with dummy gradients to ensure everything works
|
# Create dummy gradients and variables for full state initialization
|
||||||
dummy_grads = [tf.zeros_like(w) for w in self.model.trainable_variables]
|
dummy_grads = [tf.zeros_like(var) for var in self.model.trainable_variables]
|
||||||
self.optimizer.apply_gradients(zip(dummy_grads, self.model.trainable_variables))
|
print(f"Created {len(dummy_grads)} dummy gradients")
|
||||||
print("✅ Optimizer state pre-built successfully with TPU strategy")
|
|
||||||
|
# Apply dummy gradients to fully initialize optimizer state
|
||||||
|
# This ensures all optimizer variables are created within the strategy scope
|
||||||
|
self.optimizer.apply_gradients(zip(dummy_grads, self.model.trainable_variables))
|
||||||
|
print("✅ Optimizer state fully initialized with dummy gradients")
|
||||||
|
|
||||||
|
# Verify optimizer is properly built
|
||||||
|
print(f"Optimizer iterations: {self.optimizer.iterations}")
|
||||||
|
print(f"Optimizer built: {self.optimizer.built}")
|
||||||
|
|
||||||
|
# Print optimizer variable names for debugging
|
||||||
|
if hasattr(self.optimizer, 'variables') and self.optimizer.variables:
|
||||||
|
print(f"Optimizer has {len(self.optimizer.variables)} internal variables")
|
||||||
else:
|
else:
|
||||||
# Fallback: just build optimizer variables without applying gradients
|
print("⚠️ Optimizer has no internal variables - this might cause issues")
|
||||||
print("⚠️ Strategy not fully initialized, using fallback optimizer build")
|
|
||||||
# Force build the optimizer with the model variables
|
print("✅ Optimizer pre-build completed successfully")
|
||||||
self.optimizer.build(self.model.trainable_variables)
|
|
||||||
print("✅ Optimizer built in fallback mode")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"⚠️ Warning: Could not pre-build optimizer state: {e}")
|
print(f"❌ CRITICAL: Could not pre-build optimizer state: {e}")
|
||||||
print("✅ Continuing without optimizer pre-build - optimizer will build during first training step")
|
print(f"Error type: {type(e).__name__}")
|
||||||
|
import traceback
|
||||||
|
print(f"Full traceback: {traceback.format_exc()}")
|
||||||
|
raise RuntimeError(f"Optimizer pre-build failed: {e}") from e
|
||||||
|
|
||||||
print("📅 Setting up learning rate scheduler...")
|
print("📅 Setting up learning rate scheduler...")
|
||||||
self.lr_scheduler = self._create_lr_scheduler()
|
self.lr_scheduler = self._create_lr_scheduler()
|
||||||
@@ -422,15 +435,31 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
# Note: TensorFlow doesn't have the same parameter group functionality as PyTorch
|
# Note: TensorFlow doesn't have the same parameter group functionality as PyTorch
|
||||||
# We'll use a single optimizer and handle different learning rates in the scheduler
|
# We'll use a single optimizer and handle different learning rates in the scheduler
|
||||||
|
|
||||||
# Create optimizer within strategy scope to ensure proper initialization
|
|
||||||
print(f"Creating optimizer with strategy: {type(self.strategy).__name__}")
|
print(f"Creating optimizer with strategy: {type(self.strategy).__name__}")
|
||||||
optimizer = tf.keras.optimizers.AdamW(
|
|
||||||
learning_rate=self.args['lr_max'],
|
# For TPU training, we need to be more explicit about optimizer configuration
|
||||||
beta_1=self.args['beta0'],
|
# to avoid strategy context issues
|
||||||
beta_2=self.args['beta1'],
|
if isinstance(self.strategy, tf.distribute.TPUStrategy):
|
||||||
epsilon=self.args['epsilon'],
|
print("Using TPU-optimized optimizer configuration")
|
||||||
weight_decay=self.args['weight_decay']
|
# TPU-specific optimizer configuration
|
||||||
)
|
optimizer = tf.keras.optimizers.AdamW(
|
||||||
|
learning_rate=self.args['lr_max'],
|
||||||
|
beta_1=self.args['beta0'],
|
||||||
|
beta_2=self.args['beta1'],
|
||||||
|
epsilon=self.args['epsilon'],
|
||||||
|
weight_decay=self.args['weight_decay'],
|
||||||
|
# TPU-specific settings
|
||||||
|
global_clipnorm=self.args.get('grad_norm_clip_value', 0.0) if self.args.get('grad_norm_clip_value', 0.0) > 0 else None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("Using standard optimizer configuration")
|
||||||
|
optimizer = tf.keras.optimizers.AdamW(
|
||||||
|
learning_rate=self.args['lr_max'],
|
||||||
|
beta_1=self.args['beta0'],
|
||||||
|
beta_2=self.args['beta1'],
|
||||||
|
epsilon=self.args['epsilon'],
|
||||||
|
weight_decay=self.args['weight_decay']
|
||||||
|
)
|
||||||
|
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
@@ -475,7 +504,6 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
total_params = sum([tf.size(w).numpy() for w in self.model.trainable_weights])
|
total_params = sum([tf.size(w).numpy() for w in self.model.trainable_weights])
|
||||||
self.logger.info(f"Model has {total_params:,} trainable parameters")
|
self.logger.info(f"Model has {total_params:,} trainable parameters")
|
||||||
|
|
||||||
@tf.function
|
|
||||||
def _train_step(self, batch, step):
|
def _train_step(self, batch, step):
|
||||||
"""Single training step with gradient tape"""
|
"""Single training step with gradient tape"""
|
||||||
features = batch['input_features']
|
features = batch['input_features']
|
||||||
@@ -575,22 +603,43 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
# Apply gradients (only for variables that have gradients)
|
# Apply gradients (only for variables that have gradients)
|
||||||
if len(filtered_gradients) > 0:
|
if len(filtered_gradients) > 0:
|
||||||
# Ensure we're in the strategy scope when applying gradients
|
# Apply gradients with comprehensive error handling
|
||||||
# This prevents the 'NoneType' extended attribute error
|
# The optimizer should already be built and have all necessary variables
|
||||||
try:
|
try:
|
||||||
|
# Check if optimizer is properly built before applying gradients
|
||||||
|
if not self.optimizer.built:
|
||||||
|
print("WARNING: Optimizer not built, building now...")
|
||||||
|
# This should not happen if pre-build worked correctly
|
||||||
|
self.optimizer.build(filtered_variables)
|
||||||
|
|
||||||
|
# Apply gradients - this should work since optimizer is pre-built
|
||||||
self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables))
|
self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables))
|
||||||
|
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
if "'NoneType' object has no attribute 'extended'" in str(e):
|
print("CRITICAL ERROR in gradient application:")
|
||||||
# Strategy context was lost, this should not happen in a @tf.function
|
print(f"Error: {e}")
|
||||||
tf.print(f"ERROR: Strategy context lost during gradient application: {e}")
|
print("This indicates the optimizer lost its strategy context")
|
||||||
tf.print("This indicates a serious issue with the distributed training setup")
|
print(f"Optimizer built: {self.optimizer.built}")
|
||||||
raise RuntimeError(f"Strategy context lost during training: {e}")
|
print(f"Number of gradients: {len(filtered_gradients)}")
|
||||||
else:
|
print(f"Number of variables: {len(filtered_variables)}")
|
||||||
raise
|
|
||||||
|
# Check current strategy
|
||||||
|
current_strategy = tf.distribute.get_strategy()
|
||||||
|
print(f"Current strategy: {type(current_strategy).__name__}")
|
||||||
|
print(f"Training strategy: {type(self.strategy).__name__}")
|
||||||
|
|
||||||
|
# Re-raise with more context
|
||||||
|
raise RuntimeError(f"Gradient application failed - optimizer strategy context lost: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Catch any other errors during gradient application
|
||||||
|
print("Unexpected error during gradient application:")
|
||||||
|
print(f"Error type: {type(e).__name__}")
|
||||||
|
print(f"Error message: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
return loss, grad_norm
|
return loss, grad_norm
|
||||||
|
|
||||||
@tf.function
|
|
||||||
def _validation_step(self, batch):
|
def _validation_step(self, batch):
|
||||||
"""Single validation step"""
|
"""Single validation step"""
|
||||||
features = batch['input_features']
|
features = batch['input_features']
|
||||||
|
Reference in New Issue
Block a user