5.8 KiB
5.8 KiB
TPU Training Setup Guide for Brain-to-Text RNN
This guide explains how to use the TPU support that has been added to the brain-to-text RNN training code.
Prerequisites
1. Install PyTorch XLA for TPU Support
# Install PyTorch XLA (adjust version as needed)
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
# Or for specific PyTorch version:
pip install torch_xla==2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
2. Install Accelerate Library
pip install accelerate
3. Verify TPU Access
# Check if TPU is available
python -c "import torch_xla; import torch_xla.core.xla_model as xm; print(f'TPU device: {xm.xla_device()}')"
Configuration Setup
1. Enable TPU in Configuration File
Update your rnn_args.yaml
file with TPU settings:
# TPU and distributed training settings
use_tpu: true # Enable TPU training
num_tpu_cores: 8 # Number of TPU cores (8 for v3-8 or v4-8)
gradient_accumulation_steps: 1 # Gradient accumulation for large effective batch size
dataloader_num_workers: 0 # Must be 0 for TPU to avoid multiprocessing issues
use_amp: true # Enable mixed precision (bfloat16)
# Adjust batch size for multi-core TPU
dataset:
batch_size: 8 # Per-core batch size (total = 8 cores × 8 = 64)
2. TPU-Optimized Hyperparameters
Recommended adjustments for TPU training:
# Learning rate scaling for distributed training
lr_max: 0.005 # May need to scale with number of cores
lr_max_day: 0.005
# Batch size considerations
dataset:
batch_size: 8 # Per-core batch size
days_per_batch: 4 # Keep consistent across cores
Training Launch Options
Method 1: Using the TPU Launch Script (Recommended)
# Basic TPU training with 8 cores
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
# Check TPU environment only
python launch_tpu_training.py --check_only
# Custom configuration file
python launch_tpu_training.py --config my_tpu_config.yaml --num_cores 8
Method 2: Direct Accelerate Launch
# Configure accelerate (one-time setup)
accelerate config
# Or use provided TPU config
export ACCELERATE_CONFIG_FILE=accelerate_config_tpu.yaml
# Launch training
accelerate launch --config_file accelerate_config_tpu.yaml train_model.py --config_path rnn_args.yaml
Method 3: Manual XLA Launch (Advanced)
# Set TPU environment variables
export TPU_CORES=8
export XLA_USE_BF16=1
# Launch with PyTorch XLA
python -m torch_xla.distributed.xla_dist --tpu --num_devices 8 train_model.py --config_path rnn_args.yaml
Key TPU Features Implemented
1. Distributed Training Support
- Automatic model parallelization across 8 TPU cores
- Synchronized gradient updates across all cores
- Proper checkpoint saving/loading for distributed training
2. Mixed Precision Training
- Automatic bfloat16 precision for TPU optimization
- Faster training with maintained numerical stability
- Reduced memory usage
3. TPU-Optimized Data Loading
- Single-threaded data loading (num_workers=0) for TPU compatibility
- Automatic data distribution across TPU cores
- Efficient batch processing
4. Inference Support
- TPU-compatible inference methods added to trainer class
inference()
andinference_batch()
methods for production use- Automatic mixed precision during inference
Performance Optimization Tips
1. Batch Size Tuning
- Start with total batch size = 64 (8 cores × 8 per core)
- Increase gradually if memory allows
- Monitor TPU utilization with
top
command
2. Gradient Accumulation
- Use
gradient_accumulation_steps
to simulate larger batch sizes - Effective batch size = batch_size × num_cores × gradient_accumulation_steps
3. Learning Rate Scaling
- Consider scaling learning rate with number of cores
- Linear scaling:
lr_new = lr_base × num_cores
- May need warmup adjustment for large batch training
4. Memory Management
- TPU v3-8: 128GB HBM memory total
- TPU v4-8: 512GB HBM memory total
- Monitor memory usage to avoid OOM errors
Monitoring and Debugging
1. TPU Utilization
# Monitor TPU usage
watch -n 1 'python -c "import torch_xla.core.xla_model as xm; print(f\"TPU cores: {xm.xrt_world_size()}\")"'
2. Training Logs
- Training logs include device information and core count
- Monitor validation metrics across all cores
- Check for synchronization issues in distributed training
3. Common Issues and Solutions
Issue: "No TPU devices found"
- Solution: Verify TPU runtime is started and accessible
Issue: "DataLoader workers > 0 causes hangs"
- Solution: Set
dataloader_num_workers: 0
in config
Issue: "Mixed precision errors"
- Solution: Ensure
use_amp: true
and PyTorch XLA supports bfloat16
Issue: "Gradient synchronization timeouts"
- Solution: Check network connectivity between TPU cores
Example Training Command
# Complete TPU training example
cd model_training_nnn
# 1. Update config for TPU
vim rnn_args.yaml # Set use_tpu: true, num_tpu_cores: 8
# 2. Launch TPU training
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
# 3. Monitor training progress
tail -f trained_models/baseline_rnn/training_log
Configuration Reference
Required TPU Settings
use_tpu: true
num_tpu_cores: 8
dataloader_num_workers: 0
use_amp: true
Optional TPU Optimizations
gradient_accumulation_steps: 1
dataset:
batch_size: 8 # Per-core batch size
mixed_precision: bf16
This TPU implementation allows you to leverage all 8 cores of your TPU for both training and inference, with automatic distributed training management through the Accelerate library.