tpu maual dataloader

This commit is contained in:
Zchen
2025-10-12 20:43:43 +08:00
parent bc015f5efb
commit 7cc9c41b7f

View File

@@ -49,7 +49,7 @@ class BrainToTextDecoder_Trainer:
) )
# Set even_batches to False after initialization - required for batch_size=None DataLoaders # Set even_batches to False after initialization - required for batch_size=None DataLoaders
self.accelerator.even_batches = False # Note: This may not be settable in all Accelerate versions, but we handle it in DataLoader config
# Trainer fields # Trainer fields
self.args = args self.args = args
@@ -280,19 +280,36 @@ class BrainToTextDecoder_Trainer:
param.requires_grad = False param.requires_grad = False
# Prepare model, optimizer, scheduler, and dataloaders for distributed training # Prepare model, optimizer, scheduler, and dataloaders for distributed training
( # For TPU, don't prepare DataLoaders with Accelerator to avoid batch_sampler issues
self.model, use_tpu = self.args.get('use_tpu', False)
self.optimizer,
self.learning_rate_scheduler, if use_tpu:
self.train_loader, # On TPU, only prepare model, optimizer, and scheduler
self.val_loader, (
) = self.accelerator.prepare( self.model,
self.model, self.optimizer,
self.optimizer, self.learning_rate_scheduler,
self.learning_rate_scheduler, ) = self.accelerator.prepare(
self.train_loader, self.model,
self.val_loader, self.optimizer,
) self.learning_rate_scheduler,
)
# DataLoaders remain unprepared but will work with our custom configuration
else:
# Standard GPU/CPU preparation including DataLoaders
(
self.model,
self.optimizer,
self.learning_rate_scheduler,
self.train_loader,
self.val_loader,
) = self.accelerator.prepare(
self.model,
self.optimizer,
self.learning_rate_scheduler,
self.train_loader,
self.val_loader,
)
self.logger.info("Prepared model and dataloaders with Accelerator") self.logger.info("Prepared model and dataloaders with Accelerator")
@@ -563,12 +580,22 @@ class BrainToTextDecoder_Trainer:
# Train step # Train step
start_time = time.time() start_time = time.time()
# Data is automatically moved to device by Accelerator # Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator
features = batch['input_features'] use_tpu = self.args.get('use_tpu', False)
labels = batch['seq_class_ids'] if use_tpu:
n_time_steps = batch['n_time_steps'] # Manual data movement for TPU since DataLoaders are not prepared by Accelerator
phone_seq_lens = batch['phone_seq_lens'] features = batch['input_features'].to(self.device)
day_indicies = batch['day_indicies'] labels = batch['seq_class_ids'].to(self.device)
n_time_steps = batch['n_time_steps'].to(self.device)
phone_seq_lens = batch['phone_seq_lens'].to(self.device)
day_indicies = batch['day_indicies'].to(self.device)
else:
# For GPU/CPU, data is automatically moved to device by Accelerator
features = batch['input_features']
labels = batch['seq_class_ids']
n_time_steps = batch['n_time_steps']
phone_seq_lens = batch['phone_seq_lens']
day_indicies = batch['day_indicies']
# Use Accelerator's autocast (mixed precision handled by Accelerator init) # Use Accelerator's autocast (mixed precision handled by Accelerator init)
with self.accelerator.autocast(): with self.accelerator.autocast():
@@ -732,12 +759,22 @@ class BrainToTextDecoder_Trainer:
for i, batch in enumerate(loader): for i, batch in enumerate(loader):
# Data is automatically moved to device by Accelerator # Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator
features = batch['input_features'] use_tpu = self.args.get('use_tpu', False)
labels = batch['seq_class_ids'] if use_tpu:
n_time_steps = batch['n_time_steps'] # Manual data movement for TPU since DataLoaders are not prepared by Accelerator
phone_seq_lens = batch['phone_seq_lens'] features = batch['input_features'].to(self.device)
day_indicies = batch['day_indicies'] labels = batch['seq_class_ids'].to(self.device)
n_time_steps = batch['n_time_steps'].to(self.device)
phone_seq_lens = batch['phone_seq_lens'].to(self.device)
day_indicies = batch['day_indicies'].to(self.device)
else:
# For GPU/CPU, data is automatically moved to device by Accelerator
features = batch['input_features']
labels = batch['seq_class_ids']
n_time_steps = batch['n_time_steps']
phone_seq_lens = batch['phone_seq_lens']
day_indicies = batch['day_indicies']
# Determine if we should perform validation on this batch # Determine if we should perform validation on this batch
day = day_indicies[0].item() day = day_indicies[0].item()
@@ -838,10 +875,18 @@ class BrainToTextDecoder_Trainer:
''' '''
self.model.eval() self.model.eval()
# Data is automatically moved to device by Accelerator # Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator
features = batch['input_features'] use_tpu = self.args.get('use_tpu', False)
day_indicies = batch['day_indicies'] if use_tpu:
n_time_steps = batch['n_time_steps'] # Manual data movement for TPU since DataLoaders are not prepared by Accelerator
features = batch['input_features'].to(self.device)
day_indicies = batch['day_indicies'].to(self.device)
n_time_steps = batch['n_time_steps'].to(self.device)
else:
# For GPU/CPU, data is automatically moved to device by Accelerator
features = batch['input_features']
day_indicies = batch['day_indicies']
n_time_steps = batch['n_time_steps']
with torch.no_grad(): with torch.no_grad():
with self.accelerator.autocast(): with self.accelerator.autocast():