This commit is contained in:
Zchen
2025-10-17 02:09:14 +08:00
parent ca8c615505
commit 7ede7b5f12

View File

@@ -100,8 +100,13 @@ class BrainToTextDecoderTrainerTF:
@tf.function
def init_optimizer_slots():
dummy_gradients = [tf.zeros_like(var) for var in self.model.trainable_variables]
self.optimizer.apply_gradients(zip(dummy_gradients, self.model.trainable_variables))
# Use ALL trainable variables for slot initialization, not just filtered ones
# This ensures slot variables are created for all variables that might need gradients
all_variables = self.model.trainable_variables
dummy_gradients = [tf.zeros_like(var) for var in all_variables]
# Apply gradients for all variables to ensure all slots are created
self.optimizer.apply_gradients(zip(dummy_gradients, all_variables))
return tf.constant(True) # Return something to satisfy strategy.run
# Run the slot initialization in replica context
@@ -583,34 +588,30 @@ class BrainToTextDecoderTrainerTF:
# Calculate gradients - TensorFlow自动处理混合精度
gradients = tape.gradient(loss, self.model.trainable_variables)
# Filter out None gradients (for h0 variables that don't need gradients)
filtered_grads_and_vars = []
for grad, var in zip(gradients, self.model.trainable_variables):
if grad is not None:
filtered_grads_and_vars.append((grad, var))
else:
# Log which variables don't have gradients (informational)
tf.print(f"No gradient for variable: {var.name}")
# For TPU compatibility, use all variables (TensorFlow will handle None gradients automatically)
# This ensures consistency with slot variable initialization
all_variables = self.model.trainable_variables
# Extract filtered gradients and variables
filtered_gradients = [grad for grad, _ in filtered_grads_and_vars]
filtered_variables = [var for _, var in filtered_grads_and_vars]
# Replace None gradients with zeros to maintain consistency
safe_gradients = []
for grad, var in zip(gradients, all_variables):
if grad is not None:
safe_gradients.append(grad)
else:
# Create zero gradient for variables without gradients
safe_gradients.append(tf.zeros_like(var))
# Clip gradients
if self.args['grad_norm_clip_value'] > 0 and len(filtered_gradients) > 0:
filtered_gradients, grad_norm = tf.clip_by_global_norm(
filtered_gradients, self.args['grad_norm_clip_value']
if self.args['grad_norm_clip_value'] > 0:
safe_gradients, grad_norm = tf.clip_by_global_norm(
safe_gradients, self.args['grad_norm_clip_value']
)
elif len(filtered_gradients) > 0:
grad_norm = tf.global_norm(filtered_gradients)
else:
grad_norm = tf.constant(0.0)
grad_norm = tf.global_norm(safe_gradients)
# Apply gradients (only for variables that have gradients)
if len(filtered_gradients) > 0:
# Apply gradients directly - optimizer should be pre-built and ready
# In @tf.function, we need to keep error handling simple
self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables))
# Apply gradients to ALL variables (consistent with initialization)
# TensorFlow optimizer will handle zero gradients correctly
self.optimizer.apply_gradients(zip(safe_gradients, all_variables))
return loss, grad_norm