tpu
This commit is contained in:
@@ -49,6 +49,7 @@ class BrainToTextDecoder_Trainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Set even_batches to False to handle batch_size=None in DataLoaders
|
# Set even_batches to False to handle batch_size=None in DataLoaders
|
||||||
|
# For TPU, we need to handle the batch_sampler issue more carefully
|
||||||
self.accelerator.even_batches = False
|
self.accelerator.even_batches = False
|
||||||
|
|
||||||
# Trainer fields
|
# Trainer fields
|
||||||
@@ -188,16 +189,15 @@ class BrainToTextDecoder_Trainer:
|
|||||||
# Use TPU-optimized dataloader settings if TPU is enabled
|
# Use TPU-optimized dataloader settings if TPU is enabled
|
||||||
num_workers = self.args['dataset']['dataloader_num_workers'] if self.args.get('use_tpu', False) else self.args['dataset']['num_dataloader_workers']
|
num_workers = self.args['dataset']['dataloader_num_workers'] if self.args.get('use_tpu', False) else self.args['dataset']['num_dataloader_workers']
|
||||||
|
|
||||||
# For TPU environments or when batch_size=None causes issues, use batch_size=1
|
# For TPU environments, we need to be more careful about DataLoader configuration
|
||||||
# since our dataset already returns complete batches
|
use_tpu = self.args.get('use_tpu', False)
|
||||||
batch_size_setting = 1 if (self.args.get('use_tpu', False) or self.accelerator.device.type == 'xla') else None
|
|
||||||
|
|
||||||
self.train_loader = DataLoader(
|
self.train_loader = DataLoader(
|
||||||
self.train_dataset,
|
self.train_dataset,
|
||||||
batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches
|
batch_size = None, # Dataset.__getitem__() already returns batches
|
||||||
shuffle = self.args['dataset']['loader_shuffle'],
|
shuffle = self.args['dataset']['loader_shuffle'],
|
||||||
num_workers = num_workers,
|
num_workers = num_workers,
|
||||||
pin_memory = True if self.accelerator.device.type != 'xla' else False # TPU doesn't support pin_memory
|
pin_memory = not use_tpu # TPU doesn't support pin_memory
|
||||||
)
|
)
|
||||||
|
|
||||||
# val dataset and dataloader
|
# val dataset and dataloader
|
||||||
@@ -213,10 +213,10 @@ class BrainToTextDecoder_Trainer:
|
|||||||
)
|
)
|
||||||
self.val_loader = DataLoader(
|
self.val_loader = DataLoader(
|
||||||
self.val_dataset,
|
self.val_dataset,
|
||||||
batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches
|
batch_size = None, # Dataset.__getitem__() already returns batches
|
||||||
shuffle = False,
|
shuffle = False,
|
||||||
num_workers = 0, # Keep validation dataloader single-threaded for consistency
|
num_workers = 0, # Keep validation dataloader single-threaded for consistency
|
||||||
pin_memory = True if self.accelerator.device.type != 'xla' else False # TPU doesn't support pin_memory
|
pin_memory = not use_tpu # TPU doesn't support pin_memory
|
||||||
)
|
)
|
||||||
|
|
||||||
self.logger.info("Successfully initialized datasets")
|
self.logger.info("Successfully initialized datasets")
|
||||||
@@ -252,6 +252,16 @@ class BrainToTextDecoder_Trainer:
|
|||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
# Prepare model, optimizer, scheduler, and dataloaders for distributed training
|
# Prepare model, optimizer, scheduler, and dataloaders for distributed training
|
||||||
|
# For TPU environments, we may need special handling of DataLoaders
|
||||||
|
if use_tpu:
|
||||||
|
# On TPU, prepare DataLoaders separately to avoid batch_sampler issues
|
||||||
|
self.model, self.optimizer, self.learning_rate_scheduler = self.accelerator.prepare(
|
||||||
|
self.model, self.optimizer, self.learning_rate_scheduler
|
||||||
|
)
|
||||||
|
# Manually move DataLoaders to device if needed - TPU should handle this automatically
|
||||||
|
# through the Accelerator during training/validation loops
|
||||||
|
else:
|
||||||
|
# Standard preparation for GPU/CPU
|
||||||
(
|
(
|
||||||
self.model,
|
self.model,
|
||||||
self.optimizer,
|
self.optimizer,
|
||||||
|
Reference in New Issue
Block a user