This commit is contained in:
Zchen
2025-10-16 22:02:11 +08:00
parent 6f94ad5fae
commit bd61136f93

View File

@@ -107,17 +107,12 @@ class BrainToTextDecoderTrainerTF:
self.optimizer.build(self.model.trainable_variables) self.optimizer.build(self.model.trainable_variables)
print("✅ Optimizer built with model variables") print("✅ Optimizer built with model variables")
# Verify optimizer is properly built # Verify optimizer is properly built - just check iterations
print(f"Optimizer iterations: {self.optimizer.iterations}") print(f"Optimizer iterations: {self.optimizer.iterations}")
print(f"Optimizer built: {self.optimizer.built}")
# Print optimizer variable names for debugging # Simple check - if we have iterations, optimizer is ready
if hasattr(self.optimizer, 'variables') and self.optimizer.variables: print("✅ Optimizer ready for training")
print(f"Optimizer has {len(self.optimizer.variables)} internal variables")
else:
print("⚠️ Optimizer has no internal variables - this might cause issues")
print("✅ Optimizer pre-build completed successfully")
print("📝 Note: Optimizer state will be fully initialized on first training step") print("📝 Note: Optimizer state will be fully initialized on first training step")
except Exception as e: except Exception as e:
@@ -481,7 +476,10 @@ class BrainToTextDecoderTrainerTF:
def _log_model_info(self): def _log_model_info(self):
"""Log model architecture and parameter information""" """Log model architecture and parameter information"""
self.logger.info("Initialized TripleGRUDecoder model") if self.logger:
self.logger.info("Initialized TripleGRUDecoder model")
else:
print("Initialized TripleGRUDecoder model")
# Build the model by calling it once with dummy data # Build the model by calling it once with dummy data
dummy_batch_size = 2 dummy_batch_size = 2
@@ -494,7 +492,11 @@ class BrainToTextDecoderTrainerTF:
# Count parameters # Count parameters
total_params = sum([tf.size(w).numpy() for w in self.model.trainable_weights]) total_params = sum([tf.size(w).numpy() for w in self.model.trainable_weights])
self.logger.info(f"Model has {total_params:,} trainable parameters")
if self.logger:
self.logger.info(f"Model has {total_params:,} trainable parameters")
else:
print(f"Model has {total_params:,} trainable parameters")
def _train_step(self, batch, step): def _train_step(self, batch, step):
"""Single training step with gradient tape""" """Single training step with gradient tape"""