From 7ede7b5f120c4e36d3b44d5f25dd12b2ed725f11 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Fri, 17 Oct 2025 02:09:14 +0800 Subject: [PATCH] f --- model_training_nnn_tpu/trainer_tf.py | 49 ++++++++++++++-------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 2d0c8d3..32bff8a 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -100,8 +100,13 @@ class BrainToTextDecoderTrainerTF: @tf.function def init_optimizer_slots(): - dummy_gradients = [tf.zeros_like(var) for var in self.model.trainable_variables] - self.optimizer.apply_gradients(zip(dummy_gradients, self.model.trainable_variables)) + # Use ALL trainable variables for slot initialization, not just filtered ones + # This ensures slot variables are created for all variables that might need gradients + all_variables = self.model.trainable_variables + dummy_gradients = [tf.zeros_like(var) for var in all_variables] + + # Apply gradients for all variables to ensure all slots are created + self.optimizer.apply_gradients(zip(dummy_gradients, all_variables)) return tf.constant(True) # Return something to satisfy strategy.run # Run the slot initialization in replica context @@ -583,34 +588,30 @@ class BrainToTextDecoderTrainerTF: # Calculate gradients - TensorFlow自动处理混合精度 gradients = tape.gradient(loss, self.model.trainable_variables) - # Filter out None gradients (for h0 variables that don't need gradients) - filtered_grads_and_vars = [] - for grad, var in zip(gradients, self.model.trainable_variables): - if grad is not None: - filtered_grads_and_vars.append((grad, var)) - else: - # Log which variables don't have gradients (informational) - tf.print(f"No gradient for variable: {var.name}") + # For TPU compatibility, use all variables (TensorFlow will handle None gradients automatically) + # This ensures consistency with slot variable initialization + all_variables = self.model.trainable_variables - # Extract filtered gradients and variables - filtered_gradients = [grad for grad, _ in filtered_grads_and_vars] - filtered_variables = [var for _, var in filtered_grads_and_vars] + # Replace None gradients with zeros to maintain consistency + safe_gradients = [] + for grad, var in zip(gradients, all_variables): + if grad is not None: + safe_gradients.append(grad) + else: + # Create zero gradient for variables without gradients + safe_gradients.append(tf.zeros_like(var)) # Clip gradients - if self.args['grad_norm_clip_value'] > 0 and len(filtered_gradients) > 0: - filtered_gradients, grad_norm = tf.clip_by_global_norm( - filtered_gradients, self.args['grad_norm_clip_value'] + if self.args['grad_norm_clip_value'] > 0: + safe_gradients, grad_norm = tf.clip_by_global_norm( + safe_gradients, self.args['grad_norm_clip_value'] ) - elif len(filtered_gradients) > 0: - grad_norm = tf.global_norm(filtered_gradients) else: - grad_norm = tf.constant(0.0) + grad_norm = tf.global_norm(safe_gradients) - # Apply gradients (only for variables that have gradients) - if len(filtered_gradients) > 0: - # Apply gradients directly - optimizer should be pre-built and ready - # In @tf.function, we need to keep error handling simple - self.optimizer.apply_gradients(zip(filtered_gradients, filtered_variables)) + # Apply gradients to ALL variables (consistent with initialization) + # TensorFlow optimizer will handle zero gradients correctly + self.optimizer.apply_gradients(zip(safe_gradients, all_variables)) return loss, grad_norm