diff --git a/README.md b/README.md index 690e8ac..a8a1911 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ The code is organized into four main directories: `utils`, `analyses`, `data`, a - The `model_training` directory contains the code necessary to train the brain-to-text model, including the offline model training and an offline simulation of the online finetuning pipeline, and also to run the language model. Note that the data used in the model training pipeline is simulated neural data, as the real neural data is not yet available. ## Python environment setup -The code is written in Python 3.9 and tested on Ubuntu 22.04. We recommend using a conda environment to manage the dependencies. +The code is written in Python 3.10 and tested on Ubuntu 22.04. We recommend using a conda environment to manage the dependencies. To install miniconda, follow the instructions [here](https://docs.anaconda.com/miniconda/miniconda-install/). diff --git a/model_training/README.md b/model_training/README.md new file mode 100644 index 0000000..b2d30aa --- /dev/null +++ b/model_training/README.md @@ -0,0 +1,23 @@ +# Model Training + +This directory contains code and resources for training the brain-to-text RNN model. This model is largely based on the architecture described in the paper "*An Accurate and Rapidly Calibrating Speech Neuroprosthesis*" by Card et al. (2024), but also contains modifications to improve performance, efficiency, and usability. + +## Setup + +1. Install the required conda environment by following the instructions in the root `README.md` file. This will set up the necessary dependencies for running the model training and evaluation code. + +2. Download the dataset from Dryad: [Dryad Dataset](https://datadryad.org/dataset/doi:10.5061/dryad.dncjsxm85). Place the downloaded data in the `data` directory. + +## Training + +To train the baseline RNN model, run the following command: +```bash +python train_model.py +``` + +## Evaluation + +To evaluate the model, run: +```bash +python evaluate_model.py +``` \ No newline at end of file diff --git a/model_training/data_augmentations.py b/model_training/data_augmentations.py new file mode 100644 index 0000000..7f4505a --- /dev/null +++ b/model_training/data_augmentations.py @@ -0,0 +1,37 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy.ndimage import gaussian_filter1d + +def gauss_smooth(inputs, device, smooth_kernel_std=2, smooth_kernel_size=100, padding='same'): + """ + Applies a 1D Gaussian smoothing operation with PyTorch to smooth the data along the time axis. + Args: + inputs (tensor : B x T x N): A 3D tensor with batch size B, time steps T, and number of features N. + Assumed to already be on the correct device (e.g., GPU). + kernelSD (float): Standard deviation of the Gaussian smoothing kernel. + padding (str): Padding mode, either 'same' or 'valid'. + device (str): Device to use for computation (e.g., 'cuda' or 'cpu'). + Returns: + smoothed (tensor : B x T x N): A smoothed 3D tensor with batch size B, time steps T, and number of features N. + """ + # Get Gaussian kernel + inp = np.zeros(smooth_kernel_size, dtype=np.float32) + inp[smooth_kernel_size // 2] = 1 + gaussKernel = gaussian_filter1d(inp, smooth_kernel_std) + validIdx = np.argwhere(gaussKernel > 0.01) + gaussKernel = gaussKernel[validIdx] + gaussKernel = np.squeeze(gaussKernel / np.sum(gaussKernel)) + + # Convert to tensor + gaussKernel = torch.tensor(gaussKernel, dtype=torch.float32, device=device) + gaussKernel = gaussKernel.view(1, 1, -1) # [1, 1, kernel_size] + + # Prepare convolution + B, T, C = inputs.shape + inputs = inputs.permute(0, 2, 1) # [B, C, T] + gaussKernel = gaussKernel.repeat(C, 1, 1) # [C, 1, kernel_size] + + # Perform convolution + smoothed = F.conv1d(inputs, gaussKernel, padding=padding, groups=C) + return smoothed.permute(0, 2, 1) # [B, T, C] \ No newline at end of file diff --git a/model_training/dataset.py b/model_training/dataset.py new file mode 100644 index 0000000..2442f6e --- /dev/null +++ b/model_training/dataset.py @@ -0,0 +1,334 @@ +import os +import torch +from torch.utils.data import Dataset +import h5py +import numpy as np +from torch.nn.utils.rnn import pad_sequence +import math + +class BrainToTextDataset(Dataset): + ''' + Dataset for brain-to-text data + + Returns an entire batch of data instead of a single example + ''' + + def __init__( + self, + trial_indicies, + n_batches, + split = 'train', + batch_size = 64, + days_per_batch = 1, + random_seed = -1, + must_include_days = None, + feature_subset = None + ): + ''' + trial_indicies: (dict) - dictionary with day numbers as keys and lists of trial indices as values + n_batches: (int) - number of random training batches to create + split: (string) - string specifying if this is a train or test dataset + batch_size: (int) - number of examples to include in batch returned from __getitem_() + days_per_batch: (int) - how many unique days can exist in a batch; this is important for making sure that updates + to individual day layers in the GRU are not excesively noisy. Validation data will always have 1 day per batch + random_seed: (int) - seed to set for randomly assigning trials to a batch. If set to -1, trial assignment will be random + must_include_days ([int]) - list of days that must be included in every batch + feature_subset ([int]) - list of neural feature indicies that should be the only features included in the neural data + ''' + + # Set random seed for reproducibility + if random_seed != -1: + np.random.seed(random_seed) + torch.manual_seed(random_seed) + + self.split = split + + # Ensure the split is valid + if self.split not in ['train', 'test']: + raise ValueError(f'split must be either "train" or "test". Received {self.split}') + + self.days_per_batch = days_per_batch + + self.batch_size = batch_size + + self.n_batches = n_batches + + self.days = {} + self.n_trials = 0 + self.trial_indicies = trial_indicies + self.n_days = len(trial_indicies.keys()) + + self.feature_subset = feature_subset + + # Calculate total number of trials in the dataset + for d in trial_indicies: + self.n_trials += len(trial_indicies[d]['trials']) + + if must_include_days is not None and len(must_include_days) > days_per_batch: + raise ValueError(f'must_include_days must be less than or equal to days_per_batch. Received {must_include_days} and days_per_batch {days_per_batch}') + + if must_include_days is not None and len(must_include_days) > self.n_days and split != 'train': + raise ValueError(f'must_include_days is not valid for test data. Received {must_include_days} and but only {self.n_days} in the dataset') + + if must_include_days is not None: + # Map must_include_days to correct indicies if they are negative + for i, d in enumerate(must_include_days): + if d < 0: + must_include_days[i] = self.n_days + d + + self.must_include_days = must_include_days + + # Ensure that the days_per_batch is not greater than the number of days in the dataset. Raise error + if self.split == 'train' and self.days_per_batch > self.n_days: + raise ValueError(f'Requested days_per_batch: {days_per_batch} is greater than available days {self.n_days}.') + + + if self.split == 'train': + self.batch_index = self.create_batch_index_train() + else: + self.batch_index = self.create_batch_index_test() + self.n_batches = len(self.batch_index.keys()) # The validation data has a fixed amount of data + + def __len__(self): + ''' + How many batches are in this dataset. + Because training data is sampled randomly, there is no fixed dataset length, + however this method is required for DataLoader to work + ''' + return self.n_batches + + def __getitem__(self, idx): + ''' + Gets an entire batch of data from the dataset, not just a single item + ''' + batch = { + 'input_features' : [], + 'seq_class_ids' : [], + 'n_time_steps' : [], + 'phone_seq_lens' : [], + 'day_indicies' : [], + 'transcriptions' : [], + 'block_nums' : [], + 'trial_nums' : [], + } + + index = self.batch_index[idx] + + # Iterate through each day in the index + for d in index.keys(): + + # Open the hdf5 file for that day + with h5py.File(self.trial_indicies[d]['session_path'], 'r') as f: + + # For each trial in the selected trials in that day + for t in index[d]: + + try: + g = f[f'trial_{t}'] + + # Remove features is neccessary + input_features = torch.from_numpy(g['input_features'][:]) # neural data + if self.feature_subset: + input_features = input_features[:,self.feature_subset] + + batch['input_features'].append(input_features) + + batch['seq_class_ids'].append(torch.from_numpy(g['seq_class_ids'][:])) # phoneme labels + batch['transcriptions'].append(torch.from_numpy(g['transcription'][:])) # character level transcriptions + batch['n_time_steps'].append(g.attrs['n_time_steps']) # number of time steps in the trial - required since we are padding + batch['phone_seq_lens'].append(g.attrs['seq_len']) # number of phonemes in the label - required since we are padding + batch['day_indicies'].append(int(d)) # day index of each trial - required for the day specific layers + batch['block_nums'].append(g.attrs['block_num']) + batch['trial_nums'].append(g.attrs['trial_num']) + + except Exception as e: + print(f'Error loading trial {t} from session {self.trial_indicies[d]["session_path"]}: {e}') + continue + + # Pad data to form a cohesive batch + batch['input_features'] = pad_sequence(batch['input_features'], batch_first = True, padding_value = 0) + batch['seq_class_ids'] = pad_sequence(batch['seq_class_ids'], batch_first = True, padding_value = 0) + + batch['n_time_steps'] = torch.tensor(batch['n_time_steps']) + batch['phone_seq_lens'] = torch.tensor(batch['phone_seq_lens']) + batch['day_indicies'] = torch.tensor(batch['day_indicies']) + batch['transcriptions'] = torch.stack(batch['transcriptions']) + batch['block_nums'] = torch.tensor(batch['block_nums']) + batch['trial_nums'] = torch.tensor(batch['trial_nums']) + + return batch + + + def create_batch_index_train(self): + ''' + Create an index that maps a batch_number to batch_size number of trials + + Each batch will have days_per_batch unique days of data, with the number of trials for each day evenly split between the days + (or as even as possible if batch_size is not divisible by days_per_batch) + ''' + + batch_index = {} + + # Precompute the days that are not in must_include_days + if self.must_include_days is not None: + non_must_include_days = [d for d in self.trial_indicies.keys() if d not in self.must_include_days] + + for batch_idx in range(self.n_batches): + batch = {} + + # Which days will be used for this batch. Picked randomly without replacement + # TODO: In the future we may want to consider sampling days in proportion to the number of trials in each day + + # If must_include_days is not empty, we will use those days and then randomly sample the rest + if self.must_include_days is not None and len(self.must_include_days) > 0: + + days = np.concatenate((self.must_include_days, np.random.choice(non_must_include_days, size = self.days_per_batch - len(self.must_include_days), replace = False))) + + # Otherwise we will select random days without replacement + else: + days = np.random.choice(list(self.trial_indicies.keys()), size = self.days_per_batch, replace = False) + + # How many trials will be sampled from each day + num_trials = math.ceil(self.batch_size / self.days_per_batch) # Use ceiling to make sure we get at least batch_size trials + + for d in days: + + # Trials are sampled with replacement, so if a day has less than (self.batch_size / days_per_batch trials) trials, it won't be a problem + trial_idxs = np.random.choice(self.trial_indicies[d]['trials'], size = num_trials, replace = True) + batch[d] = trial_idxs + + # Remove extra trials + extra_trials = (num_trials * len(days)) - self.batch_size + + # While we still have extra trials, remove the last trial from a random day + while extra_trials > 0: + d = np.random.choice(days) + batch[d] = batch[d][:-1] + extra_trials -= 1 + + batch_index[batch_idx] = batch + + return batch_index + + def create_batch_index_test(self): + ''' + Create an index that is all validation/testing data in batches of up to self.batch_size + + If a day does not have at least self.batch_size trials, then the batch size will be less than self.batch_size + + This index will ensures that every trial in the validation set is seen once and only once + ''' + batch_index = {} + batch_idx = 0 + + for d in self.trial_indicies.keys(): + + # Calculate how many batches we need for this day + num_trials = len(self.trial_indicies[d]['trials']) + num_batches = (num_trials + self.batch_size - 1) // self.batch_size + + # Create batches for this day + for i in range(num_batches): + start_idx = i * self.batch_size + end_idx = min((i + 1) * self.batch_size, num_trials) + + # Get the trial indices for this batch + batch_trials = self.trial_indicies[d]['trials'][start_idx:end_idx] + + # Add to batch_index + batch_index[batch_idx] = {d : batch_trials} + batch_idx += 1 + + return batch_index + +def train_test_split_indicies(file_paths, test_percentage = 0.1, seed = -1, bad_trials_dict = None): + ''' + Split data from file_paths into train and test splits + Returns two dictionaries that detail which trials in each day will be a part of that split: + Example: + { + 0: trials[1,2,3], session_path: 'path' + 1: trials[2,5,6], session_path: 'path' + } + + Args: + file_paths (list): List of file paths to the hdf5 files containing the data + test_percentage (float): Percentage of trials to use for testing. 0 will use all trials for training, 1 will use all trials for testing + seed (int): Seed for reproducibility. If set to -1, the split will be random + bad_trials_dict (dict): Dictionary of trials to exclude from the dataset. Formatted as: + { + 'session_name_1': {block_num_1: [trial_nums], block_num_2: [trial_nums], ...}, + 'session_name_2': {block_num_1: [trial_nums], block_num_2: [trial_nums], ...}, + ... + } + ''' + # Set seed for reporoducibility + if seed != -1: + np.random.seed(seed) + + # Get trials in each day + trials_per_day = {} + for i, path in enumerate(file_paths): + session = [s for s in path.split('/') if (s.startswith('t15.20') or s.startswith('t12.20'))][0] + + good_trial_indices = [] + + if os.path.exists(path): + with h5py.File(path, 'r') as f: + num_trials = len(list(f.keys())) + for t in range(num_trials): + key = f'trial_{t}' + + block_num = f[key].attrs['block_num'] + trial_num = f[key].attrs['trial_num'] + + if ( + bad_trials_dict is not None + and session in bad_trials_dict + and str(block_num) in bad_trials_dict[session] + and trial_num in bad_trials_dict[session][str(block_num)] + ): + # print(f'Bad trial: {session}_{block_num}_{trial_num}') + continue + + good_trial_indices.append(t) + + trials_per_day[i] = {'num_trials': len(good_trial_indices), 'trial_indices': good_trial_indices, 'session_path': path} + + # Pick test_percentage of trials from each day for testing and (1 - test_percentage) for training + train_trials = {} + test_trials = {} + + for day in trials_per_day.keys(): + + num_trials = trials_per_day[day]['num_trials'] + + # Generate all trial indices for this day (assuming 0-indexed) + all_trial_indices = trials_per_day[day]['trial_indices'] + + # If test_percentage is 0 or 1, we can just assign all trials to either train or test + if test_percentage == 0: + train_trials[day] = {'trials' : all_trial_indices, 'session_path' : trials_per_day[day]['session_path']} + test_trials[day] = {'trials' : [], 'session_path' : trials_per_day[day]['session_path']} + continue + + elif test_percentage == 1: + train_trials[day] = {'trials' : [], 'session_path' : trials_per_day[day]['session_path']} + test_trials[day] = {'trials' : all_trial_indices, 'session_path' : trials_per_day[day]['session_path']} + continue + + else: + # Calculate how many trials to use for testing + num_test = max(1, int(num_trials * test_percentage)) + + # Randomly select indices for testing + test_indices = np.random.choice(all_trial_indices, size=num_test, replace=False).tolist() + + # Remaining indices go to training + train_indices = [idx for idx in all_trial_indices if idx not in test_indices] + + # Store the split indices + train_trials[day] = {'trials' : train_indices, 'session_path' : trials_per_day[day]['session_path']} + test_trials[day] = {'trials' : test_indices, 'session_path' : trials_per_day[day]['session_path']} + + return train_trials, test_trials \ No newline at end of file diff --git a/model_training/rnn_args.yaml b/model_training/rnn_args.yaml new file mode 100644 index 0000000..01e914c --- /dev/null +++ b/model_training/rnn_args.yaml @@ -0,0 +1,169 @@ +model: + n_input_features: 512 + n_units: 768 + rnn_dropout: 0.4 + rnn_trainable: true + n_layers: 5 + bidirectional: false + patch_size: 14 + patch_stride: 4 + input_network: + n_input_layers: 1 + input_layer_sizes: + - 512 + input_trainable: true + input_layer_dropout: 0.2 +gpu_number: '1' +distributed_training: false +mode: train +use_amp: true +output_dir: /media/lm-pc/8tb_nvme/b2txt25/rnn_v2_jitter +init_from_checkpoint: false +checkpoint_dir: /media/lm-pc/8tb_nvme/b2txt25/rnn_v2_jitter/checkpoint +init_checkpoint_path: None +save_best_checkpoint: true +save_all_val_steps: false +save_final_model: false +save_val_metrics: true +early_stopping: false +early_stopping_val_steps: 20 +num_training_batches: 120000 +lr_scheduler_type: cosine +lr_max: 0.005 +lr_min: 0.0001 +lr_decay_steps: 120000 +lr_warmup_steps: 1000 +lr_max_day: 0.005 +lr_min_day: 0.0001 +lr_decay_steps_day: 120000 +lr_warmup_steps_day: 1000 +beta0: 0.9 +beta1: 0.999 +epsilon: 0.1 +weight_decay: 0.001 +weight_decay_day: 0 +seed: 10 +grad_norm_clip_value: 10 +batches_per_train_log: 200 +batches_per_val_step: 2000 +batches_per_save: 0 +log_individual_day_val_PER: true +log_val_skip_logs: false +save_val_logits: true +save_val_data: false +dataset: + data_transforms: + white_noise_std: 1.0 + constant_offset_std: 0.2 + random_walk_std: 0.0 + random_walk_axis: -1 + static_gain_std: 0.0 + random_cut: 3 #0 + smooth_kernel_size: 100 + smooth_data: true + smooth_kernel_std: 2 + neural_dim: 512 + batch_size: 64 + n_classes: 41 + max_seq_elements: 500 + days_per_batch: 4 + seed: 1 + num_dataloader_workers: 4 + loader_shuffle: false + must_include_days: null + test_percentage: 0.1 + feature_subset: null + dataset_dir: /media/lm-pc/8tb_nvme/b2txt25/hdf5_data + bad_trials_dict: null + sessions: + - t15.2023.08.11 + - t15.2023.08.13 + - t15.2023.08.18 + - t15.2023.08.20 + - t15.2023.08.25 + - t15.2023.08.27 + - t15.2023.09.01 + - t15.2023.09.03 + - t15.2023.09.24 + - t15.2023.09.29 + - t15.2023.10.01 + - t15.2023.10.06 + - t15.2023.10.08 + - t15.2023.10.13 + - t15.2023.10.15 + - t15.2023.10.20 + - t15.2023.10.22 + - t15.2023.11.03 + - t15.2023.11.04 + - t15.2023.11.17 + - t15.2023.11.19 + - t15.2023.11.26 + - t15.2023.12.03 + - t15.2023.12.08 + - t15.2023.12.10 + - t15.2023.12.17 + - t15.2023.12.29 + - t15.2024.02.25 + - t15.2024.03.03 + - t15.2024.03.08 + - t15.2024.03.15 + - t15.2024.03.17 + - t15.2024.04.25 + - t15.2024.04.28 + - t15.2024.05.10 + - t15.2024.06.14 + - t15.2024.07.19 + - t15.2024.07.21 + - t15.2024.07.28 + - t15.2025.01.10 + - t15.2025.01.12 + - t15.2025.03.14 + - t15.2025.03.16 + - t15.2025.03.30 + - t15.2025.04.13 + dataset_probability_val: + - 0 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 0 + - 1 + - 1 + - 1 + - 0 + - 0 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 + - 1 \ No newline at end of file diff --git a/model_training/rnn_model.py b/model_training/rnn_model.py new file mode 100644 index 0000000..ec1a624 --- /dev/null +++ b/model_training/rnn_model.py @@ -0,0 +1,136 @@ +import torch +from torch import nn + +class GRUDecoder(nn.Module): + ''' + Defines the GRU decoder + + This class combines day-specific input layers, a GRU, and an output classification layer + ''' + def __init__(self, + neural_dim, + n_units, + n_days, + n_classes, + rnn_dropout = 0.0, + input_dropout = 0.0, + n_layers = 5, + patch_size = 0, + patch_stride = 0, + ): + ''' + neural_dim (int) - number of channels in a single timestep (e.g. 512) + n_units (int) - number of hidden units in each recurrent layer - equal to the size of the hidden state + n_days (int) - number of days in the dataset + n_classes (int) - number of classes + rnn_dropout (float) - percentage of units to droupout during training + input_dropout (float) - percentage of input units to dropout during training + n_layers (int) - number of recurrent layers + patch_size (int) - the number of timesteps to concat on initial input layer - a value of 0 will disable this "input concat" step + patch_stride(int) - the number of timesteps to stride over when concatenating initial input + ''' + super(GRUDecoder, self).__init__() + + self.neural_dim = neural_dim + self.n_units = n_units + self.n_classes = n_classes + self.n_layers = n_layers + self.n_days = n_days + + self.rnn_dropout = rnn_dropout + self.input_dropout = input_dropout + + self.patch_size = patch_size + self.patch_stride = patch_stride + + # Parameters for the day-specific input layers + self.day_layer_activation = nn.Softsign() # basically a shallower tanh + + # Set weights for day layers to be identity matrices so the model can learn its own day-specific transformations + self.day_weights = nn.ParameterList( + [nn.Parameter(torch.eye(self.neural_dim)) for _ in range(self.n_days)] + ) + self.day_biases = nn.ParameterList( + [nn.Parameter(torch.zeros(1, self.neural_dim)) for _ in range(self.n_days)] + ) + + self.day_layer_dropout = nn.Dropout(input_dropout) + + self.input_size = self.neural_dim + + # If we are using "strided inputs", then the input size of the first recurrent layer will actually be in_size * patch_size + if self.patch_size > 0: + self.input_size *= self.patch_size + + self.gru = nn.GRU( + input_size = self.input_size, + hidden_size = self.n_units, + num_layers = self.n_layers, + dropout = self.rnn_dropout, + batch_first = True, # The first dim of our input is the batch dim + bidirectional = False, + ) + + # Set recurrent units to have orthogonal param init and input layers to have xavier init + for name, param in self.gru.named_parameters(): + if "weight_hh" in name: + nn.init.orthogonal_(param) + if "weight_ih" in name: + nn.init.xavier_uniform_(param) + + # Prediciton head. Weight init to xavier + self.out = nn.Linear(self.n_units, self.n_classes) + nn.init.xavier_uniform_(self.out.weight) + + # Learnable initial hidden states + self.h0 = nn.Parameter(nn.init.xavier_uniform_(torch.zeros(1, 1, self.n_units))) + + def forward(self, x, day_idx, states = None, return_state = False): + ''' + x (tensor) - batch of examples (trials) of shape: (batch_size, time_series_length, neural_dim) + day_idx (tensor) - tensor which is a list of day indexs corresponding to the day of each example in the batch x. + ''' + + # Apply day-specific layer to (hopefully) project neural data from the different days to the same latent space + day_weights = torch.stack([self.day_weights[i] for i in day_idx], dim=0) + day_biases = torch.cat([self.day_biases[i] for i in day_idx], dim=0).unsqueeze(1) + + x = torch.einsum("btd,bdk->btk", x, day_weights) + day_biases + x = self.day_layer_activation(x) + + # Apply dropout to the ouput of the day specific layer + if self.input_dropout > 0: + x = self.day_layer_dropout(x) + + # (Optionally) Perform input concat operation + if self.patch_size > 0: + + x = x.unsqueeze(1) # [batches, 1, timesteps, feature_dim] + x = x.permute(0, 3, 1, 2) # [batches, feature_dim, 1, timesteps] + + # Extract patches using unfold (sliding window) + x_unfold = x.unfold(3, self.patch_size, self.patch_stride) # [batches, feature_dim, 1, num_patches, patch_size] + + # Remove dummy height dimension and rearrange dimensions + x_unfold = x_unfold.squeeze(2) # [batches, feature_dum, num_patches, patch_size] + x_unfold = x_unfold.permute(0, 2, 3, 1) # [batches, num_patches, patch_size, feature_dim] + + # Flatten last two dimensions (patch_size and features) + x = x_unfold.reshape(x.size(0), x_unfold.size(1), -1) + + # Determine initial hidden states + if states is None: + states = self.h0.expand(self.n_layers, x.shape[0], self.n_units).contiguous() + + # Pass input through RNN + output, hidden_states = self.gru(x, states) + + # Compute logits + logits = self.out(output) + + if return_state: + return logits, hidden_states + + return logits + + diff --git a/model_training/rnn_trainer.py b/model_training/rnn_trainer.py new file mode 100644 index 0000000..942b671 --- /dev/null +++ b/model_training/rnn_trainer.py @@ -0,0 +1,754 @@ +import torch +from torch.utils.data import DataLoader +from torch.optim.lr_scheduler import LambdaLR +import random +import time +import os +import numpy as np +import math +import pathlib +import logging +import sys +import json +import pickle + +from dataset import BrainToTextDataset, train_test_split_indicies +from data_augmentations import gauss_smooth + +import torchaudio.functional as F # for edit distance +from omegaconf import OmegaConf + +torch.set_float32_matmul_precision('high') # makes float32 matmuls faster on some GPUs +torch.backends.cudnn.deterministic = True # makes training more reproducible +torch._dynamo.config.cache_size_limit = 64 + +from rnn_model import GRUDecoder + +class BrainToTextDecoder_Trainer: + """ + This class will initialize and train a brain-to-text phoneme decoder + + Written by Nick Card and Zachery Fogg with reference to Stanford NPTL's decoding function + """ + + def __init__(self, args): + ''' + args : dictionary of training arguments + ''' + + # Trainer fields + self.args = args + self.logger = None + self.device = None + self.model = None + self.optimizer = None + self.learning_rate_scheduler = None + self.ctc_loss = None + + self.best_val_PER = torch.inf # track best PER for checkpointing + self.best_val_loss = torch.inf # track best loss for checkpointing + + self.train_dataset = None + self.val_dataset = None + self.train_loader = None + self.val_loader = None + + self.transform_args = self.args['dataset']['data_transforms'] + + # Create output directory + if args['mode'] == 'train': + os.makedirs(self.args['output_dir'], exist_ok=True) + + # Create checkpoint directory + if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']: + os.makedirs(self.args['checkpoint_dir'], exist_ok = True) + + # Set up logging + self.logger = logging.getLogger(__name__) + for handler in self.logger.handlers[:]: # make a copy of the list + self.logger.removeHandler(handler) + self.logger.setLevel(logging.INFO) + formatter = logging.Formatter(fmt='%(asctime)s: %(message)s') + + if args['mode']=='train': + # During training, save logs to file in output directory + fh = logging.FileHandler(str(pathlib.Path(self.args['output_dir'],'training_log'))) + fh.setFormatter(formatter) + self.logger.addHandler(fh) + + # Always print logs to stdout + sh = logging.StreamHandler(sys.stdout) + sh.setFormatter(formatter) + self.logger.addHandler(sh) + + # Configure device pytorch will use + if not self.args['distributed_training']: + if torch.cuda.is_available(): + self.device = f"cuda:{self.args['gpu_number']}" + + else: + self.device = "cpu" + else: + self.device = "cuda" + + self.logger.info(f'Using device: {self.device}') + + # Set seed if provided + if self.args['seed'] != -1: + np.random.seed(self.args['seed']) + random.seed(self.args['seed']) + torch.manual_seed(self.args['seed']) + + # Initialize the model + self.model = GRUDecoder( + neural_dim = self.args['model']['n_input_features'], + n_units = self.args['model']['n_units'], + n_days = len(self.args['dataset']['sessions']), + n_classes = self.args['dataset']['n_classes'], + rnn_dropout = self.args['model']['rnn_dropout'], + input_dropout = self.args['model']['input_network']['input_layer_dropout'], + n_layers = self.args['model']['n_layers'], + bidirectional = self.args['model']['bidirectional'], + patch_size = self.args['model']['patch_size'], + patch_stride = self.args['model']['patch_stride'], + ) + + # Call torch.compile to speed up training + self.logger.info("Using torch.compile") + self.model = torch.compile(self.model) + + self.logger.info(f"Initialized RNN decoding model") + + self.logger.info(self.model) + + # Log how many parameters are in the model + total_params = sum(p.numel() for p in self.model.parameters()) + self.logger.info(f"Model has {total_params:,} parameters") + + # Determine how many day-specific parameters are in the model + day_params = 0 + for name, param in self.model.named_parameters(): + if 'day' in name: + day_params += param.numel() + + self.logger.info(f"Model has {day_params:,} day-specific parameters | {((day_params / total_params) * 100):.2f}% of total parameters") + + # Create datasets and dataloaders + train_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_train.hdf5') for s in self.args['dataset']['sessions']] + val_file_paths = [os.path.join(self.args["dataset"]["dataset_dir"],s,'data_val.hdf5') for s in self.args['dataset']['sessions']] + + # Ensure that there are no duplicate days + if len(set(train_file_paths)) != len(train_file_paths): + raise ValueError("There are duplicate sessions listed in the train dataset") + if len(set(val_file_paths)) != len(val_file_paths): + raise ValueError("There are duplicate sessions listed in the val dataset") + + # Split trials into train and test sets + train_trials, _ = train_test_split_indicies( + file_paths = train_file_paths, + test_percentage = 0, + seed = self.args['dataset']['seed'], + bad_trials_dict = None, + ) + _, val_trials = train_test_split_indicies( + file_paths = val_file_paths, + test_percentage = 1, + seed = self.args['dataset']['seed'], + bad_trials_dict = None, + ) + + # Save dictionaries to output directory to know which trials were train vs val + with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f: + json.dump({'train' : train_trials, 'val': val_trials}, f) + + # Determine if a only a subset of neural features should be used + feature_subset = None + if ('feature_subset' in self.args['dataset']) and self.args['dataset']['feature_subset'] != None: + feature_subset = self.args['dataset']['feature_subset'] + self.logger.info(f'Using only a subset of features: {feature_subset}') + + # train dataset and dataloader + self.train_dataset = BrainToTextDataset( + trial_indicies = train_trials, + split = 'train', + days_per_batch = self.args['dataset']['days_per_batch'], + n_batches = self.args['num_training_batches'], + batch_size = self.args['dataset']['batch_size'], + must_include_days = None, + random_seed = self.args['dataset']['seed'], + feature_subset = feature_subset + ) + self.train_loader = DataLoader( + self.train_dataset, + batch_size = None, # Dataset.__getitem__() already returns batches + shuffle = self.args['dataset']['loader_shuffle'], + num_workers = self.args['dataset']['num_dataloader_workers'], + pin_memory = True + ) + + # val dataset and dataloader + self.val_dataset = BrainToTextDataset( + trial_indicies = val_trials, + split = 'test', + days_per_batch = None, + n_batches = None, + batch_size = self.args['dataset']['batch_size'], + must_include_days = None, + random_seed = self.args['dataset']['seed'], + feature_subset = feature_subset + ) + self.val_loader = DataLoader( + self.val_dataset, + batch_size = None, # Dataset.__getitem__() already returns batches + shuffle = False, + num_workers = 0, + pin_memory = True + ) + + self.logger.info("Successfully initialized datasets") + + # Create optimizer, learning rate scheduler, and loss + self.optimizer = self.create_optimizer() + + if self.args['lr_scheduler_type'] == 'linear': + self.learning_rate_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer = self.optimizer, + start_factor = 1.0, + end_factor = self.args['lr_min'] / self.args['lr_max'], + total_iters = self.args['lr_decay_steps'], + ) + elif self.args['lr_scheduler_type'] == 'cosine': + self.learning_rate_scheduler = self.create_cosine_lr_scheduler(self.optimizer) + + else: + raise ValueError(f"Invalid learning rate scheduler type: {self.args['lr_scheduler_type']}") + + self.ctc_loss = torch.nn.CTCLoss(blank = 0, reduction = 'none', zero_infinity = False) + + # If a checkpoint is provided, then load from checkpoint + if self.args['init_from_checkpoint']: + self.load_model_checkpoint(self.args['init_checkpoint_path']) + + # Set rnn and/or input layers to not trainable if specified + for name, param in self.model.named_parameters(): + if not self.args['model']['rnn_trainable'] and 'gru' in name: + param.requires_grad = False + + elif not self.args['model']['input_network']['input_trainable'] and 'day' in name: + param.requires_grad = False + + # Send model to device + self.model.to(self.device) + + def create_optimizer(self): + ''' + Create the optimizer with special param groups + + Biases and day weights should not be decayed + + Day weights should have a separate learning rate + ''' + bias_params = [p for name, p in self.model.named_parameters() if 'gru.bias' in name or 'out.bias' in name] + day_params = [p for name, p in self.model.named_parameters() if 'day_' in name] + other_params = [p for name, p in self.model.named_parameters() if 'day_' not in name and 'gru.bias' not in name and 'out.bias' not in name] + + if len(day_params) != 0: + param_groups = [ + {'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'}, + {'params' : day_params, 'lr' : self.args['lr_max_day'], 'weight_decay' : self.args['weight_decay_day'], 'group_type' : 'day_layer'}, + {'params' : other_params, 'group_type' : 'other'} + ] + else: + param_groups = [ + {'params' : bias_params, 'weight_decay' : 0, 'group_type' : 'bias'}, + {'params' : other_params, 'group_type' : 'other'} + ] + + optim = torch.optim.AdamW( + param_groups, + lr = self.args['lr_max'], + betas = (self.args['beta0'], self.args['beta1']), + eps = self.args['epsilon'], + weight_decay = self.args['weight_decay'], + fused = True + ) + + return optim + + def create_cosine_lr_scheduler(self, optim): + lr_max = self.args['lr_max'] + lr_min = self.args['lr_min'] + lr_decay_steps = self.args['lr_decay_steps'] + + lr_max_day = self.args['lr_max_day'] + lr_min_day = self.args['lr_min_day'] + lr_decay_steps_day = self.args['lr_decay_steps_day'] + + lr_warmup_steps = self.args['lr_warmup_steps'] + lr_warmup_steps_day = self.args['lr_warmup_steps_day'] + + def lr_lambda(current_step, min_lr_ratio, decay_steps, warmup_steps): + ''' + Create lr lambdas for each param group that implement cosine decay + + Different lr lambda decaying for day params vs rest of the model + ''' + # Warmup phase + if current_step < warmup_steps: + return float(current_step) / float(max(1, warmup_steps)) + + # Cosine decay phase + if current_step < decay_steps: + progress = float(current_step - warmup_steps) / float( + max(1, decay_steps - warmup_steps) + ) + cosine_decay = 0.5 * (1 + math.cos(math.pi * progress)) + # Scale from 1.0 to min_lr_ratio + return max(min_lr_ratio, min_lr_ratio + (1 - min_lr_ratio) * cosine_decay) + + # After cosine decay is complete, maintain min_lr_ratio + return min_lr_ratio + + if len(optim.param_groups) == 3: + lr_lambdas = [ + lambda step: lr_lambda( + step, + lr_min / lr_max, + lr_decay_steps, + lr_warmup_steps), # biases + lambda step: lr_lambda( + step, + lr_min_day / lr_max_day, + lr_decay_steps_day, + lr_warmup_steps_day, + ), # day params + lambda step: lr_lambda( + step, + lr_min / lr_max, + lr_decay_steps, + lr_warmup_steps), # rest of model weights + ] + elif len(optim.param_groups) == 2: + lr_lambdas = [ + lambda step: lr_lambda( + step, + lr_min / lr_max, + lr_decay_steps, + lr_warmup_steps), # biases + lambda step: lr_lambda( + step, + lr_min / lr_max, + lr_decay_steps, + lr_warmup_steps), # rest of model weights + ] + else: + raise ValueError(f"Invalid number of param groups in optimizer: {len(optim.param_groups)}") + + return LambdaLR(optim, lr_lambdas, -1) + + def load_model_checkpoint(self, load_path): + ''' + Load a training checkpoint + ''' + checkpoint = torch.load(load_path, weights_only = False) # checkpoint is just a dict + + self.model.load_state_dict(checkpoint['model_state_dict']) + self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + self.best_val_PER = checkpoint['val_PER'] # best phoneme error rate + self.best_val_loss = checkpoint['val_loss'] if 'val_loss' in checkpoint.keys() else torch.inf + + self.model.to(self.device) + + # Send optimizer params back to GPU + for state in self.optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(self.device) + + self.logger.info("Loaded model from checkpoint: " + load_path) + + def save_model_checkpoint(self, save_path, PER, loss): + ''' + Save a training checkpoint + ''' + + checkpoint = { + 'model_state_dict' : self.model.state_dict(), + 'optimizer_state_dict' : self.optimizer.state_dict(), + 'scheduler_state_dict' : self.learning_rate_scheduler.state_dict(), + 'val_PER' : PER, + 'val_loss' : loss + } + + torch.save(checkpoint, save_path) + + self.logger.info("Saved model to checkpoint: " + save_path) + + # Save the args file alongside the checkpoint + with open(os.path.join(self.args['checkpoint_dir'], 'args.yaml'), 'w') as f: + OmegaConf.save(config=self.args, f=f) + + def create_attention_mask(self, sequence_lengths): + + max_length = torch.max(sequence_lengths).item() + + batch_size = sequence_lengths.size(0) + + # Create a mask for valid key positions (columns) + # Shape: [batch_size, max_length] + key_mask = torch.arange(max_length, device=sequence_lengths.device).expand(batch_size, max_length) + key_mask = key_mask < sequence_lengths.unsqueeze(1) + + # Expand key_mask to [batch_size, 1, 1, max_length] + # This will be broadcast across all query positions + key_mask = key_mask.unsqueeze(1).unsqueeze(1) + + # Create the attention mask of shape [batch_size, 1, max_length, max_length] + # by broadcasting key_mask across all query positions + attention_mask = key_mask.expand(batch_size, 1, max_length, max_length) + + # Convert boolean mask to float mask: + # - True (valid key positions) -> 0.0 (no change to attention scores) + # - False (padding key positions) -> -inf (will become 0 after softmax) + attention_mask_float = torch.where(attention_mask, + True, + False) + + return attention_mask_float + + def transform_data(self, features, n_time_steps, mode = 'train'): + ''' + Apply various augmentations and smoothing to data + Performing augmentations is much faster on GPU than CPU + ''' + + data_shape = features.shape + batch_size = data_shape[0] + channels = data_shape[-1] + + # We only apply these augmentations in training + if mode == 'train': + # add static gain noise + if self.transform_args['static_gain_std'] > 0: + warp_mat = torch.tile(torch.unsqueeze(torch.eye(channels), dim = 0), (batch_size, 1, 1)) + warp_mat += torch.randn_like(warp_mat, device=self.device) * self.transform_args['static_gain_std'] + + features = torch.matmul(features, warp_mat) + + # add white noise + if self.transform_args['white_noise_std'] > 0: + features += torch.randn(data_shape, device=self.device) * self.transform_args['white_noise_std'] + + # add constant offset noise + if self.transform_args['constant_offset_std'] > 0: + features += torch.randn((batch_size, 1, channels), device=self.device) * self.transform_args['constant_offset_std'] + + # add random walk noise + if self.transform_args['random_walk_std'] > 0: + features += torch.cumsum(torch.randn(data_shape, device=self.device) * self.transform_args['random_walk_std'], dim =self.transform_args['random_walk_axis']) + + # randomly cutoff part of the data timecourse + if self.transform_args['random_cut'] > 0: + cut = np.random.randint(0, self.transform_args['random_cut']) + features = features[:, cut:, :] + n_time_steps = n_time_steps - cut + + # Apply Gaussian smoothing to data + # This is done in both training and validation + if self.transform_args['smooth_data']: + features = gauss_smooth( + inputs = features, + device = self.device, + smooth_kernel_std = self.transform_args['smooth_kernel_std'], + smooth_kernel_size= self.transform_args['smooth_kernel_size'], + ) + + + return features, n_time_steps + + def train(self): + ''' + Train the model + ''' + + # Set model to train mode (specificially to make sure dropout layers are engaged) + self.model.train() + + # create vars to track performance + train_losses = [] + val_losses = [] + val_PERs = [] + val_results = [] + + val_steps_since_improvement = 0 + + # training params + save_best_checkpoint = self.args.get('save_best_checkpoint', True) + early_stopping = self.args.get('early_stopping', True) + + early_stopping_val_steps = self.args['early_stopping_val_steps'] + + train_start_time = time.time() + + # train for specified number of batches + for i, batch in enumerate(self.train_loader): + + self.model.train() + self.optimizer.zero_grad() + + # Train step + start_time = time.time() + + # Move data to device + features = batch['input_features'].to(self.device) + labels = batch['seq_class_ids'].to(self.device) + n_time_steps = batch['n_time_steps'].to(self.device) + phone_seq_lens = batch['phone_seq_lens'].to(self.device) + day_indicies = batch['day_indicies'].to(self.device) + + # Use autocast for efficiency + with torch.autocast(device_type = "cuda", enabled = self.args['use_amp'], dtype = torch.bfloat16): + + # Apply augmentations to the data + features, n_time_steps = self.transform_data(features, n_time_steps, 'train') + + adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) + + # Get phoneme predictions + logits = self.model(features, day_indicies) + + # Calculate CTC Loss + loss = self.ctc_loss( + log_probs = torch.permute(logits.log_softmax(2), [1, 0, 2]), + targets = labels, + input_lengths = adjusted_lens, + target_lengths = phone_seq_lens + ) + + loss = torch.mean(loss) # take mean loss over batches + + loss.backward() + + # Clip gradient + if self.args['grad_norm_clip_value'] > 0: + grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), + max_norm = self.args['grad_norm_clip_value'], + error_if_nonfinite = True, + foreach = True + ) + + self.optimizer.step() + self.learning_rate_scheduler.step() + + # Save training metrics + train_step_duration = time.time() - start_time + train_losses.append(loss.detach().item()) + + # Incrementally log training progress + if i % self.args['batches_per_train_log'] == 0: + self.logger.info(f'Train batch {i}: ' + + f'loss: {(loss.detach().item()):.2f} ' + + f'grad norm: {grad_norm:.2f} ' + f'time: {train_step_duration:.3f}') + + # Incrementally run a test step + if i % self.args['batches_per_val_step'] == 0 or i == ((self.args['num_training_batches'] - 1)): + self.logger.info(f"Running test after training batch: {i}") + + # Calculate metrics on val data + start_time = time.time() + val_metrics = self.validation(loader = self.val_loader, return_logits = self.args['save_val_logits'], return_data = self.args['save_val_data']) + val_step_duration = time.time() - start_time + + + # Log info + self.logger.info(f'Val batch {i}: ' + + f'PER (avg): {val_metrics["avg_PER"]:.4f} ' + + f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} ' + + f'time: {val_step_duration:.3f}') + + if self.args['log_individual_day_val_PER']: + for day in val_metrics['day_PERs'].keys(): + self.logger.info(f"{self.args['dataset']['sessions'][day]} val PER: {val_metrics['day_PERs'][day]['total_edit_distance'] / val_metrics['day_PERs'][day]['total_seq_length']:0.4f}") + + # Save metrics + val_PERs.append(val_metrics['avg_PER']) + val_losses.append(val_metrics['avg_loss']) + val_results.append(val_metrics) + + # Determine if new best day. Based on if PER is lower, or in the case of a PER tie, if loss is lower + new_best = False + if val_metrics['avg_PER'] < self.best_val_PER: + self.logger.info(f"New best test PER {self.best_val_PER:.4f} --> {val_metrics['avg_PER']:.4f}") + self.best_val_PER = val_metrics['avg_PER'] + self.best_val_loss = val_metrics['avg_loss'] + new_best = True + elif val_metrics['avg_PER'] == self.best_val_PER and (val_metrics['avg_loss'] < self.best_val_loss): + self.logger.info(f"New best test loss {self.best_val_loss:.4f} --> {val_metrics['avg_loss']:.4f}") + self.best_val_loss = val_metrics['avg_loss'] + new_best = True + + if new_best: + + # Checkpoint if metrics have improved + if save_best_checkpoint: + self.logger.info(f"Checkpointing model") + self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/best_checkpoint', self.best_val_PER, self.best_val_loss) + + # save validation metrics to pickle file + if self.args['save_val_metrics']: + with open(f'{self.args["checkpoint_dir"]}/val_metrics.pkl', 'wb') as f: + pickle.dump(val_metrics, f) + + val_steps_since_improvement = 0 + + else: + val_steps_since_improvement +=1 + + # Optionally save this validation checkpoint, regardless of performance + if self.args['save_all_val_steps']: + self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/checkpoint_batch_{i}', val_metrics['avg_PER']) + + # Early stopping + if early_stopping and (val_steps_since_improvement >= early_stopping_val_steps): + self.logger.info(f'Overall validation PER has not improved in {early_stopping_val_steps} validation steps. Stopping training early at batch: {i}') + break + + # Log final training steps + training_duration = time.time() - train_start_time + + + self.logger.info(f'Best avg val PER achieved: {self.best_val_PER:.5f}') + self.logger.info(f'Total training time: {(training_duration / 60):.2f} minutes') + + # Save final model + if self.args['save_final_model']: + self.save_model_checkpoint(f'{self.args["checkpoint_dir"]}/final_checkpoint_batch_{i}', val_PERs[-1]) + + train_stats = {} + train_stats['train_losses'] = train_losses + train_stats['val_losses'] = val_losses + train_stats['val_PERs'] = val_PERs + train_stats['val_metrics'] = val_results + + return train_stats + + def validation(self, loader, return_logits = False, return_data = False): + ''' + Calculate metrics on the validation dataset + ''' + self.model.eval() + + metrics = {} + + # Record metrics + if return_logits: + metrics['logits'] = [] + metrics['n_time_steps'] = [] + + if return_data: + metrics['input_features'] = [] + + metrics['decoded_seqs'] = [] + metrics['true_seq'] = [] + metrics['phone_seq_lens'] = [] + metrics['transcription'] = [] + metrics['losses'] = [] + metrics['block_nums'] = [] + metrics['trial_nums'] = [] + metrics['day_indicies'] = [] + + total_edit_distance = 0 + total_seq_length = 0 + + # Calculate PER for each specific day + day_per = {} + for d in range(len(self.args['dataset']['sessions'])): + if self.args['dataset']['dataset_probability_val'][d] == 1: + day_per[d] = {'total_edit_distance' : 0, 'total_seq_length' : 0} + + for i, batch in enumerate(loader): + + features = batch['input_features'].to(self.device) + labels = batch['seq_class_ids'].to(self.device) + n_time_steps = batch['n_time_steps'].to(self.device) + phone_seq_lens = batch['phone_seq_lens'].to(self.device) + day_indicies = batch['day_indicies'].to(self.device) + + # Determine if we should perform validation on this batch + day = day_indicies[0].item() + if self.args['dataset']['dataset_probability_val'][day] == 0: + if self.args['log_val_skip_logs']: + self.logger.info(f"Skipping validation on day {day}") + continue + + with torch.no_grad(): + + with torch.autocast(device_type = "cuda", enabled = self.args['use_amp'], dtype = torch.bfloat16): + features, n_time_steps = self.transform_data(features, n_time_steps, 'val') + + adjusted_lens = ((n_time_steps - self.args['model']['patch_size']) / self.args['model']['patch_stride'] + 1).to(torch.int32) + + logits = self.model(features, day_indicies) + + loss = self.ctc_loss( + torch.permute(logits.log_softmax(2), [1, 0, 2]), + labels, + adjusted_lens, + phone_seq_lens, + ) + loss = torch.mean(loss) + + metrics['losses'].append(loss.cpu().detach().numpy()) + + # Calculate PER per day and also avg over entire validation set + batch_edit_distance = 0 + decoded_seqs = [] + for iterIdx in range(logits.shape[0]): + decoded_seq = torch.argmax(logits[iterIdx, 0 : adjusted_lens[iterIdx], :].clone().detach(),dim=-1) + decoded_seq = torch.unique_consecutive(decoded_seq, dim=-1) + decoded_seq = decoded_seq.cpu().detach().numpy() + decoded_seq = np.array([i for i in decoded_seq if i != 0]) + + trueSeq = np.array( + labels[iterIdx][0 : phone_seq_lens[iterIdx]].cpu().detach() + ) + + batch_edit_distance += F.edit_distance(decoded_seq, trueSeq) + + decoded_seqs.append(decoded_seq) + + day = batch['day_indicies'][0].item() + + day_per[day]['total_edit_distance'] += batch_edit_distance + day_per[day]['total_seq_length'] += torch.sum(phone_seq_lens).item() + + + total_edit_distance += batch_edit_distance + total_seq_length += torch.sum(phone_seq_lens) + + # Record metrics + if return_logits: + metrics['logits'].append(logits.cpu().float().numpy()) # Will be in bfloat16 if AMP is enabled, so need to set back to float32 + metrics['n_time_steps'].append(adjusted_lens.cpu().numpy()) + + if return_data: + metrics['input_features'].append(batch['input_features'].cpu().numpy()) + + metrics['decoded_seqs'].append(decoded_seqs) + metrics['true_seq'].append(batch['seq_class_ids'].cpu().numpy()) + metrics['phone_seq_lens'].append(batch['phone_seq_lens'].cpu().numpy()) + metrics['transcription'].append(batch['transcriptions'].cpu().numpy()) + metrics['losses'].append(loss.detach().item()) + metrics['block_nums'].append(batch['block_nums'].numpy()) + metrics['trial_nums'].append(batch['trial_nums'].numpy()) + metrics['day_indicies'].append(batch['day_indicies'].cpu().numpy()) + + avg_PER = total_edit_distance / total_seq_length + + metrics['day_PERs'] = day_per + metrics['avg_PER'] = avg_PER.item() + metrics['avg_loss'] = np.mean(metrics['losses']) + + return metrics \ No newline at end of file diff --git a/model_training/train_model.py b/model_training/train_model.py new file mode 100644 index 0000000..d456731 --- /dev/null +++ b/model_training/train_model.py @@ -0,0 +1,6 @@ +from omegaconf import OmegaConf +from rnn_trainer import BrainToTextDecoder_Trainer + +args = OmegaConf.load('rnn_args.yaml') +trainer = BrainToTextDecoder_Trainer(args) +metrics = trainer.train() \ No newline at end of file