Files
b2txt25/CLAUDE.md
2025-10-15 23:37:24 +08:00

23 KiB
Raw Permalink Blame History

CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

Project Overview

This repository contains the code and data for "An Accurate and Rapidly Calibrating Speech Neuroprosthesis" published in the New England Journal of Medicine (2024). It implements a brain-to-text system that converts neural signals from speech motor cortex into text using RNN models and n-gram language models.

Development Environment Setup

Main Environment (b2txt25)

./setup.sh
conda activate b2txt25

Language Model Environment (b2txt25_lm)

./setup_lm.sh
conda activate b2txt25_lm

Important: The project requires two separate conda environments due to conflicting PyTorch versions:

  • b2txt25: PyTorch with CUDA 12.6 for model training/evaluation
  • b2txt25_lm: PyTorch 1.13.1 for Kaldi-based n-gram language models

Redis Setup

Redis is required for inter-process communication. Install on Ubuntu:

curl -fsSL https://packages.redis.io/gpg | sudo gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg
echo "deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/redis.list
sudo apt-get update && sudo apt-get install redis
sudo systemctl disable redis-server

Architecture Overview

High-Level System Flow

  1. Neural Data Input: 512 features (2 per electrode × 256 electrodes) binned at 20ms resolution
  2. RNN Model: Converts neural features to phoneme logits via CTC loss
  3. Language Model: Decodes phoneme logits to words using n-gram models + OPT rescoring
  4. Redis Communication: Coordinates between RNN inference and language model processes

Key Components

Model Training (model_training/)

  • Core Script: train_model.py (loads config from rnn_args.yaml)
  • Model Architecture: rnn_model.py - 5-layer GRU with 768 hidden units
  • Trainer: rnn_trainer.py - Custom PyTorch trainer with CTC loss
  • Evaluation: evaluate_model.py - Inference pipeline with Redis communication

Language Model (language_model/)

  • Standalone Server: language-model-standalone.py - Redis-based LM server
  • Kaldi Integration: Uses custom C++ bindings for efficient n-gram decoding
  • OPT Rescoring: Facebook OPT 6.7B for language model rescoring
  • Build System: Complex CMake-based build for Kaldi/SRILM integration

Utilities (nejm_b2txt_utils/)

  • General Utils: general_utils.py - Shared utility functions
  • Package: Installed via setup.py as nejm_b2txt_utils

Analysis (analyses/)

  • Jupyter Notebooks: figure_2.ipynb, figure_4.ipynb for paper figures

Common Development Tasks

Training a Model

conda activate b2txt25
cd model_training
python train_model.py

Running Evaluation Pipeline

  1. Start Redis server:

    redis-server
    
  2. Start language model (separate terminal):

    conda activate b2txt25_lm
    python language_model/language-model-standalone.py --lm_path language_model/pretrained_language_models/openwebtext_1gram_lm_sil --do_opt --nbest 100 --acoustic_scale 0.325 --blank_penalty 90 --alpha 0.55 --redis_ip localhost --gpu_number 0
    
  3. Run evaluation (separate terminal):

    conda activate b2txt25
    cd model_training
    python evaluate_model.py --model_path ../data/t15_pretrained_rnn_baseline --data_dir ../data/hdf5_data_final --eval_type test --gpu_number 1
    
  4. Shutdown Redis:

    redis-cli shutdown
    

Building Language Model from Scratch

# Build SRILM (in language_model/srilm-1.7.3/)
export SRILM=$PWD
make MAKE_PIC=yes World

# Build Kaldi components (in language_model/runtime/server/x86/)
mkdir build && cd build
cmake .. && make -j8

Data Structure

Neural Data Format

  • File Type: HDF5 files in data/hdf5_data_final/
  • Features: 512 neural features per 20ms bin:
    • 0-64: ventral 6v threshold crossings
    • 65-128: area 4 threshold crossings
    • 129-192: 55b threshold crossings
    • 193-256: dorsal 6v threshold crossings
    • 257-320: ventral 6v spike band power
    • 321-384: area 4 spike band power
    • 385-448: 55b spike band power
    • 449-512: dorsal 6v spike band power

Data Loading

Use load_h5py_file() in model_training/evaluate_model_helpers.py as reference for HDF5 data loading.

