f
This commit is contained in:
@@ -107,17 +107,12 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
self.optimizer.build(self.model.trainable_variables)
|
self.optimizer.build(self.model.trainable_variables)
|
||||||
print("✅ Optimizer built with model variables")
|
print("✅ Optimizer built with model variables")
|
||||||
|
|
||||||
# Verify optimizer is properly built
|
# Verify optimizer is properly built - just check iterations
|
||||||
print(f"Optimizer iterations: {self.optimizer.iterations}")
|
print(f"Optimizer iterations: {self.optimizer.iterations}")
|
||||||
print(f"Optimizer built: {self.optimizer.built}")
|
|
||||||
|
|
||||||
# Print optimizer variable names for debugging
|
# Simple check - if we have iterations, optimizer is ready
|
||||||
if hasattr(self.optimizer, 'variables') and self.optimizer.variables:
|
print("✅ Optimizer ready for training")
|
||||||
print(f"Optimizer has {len(self.optimizer.variables)} internal variables")
|
|
||||||
else:
|
|
||||||
print("⚠️ Optimizer has no internal variables - this might cause issues")
|
|
||||||
|
|
||||||
print("✅ Optimizer pre-build completed successfully")
|
|
||||||
print("📝 Note: Optimizer state will be fully initialized on first training step")
|
print("📝 Note: Optimizer state will be fully initialized on first training step")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -481,7 +476,10 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
def _log_model_info(self):
|
def _log_model_info(self):
|
||||||
"""Log model architecture and parameter information"""
|
"""Log model architecture and parameter information"""
|
||||||
self.logger.info("Initialized TripleGRUDecoder model")
|
if self.logger:
|
||||||
|
self.logger.info("Initialized TripleGRUDecoder model")
|
||||||
|
else:
|
||||||
|
print("Initialized TripleGRUDecoder model")
|
||||||
|
|
||||||
# Build the model by calling it once with dummy data
|
# Build the model by calling it once with dummy data
|
||||||
dummy_batch_size = 2
|
dummy_batch_size = 2
|
||||||
@@ -494,7 +492,11 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
# Count parameters
|
# Count parameters
|
||||||
total_params = sum([tf.size(w).numpy() for w in self.model.trainable_weights])
|
total_params = sum([tf.size(w).numpy() for w in self.model.trainable_weights])
|
||||||
self.logger.info(f"Model has {total_params:,} trainable parameters")
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.info(f"Model has {total_params:,} trainable parameters")
|
||||||
|
else:
|
||||||
|
print(f"Model has {total_params:,} trainable parameters")
|
||||||
|
|
||||||
def _train_step(self, batch, step):
|
def _train_step(self, batch, step):
|
||||||
"""Single training step with gradient tape"""
|
"""Single training step with gradient tape"""
|
||||||
|
Reference in New Issue
Block a user