| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | model:
 | 
					
						
							|  |  |  |   n_input_features: 512 # number of input features in the neural data. (2 features per electrode, 256 electrodes)
 | 
					
						
							|  |  |  |   n_units: 768 # number of units per GRU layer
 | 
					
						
							|  |  |  |   rnn_dropout: 0.4 # dropout rate for the GRU layers
 | 
					
						
							|  |  |  |   rnn_trainable: true # whether the GRU layers are trainable
 | 
					
						
							|  |  |  |   n_layers: 5 # number of GRU layers
 | 
					
						
							|  |  |  |   patch_size: 14 # size of the input patches (14 time steps)
 | 
					
						
							|  |  |  |   patch_stride: 4 # stride for the input patches (4 time steps)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   input_network:
 | 
					
						
							|  |  |  |     n_input_layers: 1 # number of input layers per network (one network for each day)
 | 
					
						
							|  |  |  |     input_layer_sizes:
 | 
					
						
							|  |  |  |     - 512 # size of the input layer (number of input features)
 | 
					
						
							|  |  |  |     input_trainable: true # whether the input layer is trainable
 | 
					
						
							|  |  |  |     input_layer_dropout: 0.2 # dropout rate for the input layer
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | gpu_number: '1' # GPU number to use for training, formatted as a string (e.g., '0', '1', etc.)
 | 
					
						
							|  |  |  | mode: train
 | 
					
						
							|  |  |  | use_amp: true # whether to use automatic mixed precision (AMP) for training
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-12 15:31:45 +08:00
										 |  |  | # TPU and distributed training settings
 | 
					
						
							|  |  |  | use_tpu: false # whether to use TPU for training (set to true for TPU)
 | 
					
						
							|  |  |  | num_tpu_cores: 8 # number of TPU cores to use (typically 8 for v3-8 or v4-8)
 | 
					
						
							|  |  |  | gradient_accumulation_steps: 1 # number of gradient accumulation steps for distributed training
 | 
					
						
							|  |  |  | dataloader_num_workers: 0 # set to 0 for TPU to avoid multiprocessing issues
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-10-12 09:11:32 +08:00
										 |  |  | output_dir: trained_models/baseline_rnn # directory to save the trained model and logs
 | 
					
						
							|  |  |  | checkpoint_dir: trained_models/baseline_rnn/checkpoint # directory to save checkpoints during training
 | 
					
						
							|  |  |  | init_from_checkpoint: false # whether to initialize the model from a checkpoint
 | 
					
						
							|  |  |  | init_checkpoint_path: None # path to the checkpoint to initialize the model from, if any
 | 
					
						
							|  |  |  | save_best_checkpoint: true # whether to save the best checkpoint based on validation metrics
 | 
					
						
							|  |  |  | save_all_val_steps: false # whether to save checkpoints at all validation steps
 | 
					
						
							|  |  |  | save_final_model: false # whether to save the final model after training
 | 
					
						
							|  |  |  | save_val_metrics: true # whether to save validation metrics during training
 | 
					
						
							|  |  |  | early_stopping: false # whether to use early stopping based on validation metrics
 | 
					
						
							|  |  |  | early_stopping_val_steps: 20 # number of validation steps to wait before stopping training if no improvement is seen
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | num_training_batches: 120000 # number of training batches to run
 | 
					
						
							|  |  |  | lr_scheduler_type: cosine # type of learning rate scheduler to use
 | 
					
						
							|  |  |  | lr_max: 0.005 # maximum learning rate for the main model
 | 
					
						
							|  |  |  | lr_min: 0.0001 # minimum learning rate for the main model
 | 
					
						
							|  |  |  | lr_decay_steps: 120000 # number of steps for the learning rate decay
 | 
					
						
							|  |  |  | lr_warmup_steps: 1000 # number of warmup steps for the learning rate scheduler
 | 
					
						
							|  |  |  | lr_max_day: 0.005 # maximum learning rate for the day specific input layers
 | 
					
						
							|  |  |  | lr_min_day: 0.0001 # minimum learning rate for the day specific input layers
 | 
					
						
							|  |  |  | lr_decay_steps_day: 120000 # number of steps for the learning rate decay for the day specific input layers
 | 
					
						
							|  |  |  | lr_warmup_steps_day: 1000 # number of warmup steps for the learning rate scheduler for the day specific input layers
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | beta0: 0.9 # beta0 parameter for the Adam optimizer
 | 
					
						
							|  |  |  | beta1: 0.999 # beta1 parameter for the Adam optimizer
 | 
					
						
							|  |  |  | epsilon: 0.1 # epsilon parameter for the Adam optimizer
 | 
					
						
							|  |  |  | weight_decay: 0.001 # weight decay for the main model
 | 
					
						
							|  |  |  | weight_decay_day: 0 # weight decay for the day specific input layers
 | 
					
						
							|  |  |  | seed: 10 # random seed for reproducibility
 | 
					
						
							|  |  |  | grad_norm_clip_value: 10 # gradient norm clipping value
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | batches_per_train_log: 200 # number of batches per training log
 | 
					
						
							|  |  |  | batches_per_val_step: 2000 # number of batches per validation step
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | batches_per_save: 0 # number of batches per save
 | 
					
						
							|  |  |  | log_individual_day_val_PER: true # whether to log individual day validation performance
 | 
					
						
							|  |  |  | log_val_skip_logs: false # whether to skip logging validation metrics
 | 
					
						
							|  |  |  | save_val_logits: true # whether to save validation logits
 | 
					
						
							|  |  |  | save_val_data: false # whether to save validation data
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | dataset:
 | 
					
						
							|  |  |  |   data_transforms:
 | 
					
						
							|  |  |  |     white_noise_std: 1.0 # standard deviation of the white noise added to the data
 | 
					
						
							|  |  |  |     constant_offset_std: 0.2 # standard deviation of the constant offset added to the data
 | 
					
						
							|  |  |  |     random_walk_std: 0.0 # standard deviation of the random walk added to the data
 | 
					
						
							|  |  |  |     random_walk_axis: -1 # axis along which the random walk is applied
 | 
					
						
							|  |  |  |     static_gain_std: 0.0 # standard deviation of the static gain applied to the data
 | 
					
						
							|  |  |  |     random_cut: 3 # number of time steps to randomly cut from the beginning of each batch of trials
 | 
					
						
							|  |  |  |     smooth_kernel_size: 100 # size of the smoothing kernel applied to the data
 | 
					
						
							|  |  |  |     smooth_data: true # whether to smooth the data
 | 
					
						
							|  |  |  |     smooth_kernel_std: 2 # standard deviation of the smoothing kernel applied to the data
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   neural_dim: 512 # dimensionality of the neural data
 | 
					
						
							|  |  |  |   batch_size: 64 # batch size for training
 | 
					
						
							|  |  |  |   n_classes: 41 # number of classes (phonemes) in the dataset
 | 
					
						
							|  |  |  |   max_seq_elements: 500 # maximum number of sequence elements (phonemes) for any trial
 | 
					
						
							|  |  |  |   days_per_batch: 4 # number of randomly-selected days to include in each batch
 | 
					
						
							|  |  |  |   seed: 1 # random seed for reproducibility
 | 
					
						
							|  |  |  |   num_dataloader_workers: 4 # number of workers for the data loader
 | 
					
						
							|  |  |  |   loader_shuffle: false # whether to shuffle the data loader
 | 
					
						
							|  |  |  |   must_include_days: null # specific days to include in the dataset
 | 
					
						
							|  |  |  |   test_percentage: 0.1 # percentage of data to use for testing
 | 
					
						
							|  |  |  |   feature_subset: null # specific features to include in the dataset
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |   dataset_dir: ../data/hdf5_data_final # directory containing the dataset
 | 
					
						
							|  |  |  |   bad_trials_dict: null # dictionary of bad trials to exclude from the dataset
 | 
					
						
							|  |  |  |   sessions: # list of sessions to include in the dataset
 | 
					
						
							|  |  |  |   - t15.2023.08.11
 | 
					
						
							|  |  |  |   - t15.2023.08.13
 | 
					
						
							|  |  |  |   - t15.2023.08.18
 | 
					
						
							|  |  |  |   - t15.2023.08.20
 | 
					
						
							|  |  |  |   - t15.2023.08.25
 | 
					
						
							|  |  |  |   - t15.2023.08.27
 | 
					
						
							|  |  |  |   - t15.2023.09.01
 | 
					
						
							|  |  |  |   - t15.2023.09.03
 | 
					
						
							|  |  |  |   - t15.2023.09.24
 | 
					
						
							|  |  |  |   - t15.2023.09.29
 | 
					
						
							|  |  |  |   - t15.2023.10.01
 | 
					
						
							|  |  |  |   - t15.2023.10.06
 | 
					
						
							|  |  |  |   - t15.2023.10.08
 | 
					
						
							|  |  |  |   - t15.2023.10.13
 | 
					
						
							|  |  |  |   - t15.2023.10.15
 | 
					
						
							|  |  |  |   - t15.2023.10.20
 | 
					
						
							|  |  |  |   - t15.2023.10.22
 | 
					
						
							|  |  |  |   - t15.2023.11.03
 | 
					
						
							|  |  |  |   - t15.2023.11.04
 | 
					
						
							|  |  |  |   - t15.2023.11.17
 | 
					
						
							|  |  |  |   - t15.2023.11.19
 | 
					
						
							|  |  |  |   - t15.2023.11.26
 | 
					
						
							|  |  |  |   - t15.2023.12.03
 | 
					
						
							|  |  |  |   - t15.2023.12.08
 | 
					
						
							|  |  |  |   - t15.2023.12.10
 | 
					
						
							|  |  |  |   - t15.2023.12.17
 | 
					
						
							|  |  |  |   - t15.2023.12.29
 | 
					
						
							|  |  |  |   - t15.2024.02.25
 | 
					
						
							|  |  |  |   - t15.2024.03.03
 | 
					
						
							|  |  |  |   - t15.2024.03.08
 | 
					
						
							|  |  |  |   - t15.2024.03.15
 | 
					
						
							|  |  |  |   - t15.2024.03.17
 | 
					
						
							|  |  |  |   - t15.2024.04.25
 | 
					
						
							|  |  |  |   - t15.2024.04.28
 | 
					
						
							|  |  |  |   - t15.2024.05.10
 | 
					
						
							|  |  |  |   - t15.2024.06.14
 | 
					
						
							|  |  |  |   - t15.2024.07.19
 | 
					
						
							|  |  |  |   - t15.2024.07.21
 | 
					
						
							|  |  |  |   - t15.2024.07.28
 | 
					
						
							|  |  |  |   - t15.2025.01.10
 | 
					
						
							|  |  |  |   - t15.2025.01.12
 | 
					
						
							|  |  |  |   - t15.2025.03.14
 | 
					
						
							|  |  |  |   - t15.2025.03.16
 | 
					
						
							|  |  |  |   - t15.2025.03.30
 | 
					
						
							|  |  |  |   - t15.2025.04.13
 | 
					
						
							|  |  |  |   dataset_probability_val: # probability of including a trial in the validation set (0 or 1)
 | 
					
						
							|  |  |  |   - 0 # no val or test data from this day
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 0 # no val or test data from this day
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 0 # no val or test data from this day
 | 
					
						
							|  |  |  |   - 0 # no val or test data from this day
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 | 
					
						
							|  |  |  |   - 1
 |