265 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			265 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| """
 | |
| TensorFlow Training Script for Brain-to-Text RNN Model
 | |
| Optimized for TPU v5e-8
 | |
| 
 | |
| This script trains the TripleGRUDecoder model using TensorFlow and TPU hardware.
 | |
| It provides the same functionality as the PyTorch version but with TensorFlow
 | |
| operations optimized for TPU performance.
 | |
| 
 | |
| Usage:
 | |
|     python train_model_tf.py --config_path rnn_args.yaml
 | |
| 
 | |
| Requirements:
 | |
|     - TensorFlow >= 2.15.0
 | |
|     - TPU v5e-8 environment
 | |
|     - Access to brain-to-text HDF5 dataset
 | |
| """
 | |
| 
 | |
| import argparse
 | |
| import os
 | |
| import sys
 | |
| import logging
 | |
| from omegaconf import OmegaConf
 | |
| 
 | |
| # Add the current directory to Python path for imports
 | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__)))
 | |
| 
 | |
| from trainer_tf import BrainToTextDecoderTrainerTF
 | |
| 
 | |
| 
 | |
| def setup_tpu_environment():
 | |
|     """Setup TPU environment variables for optimal performance"""
 | |
|     # TPU v5e-8 optimizations
 | |
|     os.environ.setdefault('TPU_ML_PLATFORM', 'PyTorch/XLA')  # Enable XLA optimizations
 | |
|     os.environ.setdefault('XLA_USE_BF16', '1')  # Enable bfloat16 for memory efficiency
 | |
|     os.environ.setdefault('TF_XLA_FLAGS', '--tf_xla_auto_jit=2')  # Enable XLA JIT compilation
 | |
| 
 | |
|     # TPU memory optimizations
 | |
|     os.environ.setdefault('TPU_MEGACORE', '1')  # Enable megacore mode for larger models
 | |
|     os.environ.setdefault('LIBTPU_INIT_ARGS', '--xla_tpu_spmd_threshold_for_allgather_cse=10000')
 | |
| 
 | |
|     # Disable warnings for cleaner output
 | |
|     os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '2')
 | |
| 
 | |
|     print("TPU environment configured for v5e-8 optimizations")
 | |
| 
 | |
| 
 | |
| def validate_config(config):
 | |
|     """Validate configuration for TensorFlow TPU training"""
 | |
|     required_fields = [
 | |
|         'model.n_input_features',
 | |
|         'model.n_units',
 | |
|         'dataset.sessions',
 | |
|         'dataset.n_classes',
 | |
|         'num_training_batches',
 | |
|         'output_dir',
 | |
|         'checkpoint_dir'
 | |
|     ]
 | |
| 
 | |
|     for field in required_fields:
 | |
|         keys = field.split('.')
 | |
|         value = config
 | |
|         try:
 | |
|             for key in keys:
 | |
|                 value = value[key]
 | |
|         except KeyError:
 | |
|             raise ValueError(f"Missing required configuration field: {field}")
 | |
| 
 | |
|     # TPU-specific validations
 | |
|     if config.get('use_tpu', True):
 | |
|         if config['dataset']['batch_size'] < 8:
 | |
|             logging.warning("Small batch size may not utilize TPU efficiently. Consider batch_size >= 32")
 | |
| 
 | |
|         if not config.get('use_amp', True):
 | |
|             logging.warning("Mixed precision disabled. Consider enabling for better TPU performance")
 | |
| 
 | |
|     # Dataset validation
 | |
|     dataset_dir = config['dataset']['dataset_dir']
 | |
|     if not os.path.exists(dataset_dir):
 | |
|         raise FileNotFoundError(f"Dataset directory not found: {dataset_dir}")
 | |
| 
 | |
|     # Check if at least one session file exists
 | |
|     session_found = False
 | |
|     for session in config['dataset']['sessions']:
 | |
|         train_path = os.path.join(dataset_dir, session, 'data_train.hdf5')
 | |
|         if os.path.exists(train_path):
 | |
|             session_found = True
 | |
|             break
 | |
| 
 | |
|     if not session_found:
 | |
|         raise FileNotFoundError("No valid session data files found in dataset directory")
 | |
| 
 | |
|     print("Configuration validation passed")
 | |
| 
 | |
| 
 | |
| def main():
 | |
|     """Main training function"""
 | |
|     parser = argparse.ArgumentParser(
 | |
|         description='Train Brain-to-Text RNN Model with TensorFlow on TPU v5e-8',
 | |
|         formatter_class=argparse.ArgumentDefaultsHelpFormatter
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         '--config_path',
 | |
|         default='rnn_args.yaml',
 | |
|         help='Path to configuration YAML file'
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         '--output_dir',
 | |
|         default=None,
 | |
|         help='Override output directory from config'
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         '--checkpoint_dir',
 | |
|         default=None,
 | |
|         help='Override checkpoint directory from config'
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         '--resume_from',
 | |
|         default=None,
 | |
|         help='Path to checkpoint to resume training from'
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         '--num_batches',
 | |
|         type=int,
 | |
|         default=None,
 | |
|         help='Override number of training batches'
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         '--batch_size',
 | |
|         type=int,
 | |
|         default=None,
 | |
|         help='Override batch size'
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         '--mixed_precision',
 | |
|         action='store_true',
 | |
|         default=None,
 | |
|         help='Enable mixed precision training (bfloat16)'
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         '--disable_mixed_precision',
 | |
|         action='store_true',
 | |
|         help='Disable mixed precision training'
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         '--validate_only',
 | |
|         action='store_true',
 | |
|         help='Only run validation, do not train'
 | |
|     )
 | |
