diff --git a/model_training_nnn/rnn_args.yaml b/model_training_nnn/rnn_args.yaml index 306d8d2..6e55e64 100644 --- a/model_training_nnn/rnn_args.yaml +++ b/model_training_nnn/rnn_args.yaml @@ -22,7 +22,6 @@ use_amp: true # whether to use automatic mixed precision (AMP) for training use_tpu: true # whether to use TPU for training (set to true for TPU) num_tpu_cores: 8 # number of TPU cores to use (typically 8 for v3-8 or v4-8) gradient_accumulation_steps: 1 # number of gradient accumulation steps for distributed training -dataloader_num_workers: 0 # set to 0 for TPU to avoid multiprocessing issues output_dir: trained_models/baseline_rnn # directory to save the trained model and logs checkpoint_dir: trained_models/baseline_rnn/checkpoint # directory to save checkpoints during training @@ -82,6 +81,7 @@ dataset: days_per_batch: 4 # number of randomly-selected days to include in each batch seed: 1 # random seed for reproducibility num_dataloader_workers: 4 # number of workers for the data loader + dataloader_num_workers: 0 # set to 0 for TPU to avoid multiprocessing issues loader_shuffle: false # whether to shuffle the data loader must_include_days: null # specific days to include in the dataset test_percentage: 0.1 # percentage of data to use for testing