288 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Markdown
		
	
	
	
	
	
		
		
			
		
	
	
			288 lines
		
	
	
		
			8.7 KiB
		
	
	
	
		
			Markdown
		
	
	
	
	
	
|   | # TensorFlow Brain-to-Text Model for TPU v5e-8
 | ||
|  | 
 | ||
|  | This directory contains a complete TensorFlow implementation of the brain-to-text neural speech decoding system, optimized for TPU v5e-8 hardware. It provides equivalent functionality to the PyTorch version but with TensorFlow operations designed for maximum TPU performance. | ||
|  | 
 | ||
|  | ## Architecture Overview
 | ||
|  | 
 | ||
|  | The TensorFlow implementation maintains the same sophisticated three-model adversarial architecture: | ||
|  | 
 | ||
|  | ### Core Models
 | ||
|  | - **NoiseModel**: 2-layer GRU that estimates noise in neural data | ||
|  | - **CleanSpeechModel**: 3-layer GRU that processes denoised signal for speech recognition | ||
|  | - **NoisySpeechModel**: 2-layer GRU that processes noise signal for adversarial training | ||
|  | 
 | ||
|  | ### Key Features
 | ||
|  | - **Day-specific transformations**: Learnable input layers for each recording session | ||
|  | - **Patch processing**: Temporal patching for improved sequence modeling | ||
|  | - **Gradient Reversal Layer**: For adversarial training between noise and speech models | ||
|  | - **Mixed precision**: bfloat16 optimization for TPU v5e-8 memory efficiency | ||
|  | - **CTC Loss**: Connectionist Temporal Classification for sequence alignment | ||
|  | 
 | ||
|  | ## Files Overview
 | ||
|  | 
 | ||
|  | ### Core Implementation
 | ||
|  | - `rnn_model_tf.py`: TensorFlow model architecture with TPU optimizations | ||
|  | - `trainer_tf.py`: Training pipeline with distributed TPU strategy | ||
|  | - `dataset_tf.py`: Data loading and augmentation optimized for TPU | ||
|  | - `train_model_tf.py`: Main training script | ||
|  | - `evaluate_model_tf.py`: Evaluation and inference script | ||
|  | 
 | ||
|  | ### Configuration and Setup
 | ||
|  | - `rnn_args.yaml`: Training configuration (shared with PyTorch version) | ||
|  | - `setup_tensorflow_tpu.sh`: Environment setup script | ||
|  | - `requirements_tf.txt`: Python dependencies | ||
|  | - `README_TensorFlow.md`: This documentation | ||
|  | 
 | ||
|  | ## Quick Start
 | ||
|  | 
 | ||
|  | ### 1. Environment Setup
 | ||
|  | ```bash | ||
|  | # Run the setup script to configure TPU environment
 | ||
|  | ./setup_tensorflow_tpu.sh | ||
|  | 
 | ||
|  | # Activate the conda environment
 | ||
|  | conda activate b2txt_tf | ||
|  | ``` | ||
|  | 
 | ||
|  | ### 2. Verify TPU Access
 | ||
|  | ```python | ||
|  | import tensorflow as tf | ||
|  | 
 | ||
|  | # Check TPU availability
 | ||
|  | resolver = tf.distribute.cluster_resolver.TPUClusterResolver() | ||
|  | tf.config.experimental_connect_to_cluster(resolver) | ||
|  | tf.tpu.experimental.initialize_tpu_system(resolver) | ||
|  | strategy = tf.distribute.TPUStrategy(resolver) | ||
|  | print(f"TPU cores available: {strategy.num_replicas_in_sync}") | ||
|  | ``` | ||
|  | 
 | ||
|  | ### 3. Start Training
 | ||
|  | ```bash | ||
|  | # Basic training with default config
 | ||
|  | python train_model_tf.py --config_path rnn_args.yaml | ||
|  | 
 | ||
|  | # Training with custom settings
 | ||
|  | python train_model_tf.py \ | ||
|  |     --config_path rnn_args.yaml \ | ||
|  |     --batch_size 64 \ | ||
|  |     --num_batches 50000 \ | ||
|  |     --output_dir ./trained_models/custom_run | ||
|  | ``` | ||
|  | 
 | ||
|  | ### 4. Run Evaluation
 | ||
|  | ```bash | ||
|  | # Evaluate trained model
 | ||
|  | python evaluate_model_tf.py \ | ||
|  |     --model_path ./trained_models/baseline_rnn/checkpoint/best_checkpoint \ | ||
|  |     --config_path rnn_args.yaml \ | ||
|  |     --eval_type test | ||
|  | ``` | ||
|  | 
 | ||
|  | ## TPU v5e-8 Optimizations
 | ||
|  | 
 | ||
|  | ### Hardware-Specific Features
 | ||