Important Notes

  • GPU Requirements: OPT 6.7B requires ~12.4GB VRAM; RTX 4090s recommended
  • Memory Requirements: 3-gram LM needs ~60GB RAM, 5-gram needs ~300GB RAM
  • Environment Isolation: Always use correct conda environment for each component
  • Redis Dependency: Many scripts require Redis server to be running
  • Build Dependencies: CMake ≥3.14 and GCC ≥10.1 required for language model builds

XLA Optimizations (TPU-Friendly Model)

The RNN model has been optimized for XLA compilation and TPU training while preserving the original model architecture. These optimizations improve compilation speed and reduce memory usage on TPUs.

Applied XLA Optimizations

1. Dynamic Shape Operations → Static Operations

Problem: XLA compiler struggles with dynamic tensor shapes and indexing Solution: Replace dynamic operations with XLA-friendly alternatives

# Before (XLA-unfriendly):
day_weights = torch.stack([self.day_weights[i] for i in day_idx], dim=0)
day_biases = torch.cat([self.day_biases[i] for i in day_idx], dim=0).unsqueeze(1)

# After (XLA-friendly):
all_day_weights = torch.stack(list(self.day_weights), dim=0)  # Static stack
all_day_biases = torch.stack([bias.squeeze(0) for bias in self.day_biases], dim=0)
day_weights = torch.index_select(all_day_weights, 0, day_idx)  # Static gather
day_biases = torch.index_select(all_day_biases, 0, day_idx).unsqueeze(1)

2. Matrix Operations → XLA Primitives

Problem: Complex einsum operations are less optimized than native XLA ops Solution: Use batch matrix multiplication (bmm) for better XLA performance

# Before:
x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases

# After (XLA-optimized):
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)  # bmm + dtype consistency

5. Mixed Precision Dtype Consistency (Comprehensive Fix)

Problem: Mixed precision training causes dtype mismatches throughout the adversarial training pipeline Error: Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[32,7168], argument shape: bf16[32,7168]

Root Cause Analysis: The error occurred at dimension 7168 = 512 * 14, indicating patch processing with patch_size=14. The dtype mismatch cascaded through multiple layers:

  1. Initial bmm operations in day-specific transformations
  2. Adversarial training residual connections between models
  3. Patch processing operations (unfold, permute, reshape)
  4. Gradient Reversal Layer (GRL) operations
  5. Hidden state initialization in adversarial training helper methods

Comprehensive Solution: Implement dtype consistency across the entire adversarial training data flow:

# Fix 1: Basic bmm operations with dtype consistency
x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)

# Fix 2: Patch processing with explicit dtype preservation
if self.patch_size > 0:
    original_dtype = x.dtype  # Preserve original dtype for XLA/TPU compatibility
    x = x.unsqueeze(1)
    x = x.permute(0, 3, 1, 2)
    x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
    x_unfold = x_unfold.squeeze(2)
    x_unfold = x_unfold.permute(0, 2, 3, 1)
    x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
    # Ensure dtype consistency after patch processing operations
    x = x.to(original_dtype)

# Fix 3: Adversarial training residual connections
noise_output = noise_output.to(x_processed.dtype)
denoised_input = x_processed - noise_output

# Fix 4: Gradient Reversal Layer dtype handling
noisy_input = gradient_reverse(noise_output, grl_lambda) if grl_lambda else noise_output
# Ensure dtype consistency after GRL (preserves input dtype but explicit check)
noisy_input = noisy_input.to(x_processed.dtype)

# Fix 5: Hidden state dtype consistency in helper methods
# In _clean_forward_with_processed_input:
if states is None:
    states = self.clean_speech_model.h0.expand(3, batch_size, self.clean_speech_model.n_units).contiguous()
    # Ensure hidden states match input dtype for mixed precision training
    states = states.to(x_processed.dtype)

# In _noisy_forward_with_processed_input:
if states is None:
    states = self.noisy_speech_model.h0.expand(2, batch_size, self.noisy_speech_model.n_units).contiguous()
    # Ensure hidden states match input dtype for mixed precision training
    states = states.to(x_processed.dtype)

Key Implementation Details:

  • GradientReversalFn: Preserves input dtype automatically (identity forward, gradient reversal backward)
  • Patch Processing: Explicit dtype preservation prevents unfold operations from changing precision
  • Residual Connections: All tensor arithmetic operations ensure matching dtypes
  • Helper Methods: Hidden state initialization matches processed input dtype
  • Data Flow: NoiseModel → GRL → NoisySpeechModel maintains dtype consistency throughout

3. Hidden State Initialization

