Files
b2txt25/model_training/README.md

49 lines
4.2 KiB
Markdown
Raw Normal View History

2025-07-02 12:18:09 -07:00
# Model Training & Evaluation
2025-07-01 09:39:24 -07:00
This directory contains code and resources for training the brain-to-text RNN model. This model is largely based on the architecture described in the paper "*An Accurate and Rapidly Calibrating Speech Neuroprosthesis*" by Card et al. (2024), but also contains modifications to improve performance, efficiency, and usability.
2025-07-02 12:18:09 -07:00
A pretrained baseline RNN model is included in the [Dryad Dataset](https://datadryad.org/dataset/doi:10.5061/dryad.dncjsxm85), as is the neural data required to train that model. The code for training the same model is included here.
All model training and evaluation code was tested on a computer running Ubuntu 22.04 with two RTX 4090's and 512 GB of RAM.
2025-07-01 09:39:24 -07:00
2025-07-02 12:18:09 -07:00
## Setup
1. Install the required `b2txt25` conda environment by following the instructions in the root `README.md` file. This will set up the necessary dependencies for running the model training and evaluation code.
2025-07-01 09:39:24 -07:00
2025-07-02 12:18:09 -07:00
2. Download the dataset from Dryad: [Dryad Dataset](https://datadryad.org/dataset/doi:10.5061/dryad.dncjsxm85). Place the downloaded data in the `data` directory. Be sure to unzip `t15_copyTask_neuralData.zip` and `t15_pretrained_rnn_baseline.zip`.
2025-07-01 09:39:24 -07:00
## Training
2025-07-02 14:28:34 -07:00
### Baseline RNN Model
2025-07-02 15:34:24 -07:00
We have included a custom PyTorch implementation of the RNN model used in the paper (the paper used a TensorFlow implementation). This implementation aims to replicate or improve upon the original model's performance while leveraging PyTorch's features, resulting in a more efficient training process with a slight increase in decoding accuracy. This model includes day-specific input layers (512x512 linear input layers with softsign activation), a 5-layer GRU with 768 hidden units per layer, and a linear output layer. The model is trained to predict phonemes from neural data using CTC loss and the AdamW optimizer. Data is augmented with noise and temporal jitter to improve robustness. All model hyperparameters are specified in the [`rnn_args.yaml`](rnn_args.yaml) file.
2025-07-02 14:28:34 -07:00
### Model training script
To train the baseline RNN model, use the `b2txt25` conda environment to run the `train_model.py` script from the `model_training` directory:
2025-07-01 09:39:24 -07:00
```bash
2025-07-02 12:18:09 -07:00
conda activate b2txt25
2025-07-01 09:39:24 -07:00
python train_model.py
```
2025-07-02 15:34:24 -07:00
The model will train for 120,000 mini-batches (~3.5 hours on an RTX 4090) and should achieve an aggregate phoneme error rate of 10.1% on the validation partition. We note that the number of training batches and specific model hyperparameters may not be optimal here, and this baseline model is only meant to serve as an example. See [`rnn_args.yaml`](rnn_args.yaml) for a list of all hyperparameters.
2025-07-01 09:39:24 -07:00
## Evaluation
2025-07-02 12:18:09 -07:00
### Start redis server
2025-07-02 14:28:34 -07:00
To evaluate the model, first start a redis server on `localhost` in terminal with:
2025-07-02 12:18:09 -07:00
```bash
redis-server
```
### Start language model
2025-07-02 14:28:34 -07:00
Next, use the `b2txt25_lm` conda environment to start the ngram language model in a seperate terminal window. For example, the 1gram language model can be started using the command below. Note that the 1gram model has no gramatical structure built into it. Details on downloading pretrained 3gram and 5gram language models and running them can be found in the README.md in the `language_model` directory.
2025-07-02 12:18:09 -07:00
To run the 1gram language model from the root directory of this repository:
```bash
2025-07-03 13:23:25 -07:00
conda activate b2txt25_lm
2025-07-02 12:18:09 -07:00
python language_model/language-model-standalone.py --lm_path language_model/pretrained_language_models/openwebtext_1gram_lm_sil --do_opt --nbest 100 --acoustic_scale 0.325 --blank_penalty 90 --alpha 0.55 --redis_ip localhost --gpu_number 0
```
2025-07-01 09:39:24 -07:00
2025-07-02 12:18:09 -07:00
### Evaluate
2025-07-02 14:28:34 -07:00
Finally, use the `b2txt25` conda environment to run the `evaluate_model.py` script to load the pretrained baseline RNN, use it for inference on the heldout val or test sets to get phoneme logits, pass them through the language model via redis to get word predictions, and then save the predicted sentences to a .txt file in the format required for competition submission. An example output file for the val split can be found at `rnn_baseline_submission_file_valsplit.txt`.
2025-07-01 09:39:24 -07:00
```bash
2025-07-02 12:18:09 -07:00
conda activate b2txt25
python evaluate_model.py --model_path ../data/t15_pretrained_rnn_baseline --data_dir ../data/t15_copyTask_neuralData --eval_type test --gpu_number 1
```
### Shutdown redis
2025-07-03 13:23:25 -07:00
When you're done, you can shutdown the redis server from any terminal using `redis-cli shutdown`.