From 9a146102ffc995a94c78a73cc493d14f80188705 Mon Sep 17 00:00:00 2001 From: hinata kaga Date: Sun, 20 Jul 2025 18:48:29 +0900 Subject: [PATCH] Fix GPU device ID mismatch causing CUDA invalid device error in train_model.py --- model_training/rnn_trainer.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/model_training/rnn_trainer.py b/model_training/rnn_trainer.py index c253ccb..8149dc4 100644 --- a/model_training/rnn_trainer.py +++ b/model_training/rnn_trainer.py @@ -83,12 +83,33 @@ class BrainToTextDecoder_Trainer: # Configure device pytorch will use if torch.cuda.is_available(): - self.device = f"cuda:{self.args['gpu_number']}" - else: - self.device = "cpu" + gpu_num = self.args.get('gpu_number', 0) + try: + gpu_num = int(gpu_num) + except ValueError: + self.logger.warning(f"Invalid gpu_number value: {gpu_num}. Using 0 instead.") + gpu_num = 0 + + max_gpu_index = torch.cuda.device_count() - 1 + if gpu_num > max_gpu_index: + self.logger.warning(f"Requested GPU {gpu_num} not available. Using GPU 0 instead.") + gpu_num = 0 + + try: + self.device = torch.device(f"cuda:{gpu_num}") + test_tensor = torch.tensor([1.0]).to(self.device) + test_tensor = test_tensor * 2 + except Exception as e: + self.logger.error(f"Error initializing CUDA device {gpu_num}: {str(e)}") + self.logger.info("Falling back to CPU") + self.device = torch.device("cpu") + else: + self.device = torch.device("cpu") self.logger.info(f'Using device: {self.device}') + + # Set seed if provided if self.args['seed'] != -1: np.random.seed(self.args['seed'])