tpu maual dataloader
This commit is contained in:
		| @@ -49,7 +49,7 @@ class BrainToTextDecoder_Trainer: | ||||
|         ) | ||||
|  | ||||
|         # Set even_batches to False after initialization - required for batch_size=None DataLoaders | ||||
|         self.accelerator.even_batches = False | ||||
|         # Note: This may not be settable in all Accelerate versions, but we handle it in DataLoader config | ||||
|  | ||||
|         # Trainer fields | ||||
|         self.args = args | ||||
| @@ -280,19 +280,36 @@ class BrainToTextDecoder_Trainer: | ||||
|                 param.requires_grad = False | ||||
|  | ||||
|         # Prepare model, optimizer, scheduler, and dataloaders for distributed training | ||||
|         ( | ||||
|             self.model, | ||||
|             self.optimizer, | ||||
|             self.learning_rate_scheduler, | ||||
|             self.train_loader, | ||||
|             self.val_loader, | ||||
|         ) = self.accelerator.prepare( | ||||
|             self.model, | ||||
|             self.optimizer, | ||||
|             self.learning_rate_scheduler, | ||||
|             self.train_loader, | ||||
|             self.val_loader, | ||||
|         ) | ||||
|         # For TPU, don't prepare DataLoaders with Accelerator to avoid batch_sampler issues | ||||
|         use_tpu = self.args.get('use_tpu', False) | ||||
|  | ||||
|         if use_tpu: | ||||
|             # On TPU, only prepare model, optimizer, and scheduler | ||||
|             ( | ||||
|                 self.model, | ||||
|                 self.optimizer, | ||||
|                 self.learning_rate_scheduler, | ||||
|             ) = self.accelerator.prepare( | ||||
|                 self.model, | ||||
|                 self.optimizer, | ||||
|                 self.learning_rate_scheduler, | ||||
|             ) | ||||
|             # DataLoaders remain unprepared but will work with our custom configuration | ||||
|         else: | ||||
|             # Standard GPU/CPU preparation including DataLoaders | ||||
|             ( | ||||
|                 self.model, | ||||
|                 self.optimizer, | ||||
|                 self.learning_rate_scheduler, | ||||
|                 self.train_loader, | ||||
|                 self.val_loader, | ||||
|             ) = self.accelerator.prepare( | ||||
|                 self.model, | ||||
|                 self.optimizer, | ||||
|                 self.learning_rate_scheduler, | ||||
|                 self.train_loader, | ||||
|                 self.val_loader, | ||||
|             ) | ||||
|  | ||||
|         self.logger.info("Prepared model and dataloaders with Accelerator") | ||||
|  | ||||
| @@ -561,14 +578,24 @@ class BrainToTextDecoder_Trainer: | ||||
|             self.optimizer.zero_grad() | ||||
|              | ||||
|             # Train step | ||||
|             start_time = time.time()  | ||||
|             start_time = time.time() | ||||
|  | ||||
|             # Data is automatically moved to device by Accelerator | ||||
|             features = batch['input_features'] | ||||
|             labels = batch['seq_class_ids'] | ||||
|             n_time_steps = batch['n_time_steps'] | ||||
|             phone_seq_lens = batch['phone_seq_lens'] | ||||
|             day_indicies = batch['day_indicies'] | ||||
|             # Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator | ||||
|             use_tpu = self.args.get('use_tpu', False) | ||||
|             if use_tpu: | ||||
|                 # Manual data movement for TPU since DataLoaders are not prepared by Accelerator | ||||
|                 features = batch['input_features'].to(self.device) | ||||
|                 labels = batch['seq_class_ids'].to(self.device) | ||||
|                 n_time_steps = batch['n_time_steps'].to(self.device) | ||||
|                 phone_seq_lens = batch['phone_seq_lens'].to(self.device) | ||||
|                 day_indicies = batch['day_indicies'].to(self.device) | ||||
|             else: | ||||
|                 # For GPU/CPU, data is automatically moved to device by Accelerator | ||||
|                 features = batch['input_features'] | ||||
|                 labels = batch['seq_class_ids'] | ||||
|                 n_time_steps = batch['n_time_steps'] | ||||
|                 phone_seq_lens = batch['phone_seq_lens'] | ||||
|                 day_indicies = batch['day_indicies'] | ||||
|  | ||||
|             # Use Accelerator's autocast (mixed precision handled by Accelerator init) | ||||
|             with self.accelerator.autocast(): | ||||
| @@ -732,12 +759,22 @@ class BrainToTextDecoder_Trainer: | ||||
|  | ||||
|         for i, batch in enumerate(loader): | ||||
|  | ||||
|             # Data is automatically moved to device by Accelerator | ||||
|             features = batch['input_features'] | ||||
|             labels = batch['seq_class_ids'] | ||||
|             n_time_steps = batch['n_time_steps'] | ||||
|             phone_seq_lens = batch['phone_seq_lens'] | ||||
|             day_indicies = batch['day_indicies'] | ||||
|             # Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator | ||||
|             use_tpu = self.args.get('use_tpu', False) | ||||
|             if use_tpu: | ||||
|                 # Manual data movement for TPU since DataLoaders are not prepared by Accelerator | ||||
|                 features = batch['input_features'].to(self.device) | ||||
|                 labels = batch['seq_class_ids'].to(self.device) | ||||
|                 n_time_steps = batch['n_time_steps'].to(self.device) | ||||
|                 phone_seq_lens = batch['phone_seq_lens'].to(self.device) | ||||
|                 day_indicies = batch['day_indicies'].to(self.device) | ||||
|             else: | ||||
|                 # For GPU/CPU, data is automatically moved to device by Accelerator | ||||
|                 features = batch['input_features'] | ||||
|                 labels = batch['seq_class_ids'] | ||||
|                 n_time_steps = batch['n_time_steps'] | ||||
|                 phone_seq_lens = batch['phone_seq_lens'] | ||||
|                 day_indicies = batch['day_indicies'] | ||||
|  | ||||
|             # Determine if we should perform validation on this batch | ||||
|             day = day_indicies[0].item() | ||||
| @@ -838,10 +875,18 @@ class BrainToTextDecoder_Trainer: | ||||
|         ''' | ||||
|         self.model.eval() | ||||
|  | ||||
|         # Data is automatically moved to device by Accelerator | ||||
|         features = batch['input_features'] | ||||
|         day_indicies = batch['day_indicies'] | ||||
|         n_time_steps = batch['n_time_steps'] | ||||
|         # Handle data movement - for TPU, manually move to device since DataLoader wasn't prepared by Accelerator | ||||
|         use_tpu = self.args.get('use_tpu', False) | ||||
|         if use_tpu: | ||||
|             # Manual data movement for TPU since DataLoaders are not prepared by Accelerator | ||||
|             features = batch['input_features'].to(self.device) | ||||
|             day_indicies = batch['day_indicies'].to(self.device) | ||||
|             n_time_steps = batch['n_time_steps'].to(self.device) | ||||
|         else: | ||||
|             # For GPU/CPU, data is automatically moved to device by Accelerator | ||||
|             features = batch['input_features'] | ||||
|             day_indicies = batch['day_indicies'] | ||||
|             n_time_steps = batch['n_time_steps'] | ||||
|  | ||||
|         with torch.no_grad(): | ||||
|             with self.accelerator.autocast(): | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Zchen
					Zchen