| 
									
										
										
										
											2025-10-12 15:31:45 +08:00
										 |  |  | import argparse | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | from omegaconf import OmegaConf | 
					
						
							|  |  |  | from rnn_trainer import BrainToTextDecoder_Trainer | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-12 15:31:45 +08:00
										 |  |  | def main(): | 
					
						
							|  |  |  |     parser = argparse.ArgumentParser(description='Train Brain-to-Text RNN Model') | 
					
						
							|  |  |  |     parser.add_argument('--config_path', default='rnn_args.yaml', | 
					
						
							|  |  |  |                        help='Path to configuration file (default: rnn_args.yaml)') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     args = parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Load configuration | 
					
						
							|  |  |  |     config = OmegaConf.load(args.config_path) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Initialize trainer | 
					
						
							|  |  |  |     trainer = BrainToTextDecoder_Trainer(config) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Start training | 
					
						
							|  |  |  |     trainer.train() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     print("Training completed successfully!") | 
					
						
							|  |  |  |     print(f"Best validation PER: {trainer.best_val_PER:.5f}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     main() |