diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index 4428d11..3755234 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -46,10 +46,10 @@ class BrainToTextDecoder_Trainer: gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1), log_with=None, # We'll use our own logging project_dir=args.get('output_dir', './output'), - even_batches=False, # Required for batch_size=None DataLoaders ) - # Note: even_batches is handled automatically by Accelerator based on our DataLoader configuration + # Set even_batches to False after initialization - required for batch_size=None DataLoaders + self.accelerator.even_batches = False # Trainer fields self.args = args