diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index 822c775..de03d2c 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -49,6 +49,7 @@ class BrainToTextDecoder_Trainer: ) # Set even_batches to False to handle batch_size=None in DataLoaders + # For TPU, we need to handle the batch_sampler issue more carefully self.accelerator.even_batches = False # Trainer fields @@ -188,16 +189,15 @@ class BrainToTextDecoder_Trainer: # Use TPU-optimized dataloader settings if TPU is enabled num_workers = self.args['dataset']['dataloader_num_workers'] if self.args.get('use_tpu', False) else self.args['dataset']['num_dataloader_workers'] - # For TPU environments or when batch_size=None causes issues, use batch_size=1 - # since our dataset already returns complete batches - batch_size_setting = 1 if (self.args.get('use_tpu', False) or self.accelerator.device.type == 'xla') else None + # For TPU environments, we need to be more careful about DataLoader configuration + use_tpu = self.args.get('use_tpu', False) self.train_loader = DataLoader( self.train_dataset, - batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches + batch_size = None, # Dataset.__getitem__() already returns batches shuffle = self.args['dataset']['loader_shuffle'], num_workers = num_workers, - pin_memory = True if self.accelerator.device.type != 'xla' else False # TPU doesn't support pin_memory + pin_memory = not use_tpu # TPU doesn't support pin_memory ) # val dataset and dataloader @@ -213,10 +213,10 @@ class BrainToTextDecoder_Trainer: ) self.val_loader = DataLoader( self.val_dataset, - batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches + batch_size = None, # Dataset.__getitem__() already returns batches shuffle = False, num_workers = 0, # Keep validation dataloader single-threaded for consistency - pin_memory = True if self.accelerator.device.type != 'xla' else False # TPU doesn't support pin_memory + pin_memory = not use_tpu # TPU doesn't support pin_memory ) self.logger.info("Successfully initialized datasets") @@ -252,19 +252,29 @@ class BrainToTextDecoder_Trainer: param.requires_grad = False # Prepare model, optimizer, scheduler, and dataloaders for distributed training - ( - self.model, - self.optimizer, - self.learning_rate_scheduler, - self.train_loader, - self.val_loader, - ) = self.accelerator.prepare( - self.model, - self.optimizer, - self.learning_rate_scheduler, - self.train_loader, - self.val_loader, - ) + # For TPU environments, we may need special handling of DataLoaders + if use_tpu: + # On TPU, prepare DataLoaders separately to avoid batch_sampler issues + self.model, self.optimizer, self.learning_rate_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.learning_rate_scheduler + ) + # Manually move DataLoaders to device if needed - TPU should handle this automatically + # through the Accelerator during training/validation loops + else: + # Standard preparation for GPU/CPU + ( + self.model, + self.optimizer, + self.learning_rate_scheduler, + self.train_loader, + self.val_loader, + ) = self.accelerator.prepare( + self.model, + self.optimizer, + self.learning_rate_scheduler, + self.train_loader, + self.val_loader, + ) self.logger.info("Prepared model and dataloaders with Accelerator")