competition update

This commit is contained in:
nckcard
2025-07-02 12:18:09 -07:00
parent 9e17716a4a
commit 77dbcf868f
2615 changed files with 1648116 additions and 125 deletions

View File

@@ -1,23 +1,44 @@
# Model Training
# Model Training & Evaluation
This directory contains code and resources for training the brain-to-text RNN model. This model is largely based on the architecture described in the paper "*An Accurate and Rapidly Calibrating Speech Neuroprosthesis*" by Card et al. (2024), but also contains modifications to improve performance, efficiency, and usability.
A pretrained baseline RNN model is included in the [Dryad Dataset](https://datadryad.org/dataset/doi:10.5061/dryad.dncjsxm85), as is the neural data required to train that model. The code for training the same model is included here.
All model training and evaluation code was tested on a computer running Ubuntu 22.04 with two RTX 4090's and 512 GB of RAM.
## Setup
1. Install the required `b2txt25` conda environment by following the instructions in the root `README.md` file. This will set up the necessary dependencies for running the model training and evaluation code.
1. Install the required conda environment by following the instructions in the root `README.md` file. This will set up the necessary dependencies for running the model training and evaluation code.
2. Download the dataset from Dryad: [Dryad Dataset](https://datadryad.org/dataset/doi:10.5061/dryad.dncjsxm85). Place the downloaded data in the `data` directory.
2. Download the dataset from Dryad: [Dryad Dataset](https://datadryad.org/dataset/doi:10.5061/dryad.dncjsxm85). Place the downloaded data in the `data` directory. Be sure to unzip `t15_copyTask_neuralData.zip` and `t15_pretrained_rnn_baseline.zip`.
## Training
To train the baseline RNN model, run the following command:
To train the baseline RNN model, run the following command from the `model_training` directory:
```bash
conda activate b2txt25
python train_model.py
```
The model will train for 120,000 mini-batches (~3.5 hours on an RTX 4090) and should achieve an aggregate phoneme error rate of 10.1% on the validation partition. We note that the number of training batches and specific model hyperparameters may not be optimal here, and this baseline model is only meant to serve as an example. See `rnn_args.yaml` for a list of all hyperparameters.
## Evaluation
To evaluate the model, run:
### Start redis server
To evaluate the model, first start a redis server in terminal with:
```bash
python evaluate_model.py
```
redis-server
```
### Start language model
Next, start the ngram language model in a seperate terminal window. For example, the 1gram language model can be started using the command below. Note that the 1gram model has no gramatical structure built into it. Details on downloading pretrained 3gram and 5gram language models and running them can be found in the README.md in the `language_model` directory.
To run the 1gram language model from the root directory of this repository:
```bash
conda activate b2txt_lm
python language_model/language-model-standalone.py --lm_path language_model/pretrained_language_models/openwebtext_1gram_lm_sil --do_opt --nbest 100 --acoustic_scale 0.325 --blank_penalty 90 --alpha 0.55 --redis_ip localhost --gpu_number 0
```
### Evaluate
Finally, run the `evaluate_model.py` script to load the pretrained baseline RNN, use it for inference on the heldout val or test sets to get phoneme logits, pass them through the language model via redis to get word predictions, and then save the predicted sentences to a .txt file in the format required for competition submission.
```bash
conda activate b2txt25
python evaluate_model.py --model_path ../data/t15_pretrained_rnn_baseline --data_dir ../data/t15_copyTask_neuralData --eval_type test --gpu_number 1
```
### Shutdown redis
When you're done, you can shutdown the redis server from any terminal using `redis-cli shutdown`.

View File

@@ -124,7 +124,7 @@ class BrainToTextDataset(Dataset):
for t in index[d]:
try:
g = f[f'trial_{t}']
g = f[f'trial_{t:04d}']
# Remove features is neccessary
input_features = torch.from_numpy(g['input_features'][:]) # neural data
@@ -277,7 +277,7 @@ def train_test_split_indicies(file_paths, test_percentage = 0.1, seed = -1, bad_
with h5py.File(path, 'r') as f:
num_trials = len(list(f.keys()))
for t in range(num_trials):
key = f'trial_{t}'
key = f'trial_{t:04d}'
block_num = f[key].attrs['block_num']
trial_num = f[key].attrs['trial_num']

View File

@@ -0,0 +1,265 @@
import os
import sys
import torch
import numpy as np
import redis
from omegaconf import OmegaConf
import time
from tqdm import tqdm
import editdistance
import argparse
from rnn_model import GRUDecoder
from evaluate_model_helpers import *
# argument parser for command line arguments
parser = argparse.ArgumentParser(description='Evaluate a pretrained RNN model on the copy task dataset.')
parser.add_argument('--model_path', type=str, default='../data/t15_pretrained_rnn_baseline',
help='Path to the pretrained model directory (relative to the current working directory).')
parser.add_argument('--data_dir', type=str, default='../data/t15_copyTask_neuralData',
help='Path to the dataset directory (relative to the current working directory).')
parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'],
help='Evaluation type: "val" for validation set, "test" for test set. '
'If "test", ground truth is not available.')
parser.add_argument('--gpu_number', type=int, default=1,
help='GPU number to use for RNN model inference. Set to -1 to use CPU.')
args = parser.parse_args()
# paths to model and data directories
# Note: these paths are relative to the current working directory
model_path = args.model_path
data_dir = args.data_dir
# define evaluation type
eval_type = args.eval_type # can be 'val' or 'test'. if 'test', ground truth is not available
# load model args
model_args = OmegaConf.load(os.path.join(model_path, 'checkpoint/args.yaml'))
# set up gpu device
gpu_number = args.gpu_number
if torch.cuda.is_available() and gpu_number >= 0:
if gpu_number >= torch.cuda.device_count():
raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}')
device = f'cuda:{gpu_number}'
device = torch.device(device)
print(f'Using {device} for model inference.')
else:
if gpu_number >= 0:
print(f'GPU number {gpu_number} requested but not available.')
print('Using CPU for model inference.')
device = torch.device('cpu')
# define model
model = GRUDecoder(
neural_dim = model_args['model']['n_input_features'],
n_units = model_args['model']['n_units'],
n_days = len(model_args['dataset']['sessions']),
n_classes = model_args['dataset']['n_classes'],
rnn_dropout = model_args['model']['rnn_dropout'],
input_dropout = model_args['model']['input_network']['input_layer_dropout'],
n_layers = model_args['model']['n_layers'],
patch_size = model_args['model']['patch_size'],
patch_stride = model_args['model']['patch_stride'],
)
# load model weights
checkpoint = torch.load(os.path.join(model_path, 'checkpoint/best_checkpoint'), weights_only=False)
# rename keys to not start with "module." (happens if model was saved with DataParallel)
for key in list(checkpoint['model_state_dict'].keys()):
checkpoint['model_state_dict'][key.replace("module.", "")] = checkpoint['model_state_dict'].pop(key)
checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = checkpoint['model_state_dict'].pop(key)
model.load_state_dict(checkpoint['model_state_dict'])
# add model to device
model.to(device)
# set model to eval mode
model.eval()
# load data for each session
test_data = {}
total_test_trials = 0
for session in model_args['dataset']['sessions']:
files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]
if f'data_{eval_type}.hdf5' in files:
eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
data = load_h5py_file(eval_file)
test_data[session] = data
total_test_trials += len(test_data[session]["neural_features"])
print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.')
print(f'Total number of {eval_type} trials: {total_test_trials}')
print()
# put neural data through the pretrained model to get phoneme predictions (logits)
with tqdm(total=total_test_trials, desc='Predicting phoneme sequences', unit='trial') as pbar:
for session, data in test_data.items():
data['logits'] = []
data['pred_seq'] = []
input_layer = model_args['dataset']['sessions'].index(session)
for trial in range(len(data['neural_features'])):
# get neural input for the trial
neural_input = data['neural_features'][trial]
# add batch dimension
neural_input = np.expand_dims(neural_input, axis=0)
# convert to torch tensor
neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16)
# run decoding step
logits = runSingleDecodingStep(neural_input, input_layer, model, model_args, device)
data['logits'].append(logits)
pbar.update(1)
pbar.close()
# convert logits to phoneme sequences and print them out
for session, data in test_data.items():
data['pred_seq'] = []
for trial in range(len(data['logits'])):
logits = data['logits'][trial][0]
pred_seq = np.argmax(logits, axis=-1)
# remove blanks (0)
pred_seq = [int(p) for p in pred_seq if p != 0]
# remove consecutive duplicates
pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
# convert to phonemes
pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq]
# add to data
data['pred_seq'].append(pred_seq)
# print out the predicted sequences
block_num = data['block_num'][trial]
trial_num = data['trial_num'][trial]
print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}')
if eval_type == 'val':
sentence_label = data['sentence_label'][trial]
true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
print(f'Sentence label: {sentence_label}')
print(f'True sequence: {" ".join(true_seq)}')
print(f'Predicted Sequence: {" ".join(pred_seq)}')
print()
# language model inference via redis
# make sure that the standalone language model is running on the localhost redis ip
# see README.md for instructions on how to run the language model
r = redis.Redis(host='localhost', port=6379, db=0)
remote_lm_input_stream = 'remote_lm_input'
remote_lm_output_partial_stream = 'remote_lm_output_partial'
remote_lm_output_final_stream = 'remote_lm_output_final'
remote_lm_output_partial_lastEntrySeen = get_current_redis_time_ms(r)
remote_lm_output_final_lastEntrySeen = get_current_redis_time_ms(r)
remote_lm_done_resetting_lastEntrySeen = get_current_redis_time_ms(r)
remote_lm_done_finalizing_lastEntrySeen = get_current_redis_time_ms(r)
remote_lm_done_updating_lastEntrySeen = get_current_redis_time_ms(r)
lm_results = {
'session': [],
'block': [],
'trial': [],
'true_sentence': [],
'pred_sentence': [],
}
# loop through all trials and put logits into the remote language model to get text predictions
# note: this takes ~15-20 minutes to run on the entire test split with the 5-gram LM + OPT rescoring (RTX 4090)
with tqdm(total=total_test_trials, desc='Running remote language model', unit='trial') as pbar:
for session in test_data.keys():
for trial in range(len(test_data[session]['logits'])):
# get trial logits and rearrange them for the LM
logits = rearrange_speech_logits_pt(test_data[session]['logits'][trial])[0]
# reset language model
remote_lm_done_resetting_lastEntrySeen = reset_remote_language_model(r, remote_lm_done_resetting_lastEntrySeen)
'''
# update language model parameters
remote_lm_done_updating_lastEntrySeen = update_remote_lm_params(
r,
remote_lm_done_updating_lastEntrySeen,
acoustic_scale=0.35,
blank_penalty=90.0,
alpha=0.55,
)
'''
# put logits into LM
remote_lm_output_partial_lastEntrySeen, decoded = send_logits_to_remote_lm(
r,
remote_lm_input_stream,
remote_lm_output_partial_stream,
remote_lm_output_partial_lastEntrySeen,
logits,
)
# finalize remote LM
remote_lm_output_final_lastEntrySeen, lm_out = finalize_remote_lm(
r,
remote_lm_output_final_stream,
remote_lm_output_final_lastEntrySeen,
)
# get the best candidate sentence
best_candidate_sentence = lm_out['candidate_sentences'][0]
# store results
lm_results['session'].append(session)
lm_results['block'].append(test_data[session]['block_num'][trial])
lm_results['trial'].append(test_data[session]['trial_num'][trial])
if eval_type == 'val':
lm_results['true_sentence'].append(test_data[session]['sentence_label'][trial])
else:
lm_results['true_sentence'].append(None)
lm_results['pred_sentence'].append(best_candidate_sentence)
# update progress bar
pbar.update(1)
pbar.close()
# if using the validation set, lets calculate the aggregate word error rate (WER)
if eval_type == 'val':
total_true_length = 0
total_edit_distance = 0
lm_results['edit_distance'] = []
lm_results['num_words'] = []
for i in range(len(lm_results['pred_sentence'])):
true_sentence = remove_punctuation(lm_results['true_sentence'][i]).strip()
pred_sentence = remove_punctuation(lm_results['pred_sentence'][i]).strip()
ed = editdistance.eval(true_sentence.split(), pred_sentence.split())
total_true_length += len(true_sentence.split())
total_edit_distance += ed
lm_results['edit_distance'].append(ed)
lm_results['num_words'].append(len(true_sentence.split()))
print(f'{lm_results["session"][i]} - Block {lm_results["block"][i]}, Trial {lm_results["trial"][i]}')
print(f'True sentence: {true_sentence}')
print(f'Predicted sentence: {pred_sentence}')
print(f'WER: {ed} / {100 * len(true_sentence.split())} = {ed / len(true_sentence.split()):.2f}%')
print()
print(f'Total true sentence length: {total_true_length}')
print(f'Total edit distance: {total_edit_distance}')
print(f'Aggregate Word Error Rate (WER): {100 * total_edit_distance / total_true_length:.2f}%')
# write predicted sentences to a text file. put a timestamp in the filename (YYYYMMDD_HHMMSS)
output_file = os.path.join(model_path, f'baseline_rnn_{eval_type}_predicted_sentences_{time.strftime("%Y%m%d_%H%M%S")}.txt')
with open(output_file, 'w') as f:
for i in range(len(lm_results['pred_sentence'])):
f.write(f"{remove_punctuation(lm_results['pred_sentence'][i])}\n")

View File

@@ -0,0 +1,290 @@
import torch
import numpy as np
import h5py
import time
import re
from data_augmentations import gauss_smooth
LOGIT_TO_PHONEME = [
'BLANK',
'AA', 'AE', 'AH', 'AO', 'AW',
'AY', 'B', 'CH', 'D', 'DH',
'EH', 'ER', 'EY', 'F', 'G',
'HH', 'IH', 'IY', 'JH', 'K',
'L', 'M', 'N', 'NG', 'OW',
'OY', 'P', 'R', 'S', 'SH',
'T', 'TH', 'UH', 'UW', 'V',
'W', 'Y', 'Z', 'ZH',
' | ',
]
def _extract_transcription(input):
endIdx = np.argwhere(input == 0)[0, 0]
trans = ''
for c in range(endIdx):
trans += chr(input[c])
return trans
def load_h5py_file(file_path):
data = {
'neural_features': [],
'n_time_steps': [],
'seq_class_ids': [],
'seq_len': [],
'transcriptions': [],
'sentence_label': [],
'session': [],
'block_num': [],
'trial_num': []
}
# Open the hdf5 file for that day
with h5py.File(file_path, 'r') as f:
keys = list(f.keys())
# For each trial in the selected trials in that day
for key in keys:
g = f[key]
neural_features = g['input_features'][:]
n_time_steps = g.attrs['n_time_steps']
seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None
seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None
transcription = g['transcription'][:] if 'transcription' in g else None
sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None
session = g.attrs['session']
block_num = g.attrs['block_num']
trial_num = g.attrs['trial_num']
data['neural_features'].append(neural_features)
data['n_time_steps'].append(n_time_steps)
data['seq_class_ids'].append(seq_class_ids)
data['seq_len'].append(seq_len)
data['transcriptions'].append(transcription)
data['sentence_label'].append(sentence_label)
data['session'].append(session)
data['block_num'].append(block_num)
data['trial_num'].append(trial_num)
return data
def rearrange_speech_logits_pt(logits):
# original order is [BLANK, phonemes..., SIL]
# rearrange so the order is [BLANK, SIL, phonemes...]
logits = np.concatenate((logits[:, :, 0:1], logits[:, :, -1:], logits[:, :, 1:-1]), axis=-1)
return logits
# single decoding step function.
# smooths data and puts it through the model.
def runSingleDecodingStep(x, input_layer, model, model_args, device):
# Use autocast for efficiency
with torch.autocast(device_type = "cuda", enabled = model_args['use_amp'], dtype = torch.bfloat16):
x = gauss_smooth(
inputs = x,
device = device,
smooth_kernel_std = model_args['dataset']['data_transforms']['smooth_kernel_std'],
smooth_kernel_size = model_args['dataset']['data_transforms']['smooth_kernel_size'],
padding = 'valid',
)
with torch.no_grad():
logits, _ = model(
x = x,
day_idx = torch.tensor([input_layer], device=device),
states = None, # no initial states
return_state = True,
)
# convert logits from bfloat16 to float32
logits = logits.float().cpu().numpy()
# # original order is [BLANK, phonemes..., SIL]
# # rearrange so the order is [BLANK, SIL, phonemes...]
# logits = rearrange_speech_logits_pt(logits)
return logits
def remove_punctuation(sentence):
# Remove punctuation
sentence = re.sub(r'[^a-zA-Z\- \']', '', sentence)
sentence = sentence.replace('- ', ' ').lower()
sentence = sentence.replace('--', '').lower()
sentence = sentence.replace(" '", "'").lower()
sentence = sentence.strip()
sentence = ' '.join([word for word in sentence.split() if word != ''])
return sentence
def get_current_redis_time_ms(redis_conn):
t = redis_conn.time()
return int(t[0]*1000 + t[1]/1000)
######### language model helper functions ##########
def reset_remote_language_model(
r,
remote_lm_done_resetting_lastEntrySeen,
):
r.xadd('remote_lm_reset', {'done': 0})
time.sleep(0.001)
# print('Resetting remote language model before continuing...')
remote_lm_done_resetting = []
while len(remote_lm_done_resetting) == 0:
remote_lm_done_resetting = r.xread(
{'remote_lm_done_resetting': remote_lm_done_resetting_lastEntrySeen},
count=1,
block=10000,
)
if len(remote_lm_done_resetting) == 0:
print(f'Still waiting for remote lm reset from ts {remote_lm_done_resetting_lastEntrySeen}...')
for entry_id, entry_data in remote_lm_done_resetting[0][1]:
remote_lm_done_resetting_lastEntrySeen = entry_id
# print('Remote language model reset.')
return remote_lm_done_resetting_lastEntrySeen
def update_remote_lm_params(
r,
remote_lm_done_updating_lastEntrySeen,
acoustic_scale=0.35,
blank_penalty=90.0,
alpha=0.55,
):
# update remote lm params
entry_dict = {
# 'max_active': max_active,
# 'min_active': min_active,
# 'beam': beam,
# 'lattice_beam': lattice_beam,
'acoustic_scale': acoustic_scale,
# 'ctc_blank_skip_threshold': ctc_blank_skip_threshold,
# 'length_penalty': length_penalty,
# 'nbest': nbest,
'blank_penalty': blank_penalty,
'alpha': alpha,
# 'do_opt': do_opt,
# 'rescore': rescore,
# 'top_candidates_to_augment': top_candidates_to_augment,
# 'score_penalty_percent': score_penalty_percent,
# 'specific_word_bias': specific_word_bias,
}
r.xadd('remote_lm_update_params', entry_dict)
time.sleep(0.001)
remote_lm_done_updating = []
while len(remote_lm_done_updating) == 0:
remote_lm_done_updating = r.xread(
{'remote_lm_done_updating_params': remote_lm_done_updating_lastEntrySeen},
block=10000,
count=1,
)
if len(remote_lm_done_updating) == 0:
print(f'Still waiting for remote lm to update parameters from ts {remote_lm_done_updating_lastEntrySeen}...')
for entry_id, entry_data in remote_lm_done_updating[0][1]:
remote_lm_done_updating_lastEntrySeen = entry_id
# print('Remote language model params updated.')
return remote_lm_done_updating_lastEntrySeen
def send_logits_to_remote_lm(
r,
remote_lm_input_stream,
remote_lm_output_partial_stream,
remote_lm_output_partial_lastEntrySeen,
logits,
):
# put logits into remote lm and get partial output
r.xadd(remote_lm_input_stream, {'logits': np.float32(logits).tobytes()})
remote_lm_output = []
while len(remote_lm_output) == 0:
remote_lm_output = r.xread(
{remote_lm_output_partial_stream: remote_lm_output_partial_lastEntrySeen},
block=10000,
count=1,
)
if len(remote_lm_output) == 0:
print(f'Still waiting for remote lm partial output from ts {remote_lm_output_partial_lastEntrySeen}...')
for entry_id, entry_data in remote_lm_output[0][1]:
remote_lm_output_partial_lastEntrySeen = entry_id
decoded = entry_data[b'lm_response_partial'].decode()
return remote_lm_output_partial_lastEntrySeen, decoded
def finalize_remote_lm(
r,
remote_lm_output_final_stream,
remote_lm_output_final_lastEntrySeen,
):
# finalize remote lm
r.xadd('remote_lm_finalize', {'done': 0})
time.sleep(0.005)
remote_lm_output = []
while len(remote_lm_output) == 0:
remote_lm_output = r.xread(
{remote_lm_output_final_stream: remote_lm_output_final_lastEntrySeen},
block=10000,
count=1,
)
if len(remote_lm_output) == 0:
print(f'Still waiting for remote lm final output from ts {remote_lm_output_final_lastEntrySeen}...')
# print('Received remote lm final output.')
for entry_id, entry_data in remote_lm_output[0][1]:
remote_lm_output_final_lastEntrySeen = entry_id
candidate_sentences = [str(c) for c in entry_data[b'scoring'].decode().split(';')[::5]]
candidate_acoustic_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[1::5]]
candidate_ngram_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[2::5]]
candidate_llm_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[3::5]]
candidate_total_scores = [float(c) for c in entry_data[b'scoring'].decode().split(';')[4::5]]
# account for a weird edge case where there are no candidate sentences
if len(candidate_sentences) == 0 or len(candidate_total_scores) == 0:
print('No candidate sentences were received from the language model.')
candidate_sentences = ['']
candidate_acoustic_scores = [0]
candidate_ngram_scores = [0]
candidate_llm_scores = [0]
candidate_total_scores = [0]
else:
# sort candidate sentences by total score (higher is better)
sort_order = np.argsort(candidate_total_scores)[::-1]
candidate_sentences = [candidate_sentences[i] for i in sort_order]
candidate_acoustic_scores = [candidate_acoustic_scores[i] for i in sort_order]
candidate_ngram_scores = [candidate_ngram_scores[i] for i in sort_order]
candidate_llm_scores = [candidate_llm_scores[i] for i in sort_order]
candidate_total_scores = [candidate_total_scores[i] for i in sort_order]
# loop through candidates backwards and remove any duplicates
for i in range(len(candidate_sentences)-1, 0, -1):
if candidate_sentences[i] in candidate_sentences[:i]:
candidate_sentences.pop(i)
candidate_acoustic_scores.pop(i)
candidate_ngram_scores.pop(i)
candidate_llm_scores.pop(i)
candidate_total_scores.pop(i)
lm_out = {
'candidate_sentences': candidate_sentences,
'candidate_acoustic_scores': candidate_acoustic_scores,
'candidate_ngram_scores': candidate_ngram_scores,
'candidate_llm_scores': candidate_llm_scores,
'candidate_total_scores': candidate_total_scores,
}
return remote_lm_output_final_lastEntrySeen, lm_out

View File

@@ -1,81 +1,89 @@
model:
n_input_features: 512
n_units: 768
rnn_dropout: 0.4
rnn_trainable: true
n_layers: 5
bidirectional: false
patch_size: 14
patch_stride: 4
n_input_features: 512 # number of input features in the neural data. (2 features per electrode, 256 electrodes)
n_units: 768 # number of units per GRU layer
rnn_dropout: 0.4 # dropout rate for the GRU layers
rnn_trainable: true # whether the GRU layers are trainable
n_layers: 5 # number of GRU layers
patch_size: 14 # size of the input patches (14 time steps)
patch_stride: 4 # stride for the input patches (4 time steps)
input_network:
n_input_layers: 1
n_input_layers: 1 # number of input layers per network (one network for each day)
input_layer_sizes:
- 512
input_trainable: true
input_layer_dropout: 0.2
gpu_number: '1'
distributed_training: false
- 512 # size of the input layer (number of input features)
input_trainable: true # whether the input layer is trainable
input_layer_dropout: 0.2 # dropout rate for the input layer
gpu_number: '1' # GPU number to use for training, formatted as a string (e.g., '0', '1', etc.)
mode: train
use_amp: true
output_dir: /media/lm-pc/8tb_nvme/b2txt25/rnn_v2_jitter
init_from_checkpoint: false
checkpoint_dir: /media/lm-pc/8tb_nvme/b2txt25/rnn_v2_jitter/checkpoint
init_checkpoint_path: None
save_best_checkpoint: true
save_all_val_steps: false
save_final_model: false
save_val_metrics: true
early_stopping: false
early_stopping_val_steps: 20
num_training_batches: 120000
lr_scheduler_type: cosine
lr_max: 0.005
lr_min: 0.0001
lr_decay_steps: 120000
lr_warmup_steps: 1000
lr_max_day: 0.005
lr_min_day: 0.0001
lr_decay_steps_day: 120000
lr_warmup_steps_day: 1000
beta0: 0.9
beta1: 0.999
epsilon: 0.1
weight_decay: 0.001
weight_decay_day: 0
seed: 10
grad_norm_clip_value: 10
batches_per_train_log: 200
batches_per_val_step: 2000
batches_per_save: 0
log_individual_day_val_PER: true
log_val_skip_logs: false
save_val_logits: true
save_val_data: false
use_amp: true # whether to use automatic mixed precision (AMP) for training
output_dir: trained_models/baseline_rnn # directory to save the trained model and logs
checkpoint_dir: trained_models/baseline_rnn/checkpoint # directory to save checkpoints during training
init_from_checkpoint: false # whether to initialize the model from a checkpoint
init_checkpoint_path: None # path to the checkpoint to initialize the model from, if any
save_best_checkpoint: true # whether to save the best checkpoint based on validation metrics
save_all_val_steps: false # whether to save checkpoints at all validation steps
save_final_model: false # whether to save the final model after training
save_val_metrics: true # whether to save validation metrics during training
early_stopping: false # whether to use early stopping based on validation metrics
early_stopping_val_steps: 20 # number of validation steps to wait before stopping training if no improvement is seen
num_training_batches: 120000 # number of training batches to run
lr_scheduler_type: cosine # type of learning rate scheduler to use
lr_max: 0.005 # maximum learning rate for the main model
lr_min: 0.0001 # minimum learning rate for the main model
lr_decay_steps: 120000 # number of steps for the learning rate decay
lr_warmup_steps: 1000 # number of warmup steps for the learning rate scheduler
lr_max_day: 0.005 # maximum learning rate for the day specific input layers
lr_min_day: 0.0001 # minimum learning rate for the day specific input layers
lr_decay_steps_day: 120000 # number of steps for the learning rate decay for the day specific input layers
lr_warmup_steps_day: 1000 # number of warmup steps for the learning rate scheduler for the day specific input layers
beta0: 0.9 # beta0 parameter for the Adam optimizer
beta1: 0.999 # beta1 parameter for the Adam optimizer
epsilon: 0.1 # epsilon parameter for the Adam optimizer
weight_decay: 0.001 # weight decay for the main model
weight_decay_day: 0 # weight decay for the day specific input layers
seed: 10 # random seed for reproducibility
grad_norm_clip_value: 10 # gradient norm clipping value
batches_per_train_log: 200 # number of batches per training log
batches_per_val_step: 2000 # number of batches per validation step
batches_per_save: 0 # number of batches per save
log_individual_day_val_PER: true # whether to log individual day validation performance
log_val_skip_logs: false # whether to skip logging validation metrics
save_val_logits: true # whether to save validation logits
save_val_data: false # whether to save validation data
dataset:
data_transforms:
white_noise_std: 1.0
constant_offset_std: 0.2
random_walk_std: 0.0
random_walk_axis: -1
static_gain_std: 0.0
random_cut: 3 #0
smooth_kernel_size: 100
smooth_data: true
smooth_kernel_std: 2
neural_dim: 512
batch_size: 64
n_classes: 41
max_seq_elements: 500
days_per_batch: 4
seed: 1
num_dataloader_workers: 4
loader_shuffle: false
must_include_days: null
test_percentage: 0.1
feature_subset: null
dataset_dir: /media/lm-pc/8tb_nvme/b2txt25/hdf5_data
bad_trials_dict: null
sessions:
white_noise_std: 1.0 # standard deviation of the white noise added to the data
constant_offset_std: 0.2 # standard deviation of the constant offset added to the data
random_walk_std: 0.0 # standard deviation of the random walk added to the data
random_walk_axis: -1 # axis along which the random walk is applied
static_gain_std: 0.0 # standard deviation of the static gain applied to the data
random_cut: 3 # number of time steps to randomly cut from the beginning of each batch of trials
smooth_kernel_size: 100 # size of the smoothing kernel applied to the data
smooth_data: true # whether to smooth the data
smooth_kernel_std: 2 # standard deviation of the smoothing kernel applied to the data
neural_dim: 512 # dimensionality of the neural data
batch_size: 64 # batch size for training
n_classes: 41 # number of classes (phonemes) in the dataset
max_seq_elements: 500 # maximum number of sequence elements (phonemes) for any trial
days_per_batch: 4 # number of randomly-selected days to include in each batch
seed: 1 # random seed for reproducibility
num_dataloader_workers: 4 # number of workers for the data loader
loader_shuffle: false # whether to shuffle the data loader
must_include_days: null # specific days to include in the dataset
test_percentage: 0.1 # percentage of data to use for testing
feature_subset: null # specific features to include in the dataset
dataset_dir: ../data/t15_copyTask_neuralData # directory containing the dataset
bad_trials_dict: null # dictionary of bad trials to exclude from the dataset
sessions: # list of sessions to include in the dataset
- t15.2023.08.11
- t15.2023.08.13
- t15.2023.08.18
@@ -121,7 +129,7 @@ dataset:
- t15.2025.03.16
- t15.2025.03.30
- t15.2025.04.13
dataset_probability_val:
dataset_probability_val: # probability of including a trial in the validation set (0 or 1)
- 0
- 1
- 1

View File

@@ -57,11 +57,11 @@ class BrainToTextDecoder_Trainer:
# Create output directory
if args['mode'] == 'train':
os.makedirs(self.args['output_dir'], exist_ok=True)
os.makedirs(self.args['output_dir'], exist_ok=False)
# Create checkpoint directory
if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']:
os.makedirs(self.args['checkpoint_dir'], exist_ok = True)
os.makedirs(self.args['checkpoint_dir'], exist_ok=False)
# Set up logging
self.logger = logging.getLogger(__name__)
@@ -82,14 +82,10 @@ class BrainToTextDecoder_Trainer:
self.logger.addHandler(sh)
# Configure device pytorch will use
if not self.args['distributed_training']:
if torch.cuda.is_available():
self.device = f"cuda:{self.args['gpu_number']}"
else:
self.device = "cpu"
if torch.cuda.is_available():
self.device = f"cuda:{self.args['gpu_number']}"
else:
self.device = "cuda"
self.device = "cpu"
self.logger.info(f'Using device: {self.device}')
@@ -108,7 +104,6 @@ class BrainToTextDecoder_Trainer:
rnn_dropout = self.args['model']['rnn_dropout'],
input_dropout = self.args['model']['input_network']['input_layer_dropout'],
n_layers = self.args['model']['n_layers'],
bidirectional = self.args['model']['bidirectional'],
patch_size = self.args['model']['patch_size'],
patch_stride = self.args['model']['patch_stride'],
)