fix
This commit is contained in:
@@ -102,18 +102,25 @@ class BrainToTextDecoderTrainerTF:
|
||||
try:
|
||||
# Check if strategy is properly initialized before applying gradients
|
||||
if hasattr(self.strategy, 'merge_call') and callable(getattr(self.strategy, 'merge_call')):
|
||||
print("✅ Strategy has merge_call, building optimizer properly...")
|
||||
|
||||
# Build optimizer by explicitly calling build method
|
||||
self.optimizer.build(self.model.trainable_variables)
|
||||
print("✅ Optimizer built with model variables")
|
||||
|
||||
# Test with dummy gradients to ensure everything works
|
||||
dummy_grads = [tf.zeros_like(w) for w in self.model.trainable_variables]
|
||||
self.optimizer.apply_gradients(zip(dummy_grads, self.model.trainable_variables))
|
||||
print("✅ Optimizer state pre-built successfully with TPU strategy")
|
||||
else:
|
||||
# Fallback: just build optimizer variables without applying gradients
|
||||
print("⚠️ Strategy not fully initialized, skipping optimizer pre-build")
|
||||
# Alternative: trigger optimizer variable creation
|
||||
_ = self.optimizer.iterations
|
||||
print("✅ Optimizer state initialized (fallback mode)")
|
||||
print("⚠️ Strategy not fully initialized, using fallback optimizer build")
|
||||
# Force build the optimizer with the model variables
|
||||
self.optimizer.build(self.model.trainable_variables)
|
||||
print("✅ Optimizer built in fallback mode")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Warning: Could not pre-build optimizer state: {e}")
|
||||
print("✅ Continuing without optimizer pre-build")
|
||||
print("✅ Continuing without optimizer pre-build - optimizer will build during first training step")
|
||||
|
||||
print("📅 Setting up learning rate scheduler...")
|
||||
self.lr_scheduler = self._create_lr_scheduler()
|
||||
@@ -414,6 +421,9 @@ class BrainToTextDecoderTrainerTF:
|
||||
"""Create AdamW optimizer with parameter groups"""
|
||||
# Note: TensorFlow doesn't have the same parameter group functionality as PyTorch
|
||||
# We'll use a single optimizer and handle different learning rates in the scheduler
|
||||
|
||||
# Create optimizer within strategy scope to ensure proper initialization
|
||||
print(f"Creating optimizer with strategy: {type(self.strategy).__name__}")
|
||||
optimizer = tf.keras.optimizers.AdamW(
|
||||
learning_rate=self.args['lr_max'],
|
||||
beta_1=self.args['beta0'],
|
||||
@@ -565,7 +575,18 @@ class BrainToTextDecoderTrainerTF:
|
||||
|
||||
# Apply gradients (only for variables that have gradients)
|
||||
if len(filtered_gradients) > 0:
|
||||
# Ensure we're in the strategy scope when applying gradients
|
||||
# This prevents the 'NoneType' extended attribute error
|
||||
try:
|
||||
self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables))
|
||||
except AttributeError as e:
|
||||
if "'NoneType' object has no attribute 'extended'" in str(e):
|
||||
# Strategy context was lost, this should not happen in a @tf.function
|
||||
tf.print(f"ERROR: Strategy context lost during gradient application: {e}")
|
||||
tf.print("This indicates a serious issue with the distributed training setup")
|
||||
raise RuntimeError(f"Strategy context lost during training: {e}")
|
||||
else:
|
||||
raise
|
||||
|
||||
return loss, grad_norm
|
||||
|
||||
|
Reference in New Issue
Block a user