|  | - **Mixed Precision**: Automatic bfloat16 conversion for 2x memory efficiency | ||
|  | - **XLA Compilation**: Just-in-time compilation for optimal TPU performance | ||
|  | - **Distributed Strategy**: Automatic sharding across 8 TPU cores | ||
|  | - **Memory Management**: Efficient tensor operations to avoid OOM errors | ||
|  | 
 | ||
|  | ### Performance Optimizations
 | ||
|  | - **Batch Matrix Operations**: `tf.linalg.matmul` instead of element-wise operations | ||
|  | - **Static Shapes**: Avoiding dynamic tensor shapes for better compilation | ||
|  | - **Efficient Gathering**: `tf.gather` for day-specific parameter selection | ||
|  | - **Gradient Reversal**: Custom gradient function for adversarial training | ||
|  | 
 | ||
|  | ## Configuration
 | ||
|  | 
 | ||
|  | The model uses the same `rnn_args.yaml` configuration as the PyTorch version. Key TPU-specific settings: | ||
|  | 
 | ||
|  | ```yaml | ||
|  | # TPU-specific settings
 | ||
|  | use_amp: true                    # Enable mixed precision (bfloat16) | ||
|  | dataset: | ||
|  |   batch_size: 32                # Optimized for TPU memory | ||
|  |   num_dataloader_workers: 0     # Disable multiprocessing on TPU | ||
|  | 
 | ||
|  | # Model architecture (same as PyTorch)
 | ||
|  | model: | ||
|  |   n_input_features: 512         # Neural features per timestep | ||
|  |   n_units: 768                  # Hidden units per GRU layer | ||
|  |   patch_size: 14                # Temporal patch size | ||
|  |   patch_stride: 4               # Patch stride | ||
|  | ``` | ||
|  | 
 | ||
|  | ## Performance Comparison
 | ||
|  | 
 | ||
|  | ### TPU v5e-8 vs Other Hardware
 | ||
|  | - **Memory**: 2x improvement with bfloat16 mixed precision | ||
|  | - **Throughput**: ~3-4x faster training than V100 GPU | ||
|  | - **Scalability**: Automatic distribution across 8 cores | ||
|  | - **Cost Efficiency**: Better performance-per-dollar for large models | ||
|  | 
 | ||
|  | ### Expected Training Times (120k batches)
 | ||
|  | - **TPU v5e-8**: ~4-6 hours | ||
|  | - **Single V100**: ~15-20 hours | ||
|  | - **RTX 4090**: ~12-18 hours | ||
|  | 
 | ||
|  | ## Model Architecture Details
 | ||
|  | 
 | ||
|  | ### TripleGRUDecoder Forward Pass
 | ||
|  | ```python | ||
|  | # Training mode (adversarial)
 | ||
|  | clean_logits, noisy_logits, noise_output = model( | ||
|  |     features, day_indices, mode='full', | ||
|  |     grl_lambda=0.5, training=True | ||
|  | ) | ||
|  | 
 | ||
|  | # Inference mode (production)
 | ||
|  | clean_logits = model( | ||
|  |     features, day_indices, mode='inference', | ||
|  |     training=False | ||
|  | ) | ||
|  | ``` | ||
|  | 
 | ||
|  | ### Loss Functions
 | ||
|  | ```python | ||
|  | # Clean speech CTC loss
 | ||
|  | clean_loss = ctc_loss(clean_logits, labels, input_lengths, label_lengths) | ||
|  | 
 | ||
|  | # Adversarial noisy speech loss (with gradient reversal)
 | ||
|  | noisy_loss = ctc_loss(noisy_logits, labels, input_lengths, label_lengths) | ||
|  | 
 | ||
|  | # Combined loss
 | ||
|  | total_loss = clean_loss + 0.2 * noisy_loss + 0.001 * noise_l2_loss | ||
|  | ``` | ||
|  | 
 | ||
|  | ## Data Pipeline
 | ||
|  | 
 | ||
|  | ### HDF5 Data Loading
 | ||
|  | The TensorFlow implementation efficiently loads data from HDF5 files: | ||
|  | - **Batch creation**: Pre-batched data with padding | ||
|  | - **Feature subsets**: Configurable neural feature selection | ||
|  | - **Day balancing**: Ensures even representation across recording sessions | ||
|  | - **Memory efficiency**: Lazy loading with tf.data.Dataset | ||
|  | 
 | ||
|  | ### Data Augmentations
 | ||
|  | - **Gaussian smoothing**: Temporal smoothing of neural signals | ||
|  | - **White noise**: Additive Gaussian noise for robustness | ||
|  | - **Static gain**: Channel-wise multiplicative noise | ||
|  | - **Random walk**: Temporal drift simulation | ||
|  | - **Random cutoff**: Variable sequence lengths | ||
|  | 
 | ||
|  | ## Troubleshooting
 | ||
|  | 
 | ||
|  | ### Common TPU Issues
 | ||
|  | 
 | ||
|  | #### "Resource exhausted" errors
 | ||
