tpu
This commit is contained in:
@@ -200,7 +200,7 @@ class BrainToTextDecoder_Trainer:
|
||||
shuffle = False, # Can't shuffle with custom batching
|
||||
num_workers = num_workers,
|
||||
pin_memory = False, # TPU doesn't support pin_memory
|
||||
collate_fn = lambda x: x[0] # Since Dataset returns batch, just pass it through
|
||||
collate_fn = lambda x: x if isinstance(x, dict) else x[0] # Handle both dict and list formats
|
||||
)
|
||||
else:
|
||||
# Standard GPU/CPU configuration
|
||||
@@ -233,7 +233,7 @@ class BrainToTextDecoder_Trainer:
|
||||
shuffle = False,
|
||||
num_workers = 0, # Keep validation dataloader single-threaded for consistency
|
||||
pin_memory = False, # TPU doesn't support pin_memory
|
||||
collate_fn = lambda x: x[0] # Since Dataset returns batch, just pass it through
|
||||
collate_fn = lambda x: x if isinstance(x, dict) else x[0] # Handle both dict and list formats
|
||||
)
|
||||
else:
|
||||
# Standard GPU/CPU configuration
|
||||
|
Reference in New Issue
Block a user