diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index de03d2c..89e35a6 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -192,9 +192,12 @@ class BrainToTextDecoder_Trainer: # For TPU environments, we need to be more careful about DataLoader configuration use_tpu = self.args.get('use_tpu', False) + # TPU doesn't handle batch_size=None well, so use batch_size=1 for TPU + batch_size_setting = 1 if use_tpu else None + self.train_loader = DataLoader( self.train_dataset, - batch_size = None, # Dataset.__getitem__() already returns batches + batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches, but TPU needs batch_size=1 shuffle = self.args['dataset']['loader_shuffle'], num_workers = num_workers, pin_memory = not use_tpu # TPU doesn't support pin_memory @@ -213,7 +216,7 @@ class BrainToTextDecoder_Trainer: ) self.val_loader = DataLoader( self.val_dataset, - batch_size = None, # Dataset.__getitem__() already returns batches + batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches, but TPU needs batch_size=1 shuffle = False, num_workers = 0, # Keep validation dataloader single-threaded for consistency pin_memory = not use_tpu # TPU doesn't support pin_memory @@ -252,29 +255,19 @@ class BrainToTextDecoder_Trainer: param.requires_grad = False # Prepare model, optimizer, scheduler, and dataloaders for distributed training - # For TPU environments, we may need special handling of DataLoaders - if use_tpu: - # On TPU, prepare DataLoaders separately to avoid batch_sampler issues - self.model, self.optimizer, self.learning_rate_scheduler = self.accelerator.prepare( - self.model, self.optimizer, self.learning_rate_scheduler - ) - # Manually move DataLoaders to device if needed - TPU should handle this automatically - # through the Accelerator during training/validation loops - else: - # Standard preparation for GPU/CPU - ( - self.model, - self.optimizer, - self.learning_rate_scheduler, - self.train_loader, - self.val_loader, - ) = self.accelerator.prepare( - self.model, - self.optimizer, - self.learning_rate_scheduler, - self.train_loader, - self.val_loader, - ) + ( + self.model, + self.optimizer, + self.learning_rate_scheduler, + self.train_loader, + self.val_loader, + ) = self.accelerator.prepare( + self.model, + self.optimizer, + self.learning_rate_scheduler, + self.train_loader, + self.val_loader, + ) self.logger.info("Prepared model and dataloaders with Accelerator") @@ -466,6 +459,13 @@ class BrainToTextDecoder_Trainer: Performing augmentations is much faster on GPU than CPU ''' + # Handle TPU case where DataLoader with batch_size=1 adds an extra dimension + use_tpu = self.args.get('use_tpu', False) + if use_tpu and features.dim() == 4 and features.size(0) == 1: + features = features.squeeze(0) # Remove the extra batch dimension added by DataLoader + if isinstance(n_time_steps, torch.Tensor) and n_time_steps.dim() == 2: + n_time_steps = n_time_steps.squeeze(0) + data_shape = features.shape batch_size = data_shape[0] channels = data_shape[-1]