From 7cc9c41b7f9f47d1ad07905d85227d909907d1f3 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Sun, 12 Oct 2025 20:43:43 +0800 Subject: [PATCH] tpu maual dataloader --- model_training_nnn/rnn_trainer.py | 107 +++++++++++++++++++++--------- 1 file changed, 76 insertions(+), 31 deletions(-) diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index 3755234..a862607 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -49,7 +49,7 @@ class BrainToTextDecoder_Trainer: ) # Set even_batches to False after initialization - required for batch_size=None DataLoaders - self.accelerator.even_batches = False + # Note: This may not be settable in all Accelerate versions, but we handle it in DataLoader config # Trainer fields self.args = args @@ -280,19 +280,36 @@ class BrainToTextDecoder_Trainer: param.requires_grad = False # Prepare model, optimizer, scheduler, and dataloaders for distributed training - ( - self.model, - self.optimizer, - self.learning_rate_scheduler, - self.train_loader, - self.val_loader, - ) = self.accelerator.prepare( - self.model, - self.optimizer, - self.learning_rate_scheduler, - self.train_loader, - self.val_loader, - ) + # For TPU, don't prepare DataLoaders with Accelerator to avoid batch_sampler issues + use_tpu = self.args.get('use_tpu', False) + + if use_tpu: + # On TPU, only prepare model, optimizer, and scheduler + ( + self.model, + self.optimizer, + self.learning_rate_scheduler, + ) = self.accelerator.prepare( + self.model, + self.optimizer, + self.learning_rate_scheduler, + ) + # DataLoaders remain unprepared but will work with our custom configuration + else: + # Standard GPU/CPU preparation including DataLoaders + ( + self.model, + self.optimizer, + self.learning_rate_scheduler, + self.train_loader, + self.val_loader, + ) = self.accelerator.prepare( + self.model, + self.optimizer, + self.learning_rate_scheduler, + self.train_loader, + self.val_loader, + ) self.logger.info("Prepared model and dataloaders with Accelerator") @@ -561,14 +578,24 @@ class BrainToTextDecoder_Trainer: self.optimizer.zero_grad() # Train step - start_time = time.time() + start_time = time.time() - # Data is automatically moved to device by Accelerator - features = batch['input_features'] - labels = batch['seq_class_ids'] - n_time_steps = batch['n_time_steps'] - phone_seq_lens = batch['phone_seq_lens'] - day_indicies = batch['day_indicies'] + # Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator + use_tpu = self.args.get('use_tpu', False) + if use_tpu: + # Manual data movement for TPU since DataLoaders are not prepared by Accelerator + features = batch['input_features'].to(self.device) + labels = batch['seq_class_ids'].to(self.device) + n_time_steps = batch['n_time_steps'].to(self.device) + phone_seq_lens = batch['phone_seq_lens'].to(self.device) + day_indicies = batch['day_indicies'].to(self.device) + else: + # For GPU/CPU, data is automatically moved to device by Accelerator + features = batch['input_features'] + labels = batch['seq_class_ids'] + n_time_steps = batch['n_time_steps'] + phone_seq_lens = batch['phone_seq_lens'] + day_indicies = batch['day_indicies'] # Use Accelerator's autocast (mixed precision handled by Accelerator init) with self.accelerator.autocast(): @@ -732,12 +759,22 @@ class BrainToTextDecoder_Trainer: for i, batch in enumerate(loader): - # Data is automatically moved to device by Accelerator - features = batch['input_features'] - labels = batch['seq_class_ids'] - n_time_steps = batch['n_time_steps'] - phone_seq_lens = batch['phone_seq_lens'] - day_indicies = batch['day_indicies'] + # Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator + use_tpu = self.args.get('use_tpu', False) + if use_tpu: + # Manual data movement for TPU since DataLoaders are not prepared by Accelerator + features = batch['input_features'].to(self.device) + labels = batch['seq_class_ids'].to(self.device) + n_time_steps = batch['n_time_steps'].to(self.device) + phone_seq_lens = batch['phone_seq_lens'].to(self.device) + day_indicies = batch['day_indicies'].to(self.device) + else: + # For GPU/CPU, data is automatically moved to device by Accelerator + features = batch['input_features'] + labels = batch['seq_class_ids'] + n_time_steps = batch['n_time_steps'] + phone_seq_lens = batch['phone_seq_lens'] + day_indicies = batch['day_indicies'] # Determine if we should perform validation on this batch day = day_indicies[0].item() @@ -838,10 +875,18 @@ class BrainToTextDecoder_Trainer: ''' self.model.eval() - # Data is automatically moved to device by Accelerator - features = batch['input_features'] - day_indicies = batch['day_indicies'] - n_time_steps = batch['n_time_steps'] + # Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator + use_tpu = self.args.get('use_tpu', False) + if use_tpu: + # Manual data movement for TPU since DataLoaders are not prepared by Accelerator + features = batch['input_features'].to(self.device) + day_indicies = batch['day_indicies'].to(self.device) + n_time_steps = batch['n_time_steps'].to(self.device) + else: + # For GPU/CPU, data is automatically moved to device by Accelerator + features = batch['input_features'] + day_indicies = batch['day_indicies'] + n_time_steps = batch['n_time_steps'] with torch.no_grad(): with self.accelerator.autocast():