readme file linking
This commit is contained in:
12
README.md
12
README.md
@@ -23,7 +23,7 @@ The code is organized into five main directories: `utils`, `analyses`, `data`, `
|
||||
- The `analyses` directory contains the code necessary to reproduce results shown in the main text and supplemental appendix.
|
||||
- The `data` directory contains the data necessary to reproduce the results in the paper. Download it from Dryad using the link above and place it in this directory.
|
||||
- The `model_training` directory contains the code necessary to train and evaluate the brain-to-text model. See the README.md in that folder for more detailed instructions.
|
||||
- The `language_model` directory contains the ngram language model implementation and a pretrained 1gram language model. Pretrained 3gram and 5gram language models can be downloaded [here](https://datadryad.org/dataset/doi:10.5061/dryad.x69p8czpq) (`languageModel.tar.gz` and `languageModel_5gram.tar.gz`). See the `README.md` in this directory for more information.
|
||||
- The `language_model` directory contains the ngram language model implementation and a pretrained 1gram language model. Pretrained 3gram and 5gram language models can be downloaded [here](https://datadryad.org/dataset/doi:10.5061/dryad.x69p8czpq) (`languageModel.tar.gz` and `languageModel_5gram.tar.gz`). See [`language_model/README.md`](language_model/README.md) for more information.
|
||||
|
||||
## Data
|
||||
The data used in this repository consists of various datasets for recreating figures and training/evaluating the brain-to-text model:
|
||||
@@ -44,9 +44,9 @@ The data used in this repository consists of various datasets for recreating fig
|
||||
- The ground truth sentence label
|
||||
- The ground truth phoneme sequence label
|
||||
- The data is split into training, validation, and test sets. The test set does not include ground truth sentence or phoneme labels.
|
||||
- Data for each session/split is stored in `.hdf5` files. An example of how to load this data using the Python `h5py` library is provided in the `model_training/evaluate_model_helpers.py` file in the `load_h5py_file()` function.
|
||||
- Data for each session/split is stored in `.hdf5` files. An example of how to load this data using the Python `h5py` library is provided in the [`model_training/evaluate_model_helpers.py`](model_training/evaluate_model_helpers.py) file in the `load_h5py_file()` function.
|
||||
- Each block of data contains sentences drawn from a range of corpuses (Switchboard, OpenWebText2, a 50-word corpus, a custom frequent-word corpus, and a corpus of random word sequences). Furthermore, the majority of the data is during attempted vocalized speaking, but some of it is during attempted silent speaking.
|
||||
- `t15_pretrained_rnn_baseline.zip`: This dataset contains the pretrained RNN baseline model checkpoint and args. An example of how to load this model and use it for inference is provided in the `model_training/evaluate_model.py` file.
|
||||
- `t15_pretrained_rnn_baseline.zip`: This dataset contains the pretrained RNN baseline model checkpoint and args. An example of how to load this model and use it for inference is provided in the [`model_training/evaluate_model.py`](model_training/evaluate_model.py) file.
|
||||
|
||||
Please download these datasets from [Dryad](https://datadryad.org/stash/dataset/doi:10.5061/dryad.dncjsxm85) and place them in the `data` directory. Be sure to unzip both datasets before running the code.
|
||||
|
||||
@@ -64,7 +64,7 @@ Please download these datasets from [Dryad](https://datadryad.org/stash/dataset/
|
||||
sudo apt-get update
|
||||
sudo apt-get install redis
|
||||
```
|
||||
- Turn of autorestarting for the redis server in terminal:
|
||||
- Turn off autorestarting for the redis server in terminal:
|
||||
- `sudo systemctl disable redis-server`
|
||||
- `CMake >= 3.14` and `gcc >= 10.1` are required for the ngram language model installation. You can install these on linux with `sudo apt-get install build-essential`.
|
||||
|
||||
@@ -75,9 +75,9 @@ To create a conda environment with the necessary dependencies, run the following
|
||||
```
|
||||
|
||||
## Python environment setup for ngram language model and OPT rescoring
|
||||
We use an ngram language model plus rescoring via the [Facebook OPT 6.7b](https://huggingface.co/facebook/opt-6.7b) LLM. A pretrained 1gram language model is included in this repository at `language_model/pretrained_language_models/openwebtext_1gram_lm_sil`. Pretrained 3gram and 5gram language models are available for download [here](https://datadryad.org/dataset/doi:10.5061/dryad.x69p8czpq) (`languageModel.tar.gz` and `languageModel_5gram.tar.gz`). Note that the 3gram model requires ~60GB of RAM, and the 5gram model requires ~300GB of RAM. Furthermore, OPT 6.7b requires a GPU with at least ~12.4 GB of VRAM to load for inference.
|
||||
We use an ngram language model plus rescoring via the [Facebook OPT 6.7b](https://huggingface.co/facebook/opt-6.7b) LLM. A pretrained 1gram language model is included in this repository at [`language_model/pretrained_language_models/openwebtext_1gram_lm_sil`](language_model/pretrained_language_models/openwebtext_1gram_lm_sil). Pretrained 3gram and 5gram language models are available for download [here](https://datadryad.org/dataset/doi:10.5061/dryad.x69p8czpq) (`languageModel.tar.gz` and `languageModel_5gram.tar.gz`). Note that the 3gram model requires ~60GB of RAM, and the 5gram model requires ~300GB of RAM. Furthermore, OPT 6.7b requires a GPU with at least ~12.4 GB of VRAM to load for inference.
|
||||
|
||||
Our Kaldi-based ngram implementation requires a different version of torch than our model training pipeline, so running the ngram language models requires an additional seperate python conda environment. To create this conda environment, run the following command from the root directory of this repository. For more detailed instructions, see the README.md in the `language_model` subdirectory.
|
||||
Our Kaldi-based ngram implementation requires a different version of torch than our model training pipeline, so running the ngram language models requires an additional seperate python conda environment. To create this conda environment, run the following command from the root directory of this repository. For more detailed instructions, see the README.md in the [`language_model`](language_model) subdirectory.
|
||||
```bash
|
||||
./setup_lm.sh
|
||||
```
|
@@ -1,5 +1,5 @@
|
||||
# Pretrained ngram language models
|
||||
A pretrained 1gram language model is included in this repository at `language_model/pretrained_language_models/openwebtext_1gram_lm_sil`. Pretrained 3gram and 5gram language models are available for download [here](https://datadryad.org/dataset/doi:10.5061/dryad.x69p8czpq) (`languageModel.tar.gz` and `languageModel_5gram.tar.gz`) and should likewise be placed in the `language_model/pretrained_language_models/` directory. Note that the 3gram model requires ~60GB of RAM, and the 5gram model requires ~300GB of RAM. Furthermore, OPT 6.7b requires a GPU with at least ~12.4 GB of VRAM to load for inference.
|
||||
A pretrained 1gram language model is included in this repository at `language_model/pretrained_language_models/openwebtext_1gram_lm_sil`. Pretrained 3gram and 5gram language models are available for download [here](https://datadryad.org/dataset/doi:10.5061/dryad.x69p8czpq) (`languageModel.tar.gz` and `languageModel_5gram.tar.gz`) and should likewise be placed in the [`pretrained_language_models`](pretrained_language_models) directory. Note that the 3gram model requires ~60GB of RAM, and the 5gram model requires ~300GB of RAM. Furthermore, OPT 6.7b requires a GPU with at least ~12.4 GB of VRAM to load for inference.
|
||||
|
||||
# Dependencies
|
||||
```
|
||||
@@ -13,11 +13,11 @@ sudo apt-get install build-essential
|
||||
```
|
||||
|
||||
# Install language model python package
|
||||
Use the `setup_lm.sh` script in the root directory of this repository to create the `b2txt_lm` conda env and install the `lm-decoder` package to it. Before install, make sure that there is no `build` or `fc_base` directory in your `language_model/runtime/server/x86` directory, as this may cause the build to fail.
|
||||
Use the `setup_lm.sh` script in the root directory of this repository to create the `b2txt_lm` conda env and install the `lm-decoder` package to it. Before install, make sure that there is no `build` or `fc_base` directory in your [`runtime/server/x86`](runtime/server/x86) directory, as this may cause the build to fail.
|
||||
|
||||
|
||||
# Using a pretrained ngram language model
|
||||
The `language-model-standalone.py` script included here is made to work with the `evaluate_model.py` script in the `model_training` directory. `language-model-standalone.py` will do the following when run:
|
||||
The [`language-model-standalone.py`](language-model-standalone.py) script included here is made to work with [`evaluate_model.py`](../model_training/evaluate_model.py). `language-model-standalone.py` will do the following when run:
|
||||
1. Initialize `opt-6.7b` it on the specified gpu (`--gpu_number` arg). The first time you run the script, it will automatically download `opt-6.7b` from huggingface.
|
||||
2. Initialize the ngram language model (specified with the `--lm_path` arg)
|
||||
3. Connect to the `localhost` redis server (or a different server, specified by the `--redis_ip` and `--redis_port` args)
|
||||
|
@@ -12,7 +12,7 @@ All model training and evaluation code was tested on a computer running Ubuntu 2
|
||||
|
||||
## Training
|
||||
### Baseline RNN Model
|
||||
We have included a custom PyTorch implementation of the RNN model used in the paper (the paper used a TensorFlow implementation). This implementation aims to replicate or improve upon the original model's performance while leveraging PyTorch's features, resulting in a more efficient training process with a slight increase in decoding accuracy. This model includes day-specific input layers (512x512 linear input layers with softsign activation), a 5-layer GRU with 768 hidden units per layer, and a linear output layer. The model is trained to predict phonemes from neural data using CTC loss and the AdamW optimizer. Data is augmented with noise and temporal jitter to improve robustness. All model hyperparameters are specified in the `rnn_args.yaml` file.
|
||||
We have included a custom PyTorch implementation of the RNN model used in the paper (the paper used a TensorFlow implementation). This implementation aims to replicate or improve upon the original model's performance while leveraging PyTorch's features, resulting in a more efficient training process with a slight increase in decoding accuracy. This model includes day-specific input layers (512x512 linear input layers with softsign activation), a 5-layer GRU with 768 hidden units per layer, and a linear output layer. The model is trained to predict phonemes from neural data using CTC loss and the AdamW optimizer. Data is augmented with noise and temporal jitter to improve robustness. All model hyperparameters are specified in the [`rnn_args.yaml`](rnn_args.yaml) file.
|
||||
|
||||
### Model training script
|
||||
To train the baseline RNN model, use the `b2txt25` conda environment to run the `train_model.py` script from the `model_training` directory:
|
||||
@@ -20,7 +20,7 @@ To train the baseline RNN model, use the `b2txt25` conda environment to run the
|
||||
conda activate b2txt25
|
||||
python train_model.py
|
||||
```
|
||||
The model will train for 120,000 mini-batches (~3.5 hours on an RTX 4090) and should achieve an aggregate phoneme error rate of 10.1% on the validation partition. We note that the number of training batches and specific model hyperparameters may not be optimal here, and this baseline model is only meant to serve as an example. See `rnn_args.yaml` for a list of all hyperparameters.
|
||||
The model will train for 120,000 mini-batches (~3.5 hours on an RTX 4090) and should achieve an aggregate phoneme error rate of 10.1% on the validation partition. We note that the number of training batches and specific model hyperparameters may not be optimal here, and this baseline model is only meant to serve as an example. See [`rnn_args.yaml`](rnn_args.yaml) for a list of all hyperparameters.
|
||||
|
||||
## Evaluation
|
||||
### Start redis server
|
||||
|
Reference in New Issue
Block a user