Files
b2txt25/model_training_nnn_tpu/README.md
Zchen 56fa336af0 tpu
2025-10-15 14:26:11 +08:00

80 lines
3.8 KiB
Markdown

# TPU-Optimized Brain-to-Text Model Training
This directory contains TPU-optimized code for training the brain-to-text RNN model with advanced adversarial training architecture. The model is based on "*An Accurate and Rapidly Calibrating Speech Neuroprosthesis*" by Card et al. (2024), enhanced with three-model adversarial training and comprehensive XLA optimizations for efficient TPU training.
## Key Features
- **Triple-Model Adversarial Architecture**: NoiseModel + CleanSpeechModel + NoisySpeechModel for robust neural decoding
- **XLA/TPU Optimizations**: Comprehensive optimizations for fast compilation and efficient TPU utilization
- **Mixed Precision Training**: bfloat16 support with full dtype consistency
- **Distributed Training**: 8-core TPU support with Accelerate library integration
- **687M Parameters**: Large-scale model with patch processing and day-specific adaptations
For detailed technical documentation, see [TPU_MODEL_SUMMARY.md](TPU_MODEL_SUMMARY.md).
## Setup
1. Install the required `b2txt25` conda environment by following the instructions in the root `README.md` file. This will set up the necessary dependencies for running the model training and evaluation code.
2. Download the dataset from Dryad: [Dryad Dataset](https://datadryad.org/dataset/doi:10.5061/dryad.dncjsxm85). Place the downloaded data in the `data` directory. See the main [README.md](../README.md) file for more details on the included datasets and the proper `data` directory structure.
## TPU Training
### Triple-Model Adversarial Architecture
This implementation features an advanced three-model adversarial training system:
- **NoiseModel**: 2-layer GRU that estimates noise in neural data
- **CleanSpeechModel**: 3-layer GRU that processes denoised signals for speech recognition
- **NoisySpeechModel**: 2-layer GRU that processes noise signals for adversarial training
The architecture uses residual connections and gradient reversal layers (GRL) to improve robustness. All models include day-specific input layers (512x512 linear with softsign activation), patch processing (14 timesteps), and are optimized for XLA compilation on TPU.
### Training Methods
#### Option 1: Direct Training
```bash
conda activate b2txt25
python train_model.py --config_path rnn_args.yaml
```
#### Option 2: Launcher Script (Recommended)
```bash
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
```
#### Option 3: Accelerate
```bash
accelerate launch --config_file accelerate_config_tpu.yaml train_model.py
```
The model trains for 120,000 mini-batches with mixed precision (bfloat16) and distributed training across 8 TPU cores. Expected training time varies based on TPU type and configuration. All hyperparameters are specified in [`rnn_args.yaml`](rnn_args.yaml).
## Model Configuration
### Key Configuration Files
- **`rnn_args.yaml`**: Main training configuration with adversarial training settings
- **`accelerate_config_tpu.yaml`**: Accelerate library configuration for TPU
- **`launch_tpu_training.py`**: Convenient TPU training launcher
### Adversarial Training Settings
```yaml
adversarial:
enabled: true
grl_lambda: 0.5 # Gradient Reversal Layer strength
noisy_loss_weight: 0.2 # Weight for noisy branch CTC loss
noise_l2_weight: 0.0 # L2 regularization on noise output
warmup_steps: 0 # Steps before enabling adversarial training
```
### TPU-Specific Settings
```yaml
use_tpu: true
num_tpu_cores: 8
gradient_accumulation_steps: 2
use_amp: true # bfloat16 mixed precision
batch_size: 32 # Per-core batch size
num_dataloader_workers: 0 # Required for TPU
```
## Evaluation
Model evaluation using the trained TripleGRUDecoder requires the language model pipeline. Please refer to the main project README for complete evaluation setup instructions. The evaluation scripts in this directory are currently being adapted for TPU compatibility.