Files
b2txt25/model_training_nnn_tpu/TPU_SETUP_GUIDE.md
Zchen 56fa336af0 tpu
2025-10-15 14:26:11 +08:00

5.8 KiB
Raw Blame History

TPU Training Setup Guide for Brain-to-Text RNN

This guide explains how to use the TPU support that has been added to the brain-to-text RNN training code.

Prerequisites

1. Install PyTorch XLA for TPU Support

# Install PyTorch XLA (adjust version as needed)
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html

# Or for specific PyTorch version:
pip install torch_xla==2.1.0 -f https://storage.googleapis.com/libtpu-releases/index.html

2. Install Accelerate Library

pip install accelerate

3. Verify TPU Access

# Check if TPU is available
python -c "import torch_xla; import torch_xla.core.xla_model as xm; print(f'TPU device: {xm.xla_device()}')"

Configuration Setup

1. Enable TPU in Configuration File

Update your rnn_args.yaml file with TPU settings:

# TPU and distributed training settings
use_tpu: true                        # Enable TPU training
num_tpu_cores: 8                     # Number of TPU cores (8 for v3-8 or v4-8)
gradient_accumulation_steps: 1       # Gradient accumulation for large effective batch size
dataloader_num_workers: 0           # Must be 0 for TPU to avoid multiprocessing issues
use_amp: true                       # Enable mixed precision (bfloat16)

# Adjust batch size for multi-core TPU
dataset:
  batch_size: 8                     # Per-core batch size (total = 8 cores × 8 = 64)

2. TPU-Optimized Hyperparameters

Recommended adjustments for TPU training:

# Learning rate scaling for distributed training
lr_max: 0.005                       # May need to scale with number of cores
lr_max_day: 0.005

# Batch size considerations
dataset:
  batch_size: 8                     # Per-core batch size
  days_per_batch: 4                 # Keep consistent across cores

Training Launch Options

# Basic TPU training with 8 cores
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8

# Check TPU environment only
python launch_tpu_training.py --check_only

# Custom configuration file
python launch_tpu_training.py --config my_tpu_config.yaml --num_cores 8

Method 2: Direct Accelerate Launch

# Configure accelerate (one-time setup)
accelerate config

# Or use provided TPU config
export ACCELERATE_CONFIG_FILE=accelerate_config_tpu.yaml

# Launch training
accelerate launch --config_file accelerate_config_tpu.yaml train_model.py --config_path rnn_args.yaml

Method 3: Manual XLA Launch (Advanced)

# Set TPU environment variables
export TPU_CORES=8
export XLA_USE_BF16=1

# Launch with PyTorch XLA
python -m torch_xla.distributed.xla_dist --tpu --num_devices 8 train_model.py --config_path rnn_args.yaml

Key TPU Features Implemented

1. Distributed Training Support

  • Automatic model parallelization across 8 TPU cores
  • Synchronized gradient updates across all cores
  • Proper checkpoint saving/loading for distributed training

2. Mixed Precision Training

  • Automatic bfloat16 precision for TPU optimization
  • Faster training with maintained numerical stability
  • Reduced memory usage

3. TPU-Optimized Data Loading

  • Single-threaded data loading (num_workers=0) for TPU compatibility
  • Automatic data distribution across TPU cores
  • Efficient batch processing

4. Inference Support

  • TPU-compatible inference methods added to trainer class
  • inference() and inference_batch() methods for production use
  • Automatic mixed precision during inference

Performance Optimization Tips

1. Batch Size Tuning

  • Start with total batch size = 64 (8 cores × 8 per core)
  • Increase gradually if memory allows
  • Monitor TPU utilization with top command

2. Gradient Accumulation

  • Use gradient_accumulation_steps to simulate larger batch sizes
  • Effective batch size = batch_size × num_cores × gradient_accumulation_steps

3. Learning Rate Scaling

  • Consider scaling learning rate with number of cores
  • Linear scaling: lr_new = lr_base × num_cores
  • May need warmup adjustment for large batch training

4. Memory Management

  • TPU v3-8: 128GB HBM memory total
  • TPU v4-8: 512GB HBM memory total
  • Monitor memory usage to avoid OOM errors

Monitoring and Debugging

1. TPU Utilization

# Monitor TPU usage
watch -n 1 'python -c "import torch_xla.core.xla_model as xm; print(f\"TPU cores: {xm.xrt_world_size()}\")"'

2. Training Logs

  • Training logs include device information and core count
  • Monitor validation metrics across all cores
  • Check for synchronization issues in distributed training

3. Common Issues and Solutions

Issue: "No TPU devices found"

  • Solution: Verify TPU runtime is started and accessible

Issue: "DataLoader workers > 0 causes hangs"

  • Solution: Set dataloader_num_workers: 0 in config

Issue: "Mixed precision errors"

  • Solution: Ensure use_amp: true and PyTorch XLA supports bfloat16

Issue: "Gradient synchronization timeouts"

  • Solution: Check network connectivity between TPU cores

Example Training Command

# Complete TPU training example
cd model_training_nnn

# 1. Update config for TPU
vim rnn_args.yaml  # Set use_tpu: true, num_tpu_cores: 8

# 2. Launch TPU training
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8

# 3. Monitor training progress
tail -f trained_models/baseline_rnn/training_log

Configuration Reference

Required TPU Settings

use_tpu: true
num_tpu_cores: 8
dataloader_num_workers: 0
use_amp: true

Optional TPU Optimizations

gradient_accumulation_steps: 1
dataset:
  batch_size: 8  # Per-core batch size
mixed_precision: bf16

This TPU implementation allows you to leverage all 8 cores of your TPU for both training and inference, with automatic distributed training management through the Accelerate library.