b2txt25 wip
This commit is contained in:
23
model_training/README.md
Normal file
23
model_training/README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# Model Training
|
||||
|
||||
This directory contains code and resources for training the brain-to-text RNN model. This model is largely based on the architecture described in the paper "*An Accurate and Rapidly Calibrating Speech Neuroprosthesis*" by Card et al. (2024), but also contains modifications to improve performance, efficiency, and usability.
|
||||
|
||||
## Setup
|
||||
|
||||
1. Install the required conda environment by following the instructions in the root `README.md` file. This will set up the necessary dependencies for running the model training and evaluation code.
|
||||
|
||||
2. Download the dataset from Dryad: [Dryad Dataset](https://datadryad.org/dataset/doi:10.5061/dryad.dncjsxm85). Place the downloaded data in the `data` directory.
|
||||
|
||||
## Training
|
||||
|
||||
To train the baseline RNN model, run the following command:
|
||||
```bash
|
||||
python train_model.py
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
To evaluate the model, run:
|
||||
```bash
|
||||
python evaluate_model.py
|
||||
```
|
37
model_training/data_augmentations.py
Normal file
37
model_training/data_augmentations.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
|
||||
def gauss_smooth(inputs, device, smooth_kernel_std=2, smooth_kernel_size=100, padding='same'):
|
||||
"""
|
||||
Applies a 1D Gaussian smoothing operation with PyTorch to smooth the data along the time axis.
|
||||
Args:
|
||||
inputs (tensor : B x T x N): A 3D tensor with batch size B, time steps T, and number of features N.
|
||||
Assumed to already be on the correct device (e.g., GPU).
|
||||
kernelSD (float): Standard deviation of the Gaussian smoothing kernel.
|
||||
padding (str): Padding mode, either 'same' or 'valid'.
|
||||
device (str): Device to use for computation (e.g., 'cuda' or 'cpu').
|
||||
Returns:
|
||||
smoothed (tensor : B x T x N): A smoothed 3D tensor with batch size B, time steps T, and number of features N.
|
||||
"""
|
||||
# Get Gaussian kernel
|
||||
inp = np.zeros(smooth_kernel_size, dtype=np.float32)
|
||||
inp[smooth_kernel_size // 2] = 1
|
||||
gaussKernel = gaussian_filter1d(inp, smooth_kernel_std)
|
||||
validIdx = np.argwhere(gaussKernel > 0.01)
|
||||
gaussKernel = gaussKernel[validIdx]
|
||||
gaussKernel = np.squeeze(gaussKernel / np.sum(gaussKernel))
|
||||
|
||||
# Convert to tensor
|
||||
gaussKernel = torch.tensor(gaussKernel, dtype=torch.float32, device=device)
|
||||
gaussKernel = gaussKernel.view(1, 1, -1) # [1, 1, kernel_size]
|
||||
|
||||
# Prepare convolution
|
||||
B, T, C = inputs.shape
|
||||
inputs = inputs.permute(0, 2, 1) # [B, C, T]
|
||||
gaussKernel = gaussKernel.repeat(C, 1, 1) # [C, 1, kernel_size]
|
||||
|
||||
# Perform convolution
|
||||
smoothed = F.conv1d(inputs, gaussKernel, padding=padding, groups=C)
|
||||
return smoothed.permute(0, 2, 1) # [B, T, C]
|
334
model_training/dataset.py
Normal file
334
model_training/dataset.py
Normal file
@@ -0,0 +1,334 @@
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import h5py
|
||||
import numpy as np
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import math
|
||||
|
||||
class BrainToTextDataset(Dataset):
|
||||
'''
|
||||
Dataset for brain-to-text data
|
||||
|
||||
Returns an entire batch of data instead of a single example
|
||||
'''
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trial_indicies,
|
||||
n_batches,
|
||||
split = 'train',
|
||||
batch_size = 64,
|
||||
days_per_batch = 1,
|
||||
random_seed = -1,
|
||||
must_include_days = None,
|
||||
feature_subset = None
|
||||
):
|
||||
'''
|
||||
trial_indicies: (dict) - dictionary with day numbers as keys and lists of trial indices as values
|
||||
n_batches: (int) - number of random training batches to create
|
||||
split: (string) - string specifying if this is a train or test dataset
|
||||
batch_size: (int) - number of examples to include in batch returned from __getitem_()
|
||||
days_per_batch: (int) - how many unique days can exist in a batch; this is important for making sure that updates
|
||||
to individual day layers in the GRU are not excesively noisy. Validation data will always have 1 day per batch
|
||||
random_seed: (int) - seed to set for randomly assigning trials to a batch. If set to -1, trial assignment will be random
|
||||
must_include_days ([int]) - list of days that must be included in every batch
|
||||
feature_subset ([int]) - list of neural feature indicies that should be the only features included in the neural data
|
||||
'''
|
||||
|
||||
# Set random seed for reproducibility
|
||||
if random_seed != -1:
|
||||
np.random.seed(random_seed)
|
||||
torch.manual_seed(random_seed)
|
||||
|
||||
self.split = split
|
||||
|
||||
# Ensure the split is valid
|
||||
if self.split not in ['train', 'test']:
|
||||
raise ValueError(f'split must be either "train" or "test". Received {self.split}')
|
||||
|
||||
self.days_per_batch = days_per_batch
|
||||
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.n_batches = n_batches
|
||||
|
||||
self.days = {}
|
||||
self.n_trials = 0
|
||||
self.trial_indicies = trial_indicies
|
||||
self.n_days = len(trial_indicies.keys())
|
||||
|
||||
self.feature_subset = feature_subset
|
||||
|
||||
# Calculate total number of trials in the dataset
|
||||
for d in trial_indicies:
|
||||
self.n_trials += len(trial_indicies[d]['trials'])
|
||||
|
||||
if must_include_days is not None and len(must_include_days) > days_per_batch:
|
||||
raise ValueError(f'must_include_days must be less than or equal to days_per_batch. Received {must_include_days} and days_per_batch {days_per_batch}')
|
||||
|
||||
if must_include_days is not None and len(must_include_days) > self.n_days and split != 'train':
|
||||
raise ValueError(f'must_include_days is not valid for test data. Received {must_include_days} and but only {self.n_days} in the dataset')
|
||||
|
||||
if must_include_days is not None:
|
||||
# Map must_include_days to correct indicies if they are negative
|
||||
for i, d in enumerate(must_include_days):
|
||||
if d < 0:
|
||||
must_include_days[i] = self.n_days + d
|
||||
|
||||
self.must_include_days = must_include_days
|
||||
|
||||
# Ensure that the days_per_batch is not greater than the number of days in the dataset. Raise error
|
||||
if self.split == 'train' and self.days_per_batch > self.n_days:
|
||||
raise ValueError(f'Requested days_per_batch: {days_per_batch} is greater than available days {self.n_days}.')
|
||||
|
||||
|
||||
if self.split == 'train':
|
||||
self.batch_index = self.create_batch_index_train()
|
||||
else:
|
||||
self.batch_index = self.create_batch_index_test()
|
||||
self.n_batches = len(self.batch_index.keys()) # The validation data has a fixed amount of data
|
||||
|
||||
def __len__(self):
|
||||
'''
|
||||
How many batches are in this dataset.
|
||||
Because training data is sampled randomly, there is no fixed dataset length,
|
||||
however this method is required for DataLoader to work
|
||||
'''
|
||||
return self.n_batches
|
||||
|
||||
def __getitem__(self, idx):
|
||||
'''
|
||||
Gets an entire batch of data from the dataset, not just a single item
|
||||
'''
|
||||
batch = {
|
||||
'input_features' : [],
|
||||
'seq_class_ids' : [],
|
||||
'n_time_steps' : [],
|
||||
'phone_seq_lens' : [],
|
||||
'day_indicies' : [],
|
||||
'transcriptions' : [],
|
||||
'block_nums' : [],
|
||||
'trial_nums' : [],
|
||||
}
|
||||
|
||||
index = self.batch_index[idx]
|
||||
|
||||
# Iterate through each day in the index
|
||||
for d in index.keys():
|
||||
|
||||
# Open the hdf5 file for that day
|
||||
with h5py.File(self.trial_indicies[d]['session_path'], 'r') as f:
|
||||
|
||||
# For each trial in the selected trials in that day
|
||||
for t in index[d]:
|
||||
|
||||
try:
|
||||
g = f[f'trial_{t}']
|
||||
|
||||
# Remove features is neccessary
|
||||
input_features = torch.from_numpy(g['input_features'][:]) # neural data
|
||||
if self.feature_subset:
|
||||
input_features = input_features[:,self.feature_subset]
|
||||
|
||||
batch['input_features'].append(input_features)
|
||||
|
||||
batch['seq_class_ids'].append(torch.from_numpy(g['seq_class_ids'][:])) # phoneme labels
|
||||
batch['transcriptions'].append(torch.from_numpy(g['transcription'][:])) # character level transcriptions
|
||||
batch['n_time_steps'].append(g.attrs['n_time_steps']) # number of time steps in the trial - required since we are padding
|
||||
batch['phone_seq_lens'].append(g.attrs['seq_len']) # number of phonemes in the label - required since we are padding
|
||||
batch['day_indicies'].append(int(d)) # day index of each trial - required for the day specific layers
|
||||
batch['block_nums'].append(g.attrs['block_num'])
|
||||
batch['trial_nums'].append(g.attrs['trial_num'])
|
||||
|
||||
except Exception as e:
|
||||
print(f'Error loading trial {t} from session {self.trial_indicies[d]["session_path"]}: {e}')
|
||||
continue
|
||||
|
||||
# Pad data to form a cohesive batch
|
||||
batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0)
|
||||
batch['seq_class_ids'] = pad_sequence(batch['seq_class_ids'], batch_first = True, padding_value = 0)
|
||||
|
||||
batch['n_time_steps'] = torch.tensor(batch['n_time_steps'])
|
||||
batch['phone_seq_lens'] = torch.tensor(batch['phone_seq_lens'])
|
||||
batch['day_indicies'] = torch.tensor(batch['day_indicies'])
|
||||
batch['transcriptions'] = torch.stack(batch['transcriptions'])
|
||||
batch['block_nums'] = torch.tensor(batch['block_nums'])
|
||||
batch['trial_nums'] = torch.tensor(batch['trial_nums'])
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def create_batch_index_train(self):
|
||||
'''
|
||||
Create an index that maps a batch_number to batch_size number of trials
|
||||
|
||||
Each batch will have days_per_batch unique days of data, with the number of trials for each day evenly split between the days
|
||||
(or as even as possible if batch_size is not divisible by days_per_batch)
|
||||
'''
|
||||
|
||||
batch_index = {}
|
||||
|
||||
# Precompute the days that are not in must_include_days
|
||||
if self.must_include_days is not None:
|
||||
non_must_include_days = [d for d in self.trial_indicies.keys() if d not in self.must_include_days]
|
||||
|
||||
for batch_idx in range(self.n_batches):
|
||||
batch = {}
|
||||
|
||||
# Which days will be used for this batch. Picked randomly without replacement
|
||||
# TODO: In the future we may want to consider sampling days in proportion to the number of trials in each day
|
||||
|
||||
# If must_include_days is not empty, we will use those days and then randomly sample the rest
|
||||
if self.must_include_days is not None and len(self.must_include_days) > 0:
|
||||
|
||||
days = np.concatenate((self.must_include_days, np.random.choice(non_must_include_days, size = self.days_per_batch - len(self.must_include_days), replace = False)))
|
||||
|
||||
# Otherwise we will select random days without replacement
|
||||
else:
|
||||
days = np.random.choice(list(self.trial_indicies.keys()), size = self.days_per_batch, replace = False)
|
||||
|
||||
# How many trials will be sampled from each day
|
||||
num_trials = math.ceil(self.batch_size / self.days_per_batch) # Use ceiling to make sure we get at least batch_size trials
|
||||
|
||||
for d in days:
|
||||
|
||||
# Trials are sampled with replacement, so if a day has less than (self.batch_size / days_per_batch trials) trials, it won't be a problem
|
||||
trial_idxs = np.random.choice(self.trial_indicies[d]['trials'], size = num_trials, replace = True)
|
||||
batch[d] = trial_idxs
|
||||
|
||||
# Remove extra trials
|
||||
extra_trials = (num_trials * len(days)) - self.batch_size
|
||||
|
||||
# While we still have extra trials, remove the last trial from a random day
|
||||
while extra_trials > 0:
|
||||
d = np.random.choice(days)
|
||||
batch[d] = batch[d][:-1]
|
||||
extra_trials -= 1
|
||||
|
||||
batch_index[batch_idx] = batch
|
||||
|
||||
return batch_index
|
||||
|
||||
def create_batch_index_test(self):
|
||||
'''
|
||||
Create an index that is all validation/testing data in batches of up to self.batch_size
|
||||
|
||||
If a day does not have at least self.batch_size trials, then the batch size will be less than self.batch_size
|
||||
|
||||
This index will ensures that every trial in the validation set is seen once and only once
|
||||
'''
|
||||
batch_index = {}
|
||||
batch_idx = 0
|
||||
|
||||
for d in self.trial_indicies.keys():
|
||||
|
||||
# Calculate how many batches we need for this day
|
||||
num_trials = len(self.trial_indicies[d]['trials'])
|
||||
num_batches = (num_trials + self.batch_size - 1) // self.batch_size
|
||||
|
||||
# Create batches for this day
|
||||
for i in range(num_batches):
|
||||
start_idx = i * self.batch_size
|
||||
end_idx = min((i + 1) * self.batch_size, num_trials)
|
||||
|
||||
# Get the trial indices for this batch
|
||||
batch_trials = self.trial_indicies[d]['trials'][start_idx:end_idx]
|
||||
|
||||
# Add to batch_index
|
||||
batch_index[batch_idx] = {d : batch_trials}
|
||||
batch_idx += 1
|
||||
|
||||
return batch_index
|
||||
|
||||
def train_test_split_indicies(file_paths, test_percentage = 0.1, seed = -1, bad_trials_dict = None):
|
||||
'''
|
||||
Split data from file_paths into train and test splits
|
||||
Returns two dictionaries that detail which trials in each day will be a part of that split:
|
||||
Example:
|
||||
{
|
||||
0: trials[1,2,3], session_path: 'path'
|
||||
1: trials[2,5,6], session_path: 'path'
|
||||
}
|
||||
|
||||
Args:
|
||||
file_paths (list): List of file paths to the hdf5 files containing the data
|
||||
test_percentage (float): Percentage of trials to use for testing. 0 will use all trials for training, 1 will use all trials for testing
|
||||
seed (int): Seed for reproducibility. If set to -1, the split will be random
|
||||
bad_trials_dict (dict): Dictionary of trials to exclude from the dataset. Formatted as:
|
||||
{
|
||||
'session_name_1': {block_num_1: [trial_nums], block_num_2: [trial_nums], ...},
|
||||
'session_name_2': {block_num_1: [trial_nums], block_num_2: [trial_nums], ...},
|
||||
...
|
||||
}
|
||||
'''
|
||||
# Set seed for reporoducibility
|
||||
if seed != -1:
|
||||
np.random.seed(seed)
|
||||
|
||||
# Get trials in each day
|
||||
trials_per_day = {}
|
||||
for i, path in enumerate(file_paths):
|
||||
session = [s for s in path.split('/') if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
|
||||
|
||||
good_trial_indices = []
|
||||
|
||||
if os.path.exists(path):
|
||||
with h5py.File(path, 'r') as f:
|
||||
num_trials = len(list(f.keys()))
|
||||
for t in range(num_trials):
|
||||
key = f'trial_{t}'
|
||||
|
||||
block_num = f[key].attrs['block_num']
|
||||
trial_num = f[key].attrs['trial_num']
|
||||
|
||||
if (
|
||||
bad_trials_dict is not None
|
||||
and session in bad_trials_dict
|
||||
and str(block_num) in bad_trials_dict[session]
|
||||
and trial_num in bad_trials_dict[session][str(block_num)]
|
||||
):
|
||||
# print(f'Bad trial: {session}_{block_num}_{trial_num}')
|
||||
continue
|
||||
|
||||
good_trial_indices.append(t)
|
||||
|
||||
trials_per_day[i] = {'num_trials': len(good_trial_indices), 'trial_indices': good_trial_indices, 'session_path': path}
|
||||
|
||||
# Pick test_percentage of trials from each day for testing and (1 - test_percentage) for training
|
||||
train_trials = {}
|
||||
test_trials = {}
|
||||
|
||||
for day in trials_per_day.keys():
|
||||
|
||||
num_trials = trials_per_day[day]['num_trials']
|
||||
|
||||
# Generate all trial indices for this day (assuming 0-indexed)
|
||||
all_trial_indices = trials_per_day[day]['trial_indices']
|
||||
|
||||
# If test_percentage is 0 or 1, we can just assign all trials to either train or test
|
||||
if test_percentage == 0:
|
||||
train_trials[day] = {'trials' : all_trial_indices, 'session_path' : trials_per_day[day]['session_path']}
|
||||
test_trials[day] = {'trials' : [], 'session_path' : trials_per_day[day]['session_path']}
|
||||
continue
|
||||
|
||||
elif test_percentage == 1:
|
||||
train_trials[day] = {'trials' : [], 'session_path' : trials_per_day[day]['session_path']}
|
||||
test_trials[day] = {'trials' : all_trial_indices, 'session_path' : trials_per_day[day]['session_path']}
|
||||
continue
|
||||
|
||||
else:
|
||||
# Calculate how many trials to use for testing
|
||||
num_test = max(1, int(num_trials * test_percentage))
|
||||
|
||||
# Randomly select indices for testing
|
||||
test_indices = np.random.choice(all_trial_indices, size=num_test, replace=False).tolist()
|
||||
|
||||
# Remaining indices go to training
|
||||
train_indices = [idx for idx in all_trial_indices if idx not in test_indices]
|
||||
|
||||
# Store the split indices
|
||||
train_trials[day] = {'trials' : train_indices, 'session_path' : trials_per_day[day]['session_path']}
|
||||
test_trials[day] = {'trials' : test_indices, 'session_path' : trials_per_day[day]['session_path']}
|
||||
|
||||
return train_trials, test_trials
|
169
model_training/rnn_args.yaml
Normal file
169
model_training/rnn_args.yaml
Normal file
@@ -0,0 +1,169 @@
|
||||
model:
|
||||
n_input_features: 512
|
||||
n_units: 768
|
||||
rnn_dropout: 0.4
|
||||
rnn_trainable: true
|
||||
n_layers: 5
|
||||
bidirectional: false
|
||||
patch_size: 14
|
||||
patch_stride: 4
|
||||
input_network:
|
||||
n_input_layers: 1
|
||||
input_layer_sizes:
|
||||
- 512
|
||||
input_trainable: true
|
||||
input_layer_dropout: 0.2
|
||||
gpu_number: '1'
|
||||
distributed_training: false
|
||||
mode: train
|
||||
use_amp: true
|
||||
output_dir: /media/lm-pc/8tb_nvme/b2txt25/rnn_v2_jitter
|
||||
init_from_checkpoint: false
|
||||
checkpoint_dir: /media/lm-pc/8tb_nvme/b2txt25/rnn_v2_jitter/checkpoint
|
||||
init_checkpoint_path: None
|
||||
save_best_checkpoint: true
|
||||
save_all_val_steps: false
|
||||
save_final_model: false
|
||||
save_val_metrics: true
|
||||
early_stopping: false
|
||||
early_stopping_val_steps: 20
|
||||
num_training_batches: 120000
|
||||
lr_scheduler_type: cosine
|
||||
lr_max: 0.005
|
||||
lr_min: 0.0001
|
||||
lr_decay_steps: 120000
|
||||
lr_warmup_steps: 1000
|
||||
lr_max_day: 0.005
|
||||
lr_min_day: 0.0001
|
||||
lr_decay_steps_day: 120000
|
||||
lr_warmup_steps_day: 1000
|
||||
beta0: 0.9
|
||||
beta1: 0.999
|
||||
epsilon: 0.1
|
||||
weight_decay: 0.001
|
||||
weight_decay_day: 0
|
||||
seed: 10
|
||||
grad_norm_clip_value: 10
|
||||
batches_per_train_log: 200
|
||||
batches_per_val_step: 2000
|
||||
batches_per_save: 0
|
||||
log_individual_day_val_PER: true
|
||||
log_val_skip_logs: false
|
||||
save_val_logits: true
|
||||
save_val_data: false
|
||||
dataset:
|
||||
data_transforms:
|
||||
white_noise_std: 1.0
|
||||
constant_offset_std: 0.2
|
||||
random_walk_std: 0.0
|
||||
random_walk_axis: -1
|
||||
static_gain_std: 0.0
|
||||
random_cut: 3 #0
|
||||
smooth_kernel_size: 100
|
||||
smooth_data: true
|
||||
smooth_kernel_std: 2
|
||||
neural_dim: 512
|
||||
batch_size: 64
|
||||
n_classes: 41
|
||||
max_seq_elements: 500
|
||||
days_per_batch: 4
|
||||
seed: 1
|
||||
num_dataloader_workers: 4
|
||||
loader_shuffle: false
|
||||
must_include_days: null
|
||||
test_percentage: 0.1
|
||||
feature_subset: null
|
||||
dataset_dir: /media/lm-pc/8tb_nvme/b2txt25/hdf5_data
|
||||
bad_trials_dict: null
|
||||
sessions:
|
||||
- t15.2023.08.11
|
||||
- t15.2023.08.13
|
||||
- t15.2023.08.18
|
||||
- t15.2023.08.20
|
||||
- t15.2023.08.25
|
||||
- t15.2023.08.27
|
||||
- t15.2023.09.01
|
||||
- t15.2023.09.03
|
||||
- t15.2023.09.24
|
||||
- t15.2023.09.29
|
||||
- t15.2023.10.01
|
||||
- t15.2023.10.06
|
||||
- t15.2023.10.08
|
||||
- t15.2023.10.13
|
||||
- t15.2023.10.15
|
||||
- t15.2023.10.20
|
||||
- t15.2023.10.22
|
||||
- t15.2023.11.03
|
||||
- t15.2023.11.04
|
||||
- t15.2023.11.17
|
||||
- t15.2023.11.19
|
||||
- t15.2023.11.26
|
||||
- t15.2023.12.03
|
||||
- t15.2023.12.08
|
||||
- t15.2023.12.10
|
||||
- t15.2023.12.17
|
||||
- t15.2023.12.29
|
||||
- t15.2024.02.25
|
||||
- t15.2024.03.03
|
||||
- t15.2024.03.08
|
||||
- t15.2024.03.15
|
||||
- t15.2024.03.17
|
||||
- t15.2024.04.25
|
||||
- t15.2024.04.28
|
||||
- t15.2024.05.10
|
||||
- t15.2024.06.14
|
||||
- t15.2024.07.19
|
||||
- t15.2024.07.21
|
||||
- t15.2024.07.28
|
||||
- t15.2025.01.10
|
||||
- t15.2025.01.12
|
||||
- t15.2025.03.14
|
||||
- t15.2025.03.16
|
||||
- t15.2025.03.30
|
||||
- t15.2025.04.13
|
||||
dataset_probability_val:
|
||||
- 0
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 0
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 0
|
||||
- 0
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
||||
- 1
|
136
model_training/rnn_model.py
Normal file
136
model_training/rnn_model.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
class GRUDecoder(nn.Module):
|
||||
'''
|
||||
Defines the GRU decoder
|
||||
|
||||
This class combines day-specific input layers, a GRU, and an output classification layer
|
||||
'''
|
||||
def __init__(self,
|
||||
neural_dim,
|
||||
n_units,
|
||||
n_days,
|
||||
n_classes,
|
||||
rnn_dropout = 0.0,
|
||||
input_dropout = 0.0,
|
||||
n_layers = 5,
|
||||
patch_size = 0,
|
||||
patch_stride = 0,
|
||||
):
|
||||
'''
|
||||
neural_dim (int) - number of channels in a single timestep (e.g. 512)
|
||||
n_units (int) - number of hidden units in each recurrent layer - equal to the size of the hidden state
|
||||
n_days (int) - number of days in the dataset
|
||||
n_classes (int) - number of classes
|
||||
rnn_dropout (float) - percentage of units to droupout during training
|
||||
input_dropout (float) - percentage of input units to dropout during training
|
||||
n_layers (int) - number of recurrent layers
|
||||
patch_size (int) - the number of timesteps to concat on initial input layer - a value of 0 will disable this "input concat" step
|
||||
patch_stride(int) - the number of timesteps to stride over when concatenating initial input
|
||||
'''
|
||||
super(GRUDecoder, self).__init__()
|
||||
|
||||
self.neural_dim = neural_dim
|
||||
self.n_units = n_units
|
||||
self.n_classes = n_classes
|
||||
self.n_layers = n_layers
|
||||
self.n_days = n_days
|
||||
|
||||
self.rnn_dropout = rnn_dropout
|
||||
self.input_dropout = input_dropout
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.patch_stride = patch_stride
|
||||
|
||||
# Parameters for the day-specific input layers
|
||||
self.day_layer_activation = nn.Softsign() # basically a shallower tanh
|
||||
|
||||
# Set weights for day layers to be identity matrices so the model can learn its own day-specific transformations
|
||||
self.day_weights = nn.ParameterList(
|
||||
[nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)]
|
||||
)
|
||||
self.day_biases = nn.ParameterList(
|
||||
[nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)]
|
||||
)
|
||||
|
||||
self.day_layer_dropout = nn.Dropout(input_dropout)
|
||||
|
||||
self.input_size = self.neural_dim
|
||||
|
||||
# If we are using "strided inputs", then the input size of the first recurrent layer will actually be in_size * patch_size
|
||||
if self.patch_size > 0:
|
||||
self.input_size *= self.patch_size
|
||||
|
||||
self.gru = nn.GRU(
|
||||
input_size = self.input_size,
|
||||
hidden_size = self.n_units,
|
||||
num_layers = self.n_layers,
|
||||
dropout = self.rnn_dropout,
|
||||
batch_first = True, # The first dim of our input is the batch dim
|
||||
bidirectional = False,
|
||||
)
|
||||
|
||||
# Set recurrent units to have orthogonal param init and input layers to have xavier init
|
||||
for name, param in self.gru.named_parameters():
|
||||
if "weight_hh" in name:
|
||||
nn.init.orthogonal_(param)
|
||||
if "weight_ih" in name:
|
||||
nn.init.xavier_uniform_(param)
|
||||
|
||||
# Prediciton head. Weight init to xavier
|
||||
self.out = nn.Linear(self.n_units, self.n_classes)
|
||||
nn.init.xavier_uniform_(self.out.weight)
|
||||
|
||||
# Learnable initial hidden states
|
||||
self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units)))
|
||||
|
||||
def forward(self, x, day_idx, states = None, return_state = False):
|
||||
'''
|
||||
x (tensor) - batch of examples (trials) of shape: (batch_size, time_series_length, neural_dim)
|
||||
day_idx (tensor) - tensor which is a list of day indexs corresponding to the day of each example in the batch x.
|
||||
'''
|
||||
|
||||
# Apply day-specific layer to (hopefully) project neural data from the different days to the same latent space
|
||||
day_weights = torch.stack([self.day_weights[i] for i in day_idx], dim=0)
|
||||
day_biases = torch.cat([self.day_biases[i] for i in day_idx], dim=0).unsqueeze(1)
|
||||
|
||||
x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases
|
||||
x = self.day_layer_activation(x)
|
||||
|
||||
# Apply dropout to the ouput of the day specific layer
|
||||
if self.input_dropout > 0:
|
||||
x = self.day_layer_dropout(x)
|
||||
|
||||
# (Optionally) Perform input concat operation
|
||||
if self.patch_size > 0:
|
||||
|
||||
x = x.unsqueeze(1) # [batches, 1, timesteps, feature_dim]
|
||||
x = x.permute(0, 3, 1, 2) # [batches, feature_dim, 1, timesteps]
|
||||
|
||||
# Extract patches using unfold (sliding window)
|
||||
x_unfold = x.unfold(3, self.patch_size, self.patch_stride) # [batches, feature_dim, 1, num_patches, patch_size]
|
||||
|
||||
# Remove dummy height dimension and rearrange dimensions
|
||||
x_unfold = x_unfold.squeeze(2) # [batches, feature_dum, num_patches, patch_size]
|
||||
x_unfold = x_unfold.permute(0, 2, 3, 1) # [batches, num_patches, patch_size, feature_dim]
|
||||
|
||||
# Flatten last two dimensions (patch_size and features)
|
||||
x = x_unfold.reshape(x.size(0), x_unfold.size(1), -1)
|
||||
|
||||
# Determine initial hidden states
|
||||
if states is None:
|
||||
states = self.h0.expand(self.n_layers, x.shape[0], self.n_units).contiguous()
|
||||
|
||||
# Pass input through RNN
|
||||
output, hidden_states = self.gru(x, states)
|
||||
|
||||
# Compute logits
|
||||
logits = self.out(output)
|
||||
|
||||
if return_state:
|
||||
return logits, hidden_states
|
||||
|
||||
return logits
|
||||
|
||||
|
754
model_training/rnn_trainer.py
Normal file
754
model_training/rnn_trainer.py
Normal file
@@ -0,0 +1,754 @@
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
import random
|
||||
import time
|
||||
import os
|
||||
import numpy as np
|
||||
import math
|
||||
import pathlib
|
||||
import logging
|
||||
import sys
|
||||
import json
|
||||
import pickle
|
||||
|
||||
from dataset import BrainToTextDataset, train_test_split_indicies
|
||||
from data_augmentations import gauss_smooth
|
||||
|
||||
import torchaudio.functional as F # for edit distance
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
torch.set_float32_matmul_precision('high') # makes float32 matmuls faster on some GPUs
|
||||
torch.backends.cudnn.deterministic = True # makes training more reproducible
|
||||
torch._dynamo.config.cache_size_limit = 64
|
||||
|
||||
from rnn_model import GRUDecoder
|
||||
|
||||
class BrainToTextDecoder_Trainer:
|
||||
"""
|
||||
This class will initialize and train a brain-to-text phoneme decoder
|
||||
|
||||
Written by Nick Card and Zachery Fogg with reference to Stanford NPTL's decoding function
|
||||
"""
|
||||
|
||||
def __init__(self, args):
|
||||
'''
|
||||
args : dictionary of training arguments
|
||||
'''
|
||||
|
||||
# Trainer fields
|
||||
self.args = args
|
||||
self.logger = None
|
||||
self.device = None
|
||||
self.model = None
|
||||
self.optimizer = None
|
||||
self.learning_rate_scheduler = None
|
||||
self.ctc_loss = None
|
||||
|
||||
self.best_val_PER = torch.inf # track best PER for checkpointing
|
||||
self.best_val_loss = torch.inf # track best loss for checkpointing
|
||||
|
||||
self.train_dataset = None
|
||||
self.val_dataset = None
|
||||
self.train_loader = None
|
||||
self.val_loader = None
|
||||
|
||||
self.transform_args = self.args['dataset']['data_transforms']
|
||||
|
||||
# Create output directory
|
||||
if args['mode'] == 'train':
|
||||
os.makedirs(self.args['output_dir'], exist_ok=True)
|
||||
|
||||
# Create checkpoint directory
|
||||
if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']:
|
||||
os.makedirs(self.args['checkpoint_dir'], exist_ok = True)
|
||||
|
||||
# Set up logging
|
||||
self.logger = logging.getLogger(__name__)
|
||||
for handler in self.logger.handlers[:]: # make a copy of the list
|
||||
self.logger.removeHandler(handler)
|
||||
self.logger.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter(fmt='%(asctime)s: %(message)s')
|
||||
|
||||
if args['mode']=='train':
|
||||
# During training, save logs to file in output directory
|
||||
fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'],'training_log')))
|
||||
fh.setFormatter(formatter)
|
||||
self.logger.addHandler(fh)
|
||||
|
||||
# Always print logs to stdout
|
||||
sh = logging.StreamHandler(sys.stdout)
|
||||
sh.setFormatter(formatter)
|
||||
self.logger.addHandler(sh)
|
||||
|
||||
# Configure device pytorch will use
|
||||
if not self.args['distributed_training']:
|
||||
if torch.cuda.is_available():
|
||||
self.device = f"cuda:{self.args['gpu_number']}"
|
||||
|
||||
else:
|
||||
self.device = "cpu"
|
||||
else:
|
||||
self.device = "cuda"
|
||||
|
||||
self.logger.info(f'Using device: {self.device}')
|
||||
|
||||
# Set seed if provided
|
||||
if self.args['seed'] != -1:
|
||||
np.random.seed(self.args['seed'])
|
||||
random.seed(self.args['seed'])
|
||||
torch.manual_seed(self.args['seed'])
|
||||
|
||||
# Initialize the model
|
||||
self.model = GRUDecoder(
|
||||
neural_dim = self.args['model']['n_input_features'],
|
||||
n_units = self.args['model']['n_units'],
|
||||
n_days = len(self.args['dataset']['sessions']),
|
||||
n_classes = self.args['dataset']['n_classes'],
|
||||
rnn_dropout = self.args['model']['rnn_dropout'],
|
||||
input_dropout = self.args['model']['input_network']['input_layer_dropout'],
|
||||
n_layers = self.args['model']['n_layers'],
|
||||
bidirectional = self.args['model']['bidirectional'],
|
||||
patch_size = self.args['model']['patch_size'],
|
||||
patch_stride = self.args['model']['patch_stride'],
|
||||
)
|
||||
|
||||
# Call torch.compile to speed up training
|
||||
self.logger.info("Using torch.compile")
|
||||
self.model = torch.compile(self.model)
|
||||
|
||||
self.logger.info(f"Initialized RNN decoding model")
|
||||
|
||||
self.logger.info(self.model)
|
||||
|
||||
# Log how many parameters are in the model
|
||||
total_params = sum(p.numel() for p in self.model.parameters())
|
||||
self.logger.info(f"Model has {total_params:,} parameters")
|
||||
|
||||
# Determine how many day-specific parameters are in the model
|
||||
day_params = 0
|
||||
for name, param in self.model.named_parameters():
|
||||
if 'day' in name:
|
||||
day_params += param.numel()
|
||||
|
||||
self.logger.info(f"Model has {day_params:,} day-specific parameters | {((day_params / total_params) * 100):.2f}% of total parameters")
|
||||
|
||||
# Create datasets and dataloaders
|
||||
train_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_train.hdf5') for s in self.args['dataset']['sessions']]
|
||||
val_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_val.hdf5') for s in self.args['dataset']['sessions']]
|
||||
|
||||
# Ensure that there are no duplicate days
|
||||
if len(set(train_file_paths)) != len(train_file_paths):
|
||||
raise ValueError("There are duplicate sessions listed in the train dataset")
|
||||
if len(set(val_file_paths)) != len(val_file_paths):
|
||||
raise ValueError("There are duplicate sessions listed in the val dataset")
|
||||
|
||||
# Split trials into train and test sets
|
||||
train_trials, _ = train_test_split_indicies(
|
||||
file_paths = train_file_paths,
|
||||
test_percentage = 0,
|
||||
seed = self.args['dataset']['seed'],
|
||||
bad_trials_dict = None,
|
||||
)
|
||||
_, val_trials = train_test_split_indicies(
|
||||
file_paths = val_file_paths,
|
||||
test_percentage = 1,
|
||||
seed = self.args['dataset']['seed'],
|
||||
bad_trials_dict = None,
|
||||
)
|
||||
|
||||
# Save dictionaries to output directory to know which trials were train vs val
|
||||
with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f:
|
||||
json.dump({'train' : train_trials, 'val': val_trials}, f)
|
||||
|
||||
# Determine if a only a subset of neural features should be used
|
||||
feature_subset = None
|
||||
if ('feature_subset' in self.args['dataset']) and self.args['dataset']['feature_subset'] != None:
|
||||
feature_subset = self.args['dataset']['feature_subset']
|
||||
self.logger.info(f'Using only a subset of features: {feature_subset}')
|
||||
|
||||
# train dataset and dataloader
|
||||
self.train_dataset = BrainToTextDataset(
|
||||
trial_indicies = train_trials,
|
||||
split = 'train',
|
||||
days_per_batch = self.args['dataset']['days_per_batch'],
|
||||
n_batches = self.args['num_training_batches'],
|
||||
batch_size = self.args['dataset']['batch_size'],
|
||||
must_include_days = None,
|
||||
random_seed = self.args['dataset']['seed'],
|
||||
feature_subset = feature_subset
|
||||
)
|
||||
self.train_loader = DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size = None, # Dataset.__getitem__() already returns batches
|
||||
shuffle = self.args['dataset']['loader_shuffle'],
|
||||
num_workers = self.args['dataset']['num_dataloader_workers'],
|
||||
pin_memory = True
|
||||
)
|
||||
|
||||
# val dataset and dataloader
|
||||
self.val_dataset = BrainToTextDataset(
|
||||
trial_indicies = val_trials,
|
||||
split = 'test',
|
||||
days_per_batch = None,
|
||||
n_batches = None,
|
||||
batch_size = self.args['dataset']['batch_size'],
|
||||
must_include_days = None,
|
||||
random_seed = self.args['dataset']['seed'],
|
||||
feature_subset = feature_subset
|
||||
)
|
||||
self.val_loader = DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size = None, # Dataset.__getitem__() already returns batches
|
||||
shuffle = False,
|
||||
num_workers = 0,
|
||||
pin_memory = True
|
||||
)
|
||||
|
||||
self.logger.info("Successfully initialized datasets")
|
||||
|
||||
# Create optimizer, learning rate scheduler, and loss
|
||||
self.optimizer = self.create_optimizer()
|
||||
|
||||
if self.args['lr_scheduler_type'] == 'linear':
|
||||
self.learning_rate_scheduler = torch.optim.lr_scheduler.LinearLR(
|
||||
optimizer = self.optimizer,
|
||||
start_factor = 1.0,
|
||||
end_factor = self.args['lr_min'] / self.args['lr_max'],
|
||||
total_iters = self.args['lr_decay_steps'],
|
||||
)
|
||||
elif self.args['lr_scheduler_type'] == 'cosine':
|
||||
self.learning_rate_scheduler = self.create_cosine_lr_scheduler(self.optimizer)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid learning rate scheduler type: {self.args['lr_scheduler_type']}")
|
||||
|
||||
self.ctc_loss = torch.nn.CTCLoss(blank = 0, reduction = 'none', zero_infinity = False)
|
||||
|
||||
# If a checkpoint is provided, then load from checkpoint
|
||||
if self.args['init_from_checkpoint']:
|
||||
self.load_model_checkpoint(self.args['init_checkpoint_path'])
|
||||
|
||||
# Set rnn and/or input layers to not trainable if specified
|
||||
for name, param in self.model.named_parameters():
|
||||
if not self.args['model']['rnn_trainable'] and 'gru' in name:
|
||||
param.requires_grad = False
|
||||
|
||||
elif not self.args['model']['input_network']['input_trainable'] and 'day' in name:
|
||||
param.requires_grad = False
|
||||
|
||||
# Send model to device
|
||||
self.model.to(self.device)
|
||||
|
||||
def create_optimizer(self):
|
||||
'''
|
||||
Create the optimizer with special param groups
|
||||
|
||||
Biases and day weights should not be decayed
|
||||
|
||||
Day weights should have a separate learning rate
|
||||
'''
|
||||
bias_params = [p for name, p in self.model.named_parameters() if 'gru.bias' in name or 'out.bias' in name]
|
||||
day_params = [p for name, p in self.model.named_parameters() if 'day_' in name]
|
||||
other_params = [p for name, p in self.model.named_parameters() if 'day_' not in name and 'gru.bias' not in name and 'out.bias' not in name]
|
||||
|
||||
if len(day_params) != 0:
|
||||
param_groups = [
|
||||
{'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'},
|
||||
{'params' : day_params, 'lr' : self.args['lr_max_day'], 'weight_decay' : self.args['weight_decay_day'], 'group_type' : 'day_layer'},
|
||||
{'params' : other_params, 'group_type' : 'other'}
|
||||
]
|
||||
else:
|
||||
param_groups = [
|
||||
{'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'},
|
||||
{'params' : other_params, 'group_type' : 'other'}
|
||||
]
|
||||
|
||||
optim = torch.optim.AdamW(
|
||||
param_groups,
|
||||
lr = self.args['lr_max'],
|
||||
betas = (self.args['beta0'], self.args['beta1']),
|
||||
eps = self.args['epsilon'],
|
||||
weight_decay = self.args['weight_decay'],
|
||||
fused = True
|
||||
)
|
||||
|
||||
return optim
|
||||
|
||||
def create_cosine_lr_scheduler(self, optim):
|
||||
lr_max = self.args['lr_max']
|
||||
lr_min = self.args['lr_min']
|
||||
lr_decay_steps = self.args['lr_decay_steps']
|
||||
|
||||
lr_max_day = self.args['lr_max_day']
|
||||
lr_min_day = self.args['lr_min_day']
|
||||
lr_decay_steps_day = self.args['lr_decay_steps_day']
|
||||
|
||||
lr_warmup_steps = self.args['lr_warmup_steps']
|
||||
lr_warmup_steps_day = self.args['lr_warmup_steps_day']
|
||||
|
||||
def lr_lambda(current_step, min_lr_ratio, decay_steps, warmup_steps):
|
||||
'''
|
||||
Create lr lambdas for each param group that implement cosine decay
|
||||
|
||||
Different lr lambda decaying for day params vs rest of the model
|
||||
'''
|
||||
# Warmup phase
|
||||
if current_step < warmup_steps:
|
||||
return float(current_step) / float(max(1, warmup_steps))
|
||||
|
||||
# Cosine decay phase
|
||||
if current_step < decay_steps:
|
||||
progress = float(current_step - warmup_steps) / float(
|
||||
max(1, decay_steps - warmup_steps)
|
||||
)
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
|
||||
# Scale from 1.0 to min_lr_ratio
|
||||
return max(min_lr_ratio, min_lr_ratio + (1 - min_lr_ratio) * cosine_decay)
|
||||
|
||||
# After cosine decay is complete, maintain min_lr_ratio
|
||||
return min_lr_ratio
|
||||
|
||||
if len(optim.param_groups) == 3:
|
||||
lr_lambdas = [
|
||||
lambda step: lr_lambda(
|
||||
step,
|
||||
lr_min / lr_max,
|
||||
lr_decay_steps,
|
||||
lr_warmup_steps), # biases
|
||||
lambda step: lr_lambda(
|
||||
step,
|
||||
lr_min_day / lr_max_day,
|
||||
lr_decay_steps_day,
|
||||
lr_warmup_steps_day,
|
||||
), # day params
|
||||
lambda step: lr_lambda(
|
||||
step,
|
||||
lr_min / lr_max,
|
||||
lr_decay_steps,
|
||||
lr_warmup_steps), # rest of model weights
|
||||
]
|
||||
elif len(optim.param_groups) == 2:
|
||||
lr_lambdas = [
|
||||
lambda step: lr_lambda(
|
||||
step,
|
||||
lr_min / lr_max,
|
||||
lr_decay_steps,
|
||||
lr_warmup_steps), # biases
|
||||
lambda step: lr_lambda(
|
||||
step,
|
||||
lr_min / lr_max,
|
||||
lr_decay_steps,
|
||||
lr_warmup_steps), # rest of model weights
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Invalid number of param groups in optimizer: {len(optim.param_groups)}")
|
||||
|
||||
return LambdaLR(optim, lr_lambdas, -1)
|
||||
|
||||
def load_model_checkpoint(self, load_path):
|
||||
'''
|
||||
Load a training checkpoint
|
||||
'''
|
||||
checkpoint = torch.load(load_path, weights_only = False) # checkpoint is just a dict
|
||||
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
self.best_val_PER = checkpoint['val_PER'] # best phoneme error rate
|
||||
self.best_val_loss = checkpoint['val_loss'] if 'val_loss' in checkpoint.keys() else torch.inf
|
||||
|
||||
self.model.to(self.device)
|
||||
|
||||
# Send optimizer params back to GPU
|
||||
for state in self.optimizer.state.values():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.to(self.device)
|
||||
|
||||
self.logger.info("Loaded model from checkpoint: " + load_path)
|
||||
|
||||
def save_model_checkpoint(self, save_path, PER, loss):
|
||||
'''
|
||||
Save a training checkpoint
|
||||
'''
|
||||
|
||||
checkpoint = {
|
||||
'model_state_dict' : self.model.state_dict(),
|
||||
'optimizer_state_dict' : self.optimizer.state_dict(),
|
||||
'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(),
|
||||
'val_PER' : PER,
|
||||
'val_loss' : loss
|
||||
}
|
||||
|
||||
torch.save(checkpoint, save_path)
|
||||
|
||||
self.logger.info("Saved model to checkpoint: " + save_path)
|
||||
|
||||
# Save the args file alongside the checkpoint
|
||||
with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f:
|
||||
OmegaConf.save(config=self.args, f=f)
|
||||
|
||||
def create_attention_mask(self, sequence_lengths):
|
||||
|
||||
max_length = torch.max(sequence_lengths).item()
|
||||
|
||||
batch_size = sequence_lengths.size(0)
|
||||
|
||||
# Create a mask for valid key positions (columns)
|
||||
# Shape: [batch_size, max_length]
|
||||
key_mask = torch.arange(max_length, device=sequence_lengths.device).expand(batch_size, max_length)
|
||||
key_mask = key_mask < sequence_lengths.unsqueeze(1)
|
||||
|
||||
# Expand key_mask to [batch_size, 1, 1, max_length]
|
||||
# This will be broadcast across all query positions
|
||||
key_mask = key_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
# Create the attention mask of shape [batch_size, 1, max_length, max_length]
|
||||
# by broadcasting key_mask across all query positions
|
||||
attention_mask = key_mask.expand(batch_size, 1, max_length, max_length)
|
||||
|
||||
# Convert boolean mask to float mask:
|
||||
# - True (valid key positions) -> 0.0 (no change to attention scores)
|
||||
# - False (padding key positions) -> -inf (will become 0 after softmax)
|
||||
attention_mask_float = torch.where(attention_mask,
|
||||
True,
|
||||
False)
|
||||
|
||||
return attention_mask_float
|
||||
|
||||
def transform_data(self, features, n_time_steps, mode = 'train'):
|
||||
'''
|
||||
Apply various augmentations and smoothing to data
|
||||
Performing augmentations is much faster on GPU than CPU
|
||||
'''
|
||||
|
||||
data_shape = features.shape
|
||||
batch_size = data_shape[0]
|
||||
channels = data_shape[-1]
|
||||
|
||||
# We only apply these augmentations in training
|
||||
if mode == 'train':
|
||||
# add static gain noise
|
||||
if self.transform_args['static_gain_std'] > 0:
|
||||
warp_mat = torch.tile(torch.unsqueeze(torch.eye(channels), dim = 0), (batch_size, 1, 1))
|
||||
warp_mat += torch.randn_like(warp_mat, device=self.device) * self.transform_args['static_gain_std']
|
||||
|
||||
features = torch.matmul(features, warp_mat)
|
||||
|
||||
# add white noise
|
||||
if self.transform_args['white_noise_std'] > 0:
|
||||
features += torch.randn(data_shape, device=self.device) * self.transform_args['white_noise_std']
|
||||
|
||||
# add constant offset noise
|
||||
if self.transform_args['constant_offset_std'] > 0:
|
||||
features += torch.randn((batch_size, 1, channels), device=self.device) * self.transform_args['constant_offset_std']
|
||||
|
||||
# add random walk noise
|
||||
if self.transform_args['random_walk_std'] > 0:
|
||||
features += torch.cumsum(torch.randn(data_shape, device=self.device) * self.transform_args['random_walk_std'], dim =self.transform_args['random_walk_axis'])
|
||||
|
||||
# randomly cutoff part of the data timecourse
|
||||
if self.transform_args['random_cut'] > 0:
|
||||
cut = np.random.randint(0, self.transform_args['random_cut'])
|
||||
features = features[:, cut:, :]
|
||||
n_time_steps = n_time_steps - cut
|
||||
|
||||
# Apply Gaussian smoothing to data
|
||||
# This is done in both training and validation
|
||||
if self.transform_args['smooth_data']:
|
||||
features = gauss_smooth(
|
||||
inputs = features,
|
||||
device = self.device,
|
||||
smooth_kernel_std = self.transform_args['smooth_kernel_std'],
|
||||
smooth_kernel_size= self.transform_args['smooth_kernel_size'],
|
||||
)
|
||||
|
||||
|
||||
return features, n_time_steps
|
||||
|
||||
def train(self):
|
||||
'''
|
||||
Train the model
|
||||
'''
|
||||
|
||||
# Set model to train mode (specificially to make sure dropout layers are engaged)
|
||||
self.model.train()
|
||||
|
||||
# create vars to track performance
|
||||
train_losses = []
|
||||
val_losses = []
|
||||
val_PERs = []
|
||||
val_results = []
|
||||
|
||||
val_steps_since_improvement = 0
|
||||
|
||||
# training params
|
||||
save_best_checkpoint = self.args.get('save_best_checkpoint', True)
|
||||
early_stopping = self.args.get('early_stopping', True)
|
||||
|
||||
early_stopping_val_steps = self.args['early_stopping_val_steps']
|
||||
|
||||
train_start_time = time.time()
|
||||
|
||||
# train for specified number of batches
|
||||
for i, batch in enumerate(self.train_loader):
|
||||
|
||||
self.model.train()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Train step
|
||||
start_time = time.time()
|
||||
|
||||
# Move data to device
|
||||
features = batch['input_features'].to(self.device)
|
||||
labels = batch['seq_class_ids'].to(self.device)
|
||||
n_time_steps = batch['n_time_steps'].to(self.device)
|
||||
phone_seq_lens = batch['phone_seq_lens'].to(self.device)
|
||||
day_indicies = batch['day_indicies'].to(self.device)
|
||||
|
||||
# Use autocast for efficiency
|
||||
with torch.autocast(device_type = "cuda", enabled = self.args['use_amp'], dtype = torch.bfloat16):
|
||||
|
||||
# Apply augmentations to the data
|
||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'train')
|
||||
|
||||
adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||
|
||||
# Get phoneme predictions
|
||||
logits = self.model(features, day_indicies)
|
||||
|
||||
# Calculate CTC Loss
|
||||
loss = self.ctc_loss(
|
||||
log_probs = torch.permute(logits.log_softmax(2), [1, 0, 2]),
|
||||
targets = labels,
|
||||
input_lengths = adjusted_lens,
|
||||
target_lengths = phone_seq_lens
|
||||
)
|
||||
|
||||
loss = torch.mean(loss) # take mean loss over batches
|
||||
|
||||
loss.backward()
|
||||
|
||||
# Clip gradient
|
||||
if self.args['grad_norm_clip_value'] > 0:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(),
|
||||
max_norm = self.args['grad_norm_clip_value'],
|
||||
error_if_nonfinite = True,
|
||||
foreach = True
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
self.learning_rate_scheduler.step()
|
||||
|
||||
# Save training metrics
|
||||
train_step_duration = time.time() - start_time
|
||||
train_losses.append(loss.detach().item())
|
||||
|
||||
# Incrementally log training progress
|
||||
if i % self.args['batches_per_train_log'] == 0:
|
||||
self.logger.info(f'Train batch {i}: ' +
|
||||
f'loss: {(loss.detach().item()):.2f} ' +
|
||||
f'grad norm: {grad_norm:.2f} '
|
||||
f'time: {train_step_duration:.3f}')
|
||||
|
||||
# Incrementally run a test step
|
||||
if i % self.args['batches_per_val_step'] == 0 or i == ((self.args['num_training_batches'] - 1)):
|
||||
self.logger.info(f"Running test after training batch: {i}")
|
||||
|
||||
# Calculate metrics on val data
|
||||
start_time = time.time()
|
||||
val_metrics = self.validation(loader = self.val_loader, return_logits = self.args['save_val_logits'], return_data = self.args['save_val_data'])
|
||||
val_step_duration = time.time() - start_time
|
||||
|
||||
|
||||
# Log info
|
||||
self.logger.info(f'Val batch {i}: ' +
|
||||
f'PER (avg): {val_metrics["avg_PER"]:.4f} ' +
|
||||
f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} ' +
|
||||
f'time: {val_step_duration:.3f}')
|
||||
|
||||
if self.args['log_individual_day_val_PER']:
|
||||
for day in val_metrics['day_PERs'].keys():
|
||||
self.logger.info(f"{self.args['dataset']['sessions'][day]} val PER: {val_metrics['day_PERs'][day]['total_edit_distance'] / val_metrics['day_PERs'][day]['total_seq_length']:0.4f}")
|
||||
|
||||
# Save metrics
|
||||
val_PERs.append(val_metrics['avg_PER'])
|
||||
val_losses.append(val_metrics['avg_loss'])
|
||||
val_results.append(val_metrics)
|
||||
|
||||
# Determine if new best day. Based on if PER is lower, or in the case of a PER tie, if loss is lower
|
||||
new_best = False
|
||||
if val_metrics['avg_PER'] < self.best_val_PER:
|
||||
self.logger.info(f"New best test PER {self.best_val_PER:.4f} --> {val_metrics['avg_PER']:.4f}")
|
||||
self.best_val_PER = val_metrics['avg_PER']
|
||||
self.best_val_loss = val_metrics['avg_loss']
|
||||
new_best = True
|
||||
elif val_metrics['avg_PER'] == self.best_val_PER and (val_metrics['avg_loss'] < self.best_val_loss):
|
||||
self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}")
|
||||
self.best_val_loss = val_metrics['avg_loss']
|
||||
new_best = True
|
||||
|
||||
if new_best:
|
||||
|
||||
# Checkpoint if metrics have improved
|
||||
if save_best_checkpoint:
|
||||
self.logger.info(f"Checkpointing model")
|
||||
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/best_checkpoint', self.best_val_PER, self.best_val_loss)
|
||||
|
||||
# save validation metrics to pickle file
|
||||
if self.args['save_val_metrics']:
|
||||
with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f:
|
||||
pickle.dump(val_metrics, f)
|
||||
|
||||
val_steps_since_improvement = 0
|
||||
|
||||
else:
|
||||
val_steps_since_improvement +=1
|
||||
|
||||
# Optionally save this validation checkpoint, regardless of performance
|
||||
if self.args['save_all_val_steps']:
|
||||
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/checkpoint_batch_{i}', val_metrics['avg_PER'])
|
||||
|
||||
# Early stopping
|
||||
if early_stopping and (val_steps_since_improvement >= early_stopping_val_steps):
|
||||
self.logger.info(f'Overall validation PER has not improved in {early_stopping_val_steps} validation steps. Stopping training early at batch: {i}')
|
||||
break
|
||||
|
||||
# Log final training steps
|
||||
training_duration = time.time() - train_start_time
|
||||
|
||||
|
||||
self.logger.info(f'Best avg val PER achieved: {self.best_val_PER:.5f}')
|
||||
self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes')
|
||||
|
||||
# Save final model
|
||||
if self.args['save_final_model']:
|
||||
self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/final_checkpoint_batch_{i}', val_PERs[-1])
|
||||
|
||||
train_stats = {}
|
||||
train_stats['train_losses'] = train_losses
|
||||
train_stats['val_losses'] = val_losses
|
||||
train_stats['val_PERs'] = val_PERs
|
||||
train_stats['val_metrics'] = val_results
|
||||
|
||||
return train_stats
|
||||
|
||||
def validation(self, loader, return_logits = False, return_data = False):
|
||||
'''
|
||||
Calculate metrics on the validation dataset
|
||||
'''
|
||||
self.model.eval()
|
||||
|
||||
metrics = {}
|
||||
|
||||
# Record metrics
|
||||
if return_logits:
|
||||
metrics['logits'] = []
|
||||
metrics['n_time_steps'] = []
|
||||
|
||||
if return_data:
|
||||
metrics['input_features'] = []
|
||||
|
||||
metrics['decoded_seqs'] = []
|
||||
metrics['true_seq'] = []
|
||||
metrics['phone_seq_lens'] = []
|
||||
metrics['transcription'] = []
|
||||
metrics['losses'] = []
|
||||
metrics['block_nums'] = []
|
||||
metrics['trial_nums'] = []
|
||||
metrics['day_indicies'] = []
|
||||
|
||||
total_edit_distance = 0
|
||||
total_seq_length = 0
|
||||
|
||||
# Calculate PER for each specific day
|
||||
day_per = {}
|
||||
for d in range(len(self.args['dataset']['sessions'])):
|
||||
if self.args['dataset']['dataset_probability_val'][d] == 1:
|
||||
day_per[d] = {'total_edit_distance' : 0, 'total_seq_length' : 0}
|
||||
|
||||
for i, batch in enumerate(loader):
|
||||
|
||||
features = batch['input_features'].to(self.device)
|
||||
labels = batch['seq_class_ids'].to(self.device)
|
||||
n_time_steps = batch['n_time_steps'].to(self.device)
|
||||
phone_seq_lens = batch['phone_seq_lens'].to(self.device)
|
||||
day_indicies = batch['day_indicies'].to(self.device)
|
||||
|
||||
# Determine if we should perform validation on this batch
|
||||
day = day_indicies[0].item()
|
||||
if self.args['dataset']['dataset_probability_val'][day] == 0:
|
||||
if self.args['log_val_skip_logs']:
|
||||
self.logger.info(f"Skipping validation on day {day}")
|
||||
continue
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
with torch.autocast(device_type = "cuda", enabled = self.args['use_amp'], dtype = torch.bfloat16):
|
||||
features, n_time_steps = self.transform_data(features, n_time_steps, 'val')
|
||||
|
||||
adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32)
|
||||
|
||||
logits = self.model(features, day_indicies)
|
||||
|
||||
loss = self.ctc_loss(
|
||||
torch.permute(logits.log_softmax(2), [1, 0, 2]),
|
||||
labels,
|
||||
adjusted_lens,
|
||||
phone_seq_lens,
|
||||
)
|
||||
loss = torch.mean(loss)
|
||||
|
||||
metrics['losses'].append(loss.cpu().detach().numpy())
|
||||
|
||||
# Calculate PER per day and also avg over entire validation set
|
||||
batch_edit_distance = 0
|
||||
decoded_seqs = []
|
||||
for iterIdx in range(logits.shape[0]):
|
||||
decoded_seq = torch.argmax(logits[iterIdx, 0 : adjusted_lens[iterIdx], :].clone().detach(),dim=-1)
|
||||
decoded_seq = torch.unique_consecutive(decoded_seq, dim=-1)
|
||||
decoded_seq = decoded_seq.cpu().detach().numpy()
|
||||
decoded_seq = np.array([i for i in decoded_seq if i != 0])
|
||||
|
||||
trueSeq = np.array(
|
||||
labels[iterIdx][0 : phone_seq_lens[iterIdx]].cpu().detach()
|
||||
)
|
||||
|
||||
batch_edit_distance += F.edit_distance(decoded_seq, trueSeq)
|
||||
|
||||
decoded_seqs.append(decoded_seq)
|
||||
|
||||
day = batch['day_indicies'][0].item()
|
||||
|
||||
day_per[day]['total_edit_distance'] += batch_edit_distance
|
||||
day_per[day]['total_seq_length'] += torch.sum(phone_seq_lens).item()
|
||||
|
||||
|
||||
total_edit_distance += batch_edit_distance
|
||||
total_seq_length += torch.sum(phone_seq_lens)
|
||||
|
||||
# Record metrics
|
||||
if return_logits:
|
||||
metrics['logits'].append(logits.cpu().float().numpy()) # Will be in bfloat16 if AMP is enabled, so need to set back to float32
|
||||
metrics['n_time_steps'].append(adjusted_lens.cpu().numpy())
|
||||
|
||||
if return_data:
|
||||
metrics['input_features'].append(batch['input_features'].cpu().numpy())
|
||||
|
||||
metrics['decoded_seqs'].append(decoded_seqs)
|
||||
metrics['true_seq'].append(batch['seq_class_ids'].cpu().numpy())
|
||||
metrics['phone_seq_lens'].append(batch['phone_seq_lens'].cpu().numpy())
|
||||
metrics['transcription'].append(batch['transcriptions'].cpu().numpy())
|
||||
metrics['losses'].append(loss.detach().item())
|
||||
metrics['block_nums'].append(batch['block_nums'].numpy())
|
||||
metrics['trial_nums'].append(batch['trial_nums'].numpy())
|
||||
metrics['day_indicies'].append(batch['day_indicies'].cpu().numpy())
|
||||
|
||||
avg_PER = total_edit_distance / total_seq_length
|
||||
|
||||
metrics['day_PERs'] = day_per
|
||||
metrics['avg_PER'] = avg_PER.item()
|
||||
metrics['avg_loss'] = np.mean(metrics['losses'])
|
||||
|
||||
return metrics
|
6
model_training/train_model.py
Normal file
6
model_training/train_model.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from omegaconf import OmegaConf
|
||||
from rnn_trainer import BrainToTextDecoder_Trainer
|
||||
|
||||
args = OmegaConf.load('rnn_args.yaml')
|
||||
trainer = BrainToTextDecoder_Trainer(args)
|
||||
metrics = trainer.train()
|
Reference in New Issue
Block a user