This commit is contained in:
Zchen
2025-10-12 21:08:15 +08:00
parent a5ff1b4c8e
commit c6fc211b00

View File

@@ -200,7 +200,7 @@ class BrainToTextDecoder_Trainer:
shuffle = False, # Can't shuffle with custom batching
num_workers = num_workers,
pin_memory = False, # TPU doesn't support pin_memory
collate_fn = lambda x: x[0] # Since Dataset returns batch, just pass it through
collate_fn = lambda x: x if isinstance(x, dict) else x[0] # Handle both dict and list formats
)
else:
# Standard GPU/CPU configuration
@@ -233,7 +233,7 @@ class BrainToTextDecoder_Trainer:
shuffle = False,
num_workers = 0, # Keep validation dataloader single-threaded for consistency
pin_memory = False, # TPU doesn't support pin_memory
collate_fn = lambda x: x[0] # Since Dataset returns batch, just pass it through
collate_fn = lambda x: x if isinstance(x, dict) else x[0] # Handle both dict and list formats
)
else:
# Standard GPU/CPU configuration