diff --git a/model_training_nnn/rnn_args.yaml b/model_training_nnn/rnn_args.yaml index e824035..306d8d2 100644 --- a/model_training_nnn/rnn_args.yaml +++ b/model_training_nnn/rnn_args.yaml @@ -19,7 +19,7 @@ mode: train use_amp: true # whether to use automatic mixed precision (AMP) for training # TPU and distributed training settings -use_tpu: false # whether to use TPU for training (set to true for TPU) +use_tpu: true # whether to use TPU for training (set to true for TPU) num_tpu_cores: 8 # number of TPU cores to use (typically 8 for v3-8 or v4-8) gradient_accumulation_steps: 1 # number of gradient accumulation steps for distributed training dataloader_num_workers: 0 # set to 0 for TPU to avoid multiprocessing issues diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index fbafc29..4194051 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -46,7 +46,6 @@ class BrainToTextDecoder_Trainer: gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1), log_with=None, # We'll use our own logging project_dir=args.get('output_dir', './output'), - even_batches=False, # Required for batch_size=None DataLoaders )