This commit is contained in:
Zchen
2025-10-12 19:49:47 +08:00
parent 3f91f2042f
commit bc9aa17e00

View File

@@ -192,9 +192,12 @@ class BrainToTextDecoder_Trainer:
# For TPU environments, we need to be more careful about DataLoader configuration
use_tpu = self.args.get('use_tpu', False)
# TPU doesn't handle batch_size=None well, so use batch_size=1 for TPU
batch_size_setting = 1 if use_tpu else None
self.train_loader = DataLoader(
self.train_dataset,
batch_size = None, # Dataset.__getitem__() already returns batches
batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches, but TPU needs batch_size=1
shuffle = self.args['dataset']['loader_shuffle'],
num_workers = num_workers,
pin_memory = not use_tpu # TPU doesn't support pin_memory
@@ -213,7 +216,7 @@ class BrainToTextDecoder_Trainer:
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size = None, # Dataset.__getitem__() already returns batches
batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches, but TPU needs batch_size=1
shuffle = False,
num_workers = 0, # Keep validation dataloader single-threaded for consistency
pin_memory = not use_tpu # TPU doesn't support pin_memory
@@ -252,29 +255,19 @@ class BrainToTextDecoder_Trainer:
param.requires_grad = False
# Prepare model, optimizer, scheduler, and dataloaders for distributed training
# For TPU environments, we may need special handling of DataLoaders
if use_tpu:
# On TPU, prepare DataLoaders separately to avoid batch_sampler issues
self.model, self.optimizer, self.learning_rate_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.learning_rate_scheduler
)
# Manually move DataLoaders to device if needed - TPU should handle this automatically
# through the Accelerator during training/validation loops
else:
# Standard preparation for GPU/CPU
(
self.model,
self.optimizer,
self.learning_rate_scheduler,
self.train_loader,
self.val_loader,
) = self.accelerator.prepare(
self.model,
self.optimizer,
self.learning_rate_scheduler,
self.train_loader,
self.val_loader,
)
(
self.model,
self.optimizer,
self.learning_rate_scheduler,
self.train_loader,
self.val_loader,
) = self.accelerator.prepare(
self.model,
self.optimizer,
self.learning_rate_scheduler,
self.train_loader,
self.val_loader,
)
self.logger.info("Prepared model and dataloaders with Accelerator")
@@ -466,6 +459,13 @@ class BrainToTextDecoder_Trainer:
Performing augmentations is much faster on GPU than CPU
'''
# Handle TPU case where DataLoader with batch_size=1 adds an extra dimension
use_tpu = self.args.get('use_tpu', False)
if use_tpu and features.dim() == 4 and features.size(0) == 1:
features = features.squeeze(0) # Remove the extra batch dimension added by DataLoader
if isinstance(n_time_steps, torch.Tensor) and n_time_steps.dim() == 2:
n_time_steps = n_time_steps.squeeze(0)
data_shape = features.shape
batch_size = data_shape[0]
channels = data_shape[-1]