competition update
This commit is contained in:
@@ -1,23 +1,44 @@
|
||||
# Model Training
|
||||
|
||||
# Model Training & Evaluation
|
||||
This directory contains code and resources for training the brain-to-text RNN model. This model is largely based on the architecture described in the paper "*An Accurate and Rapidly Calibrating Speech Neuroprosthesis*" by Card et al. (2024), but also contains modifications to improve performance, efficiency, and usability.
|
||||
|
||||
A pretrained baseline RNN model is included in the [Dryad Dataset](https://datadryad.org/dataset/doi:10.5061/dryad.dncjsxm85), as is the neural data required to train that model. The code for training the same model is included here.
|
||||
|
||||
All model training and evaluation code was tested on a computer running Ubuntu 22.04 with two RTX 4090's and 512 GB of RAM.
|
||||
|
||||
## Setup
|
||||
1. Install the required `b2txt25` conda environment by following the instructions in the root `README.md` file. This will set up the necessary dependencies for running the model training and evaluation code.
|
||||
|
||||
1. Install the required conda environment by following the instructions in the root `README.md` file. This will set up the necessary dependencies for running the model training and evaluation code.
|
||||
|
||||
2. Download the dataset from Dryad: [Dryad Dataset](https://datadryad.org/dataset/doi:10.5061/dryad.dncjsxm85). Place the downloaded data in the `data` directory.
|
||||
2. Download the dataset from Dryad: [Dryad Dataset](https://datadryad.org/dataset/doi:10.5061/dryad.dncjsxm85). Place the downloaded data in the `data` directory. Be sure to unzip `t15_copyTask_neuralData.zip` and `t15_pretrained_rnn_baseline.zip`.
|
||||
|
||||
## Training
|
||||
|
||||
To train the baseline RNN model, run the following command:
|
||||
To train the baseline RNN model, run the following command from the `model_training` directory:
|
||||
```bash
|
||||
conda activate b2txt25
|
||||
python train_model.py
|
||||
```
|
||||
The model will train for 120,000 mini-batches (~3.5 hours on an RTX 4090) and should achieve an aggregate phoneme error rate of 10.1% on the validation partition. We note that the number of training batches and specific model hyperparameters may not be optimal here, and this baseline model is only meant to serve as an example. See `rnn_args.yaml` for a list of all hyperparameters.
|
||||
|
||||
## Evaluation
|
||||
|
||||
To evaluate the model, run:
|
||||
### Start redis server
|
||||
To evaluate the model, first start a redis server in terminal with:
|
||||
```bash
|
||||
python evaluate_model.py
|
||||
```
|
||||
redis-server
|
||||
```
|
||||
|
||||
### Start language model
|
||||
Next, start the ngram language model in a seperate terminal window. For example, the 1gram language model can be started using the command below. Note that the 1gram model has no gramatical structure built into it. Details on downloading pretrained 3gram and 5gram language models and running them can be found in the README.md in the `language_model` directory.
|
||||
To run the 1gram language model from the root directory of this repository:
|
||||
```bash
|
||||
conda activate b2txt_lm
|
||||
python language_model/language-model-standalone.py --lm_path language_model/pretrained_language_models/openwebtext_1gram_lm_sil --do_opt --nbest 100 --acoustic_scale 0.325 --blank_penalty 90 --alpha 0.55 --redis_ip localhost --gpu_number 0
|
||||
```
|
||||
|
||||
### Evaluate
|
||||
Finally, run the `evaluate_model.py` script to load the pretrained baseline RNN, use it for inference on the heldout val or test sets to get phoneme logits, pass them through the language model via redis to get word predictions, and then save the predicted sentences to a .txt file in the format required for competition submission.
|
||||
```bash
|
||||
conda activate b2txt25
|
||||
python evaluate_model.py --model_path ../data/t15_pretrained_rnn_baseline --data_dir ../data/t15_copyTask_neuralData --eval_type test --gpu_number 1
|
||||
```
|
||||
|
||||
### Shutdown redis
|
||||
When you're done, you can shutdown the redis server from any terminal using `redis-cli shutdown`.
|
@@ -124,7 +124,7 @@ class BrainToTextDataset(Dataset):
|
||||
for t in index[d]:
|
||||
|
||||
try:
|
||||
g = f[f'trial_{t}']
|
||||
g = f[f'trial_{t:04d}']
|
||||
|
||||
# Remove features is neccessary
|
||||
input_features = torch.from_numpy(g['input_features'][:]) # neural data
|
||||
@@ -277,7 +277,7 @@ def train_test_split_indicies(file_paths, test_percentage = 0.1, seed = -1, bad_
|
||||
with h5py.File(path, 'r') as f:
|
||||
num_trials = len(list(f.keys()))
|
||||
for t in range(num_trials):
|
||||
key = f'trial_{t}'
|
||||
key = f'trial_{t:04d}'
|
||||
|
||||
block_num = f[key].attrs['block_num']
|
||||
trial_num = f[key].attrs['trial_num']
|
||||
|
265
model_training/evaluate_model.py
Normal file
265
model_training/evaluate_model.py
Normal file
@@ -0,0 +1,265 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import numpy as np
|
||||
import redis
|
||||
from omegaconf import OmegaConf
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
import editdistance
|
||||
import argparse
|
||||
|
||||
from rnn_model import GRUDecoder
|
||||
from evaluate_model_helpers import *
|
||||
|
||||
# argument parser for command line arguments
|
||||
parser = argparse.ArgumentParser(description='Evaluate a pretrained RNN model on the copy task dataset.')
|
||||
parser.add_argument('--model_path', type=str, default='../data/t15_pretrained_rnn_baseline',
|
||||
help='Path to the pretrained model directory (relative to the current working directory).')
|
||||
parser.add_argument('--data_dir', type=str, default='../data/t15_copyTask_neuralData',
|
||||
help='Path to the dataset directory (relative to the current working directory).')
|
||||
parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'],
|
||||
help='Evaluation type: "val" for validation set, "test" for test set. '
|
||||
'If "test", ground truth is not available.')
|
||||
parser.add_argument('--gpu_number', type=int, default=1,
|
||||
help='GPU number to use for RNN model inference. Set to -1 to use CPU.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# paths to model and data directories
|
||||
# Note: these paths are relative to the current working directory
|
||||
model_path = args.model_path
|
||||
data_dir = args.data_dir
|
||||
|
||||
# define evaluation type
|
||||
eval_type = args.eval_type # can be 'val' or 'test'. if 'test', ground truth is not available
|
||||
|
||||
# load model args
|
||||
model_args = OmegaConf.load(os.path.join(model_path, 'checkpoint/args.yaml'))
|
||||
|
||||
# set up gpu device
|
||||
gpu_number = args.gpu_number
|
||||
if torch.cuda.is_available() and gpu_number >= 0:
|
||||
if gpu_number >= torch.cuda.device_count():
|
||||
raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}')
|
||||
device = f'cuda:{gpu_number}'
|
||||
device = torch.device(device)
|
||||
print(f'Using {device} for model inference.')
|
||||
else:
|
||||
if gpu_number >= 0:
|
||||
print(f'GPU number {gpu_number} requested but not available.')
|
||||
print('Using CPU for model inference.')
|
||||
device = torch.device('cpu')
|
||||
|
||||
# define model
|
||||
model = GRUDecoder(
|
||||
neural_dim = model_args['model']['n_input_features'],
|
||||
n_units = model_args['model']['n_units'],
|
||||
n_days = len(model_args['dataset']['sessions']),
|
||||
n_classes = model_args['dataset']['n_classes'],
|
||||
rnn_dropout = model_args['model']['rnn_dropout'],
|
||||
input_dropout = model_args['model']['input_network']['input_layer_dropout'],
|
||||
n_layers = model_args['model']['n_layers'],
|
||||
patch_size = model_args['model']['patch_size'],
|
||||
patch_stride = model_args['model']['patch_stride'],
|
||||
)
|
||||
|
||||
# load model weights
|
||||
checkpoint = torch.load(os.path.join(model_path, 'checkpoint/best_checkpoint'), weights_only=False)
|
||||
# rename keys to not start with "module." (happens if model was saved with DataParallel)
|
||||
for key in list(checkpoint['model_state_dict'].keys()):
|
||||
checkpoint['model_state_dict'][key.replace("module.", "")] = checkpoint['model_state_dict'].pop(key)
|
||||
checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = checkpoint['model_state_dict'].pop(key)
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
|
||||
# add model to device
|
||||
model.to(device)
|
||||
|
||||
# set model to eval mode
|
||||
model.eval()
|
||||
|
||||
# load data for each session
|
||||
test_data = {}
|
||||
total_test_trials = 0
|
||||
for session in model_args['dataset']['sessions']:
|
||||
files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]
|
||||
if f'data_{eval_type}.hdf5' in files:
|
||||
eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
|
||||
|
||||
data = load_h5py_file(eval_file)
|
||||
test_data[session] = data
|
||||
|
||||
total_test_trials += len(test_data[session]["neural_features"])
|
||||
print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.')
|
||||
print(f'Total number of {eval_type} trials: {total_test_trials}')
|
||||
print()
|
||||
|
||||
|
||||
# put neural data through the pretrained model to get phoneme predictions (logits)
|
||||
with tqdm(total=total_test_trials, desc='Predicting phoneme sequences', unit='trial') as pbar:
|
||||
for session, data in test_data.items():
|
||||
|
||||
data['logits'] = []
|
||||
data['pred_seq'] = []
|
||||
input_layer = model_args['dataset']['sessions'].index(session)
|
||||
|
||||
for trial in range(len(data['neural_features'])):
|
||||
# get neural input for the trial
|
||||
neural_input = data['neural_features'][trial]
|
||||
|
||||
# add batch dimension
|
||||
neural_input = np.expand_dims(neural_input, axis=0)
|
||||
|
||||
# convert to torch tensor
|
||||
neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16)
|
||||
|
||||
# run decoding step
|
||||
logits = runSingleDecodingStep(neural_input, input_layer, model, model_args, device)
|
||||
data['logits'].append(logits)
|
||||
|
||||
pbar.update(1)
|
||||
pbar.close()
|
||||
|
||||
|
||||
# convert logits to phoneme sequences and print them out
|
||||
for session, data in test_data.items():
|
||||
data['pred_seq'] = []
|
||||
for trial in range(len(data['logits'])):
|
||||
logits = data['logits'][trial][0]
|
||||
pred_seq = np.argmax(logits, axis=-1)
|
||||
# remove blanks (0)
|
||||
pred_seq = [int(p) for p in pred_seq if p != 0]
|
||||
# remove consecutive duplicates
|
||||
pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
|
||||
# convert to phonemes
|
||||
pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq]
|
||||
# add to data
|
||||
data['pred_seq'].append(pred_seq)
|
||||
|
||||
# print out the predicted sequences
|
||||
block_num = data['block_num'][trial]
|
||||
trial_num = data['trial_num'][trial]
|
||||
print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}')
|
||||
if eval_type == 'val':
|
||||
sentence_label = data['sentence_label'][trial]
|
||||
true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
|
||||
true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
|
||||
|
||||
print(f'Sentence label: {sentence_label}')
|
||||
print(f'True sequence: {" ".join(true_seq)}')
|
||||
print(f'Predicted Sequence: {" ".join(pred_seq)}')
|
||||
print()
|
||||
|
||||
|
||||
# language model inference via redis
|
||||
# make sure that the standalone language model is running on the localhost redis ip
|
||||
# see README.md for instructions on how to run the language model
|
||||
r = redis.Redis(host='localhost', port=6379, db=0)
|
||||
remote_lm_input_stream = 'remote_lm_input'
|
||||
remote_lm_output_partial_stream = 'remote_lm_output_partial'
|
||||
remote_lm_output_final_stream = 'remote_lm_output_final'
|
||||
|
||||
remote_lm_output_partial_lastEntrySeen = get_current_redis_time_ms(r)
|
||||
remote_lm_output_final_lastEntrySeen = get_current_redis_time_ms(r)
|
||||
remote_lm_done_resetting_lastEntrySeen = get_current_redis_time_ms(r)
|
||||
remote_lm_done_finalizing_lastEntrySeen = get_current_redis_time_ms(r)
|
||||
remote_lm_done_updating_lastEntrySeen = get_current_redis_time_ms(r)
|
||||
|
||||
lm_results = {
|
||||
'session': [],
|
||||
'block': [],
|
||||
'trial': [],
|
||||
'true_sentence': [],
|
||||
'pred_sentence': [],
|
||||
}
|
||||
|
||||
# loop through all trials and put logits into the remote language model to get text predictions
|
||||
# note: this takes ~15-20 minutes to run on the entire test split with the 5-gram LM + OPT rescoring (RTX 4090)
|
||||
with tqdm(total=total_test_trials, desc='Running remote language model', unit='trial') as pbar:
|
||||
for session in test_data.keys():
|
||||
for trial in range(len(test_data[session]['logits'])):
|
||||
# get trial logits and rearrange them for the LM
|
||||
logits = rearrange_speech_logits_pt(test_data[session]['logits'][trial])[0]
|
||||
|
||||
# reset language model
|
||||
remote_lm_done_resetting_lastEntrySeen = reset_remote_language_model(r, remote_lm_done_resetting_lastEntrySeen)
|
||||
|
||||
'''
|
||||
# update language model parameters
|
||||
remote_lm_done_updating_lastEntrySeen = update_remote_lm_params(
|
||||
r,
|
||||
remote_lm_done_updating_lastEntrySeen,
|
||||
acoustic_scale=0.35,
|
||||
blank_penalty=90.0,
|
||||
alpha=0.55,
|
||||
)
|
||||
'''
|
||||
|
||||
# put logits into LM
|
||||
remote_lm_output_partial_lastEntrySeen, decoded = send_logits_to_remote_lm(
|
||||
r,
|
||||
remote_lm_input_stream,
|
||||
remote_lm_output_partial_stream,
|
||||
remote_lm_output_partial_lastEntrySeen,
|
||||
logits,
|
||||
)
|
||||
|
||||
# finalize remote LM
|
||||
remote_lm_output_final_lastEntrySeen, lm_out = finalize_remote_lm(
|
||||
r,
|
||||
remote_lm_output_final_stream,
|
||||
remote_lm_output_final_lastEntrySeen,
|
||||
)
|
||||
|
||||
# get the best candidate sentence
|
||||
best_candidate_sentence = lm_out['candidate_sentences'][0]
|
||||
|
||||
# store results
|
||||
lm_results['session'].append(session)
|
||||
lm_results['block'].append(test_data[session]['block_num'][trial])
|
||||
lm_results['trial'].append(test_data[session]['trial_num'][trial])
|
||||
if eval_type == 'val':
|
||||
lm_results['true_sentence'].append(test_data[session]['sentence_label'][trial])
|
||||
else:
|
||||
lm_results['true_sentence'].append(None)
|
||||
lm_results['pred_sentence'].append(best_candidate_sentence)
|
||||
|
||||
# update progress bar
|
||||
pbar.update(1)
|
||||
pbar.close()
|
||||
|
||||
|
||||
# if using the validation set, lets calculate the aggregate word error rate (WER)
|
||||
if eval_type == 'val':
|
||||
total_true_length = 0
|
||||
total_edit_distance = 0
|
||||
|
||||
lm_results['edit_distance'] = []
|
||||
lm_results['num_words'] = []
|
||||
|
||||
for i in range(len(lm_results['pred_sentence'])):
|
||||
true_sentence = remove_punctuation(lm_results['true_sentence'][i]).strip()
|
||||
pred_sentence = remove_punctuation(lm_results['pred_sentence'][i]).strip()
|
||||
ed = editdistance.eval(true_sentence.split(), pred_sentence.split())
|
||||
|
||||
total_true_length += len(true_sentence.split())
|
||||
total_edit_distance += ed
|
||||
|
||||
lm_results['edit_distance'].append(ed)
|
||||
lm_results['num_words'].append(len(true_sentence.split()))
|
||||
|
||||
print(f'{lm_results["session"][i]} - Block {lm_results["block"][i]}, Trial {lm_results["trial"][i]}')
|
||||
print(f'True sentence: {true_sentence}')
|
||||
print(f'Predicted sentence: {pred_sentence}')
|
||||
print(f'WER: {ed} / {100 * len(true_sentence.split())} = {ed / len(true_sentence.split()):.2f}%')
|
||||
print()
|
||||
|
||||
print(f'Total true sentence length: {total_true_length}')
|
||||
print(f'Total edit distance: {total_edit_distance}')
|
||||
print(f'Aggregate Word Error Rate (WER): {100 * total_edit_distance / total_true_length:.2f}%')
|
||||
|
||||
|
||||
# write predicted sentences to a text file. put a timestamp in the filename (YYYYMMDD_HHMMSS)
|
||||
output_file = os.path.join(model_path, f'baseline_rnn_{eval_type}_predicted_sentences_{time.strftime("%Y%m%d_%H%M%S")}.txt')
|
||||
with open(output_file, 'w') as f:
|
||||
for i in range(len(lm_results['pred_sentence'])):
|
||||
f.write(f"{remove_punctuation(lm_results['pred_sentence'][i])}\n")
|
290
model_training/evaluate_model_helpers.py
Normal file
290
model_training/evaluate_model_helpers.py
Normal file
@@ -0,0 +1,290 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import h5py
|
||||
import time
|
||||
import re
|
||||
|
||||
from data_augmentations import gauss_smooth
|
||||
|
||||
LOGIT_TO_PHONEME = [
|
||||
'BLANK',
|
||||
'AA', 'AE', 'AH', 'AO', 'AW',
|
||||
'AY', 'B', 'CH', 'D', 'DH',
|
||||
'EH', 'ER', 'EY', 'F', 'G',
|
||||
'HH', 'IH', 'IY', 'JH', 'K',
|
||||
'L', 'M', 'N', 'NG', 'OW',
|
||||
'OY', 'P', 'R', 'S', 'SH',
|
||||
'T', 'TH', 'UH', 'UW', 'V',
|
||||
'W', 'Y', 'Z', 'ZH',
|
||||
' | ',
|
||||
]
|
||||
|
||||
def _extract_transcription(input):
|
||||
endIdx = np.argwhere(input == 0)[0, 0]
|
||||
trans = ''
|
||||
for c in range(endIdx):
|
||||
trans += chr(input[c])
|
||||
return trans
|
||||
|
||||
def load_h5py_file(file_path):
|
||||
data = {
|
||||
'neural_features': [],
|
||||
'n_time_steps': [],
|
||||
'seq_class_ids': [],
|
||||
'seq_len': [],
|
||||
'transcriptions': [],
|
||||
'sentence_label': [],
|
||||
'session': [],
|
||||
'block_num': [],
|
||||
'trial_num': []
|
||||
}
|
||||
# Open the hdf5 file for that day
|
||||
with h5py.File(file_path, 'r') as f:
|
||||
|
||||
keys = list(f.keys())
|
||||
|
||||
# For each trial in the selected trials in that day
|
||||
for key in keys:
|
||||
g = f[key]
|
||||
|
||||
neural_features = g['input_features'][:]
|
||||
n_time_steps = g.attrs['n_time_steps']
|
||||
seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None
|
||||
seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None
|
||||
transcription = g['transcription'][:] if 'transcription' in g else None
|
||||
sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None
|
||||
session = g.attrs['session']
|
||||
block_num = g.attrs['block_num']
|
||||
trial_num = g.attrs['trial_num']
|
||||
|
||||
|
||||
data['neural_features'].append(neural_features)
|
||||
data['n_time_steps'].append(n_time_steps)
|
||||
data['seq_class_ids'].append(seq_class_ids)
|
||||
data['seq_len'].append(seq_len)
|
||||
data['transcriptions'].append(transcription)
|
||||
data['sentence_label'].append(sentence_label)
|
||||
data['session'].append(session)
|
||||
data['block_num'].append(block_num)
|
||||
data['trial_num'].append(trial_num)
|
||||
return data
|
||||
|
||||
def rearrange_speech_logits_pt(logits):
|
||||
# original order is [BLANK, phonemes..., SIL]
|
||||
# rearrange so the order is [BLANK, SIL, phonemes...]
|
||||
logits = np.concatenate((logits[:, :, 0:1], logits[:, :, -1:], logits[:, :, 1:-1]), axis=-1)
|
||||
return logits
|
||||
|
||||
# single decoding step function.
|
||||
# smooths data and puts it through the model.
|
||||
def runSingleDecodingStep(x, input_layer, model, model_args, device):
|
||||
|
||||
# Use autocast for efficiency
|
||||
with torch.autocast(device_type = "cuda", enabled = model_args['use_amp'], dtype = torch.bfloat16):
|
||||
|
||||
x = gauss_smooth(
|
||||
inputs = x,
|
||||
device = device,
|
||||
smooth_kernel_std = model_args['dataset']['data_transforms']['smooth_kernel_std'],
|
||||
smooth_kernel_size = model_args['dataset']['data_transforms']['smooth_kernel_size'],
|
||||
padding = 'valid',
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
logits, _ = model(
|
||||
x = x,
|
||||
day_idx = torch.tensor([input_layer], device=device),
|
||||
states = None, # no initial states
|
||||
return_state = True,
|
||||
)
|
||||
|
||||
# convert logits from bfloat16 to float32
|
||||
logits = logits.float().cpu().numpy()
|
||||
|
||||
# # original order is [BLANK, phonemes..., SIL]
|
||||
# # rearrange so the order is [BLANK, SIL, phonemes...]
|
||||
# logits = rearrange_speech_logits_pt(logits)
|
||||
|
||||
return logits
|
||||
|
||||
def remove_punctuation(sentence):
|
||||
# Remove punctuation
|
||||
sentence = re.sub(r'[^a-zA-Z\- \']', '', sentence)
|
||||
sentence = sentence.replace('- ', ' ').lower()
|
||||
sentence = sentence.replace('--', '').lower()
|
||||
sentence = sentence.replace(" '", "'").lower()
|
||||
|
||||
sentence = sentence.strip()
|
||||
sentence = ' '.join([word for word in sentence.split() if word != ''])
|
||||
|
||||
return sentence
|
||||
|
||||
def get_current_redis_time_ms(redis_conn):
|
||||
t = redis_conn.time()
|
||||
return int(t[0]*1000 + t[1]/1000)
|
||||
|
||||
|
||||
######### language model helper functions ##########
|
||||
|
||||
def reset_remote_language_model(
|
||||
r,
|
||||
remote_lm_done_resetting_lastEntrySeen,
|
||||
):
|
||||
|
||||
r.xadd('remote_lm_reset', {'done': 0})
|
||||
time.sleep(0.001)
|
||||
# print('Resetting remote language model before continuing...')
|
||||
remote_lm_done_resetting = []
|
||||
while len(remote_lm_done_resetting) == 0:
|
||||
remote_lm_done_resetting = r.xread(
|
||||
{'remote_lm_done_resetting': remote_lm_done_resetting_lastEntrySeen},
|
||||
count=1,
|
||||
block=10000,
|
||||
)
|
||||
if len(remote_lm_done_resetting) == 0:
|
||||
print(f'Still waiting for remote lm reset from ts {remote_lm_done_resetting_lastEntrySeen}...')
|
||||
for entry_id, entry_data in remote_lm_done_resetting[0][1]:
|
||||
remote_lm_done_resetting_lastEntrySeen = entry_id
|
||||
# print('Remote language model reset.')
|
||||
|
||||
return remote_lm_done_resetting_lastEntrySeen
|
||||
|
||||
|
||||
def update_remote_lm_params(
|
||||
r,
|
||||
remote_lm_done_updating_lastEntrySeen,
|
||||
acoustic_scale=0.35,
|
||||
blank_penalty=90.0,
|
||||
alpha=0.55,
|
||||
):
|
||||
|
||||
# update remote lm params
|
||||
entry_dict = {
|
||||
# 'max_active': max_active,
|
||||
# 'min_active': min_active,
|
||||
# 'beam': beam,
|
||||
# 'lattice_beam': lattice_beam,
|
||||
'acoustic_scale': acoustic_scale,
|
||||
# 'ctc_blank_skip_threshold': ctc_blank_skip_threshold,
|
||||
# 'length_penalty': length_penalty,
|
||||
# 'nbest': nbest,
|
||||
'blank_penalty': blank_penalty,
|
||||
'alpha': alpha,
|
||||
# 'do_opt': do_opt,
|
||||
# 'rescore': rescore,
|
||||
# 'top_candidates_to_augment': top_candidates_to_augment,
|
||||
# 'score_penalty_percent': score_penalty_percent,
|
||||
# 'specific_word_bias': specific_word_bias,
|
||||
}
|
||||
|
||||
r.xadd('remote_lm_update_params', entry_dict)
|
||||
time.sleep(0.001)
|
||||
remote_lm_done_updating = []
|
||||
while len(remote_lm_done_updating) == 0:
|
||||
remote_lm_done_updating = r.xread(
|
||||
{'remote_lm_done_updating_params': remote_lm_done_updating_lastEntrySeen},
|
||||
block=10000,
|
||||
count=1,
|
||||
)
|
||||
if len(remote_lm_done_updating) == 0:
|
||||
print(f'Still waiting for remote lm to update parameters from ts {remote_lm_done_updating_lastEntrySeen}...')
|
||||
for entry_id, entry_data in remote_lm_done_updating[0][1]:
|
||||
remote_lm_done_updating_lastEntrySeen = entry_id
|
||||
# print('Remote language model params updated.')
|
||||
|
||||
return remote_lm_done_updating_lastEntrySeen
|
||||
|
||||
|
||||
def send_logits_to_remote_lm(
|
||||
r,
|
||||
remote_lm_input_stream,
|
||||
remote_lm_output_partial_stream,
|
||||
remote_lm_output_partial_lastEntrySeen,
|
||||
logits,
|
||||
):
|
||||
|
||||
# put logits into remote lm and get partial output
|
||||
r.xadd(remote_lm_input_stream, {'logits': np.float32(logits).tobytes()})
|
||||
remote_lm_output = []
|
||||
while len(remote_lm_output) == 0:
|
||||
remote_lm_output = r.xread(
|
||||
{remote_lm_output_partial_stream: remote_lm_output_partial_lastEntrySeen},
|
||||
block=10000,
|
||||
count=1,
|
||||
)
|
||||
if len(remote_lm_output) == 0:
|
||||
print(f'Still waiting for remote lm partial output from ts {remote_lm_output_partial_lastEntrySeen}...')
|
||||
for entry_id, entry_data in remote_lm_output[0][1]:
|
||||
remote_lm_output_partial_lastEntrySeen = entry_id
|
||||
decoded = entry_data[b'lm_response_partial'].decode()
|
||||
|
||||
return remote_lm_output_partial_lastEntrySeen, decoded
|
||||
|
||||
|
||||
def finalize_remote_lm(
|
||||
r,
|
||||
remote_lm_output_final_stream,
|
||||
remote_lm_output_final_lastEntrySeen,
|
||||
):
|
||||
|
||||
# finalize remote lm
|
||||
r.xadd('remote_lm_finalize', {'done': 0})
|
||||
time.sleep(0.005)
|
||||
remote_lm_output = []
|
||||
while len(remote_lm_output) == 0:
|
||||
remote_lm_output = r.xread(
|
||||
{remote_lm_output_final_stream: remote_lm_output_final_lastEntrySeen},
|
||||
block=10000,
|
||||
count=1,
|
||||
)
|
||||
if len(remote_lm_output) == 0:
|
||||
print(f'Still waiting for remote lm final output from ts {remote_lm_output_final_lastEntrySeen}...')
|
||||
# print('Received remote lm final output.')
|
||||
|
||||
for entry_id, entry_data in remote_lm_output[0][1]:
|
||||
remote_lm_output_final_lastEntrySeen = entry_id
|
||||
|
||||
candidate_sentences = [str(c) for c in entry_data[b'scoring'].decode().split(';')[::5]]
|
||||
candidate_acoustic_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[1::5]]
|
||||
candidate_ngram_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[2::5]]
|
||||
candidate_llm_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[3::5]]
|
||||
candidate_total_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[4::5]]
|
||||
|
||||
|
||||
# account for a weird edge case where there are no candidate sentences
|
||||
if len(candidate_sentences) == 0 or len(candidate_total_scores) == 0:
|
||||
print('No candidate sentences were received from the language model.')
|
||||
candidate_sentences = ['']
|
||||
candidate_acoustic_scores = [0]
|
||||
candidate_ngram_scores = [0]
|
||||
candidate_llm_scores = [0]
|
||||
candidate_total_scores = [0]
|
||||
|
||||
else:
|
||||
# sort candidate sentences by total score (higher is better)
|
||||
sort_order = np.argsort(candidate_total_scores)[::-1]
|
||||
|
||||
candidate_sentences = [candidate_sentences[i] for i in sort_order]
|
||||
candidate_acoustic_scores = [candidate_acoustic_scores[i] for i in sort_order]
|
||||
candidate_ngram_scores = [candidate_ngram_scores[i] for i in sort_order]
|
||||
candidate_llm_scores = [candidate_llm_scores[i] for i in sort_order]
|
||||
candidate_total_scores = [candidate_total_scores[i] for i in sort_order]
|
||||
|
||||
# loop through candidates backwards and remove any duplicates
|
||||
for i in range(len(candidate_sentences)-1, 0, -1):
|
||||
if candidate_sentences[i] in candidate_sentences[:i]:
|
||||
candidate_sentences.pop(i)
|
||||
candidate_acoustic_scores.pop(i)
|
||||
candidate_ngram_scores.pop(i)
|
||||
candidate_llm_scores.pop(i)
|
||||
candidate_total_scores.pop(i)
|
||||
|
||||
lm_out = {
|
||||
'candidate_sentences': candidate_sentences,
|
||||
'candidate_acoustic_scores': candidate_acoustic_scores,
|
||||
'candidate_ngram_scores': candidate_ngram_scores,
|
||||
'candidate_llm_scores': candidate_llm_scores,
|
||||
'candidate_total_scores': candidate_total_scores,
|
||||
}
|
||||
|
||||
return remote_lm_output_final_lastEntrySeen, lm_out
|
@@ -1,81 +1,89 @@
|
||||
model:
|
||||
n_input_features: 512
|
||||
n_units: 768
|
||||
rnn_dropout: 0.4
|
||||
rnn_trainable: true
|
||||
n_layers: 5
|
||||
bidirectional: false
|
||||
patch_size: 14
|
||||
patch_stride: 4
|
||||
n_input_features: 512 # number of input features in the neural data. (2 features per electrode, 256 electrodes)
|
||||
n_units: 768 # number of units per GRU layer
|
||||
rnn_dropout: 0.4 # dropout rate for the GRU layers
|
||||
rnn_trainable: true # whether the GRU layers are trainable
|
||||
n_layers: 5 # number of GRU layers
|
||||
patch_size: 14 # size of the input patches (14 time steps)
|
||||
patch_stride: 4 # stride for the input patches (4 time steps)
|
||||
|
||||
input_network:
|
||||
n_input_layers: 1
|
||||
n_input_layers: 1 # number of input layers per network (one network for each day)
|
||||
input_layer_sizes:
|
||||
- 512
|
||||
input_trainable: true
|
||||
input_layer_dropout: 0.2
|
||||
gpu_number: '1'
|
||||
distributed_training: false
|
||||
- 512 # size of the input layer (number of input features)
|
||||
input_trainable: true # whether the input layer is trainable
|
||||
input_layer_dropout: 0.2 # dropout rate for the input layer
|
||||
|
||||
gpu_number: '1' # GPU number to use for training, formatted as a string (e.g., '0', '1', etc.)
|
||||
mode: train
|
||||
use_amp: true
|
||||
output_dir: /media/lm-pc/8tb_nvme/b2txt25/rnn_v2_jitter
|
||||
init_from_checkpoint: false
|
||||
checkpoint_dir: /media/lm-pc/8tb_nvme/b2txt25/rnn_v2_jitter/checkpoint
|
||||
init_checkpoint_path: None
|
||||
save_best_checkpoint: true
|
||||
save_all_val_steps: false
|
||||
save_final_model: false
|
||||
save_val_metrics: true
|
||||
early_stopping: false
|
||||
early_stopping_val_steps: 20
|
||||
num_training_batches: 120000
|
||||
lr_scheduler_type: cosine
|
||||
lr_max: 0.005
|
||||
lr_min: 0.0001
|
||||
lr_decay_steps: 120000
|
||||
lr_warmup_steps: 1000
|
||||
lr_max_day: 0.005
|
||||
lr_min_day: 0.0001
|
||||
lr_decay_steps_day: 120000
|
||||
lr_warmup_steps_day: 1000
|
||||
beta0: 0.9
|
||||
beta1: 0.999
|
||||
epsilon: 0.1
|
||||
weight_decay: 0.001
|
||||
weight_decay_day: 0
|
||||
seed: 10
|
||||
grad_norm_clip_value: 10
|
||||
batches_per_train_log: 200
|
||||
batches_per_val_step: 2000
|
||||
batches_per_save: 0
|
||||
log_individual_day_val_PER: true
|
||||
log_val_skip_logs: false
|
||||
save_val_logits: true
|
||||
save_val_data: false
|
||||
use_amp: true # whether to use automatic mixed precision (AMP) for training
|
||||
|
||||
output_dir: trained_models/baseline_rnn # directory to save the trained model and logs
|
||||
checkpoint_dir: trained_models/baseline_rnn/checkpoint # directory to save checkpoints during training
|
||||
init_from_checkpoint: false # whether to initialize the model from a checkpoint
|
||||
init_checkpoint_path: None # path to the checkpoint to initialize the model from, if any
|
||||
save_best_checkpoint: true # whether to save the best checkpoint based on validation metrics
|
||||
save_all_val_steps: false # whether to save checkpoints at all validation steps
|
||||
save_final_model: false # whether to save the final model after training
|
||||
save_val_metrics: true # whether to save validation metrics during training
|
||||
early_stopping: false # whether to use early stopping based on validation metrics
|
||||
early_stopping_val_steps: 20 # number of validation steps to wait before stopping training if no improvement is seen
|
||||
|
||||
num_training_batches: 120000 # number of training batches to run
|
||||
lr_scheduler_type: cosine # type of learning rate scheduler to use
|
||||
lr_max: 0.005 # maximum learning rate for the main model
|
||||
lr_min: 0.0001 # minimum learning rate for the main model
|
||||
lr_decay_steps: 120000 # number of steps for the learning rate decay
|
||||
lr_warmup_steps: 1000 # number of warmup steps for the learning rate scheduler
|
||||
lr_max_day: 0.005 # maximum learning rate for the day specific input layers
|
||||
lr_min_day: 0.0001 # minimum learning rate for the day specific input layers
|
||||
lr_decay_steps_day: 120000 # number of steps for the learning rate decay for the day specific input layers
|
||||
lr_warmup_steps_day: 1000 # number of warmup steps for the learning rate scheduler for the day specific input layers
|
||||
|
||||
beta0: 0.9 # beta0 parameter for the Adam optimizer
|
||||
beta1: 0.999 # beta1 parameter for the Adam optimizer
|
||||
epsilon: 0.1 # epsilon parameter for the Adam optimizer
|
||||
weight_decay: 0.001 # weight decay for the main model
|
||||
weight_decay_day: 0 # weight decay for the day specific input layers
|
||||
seed: 10 # random seed for reproducibility
|
||||
grad_norm_clip_value: 10 # gradient norm clipping value
|
||||
|
||||
batches_per_train_log: 200 # number of batches per training log
|
||||
batches_per_val_step: 2000 # number of batches per validation step
|
||||
|
||||
batches_per_save: 0 # number of batches per save
|
||||
log_individual_day_val_PER: true # whether to log individual day validation performance
|
||||
log_val_skip_logs: false # whether to skip logging validation metrics
|
||||
save_val_logits: true # whether to save validation logits
|
||||
save_val_data: false # whether to save validation data
|
||||
|
||||
dataset:
|
||||
data_transforms:
|
||||
white_noise_std: 1.0
|
||||
constant_offset_std: 0.2
|
||||
random_walk_std: 0.0
|
||||
random_walk_axis: -1
|
||||
static_gain_std: 0.0
|
||||
random_cut: 3 #0
|
||||
smooth_kernel_size: 100
|
||||
smooth_data: true
|
||||
smooth_kernel_std: 2
|
||||
neural_dim: 512
|
||||
batch_size: 64
|
||||
n_classes: 41
|
||||
max_seq_elements: 500
|
||||
days_per_batch: 4
|
||||
seed: 1
|
||||
num_dataloader_workers: 4
|
||||
loader_shuffle: false
|
||||
must_include_days: null
|
||||
test_percentage: 0.1
|
||||
feature_subset: null
|
||||
dataset_dir: /media/lm-pc/8tb_nvme/b2txt25/hdf5_data
|
||||
bad_trials_dict: null
|
||||
sessions:
|
||||
white_noise_std: 1.0 # standard deviation of the white noise added to the data
|
||||
constant_offset_std: 0.2 # standard deviation of the constant offset added to the data
|
||||
random_walk_std: 0.0 # standard deviation of the random walk added to the data
|
||||
random_walk_axis: -1 # axis along which the random walk is applied
|
||||
static_gain_std: 0.0 # standard deviation of the static gain applied to the data
|
||||
random_cut: 3 # number of time steps to randomly cut from the beginning of each batch of trials
|
||||
smooth_kernel_size: 100 # size of the smoothing kernel applied to the data
|
||||
smooth_data: true # whether to smooth the data
|
||||
smooth_kernel_std: 2 # standard deviation of the smoothing kernel applied to the data
|
||||
|
||||
neural_dim: 512 # dimensionality of the neural data
|
||||
batch_size: 64 # batch size for training
|
||||
n_classes: 41 # number of classes (phonemes) in the dataset
|
||||
max_seq_elements: 500 # maximum number of sequence elements (phonemes) for any trial
|
||||
days_per_batch: 4 # number of randomly-selected days to include in each batch
|
||||
seed: 1 # random seed for reproducibility
|
||||
num_dataloader_workers: 4 # number of workers for the data loader
|
||||
loader_shuffle: false # whether to shuffle the data loader
|
||||
must_include_days: null # specific days to include in the dataset
|
||||
test_percentage: 0.1 # percentage of data to use for testing
|
||||
feature_subset: null # specific features to include in the dataset
|
||||
|
||||
dataset_dir: ../data/t15_copyTask_neuralData # directory containing the dataset
|
||||
bad_trials_dict: null # dictionary of bad trials to exclude from the dataset
|
||||
sessions: # list of sessions to include in the dataset
|
||||
- t15.2023.08.11
|
||||
- t15.2023.08.13
|
||||
- t15.2023.08.18
|
||||
@@ -121,7 +129,7 @@ dataset:
|
||||
- t15.2025.03.16
|
||||
- t15.2025.03.30
|
||||
- t15.2025.04.13
|
||||
dataset_probability_val:
|
||||
dataset_probability_val: # probability of including a trial in the validation set (0 or 1)
|
||||
- 0
|
||||
- 1
|
||||
- 1
|
||||
|
@@ -57,11 +57,11 @@ class BrainToTextDecoder_Trainer:
|
||||
|
||||
# Create output directory
|
||||
if args['mode'] == 'train':
|
||||
os.makedirs(self.args['output_dir'], exist_ok=True)
|
||||
os.makedirs(self.args['output_dir'], exist_ok=False)
|
||||
|
||||
# Create checkpoint directory
|
||||
if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']:
|
||||
os.makedirs(self.args['checkpoint_dir'], exist_ok = True)
|
||||
os.makedirs(self.args['checkpoint_dir'], exist_ok=False)
|
||||
|
||||
# Set up logging
|
||||
self.logger = logging.getLogger(__name__)
|
||||
@@ -82,14 +82,10 @@ class BrainToTextDecoder_Trainer:
|
||||
self.logger.addHandler(sh)
|
||||
|
||||
# Configure device pytorch will use
|
||||
if not self.args['distributed_training']:
|
||||
if torch.cuda.is_available():
|
||||
self.device = f"cuda:{self.args['gpu_number']}"
|
||||
|
||||
else:
|
||||
self.device = "cpu"
|
||||
if torch.cuda.is_available():
|
||||
self.device = f"cuda:{self.args['gpu_number']}"
|
||||
else:
|
||||
self.device = "cuda"
|
||||
self.device = "cpu"
|
||||
|
||||
self.logger.info(f'Using device: {self.device}')
|
||||
|
||||
@@ -108,7 +104,6 @@ class BrainToTextDecoder_Trainer:
|
||||
rnn_dropout = self.args['model']['rnn_dropout'],
|
||||
input_dropout = self.args['model']['input_network']['input_layer_dropout'],
|
||||
n_layers = self.args['model']['n_layers'],
|
||||
bidirectional = self.args['model']['bidirectional'],
|
||||
patch_size = self.args['model']['patch_size'],
|
||||
patch_stride = self.args['model']['patch_stride'],
|
||||
)
|
||||
|
Reference in New Issue
Block a user