additional documentation
This commit is contained in:
@@ -11,7 +11,11 @@ All model training and evaluation code was tested on a computer running Ubuntu 2
|
||||
2. Download the dataset from Dryad: [Dryad Dataset](https://datadryad.org/dataset/doi:10.5061/dryad.dncjsxm85). Place the downloaded data in the `data` directory. Be sure to unzip `t15_copyTask_neuralData.zip` and `t15_pretrained_rnn_baseline.zip`.
|
||||
|
||||
## Training
|
||||
To train the baseline RNN model, run the following command from the `model_training` directory:
|
||||
### Baseline RNN Model
|
||||
We have included a custom PyTorch implementation of the RNN model used in the paper (the paper used a TensorFlow implementation). This implementation aims to replicate or improve upon the original model's performance while leveraging PyTorch's features, resulting in a more efficient training process with a slight increase in decoding accuracy. This model includes day-specific input layers (512x512 linear input layers with softsign activation), a 5-layer GRU with 768 hidden units per layer, and a linear output layer. The model is trained to predict phonemes from neural data using CTC loss and the AdamW optimizer. Data is augmented with noise and temporal jitter to improve robustness. All model hyperparameters are specified in the `rnn_args.yaml` file.
|
||||
|
||||
### Model training script
|
||||
To train the baseline RNN model, use the `b2txt25` conda environment to run the `train_model.py` script from the `model_training` directory:
|
||||
```bash
|
||||
conda activate b2txt25
|
||||
python train_model.py
|
||||
@@ -20,13 +24,13 @@ The model will train for 120,000 mini-batches (~3.5 hours on an RTX 4090) and sh
|
||||
|
||||
## Evaluation
|
||||
### Start redis server
|
||||
To evaluate the model, first start a redis server in terminal with:
|
||||
To evaluate the model, first start a redis server on `localhost` in terminal with:
|
||||
```bash
|
||||
redis-server
|
||||
```
|
||||
|
||||
### Start language model
|
||||
Next, start the ngram language model in a seperate terminal window. For example, the 1gram language model can be started using the command below. Note that the 1gram model has no gramatical structure built into it. Details on downloading pretrained 3gram and 5gram language models and running them can be found in the README.md in the `language_model` directory.
|
||||
Next, use the `b2txt25_lm` conda environment to start the ngram language model in a seperate terminal window. For example, the 1gram language model can be started using the command below. Note that the 1gram model has no gramatical structure built into it. Details on downloading pretrained 3gram and 5gram language models and running them can be found in the README.md in the `language_model` directory.
|
||||
To run the 1gram language model from the root directory of this repository:
|
||||
```bash
|
||||
conda activate b2txt_lm
|
||||
@@ -34,7 +38,7 @@ python language_model/language-model-standalone.py --lm_path language_model/pret
|
||||
```
|
||||
|
||||
### Evaluate
|
||||
Finally, run the `evaluate_model.py` script to load the pretrained baseline RNN, use it for inference on the heldout val or test sets to get phoneme logits, pass them through the language model via redis to get word predictions, and then save the predicted sentences to a .txt file in the format required for competition submission.
|
||||
Finally, use the `b2txt25` conda environment to run the `evaluate_model.py` script to load the pretrained baseline RNN, use it for inference on the heldout val or test sets to get phoneme logits, pass them through the language model via redis to get word predictions, and then save the predicted sentences to a .txt file in the format required for competition submission. An example output file for the val split can be found at `rnn_baseline_submission_file_valsplit.txt`.
|
||||
```bash
|
||||
conda activate b2txt25
|
||||
python evaluate_model.py --model_path ../data/t15_pretrained_rnn_baseline --data_dir ../data/t15_copyTask_neuralData --eval_type test --gpu_number 1
|
||||
|
@@ -262,4 +262,9 @@ if eval_type == 'val':
|
||||
output_file = os.path.join(model_path, f'baseline_rnn_{eval_type}_predicted_sentences_{time.strftime("%Y%m%d_%H%M%S")}.txt')
|
||||
with open(output_file, 'w') as f:
|
||||
for i in range(len(lm_results['pred_sentence'])):
|
||||
f.write(f"{remove_punctuation(lm_results['pred_sentence'][i])}\n")
|
||||
if i < len(lm_results['pred_sentence']) - 1:
|
||||
# write sentence + newline
|
||||
f.write(f"{remove_punctuation(lm_results['pred_sentence'][i])}\n")
|
||||
else:
|
||||
# don't add a newline at the end of the last sentence
|
||||
f.write(f"{remove_punctuation(lm_results['pred_sentence'][i])}")
|
@@ -130,7 +130,7 @@ dataset:
|
||||
- t15.2025.03.30
|
||||
- t15.2025.04.13
|
||||
dataset_probability_val: # probability of including a trial in the validation set (0 or 1)
|
||||
- 0
|
||||
- 0 # no val or test data from this day
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
@@ -158,12 +158,12 @@ dataset:
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 0
|
||||
- 0 # no val or test data from this day
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 0
|
||||
- 0
|
||||
- 0 # no val or test data from this day
|
||||
- 0 # no val or test data from this day
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
|
Reference in New Issue
Block a user