161 lines
5.4 KiB
Python
161 lines
5.4 KiB
Python
![]() |
#!/usr/bin/env python3
|
||
|
"""
|
||
|
TPU Training Launch Script for Brain-to-Text RNN Model
|
||
|
|
||
|
This script provides easy TPU training setup using Accelerate library.
|
||
|
Supports both single TPU core and multi-core (8 cores) training.
|
||
|
|
||
|
Usage:
|
||
|
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
|
||
|
|
||
|
Requirements:
|
||
|
- PyTorch XLA installed
|
||
|
- Accelerate library installed
|
||
|
- TPU runtime available
|
||
|
"""
|
||
|
|
||
|
import argparse
|
||
|
import yaml
|
||
|
import os
|
||
|
import sys
|
||
|
from pathlib import Path
|
||
|
|
||
|
def update_config_for_tpu(config_path, num_cores=8):
|
||
|
"""
|
||
|
Update configuration file to enable TPU training
|
||
|
"""
|
||
|
with open(config_path, 'r') as f:
|
||
|
config = yaml.safe_load(f)
|
||
|
|
||
|
# Enable TPU settings
|
||
|
config['use_tpu'] = True
|
||
|
config['num_tpu_cores'] = num_cores
|
||
|
config['dataloader_num_workers'] = 0 # Required for TPU
|
||
|
config['use_amp'] = True # Enable mixed precision with bfloat16
|
||
|
|
||
|
# Adjust batch size and gradient accumulation for multi-core TPU
|
||
|
if num_cores > 1:
|
||
|
# Distribute batch size across cores
|
||
|
original_batch_size = config['dataset']['batch_size']
|
||
|
config['dataset']['batch_size'] = max(1, original_batch_size // num_cores)
|
||
|
config['gradient_accumulation_steps'] = max(1, config.get('gradient_accumulation_steps', 1))
|
||
|
|
||
|
print(f"Adjusted batch size from {original_batch_size} to {config['dataset']['batch_size']} per core")
|
||
|
print(f"Gradient accumulation steps: {config['gradient_accumulation_steps']}")
|
||
|
|
||
|
# Save updated config
|
||
|
tpu_config_path = config_path.replace('.yaml', '_tpu.yaml')
|
||
|
with open(tpu_config_path, 'w') as f:
|
||
|
yaml.dump(config, f, default_flow_style=False)
|
||
|
|
||
|
print(f"TPU configuration saved to: {tpu_config_path}")
|
||
|
return tpu_config_path
|
||
|
|
||
|
def check_tpu_environment():
|
||
|
"""
|
||
|
Check if TPU environment is properly set up
|
||
|
"""
|
||
|
try:
|
||
|
import torch_xla
|
||
|
import torch_xla.core.xla_model as xm
|
||
|
|
||
|
# Check if TPUs are available
|
||
|
device = xm.xla_device()
|
||
|
print(f"TPU device available: {device}")
|
||
|
print(f"TPU ordinal: {xm.get_ordinal()}")
|
||
|
print(f"TPU world size: {xm.xrt_world_size()}")
|
||
|
|
||
|
return True
|
||
|
except ImportError:
|
||
|
print("ERROR: torch_xla not installed. Please install PyTorch XLA for TPU support.")
|
||
|
return False
|
||
|
except Exception as e:
|
||
|
print(f"ERROR: TPU not available - {e}")
|
||
|
return False
|
||
|
|
||
|
def run_tpu_training(config_path, num_cores=8):
|
||
|
"""
|
||
|
Launch TPU training using accelerate
|
||
|
"""
|
||
|
# Check TPU environment
|
||
|
if not check_tpu_environment():
|
||
|
sys.exit(1)
|
||
|
|
||
|
# Update config for TPU
|
||
|
tpu_config_path = update_config_for_tpu(config_path, num_cores)
|
||
|
|
||
|
# Set TPU environment variables BEFORE launching training
|
||
|
os.environ['TPU_CORES'] = str(num_cores)
|
||
|
os.environ['XLA_USE_BF16'] = '1' # Enable bfloat16
|
||
|
|
||
|
# Critical XLA multi-threading settings - must be set before torch_xla import
|
||
|
cpu_count = os.cpu_count()
|
||
|
os.environ['XLA_FLAGS'] = (
|
||
|
'--xla_cpu_multi_thread_eigen=true '
|
||
|
'--xla_cpu_enable_fast_math=true '
|
||
|
f'--xla_force_host_platform_device_count={cpu_count}'
|
||
|
)
|
||
|
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
|
||
|
|
||
|
print(f"Set XLA compilation to use {cpu_count} CPU threads")
|
||
|
print(f"XLA_FLAGS: {os.environ['XLA_FLAGS']}")
|
||
|
print(f"PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
|
||
|
|
||
|
# Launch training with accelerate using subprocess to ensure environment variables are passed
|
||
|
cmd = f"accelerate launch --tpu --num_processes {num_cores} train_model.py --config_path {tpu_config_path}"
|
||
|
|
||
|
print(f"Launching TPU training with command:")
|
||
|
print(f" {cmd}")
|
||
|
print(f"Using {num_cores} TPU cores")
|
||
|
print("-" * 60)
|
||
|
|
||
|
# Use subprocess to ensure environment variables are properly inherited
|
||
|
import subprocess
|
||
|
|
||
|
# Create environment with our XLA settings
|
||
|
env = os.environ.copy()
|
||
|
env.update({
|
||
|
'TPU_CORES': str(num_cores),
|
||
|
'XLA_USE_BF16': '1',
|
||
|
'XLA_FLAGS': (
|
||
|
'--xla_cpu_multi_thread_eigen=true '
|
||
|
'--xla_cpu_enable_fast_math=true '
|
||
|
f'--xla_force_host_platform_device_count={cpu_count}'
|
||
|
),
|
||
|
'PYTORCH_XLA_COMPILATION_THREADS': str(cpu_count)
|
||
|
})
|
||
|
|
||
|
print(f"Environment variables set for subprocess:")
|
||
|
print(f" XLA_FLAGS: {env['XLA_FLAGS']}")
|
||
|
print(f" PYTORCH_XLA_COMPILATION_THREADS: {env['PYTORCH_XLA_COMPILATION_THREADS']}")
|
||
|
print("-" * 60)
|
||
|
|
||
|
# Execute training with proper environment
|
||
|
result = subprocess.run(cmd.split(), env=env)
|
||
|
return result.returncode
|
||
|
|
||
|
def main():
|
||
|
parser = argparse.ArgumentParser(description='Launch TPU training for Brain-to-Text RNN')
|
||
|
parser.add_argument('--config', default='rnn_args.yaml',
|
||
|
help='Path to configuration file (default: rnn_args.yaml)')
|
||
|
parser.add_argument('--num_cores', type=int, default=8,
|
||
|
help='Number of TPU cores to use (default: 8)')
|
||
|
parser.add_argument('--check_only', action='store_true',
|
||
|
help='Only check TPU environment, do not launch training')
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
# Verify config file exists
|
||
|
if not os.path.exists(args.config):
|
||
|
print(f"ERROR: Configuration file {args.config} not found")
|
||
|
sys.exit(1)
|
||
|
|
||
|
if args.check_only:
|
||
|
check_tpu_environment()
|
||
|
return
|
||
|
|
||
|
# Run TPU training
|
||
|
run_tpu_training(args.config, args.num_cores)
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|