tpu
This commit is contained in:
@@ -20,7 +20,7 @@ use_amp: true # whether to use automatic mixed precision (AMP) for training
|
|||||||
|
|
||||||
# TPU and distributed training settings
|
# TPU and distributed training settings
|
||||||
use_tpu: true # whether to use TPU for training (set to true for TPU)
|
use_tpu: true # whether to use TPU for training (set to true for TPU)
|
||||||
num_tpu_cores: 8 # number of TPU cores to use (typically 8 for v3-8 or v4-8)
|
num_tpu_cores: 8 # number of TPU cores to use (full TPU v3-8)
|
||||||
gradient_accumulation_steps: 1 # number of gradient accumulation steps for distributed training
|
gradient_accumulation_steps: 1 # number of gradient accumulation steps for distributed training
|
||||||
|
|
||||||
output_dir: trained_models/baseline_rnn # directory to save the trained model and logs
|
output_dir: trained_models/baseline_rnn # directory to save the trained model and logs
|
||||||
|
@@ -183,34 +183,14 @@ class BrainToTextDecoder_Trainer:
|
|||||||
random_seed = self.args['dataset']['seed'],
|
random_seed = self.args['dataset']['seed'],
|
||||||
feature_subset = feature_subset
|
feature_subset = feature_subset
|
||||||
)
|
)
|
||||||
# Use TPU-optimized dataloader settings if TPU is enabled
|
# Standard DataLoader configuration - let Accelerator handle device-specific optimizations
|
||||||
num_workers = self.args['dataset']['dataloader_num_workers'] if self.args.get('use_tpu', False) else self.args['dataset']['num_dataloader_workers']
|
self.train_loader = DataLoader(
|
||||||
|
self.train_dataset,
|
||||||
# For TPU environments, we need to be more careful about DataLoader configuration
|
batch_size = None, # Dataset.__getitem__() already returns batches
|
||||||
use_tpu = self.args.get('use_tpu', False)
|
shuffle = self.args['dataset']['loader_shuffle'],
|
||||||
|
num_workers = self.args['dataset']['num_dataloader_workers'],
|
||||||
if use_tpu:
|
pin_memory = True
|
||||||
# For TPU, create a custom DataLoader that properly handles our batch-returning Dataset
|
)
|
||||||
# TPU requires specific DataLoader configuration to avoid batch_sampler issues
|
|
||||||
self.train_loader = DataLoader(
|
|
||||||
self.train_dataset,
|
|
||||||
batch_size = None, # None because our Dataset returns batches
|
|
||||||
sampler = None, # Disable sampler to avoid batch_sampler conflicts
|
|
||||||
batch_sampler = None, # Explicitly set to None
|
|
||||||
shuffle = False, # Can't shuffle with custom batching
|
|
||||||
num_workers = num_workers,
|
|
||||||
pin_memory = False, # TPU doesn't support pin_memory
|
|
||||||
collate_fn = lambda x: x if isinstance(x, dict) else x[0] # Handle both dict and list formats
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Standard GPU/CPU configuration
|
|
||||||
self.train_loader = DataLoader(
|
|
||||||
self.train_dataset,
|
|
||||||
batch_size = None, # Dataset.__getitem__() already returns batches
|
|
||||||
shuffle = self.args['dataset']['loader_shuffle'],
|
|
||||||
num_workers = num_workers,
|
|
||||||
pin_memory = True
|
|
||||||
)
|
|
||||||
|
|
||||||
# val dataset and dataloader
|
# val dataset and dataloader
|
||||||
self.val_dataset = BrainToTextDataset(
|
self.val_dataset = BrainToTextDataset(
|
||||||
@@ -223,27 +203,14 @@ class BrainToTextDecoder_Trainer:
|
|||||||
random_seed = self.args['dataset']['seed'],
|
random_seed = self.args['dataset']['seed'],
|
||||||
feature_subset = feature_subset
|
feature_subset = feature_subset
|
||||||
)
|
)
|
||||||
if use_tpu:
|
# Standard validation DataLoader configuration
|
||||||
# For TPU, create a custom DataLoader that properly handles our batch-returning Dataset
|
self.val_loader = DataLoader(
|
||||||
self.val_loader = DataLoader(
|
self.val_dataset,
|
||||||
self.val_dataset,
|
batch_size = None, # Dataset.__getitem__() already returns batches
|
||||||
batch_size = None, # None because our Dataset returns batches
|
shuffle = False,
|
||||||
sampler = None, # Disable sampler to avoid batch_sampler conflicts
|
num_workers = 0, # Keep validation dataloader single-threaded for consistency
|
||||||
batch_sampler = None, # Explicitly set to None
|
pin_memory = True
|
||||||
shuffle = False,
|
)
|
||||||
num_workers = 0, # Keep validation dataloader single-threaded for consistency
|
|
||||||
pin_memory = False, # TPU doesn't support pin_memory
|
|
||||||
collate_fn = lambda x: x if isinstance(x, dict) else x[0] # Handle both dict and list formats
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Standard GPU/CPU configuration
|
|
||||||
self.val_loader = DataLoader(
|
|
||||||
self.val_dataset,
|
|
||||||
batch_size = None, # Dataset.__getitem__() already returns batches
|
|
||||||
shuffle = False,
|
|
||||||
num_workers = 0, # Keep validation dataloader single-threaded for consistency
|
|
||||||
pin_memory = True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.logger.info("Successfully initialized datasets")
|
self.logger.info("Successfully initialized datasets")
|
||||||
|
|
||||||
@@ -278,36 +245,20 @@ class BrainToTextDecoder_Trainer:
|
|||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
|
|
||||||
# Prepare model, optimizer, scheduler, and dataloaders for distributed training
|
# Prepare model, optimizer, scheduler, and dataloaders for distributed training
|
||||||
# For TPU, don't prepare DataLoaders with Accelerator to avoid batch_sampler issues
|
# Let Accelerator handle everything automatically for both GPU and TPU
|
||||||
use_tpu = self.args.get('use_tpu', False)
|
(
|
||||||
|
self.model,
|
||||||
if use_tpu:
|
self.optimizer,
|
||||||
# On TPU, only prepare model, optimizer, and scheduler
|
self.learning_rate_scheduler,
|
||||||
(
|
self.train_loader,
|
||||||
self.model,
|
self.val_loader,
|
||||||
self.optimizer,
|
) = self.accelerator.prepare(
|
||||||
self.learning_rate_scheduler,
|
self.model,
|
||||||
) = self.accelerator.prepare(
|
self.optimizer,
|
||||||
self.model,
|
self.learning_rate_scheduler,
|
||||||
self.optimizer,
|
self.train_loader,
|
||||||
self.learning_rate_scheduler,
|
self.val_loader,
|
||||||
)
|
)
|
||||||
# DataLoaders remain unprepared but will work with our custom configuration
|
|
||||||
else:
|
|
||||||
# Standard GPU/CPU preparation including DataLoaders
|
|
||||||
(
|
|
||||||
self.model,
|
|
||||||
self.optimizer,
|
|
||||||
self.learning_rate_scheduler,
|
|
||||||
self.train_loader,
|
|
||||||
self.val_loader,
|
|
||||||
) = self.accelerator.prepare(
|
|
||||||
self.model,
|
|
||||||
self.optimizer,
|
|
||||||
self.learning_rate_scheduler,
|
|
||||||
self.train_loader,
|
|
||||||
self.val_loader,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.logger.info("Prepared model and dataloaders with Accelerator")
|
self.logger.info("Prepared model and dataloaders with Accelerator")
|
||||||
|
|
||||||
@@ -578,22 +529,12 @@ class BrainToTextDecoder_Trainer:
|
|||||||
# Train step
|
# Train step
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator
|
# Data is automatically moved to device by Accelerator
|
||||||
use_tpu = self.args.get('use_tpu', False)
|
features = batch['input_features']
|
||||||
if use_tpu:
|
labels = batch['seq_class_ids']
|
||||||
# Manual data movement for TPU since DataLoaders are not prepared by Accelerator
|
n_time_steps = batch['n_time_steps']
|
||||||
features = batch['input_features'].to(self.device)
|
phone_seq_lens = batch['phone_seq_lens']
|
||||||
labels = batch['seq_class_ids'].to(self.device)
|
day_indicies = batch['day_indicies']
|
||||||
n_time_steps = batch['n_time_steps'].to(self.device)
|
|
||||||
phone_seq_lens = batch['phone_seq_lens'].to(self.device)
|
|
||||||
day_indicies = batch['day_indicies'].to(self.device)
|
|
||||||
else:
|
|
||||||
# For GPU/CPU, data is automatically moved to device by Accelerator
|
|
||||||
features = batch['input_features']
|
|
||||||
labels = batch['seq_class_ids']
|
|
||||||
n_time_steps = batch['n_time_steps']
|
|
||||||
phone_seq_lens = batch['phone_seq_lens']
|
|
||||||
day_indicies = batch['day_indicies']
|
|
||||||
|
|
||||||
# Use Accelerator's autocast (mixed precision handled by Accelerator init)
|
# Use Accelerator's autocast (mixed precision handled by Accelerator init)
|
||||||
with self.accelerator.autocast():
|
with self.accelerator.autocast():
|
||||||
@@ -757,22 +698,12 @@ class BrainToTextDecoder_Trainer:
|
|||||||
|
|
||||||
for i, batch in enumerate(loader):
|
for i, batch in enumerate(loader):
|
||||||
|
|
||||||
# Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator
|
# Data is automatically moved to device by Accelerator
|
||||||
use_tpu = self.args.get('use_tpu', False)
|
features = batch['input_features']
|
||||||
if use_tpu:
|
labels = batch['seq_class_ids']
|
||||||
# Manual data movement for TPU since DataLoaders are not prepared by Accelerator
|
n_time_steps = batch['n_time_steps']
|
||||||
features = batch['input_features'].to(self.device)
|
phone_seq_lens = batch['phone_seq_lens']
|
||||||
labels = batch['seq_class_ids'].to(self.device)
|
day_indicies = batch['day_indicies']
|
||||||
n_time_steps = batch['n_time_steps'].to(self.device)
|
|
||||||
phone_seq_lens = batch['phone_seq_lens'].to(self.device)
|
|
||||||
day_indicies = batch['day_indicies'].to(self.device)
|
|
||||||
else:
|
|
||||||
# For GPU/CPU, data is automatically moved to device by Accelerator
|
|
||||||
features = batch['input_features']
|
|
||||||
labels = batch['seq_class_ids']
|
|
||||||
n_time_steps = batch['n_time_steps']
|
|
||||||
phone_seq_lens = batch['phone_seq_lens']
|
|
||||||
day_indicies = batch['day_indicies']
|
|
||||||
|
|
||||||
# Determine if we should perform validation on this batch
|
# Determine if we should perform validation on this batch
|
||||||
day = day_indicies[0].item()
|
day = day_indicies[0].item()
|
||||||
@@ -869,22 +800,14 @@ class BrainToTextDecoder_Trainer:
|
|||||||
|
|
||||||
def inference_batch(self, batch, mode='inference'):
|
def inference_batch(self, batch, mode='inference'):
|
||||||
'''
|
'''
|
||||||
TPU-compatible inference method for processing a full batch
|
Inference method for processing a full batch
|
||||||
'''
|
'''
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
# Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator
|
# Data is automatically moved to device by Accelerator
|
||||||
use_tpu = self.args.get('use_tpu', False)
|
features = batch['input_features']
|
||||||
if use_tpu:
|
day_indicies = batch['day_indicies']
|
||||||
# Manual data movement for TPU since DataLoaders are not prepared by Accelerator
|
n_time_steps = batch['n_time_steps']
|
||||||
features = batch['input_features'].to(self.device)
|
|
||||||
day_indicies = batch['day_indicies'].to(self.device)
|
|
||||||
n_time_steps = batch['n_time_steps'].to(self.device)
|
|
||||||
else:
|
|
||||||
# For GPU/CPU, data is automatically moved to device by Accelerator
|
|
||||||
features = batch['input_features']
|
|
||||||
day_indicies = batch['day_indicies']
|
|
||||||
n_time_steps = batch['n_time_steps']
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with self.accelerator.autocast():
|
with self.accelerator.autocast():
|
||||||
|
Reference in New Issue
Block a user