diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index 89e35a6..fd69cfa 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -48,9 +48,7 @@ class BrainToTextDecoder_Trainer: project_dir=args.get('output_dir', './output'), ) - # Set even_batches to False to handle batch_size=None in DataLoaders - # For TPU, we need to handle the batch_sampler issue more carefully - self.accelerator.even_batches = False + # Note: even_batches is handled automatically by Accelerator based on our DataLoader configuration # Trainer fields self.args = args @@ -192,16 +190,29 @@ class BrainToTextDecoder_Trainer: # For TPU environments, we need to be more careful about DataLoader configuration use_tpu = self.args.get('use_tpu', False) - # TPU doesn't handle batch_size=None well, so use batch_size=1 for TPU - batch_size_setting = 1 if use_tpu else None - - self.train_loader = DataLoader( - self.train_dataset, - batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches, but TPU needs batch_size=1 - shuffle = self.args['dataset']['loader_shuffle'], - num_workers = num_workers, - pin_memory = not use_tpu # TPU doesn't support pin_memory - ) + if use_tpu: + # For TPU, create a custom DataLoader that properly handles our batch-returning Dataset + # TPU requires specific DataLoader configuration to avoid batch_sampler issues + from torch.utils.data import DataLoader + self.train_loader = DataLoader( + self.train_dataset, + batch_size = None, # None because our Dataset returns batches + sampler = None, # Disable sampler to avoid batch_sampler conflicts + batch_sampler = None, # Explicitly set to None + shuffle = False, # Can't shuffle with custom batching + num_workers = num_workers, + pin_memory = False, # TPU doesn't support pin_memory + collate_fn = lambda x: x[0] # Since Dataset returns batch, just pass it through + ) + else: + # Standard GPU/CPU configuration + self.train_loader = DataLoader( + self.train_dataset, + batch_size = None, # Dataset.__getitem__() already returns batches + shuffle = self.args['dataset']['loader_shuffle'], + num_workers = num_workers, + pin_memory = True + ) # val dataset and dataloader self.val_dataset = BrainToTextDataset( @@ -214,13 +225,27 @@ class BrainToTextDecoder_Trainer: random_seed = self.args['dataset']['seed'], feature_subset = feature_subset ) - self.val_loader = DataLoader( - self.val_dataset, - batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches, but TPU needs batch_size=1 - shuffle = False, - num_workers = 0, # Keep validation dataloader single-threaded for consistency - pin_memory = not use_tpu # TPU doesn't support pin_memory - ) + if use_tpu: + # For TPU, create a custom DataLoader that properly handles our batch-returning Dataset + self.val_loader = DataLoader( + self.val_dataset, + batch_size = None, # None because our Dataset returns batches + sampler = None, # Disable sampler to avoid batch_sampler conflicts + batch_sampler = None, # Explicitly set to None + shuffle = False, + num_workers = 0, # Keep validation dataloader single-threaded for consistency + pin_memory = False, # TPU doesn't support pin_memory + collate_fn = lambda x: x[0] # Since Dataset returns batch, just pass it through + ) + else: + # Standard GPU/CPU configuration + self.val_loader = DataLoader( + self.val_dataset, + batch_size = None, # Dataset.__getitem__() already returns batches + shuffle = False, + num_workers = 0, # Keep validation dataloader single-threaded for consistency + pin_memory = True + ) self.logger.info("Successfully initialized datasets") @@ -459,12 +484,7 @@ class BrainToTextDecoder_Trainer: Performing augmentations is much faster on GPU than CPU ''' - # Handle TPU case where DataLoader with batch_size=1 adds an extra dimension - use_tpu = self.args.get('use_tpu', False) - if use_tpu and features.dim() == 4 and features.size(0) == 1: - features = features.squeeze(0) # Remove the extra batch dimension added by DataLoader - if isinstance(n_time_steps, torch.Tensor) and n_time_steps.dim() == 2: - n_time_steps = n_time_steps.squeeze(0) + # TPU and GPU should now handle data consistently with our improved DataLoader configuration data_shape = features.shape batch_size = data_shape[0]