tpu maual dataloader
This commit is contained in:
@@ -49,7 +49,7 @@ class BrainToTextDecoder_Trainer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Set even_batches to False after initialization - required for batch_size=None DataLoaders
|
# Set even_batches to False after initialization - required for batch_size=None DataLoaders
|
||||||
self.accelerator.even_batches = False
|
# Note: This may not be settable in all Accelerate versions, but we handle it in DataLoader config
|
||||||
|
|
||||||
# Trainer fields
|
# Trainer fields
|
||||||
self.args = args
|
self.args = args
|
||||||
@@ -280,19 +280,36 @@ class BrainToTextDecoder_Trainer:
|
|||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
# Prepare model, optimizer, scheduler, and dataloaders for distributed training
|
# Prepare model, optimizer, scheduler, and dataloaders for distributed training
|
||||||
(
|
# For TPU, don't prepare DataLoaders with Accelerator to avoid batch_sampler issues
|
||||||
self.model,
|
use_tpu = self.args.get('use_tpu', False)
|
||||||
self.optimizer,
|
|
||||||
self.learning_rate_scheduler,
|
if use_tpu:
|
||||||
self.train_loader,
|
# On TPU, only prepare model, optimizer, and scheduler
|
||||||
self.val_loader,
|
(
|
||||||
) = self.accelerator.prepare(
|
self.model,
|
||||||
self.model,
|
self.optimizer,
|
||||||
self.optimizer,
|
self.learning_rate_scheduler,
|
||||||
self.learning_rate_scheduler,
|
) = self.accelerator.prepare(
|
||||||
self.train_loader,
|
self.model,
|
||||||
self.val_loader,
|
self.optimizer,
|
||||||
)
|
self.learning_rate_scheduler,
|
||||||
|
)
|
||||||
|
# DataLoaders remain unprepared but will work with our custom configuration
|
||||||
|
else:
|
||||||
|
# Standard GPU/CPU preparation including DataLoaders
|
||||||
|
(
|
||||||
|
self.model,
|
||||||
|
self.optimizer,
|
||||||
|
self.learning_rate_scheduler,
|
||||||
|
self.train_loader,
|
||||||
|
self.val_loader,
|
||||||
|
) = self.accelerator.prepare(
|
||||||
|
self.model,
|
||||||
|
self.optimizer,
|
||||||
|
self.learning_rate_scheduler,
|
||||||
|
self.train_loader,
|
||||||
|
self.val_loader,
|
||||||
|
)
|
||||||
|
|
||||||
self.logger.info("Prepared model and dataloaders with Accelerator")
|
self.logger.info("Prepared model and dataloaders with Accelerator")
|
||||||
|
|
||||||
@@ -561,14 +578,24 @@ class BrainToTextDecoder_Trainer:
|
|||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
# Train step
|
# Train step
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Data is automatically moved to device by Accelerator
|
# Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator
|
||||||
features = batch['input_features']
|
use_tpu = self.args.get('use_tpu', False)
|
||||||
labels = batch['seq_class_ids']
|
if use_tpu:
|
||||||
n_time_steps = batch['n_time_steps']
|
# Manual data movement for TPU since DataLoaders are not prepared by Accelerator
|
||||||
phone_seq_lens = batch['phone_seq_lens']
|
features = batch['input_features'].to(self.device)
|
||||||
day_indicies = batch['day_indicies']
|
labels = batch['seq_class_ids'].to(self.device)
|
||||||
|
n_time_steps = batch['n_time_steps'].to(self.device)
|
||||||
|
phone_seq_lens = batch['phone_seq_lens'].to(self.device)
|
||||||
|
day_indicies = batch['day_indicies'].to(self.device)
|
||||||
|
else:
|
||||||
|
# For GPU/CPU, data is automatically moved to device by Accelerator
|
||||||
|
features = batch['input_features']
|
||||||
|
labels = batch['seq_class_ids']
|
||||||
|
n_time_steps = batch['n_time_steps']
|
||||||
|
phone_seq_lens = batch['phone_seq_lens']
|
||||||
|
day_indicies = batch['day_indicies']
|
||||||
|
|
||||||
# Use Accelerator's autocast (mixed precision handled by Accelerator init)
|
# Use Accelerator's autocast (mixed precision handled by Accelerator init)
|
||||||
with self.accelerator.autocast():
|
with self.accelerator.autocast():
|
||||||
@@ -732,12 +759,22 @@ class BrainToTextDecoder_Trainer:
|
|||||||
|
|
||||||
for i, batch in enumerate(loader):
|
for i, batch in enumerate(loader):
|
||||||
|
|
||||||
# Data is automatically moved to device by Accelerator
|
# Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator
|
||||||
features = batch['input_features']
|
use_tpu = self.args.get('use_tpu', False)
|
||||||
labels = batch['seq_class_ids']
|
if use_tpu:
|
||||||
n_time_steps = batch['n_time_steps']
|
# Manual data movement for TPU since DataLoaders are not prepared by Accelerator
|
||||||
phone_seq_lens = batch['phone_seq_lens']
|
features = batch['input_features'].to(self.device)
|
||||||
day_indicies = batch['day_indicies']
|
labels = batch['seq_class_ids'].to(self.device)
|
||||||
|
n_time_steps = batch['n_time_steps'].to(self.device)
|
||||||
|
phone_seq_lens = batch['phone_seq_lens'].to(self.device)
|
||||||
|
day_indicies = batch['day_indicies'].to(self.device)
|
||||||
|
else:
|
||||||
|
# For GPU/CPU, data is automatically moved to device by Accelerator
|
||||||
|
features = batch['input_features']
|
||||||
|
labels = batch['seq_class_ids']
|
||||||
|
n_time_steps = batch['n_time_steps']
|
||||||
|
phone_seq_lens = batch['phone_seq_lens']
|
||||||
|
day_indicies = batch['day_indicies']
|
||||||
|
|
||||||
# Determine if we should perform validation on this batch
|
# Determine if we should perform validation on this batch
|
||||||
day = day_indicies[0].item()
|
day = day_indicies[0].item()
|
||||||
@@ -838,10 +875,18 @@ class BrainToTextDecoder_Trainer:
|
|||||||
'''
|
'''
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
# Data is automatically moved to device by Accelerator
|
# Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator
|
||||||
features = batch['input_features']
|
use_tpu = self.args.get('use_tpu', False)
|
||||||
day_indicies = batch['day_indicies']
|
if use_tpu:
|
||||||
n_time_steps = batch['n_time_steps']
|
# Manual data movement for TPU since DataLoaders are not prepared by Accelerator
|
||||||
|
features = batch['input_features'].to(self.device)
|
||||||
|
day_indicies = batch['day_indicies'].to(self.device)
|
||||||
|
n_time_steps = batch['n_time_steps'].to(self.device)
|
||||||
|
else:
|
||||||
|
# For GPU/CPU, data is automatically moved to device by Accelerator
|
||||||
|
features = batch['input_features']
|
||||||
|
day_indicies = batch['day_indicies']
|
||||||
|
n_time_steps = batch['n_time_steps']
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with self.accelerator.autocast():
|
with self.accelerator.autocast():
|
||||||
|
Reference in New Issue
Block a user