|  | ```bash | ||
|  | # Reduce batch size
 | ||
|  | python train_model_tf.py --batch_size 16 | ||
|  | 
 | ||
|  | # Enable gradient accumulation
 | ||
|  | # Modify config: gradient_accumulation_steps: 4
 | ||
|  | ``` | ||
|  | 
 | ||
|  | #### TPU not detected
 | ||
|  | ```bash | ||
|  | # Check environment variables
 | ||
|  | echo $TPU_NAME | ||
|  | echo $COLAB_TPU_ADDR | ||
|  | 
 | ||
|  | # Verify TPU access
 | ||
|  | gcloud compute tpus list | ||
|  | ``` | ||
|  | 
 | ||
|  | #### Mixed precision issues
 | ||
|  | ```bash | ||
|  | # Disable mixed precision if needed
 | ||
|  | python train_model_tf.py --disable_mixed_precision | ||
|  | ``` | ||
|  | 
 | ||
|  | ### Performance Debugging
 | ||
|  | ```python | ||
|  | # Enable XLA logging
 | ||
|  | import os | ||
|  | os.environ['TF_XLA_FLAGS'] = '--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit' | ||
|  | 
 | ||
|  | # Profile TPU usage
 | ||
|  | tf.profiler.experimental.start('logdir') | ||
|  | # ... training code ...
 | ||
|  | tf.profiler.experimental.stop() | ||
|  | ``` | ||
|  | 
 | ||
|  | ## Advanced Usage
 | ||
|  | 
 | ||
|  | ### Custom Training Loop
 | ||
|  | ```python | ||
|  | from trainer_tf import BrainToTextDecoderTrainerTF | ||
|  | 
 | ||
|  | # Initialize trainer
 | ||
|  | trainer = BrainToTextDecoderTrainerTF(config) | ||
|  | 
 | ||
|  | # Custom training with checkpointing
 | ||
|  | for epoch in range(num_epochs): | ||
|  |     stats = trainer.train() | ||
|  |     if epoch % 5 == 0: | ||
|  |         trainer._save_checkpoint(f'epoch_{epoch}', epoch) | ||
|  | ``` | ||
|  | 
 | ||
|  | ### Model Inference
 | ||
|  | ```python | ||
|  | # Load trained model
 | ||
|  | model = trainer.model | ||
|  | model.load_weights('path/to/checkpoint.weights.h5') | ||
|  | 
 | ||
|  | # Run inference
 | ||
|  | logits = trainer.inference(features, day_indices, n_time_steps) | ||
|  | 
 | ||
|  | # Decode predictions
 | ||
|  | predictions = tf.argmax(logits, axis=-1) | ||
|  | ``` | ||
|  | 
 | ||
|  | ### Hyperparameter Tuning
 | ||
|  | ```python | ||
|  | # Grid search over learning rates
 | ||
|  | learning_rates = [0.001, 0.005, 0.01] | ||
|  | for lr in learning_rates: | ||
|  |     config.lr_max = lr | ||
|  |     trainer = BrainToTextDecoderTrainerTF(config) | ||
|  |     stats = trainer.train() | ||
|  | ``` | ||
|  | 
 | ||
|  | ## Research and Development
 | ||
|  | 
 | ||
|  | This TensorFlow implementation maintains full compatibility with the published research while providing: | ||
|  | 
 | ||
|  | 1. **Reproducible Results**: Same model architecture and training procedures | ||
|  | 2. **Hardware Optimization**: TPU-specific performance improvements | ||
|  | 3. **Scalability**: Easy scaling to larger models and datasets | ||
|  | 4. **Extensibility**: Clean APIs for research modifications | ||
|  | 
 | ||
|  | ### Key Research Features
 | ||
|  | - **Adversarial Training**: Domain adaptation through gradient reversal | ||
|  | - **Multi-day Learning**: Session-specific input transformations | ||
|  | - **Temporal Modeling**: Patch-based sequence processing | ||
|  | - **Robust Training**: Comprehensive data augmentation pipeline | ||
|  | 
 | ||
|  | ## Citation
 | ||
|  | 
 | ||
|  | If you use this TensorFlow implementation in your research, please cite the original paper: | ||
|  | 
 | ||
|  | ```bibtex | ||
|  | @article{card2024accurate, | ||
|  |   title={An Accurate and Rapidly Calibrating Speech Neuroprosthesis}, | ||
|  |   author={Card, Nicholas S and others}, | ||
|  |   journal={New England Journal of Medicine}, | ||
|  |   year={2024} | ||
|  | } | ||
|  | ``` | ||
|  | 
 | ||
|  | ## Support
 | ||
|  | 
 | ||
|  | For questions specific to the TensorFlow implementation: | ||
|  | 1. Check this README and the PyTorch documentation in `../CLAUDE.md` | ||
|  | 2. Review configuration options in `rnn_args.yaml` | ||
|  | 3. Examine example scripts in this directory | ||
|  | 4. Open issues on the project repository | ||
|  | 
 | ||
|  | For TPU-specific questions, consult Google Cloud TPU documentation and TensorFlow TPU guides. |