#!/usr/bin/env python3 """ TPU Training Launch Script for Brain-to-Text RNN Model This script provides easy TPU training setup using Accelerate library. Supports both single TPU core and multi-core (8 cores) training. Usage: python launch_tpu_training.py --config rnn_args.yaml --num_cores 8 Requirements: - PyTorch XLA installed - Accelerate library installed - TPU runtime available """ import argparse import yaml import os import sys from pathlib import Path def update_config_for_tpu(config_path, num_cores=8): """ Update configuration file to enable TPU training """ with open(config_path, 'r') as f: config = yaml.safe_load(f) # Enable TPU settings config['use_tpu'] = True config['num_tpu_cores'] = num_cores config['dataloader_num_workers'] = 0 # Required for TPU config['use_amp'] = True # Enable mixed precision with bfloat16 # Adjust batch size and gradient accumulation for multi-core TPU if num_cores > 1: # Distribute batch size across cores original_batch_size = config['dataset']['batch_size'] config['dataset']['batch_size'] = max(1, original_batch_size // num_cores) config['gradient_accumulation_steps'] = max(1, config.get('gradient_accumulation_steps', 1)) print(f"Adjusted batch size from {original_batch_size} to {config['dataset']['batch_size']} per core") print(f"Gradient accumulation steps: {config['gradient_accumulation_steps']}") # Save updated config tpu_config_path = config_path.replace('.yaml', '_tpu.yaml') with open(tpu_config_path, 'w') as f: yaml.dump(config, f, default_flow_style=False) print(f"TPU configuration saved to: {tpu_config_path}") return tpu_config_path def check_tpu_environment(): """ Check if TPU environment is properly set up """ try: import torch_xla import torch_xla.core.xla_model as xm # Check if TPUs are available device = xm.xla_device() print(f"TPU device available: {device}") print(f"TPU ordinal: {xm.get_ordinal()}") print(f"TPU world size: {xm.xrt_world_size()}") return True except ImportError: print("ERROR: torch_xla not installed. Please install PyTorch XLA for TPU support.") return False except Exception as e: print(f"ERROR: TPU not available - {e}") return False def run_tpu_training(config_path, num_cores=8): """ Launch TPU training using accelerate """ # Check TPU environment if not check_tpu_environment(): sys.exit(1) # Update config for TPU tpu_config_path = update_config_for_tpu(config_path, num_cores) # Set TPU environment variables BEFORE launching training os.environ['TPU_CORES'] = str(num_cores) os.environ['XLA_USE_BF16'] = '1' # Enable bfloat16 # Critical XLA multi-threading settings - must be set before torch_xla import cpu_count = os.cpu_count() os.environ['XLA_FLAGS'] = ( '--xla_cpu_multi_thread_eigen=true ' '--xla_cpu_enable_fast_math=true ' f'--xla_force_host_platform_device_count={cpu_count}' ) os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count) print(f"Set XLA compilation to use {cpu_count} CPU threads") print(f"XLA_FLAGS: {os.environ['XLA_FLAGS']}") print(f"PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}") # Launch training with accelerate using subprocess to ensure environment variables are passed cmd = f"accelerate launch --tpu --num_processes {num_cores} train_model.py --config_path {tpu_config_path}" print(f"Launching TPU training with command:") print(f" {cmd}") print(f"Using {num_cores} TPU cores") print("-" * 60) # Use subprocess to ensure environment variables are properly inherited import subprocess # Create environment with our XLA settings env = os.environ.copy() env.update({ 'TPU_CORES': str(num_cores), 'XLA_USE_BF16': '1', 'XLA_FLAGS': ( '--xla_cpu_multi_thread_eigen=true ' '--xla_cpu_enable_fast_math=true ' f'--xla_force_host_platform_device_count={cpu_count}' ), 'PYTORCH_XLA_COMPILATION_THREADS': str(cpu_count) }) print(f"Environment variables set for subprocess:") print(f" XLA_FLAGS: {env['XLA_FLAGS']}") print(f" PYTORCH_XLA_COMPILATION_THREADS: {env['PYTORCH_XLA_COMPILATION_THREADS']}") print("-" * 60) # Execute training with proper environment result = subprocess.run(cmd.split(), env=env) return result.returncode def main(): parser = argparse.ArgumentParser(description='Launch TPU training for Brain-to-Text RNN') parser.add_argument('--config', default='rnn_args.yaml', help='Path to configuration file (default: rnn_args.yaml)') parser.add_argument('--num_cores', type=int, default=8, help='Number of TPU cores to use (default: 8)') parser.add_argument('--check_only', action='store_true', help='Only check TPU environment, do not launch training') args = parser.parse_args() # Verify config file exists if not os.path.exists(args.config): print(f"ERROR: Configuration file {args.config} not found") sys.exit(1) if args.check_only: check_tpu_environment() return # Run TPU training run_tpu_training(args.config, args.num_cores) if __name__ == "__main__": main()