tpu支持
This commit is contained in:
204
model_training_nnn/TPU_SETUP_GUIDE.md
Normal file
204
model_training_nnn/TPU_SETUP_GUIDE.md
Normal file
@@ -0,0 +1,204 @@
|
||||
# TPU Training Setup Guide for Brain-to-Text RNN
|
||||
|
||||
This guide explains how to use the TPU support that has been added to the brain-to-text RNN training code.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. Install PyTorch XLA for TPU Support
|
||||
```bash
|
||||
# Install PyTorch XLA (adjust version as needed)
|
||||
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
|
||||
# Or for specific PyTorch version:
|
||||
pip install torch_xla==2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
```
|
||||
|
||||
### 2. Install Accelerate Library
|
||||
```bash
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
### 3. Verify TPU Access
|
||||
```bash
|
||||
# Check if TPU is available
|
||||
python -c "import torch_xla; import torch_xla.core.xla_model as xm; print(f'TPU device: {xm.xla_device()}')"
|
||||
```
|
||||
|
||||
## Configuration Setup
|
||||
|
||||
### 1. Enable TPU in Configuration File
|
||||
|
||||
Update your `rnn_args.yaml` file with TPU settings:
|
||||
|
||||
```yaml
|
||||
# TPU and distributed training settings
|
||||
use_tpu: true # Enable TPU training
|
||||
num_tpu_cores: 8 # Number of TPU cores (8 for v3-8 or v4-8)
|
||||
gradient_accumulation_steps: 1 # Gradient accumulation for large effective batch size
|
||||
dataloader_num_workers: 0 # Must be 0 for TPU to avoid multiprocessing issues
|
||||
use_amp: true # Enable mixed precision (bfloat16)
|
||||
|
||||
# Adjust batch size for multi-core TPU
|
||||
dataset:
|
||||
batch_size: 8 # Per-core batch size (total = 8 cores × 8 = 64)
|
||||
```
|
||||
|
||||
### 2. TPU-Optimized Hyperparameters
|
||||
|
||||
Recommended adjustments for TPU training:
|
||||
|
||||
```yaml
|
||||
# Learning rate scaling for distributed training
|
||||
lr_max: 0.005 # May need to scale with number of cores
|
||||
lr_max_day: 0.005
|
||||
|
||||
# Batch size considerations
|
||||
dataset:
|
||||
batch_size: 8 # Per-core batch size
|
||||
days_per_batch: 4 # Keep consistent across cores
|
||||
```
|
||||
|
||||
## Training Launch Options
|
||||
|
||||
### Method 1: Using the TPU Launch Script (Recommended)
|
||||
|
||||
```bash
|
||||
# Basic TPU training with 8 cores
|
||||
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
|
||||
|
||||
# Check TPU environment only
|
||||
python launch_tpu_training.py --check_only
|
||||
|
||||
# Custom configuration file
|
||||
python launch_tpu_training.py --config my_tpu_config.yaml --num_cores 8
|
||||
```
|
||||
|
||||
### Method 2: Direct Accelerate Launch
|
||||
|
||||
```bash
|
||||
# Configure accelerate (one-time setup)
|
||||
accelerate config
|
||||
|
||||
# Or use provided TPU config
|
||||
export ACCELERATE_CONFIG_FILE=accelerate_config_tpu.yaml
|
||||
|
||||
# Launch training
|
||||
accelerate launch --config_file accelerate_config_tpu.yaml train_model.py --config_path rnn_args.yaml
|
||||
```
|
||||
|
||||
### Method 3: Manual XLA Launch (Advanced)
|
||||
|
||||
```bash
|
||||
# Set TPU environment variables
|
||||
export TPU_CORES=8
|
||||
export XLA_USE_BF16=1
|
||||
|
||||
# Launch with PyTorch XLA
|
||||
python -m torch_xla.distributed.xla_dist --tpu --num_devices 8 train_model.py --config_path rnn_args.yaml
|
||||
```
|
||||
|
||||
## Key TPU Features Implemented
|
||||
|
||||
### 1. Distributed Training Support
|
||||
- Automatic model parallelization across 8 TPU cores
|
||||
- Synchronized gradient updates across all cores
|
||||
- Proper checkpoint saving/loading for distributed training
|
||||
|
||||
### 2. Mixed Precision Training
|
||||
- Automatic bfloat16 precision for TPU optimization
|
||||
- Faster training with maintained numerical stability
|
||||
- Reduced memory usage
|
||||
|
||||
### 3. TPU-Optimized Data Loading
|
||||
- Single-threaded data loading (num_workers=0) for TPU compatibility
|
||||
- Automatic data distribution across TPU cores
|
||||
- Efficient batch processing
|
||||
|
||||
### 4. Inference Support
|
||||
- TPU-compatible inference methods added to trainer class
|
||||
- `inference()` and `inference_batch()` methods for production use
|
||||
- Automatic mixed precision during inference
|
||||
|
||||
## Performance Optimization Tips
|
||||
|
||||
### 1. Batch Size Tuning
|
||||
- Start with total batch size = 64 (8 cores × 8 per core)
|
||||
- Increase gradually if memory allows
|
||||
- Monitor TPU utilization with `top` command
|
||||
|
||||
### 2. Gradient Accumulation
|
||||
- Use `gradient_accumulation_steps` to simulate larger batch sizes
|
||||
- Effective batch size = batch_size × num_cores × gradient_accumulation_steps
|
||||
|
||||
### 3. Learning Rate Scaling
|
||||
- Consider scaling learning rate with number of cores
|
||||
- Linear scaling: `lr_new = lr_base × num_cores`
|
||||
- May need warmup adjustment for large batch training
|
||||
|
||||
### 4. Memory Management
|
||||
- TPU v3-8: 128GB HBM memory total
|
||||
- TPU v4-8: 512GB HBM memory total
|
||||
- Monitor memory usage to avoid OOM errors
|
||||
|
||||
## Monitoring and Debugging
|
||||
|
||||
### 1. TPU Utilization
|
||||
```bash
|
||||
# Monitor TPU usage
|
||||
watch -n 1 'python -c "import torch_xla.core.xla_model as xm; print(f\"TPU cores: {xm.xrt_world_size()}\")"'
|
||||
```
|
||||
|
||||
### 2. Training Logs
|
||||
- Training logs include device information and core count
|
||||
- Monitor validation metrics across all cores
|
||||
- Check for synchronization issues in distributed training
|
||||
|
||||
### 3. Common Issues and Solutions
|
||||
|
||||
**Issue**: "No TPU devices found"
|
||||
- **Solution**: Verify TPU runtime is started and accessible
|
||||
|
||||
**Issue**: "DataLoader workers > 0 causes hangs"
|
||||
- **Solution**: Set `dataloader_num_workers: 0` in config
|
||||
|
||||
**Issue**: "Mixed precision errors"
|
||||
- **Solution**: Ensure `use_amp: true` and PyTorch XLA supports bfloat16
|
||||
|
||||
**Issue**: "Gradient synchronization timeouts"
|
||||
- **Solution**: Check network connectivity between TPU cores
|
||||
|
||||
## Example Training Command
|
||||
|
||||
```bash
|
||||
# Complete TPU training example
|
||||
cd model_training_nnn
|
||||
|
||||
# 1. Update config for TPU
|
||||
vim rnn_args.yaml # Set use_tpu: true, num_tpu_cores: 8
|
||||
|
||||
# 2. Launch TPU training
|
||||
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
|
||||
|
||||
# 3. Monitor training progress
|
||||
tail -f trained_models/baseline_rnn/training_log
|
||||
```
|
||||
|
||||
## Configuration Reference
|
||||
|
||||
### Required TPU Settings
|
||||
```yaml
|
||||
use_tpu: true
|
||||
num_tpu_cores: 8
|
||||
dataloader_num_workers: 0
|
||||
use_amp: true
|
||||
```
|
||||
|
||||
### Optional TPU Optimizations
|
||||
```yaml
|
||||
gradient_accumulation_steps: 1
|
||||
dataset:
|
||||
batch_size: 8 # Per-core batch size
|
||||
mixed_precision: bf16
|
||||
```
|
||||
|
||||
This TPU implementation allows you to leverage all 8 cores of your TPU for both training and inference, with automatic distributed training management through the Accelerate library.
|
||||
Reference in New Issue
Block a user