This commit is contained in:
Zchen
2025-10-16 21:40:43 +08:00
parent 426b72ef25
commit eefff1ce5e

View File

@@ -100,27 +100,40 @@ class BrainToTextDecoderTrainerTF:
# Force optimizer to build its internal state within strategy scope # Force optimizer to build its internal state within strategy scope
# This prevents the 'NoneType' strategy error during first apply_gradients # This prevents the 'NoneType' strategy error during first apply_gradients
try: try:
# Check if strategy is properly initialized before applying gradients print("✅ Building optimizer with complete state initialization...")
if hasattr(self.strategy, 'merge_call') and callable(getattr(self.strategy, 'merge_call')):
print("✅ Strategy has merge_call, building optimizer properly...")
# Build optimizer by explicitly calling build method # First, explicitly build the optimizer with model variables
self.optimizer.build(self.model.trainable_variables) print(f"Building optimizer with {len(self.model.trainable_variables)} variables")
print("✅ Optimizer built with model variables") self.optimizer.build(self.model.trainable_variables)
print("✅ Optimizer built with model variables")
# Test with dummy gradients to ensure everything works # Create dummy gradients and variables for full state initialization
dummy_grads = [tf.zeros_like(w) for w in self.model.trainable_variables] dummy_grads = [tf.zeros_like(var) for var in self.model.trainable_variables]
self.optimizer.apply_gradients(zip(dummy_grads, self.model.trainable_variables)) print(f"Created {len(dummy_grads)} dummy gradients")
print("✅ Optimizer state pre-built successfully with TPU strategy")
# Apply dummy gradients to fully initialize optimizer state
# This ensures all optimizer variables are created within the strategy scope
self.optimizer.apply_gradients(zip(dummy_grads, self.model.trainable_variables))
print("✅ Optimizer state fully initialized with dummy gradients")
# Verify optimizer is properly built
print(f"Optimizer iterations: {self.optimizer.iterations}")
print(f"Optimizer built: {self.optimizer.built}")
# Print optimizer variable names for debugging
if hasattr(self.optimizer, 'variables') and self.optimizer.variables:
print(f"Optimizer has {len(self.optimizer.variables)} internal variables")
else: else:
# Fallback: just build optimizer variables without applying gradients print("⚠️ Optimizer has no internal variables - this might cause issues")
print("⚠️ Strategy not fully initialized, using fallback optimizer build")
# Force build the optimizer with the model variables print("✅ Optimizer pre-build completed successfully")
self.optimizer.build(self.model.trainable_variables)
print("✅ Optimizer built in fallback mode")
except Exception as e: except Exception as e:
print(f"⚠️ Warning: Could not pre-build optimizer state: {e}") print(f"❌ CRITICAL: Could not pre-build optimizer state: {e}")
print("✅ Continuing without optimizer pre-build - optimizer will build during first training step") print(f"Error type: {type(e).__name__}")
import traceback
print(f"Full traceback: {traceback.format_exc()}")
raise RuntimeError(f"Optimizer pre-build failed: {e}") from e
print("📅 Setting up learning rate scheduler...") print("📅 Setting up learning rate scheduler...")
self.lr_scheduler = self._create_lr_scheduler() self.lr_scheduler = self._create_lr_scheduler()
@@ -422,15 +435,31 @@ class BrainToTextDecoderTrainerTF:
# Note: TensorFlow doesn't have the same parameter group functionality as PyTorch # Note: TensorFlow doesn't have the same parameter group functionality as PyTorch
# We'll use a single optimizer and handle different learning rates in the scheduler # We'll use a single optimizer and handle different learning rates in the scheduler
# Create optimizer within strategy scope to ensure proper initialization
print(f"Creating optimizer with strategy: {type(self.strategy).__name__}") print(f"Creating optimizer with strategy: {type(self.strategy).__name__}")
optimizer = tf.keras.optimizers.AdamW(
learning_rate=self.args['lr_max'], # For TPU training, we need to be more explicit about optimizer configuration
beta_1=self.args['beta0'], # to avoid strategy context issues
beta_2=self.args['beta1'], if isinstance(self.strategy, tf.distribute.TPUStrategy):
epsilon=self.args['epsilon'], print("Using TPU-optimized optimizer configuration")
weight_decay=self.args['weight_decay'] # TPU-specific optimizer configuration
) optimizer = tf.keras.optimizers.AdamW(
learning_rate=self.args['lr_max'],
beta_1=self.args['beta0'],
beta_2=self.args['beta1'],
epsilon=self.args['epsilon'],
weight_decay=self.args['weight_decay'],
# TPU-specific settings
global_clipnorm=self.args.get('grad_norm_clip_value', 0.0) if self.args.get('grad_norm_clip_value', 0.0) > 0 else None
)
else:
print("Using standard optimizer configuration")
optimizer = tf.keras.optimizers.AdamW(
learning_rate=self.args['lr_max'],
beta_1=self.args['beta0'],
beta_2=self.args['beta1'],
epsilon=self.args['epsilon'],
weight_decay=self.args['weight_decay']
)
return optimizer return optimizer
@@ -475,7 +504,6 @@ class BrainToTextDecoderTrainerTF:
total_params = sum([tf.size(w).numpy() for w in self.model.trainable_weights]) total_params = sum([tf.size(w).numpy() for w in self.model.trainable_weights])
self.logger.info(f"Model has {total_params:,} trainable parameters") self.logger.info(f"Model has {total_params:,} trainable parameters")
@tf.function
def _train_step(self, batch, step): def _train_step(self, batch, step):
"""Single training step with gradient tape""" """Single training step with gradient tape"""
features = batch['input_features'] features = batch['input_features']
@@ -575,22 +603,43 @@ class BrainToTextDecoderTrainerTF:
# Apply gradients (only for variables that have gradients) # Apply gradients (only for variables that have gradients)
if len(filtered_gradients) > 0: if len(filtered_gradients) > 0:
# Ensure we're in the strategy scope when applying gradients # Apply gradients with comprehensive error handling
# This prevents the 'NoneType' extended attribute error # The optimizer should already be built and have all necessary variables
try: try:
# Check if optimizer is properly built before applying gradients
if not self.optimizer.built:
print("WARNING: Optimizer not built, building now...")
# This should not happen if pre-build worked correctly
self.optimizer.build(filtered_variables)
# Apply gradients - this should work since optimizer is pre-built
self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables)) self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables))
except AttributeError as e: except AttributeError as e:
if "'NoneType' object has no attribute 'extended'" in str(e): print("CRITICAL ERROR in gradient application:")
# Strategy context was lost, this should not happen in a @tf.function print(f"Error: {e}")
tf.print(f"ERROR: Strategy context lost during gradient application: {e}") print("This indicates the optimizer lost its strategy context")
tf.print("This indicates a serious issue with the distributed training setup") print(f"Optimizer built: {self.optimizer.built}")
raise RuntimeError(f"Strategy context lost during training: {e}") print(f"Number of gradients: {len(filtered_gradients)}")
else: print(f"Number of variables: {len(filtered_variables)}")
raise
# Check current strategy
current_strategy = tf.distribute.get_strategy()
print(f"Current strategy: {type(current_strategy).__name__}")
print(f"Training strategy: {type(self.strategy).__name__}")
# Re-raise with more context
raise RuntimeError(f"Gradient application failed - optimizer strategy context lost: {e}")
except Exception as e:
# Catch any other errors during gradient application
print("Unexpected error during gradient application:")
print(f"Error type: {type(e).__name__}")
print(f"Error message: {e}")
raise
return loss, grad_norm return loss, grad_norm
@tf.function
def _validation_step(self, batch): def _validation_step(self, batch):
"""Single validation step""" """Single validation step"""
features = batch['input_features'] features = batch['input_features']