Files
b2txt25/CLAUDE.md
Zchen cd52ba51ba tpu
2025-10-14 23:22:59 +08:00

9.8 KiB
Raw Blame History

CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

Project Overview

This repository contains the code and data for "An Accurate and Rapidly Calibrating Speech Neuroprosthesis" published in the New England Journal of Medicine (2024). It implements a brain-to-text system that converts neural signals from speech motor cortex into text using RNN models and n-gram language models.

Development Environment Setup

Main Environment (b2txt25)

./setup.sh
conda activate b2txt25

Language Model Environment (b2txt25_lm)

./setup_lm.sh
conda activate b2txt25_lm

Important: The project requires two separate conda environments due to conflicting PyTorch versions:

  • b2txt25: PyTorch with CUDA 12.6 for model training/evaluation
  • b2txt25_lm: PyTorch 1.13.1 for Kaldi-based n-gram language models

Redis Setup

Redis is required for inter-process communication. Install on Ubuntu:

curl -fsSL https://packages.redis.io/gpg | sudo gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg
echo "deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/redis.list
sudo apt-get update && sudo apt-get install redis
sudo systemctl disable redis-server

Architecture Overview

High-Level System Flow

  1. Neural Data Input: 512 features (2 per electrode × 256 electrodes) binned at 20ms resolution
  2. RNN Model: Converts neural features to phoneme logits via CTC loss
  3. Language Model: Decodes phoneme logits to words using n-gram models + OPT rescoring
  4. Redis Communication: Coordinates between RNN inference and language model processes

Key Components

Model Training (model_training/)

  • Core Script: train_model.py (loads config from rnn_args.yaml)
  • Model Architecture: rnn_model.py - 5-layer GRU with 768 hidden units
  • Trainer: rnn_trainer.py - Custom PyTorch trainer with CTC loss
  • Evaluation: evaluate_model.py - Inference pipeline with Redis communication

Language Model (language_model/)

  • Standalone Server: language-model-standalone.py - Redis-based LM server
  • Kaldi Integration: Uses custom C++ bindings for efficient n-gram decoding
  • OPT Rescoring: Facebook OPT 6.7B for language model rescoring
  • Build System: Complex CMake-based build for Kaldi/SRILM integration

Utilities (nejm_b2txt_utils/)

  • General Utils: general_utils.py - Shared utility functions
  • Package: Installed via setup.py as nejm_b2txt_utils

Analysis (analyses/)

  • Jupyter Notebooks: figure_2.ipynb, figure_4.ipynb for paper figures

Common Development Tasks

Training a Model

conda activate b2txt25
cd model_training
python train_model.py

Running Evaluation Pipeline

  1. Start Redis server:

    redis-server
    
  2. Start language model (separate terminal):

    conda activate b2txt25_lm
    python language_model/language-model-standalone.py --lm_path language_model/pretrained_language_models/openwebtext_1gram_lm_sil --do_opt --nbest 100 --acoustic_scale 0.325 --blank_penalty 90 --alpha 0.55 --redis_ip localhost --gpu_number 0
    
  3. Run evaluation (separate terminal):

    conda activate b2txt25
    cd model_training
    python evaluate_model.py --model_path ../data/t15_pretrained_rnn_baseline --data_dir ../data/hdf5_data_final --eval_type test --gpu_number 1
    
  4. Shutdown Redis:

    redis-cli shutdown
    

Building Language Model from Scratch

# Build SRILM (in language_model/srilm-1.7.3/)
export SRILM=$PWD
make MAKE_PIC=yes World

# Build Kaldi components (in language_model/runtime/server/x86/)
mkdir build && cd build
cmake .. && make -j8

Data Structure

Neural Data Format

  • File Type: HDF5 files in data/hdf5_data_final/
  • Features: 512 neural features per 20ms bin:
    • 0-64: ventral 6v threshold crossings
    • 65-128: area 4 threshold crossings
    • 129-192: 55b threshold crossings
    • 193-256: dorsal 6v threshold crossings
    • 257-320: ventral 6v spike band power
    • 321-384: area 4 spike band power
    • 385-448: 55b spike band power
    • 449-512: dorsal 6v spike band power

Data Loading

Use load_h5py_file() in model_training/evaluate_model_helpers.py as reference for HDF5 data loading.

Important Notes

  • GPU Requirements: OPT 6.7B requires ~12.4GB VRAM; RTX 4090s recommended
  • Memory Requirements: 3-gram LM needs ~60GB RAM, 5-gram needs ~300GB RAM
  • Environment Isolation: Always use correct conda environment for each component
  • Redis Dependency: Many scripts require Redis server to be running
  • Build Dependencies: CMake ≥3.14 and GCC ≥10.1 required for language model builds

