Copy Task figure and environment setup
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -160,3 +160,8 @@ cython_debug/
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# ignore data folder
|
||||
data/*
|
||||
|
||||
.DS_Store
|
38
README.md
38
README.md
@@ -1 +1,37 @@
|
||||
# nejm-brain-to-text
|
||||
# An Accurate and Rapidly Calibrating Speech Neuroprosthesis
|
||||
*The New England Journal of Medicine* (2024)
|
||||
|
||||
Nicholas S. Card, Maitreyee Wairagkar, Carrina Iacobacci,
|
||||
Xianda Hou, Tyler Singer-Clark, Francis R. Willett,
|
||||
Erin M. Kunz, Chaofei Fan, Maryam Vahdati Nia,
|
||||
Darrel R. Deo, Aparna Srinivasan, Eun Young Choi,
|
||||
Matthew F. Glasser, Leigh R. Hochberg,
|
||||
Jaimie M. Henderson, Kiarash Shahlaie,
|
||||
Sergey D. Stavisky*, and David M. Brandman*.
|
||||
|
||||
<span style="font-size:0.8em;">\* denotes co-senior authors</span>
|
||||
|
||||

|
||||
|
||||
## Overview
|
||||
This repository contains the code and data necessary to reproduce the results of the paper "*An Accurate and Rapidly Calibrating Speech Neuroprosthesis*" by Card et al. (2024), *N Eng J Med*.
|
||||
|
||||
The code is written in Python, and the data can be downloaded from Dryad, [here](https://google.com). Please download this data and place it in the `data` directory before running the code.
|
||||
|
||||
Data is currently limited to what is necessary to reproduce the results in the paper, plus some additional simulated neural data that can be used to demonstrate the model training pipeline. A few language models of varying size and computational resource requirements are also included. We intend to share real neural data in the coming months.
|
||||
|
||||
The code is organized into four main directories: `utils`, `analyses`, `data`, and `model_training`:
|
||||
- The `utils` directory contains utility functions used throughout the code.
|
||||
- The `analyses` directory contains the code necessary to reproduce results shown in the main text and supplemental appendix.
|
||||
- The `data` directory contains the data necessary to reproduce the results in the paper. Download it from Dryad using the link above and place it in this directory.
|
||||
- The `model_training` directory contains the code necessary to train the brain-to-text model, including the offline model training and an offline simulation of the online finetuning pipeline, and also to run the language model. Note that the data used in the model training pipeline is simulated neural data, as the real neural data is not yet available.
|
||||
|
||||
## Python environment setup
|
||||
The code is written in Python 3.9 and tested on Ubuntu 22.04. We recommend using a conda environment to manage the dependencies.
|
||||
|
||||
To install miniconda, follow the instructions [here](https://docs.anaconda.com/miniconda/miniconda-install/).
|
||||
|
||||
To create a conda environment with the necessary dependencies, run the following command from the root directory of this repository:
|
||||
```bash
|
||||
./setup.sh
|
||||
```
|
295
analyses/figure_2.ipynb
Normal file
295
analyses/figure_2.ipynb
Normal file
File diff suppressed because one or more lines are too long
BIN
b2txt_methods_overview.png
Normal file
BIN
b2txt_methods_overview.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 954 KiB |
0
nejm_b2txt_utils/__init__.py
Normal file
0
nejm_b2txt_utils/__init__.py
Normal file
156
nejm_b2txt_utils/general_utils.py
Normal file
156
nejm_b2txt_utils/general_utils.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import numpy as np
|
||||
import re
|
||||
from g2p_en import G2p
|
||||
|
||||
|
||||
|
||||
LOGIT_PHONE_DEF = [
|
||||
'BLANK', 'SIL', # blank and silence
|
||||
'AA', 'AE', 'AH', 'AO', 'AW',
|
||||
'AY', 'B', 'CH', 'D', 'DH',
|
||||
'EH', 'ER', 'EY', 'F', 'G',
|
||||
'HH', 'IH', 'IY', 'JH', 'K',
|
||||
'L', 'M', 'N', 'NG', 'OW',
|
||||
'OY', 'P', 'R', 'S', 'SH',
|
||||
'T', 'TH', 'UH', 'UW', 'V',
|
||||
'W', 'Y', 'Z', 'ZH'
|
||||
]
|
||||
SIL_DEF = ['SIL']
|
||||
|
||||
|
||||
# remove puntuation from text
|
||||
def remove_punctuation(sentence):
|
||||
# Remove punctuation
|
||||
sentence = re.sub(r'[^a-zA-Z\- \']', '', sentence)
|
||||
sentence = sentence.replace('--', '').lower()
|
||||
sentence = sentence.replace(" '", "'").lower()
|
||||
|
||||
sentence = sentence.strip()
|
||||
sentence = ' '.join(sentence.split())
|
||||
|
||||
return sentence
|
||||
|
||||
|
||||
# Convert RNN logits to argmax phonemes
|
||||
def logits_to_phonemes(logits):
|
||||
seq = np.argmax(logits, axis=1)
|
||||
seq2 = np.array([seq[0]] + [seq[i] for i in range(1, len(seq)) if seq[i] != seq[i-1]])
|
||||
|
||||
phones = []
|
||||
for i in range(len(seq2)):
|
||||
phones.append(LOGIT_PHONE_DEF[seq2[i]])
|
||||
|
||||
# Remove blank and repeated phonemes
|
||||
phones = [p for p in phones if p!='BLANK']
|
||||
phones = [phones[0]] + [phones[i] for i in range(1, len(phones)) if phones[i] != phones[i-1]]
|
||||
|
||||
return phones
|
||||
|
||||
|
||||
# Convert text to phonemes
|
||||
def sentence_to_phonemes(thisTranscription, g2p_instance=None):
|
||||
if not g2p_instance:
|
||||
g2p_instance = G2p()
|
||||
|
||||
# Remove punctuation
|
||||
thisTranscription = remove_punctuation(thisTranscription)
|
||||
|
||||
# Convert to phonemes
|
||||
phonemes = []
|
||||
if len(thisTranscription) == 0:
|
||||
phonemes = SIL_DEF
|
||||
else:
|
||||
for p in g2p_instance(thisTranscription):
|
||||
if p==' ':
|
||||
phonemes.append('SIL')
|
||||
|
||||
p = re.sub(r'[0-9]', '', p) # Remove stress
|
||||
if re.match(r'[A-Z]+', p): # Only keep phonemes
|
||||
phonemes.append(p)
|
||||
|
||||
#add one SIL symbol at the end so there's one at the end of each word
|
||||
phonemes.append('SIL')
|
||||
|
||||
return phonemes, thisTranscription
|
||||
|
||||
|
||||
# Calculate WER or PER
|
||||
def calculate_error_rate(r, h):
|
||||
"""
|
||||
Calculation of WER or PER with Levenshtein distance.
|
||||
Works only for iterables up to 254 elements (uint8).
|
||||
O(nm) time ans space complexity.
|
||||
----------
|
||||
Parameters:
|
||||
r : list of true words or phonemes
|
||||
h : list of predicted words or phonemes
|
||||
----------
|
||||
Returns:
|
||||
Word error rate (WER) or phoneme error rate (PER) [int]
|
||||
----------
|
||||
Examples:
|
||||
>>> calculate_wer("who is there".split(), "is there".split())
|
||||
1
|
||||
>>> calculate_wer("who is there".split(), "".split())
|
||||
3
|
||||
>>> calculate_wer("".split(), "who is there".split())
|
||||
3
|
||||
"""
|
||||
# initialization
|
||||
d = np.zeros((len(r)+1)*(len(h)+1), dtype=np.uint8)
|
||||
d = d.reshape((len(r)+1, len(h)+1))
|
||||
for i in range(len(r)+1):
|
||||
for j in range(len(h)+1):
|
||||
if i == 0:
|
||||
d[0][j] = j
|
||||
elif j == 0:
|
||||
d[i][0] = i
|
||||
|
||||
# computation
|
||||
for i in range(1, len(r)+1):
|
||||
for j in range(1, len(h)+1):
|
||||
if r[i-1] == h[j-1]:
|
||||
d[i][j] = d[i-1][j-1]
|
||||
else:
|
||||
substitution = d[i-1][j-1] + 1
|
||||
insertion = d[i][j-1] + 1
|
||||
deletion = d[i-1][j] + 1
|
||||
d[i][j] = min(substitution, insertion, deletion)
|
||||
|
||||
return d[len(r)][len(h)]
|
||||
|
||||
|
||||
# calculate aggregate WER or PER
|
||||
def calculate_aggregate_error_rate(r, h):
|
||||
|
||||
# list setup
|
||||
err_count = []
|
||||
item_count = []
|
||||
error_rate_ind = []
|
||||
|
||||
# calculate individual error rates
|
||||
for x in range(len(h)):
|
||||
r_x = r[x]
|
||||
h_x = h[x]
|
||||
|
||||
n_err = calculate_error_rate(r_x, h_x)
|
||||
|
||||
item_count.append(len(r_x))
|
||||
err_count.append(n_err)
|
||||
error_rate_ind.append(n_err / len(r_x))
|
||||
|
||||
# Calculate aggregate error rate
|
||||
error_rate_agg = np.sum(err_count) / np.sum(item_count)
|
||||
|
||||
# calculate 95% CI
|
||||
item_count = np.array(item_count)
|
||||
err_count = np.array(err_count)
|
||||
nResamples = 10000
|
||||
resampled_error_rate = np.zeros([nResamples,])
|
||||
for n in range(nResamples):
|
||||
resampleIdx = np.random.randint(0, item_count.shape[0], [item_count.shape[0]])
|
||||
resampled_error_rate[n] = np.sum(err_count[resampleIdx]) / np.sum(item_count[resampleIdx])
|
||||
error_rate_agg_CI = np.percentile(resampled_error_rate, [2.5, 97.5])
|
||||
|
||||
# return everything as a tuple
|
||||
return (error_rate_agg, error_rate_agg_CI[0], error_rate_agg_CI[1], error_rate_ind)
|
12
setup.py
Normal file
12
setup.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from setuptools import setup, find_packages
|
||||
|
||||
setup(
|
||||
name='nejm_b2txt_utils',
|
||||
version='0.0.0',
|
||||
packages=['nejm_b2txt_utils'],
|
||||
# # Specify any packages that our package itself requires.
|
||||
# install_requires=[
|
||||
# 'numpy',
|
||||
# 'g2p_en',
|
||||
# ]
|
||||
)
|
27
setup.sh
Normal file
27
setup.sh
Normal file
@@ -0,0 +1,27 @@
|
||||
#!/bin/bash
|
||||
source ~/miniconda3/bin/activate
|
||||
|
||||
# Create a new conda environment with the name "tf-gpu" and Python 3.9
|
||||
conda create -n tf-gpu python=3.9 -y
|
||||
|
||||
# Activate the conda environment
|
||||
conda activate tf-gpu
|
||||
|
||||
# Install tensorflow-gpu version 2.10.0
|
||||
conda install -c conda-forge tensorflow-gpu=2.10.0 -y
|
||||
|
||||
# Install numpy, scikit-learn
|
||||
conda install numpy==1.26.0 scikit-learn==1.3.0 -y
|
||||
|
||||
# Install omegaconf, pyyaml, redis, matplotlib, jupyter, transformers, g2p_en
|
||||
pip install omegaconf==2.3.0 pyyaml==6.0.1 redis==5.0.1 matplotlib==3.8.1 jupyter==1.0.0 transformers==4.35.0 g2p_en==2.1.0 coloredlogs==15.0.1 numba==0.58.1
|
||||
|
||||
# install punctuation model
|
||||
pip install deepmultilingualpunctuation==1.0.1
|
||||
|
||||
# install local repository
|
||||
pip install -e .
|
||||
|
||||
# install lm-decoder
|
||||
cd LanguageModelDecoder/runtime/server/x86
|
||||
python setup.py install
|
Reference in New Issue
Block a user