Problem: Dynamic batch size allocation causes XLA recompilation Solution: Use static shapes and avoid x.shape[0] in tensor creation

# Before:
if states is None:
    states = self.h0.expand(2, x.shape[0], self.input_size).contiguous()

# After (XLA-friendly):
batch_size = x.size(0)  # Extract once
if states is None:
    states = self.h0.expand(2, batch_size, self.input_size).contiguous()

4. Return Value Optimization

Problem: Complex dictionary returns cause XLA compilation issues Solution: Use tuples instead of dictionaries for cleaner XLA graphs

# Before (XLA-unfriendly):
return {
    'clean_logits': clean_logits,
    'noisy_logits': noisy_logits,
    'noise_output': noise_output
}

# After (XLA-friendly):
return clean_logits, noisy_logits, noise_output  # Simple tuple return

Files Modified for XLA Optimization

  • model_training_nnn/rnn_model.py: Comprehensive XLA optimization with dtype consistency
    • GradientReversalFn: Added adversarial training gradient reversal layer
    • NoiseModel.forward(): Dynamic indexing → static gather operations + comprehensive dtype consistency + patch processing dtype preservation
    • CleanSpeechModel.forward(): Same optimizations + bmm for matrix ops + comprehensive dtype consistency + patch processing dtype preservation
    • NoisySpeechModel.forward(): Hidden state optimization (no day layers, simplified)
    • TripleGRUDecoder.forward(): Complex return values → tuple returns + comprehensive adversarial training dtype fixes + residual connection dtype consistency + GRL dtype handling
    • TripleGRUDecoder._apply_preprocessing(): Static preprocessing operations + dtype consistency + patch processing dtype preservation
    • TripleGRUDecoder._clean_forward_with_processed_input(): Helper method with hidden state dtype consistency for mixed precision
    • TripleGRUDecoder._noisy_forward_with_processed_input(): Helper method with hidden state dtype consistency for mixed precision

Specific Dtype Consistency Fixes Applied:

  1. Basic Operations: All torch.bmm() operations with .to(x.dtype) conversions
  2. Patch Processing: Explicit dtype preservation through unfold/permute/reshape operations
  3. Adversarial Training: Residual connections with .to(x_processed.dtype) conversions
  4. Gradient Reversal: Dtype consistency after GRL operations
  5. Hidden States: All hidden state initialization with .to(x_processed.dtype) conversions
  6. Data Flow: End-to-end dtype consistency in NoiseModel → GRL → NoisySpeechModel pipeline

Error Resolved: f32[32,7168] vs bf16[32,7168] dtype mismatch in mixed precision TPU training

Benefits of XLA Optimizations

  1. Faster Compilation: Static shapes allow XLA to pre-compile optimized kernels
  2. Better Memory Usage: Reduced dynamic allocation during training
  3. Improved TPU Utilization: XLA primitives map directly to TPU matrix units
  4. Consistent Performance: Eliminates recompilation caused by dynamic shapes

Testing and Validation

Created test scripts to verify model consistency:

  • test_xla_model.py: Comprehensive model validation testing
  • quick_test_xla.py: Fast verification of basic functionality

Important: These optimizations preserve the exact model architecture and mathematical operations. Only the implementation has been made XLA-friendly.

Usage Notes

  • All original model interfaces remain unchanged
  • Both 'inference' and 'full' modes are supported
  • Backward compatibility with existing training scripts is maintained
  • TPU training should now show improved compilation times and memory efficiency

Troubleshooting Dtype Issues in Mixed Precision Training

Common Error Pattern:

Status: INVALID_ARGUMENT: Call parameter must match argument; got parameter 0 shape: f32[X,Y], argument shape: bf16[X,Y]

Diagnosis Steps:

  1. Identify Operation: Look at the tensor dimensions to identify which operation is failing

    • 7168 = 512 * 14: Patch processing operation with patch_size=14
    • 512: Basic neural features
    • Other patterns may indicate different operations
  2. Check Data Flow: Trace the tensor through the adversarial training pipeline

    • Input → NoiseModel → residual connection → CleanSpeechModel
    • Input → NoiseModel → GRL → NoisySpeechModel
  3. Verify Dtype Consistency: Ensure all operations maintain input dtype

    • Use .to(x.dtype) for all operand tensors
    • Preserve dtype through complex operations (unfold, permute, reshape)
    • Match hidden state dtype to input tensor dtype

Quick Fix Template:

# For any tensor operation between tensors a and b:
result = operation(a, b.to(a.dtype))

