From c6fc211b00335adfd13ffa996bd0cbafb833b45e Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Sun, 12 Oct 2025 21:08:15 +0800 Subject: [PATCH] tpu --- model_training_nnn/rnn_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model_training_nnn/rnn_trainer.py b/model_training_nnn/rnn_trainer.py index 4194051..2ee00bf 100644 --- a/model_training_nnn/rnn_trainer.py +++ b/model_training_nnn/rnn_trainer.py @@ -200,7 +200,7 @@ class BrainToTextDecoder_Trainer: shuffle = False, # Can't shuffle with custom batching num_workers = num_workers, pin_memory = False, # TPU doesn't support pin_memory - collate_fn = lambda x: x[0] # Since Dataset returns batch, just pass it through + collate_fn = lambda x: x if isinstance(x, dict) else x[0] # Handle both dict and list formats ) else: # Standard GPU/CPU configuration @@ -233,7 +233,7 @@ class BrainToTextDecoder_Trainer: shuffle = False, num_workers = 0, # Keep validation dataloader single-threaded for consistency pin_memory = False, # TPU doesn't support pin_memory - collate_fn = lambda x: x[0] # Since Dataset returns batch, just pass it through + collate_fn = lambda x: x if isinstance(x, dict) else x[0] # Handle both dict and list formats ) else: # Standard GPU/CPU configuration