diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index 201cabc..4428d11 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -46,6 +46,7 @@ class BrainToTextDecoder_Trainer: gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1), log_with=None, # We'll use our own logging project_dir=args.get('output_dir', './output'), + even_batches=False, # Required for batch_size=None DataLoaders ) # Note: even_batches is handled automatically by Accelerator based on our DataLoader configuration