Copy Task figure and environment setup

This commit is contained in:
nckcard
2024-08-14 12:00:20 -07:00
parent 439d704bfb
commit aad9276b49
8 changed files with 532 additions and 1 deletions

5
.gitignore vendored
View File

@@ -160,3 +160,8 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# ignore data folder
data/*
.DS_Store

View File

@@ -1 +1,37 @@
# nejm-brain-to-text
# An Accurate and Rapidly Calibrating Speech Neuroprosthesis
*The New England Journal of Medicine* (2024)
Nicholas S. Card, Maitreyee Wairagkar, Carrina Iacobacci,
Xianda Hou, Tyler Singer-Clark, Francis R. Willett,
Erin M. Kunz, Chaofei Fan, Maryam Vahdati Nia,
Darrel R. Deo, Aparna Srinivasan, Eun Young Choi,
Matthew F. Glasser, Leigh R. Hochberg,
Jaimie M. Henderson, Kiarash Shahlaie,
Sergey D. Stavisky*, and David M. Brandman*.
<span style="font-size:0.8em;">\* denotes co-senior authors</span>
![Speech neuroprosthesis overview](b2txt_methods_overview.png)
## Overview
This repository contains the code and data necessary to reproduce the results of the paper "*An Accurate and Rapidly Calibrating Speech Neuroprosthesis*" by Card et al. (2024), *N Eng J Med*.
The code is written in Python, and the data can be downloaded from Dryad, [here](https://google.com). Please download this data and place it in the `data` directory before running the code.
Data is currently limited to what is necessary to reproduce the results in the paper, plus some additional simulated neural data that can be used to demonstrate the model training pipeline. A few language models of varying size and computational resource requirements are also included. We intend to share real neural data in the coming months.
The code is organized into four main directories: `utils`, `analyses`, `data`, and `model_training`:
- The `utils` directory contains utility functions used throughout the code.
- The `analyses` directory contains the code necessary to reproduce results shown in the main text and supplemental appendix.
- The `data` directory contains the data necessary to reproduce the results in the paper. Download it from Dryad using the link above and place it in this directory.
- The `model_training` directory contains the code necessary to train the brain-to-text model, including the offline model training and an offline simulation of the online finetuning pipeline, and also to run the language model. Note that the data used in the model training pipeline is simulated neural data, as the real neural data is not yet available.
## Python environment setup
The code is written in Python 3.9 and tested on Ubuntu 22.04. We recommend using a conda environment to manage the dependencies.
To install miniconda, follow the instructions [here](https://docs.anaconda.com/miniconda/miniconda-install/).
To create a conda environment with the necessary dependencies, run the following command from the root directory of this repository:
```bash
./setup.sh
```

295
analyses/figure_2.ipynb Normal file

File diff suppressed because one or more lines are too long

BIN
b2txt_methods_overview.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 954 KiB

View File

View File

@@ -0,0 +1,156 @@
import numpy as np
import re
from g2p_en import G2p
LOGIT_PHONE_DEF = [
'BLANK', 'SIL', # blank and silence
'AA', 'AE', 'AH', 'AO', 'AW',
'AY', 'B', 'CH', 'D', 'DH',
'EH', 'ER', 'EY', 'F', 'G',
'HH', 'IH', 'IY', 'JH', 'K',
'L', 'M', 'N', 'NG', 'OW',
'OY', 'P', 'R', 'S', 'SH',
'T', 'TH', 'UH', 'UW', 'V',
'W', 'Y', 'Z', 'ZH'
]
SIL_DEF = ['SIL']
# remove puntuation from text
def remove_punctuation(sentence):
# Remove punctuation
sentence = re.sub(r'[^a-zA-Z\- \']', '', sentence)
sentence = sentence.replace('--', '').lower()
sentence = sentence.replace(" '", "'").lower()
sentence = sentence.strip()
sentence = ' '.join(sentence.split())
return sentence
# Convert RNN logits to argmax phonemes
def logits_to_phonemes(logits):
seq = np.argmax(logits, axis=1)
seq2 = np.array([seq[0]] + [seq[i] for i in range(1, len(seq)) if seq[i] != seq[i-1]])
phones = []
for i in range(len(seq2)):
phones.append(LOGIT_PHONE_DEF[seq2[i]])
# Remove blank and repeated phonemes
phones = [p for p in phones if p!='BLANK']
phones = [phones[0]] + [phones[i] for i in range(1, len(phones)) if phones[i] != phones[i-1]]
return phones
# Convert text to phonemes
def sentence_to_phonemes(thisTranscription, g2p_instance=None):
if not g2p_instance:
g2p_instance = G2p()
# Remove punctuation
thisTranscription = remove_punctuation(thisTranscription)
# Convert to phonemes
phonemes = []
if len(thisTranscription) == 0:
phonemes = SIL_DEF
else:
for p in g2p_instance(thisTranscription):
if p==' ':
phonemes.append('SIL')
p = re.sub(r'[0-9]', '', p) # Remove stress
if re.match(r'[A-Z]+', p): # Only keep phonemes
phonemes.append(p)
#add one SIL symbol at the end so there's one at the end of each word
phonemes.append('SIL')
return phonemes, thisTranscription
# Calculate WER or PER
def calculate_error_rate(r, h):
"""
Calculation of WER or PER with Levenshtein distance.
Works only for iterables up to 254 elements (uint8).
O(nm) time ans space complexity.
----------
Parameters:
r : list of true words or phonemes
h : list of predicted words or phonemes
----------
Returns:
Word error rate (WER) or phoneme error rate (PER) [int]
----------
Examples:
>>> calculate_wer("who is there".split(), "is there".split())
1
>>> calculate_wer("who is there".split(), "".split())
3
>>> calculate_wer("".split(), "who is there".split())
3
"""
# initialization
d = np.zeros((len(r)+1)*(len(h)+1), dtype=np.uint8)
d = d.reshape((len(r)+1, len(h)+1))
for i in range(len(r)+1):
for j in range(len(h)+1):
if i == 0:
d[0][j] = j
elif j == 0:
d[i][0] = i
# computation
for i in range(1, len(r)+1):
for j in range(1, len(h)+1):
if r[i-1] == h[j-1]:
d[i][j] = d[i-1][j-1]
else:
substitution = d[i-1][j-1] + 1
insertion = d[i][j-1] + 1
deletion = d[i-1][j] + 1
d[i][j] = min(substitution, insertion, deletion)
return d[len(r)][len(h)]
# calculate aggregate WER or PER
def calculate_aggregate_error_rate(r, h):
# list setup
err_count = []
item_count = []
error_rate_ind = []
# calculate individual error rates
for x in range(len(h)):
r_x = r[x]
h_x = h[x]
n_err = calculate_error_rate(r_x, h_x)
item_count.append(len(r_x))
err_count.append(n_err)
error_rate_ind.append(n_err / len(r_x))
# Calculate aggregate error rate
error_rate_agg = np.sum(err_count) / np.sum(item_count)
# calculate 95% CI
item_count = np.array(item_count)
err_count = np.array(err_count)
nResamples = 10000
resampled_error_rate = np.zeros([nResamples,])
for n in range(nResamples):
resampleIdx = np.random.randint(0, item_count.shape[0], [item_count.shape[0]])
resampled_error_rate[n] = np.sum(err_count[resampleIdx]) / np.sum(item_count[resampleIdx])
error_rate_agg_CI = np.percentile(resampled_error_rate, [2.5, 97.5])
# return everything as a tuple
return (error_rate_agg, error_rate_agg_CI[0], error_rate_agg_CI[1], error_rate_ind)

12
setup.py Normal file
View File

@@ -0,0 +1,12 @@
from setuptools import setup, find_packages
setup(
name='nejm_b2txt_utils',
version='0.0.0',
packages=['nejm_b2txt_utils'],
# # Specify any packages that our package itself requires.
# install_requires=[
# 'numpy',
# 'g2p_en',
# ]
)

27
setup.sh Normal file
View File

@@ -0,0 +1,27 @@
#!/bin/bash
source ~/miniconda3/bin/activate
# Create a new conda environment with the name "tf-gpu" and Python 3.9
conda create -n tf-gpu python=3.9 -y
# Activate the conda environment
conda activate tf-gpu
# Install tensorflow-gpu version 2.10.0
conda install -c conda-forge tensorflow-gpu=2.10.0 -y
# Install numpy, scikit-learn
conda install numpy==1.26.0 scikit-learn==1.3.0 -y
# Install omegaconf, pyyaml, redis, matplotlib, jupyter, transformers, g2p_en
pip install omegaconf==2.3.0 pyyaml==6.0.1 redis==5.0.1 matplotlib==3.8.1 jupyter==1.0.0 transformers==4.35.0 g2p_en==2.1.0 coloredlogs==15.0.1 numba==0.58.1
# install punctuation model
pip install deepmultilingualpunctuation==1.0.1
# install local repository
pip install -e .
# install lm-decoder
cd LanguageModelDecoder/runtime/server/x86
python setup.py install