| 
 | |
|     parser.add_argument(
 | |
|         '--debug',
 | |
|         action='store_true',
 | |
|         help='Enable debug logging'
 | |
|     )
 | |
| 
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     # Setup logging
 | |
|     log_level = logging.DEBUG if args.debug else logging.INFO
 | |
|     logging.basicConfig(
 | |
|         level=log_level,
 | |
|         format='%(asctime)s - %(levelname)s - %(message)s'
 | |
|     )
 | |
| 
 | |
|     # Setup TPU environment
 | |
|     setup_tpu_environment()
 | |
| 
 | |
|     # Load configuration
 | |
|     if not os.path.exists(args.config_path):
 | |
|         raise FileNotFoundError(f"Configuration file not found: {args.config_path}")
 | |
| 
 | |
|     config = OmegaConf.load(args.config_path)
 | |
|     print(f"Loaded configuration from: {args.config_path}")
 | |
| 
 | |
|     # Apply command line overrides
 | |
|     if args.output_dir:
 | |
|         config.output_dir = args.output_dir
 | |
|     if args.checkpoint_dir:
 | |
|         config.checkpoint_dir = args.checkpoint_dir
 | |
|     if args.num_batches:
 | |
|         config.num_training_batches = args.num_batches
 | |
|     if args.batch_size:
 | |
|         config.dataset.batch_size = args.batch_size
 | |
|     if args.mixed_precision:
 | |
|         config.use_amp = True
 | |
|     if args.disable_mixed_precision:
 | |
|         config.use_amp = False
 | |
| 
 | |
|     # Validate configuration
 | |
|     validate_config(config)
 | |
| 
 | |
|     try:
 | |
|         # Initialize trainer
 | |
|         print("Initializing TensorFlow Brain-to-Text trainer...")
 | |
|         trainer = BrainToTextDecoderTrainerTF(config)
 | |
| 
 | |
|         # Load checkpoint if specified
 | |
|         if args.resume_from:
 | |
|             if os.path.exists(args.resume_from + '.weights.h5'):
 | |
|                 trainer.load_checkpoint(args.resume_from)
 | |
|                 print(f"Resumed training from checkpoint: {args.resume_from}")
 | |
|             else:
 | |
|                 raise FileNotFoundError(f"Checkpoint not found: {args.resume_from}")
 | |
| 
 | |
|         if args.validate_only:
 | |
|             print("Running validation only...")
 | |
|             # Create validation dataset
 | |
|             from dataset_tf import create_input_fn
 | |
|             val_dataset = create_input_fn(
 | |
|                 trainer.val_dataset_tf,
 | |
|                 trainer.args['dataset']['data_transforms'],
 | |
|                 training=False
 | |
|             )
 | |
|             val_dist_dataset = trainer.strategy.experimental_distribute_dataset(val_dataset)
 | |
| 
 | |
|             # Run validation
 | |
|             val_metrics = trainer._validate(val_dist_dataset)
 | |
| 
 | |
|             print(f"Validation Results:")
 | |
|             print(f"  Average Loss: {val_metrics['avg_loss']:.4f}")
 | |
|             print(f"  Average PER: {val_metrics['avg_per']:.4f}")
 | |
|             print(f"  Total Edit Distance: {val_metrics['total_edit_distance']}")
 | |
|             print(f"  Total Sequence Length: {val_metrics['total_seq_length']}")
 | |
| 
 | |
|         else:
 | |
|             # Start training
 | |
|             print("Starting training...")
 | |
|             train_stats = trainer.train()
 | |
| 
 | |
|             print("\nTraining completed successfully!")
 | |
|             print(f"Best validation PER: {trainer.best_val_per:.5f}")
 | |
|             print(f"Final training loss: {train_stats['train_losses'][-1]:.4f}")
 | |
|             print(f"Final validation loss: {train_stats['val_losses'][-1]:.4f}")
 | |
|             print(f"Total training batches: {len(train_stats['train_losses'])}")
 | |
| 
 | |
|             # Save final training statistics
 | |
|             import pickle
 | |
|             stats_path = os.path.join(config.output_dir, 'training_stats.pkl')
 | |
|             with open(stats_path, 'wb') as f:
 | |
|                 pickle.dump(train_stats, f)
 | |
|             print(f"Training statistics saved to: {stats_path}")
 | |
| 
 | |
|     except KeyboardInterrupt:
 | |
|         print("\nTraining interrupted by user")
 | |
|         sys.exit(1)
 | |
|     except Exception as e:
 | |
|         print(f"\nTraining failed with error: {e}")
 | |
|         if args.debug:
 | |
|             import traceback
 | |
|             traceback.print_exc()
 | |
|         sys.exit(1)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     main() | 
