tpu
This commit is contained in:
@@ -200,7 +200,7 @@ class BrainToTextDecoder_Trainer:
|
|||||||
shuffle = False, # Can't shuffle with custom batching
|
shuffle = False, # Can't shuffle with custom batching
|
||||||
num_workers = num_workers,
|
num_workers = num_workers,
|
||||||
pin_memory = False, # TPU doesn't support pin_memory
|
pin_memory = False, # TPU doesn't support pin_memory
|
||||||
collate_fn = lambda x: x[0] # Since Dataset returns batch, just pass it through
|
collate_fn = lambda x: x if isinstance(x, dict) else x[0] # Handle both dict and list formats
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Standard GPU/CPU configuration
|
# Standard GPU/CPU configuration
|
||||||
@@ -233,7 +233,7 @@ class BrainToTextDecoder_Trainer:
|
|||||||
shuffle = False,
|
shuffle = False,
|
||||||
num_workers = 0, # Keep validation dataloader single-threaded for consistency
|
num_workers = 0, # Keep validation dataloader single-threaded for consistency
|
||||||
pin_memory = False, # TPU doesn't support pin_memory
|
pin_memory = False, # TPU doesn't support pin_memory
|
||||||
collate_fn = lambda x: x[0] # Since Dataset returns batch, just pass it through
|
collate_fn = lambda x: x if isinstance(x, dict) else x[0] # Handle both dict and list formats
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Standard GPU/CPU configuration
|
# Standard GPU/CPU configuration
|
||||||
|
Reference in New Issue
Block a user