XLA Optimizations (TPU-Friendly Model)

The RNN model has been optimized for XLA compilation and TPU training while preserving the original model architecture. These optimizations improve compilation speed and reduce memory usage on TPUs.

Applied XLA Optimizations

1. Dynamic Shape Operations → Static Operations

Problem: XLA compiler struggles with dynamic tensor shapes and indexing Solution: Replace dynamic operations with XLA-friendly alternatives

# Before (XLA-unfriendly):
day_weights = torch.stack([self.day_weights[i] for i in day_idx], dim=0)
day_biases = torch.cat([self.day_biases[i] for i in day_idx], dim=0).unsqueeze(1)

# After (XLA-friendly):
all_day_weights = torch.stack(list(self.day_weights), dim=0)  # Static stack
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0)
day_weights = torch.index_select(all_day_weights, 0, day_idx)  # Static gather
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)

2. Matrix Operations → XLA Primitives

Problem: Complex einsum operations are less optimized than native XLA ops Solution: Use batch matrix multiplication (bmm) for better XLA performance

# Before:
x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases

# After (XLA-optimized):
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)  # bmm + dtype consistency

5. Mixed Precision Dtype Consistency

Problem: Mixed precision training causes dtype mismatches in bmm operations, adversarial residual connections, and patch processing operations Solution: Ensure all operands match input tensor dtype and preserve dtype through all operations

# Error: f32[32,7168] vs bf16[32,7168] in mixed precision training
# Fix 1: Add dtype conversions for all bmm operands
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)

# Fix 2: Ensure dtype consistency in adversarial training residual connections
denoised_input = x_processed - noise_output.to(x_processed.dtype)

# Fix 3: Preserve dtype through patch processing operations
if self.patch_size > 0:
    original_dtype = x.dtype  # Preserve original dtype for XLA/TPU compatibility
    x = x.unsqueeze(1)
    x = x.permute(0, 3, 1, 2)
    x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
    x_unfold = x_unfold.squeeze(2)
    x_unfold = x_unfold.permute(0, 2, 3, 1)
    x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
    # Ensure dtype consistency after patch processing operations
    x = x.to(original_dtype)

3. Hidden State Initialization

Problem: Dynamic batch size allocation causes XLA recompilation Solution: Use static shapes and avoid x.shape[0] in tensor creation

# Before:
if states is None:
    states = self.h0.expand(2, x.shape[0], self.input_size).contiguous()

# After (XLA-friendly):
batch_size = x.size(0)  # Extract once
if states is None:
    states = self.h0.expand(2, batch_size, self.input_size).contiguous()

4. Return Value Optimization

Problem: Complex dictionary returns cause XLA compilation issues Solution: Use tuples instead of dictionaries for cleaner XLA graphs

# Before (XLA-unfriendly):
return {
    'clean_logits': clean_logits,
    'noisy_logits': noisy_logits,
    'noise_output': noise_output
}

# After (XLA-friendly):
return clean_logits, noisy_logits, noise_output  # Simple tuple return

Files Modified for XLA Optimization

  • model_training_nnn/rnn_model.py: All three models optimized
    • NoiseModel.forward(): Dynamic indexing → static gather operations + dtype consistency
    • CleanSpeechModel.forward(): Same optimizations + bmm for matrix ops + dtype consistency
    • NoisySpeechModel.forward(): Hidden state optimization
    • TripleGRUDecoder.forward(): Complex return values → tuple returns + adversarial residual connection dtype fix
    • TripleGRUDecoder._apply_preprocessing(): Static preprocessing operations + dtype consistency

Benefits of XLA Optimizations

  1. Faster Compilation: Static shapes allow XLA to pre-compile optimized kernels
  2. Better Memory Usage: Reduced dynamic allocation during training
  3. Improved TPU Utilization: XLA primitives map directly to TPU matrix units
  4. Consistent Performance: Eliminates recompilation caused by dynamic shapes

Testing and Validation

Created test scripts to verify model consistency:

  • test_xla_model.py: Comprehensive model validation testing
  • quick_test_xla.py: Fast verification of basic functionality

Important: These optimizations preserve the exact model architecture and mathematical operations. Only the implementation has been made XLA-friendly.

Usage Notes

  • All original model interfaces remain unchanged
  • Both 'inference' and 'full' modes are supported
  • Backward compatibility with existing training scripts is maintained
  • TPU training should now show improved compilation times and memory efficiency

Competition Context

This codebase also serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing reference implementations for neural signal decoding.