diff --git a/model_training_nnn/TPU_SETUP_GUIDE.md b/model_training_nnn/TPU_SETUP_GUIDE.md new file mode 100644 index 0000000..fed0b80 --- /dev/null +++ b/model_training_nnn/TPU_SETUP_GUIDE.md @@ -0,0 +1,204 @@ +# TPU Training Setup Guide for Brain-to-Text RNN + +This guide explains how to use the TPU support that has been added to the brain-to-text RNN training code. + +## Prerequisites + +### 1. Install PyTorch XLA for TPU Support +```bash +# Install PyTorch XLA (adjust version as needed) +pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html + +# Or for specific PyTorch version: +pip install torch_xla==2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html +``` + +### 2. Install Accelerate Library +```bash +pip install accelerate +``` + +### 3. Verify TPU Access +```bash +# Check if TPU is available +python -c "import torch_xla; import torch_xla.core.xla_model as xm; print(f'TPU device: {xm.xla_device()}')" +``` + +## Configuration Setup + +### 1. Enable TPU in Configuration File + +Update your `rnn_args.yaml` file with TPU settings: + +```yaml +# TPU and distributed training settings +use_tpu: true # Enable TPU training +num_tpu_cores: 8 # Number of TPU cores (8 for v3-8 or v4-8) +gradient_accumulation_steps: 1 # Gradient accumulation for large effective batch size +dataloader_num_workers: 0 # Must be 0 for TPU to avoid multiprocessing issues +use_amp: true # Enable mixed precision (bfloat16) + +# Adjust batch size for multi-core TPU +dataset: + batch_size: 8 # Per-core batch size (total = 8 cores × 8 = 64) +``` + +### 2. TPU-Optimized Hyperparameters + +Recommended adjustments for TPU training: + +```yaml +# Learning rate scaling for distributed training +lr_max: 0.005 # May need to scale with number of cores +lr_max_day: 0.005 + +# Batch size considerations +dataset: + batch_size: 8 # Per-core batch size + days_per_batch: 4 # Keep consistent across cores +``` + +## Training Launch Options + +### Method 1: Using the TPU Launch Script (Recommended) + +```bash +# Basic TPU training with 8 cores +python launch_tpu_training.py --config rnn_args.yaml --num_cores 8 + +# Check TPU environment only +python launch_tpu_training.py --check_only + +# Custom configuration file +python launch_tpu_training.py --config my_tpu_config.yaml --num_cores 8 +``` + +### Method 2: Direct Accelerate Launch + +```bash +# Configure accelerate (one-time setup) +accelerate config + +# Or use provided TPU config +export ACCELERATE_CONFIG_FILE=accelerate_config_tpu.yaml + +# Launch training +accelerate launch --config_file accelerate_config_tpu.yaml train_model.py --config_path rnn_args.yaml +``` + +### Method 3: Manual XLA Launch (Advanced) + +```bash +# Set TPU environment variables +export TPU_CORES=8 +export XLA_USE_BF16=1 + +# Launch with PyTorch XLA +python -m torch_xla.distributed.xla_dist --tpu --num_devices 8 train_model.py --config_path rnn_args.yaml +``` + +## Key TPU Features Implemented + +### 1. Distributed Training Support +- Automatic model parallelization across 8 TPU cores +- Synchronized gradient updates across all cores +- Proper checkpoint saving/loading for distributed training + +### 2. Mixed Precision Training +- Automatic bfloat16 precision for TPU optimization +- Faster training with maintained numerical stability +- Reduced memory usage + +### 3. TPU-Optimized Data Loading +- Single-threaded data loading (num_workers=0) for TPU compatibility +- Automatic data distribution across TPU cores +- Efficient batch processing + +### 4. Inference Support +- TPU-compatible inference methods added to trainer class +- `inference()` and `inference_batch()` methods for production use +- Automatic mixed precision during inference + +## Performance Optimization Tips + +### 1. Batch Size Tuning +- Start with total batch size = 64 (8 cores × 8 per core) +- Increase gradually if memory allows +- Monitor TPU utilization with `top` command + +### 2. Gradient Accumulation +- Use `gradient_accumulation_steps` to simulate larger batch sizes +- Effective batch size = batch_size × num_cores × gradient_accumulation_steps + +### 3. Learning Rate Scaling +- Consider scaling learning rate with number of cores +- Linear scaling: `lr_new = lr_base × num_cores` +- May need warmup adjustment for large batch training + +### 4. Memory Management +- TPU v3-8: 128GB HBM memory total +- TPU v4-8: 512GB HBM memory total +- Monitor memory usage to avoid OOM errors + +## Monitoring and Debugging + +### 1. TPU Utilization +```bash +# Monitor TPU usage +watch -n 1 'python -c "import torch_xla.core.xla_model as xm; print(f\"TPU cores: {xm.xrt_world_size()}\")"' +``` + +### 2. Training Logs +- Training logs include device information and core count +- Monitor validation metrics across all cores +- Check for synchronization issues in distributed training + +### 3. Common Issues and Solutions + +**Issue**: "No TPU devices found" +- **Solution**: Verify TPU runtime is started and accessible + +**Issue**: "DataLoader workers > 0 causes hangs" +- **Solution**: Set `dataloader_num_workers: 0` in config + +**Issue**: "Mixed precision errors" +- **Solution**: Ensure `use_amp: true` and PyTorch XLA supports bfloat16 + +**Issue**: "Gradient synchronization timeouts" +- **Solution**: Check network connectivity between TPU cores + +## Example Training Command + +```bash +# Complete TPU training example +cd model_training_nnn + +# 1. Update config for TPU +vim rnn_args.yaml # Set use_tpu: true, num_tpu_cores: 8 + +# 2. Launch TPU training +python launch_tpu_training.py --config rnn_args.yaml --num_cores 8 + +# 3. Monitor training progress +tail -f trained_models/baseline_rnn/training_log +``` + +## Configuration Reference + +### Required TPU Settings +```yaml +use_tpu: true +num_tpu_cores: 8 +dataloader_num_workers: 0 +use_amp: true +``` + +### Optional TPU Optimizations +```yaml +gradient_accumulation_steps: 1 +dataset: + batch_size: 8 # Per-core batch size +mixed_precision: bf16 +``` + +This TPU implementation allows you to leverage all 8 cores of your TPU for both training and inference, with automatic distributed training management through the Accelerate library. \ No newline at end of file diff --git a/model_training_nnn/accelerate_config_tpu.yaml b/model_training_nnn/accelerate_config_tpu.yaml new file mode 100644 index 0000000..0b48dab --- /dev/null +++ b/model_training_nnn/accelerate_config_tpu.yaml @@ -0,0 +1,26 @@ +# Accelerate Configuration for TPU Training +# This file configures Accelerate library for 8-core TPU training +# with mixed precision (bfloat16) support + +compute_environment: TPU +distributed_type: TPU +tpu_name: null # Will use default TPU +tpu_zone: null # Will use default zone + +# Mixed precision settings (use bfloat16 for TPU) +mixed_precision: bf16 + +# Number of TPU cores (v3-8 or v4-8 TPUs have 8 cores) +num_processes: 8 + +# Enable TPU debugging (set to false for production) +tpu_use_cluster: false +tpu_use_sudo: false + +# Logging settings +main_process_port: null +machine_rank: 0 +num_machines: 1 + +# Enable automatic optimization +use_cpu: false \ No newline at end of file diff --git a/model_training_nnn/launch_tpu_training.py b/model_training_nnn/launch_tpu_training.py new file mode 100644 index 0000000..ae51538 --- /dev/null +++ b/model_training_nnn/launch_tpu_training.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +""" +TPU Training Launch Script for Brain-to-Text RNN Model + +This script provides easy TPU training setup using Accelerate library. +Supports both single TPU core and multi-core (8 cores) training. + +Usage: + python launch_tpu_training.py --config rnn_args.yaml --num_cores 8 + +Requirements: + - PyTorch XLA installed + - Accelerate library installed + - TPU runtime available +""" + +import argparse +import yaml +import os +import sys +from pathlib import Path + +def update_config_for_tpu(config_path, num_cores=8): + """ + Update configuration file to enable TPU training + """ + with open(config_path, 'r') as f: + config = yaml.safe_load(f) + + # Enable TPU settings + config['use_tpu'] = True + config['num_tpu_cores'] = num_cores + config['dataloader_num_workers'] = 0 # Required for TPU + config['use_amp'] = True # Enable mixed precision with bfloat16 + + # Adjust batch size and gradient accumulation for multi-core TPU + if num_cores > 1: + # Distribute batch size across cores + original_batch_size = config['dataset']['batch_size'] + config['dataset']['batch_size'] = max(1, original_batch_size // num_cores) + config['gradient_accumulation_steps'] = max(1, config.get('gradient_accumulation_steps', 1)) + + print(f"Adjusted batch size from {original_batch_size} to {config['dataset']['batch_size']} per core") + print(f"Gradient accumulation steps: {config['gradient_accumulation_steps']}") + + # Save updated config + tpu_config_path = config_path.replace('.yaml', '_tpu.yaml') + with open(tpu_config_path, 'w') as f: + yaml.dump(config, f, default_flow_style=False) + + print(f"TPU configuration saved to: {tpu_config_path}") + return tpu_config_path + +def check_tpu_environment(): + """ + Check if TPU environment is properly set up + """ + try: + import torch_xla + import torch_xla.core.xla_model as xm + + # Check if TPUs are available + device = xm.xla_device() + print(f"TPU device available: {device}") + print(f"TPU ordinal: {xm.get_ordinal()}") + print(f"TPU world size: {xm.xrt_world_size()}") + + return True + except ImportError: + print("ERROR: torch_xla not installed. Please install PyTorch XLA for TPU support.") + return False + except Exception as e: + print(f"ERROR: TPU not available - {e}") + return False + +def run_tpu_training(config_path, num_cores=8): + """ + Launch TPU training using accelerate + """ + # Check TPU environment + if not check_tpu_environment(): + sys.exit(1) + + # Update config for TPU + tpu_config_path = update_config_for_tpu(config_path, num_cores) + + # Set TPU environment variables + os.environ['TPU_CORES'] = str(num_cores) + os.environ['XLA_USE_BF16'] = '1' # Enable bfloat16 + + # Launch training with accelerate + cmd = f"accelerate launch --tpu --num_processes {num_cores} train_model.py --config_path {tpu_config_path}" + + print(f"Launching TPU training with command:") + print(f" {cmd}") + print(f"Using {num_cores} TPU cores") + print("-" * 60) + + # Execute training + os.system(cmd) + +def main(): + parser = argparse.ArgumentParser(description='Launch TPU training for Brain-to-Text RNN') + parser.add_argument('--config', default='rnn_args.yaml', + help='Path to configuration file (default: rnn_args.yaml)') + parser.add_argument('--num_cores', type=int, default=8, + help='Number of TPU cores to use (default: 8)') + parser.add_argument('--check_only', action='store_true', + help='Only check TPU environment, do not launch training') + + args = parser.parse_args() + + # Verify config file exists + if not os.path.exists(args.config): + print(f"ERROR: Configuration file {args.config} not found") + sys.exit(1) + + if args.check_only: + check_tpu_environment() + return + + # Run TPU training + run_tpu_training(args.config, args.num_cores) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/model_training_nnn/rnn_args.yaml b/model_training_nnn/rnn_args.yaml index 2944d29..e824035 100644 --- a/model_training_nnn/rnn_args.yaml +++ b/model_training_nnn/rnn_args.yaml @@ -18,6 +18,12 @@ gpu_number: '1' # GPU number to use for training, formatted as a string (e.g., ' mode: train use_amp: true # whether to use automatic mixed precision (AMP) for training +# TPU and distributed training settings +use_tpu: false # whether to use TPU for training (set to true for TPU) +num_tpu_cores: 8 # number of TPU cores to use (typically 8 for v3-8 or v4-8) +gradient_accumulation_steps: 1 # number of gradient accumulation steps for distributed training +dataloader_num_workers: 0 # set to 0 for TPU to avoid multiprocessing issues + output_dir: trained_models/baseline_rnn # directory to save the trained model and logs checkpoint_dir: trained_models/baseline_rnn/checkpoint # directory to save checkpoints during training init_from_checkpoint: false # whether to initialize the model from a checkpoint diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index 123ead4..cad2966 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -182,12 +182,15 @@ class BrainToTextDecoder_Trainer: random_seed = self.args['dataset']['seed'], feature_subset = feature_subset ) + # Use TPU-optimized dataloader settings if TPU is enabled + num_workers = self.args['dataset']['dataloader_num_workers'] if self.args.get('use_tpu', False) else self.args['dataset']['num_dataloader_workers'] + self.train_loader = DataLoader( self.train_dataset, batch_size = None, # Dataset.__getitem__() already returns batches shuffle = self.args['dataset']['loader_shuffle'], - num_workers = self.args['dataset']['num_dataloader_workers'], - pin_memory = True + num_workers = num_workers, + pin_memory = True ) # val dataset and dataloader @@ -204,9 +207,9 @@ class BrainToTextDecoder_Trainer: self.val_loader = DataLoader( self.val_dataset, batch_size = None, # Dataset.__getitem__() already returns batches - shuffle = False, - num_workers = 0, - pin_memory = True + shuffle = False, + num_workers = 0, # Keep validation dataloader single-threaded for consistency + pin_memory = True ) self.logger.info("Successfully initialized datasets") @@ -365,47 +368,52 @@ class BrainToTextDecoder_Trainer: return LambdaLR(optim, lr_lambdas, -1) def load_model_checkpoint(self, load_path): - ''' - Load a training checkpoint ''' - checkpoint = torch.load(load_path, weights_only = False) # checkpoint is just a dict + Load a training checkpoint for distributed training + ''' + # Load checkpoint on CPU first to avoid OOM issues + checkpoint = torch.load(load_path, map_location='cpu', weights_only = False) # checkpoint is just a dict + + # Get unwrapped model for loading state dict + unwrapped_model = self.accelerator.unwrap_model(self.model) + unwrapped_model.load_state_dict(checkpoint['model_state_dict']) - self.model.load_state_dict(checkpoint['model_state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler_state_dict']) self.best_val_PER = checkpoint['val_PER'] # best phoneme error rate self.best_val_loss = checkpoint['val_loss'] if 'val_loss' in checkpoint.keys() else torch.inf - self.model.to(self.device) - - # Send optimizer params back to GPU - for state in self.optimizer.state.values(): - for k, v in state.items(): - if isinstance(v, torch.Tensor): - state[k] = v.to(self.device) + # Device handling is managed by Accelerator, no need to manually move to device self.logger.info("Loaded model from checkpoint: " + load_path) def save_model_checkpoint(self, save_path, PER, loss): ''' - Save a training checkpoint + Save a training checkpoint using Accelerator for distributed training ''' + # Only save on main process to avoid conflicts + if self.accelerator.is_main_process: + # Unwrap model to get base model for saving + unwrapped_model = self.accelerator.unwrap_model(self.model) - checkpoint = { - 'model_state_dict' : self.model.state_dict(), - 'optimizer_state_dict' : self.optimizer.state_dict(), - 'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(), - 'val_PER' : PER, - 'val_loss' : loss - } - - torch.save(checkpoint, save_path) - - self.logger.info("Saved model to checkpoint: " + save_path) + checkpoint = { + 'model_state_dict' : unwrapped_model.state_dict(), + 'optimizer_state_dict' : self.optimizer.state_dict(), + 'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(), + 'val_PER' : PER, + 'val_loss' : loss + } - # Save the args file alongside the checkpoint - with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f: - OmegaConf.save(config=self.args, f=f) + torch.save(checkpoint, save_path) + + self.logger.info("Saved model to checkpoint: " + save_path) + + # Save the args file alongside the checkpoint + with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f: + OmegaConf.save(config=self.args, f=f) + + # Wait for all processes to complete checkpoint saving + self.accelerator.wait_for_everyone() def create_attention_mask(self, sequence_lengths): @@ -685,13 +693,14 @@ class BrainToTextDecoder_Trainer: if self.args['dataset']['dataset_probability_val'][d] == 1: day_per[d] = {'total_edit_distance' : 0, 'total_seq_length' : 0} - for i, batch in enumerate(loader): + for i, batch in enumerate(loader): - features = batch['input_features'].to(self.device) - labels = batch['seq_class_ids'].to(self.device) - n_time_steps = batch['n_time_steps'].to(self.device) - phone_seq_lens = batch['phone_seq_lens'].to(self.device) - day_indicies = batch['day_indicies'].to(self.device) + # Data is automatically moved to device by Accelerator + features = batch['input_features'] + labels = batch['seq_class_ids'] + n_time_steps = batch['n_time_steps'] + phone_seq_lens = batch['phone_seq_lens'] + day_indicies = batch['day_indicies'] # Determine if we should perform validation on this batch day = day_indicies[0].item() @@ -702,7 +711,7 @@ class BrainToTextDecoder_Trainer: with torch.no_grad(): - with torch.autocast(device_type = "cuda", enabled = self.args['use_amp'], dtype = torch.bfloat16): + with self.accelerator.autocast(): features, n_time_steps = self.transform_data(features, n_time_steps, 'val') adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) @@ -768,4 +777,44 @@ class BrainToTextDecoder_Trainer: metrics['avg_PER'] = avg_PER.item() metrics['avg_loss'] = np.mean(metrics['losses']) - return metrics \ No newline at end of file + return metrics + + def inference(self, features, day_indicies, n_time_steps, mode='inference'): + ''' + TPU-compatible inference method for generating phoneme logits + ''' + self.model.eval() + + with torch.no_grad(): + with self.accelerator.autocast(): + # Apply data transformations (no augmentation for inference) + features, n_time_steps = self.transform_data(features, n_time_steps, 'val') + + # Get phoneme predictions + logits = self.model(features, day_indicies, None, False, mode) + + return logits + + def inference_batch(self, batch, mode='inference'): + ''' + TPU-compatible inference method for processing a full batch + ''' + self.model.eval() + + # Data is automatically moved to device by Accelerator + features = batch['input_features'] + day_indicies = batch['day_indicies'] + n_time_steps = batch['n_time_steps'] + + with torch.no_grad(): + with self.accelerator.autocast(): + # Apply data transformations (no augmentation for inference) + features, n_time_steps = self.transform_data(features, n_time_steps, 'val') + + # Calculate adjusted sequence lengths for CTC + adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) + + # Get phoneme predictions + logits = self.model(features, day_indicies, None, False, mode) + + return logits, adjusted_lens \ No newline at end of file diff --git a/model_training_nnn/train_model.py b/model_training_nnn/train_model.py index d456731..81390c2 100644 --- a/model_training_nnn/train_model.py +++ b/model_training_nnn/train_model.py @@ -1,6 +1,25 @@ +import argparse from omegaconf import OmegaConf from rnn_trainer import BrainToTextDecoder_Trainer -args = OmegaConf.load('rnn_args.yaml') -trainer = BrainToTextDecoder_Trainer(args) -metrics = trainer.train() \ No newline at end of file +def main(): + parser = argparse.ArgumentParser(description='Train Brain-to-Text RNN Model') + parser.add_argument('--config_path', default='rnn_args.yaml', + help='Path to configuration file (default: rnn_args.yaml)') + + args = parser.parse_args() + + # Load configuration + config = OmegaConf.load(args.config_path) + + # Initialize trainer + trainer = BrainToTextDecoder_Trainer(config) + + # Start training + trainer.train() + + print("Training completed successfully!") + print(f"Best validation PER: {trainer.best_val_PER:.5f}") + +if __name__ == "__main__": + main() \ No newline at end of file