This commit is contained in:
Zchen
2025-10-12 21:20:08 +08:00
parent c6fc211b00
commit 00c94fd48b
2 changed files with 48 additions and 125 deletions

View File

@@ -20,7 +20,7 @@ use_amp: true # whether to use automatic mixed precision (AMP) for training
# TPU and distributed training settings # TPU and distributed training settings
use_tpu: true # whether to use TPU for training (set to true for TPU) use_tpu: true # whether to use TPU for training (set to true for TPU)
num_tpu_cores: 8 # number of TPU cores to use (typically 8 for v3-8 or v4-8) num_tpu_cores: 8 # number of TPU cores to use (full TPU v3-8)
gradient_accumulation_steps: 1 # number of gradient accumulation steps for distributed training gradient_accumulation_steps: 1 # number of gradient accumulation steps for distributed training
output_dir: trained_models/baseline_rnn # directory to save the trained model and logs output_dir: trained_models/baseline_rnn # directory to save the trained model and logs

View File

@@ -183,34 +183,14 @@ class BrainToTextDecoder_Trainer:
random_seed = self.args['dataset']['seed'], random_seed = self.args['dataset']['seed'],
feature_subset = feature_subset feature_subset = feature_subset
) )
# Use TPU-optimized dataloader settings if TPU is enabled # Standard DataLoader configuration - let Accelerator handle device-specific optimizations
num_workers = self.args['dataset']['dataloader_num_workers'] if self.args.get('use_tpu', False) else self.args['dataset']['num_dataloader_workers'] self.train_loader = DataLoader(
self.train_dataset,
# For TPU environments, we need to be more careful about DataLoader configuration batch_size = None, # Dataset.__getitem__() already returns batches
use_tpu = self.args.get('use_tpu', False) shuffle = self.args['dataset']['loader_shuffle'],
num_workers = self.args['dataset']['num_dataloader_workers'],
if use_tpu: pin_memory = True
# For TPU, create a custom DataLoader that properly handles our batch-returning Dataset )
# TPU requires specific DataLoader configuration to avoid batch_sampler issues
self.train_loader = DataLoader(
self.train_dataset,
batch_size = None, # None because our Dataset returns batches
sampler = None, # Disable sampler to avoid batch_sampler conflicts
batch_sampler = None, # Explicitly set to None
shuffle = False, # Can't shuffle with custom batching
num_workers = num_workers,
pin_memory = False, # TPU doesn't support pin_memory
collate_fn = lambda x: x if isinstance(x, dict) else x[0] # Handle both dict and list formats
)
else:
# Standard GPU/CPU configuration
self.train_loader = DataLoader(
self.train_dataset,
batch_size = None, # Dataset.__getitem__() already returns batches
shuffle = self.args['dataset']['loader_shuffle'],
num_workers = num_workers,
pin_memory = True
)
# val dataset and dataloader # val dataset and dataloader
self.val_dataset = BrainToTextDataset( self.val_dataset = BrainToTextDataset(
@@ -223,27 +203,14 @@ class BrainToTextDecoder_Trainer:
random_seed = self.args['dataset']['seed'], random_seed = self.args['dataset']['seed'],
feature_subset = feature_subset feature_subset = feature_subset
) )
if use_tpu: # Standard validation DataLoader configuration
# For TPU, create a custom DataLoader that properly handles our batch-returning Dataset self.val_loader = DataLoader(
self.val_loader = DataLoader( self.val_dataset,
self.val_dataset, batch_size = None, # Dataset.__getitem__() already returns batches
batch_size = None, # None because our Dataset returns batches shuffle = False,
sampler = None, # Disable sampler to avoid batch_sampler conflicts num_workers = 0, # Keep validation dataloader single-threaded for consistency
batch_sampler = None, # Explicitly set to None pin_memory = True
shuffle = False, )
num_workers = 0, # Keep validation dataloader single-threaded for consistency
pin_memory = False, # TPU doesn't support pin_memory
collate_fn = lambda x: x if isinstance(x, dict) else x[0] # Handle both dict and list formats
)
else:
# Standard GPU/CPU configuration
self.val_loader = DataLoader(
self.val_dataset,
batch_size = None, # Dataset.__getitem__() already returns batches
shuffle = False,
num_workers = 0, # Keep validation dataloader single-threaded for consistency
pin_memory = True
)
self.logger.info("Successfully initialized datasets") self.logger.info("Successfully initialized datasets")
@@ -278,36 +245,20 @@ class BrainToTextDecoder_Trainer:
param.requires_grad = False param.requires_grad = False
# Prepare model, optimizer, scheduler, and dataloaders for distributed training # Prepare model, optimizer, scheduler, and dataloaders for distributed training
# For TPU, don't prepare DataLoaders with Accelerator to avoid batch_sampler issues # Let Accelerator handle everything automatically for both GPU and TPU
use_tpu = self.args.get('use_tpu', False) (
self.model,
if use_tpu: self.optimizer,
# On TPU, only prepare model, optimizer, and scheduler self.learning_rate_scheduler,
( self.train_loader,
self.model, self.val_loader,
self.optimizer, ) = self.accelerator.prepare(
self.learning_rate_scheduler, self.model,
) = self.accelerator.prepare( self.optimizer,
self.model, self.learning_rate_scheduler,
self.optimizer, self.train_loader,
self.learning_rate_scheduler, self.val_loader,
) )
# DataLoaders remain unprepared but will work with our custom configuration
else:
# Standard GPU/CPU preparation including DataLoaders
(
self.model,
self.optimizer,
self.learning_rate_scheduler,
self.train_loader,
self.val_loader,
) = self.accelerator.prepare(
self.model,
self.optimizer,
self.learning_rate_scheduler,
self.train_loader,
self.val_loader,
)
self.logger.info("Prepared model and dataloaders with Accelerator") self.logger.info("Prepared model and dataloaders with Accelerator")
@@ -578,22 +529,12 @@ class BrainToTextDecoder_Trainer:
# Train step # Train step
start_time = time.time() start_time = time.time()
# Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator # Data is automatically moved to device by Accelerator
use_tpu = self.args.get('use_tpu', False) features = batch['input_features']
if use_tpu: labels = batch['seq_class_ids']
# Manual data movement for TPU since DataLoaders are not prepared by Accelerator n_time_steps = batch['n_time_steps']
features = batch['input_features'].to(self.device) phone_seq_lens = batch['phone_seq_lens']
labels = batch['seq_class_ids'].to(self.device) day_indicies = batch['day_indicies']
n_time_steps = batch['n_time_steps'].to(self.device)
phone_seq_lens = batch['phone_seq_lens'].to(self.device)
day_indicies = batch['day_indicies'].to(self.device)
else:
# For GPU/CPU, data is automatically moved to device by Accelerator
features = batch['input_features']
labels = batch['seq_class_ids']
n_time_steps = batch['n_time_steps']
phone_seq_lens = batch['phone_seq_lens']
day_indicies = batch['day_indicies']
# Use Accelerator's autocast (mixed precision handled by Accelerator init) # Use Accelerator's autocast (mixed precision handled by Accelerator init)
with self.accelerator.autocast(): with self.accelerator.autocast():
@@ -757,22 +698,12 @@ class BrainToTextDecoder_Trainer:
for i, batch in enumerate(loader): for i, batch in enumerate(loader):
# Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator # Data is automatically moved to device by Accelerator
use_tpu = self.args.get('use_tpu', False) features = batch['input_features']
if use_tpu: labels = batch['seq_class_ids']
# Manual data movement for TPU since DataLoaders are not prepared by Accelerator n_time_steps = batch['n_time_steps']
features = batch['input_features'].to(self.device) phone_seq_lens = batch['phone_seq_lens']
labels = batch['seq_class_ids'].to(self.device) day_indicies = batch['day_indicies']
n_time_steps = batch['n_time_steps'].to(self.device)
phone_seq_lens = batch['phone_seq_lens'].to(self.device)
day_indicies = batch['day_indicies'].to(self.device)
else:
# For GPU/CPU, data is automatically moved to device by Accelerator
features = batch['input_features']
labels = batch['seq_class_ids']
n_time_steps = batch['n_time_steps']
phone_seq_lens = batch['phone_seq_lens']
day_indicies = batch['day_indicies']
# Determine if we should perform validation on this batch # Determine if we should perform validation on this batch
day = day_indicies[0].item() day = day_indicies[0].item()
@@ -869,22 +800,14 @@ class BrainToTextDecoder_Trainer:
def inference_batch(self, batch, mode='inference'): def inference_batch(self, batch, mode='inference'):
''' '''
TPU-compatible inference method for processing a full batch Inference method for processing a full batch
''' '''
self.model.eval() self.model.eval()
# Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator # Data is automatically moved to device by Accelerator
use_tpu = self.args.get('use_tpu', False) features = batch['input_features']
if use_tpu: day_indicies = batch['day_indicies']
# Manual data movement for TPU since DataLoaders are not prepared by Accelerator n_time_steps = batch['n_time_steps']
features = batch['input_features'].to(self.device)
day_indicies = batch['day_indicies'].to(self.device)
n_time_steps = batch['n_time_steps'].to(self.device)
else:
# For GPU/CPU, data is automatically moved to device by Accelerator
features = batch['input_features']
day_indicies = batch['day_indicies']
n_time_steps = batch['n_time_steps']
with torch.no_grad(): with torch.no_grad():
with self.accelerator.autocast(): with self.accelerator.autocast():