From bd61136f934f51524f59df09bb311cf0e8b40c25 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Thu, 16 Oct 2025 22:02:11 +0800 Subject: [PATCH] f --- model_training_nnn_tpu/trainer_tf.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 175a2b3..61d7a03 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -107,17 +107,12 @@ class BrainToTextDecoderTrainerTF: self.optimizer.build(self.model.trainable_variables) print("✅ Optimizer built with model variables") - # Verify optimizer is properly built + # Verify optimizer is properly built - just check iterations print(f"Optimizer iterations: {self.optimizer.iterations}") - print(f"Optimizer built: {self.optimizer.built}") - # Print optimizer variable names for debugging - if hasattr(self.optimizer, 'variables') and self.optimizer.variables: - print(f"Optimizer has {len(self.optimizer.variables)} internal variables") - else: - print("⚠️ Optimizer has no internal variables - this might cause issues") + # Simple check - if we have iterations, optimizer is ready + print("✅ Optimizer ready for training") - print("✅ Optimizer pre-build completed successfully") print("📝 Note: Optimizer state will be fully initialized on first training step") except Exception as e: @@ -481,7 +476,10 @@ class BrainToTextDecoderTrainerTF: def _log_model_info(self): """Log model architecture and parameter information""" - self.logger.info("Initialized TripleGRUDecoder model") + if self.logger: + self.logger.info("Initialized TripleGRUDecoder model") + else: + print("Initialized TripleGRUDecoder model") # Build the model by calling it once with dummy data dummy_batch_size = 2 @@ -494,7 +492,11 @@ class BrainToTextDecoderTrainerTF: # Count parameters total_params = sum([tf.size(w).numpy() for w in self.model.trainable_weights]) - self.logger.info(f"Model has {total_params:,} trainable parameters") + + if self.logger: + self.logger.info(f"Model has {total_params:,} trainable parameters") + else: + print(f"Model has {total_params:,} trainable parameters") def _train_step(self, batch, step): """Single training step with gradient tape"""