Files
b2txt25/model_training_nnn_tpu/TPU_SETUP_GUIDE.md
Zchen 56fa336af0 tpu
2025-10-15 14:26:11 +08:00

204 lines
5.8 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# TPU Training Setup Guide for Brain-to-Text RNN
This guide explains how to use the TPU support that has been added to the brain-to-text RNN training code.
## Prerequisites
### 1. Install PyTorch XLA for TPU Support
```bash
# Install PyTorch XLA (adjust version as needed)
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
# Or for specific PyTorch version:
pip install torch_xla==2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
```
### 2. Install Accelerate Library
```bash
pip install accelerate
```
### 3. Verify TPU Access
```bash
# Check if TPU is available
python -c "import torch_xla; import torch_xla.core.xla_model as xm; print(f'TPU device: {xm.xla_device()}')"
```
## Configuration Setup
### 1. Enable TPU in Configuration File
Update your `rnn_args.yaml` file with TPU settings:
```yaml
# TPU and distributed training settings
use_tpu: true # Enable TPU training
num_tpu_cores: 8 # Number of TPU cores (8 for v3-8 or v4-8)
gradient_accumulation_steps: 1 # Gradient accumulation for large effective batch size
dataloader_num_workers: 0 # Must be 0 for TPU to avoid multiprocessing issues
use_amp: true # Enable mixed precision (bfloat16)
# Adjust batch size for multi-core TPU
dataset:
batch_size: 8 # Per-core batch size (total = 8 cores × 8 = 64)
```
### 2. TPU-Optimized Hyperparameters
Recommended adjustments for TPU training:
```yaml
# Learning rate scaling for distributed training
lr_max: 0.005 # May need to scale with number of cores
lr_max_day: 0.005
# Batch size considerations
dataset:
batch_size: 8 # Per-core batch size
days_per_batch: 4 # Keep consistent across cores
```
## Training Launch Options
### Method 1: Using the TPU Launch Script (Recommended)
```bash
# Basic TPU training with 8 cores
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
# Check TPU environment only
python launch_tpu_training.py --check_only
# Custom configuration file
python launch_tpu_training.py --config my_tpu_config.yaml --num_cores 8
```
### Method 2: Direct Accelerate Launch
```bash
# Configure accelerate (one-time setup)
accelerate config
# Or use provided TPU config
export ACCELERATE_CONFIG_FILE=accelerate_config_tpu.yaml
# Launch training
accelerate launch --config_file accelerate_config_tpu.yaml train_model.py --config_path rnn_args.yaml
```
### Method 3: Manual XLA Launch (Advanced)
```bash
# Set TPU environment variables
export TPU_CORES=8
export XLA_USE_BF16=1
# Launch with PyTorch XLA
python -m torch_xla.distributed.xla_dist --tpu --num_devices 8 train_model.py --config_path rnn_args.yaml
```
## Key TPU Features Implemented
### 1. Distributed Training Support
- Automatic model parallelization across 8 TPU cores
- Synchronized gradient updates across all cores
- Proper checkpoint saving/loading for distributed training
### 2. Mixed Precision Training
- Automatic bfloat16 precision for TPU optimization
- Faster training with maintained numerical stability
- Reduced memory usage
### 3. TPU-Optimized Data Loading
- Single-threaded data loading (num_workers=0) for TPU compatibility
- Automatic data distribution across TPU cores
- Efficient batch processing
### 4. Inference Support
- TPU-compatible inference methods added to trainer class
- `inference()` and `inference_batch()` methods for production use
- Automatic mixed precision during inference
## Performance Optimization Tips
### 1. Batch Size Tuning
- Start with total batch size = 64 (8 cores × 8 per core)
- Increase gradually if memory allows
- Monitor TPU utilization with `top` command
### 2. Gradient Accumulation
- Use `gradient_accumulation_steps` to simulate larger batch sizes
- Effective batch size = batch_size × num_cores × gradient_accumulation_steps
### 3. Learning Rate Scaling
- Consider scaling learning rate with number of cores
- Linear scaling: `lr_new = lr_base × num_cores`
- May need warmup adjustment for large batch training
### 4. Memory Management
- TPU v3-8: 128GB HBM memory total
- TPU v4-8: 512GB HBM memory total
- Monitor memory usage to avoid OOM errors
## Monitoring and Debugging
### 1. TPU Utilization
```bash
# Monitor TPU usage
watch -n 1 'python -c "import torch_xla.core.xla_model as xm; print(f\"TPU cores: {xm.xrt_world_size()}\")"'
```
### 2. Training Logs
- Training logs include device information and core count
- Monitor validation metrics across all cores
- Check for synchronization issues in distributed training
### 3. Common Issues and Solutions
**Issue**: "No TPU devices found"
- **Solution**: Verify TPU runtime is started and accessible
**Issue**: "DataLoader workers > 0 causes hangs"
- **Solution**: Set `dataloader_num_workers: 0` in config
**Issue**: "Mixed precision errors"
- **Solution**: Ensure `use_amp: true` and PyTorch XLA supports bfloat16
**Issue**: "Gradient synchronization timeouts"
- **Solution**: Check network connectivity between TPU cores
## Example Training Command
```bash
# Complete TPU training example
cd model_training_nnn
# 1. Update config for TPU
vim rnn_args.yaml # Set use_tpu: true, num_tpu_cores: 8
# 2. Launch TPU training
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
# 3. Monitor training progress
tail -f trained_models/baseline_rnn/training_log
```
## Configuration Reference
### Required TPU Settings
```yaml
use_tpu: true
num_tpu_cores: 8
dataloader_num_workers: 0
use_amp: true
```
### Optional TPU Optimizations
```yaml
gradient_accumulation_steps: 1
dataset:
batch_size: 8 # Per-core batch size
mixed_precision: bf16
```
This TPU implementation allows you to leverage all 8 cores of your TPU for both training and inference, with automatic distributed training management through the Accelerate library.