# For complex operations that might change dtype:
original_dtype = tensor.dtype
tensor = complex_operation(tensor)
tensor = tensor.to(original_dtype)

# For hidden state initialization:
states = states.to(input_tensor.dtype)

PyTorch XLA API Updates and Warnings

Deprecated APIs (as of 2024)

Important: Several torch_xla APIs have been deprecated and should be updated in new code:

1. Device API Changes

# ❌ Deprecated (shows DeprecationWarning):
device = xm.xla_device()

# ✅ Modern API:
import torch_xla
device = torch_xla.device()

2. Synchronization API Changes

# ❌ Deprecated (shows DeprecationWarning):
xm.mark_step()

# ✅ Modern API:
import torch_xla
torch_xla.sync()

3. Mixed Precision Environment Variables

# ⚠️ Will be deprecated after PyTorch XLA 2.6:
os.environ['XLA_USE_BF16'] = '1'

# 💡 Recommended: Convert model to bf16 directly in code
model = model.to(torch.bfloat16)

TPU Performance Warnings

Transparent Hugepages Warning

UserWarning: Transparent hugepages are not enabled. TPU runtime startup and
shutdown time should be significantly improved on TPU v5e and newer.

Solution (for TPU v5e and newer):

sudo sh -c "echo always > /sys/kernel/mm/transparent_hugepage/enabled"

Note: This warning appears on TPU environments and can be safely ignored if you don't have root access (e.g., Kaggle, Colab).

Updated Code Patterns

Modern XLA Synchronization Pattern

import torch_xla.core.xla_model as xm  # Still needed for other functions
import torch_xla

# Modern pattern:
def train_step():
    # ... training code ...

    # Synchronize every N steps
    if step % sync_frequency == 0:
        torch_xla.sync()  # Instead of xm.mark_step()

# Legacy pattern (still works but deprecated):
def train_step_legacy():
    # ... training code ...

    # Old way (shows deprecation warning)
    if step % sync_frequency == 0:
        xm.mark_step()
        xm.wait_device_ops()  # This is still current

Device Detection Pattern

# Modern approach:
import torch_xla

try:
    device = torch_xla.device()
    print(f"Using XLA device: {device}")
except:
    device = torch.device('cpu')
    print("Falling back to CPU")

# Legacy approach (shows warnings):
import torch_xla.core.xla_model as xm

try:
    device = xm.xla_device()  # DeprecationWarning
    print(f"Using XLA device: {device}")
except:
    device = torch.device('cpu')

Migration Guidelines

When updating existing code:

  1. Replace xm.xla_device() with torch_xla.device()
  2. Replace xm.mark_step() with torch_xla.sync()
  3. Keep xm.wait_device_ops() (still current API)
  4. Update imports to include torch_xla directly
  5. Consider explicit bf16 conversion instead of environment variables

Backward Compatibility

The deprecated APIs still work but generate warnings. For production code:

  • Update to modern APIs to avoid warnings
  • Test thoroughly as synchronization behavior may differ slightly
  • Legacy code will continue to function until removed in future versions

TensorFlow TPU Implementation

The original PyTorch implementation has been converted to TensorFlow for optimal performance on TPU v5e-8 environments, particularly for the Brain-to-Text '25 Competition on Kaggle.

Key TensorFlow Components (model_training_nnn_tpu/)

Core Files

  • rnn_model_tf.py: TensorFlow implementation of TripleGRUDecoder architecture

    • NoiseModel: 2-layer GRU for noise estimation with day-specific layers
    • CleanSpeechModel: 3-layer GRU for clean speech recognition with day-specific layers
    • NoisySpeechModel: 2-layer GRU for noisy speech recognition (no day layers)
    • TripleGRUDecoder: Main adversarial architecture combining all three models
    • CTCLoss: Custom CTC loss implementation for TPU compatibility
    • create_tpu_strategy(): Enhanced TPU connection function with robust environment detection
  • trainer_tf.py: TensorFlow training pipeline with distributed TPU support

  • dataset_tf.py: TensorFlow data loading with augmentation pipeline optimized for TPU

  • train_model_tf.py: Main training script entry point

  • evaluate_model_tf.py: Evaluation pipeline for model performance analysis

TPU v5e-8 Specific Optimizations

1. Enhanced TPU Connection

The create_tpu_strategy() function provides robust TPU detection across different environments:

