From 00c94fd48b4aacd7a499d52aabad7199c56c3e43 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Sun, 12 Oct 2025 21:20:08 +0800 Subject: [PATCH] tpu --- model_training_nnn/rnn_args.yaml | 2 +- model_training_nnn/rnn_trainer.py | 171 ++++++++---------------------- 2 files changed, 48 insertions(+), 125 deletions(-) diff --git a/model_training_nnn/rnn_args.yaml b/model_training_nnn/rnn_args.yaml index 6e55e64..54c4f79 100644 --- a/model_training_nnn/rnn_args.yaml +++ b/model_training_nnn/rnn_args.yaml @@ -20,7 +20,7 @@ use_amp: true # whether to use automatic mixed precision (AMP) for training # TPU and distributed training settings use_tpu: true # whether to use TPU for training (set to true for TPU) -num_tpu_cores: 8 # number of TPU cores to use (typically 8 for v3-8 or v4-8) +num_tpu_cores: 8 # number of TPU cores to use (full TPU v3-8) gradient_accumulation_steps: 1 # number of gradient accumulation steps for distributed training output_dir: trained_models/baseline_rnn # directory to save the trained model and logs diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index 2ee00bf..9a99998 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -183,34 +183,14 @@ class BrainToTextDecoder_Trainer: random_seed = self.args['dataset']['seed'], feature_subset = feature_subset ) - # Use TPU-optimized dataloader settings if TPU is enabled - num_workers = self.args['dataset']['dataloader_num_workers'] if self.args.get('use_tpu', False) else self.args['dataset']['num_dataloader_workers'] - - # For TPU environments, we need to be more careful about DataLoader configuration - use_tpu = self.args.get('use_tpu', False) - - if use_tpu: - # For TPU, create a custom DataLoader that properly handles our batch-returning Dataset - # TPU requires specific DataLoader configuration to avoid batch_sampler issues - self.train_loader = DataLoader( - self.train_dataset, - batch_size = None, # None because our Dataset returns batches - sampler = None, # Disable sampler to avoid batch_sampler conflicts - batch_sampler = None, # Explicitly set to None - shuffle = False, # Can't shuffle with custom batching - num_workers = num_workers, - pin_memory = False, # TPU doesn't support pin_memory - collate_fn = lambda x: x if isinstance(x, dict) else x[0] # Handle both dict and list formats - ) - else: - # Standard GPU/CPU configuration - self.train_loader = DataLoader( - self.train_dataset, - batch_size = None, # Dataset.__getitem__() already returns batches - shuffle = self.args['dataset']['loader_shuffle'], - num_workers = num_workers, - pin_memory = True - ) + # Standard DataLoader configuration - let Accelerator handle device-specific optimizations + self.train_loader = DataLoader( + self.train_dataset, + batch_size = None, # Dataset.__getitem__() already returns batches + shuffle = self.args['dataset']['loader_shuffle'], + num_workers = self.args['dataset']['num_dataloader_workers'], + pin_memory = True + ) # val dataset and dataloader self.val_dataset = BrainToTextDataset( @@ -223,27 +203,14 @@ class BrainToTextDecoder_Trainer: random_seed = self.args['dataset']['seed'], feature_subset = feature_subset ) - if use_tpu: - # For TPU, create a custom DataLoader that properly handles our batch-returning Dataset - self.val_loader = DataLoader( - self.val_dataset, - batch_size = None, # None because our Dataset returns batches - sampler = None, # Disable sampler to avoid batch_sampler conflicts - batch_sampler = None, # Explicitly set to None - shuffle = False, - num_workers = 0, # Keep validation dataloader single-threaded for consistency - pin_memory = False, # TPU doesn't support pin_memory - collate_fn = lambda x: x if isinstance(x, dict) else x[0] # Handle both dict and list formats - ) - else: - # Standard GPU/CPU configuration - self.val_loader = DataLoader( - self.val_dataset, - batch_size = None, # Dataset.__getitem__() already returns batches - shuffle = False, - num_workers = 0, # Keep validation dataloader single-threaded for consistency - pin_memory = True - ) + # Standard validation DataLoader configuration + self.val_loader = DataLoader( + self.val_dataset, + batch_size = None, # Dataset.__getitem__() already returns batches + shuffle = False, + num_workers = 0, # Keep validation dataloader single-threaded for consistency + pin_memory = True + ) self.logger.info("Successfully initialized datasets") @@ -278,36 +245,20 @@ class BrainToTextDecoder_Trainer: param.requires_grad = False # Prepare model, optimizer, scheduler, and dataloaders for distributed training - # For TPU, don't prepare DataLoaders with Accelerator to avoid batch_sampler issues - use_tpu = self.args.get('use_tpu', False) - - if use_tpu: - # On TPU, only prepare model, optimizer, and scheduler - ( - self.model, - self.optimizer, - self.learning_rate_scheduler, - ) = self.accelerator.prepare( - self.model, - self.optimizer, - self.learning_rate_scheduler, - ) - # DataLoaders remain unprepared but will work with our custom configuration - else: - # Standard GPU/CPU preparation including DataLoaders - ( - self.model, - self.optimizer, - self.learning_rate_scheduler, - self.train_loader, - self.val_loader, - ) = self.accelerator.prepare( - self.model, - self.optimizer, - self.learning_rate_scheduler, - self.train_loader, - self.val_loader, - ) + # Let Accelerator handle everything automatically for both GPU and TPU + ( + self.model, + self.optimizer, + self.learning_rate_scheduler, + self.train_loader, + self.val_loader, + ) = self.accelerator.prepare( + self.model, + self.optimizer, + self.learning_rate_scheduler, + self.train_loader, + self.val_loader, + ) self.logger.info("Prepared model and dataloaders with Accelerator") @@ -578,22 +529,12 @@ class BrainToTextDecoder_Trainer: # Train step start_time = time.time() - # Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator - use_tpu = self.args.get('use_tpu', False) - if use_tpu: - # Manual data movement for TPU since DataLoaders are not prepared by Accelerator - features = batch['input_features'].to(self.device) - labels = batch['seq_class_ids'].to(self.device) - n_time_steps = batch['n_time_steps'].to(self.device) - phone_seq_lens = batch['phone_seq_lens'].to(self.device) - day_indicies = batch['day_indicies'].to(self.device) - else: - # For GPU/CPU, data is automatically moved to device by Accelerator - features = batch['input_features'] - labels = batch['seq_class_ids'] - n_time_steps = batch['n_time_steps'] - phone_seq_lens = batch['phone_seq_lens'] - day_indicies = batch['day_indicies'] + # Data is automatically moved to device by Accelerator + features = batch['input_features'] + labels = batch['seq_class_ids'] + n_time_steps = batch['n_time_steps'] + phone_seq_lens = batch['phone_seq_lens'] + day_indicies = batch['day_indicies'] # Use Accelerator's autocast (mixed precision handled by Accelerator init) with self.accelerator.autocast(): @@ -757,22 +698,12 @@ class BrainToTextDecoder_Trainer: for i, batch in enumerate(loader): - # Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator - use_tpu = self.args.get('use_tpu', False) - if use_tpu: - # Manual data movement for TPU since DataLoaders are not prepared by Accelerator - features = batch['input_features'].to(self.device) - labels = batch['seq_class_ids'].to(self.device) - n_time_steps = batch['n_time_steps'].to(self.device) - phone_seq_lens = batch['phone_seq_lens'].to(self.device) - day_indicies = batch['day_indicies'].to(self.device) - else: - # For GPU/CPU, data is automatically moved to device by Accelerator - features = batch['input_features'] - labels = batch['seq_class_ids'] - n_time_steps = batch['n_time_steps'] - phone_seq_lens = batch['phone_seq_lens'] - day_indicies = batch['day_indicies'] + # Data is automatically moved to device by Accelerator + features = batch['input_features'] + labels = batch['seq_class_ids'] + n_time_steps = batch['n_time_steps'] + phone_seq_lens = batch['phone_seq_lens'] + day_indicies = batch['day_indicies'] # Determine if we should perform validation on this batch day = day_indicies[0].item() @@ -869,22 +800,14 @@ class BrainToTextDecoder_Trainer: def inference_batch(self, batch, mode='inference'): ''' - TPU-compatible inference method for processing a full batch + Inference method for processing a full batch ''' self.model.eval() - # Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator - use_tpu = self.args.get('use_tpu', False) - if use_tpu: - # Manual data movement for TPU since DataLoaders are not prepared by Accelerator - features = batch['input_features'].to(self.device) - day_indicies = batch['day_indicies'].to(self.device) - n_time_steps = batch['n_time_steps'].to(self.device) - else: - # For GPU/CPU, data is automatically moved to device by Accelerator - features = batch['input_features'] - day_indicies = batch['day_indicies'] - n_time_steps = batch['n_time_steps'] + # Data is automatically moved to device by Accelerator + features = batch['input_features'] + day_indicies = batch['day_indicies'] + n_time_steps = batch['n_time_steps'] with torch.no_grad(): with self.accelerator.autocast():