b2txt25 wip

This commit is contained in:
nckcard
2025-07-01 09:39:24 -07:00
parent bfea146f99
commit 9e17716a4a
8 changed files with 1460 additions and 1 deletions

View File

@@ -27,7 +27,7 @@ The code is organized into four main directories: `utils`, `analyses`, `data`, a
- The `model_training` directory contains the code necessary to train the brain-to-text model, including the offline model training and an offline simulation of the online finetuning pipeline, and also to run the language model. Note that the data used in the model training pipeline is simulated neural data, as the real neural data is not yet available. - The `model_training` directory contains the code necessary to train the brain-to-text model, including the offline model training and an offline simulation of the online finetuning pipeline, and also to run the language model. Note that the data used in the model training pipeline is simulated neural data, as the real neural data is not yet available.
## Python environment setup ## Python environment setup
The code is written in Python 3.9 and tested on Ubuntu 22.04. We recommend using a conda environment to manage the dependencies. The code is written in Python 3.10 and tested on Ubuntu 22.04. We recommend using a conda environment to manage the dependencies.
To install miniconda, follow the instructions [here](https://docs.anaconda.com/miniconda/miniconda-install/). To install miniconda, follow the instructions [here](https://docs.anaconda.com/miniconda/miniconda-install/).

23
model_training/README.md Normal file
View File

@@ -0,0 +1,23 @@
# Model Training
This directory contains code and resources for training the brain-to-text RNN model. This model is largely based on the architecture described in the paper "*An Accurate and Rapidly Calibrating Speech Neuroprosthesis*" by Card et al. (2024), but also contains modifications to improve performance, efficiency, and usability.
## Setup
1. Install the required conda environment by following the instructions in the root `README.md` file. This will set up the necessary dependencies for running the model training and evaluation code.
2. Download the dataset from Dryad: [Dryad Dataset](https://datadryad.org/dataset/doi:10.5061/dryad.dncjsxm85). Place the downloaded data in the `data` directory.
## Training
To train the baseline RNN model, run the following command:
```bash
python train_model.py
```
## Evaluation
To evaluate the model, run:
```bash
python evaluate_model.py
```

View File

@@ -0,0 +1,37 @@
import torch
import torch.nn.functional as F
import numpy as np
from scipy.ndimage import gaussian_filter1d
def gauss_smooth(inputs, device, smooth_kernel_std=2, smooth_kernel_size=100, padding='same'):
"""
Applies a 1D Gaussian smoothing operation with PyTorch to smooth the data along the time axis.
Args:
inputs (tensor : B x T x N): A 3D tensor with batch size B, time steps T, and number of features N.
Assumed to already be on the correct device (e.g., GPU).
kernelSD (float): Standard deviation of the Gaussian smoothing kernel.
padding (str): Padding mode, either 'same' or 'valid'.
device (str): Device to use for computation (e.g., 'cuda' or 'cpu').
Returns:
smoothed (tensor : B x T x N): A smoothed 3D tensor with batch size B, time steps T, and number of features N.
"""
# Get Gaussian kernel
inp = np.zeros(smooth_kernel_size, dtype=np.float32)
inp[smooth_kernel_size // 2] = 1
gaussKernel = gaussian_filter1d(inp, smooth_kernel_std)
validIdx = np.argwhere(gaussKernel > 0.01)
gaussKernel = gaussKernel[validIdx]
gaussKernel = np.squeeze(gaussKernel / np.sum(gaussKernel))
# Convert to tensor
gaussKernel = torch.tensor(gaussKernel, dtype=torch.float32, device=device)
gaussKernel = gaussKernel.view(1, 1, -1) # [1, 1, kernel_size]
# Prepare convolution
B, T, C = inputs.shape
inputs = inputs.permute(0, 2, 1) # [B, C, T]
gaussKernel = gaussKernel.repeat(C, 1, 1) # [C, 1, kernel_size]
# Perform convolution
smoothed = F.conv1d(inputs, gaussKernel, padding=padding, groups=C)
return smoothed.permute(0, 2, 1) # [B, T, C]

334
model_training/dataset.py Normal file
View File

@@ -0,0 +1,334 @@
import os
import torch
from torch.utils.data import Dataset
import h5py
import numpy as np
from torch.nn.utils.rnn import pad_sequence
import math
class BrainToTextDataset(Dataset):
'''
Dataset for brain-to-text data
Returns an entire batch of data instead of a single example
'''
def __init__(
self,
trial_indicies,
n_batches,
split = 'train',
batch_size = 64,
days_per_batch = 1,
random_seed = -1,
must_include_days = None,
feature_subset = None
):
'''
trial_indicies: (dict) - dictionary with day numbers as keys and lists of trial indices as values
n_batches: (int) - number of random training batches to create
split: (string) - string specifying if this is a train or test dataset
batch_size: (int) - number of examples to include in batch returned from __getitem_()
days_per_batch: (int) - how many unique days can exist in a batch; this is important for making sure that updates
to individual day layers in the GRU are not excesively noisy. Validation data will always have 1 day per batch
random_seed: (int) - seed to set for randomly assigning trials to a batch. If set to -1, trial assignment will be random
must_include_days ([int]) - list of days that must be included in every batch
feature_subset ([int]) - list of neural feature indicies that should be the only features included in the neural data
'''
# Set random seed for reproducibility
if random_seed != -1:
np.random.seed(random_seed)
torch.manual_seed(random_seed)
self.split = split
# Ensure the split is valid
if self.split not in ['train', 'test']:
raise ValueError(f'split must be either "train" or "test". Received {self.split}')
self.days_per_batch = days_per_batch
self.batch_size = batch_size
self.n_batches = n_batches
self.days = {}
self.n_trials = 0
self.trial_indicies = trial_indicies
self.n_days = len(trial_indicies.keys())
self.feature_subset = feature_subset
# Calculate total number of trials in the dataset
for d in trial_indicies:
self.n_trials += len(trial_indicies[d]['trials'])
if must_include_days is not None and len(must_include_days) > days_per_batch:
raise ValueError(f'must_include_days must be less than or equal to days_per_batch. Received {must_include_days} and days_per_batch {days_per_batch}')
if must_include_days is not None and len(must_include_days) > self.n_days and split != 'train':
raise ValueError(f'must_include_days is not valid for test data. Received {must_include_days} and but only {self.n_days} in the dataset')
if must_include_days is not None:
# Map must_include_days to correct indicies if they are negative
for i, d in enumerate(must_include_days):
if d < 0:
must_include_days[i] = self.n_days + d
self.must_include_days = must_include_days
# Ensure that the days_per_batch is not greater than the number of days in the dataset. Raise error
if self.split == 'train' and self.days_per_batch > self.n_days:
raise ValueError(f'Requested days_per_batch: {days_per_batch} is greater than available days {self.n_days}.')
if self.split == 'train':
self.batch_index = self.create_batch_index_train()
else:
self.batch_index = self.create_batch_index_test()
self.n_batches = len(self.batch_index.keys()) # The validation data has a fixed amount of data
def __len__(self):
'''
How many batches are in this dataset.
Because training data is sampled randomly, there is no fixed dataset length,
however this method is required for DataLoader to work
'''
return self.n_batches
def __getitem__(self, idx):
'''
Gets an entire batch of data from the dataset, not just a single item
'''
batch = {
'input_features' : [],
'seq_class_ids' : [],
'n_time_steps' : [],
'phone_seq_lens' : [],
'day_indicies' : [],
'transcriptions' : [],
'block_nums' : [],
'trial_nums' : [],
}
index = self.batch_index[idx]
# Iterate through each day in the index
for d in index.keys():
# Open the hdf5 file for that day
with h5py.File(self.trial_indicies[d]['session_path'], 'r') as f:
# For each trial in the selected trials in that day
for t in index[d]:
try:
g = f[f'trial_{t}']
# Remove features is neccessary
input_features = torch.from_numpy(g['input_features'][:]) # neural data
if self.feature_subset:
input_features = input_features[:,self.feature_subset]
batch['input_features'].append(input_features)
batch['seq_class_ids'].append(torch.from_numpy(g['seq_class_ids'][:])) # phoneme labels
batch['transcriptions'].append(torch.from_numpy(g['transcription'][:])) # character level transcriptions
batch['n_time_steps'].append(g.attrs['n_time_steps']) # number of time steps in the trial - required since we are padding
batch['phone_seq_lens'].append(g.attrs['seq_len']) # number of phonemes in the label - required since we are padding
batch['day_indicies'].append(int(d)) # day index of each trial - required for the day specific layers
batch['block_nums'].append(g.attrs['block_num'])
batch['trial_nums'].append(g.attrs['trial_num'])
except Exception as e:
print(f'Error loading trial {t} from session {self.trial_indicies[d]["session_path"]}: {e}')
continue
# Pad data to form a cohesive batch
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0)
batch['seq_class_ids'] = pad_sequence(batch['seq_class_ids'], batch_first = True, padding_value = 0)
batch['n_time_steps'] = torch.tensor(batch['n_time_steps'])
batch['phone_seq_lens'] = torch.tensor(batch['phone_seq_lens'])
batch['day_indicies'] = torch.tensor(batch['day_indicies'])
batch['transcriptions'] = torch.stack(batch['transcriptions'])
batch['block_nums'] = torch.tensor(batch['block_nums'])
batch['trial_nums'] = torch.tensor(batch['trial_nums'])
return batch
def create_batch_index_train(self):
'''
Create an index that maps a batch_number to batch_size number of trials
Each batch will have days_per_batch unique days of data, with the number of trials for each day evenly split between the days
(or as even as possible if batch_size is not divisible by days_per_batch)
'''
batch_index = {}
# Precompute the days that are not in must_include_days
if self.must_include_days is not None:
non_must_include_days = [d for d in self.trial_indicies.keys() if d not in self.must_include_days]
for batch_idx in range(self.n_batches):
batch = {}
# Which days will be used for this batch. Picked randomly without replacement
# TODO: In the future we may want to consider sampling days in proportion to the number of trials in each day
# If must_include_days is not empty, we will use those days and then randomly sample the rest
if self.must_include_days is not None and len(self.must_include_days) > 0:
days = np.concatenate((self.must_include_days, np.random.choice(non_must_include_days, size = self.days_per_batch - len(self.must_include_days), replace = False)))
# Otherwise we will select random days without replacement
else:
days = np.random.choice(list(self.trial_indicies.keys()), size = self.days_per_batch, replace = False)
# How many trials will be sampled from each day
num_trials = math.ceil(self.batch_size / self.days_per_batch) # Use ceiling to make sure we get at least batch_size trials
for d in days:
# Trials are sampled with replacement, so if a day has less than (self.batch_size / days_per_batch trials) trials, it won't be a problem
trial_idxs = np.random.choice(self.trial_indicies[d]['trials'], size = num_trials, replace = True)
batch[d] = trial_idxs
# Remove extra trials
extra_trials = (num_trials * len(days)) - self.batch_size
# While we still have extra trials, remove the last trial from a random day
while extra_trials > 0:
d = np.random.choice(days)
batch[d] = batch[d][:-1]
extra_trials -= 1
batch_index[batch_idx] = batch
return batch_index
def create_batch_index_test(self):
'''
Create an index that is all validation/testing data in batches of up to self.batch_size
If a day does not have at least self.batch_size trials, then the batch size will be less than self.batch_size
This index will ensures that every trial in the validation set is seen once and only once
'''
batch_index = {}
batch_idx = 0
for d in self.trial_indicies.keys():
# Calculate how many batches we need for this day
num_trials = len(self.trial_indicies[d]['trials'])
num_batches = (num_trials + self.batch_size - 1) // self.batch_size
# Create batches for this day
for i in range(num_batches):
start_idx = i * self.batch_size
end_idx = min((i + 1) * self.batch_size, num_trials)
# Get the trial indices for this batch
batch_trials = self.trial_indicies[d]['trials'][start_idx:end_idx]
# Add to batch_index
batch_index[batch_idx] = {d : batch_trials}
batch_idx += 1
return batch_index
def train_test_split_indicies(file_paths, test_percentage = 0.1, seed = -1, bad_trials_dict = None):
'''
Split data from file_paths into train and test splits
Returns two dictionaries that detail which trials in each day will be a part of that split:
Example:
{
0: trials[1,2,3], session_path: 'path'
1: trials[2,5,6], session_path: 'path'
}
Args:
file_paths (list): List of file paths to the hdf5 files containing the data
test_percentage (float): Percentage of trials to use for testing. 0 will use all trials for training, 1 will use all trials for testing
seed (int): Seed for reproducibility. If set to -1, the split will be random
bad_trials_dict (dict): Dictionary of trials to exclude from the dataset. Formatted as:
{
'session_name_1': {block_num_1: [trial_nums], block_num_2: [trial_nums], ...},
'session_name_2': {block_num_1: [trial_nums], block_num_2: [trial_nums], ...},
...
}
'''
# Set seed for reporoducibility
if seed != -1:
np.random.seed(seed)
# Get trials in each day
trials_per_day = {}
for i, path in enumerate(file_paths):
session = [s for s in path.split('/') if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
good_trial_indices = []
if os.path.exists(path):
with h5py.File(path, 'r') as f:
num_trials = len(list(f.keys()))
for t in range(num_trials):
key = f'trial_{t}'
block_num = f[key].attrs['block_num']
trial_num = f[key].attrs['trial_num']
if (
bad_trials_dict is not None
and session in bad_trials_dict
and str(block_num) in bad_trials_dict[session]
and trial_num in bad_trials_dict[session][str(block_num)]
):
# print(f'Bad trial: {session}_{block_num}_{trial_num}')
continue
good_trial_indices.append(t)
trials_per_day[i] = {'num_trials': len(good_trial_indices), 'trial_indices': good_trial_indices, 'session_path': path}
# Pick test_percentage of trials from each day for testing and (1 - test_percentage) for training
train_trials = {}
test_trials = {}
for day in trials_per_day.keys():
num_trials = trials_per_day[day]['num_trials']
# Generate all trial indices for this day (assuming 0-indexed)
all_trial_indices = trials_per_day[day]['trial_indices']
# If test_percentage is 0 or 1, we can just assign all trials to either train or test
if test_percentage == 0:
train_trials[day] = {'trials' : all_trial_indices, 'session_path' : trials_per_day[day]['session_path']}
test_trials[day] = {'trials' : [], 'session_path' : trials_per_day[day]['session_path']}
continue
elif test_percentage == 1:
train_trials[day] = {'trials' : [], 'session_path' : trials_per_day[day]['session_path']}
test_trials[day] = {'trials' : all_trial_indices, 'session_path' : trials_per_day[day]['session_path']}
continue
else:
# Calculate how many trials to use for testing
num_test = max(1, int(num_trials * test_percentage))
# Randomly select indices for testing
test_indices = np.random.choice(all_trial_indices, size=num_test, replace=False).tolist()
# Remaining indices go to training
train_indices = [idx for idx in all_trial_indices if idx not in test_indices]
# Store the split indices
train_trials[day] = {'trials' : train_indices, 'session_path' : trials_per_day[day]['session_path']}
test_trials[day] = {'trials' : test_indices, 'session_path' : trials_per_day[day]['session_path']}
return train_trials, test_trials

View File

@@ -0,0 +1,169 @@
model:
n_input_features: 512
n_units: 768
rnn_dropout: 0.4
rnn_trainable: true
n_layers: 5
bidirectional: false
patch_size: 14
patch_stride: 4
input_network:
n_input_layers: 1
input_layer_sizes:
- 512
input_trainable: true
input_layer_dropout: 0.2
gpu_number: '1'
distributed_training: false
mode: train
use_amp: true
output_dir: /media/lm-pc/8tb_nvme/b2txt25/rnn_v2_jitter
init_from_checkpoint: false
checkpoint_dir: /media/lm-pc/8tb_nvme/b2txt25/rnn_v2_jitter/checkpoint
init_checkpoint_path: None
save_best_checkpoint: true
save_all_val_steps: false
save_final_model: false
save_val_metrics: true
early_stopping: false
early_stopping_val_steps: 20
num_training_batches: 120000
lr_scheduler_type: cosine
lr_max: 0.005
lr_min: 0.0001
lr_decay_steps: 120000
lr_warmup_steps: 1000
lr_max_day: 0.005
lr_min_day: 0.0001
lr_decay_steps_day: 120000
lr_warmup_steps_day: 1000
beta0: 0.9
beta1: 0.999
epsilon: 0.1
weight_decay: 0.001
weight_decay_day: 0
seed: 10
grad_norm_clip_value: 10
batches_per_train_log: 200
batches_per_val_step: 2000
batches_per_save: 0
log_individual_day_val_PER: true
log_val_skip_logs: false
save_val_logits: true
save_val_data: false
dataset:
data_transforms:
white_noise_std: 1.0
constant_offset_std: 0.2
random_walk_std: 0.0
random_walk_axis: -1
static_gain_std: 0.0
random_cut: 3 #0
smooth_kernel_size: 100
smooth_data: true
smooth_kernel_std: 2
neural_dim: 512
batch_size: 64
n_classes: 41
max_seq_elements: 500
days_per_batch: 4
seed: 1
num_dataloader_workers: 4
loader_shuffle: false
must_include_days: null
test_percentage: 0.1
feature_subset: null
dataset_dir: /media/lm-pc/8tb_nvme/b2txt25/hdf5_data
bad_trials_dict: null
sessions:
- t15.2023.08.11
- t15.2023.08.13
- t15.2023.08.18
- t15.2023.08.20
- t15.2023.08.25
- t15.2023.08.27
- t15.2023.09.01
- t15.2023.09.03
- t15.2023.09.24
- t15.2023.09.29
- t15.2023.10.01
- t15.2023.10.06
- t15.2023.10.08
- t15.2023.10.13
- t15.2023.10.15
- t15.2023.10.20
- t15.2023.10.22
- t15.2023.11.03
- t15.2023.11.04
- t15.2023.11.17
- t15.2023.11.19
- t15.2023.11.26
- t15.2023.12.03
- t15.2023.12.08
- t15.2023.12.10
- t15.2023.12.17
- t15.2023.12.29
- t15.2024.02.25
- t15.2024.03.03
- t15.2024.03.08
- t15.2024.03.15
- t15.2024.03.17
- t15.2024.04.25
- t15.2024.04.28
- t15.2024.05.10
- t15.2024.06.14
- t15.2024.07.19
- t15.2024.07.21
- t15.2024.07.28
- t15.2025.01.10
- t15.2025.01.12
- t15.2025.03.14
- t15.2025.03.16
- t15.2025.03.30
- t15.2025.04.13
dataset_probability_val:
- 0
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 0
- 1
- 1
- 1
- 0
- 0
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1
- 1

136
model_training/rnn_model.py Normal file
View File

@@ -0,0 +1,136 @@
import torch
from torch import nn
class GRUDecoder(nn.Module):
'''
Defines the GRU decoder
This class combines day-specific input layers, a GRU, and an output classification layer
'''
def __init__(self,
neural_dim,
n_units,
n_days,
n_classes,
rnn_dropout = 0.0,
input_dropout = 0.0,
n_layers = 5,
patch_size = 0,
patch_stride = 0,
):
'''
neural_dim (int) - number of channels in a single timestep (e.g. 512)
n_units (int) - number of hidden units in each recurrent layer - equal to the size of the hidden state
n_days (int) - number of days in the dataset
n_classes (int) - number of classes
rnn_dropout (float) - percentage of units to droupout during training
input_dropout (float) - percentage of input units to dropout during training
n_layers (int) - number of recurrent layers
patch_size (int) - the number of timesteps to concat on initial input layer - a value of 0 will disable this "input concat" step
patch_stride(int) - the number of timesteps to stride over when concatenating initial input
'''
super(GRUDecoder, self).__init__()
self.neural_dim = neural_dim
self.n_units = n_units
self.n_classes = n_classes
self.n_layers = n_layers
self.n_days = n_days
self.rnn_dropout = rnn_dropout
self.input_dropout = input_dropout
self.patch_size = patch_size
self.patch_stride = patch_stride
# Parameters for the day-specific input layers
self.day_layer_activation = nn.Softsign() # basically a shallower tanh
# Set weights for day layers to be identity matrices so the model can learn its own day-specific transformations
self.day_weights = nn.ParameterList(
[nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)]
)
self.day_biases = nn.ParameterList(
[nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)]
)
self.day_layer_dropout = nn.Dropout(input_dropout)
self.input_size = self.neural_dim
# If we are using "strided inputs", then the input size of the first recurrent layer will actually be in_size * patch_size
if self.patch_size > 0:
self.input_size *= self.patch_size
self.gru = nn.GRU(
input_size = self.input_size,
hidden_size = self.n_units,
num_layers = self.n_layers,
dropout = self.rnn_dropout,
batch_first = True, # The first dim of our input is the batch dim
bidirectional = False,
)
# Set recurrent units to have orthogonal param init and input layers to have xavier init
for name, param in self.gru.named_parameters():
if "weight_hh" in name:
nn.init.orthogonal_(param)
if "weight_ih" in name:
nn.init.xavier_uniform_(param)
# Prediciton head. Weight init to xavier
self.out = nn.Linear(self.n_units, self.n_classes)
nn.init.xavier_uniform_(self.out.weight)
# Learnable initial hidden states
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
def forward(self, x, day_idx, states = None, return_state = False):
'''
x (tensor) - batch of examples (trials) of shape: (batch_size, time_series_length, neural_dim)
day_idx (tensor) - tensor which is a list of day indexs corresponding to the day of each example in the batch x.
'''
# Apply day-specific layer to (hopefully) project neural data from the different days to the same latent space
day_weights = torch.stack([self.day_weights[i] for i in day_idx], dim=0)
day_biases = torch.cat([self.day_biases[i] for i in day_idx], dim=0).unsqueeze(1)
x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
x = self.day_layer_activation(x)
# Apply dropout to the ouput of the day specific layer
if self.input_dropout > 0:
x = self.day_layer_dropout(x)
# (Optionally) Perform input concat operation
if self.patch_size > 0:
x = x.unsqueeze(1) # [batches, 1, timesteps, feature_dim]
x = x.permute(0, 3, 1, 2) # [batches, feature_dim, 1, timesteps]
# Extract patches using unfold (sliding window)
x_unfold = x.unfold(3, self.patch_size, self.patch_stride) # [batches, feature_dim, 1, num_patches, patch_size]
# Remove dummy height dimension and rearrange dimensions
x_unfold = x_unfold.squeeze(2) # [batches, feature_dum, num_patches, patch_size]
x_unfold = x_unfold.permute(0, 2, 3, 1) # [batches, num_patches, patch_size, feature_dim]
# Flatten last two dimensions (patch_size and features)
x = x_unfold.reshape(x.size(0), x_unfold.size(1), -1)
# Determine initial hidden states
if states is None:
states = self.h0.expand(self.n_layers, x.shape[0], self.n_units).contiguous()
# Pass input through RNN
output, hidden_states = self.gru(x, states)
# Compute logits
logits = self.out(output)
if return_state:
return logits, hidden_states
return logits

View File

@@ -0,0 +1,754 @@
import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LambdaLR
import random
import time
import os
import numpy as np
import math
import pathlib
import logging
import sys
import json
import pickle
from dataset import BrainToTextDataset, train_test_split_indicies
from data_augmentations import gauss_smooth
import torchaudio.functional as F # for edit distance
from omegaconf import OmegaConf
torch.set_float32_matmul_precision('high') # makes float32 matmuls faster on some GPUs
torch.backends.cudnn.deterministic = True # makes training more reproducible
torch._dynamo.config.cache_size_limit = 64
from rnn_model import GRUDecoder
class BrainToTextDecoder_Trainer:
"""
This class will initialize and train a brain-to-text phoneme decoder
Written by Nick Card and Zachery Fogg with reference to Stanford NPTL's decoding function
"""
def __init__(self, args):
'''
args : dictionary of training arguments
'''
# Trainer fields
self.args = args
self.logger = None
self.device = None
self.model = None
self.optimizer = None
self.learning_rate_scheduler = None
self.ctc_loss = None
self.best_val_PER = torch.inf # track best PER for checkpointing
self.best_val_loss = torch.inf # track best loss for checkpointing
self.train_dataset = None
self.val_dataset = None
self.train_loader = None
self.val_loader = None
self.transform_args = self.args['dataset']['data_transforms']
# Create output directory
if args['mode'] == 'train':
os.makedirs(self.args['output_dir'], exist_ok=True)
# Create checkpoint directory
if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']:
os.makedirs(self.args['checkpoint_dir'], exist_ok = True)
# Set up logging
self.logger = logging.getLogger(__name__)
for handler in self.logger.handlers[:]: # make a copy of the list
self.logger.removeHandler(handler)
self.logger.setLevel(logging.INFO)
formatter = logging.Formatter(fmt='%(asctime)s: %(message)s')
if args['mode']=='train':
# During training, save logs to file in output directory
fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'],'training_log')))
fh.setFormatter(formatter)
self.logger.addHandler(fh)
# Always print logs to stdout
sh = logging.StreamHandler(sys.stdout)
sh.setFormatter(formatter)
self.logger.addHandler(sh)
# Configure device pytorch will use
if not self.args['distributed_training']:
if torch.cuda.is_available():
self.device = f"cuda:{self.args['gpu_number']}"
else:
self.device = "cpu"
else:
self.device = "cuda"
self.logger.info(f'Using device: {self.device}')
# Set seed if provided
if self.args['seed'] != -1:
np.random.seed(self.args['seed'])
random.seed(self.args['seed'])
torch.manual_seed(self.args['seed'])
# Initialize the model
self.model = GRUDecoder(
neural_dim = self.args['model']['n_input_features'],
n_units = self.args['model']['n_units'],
n_days = len(self.args['dataset']['sessions']),
n_classes = self.args['dataset']['n_classes'],
rnn_dropout = self.args['model']['rnn_dropout'],
input_dropout = self.args['model']['input_network']['input_layer_dropout'],
n_layers = self.args['model']['n_layers'],
bidirectional = self.args['model']['bidirectional'],
patch_size = self.args['model']['patch_size'],
patch_stride = self.args['model']['patch_stride'],
)
# Call torch.compile to speed up training
self.logger.info("Using torch.compile")
self.model = torch.compile(self.model)
self.logger.info(f"Initialized RNN decoding model")
self.logger.info(self.model)
# Log how many parameters are in the model
total_params = sum(p.numel() for p in self.model.parameters())
self.logger.info(f"Model has {total_params:,} parameters")
# Determine how many day-specific parameters are in the model
day_params = 0
for name, param in self.model.named_parameters():
if 'day' in name:
day_params += param.numel()
self.logger.info(f"Model has {day_params:,} day-specific parameters | {((day_params / total_params) * 100):.2f}% of total parameters")
# Create datasets and dataloaders
train_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_train.hdf5') for s in self.args['dataset']['sessions']]
val_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_val.hdf5') for s in self.args['dataset']['sessions']]
# Ensure that there are no duplicate days
if len(set(train_file_paths)) != len(train_file_paths):
raise ValueError("There are duplicate sessions listed in the train dataset")
if len(set(val_file_paths)) != len(val_file_paths):
raise ValueError("There are duplicate sessions listed in the val dataset")
# Split trials into train and test sets
train_trials, _ = train_test_split_indicies(
file_paths = train_file_paths,
test_percentage = 0,
seed = self.args['dataset']['seed'],
bad_trials_dict = None,
)
_, val_trials = train_test_split_indicies(
file_paths = val_file_paths,
test_percentage = 1,
seed = self.args['dataset']['seed'],
bad_trials_dict = None,
)
# Save dictionaries to output directory to know which trials were train vs val
with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f:
json.dump({'train' : train_trials, 'val': val_trials}, f)
# Determine if a only a subset of neural features should be used
feature_subset = None
if ('feature_subset' in self.args['dataset']) and self.args['dataset']['feature_subset'] != None:
feature_subset = self.args['dataset']['feature_subset']
self.logger.info(f'Using only a subset of features: {feature_subset}')
# train dataset and dataloader
self.train_dataset = BrainToTextDataset(
trial_indicies = train_trials,
split = 'train',
days_per_batch = self.args['dataset']['days_per_batch'],
n_batches = self.args['num_training_batches'],
batch_size = self.args['dataset']['batch_size'],
must_include_days = None,
random_seed = self.args['dataset']['seed'],
feature_subset = feature_subset
)
self.train_loader = DataLoader(
self.train_dataset,
batch_size = None, # Dataset.__getitem__() already returns batches
shuffle = self.args['dataset']['loader_shuffle'],
num_workers = self.args['dataset']['num_dataloader_workers'],
pin_memory = True
)
# val dataset and dataloader
self.val_dataset = BrainToTextDataset(
trial_indicies = val_trials,
split = 'test',
days_per_batch = None,
n_batches = None,
batch_size = self.args['dataset']['batch_size'],
must_include_days = None,
random_seed = self.args['dataset']['seed'],
feature_subset = feature_subset
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size = None, # Dataset.__getitem__() already returns batches
shuffle = False,
num_workers = 0,
pin_memory = True
)
self.logger.info("Successfully initialized datasets")
# Create optimizer, learning rate scheduler, and loss
self.optimizer = self.create_optimizer()
if self.args['lr_scheduler_type'] == 'linear':
self.learning_rate_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer = self.optimizer,
start_factor = 1.0,
end_factor = self.args['lr_min'] / self.args['lr_max'],
total_iters = self.args['lr_decay_steps'],
)
elif self.args['lr_scheduler_type'] == 'cosine':
self.learning_rate_scheduler = self.create_cosine_lr_scheduler(self.optimizer)
else:
raise ValueError(f"Invalid learning rate scheduler type: {self.args['lr_scheduler_type']}")
self.ctc_loss = torch.nn.CTCLoss(blank = 0, reduction = 'none', zero_infinity = False)
# If a checkpoint is provided, then load from checkpoint
if self.args['init_from_checkpoint']:
self.load_model_checkpoint(self.args['init_checkpoint_path'])
# Set rnn and/or input layers to not trainable if specified
for name, param in self.model.named_parameters():
if not self.args['model']['rnn_trainable'] and 'gru' in name:
param.requires_grad = False
elif not self.args['model']['input_network']['input_trainable'] and 'day' in name:
param.requires_grad = False
# Send model to device
self.model.to(self.device)
def create_optimizer(self):
'''
Create the optimizer with special param groups
Biases and day weights should not be decayed
Day weights should have a separate learning rate
'''
bias_params = [p for name, p in self.model.named_parameters() if 'gru.bias' in name or 'out.bias' in name]
day_params = [p for name, p in self.model.named_parameters() if 'day_' in name]
other_params = [p for name, p in self.model.named_parameters() if 'day_' not in name and 'gru.bias' not in name and 'out.bias' not in name]
if len(day_params) != 0:
param_groups = [
{'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'},
{'params' : day_params, 'lr' : self.args['lr_max_day'], 'weight_decay' : self.args['weight_decay_day'], 'group_type' : 'day_layer'},
{'params' : other_params, 'group_type' : 'other'}
]
else:
param_groups = [
{'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'},
{'params' : other_params, 'group_type' : 'other'}
]
optim = torch.optim.AdamW(
param_groups,
lr = self.args['lr_max'],
betas = (self.args['beta0'], self.args['beta1']),
eps = self.args['epsilon'],
weight_decay = self.args['weight_decay'],
fused = True
)
return optim
def create_cosine_lr_scheduler(self, optim):
lr_max = self.args['lr_max']
lr_min = self.args['lr_min']
lr_decay_steps = self.args['lr_decay_steps']
lr_max_day = self.args['lr_max_day']
lr_min_day = self.args['lr_min_day']
lr_decay_steps_day = self.args['lr_decay_steps_day']
lr_warmup_steps = self.args['lr_warmup_steps']
lr_warmup_steps_day = self.args['lr_warmup_steps_day']
def lr_lambda(current_step, min_lr_ratio, decay_steps, warmup_steps):
'''
Create lr lambdas for each param group that implement cosine decay
Different lr lambda decaying for day params vs rest of the model
'''
# Warmup phase
if current_step < warmup_steps:
return float(current_step) / float(max(1, warmup_steps))
# Cosine decay phase
if current_step < decay_steps:
progress = float(current_step - warmup_steps) / float(
max(1, decay_steps - warmup_steps)
)
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
# Scale from 1.0 to min_lr_ratio
return max(min_lr_ratio, min_lr_ratio + (1 - min_lr_ratio) * cosine_decay)
# After cosine decay is complete, maintain min_lr_ratio
return min_lr_ratio
if len(optim.param_groups) == 3:
lr_lambdas = [
lambda step: lr_lambda(
step,
lr_min / lr_max,
lr_decay_steps,
lr_warmup_steps), # biases
lambda step: lr_lambda(
step,
lr_min_day / lr_max_day,
lr_decay_steps_day,
lr_warmup_steps_day,
), # day params
lambda step: lr_lambda(
step,
lr_min / lr_max,
lr_decay_steps,
lr_warmup_steps), # rest of model weights
]
elif len(optim.param_groups) == 2:
lr_lambdas = [
lambda step: lr_lambda(
step,
lr_min / lr_max,
lr_decay_steps,
lr_warmup_steps), # biases
lambda step: lr_lambda(
step,
lr_min / lr_max,
lr_decay_steps,
lr_warmup_steps), # rest of model weights
]
else:
raise ValueError(f"Invalid number of param groups in optimizer: {len(optim.param_groups)}")
return LambdaLR(optim, lr_lambdas, -1)
def load_model_checkpoint(self, load_path):
'''
Load a training checkpoint
'''
checkpoint = torch.load(load_path, weights_only = False) # checkpoint is just a dict
self.model.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.best_val_PER = checkpoint['val_PER'] # best phoneme error rate
self.best_val_loss = checkpoint['val_loss'] if 'val_loss' in checkpoint.keys() else torch.inf
self.model.to(self.device)
# Send optimizer params back to GPU
for state in self.optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(self.device)
self.logger.info("Loaded model from checkpoint: " + load_path)
def save_model_checkpoint(self, save_path, PER, loss):
'''
Save a training checkpoint
'''
checkpoint = {
'model_state_dict' : self.model.state_dict(),
'optimizer_state_dict' : self.optimizer.state_dict(),
'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(),
'val_PER' : PER,
'val_loss' : loss
}
torch.save(checkpoint, save_path)
self.logger.info("Saved model to checkpoint: " + save_path)
# Save the args file alongside the checkpoint
with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
OmegaConf.save(config=self.args, f=f)
def create_attention_mask(self, sequence_lengths):
max_length = torch.max(sequence_lengths).item()
batch_size = sequence_lengths.size(0)
# Create a mask for valid key positions (columns)
# Shape: [batch_size, max_length]
key_mask = torch.arange(max_length, device=sequence_lengths.device).expand(batch_size, max_length)
key_mask = key_mask < sequence_lengths.unsqueeze(1)
# Expand key_mask to [batch_size, 1, 1, max_length]
# This will be broadcast across all query positions
key_mask = key_mask.unsqueeze(1).unsqueeze(1)
# Create the attention mask of shape [batch_size, 1, max_length, max_length]
# by broadcasting key_mask across all query positions
attention_mask = key_mask.expand(batch_size, 1, max_length, max_length)
# Convert boolean mask to float mask:
# - True (valid key positions) -> 0.0 (no change to attention scores)
# - False (padding key positions) -> -inf (will become 0 after softmax)
attention_mask_float = torch.where(attention_mask,
True,
False)
return attention_mask_float
def transform_data(self, features, n_time_steps, mode = 'train'):
'''
Apply various augmentations and smoothing to data
Performing augmentations is much faster on GPU than CPU
'''
data_shape = features.shape
batch_size = data_shape[0]
channels = data_shape[-1]
# We only apply these augmentations in training
if mode == 'train':
# add static gain noise
if self.transform_args['static_gain_std'] > 0:
warp_mat = torch.tile(torch.unsqueeze(torch.eye(channels), dim = 0), (batch_size, 1, 1))
warp_mat += torch.randn_like(warp_mat, device=self.device) * self.transform_args['static_gain_std']
features = torch.matmul(features, warp_mat)
# add white noise
if self.transform_args['white_noise_std'] > 0:
features += torch.randn(data_shape, device=self.device) * self.transform_args['white_noise_std']
# add constant offset noise
if self.transform_args['constant_offset_std'] > 0:
features += torch.randn((batch_size, 1, channels), device=self.device) * self.transform_args['constant_offset_std']
# add random walk noise
if self.transform_args['random_walk_std'] > 0:
features += torch.cumsum(torch.randn(data_shape, device=self.device) * self.transform_args['random_walk_std'], dim =self.transform_args['random_walk_axis'])
# randomly cutoff part of the data timecourse
if self.transform_args['random_cut'] > 0:
cut = np.random.randint(0, self.transform_args['random_cut'])
features = features[:, cut:, :]
n_time_steps = n_time_steps - cut
# Apply Gaussian smoothing to data
# This is done in both training and validation
if self.transform_args['smooth_data']:
features = gauss_smooth(
inputs = features,
device = self.device,
smooth_kernel_std = self.transform_args['smooth_kernel_std'],
smooth_kernel_size= self.transform_args['smooth_kernel_size'],
)
return features, n_time_steps
def train(self):
'''
Train the model
'''
# Set model to train mode (specificially to make sure dropout layers are engaged)
self.model.train()
# create vars to track performance
train_losses = []
val_losses = []
val_PERs = []
val_results = []
val_steps_since_improvement = 0
# training params
save_best_checkpoint = self.args.get('save_best_checkpoint', True)
early_stopping = self.args.get('early_stopping', True)
early_stopping_val_steps = self.args['early_stopping_val_steps']
train_start_time = time.time()
# train for specified number of batches
for i, batch in enumerate(self.train_loader):
self.model.train()
self.optimizer.zero_grad()
# Train step
start_time = time.time()
# Move data to device
features = batch['input_features'].to(self.device)
labels = batch['seq_class_ids'].to(self.device)
n_time_steps = batch['n_time_steps'].to(self.device)
phone_seq_lens = batch['phone_seq_lens'].to(self.device)
day_indicies = batch['day_indicies'].to(self.device)
# Use autocast for efficiency
with torch.autocast(device_type = "cuda", enabled = self.args['use_amp'], dtype = torch.bfloat16):
# Apply augmentations to the data
features, n_time_steps = self.transform_data(features, n_time_steps, 'train')
adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
# Get phoneme predictions
logits = self.model(features, day_indicies)
# Calculate CTC Loss
loss = self.ctc_loss(
log_probs = torch.permute(logits.log_softmax(2), [1, 0, 2]),
targets = labels,
input_lengths = adjusted_lens,
target_lengths = phone_seq_lens
)
loss = torch.mean(loss) # take mean loss over batches
loss.backward()
# Clip gradient
if self.args['grad_norm_clip_value'] > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(),
max_norm = self.args['grad_norm_clip_value'],
error_if_nonfinite = True,
foreach = True
)
self.optimizer.step()
self.learning_rate_scheduler.step()
# Save training metrics
train_step_duration = time.time() - start_time
train_losses.append(loss.detach().item())
# Incrementally log training progress
if i % self.args['batches_per_train_log'] == 0:
self.logger.info(f'Train batch {i}: ' +
f'loss: {(loss.detach().item()):.2f} ' +
f'grad norm: {grad_norm:.2f} '
f'time: {train_step_duration:.3f}')
# Incrementally run a test step
if i % self.args['batches_per_val_step'] == 0 or i == ((self.args['num_training_batches'] - 1)):
self.logger.info(f"Running test after training batch: {i}")
# Calculate metrics on val data
start_time = time.time()
val_metrics = self.validation(loader = self.val_loader, return_logits = self.args['save_val_logits'], return_data = self.args['save_val_data'])
val_step_duration = time.time() - start_time
# Log info
self.logger.info(f'Val batch {i}: ' +
f'PER (avg): {val_metrics["avg_PER"]:.4f} ' +
f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} ' +
f'time: {val_step_duration:.3f}')
if self.args['log_individual_day_val_PER']:
for day in val_metrics['day_PERs'].keys():
self.logger.info(f"{self.args['dataset']['sessions'][day]} val PER: {val_metrics['day_PERs'][day]['total_edit_distance'] / val_metrics['day_PERs'][day]['total_seq_length']:0.4f}")
# Save metrics
val_PERs.append(val_metrics['avg_PER'])
val_losses.append(val_metrics['avg_loss'])
val_results.append(val_metrics)
# Determine if new best day. Based on if PER is lower, or in the case of a PER tie, if loss is lower
new_best = False
if val_metrics['avg_PER'] < self.best_val_PER:
self.logger.info(f"New best test PER {self.best_val_PER:.4f} --> {val_metrics['avg_PER']:.4f}")
self.best_val_PER = val_metrics['avg_PER']
self.best_val_loss = val_metrics['avg_loss']
new_best = True
elif val_metrics['avg_PER'] == self.best_val_PER and (val_metrics['avg_loss'] < self.best_val_loss):
self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}")
self.best_val_loss = val_metrics['avg_loss']
new_best = True
if new_best:
# Checkpoint if metrics have improved
if save_best_checkpoint:
self.logger.info(f"Checkpointing model")
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/best_checkpoint', self.best_val_PER, self.best_val_loss)
# save validation metrics to pickle file
if self.args['save_val_metrics']:
with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f:
pickle.dump(val_metrics, f)
val_steps_since_improvement = 0
else:
val_steps_since_improvement +=1
# Optionally save this validation checkpoint, regardless of performance
if self.args['save_all_val_steps']:
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/checkpoint_batch_{i}', val_metrics['avg_PER'])
# Early stopping
if early_stopping and (val_steps_since_improvement >= early_stopping_val_steps):
self.logger.info(f'Overall validation PER has not improved in {early_stopping_val_steps} validation steps. Stopping training early at batch: {i}')
break
# Log final training steps
training_duration = time.time() - train_start_time
self.logger.info(f'Best avg val PER achieved: {self.best_val_PER:.5f}')
self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes')
# Save final model
if self.args['save_final_model']:
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/final_checkpoint_batch_{i}', val_PERs[-1])
train_stats = {}
train_stats['train_losses'] = train_losses
train_stats['val_losses'] = val_losses
train_stats['val_PERs'] = val_PERs
train_stats['val_metrics'] = val_results
return train_stats
def validation(self, loader, return_logits = False, return_data = False):
'''
Calculate metrics on the validation dataset
'''
self.model.eval()
metrics = {}
# Record metrics
if return_logits:
metrics['logits'] = []
metrics['n_time_steps'] = []
if return_data:
metrics['input_features'] = []
metrics['decoded_seqs'] = []
metrics['true_seq'] = []
metrics['phone_seq_lens'] = []
metrics['transcription'] = []
metrics['losses'] = []
metrics['block_nums'] = []
metrics['trial_nums'] = []
metrics['day_indicies'] = []
total_edit_distance = 0
total_seq_length = 0
# Calculate PER for each specific day
day_per = {}
for d in range(len(self.args['dataset']['sessions'])):
if self.args['dataset']['dataset_probability_val'][d] == 1:
day_per[d] = {'total_edit_distance' : 0, 'total_seq_length' : 0}
for i, batch in enumerate(loader):
features = batch['input_features'].to(self.device)
labels = batch['seq_class_ids'].to(self.device)
n_time_steps = batch['n_time_steps'].to(self.device)
phone_seq_lens = batch['phone_seq_lens'].to(self.device)
day_indicies = batch['day_indicies'].to(self.device)
# Determine if we should perform validation on this batch
day = day_indicies[0].item()
if self.args['dataset']['dataset_probability_val'][day] == 0:
if self.args['log_val_skip_logs']:
self.logger.info(f"Skipping validation on day {day}")
continue
with torch.no_grad():
with torch.autocast(device_type = "cuda", enabled = self.args['use_amp'], dtype = torch.bfloat16):
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
logits = self.model(features, day_indicies)
loss = self.ctc_loss(
torch.permute(logits.log_softmax(2), [1, 0, 2]),
labels,
adjusted_lens,
phone_seq_lens,
)
loss = torch.mean(loss)
metrics['losses'].append(loss.cpu().detach().numpy())
# Calculate PER per day and also avg over entire validation set
batch_edit_distance = 0
decoded_seqs = []
for iterIdx in range(logits.shape[0]):
decoded_seq = torch.argmax(logits[iterIdx, 0 : adjusted_lens[iterIdx], :].clone().detach(),dim=-1)
decoded_seq = torch.unique_consecutive(decoded_seq, dim=-1)
decoded_seq = decoded_seq.cpu().detach().numpy()
decoded_seq = np.array([i for i in decoded_seq if i != 0])
trueSeq = np.array(
labels[iterIdx][0 : phone_seq_lens[iterIdx]].cpu().detach()
)
batch_edit_distance += F.edit_distance(decoded_seq, trueSeq)
decoded_seqs.append(decoded_seq)
day = batch['day_indicies'][0].item()
day_per[day]['total_edit_distance'] += batch_edit_distance
day_per[day]['total_seq_length'] += torch.sum(phone_seq_lens).item()
total_edit_distance += batch_edit_distance
total_seq_length += torch.sum(phone_seq_lens)
# Record metrics
if return_logits:
metrics['logits'].append(logits.cpu().float().numpy()) # Will be in bfloat16 if AMP is enabled, so need to set back to float32
metrics['n_time_steps'].append(adjusted_lens.cpu().numpy())
if return_data:
metrics['input_features'].append(batch['input_features'].cpu().numpy())
metrics['decoded_seqs'].append(decoded_seqs)
metrics['true_seq'].append(batch['seq_class_ids'].cpu().numpy())
metrics['phone_seq_lens'].append(batch['phone_seq_lens'].cpu().numpy())
metrics['transcription'].append(batch['transcriptions'].cpu().numpy())
metrics['losses'].append(loss.detach().item())
metrics['block_nums'].append(batch['block_nums'].numpy())
metrics['trial_nums'].append(batch['trial_nums'].numpy())
metrics['day_indicies'].append(batch['day_indicies'].cpu().numpy())
avg_PER = total_edit_distance / total_seq_length
metrics['day_PERs'] = day_per
metrics['avg_PER'] = avg_PER.item()
metrics['avg_loss'] = np.mean(metrics['losses'])
return metrics

View File

@@ -0,0 +1,6 @@
from omegaconf import OmegaConf
from rnn_trainer import BrainToTextDecoder_Trainer
args = OmegaConf.load('rnn_args.yaml')
trainer = BrainToTextDecoder_Trainer(args)
metrics = trainer.train()