def create_tpu_strategy():
    """Create TPU strategy for distributed training on TPU v5e-8"""
    # Multi-environment TPU detection
    if 'COLAB_TPU_ADDR' in os.environ:
        tpu_address = os.environ['COLAB_TPU_ADDR']
    elif 'TPU_NAME' in os.environ:
        tpu_name = os.environ['TPU_NAME']
    elif 'TPU_WORKER_ID' in os.environ:
        # Kaggle TPU environment
        tpu_address = f'grpc://10.0.0.2:8470'  # Default Kaggle TPU address

    # Enhanced error handling and debugging output
    # Fallback to default strategy if TPU connection fails

Environment Variables Detected:

  • COLAB_TPU_ADDR: Google Colab TPU environment
  • TPU_NAME: Generic TPU name specification
  • TPU_WORKER_ID: Kaggle TPU environment indicator

Troubleshooting TPU Connection Issues:

  • Error: "Failed to initialize TPU: Please provide a TPU Name to connect to."
  • Solution: The function automatically detects and uses appropriate TPU addresses based on environment
  • Debugging: All TPU-related environment variables are printed during initialization

2. Mixed Precision Training

Configured for optimal TPU v5e-8 performance:

def configure_mixed_precision():
    """Configure mixed precision for optimal TPU v5e-8 performance"""
    policy = keras.mixed_precision.Policy('mixed_bfloat16')
    keras.mixed_precision.set_global_policy(policy)

3. XLA-Optimized Operations

  • Static Tensor Operations: Using tf.stack() and tf.gather() instead of dynamic indexing
  • Efficient Matrix Operations: tf.linalg.matmul() for batch matrix multiplication
  • TPU-Friendly GRU Layers: Disabled recurrent dropout for better TPU performance
  • Patch Processing: TensorFlow equivalent of PyTorch's unfold using tf.image.extract_patches()

Key Architecture Differences from PyTorch

1. Gradient Reversal Layer (GRL)

@tf.custom_gradient
def gradient_reverse(x, lambd=1.0):
    """Gradient Reversal Layer for TensorFlow"""
    def grad(dy):
        return -lambd * dy  # Only return gradient w.r.t. x
    return tf.identity(x), grad

2. CTC Loss Implementation

Custom sparse tensor conversion for TPU compatibility:

def dense_to_sparse(dense_tensor, sequence_lengths):
    """Convert dense tensor to sparse tensor for CTC"""
    mask = tf.not_equal(dense_tensor, 0)
    indices = tf.where(mask)
    values = tf.gather_nd(dense_tensor, indices)
    return tf.SparseTensor(indices=indices, values=values, dense_shape=dense_shape)

3. Day-Specific Layers

Using add_weight() for TPU-compatible variable management:

for i in range(n_days):
    weight = self.add_weight(
        name=f'day_weight_{i}',
        shape=(neural_dim, neural_dim),
        initializer=tf.keras.initializers.Identity(),
        trainable=True
    )

Training on TPU v5e-8

Basic Training Command

# In Kaggle TPU v5e-8 environment
python train_model_tf.py

Expected Output

🔍 Detecting TPU environment...
📍 Kaggle TPU detected, worker ID: 0, address: grpc://10.0.0.2:8470
✅ TPU initialized successfully!
🎉 Number of TPU cores: 8
Training on 8 TPU cores  # Should show 8 cores, not 1

Performance Benefits

  1. Multi-Core Utilization: Properly configured TPU strategy utilizes all 8 TPU v5e-8 cores
  2. Mixed Precision: bfloat16 precision optimized for TPU matrix units
  3. XLA Compilation: Static operations enable efficient XLA graph compilation
  4. Memory Efficiency: Optimized for TPU memory constraints and batch processing

Common Issues and Solutions

Issue: "Training on 1 TPU cores" instead of 8

Cause: TPU connection fallback to default strategy Solution: Enhanced create_tpu_strategy() function with environment detection Check: Verify TPU environment variables are properly set

Issue: CTC Loss dtype errors

Cause: Mixed precision dtype mismatches Solution: Explicit dtype casting in CTCLoss.call()

Issue: Gradient Reversal Layer errors

Cause: Incorrect gradient return format Solution: Return only gradient w.r.t. input tensor, not lambda parameter

Competition Context

This codebase serves as baseline for the Brain-to-Text '25 Competition on Kaggle, providing both PyTorch and TensorFlow reference implementations for neural signal decoding with optimizations for TPU v5e-8 training environments.