tpu support fix
This commit is contained in:
@@ -48,9 +48,7 @@ class BrainToTextDecoder_Trainer:
|
|||||||
project_dir=args.get('output_dir', './output'),
|
project_dir=args.get('output_dir', './output'),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set even_batches to False to handle batch_size=None in DataLoaders
|
# Note: even_batches is handled automatically by Accelerator based on our DataLoader configuration
|
||||||
# For TPU, we need to handle the batch_sampler issue more carefully
|
|
||||||
self.accelerator.even_batches = False
|
|
||||||
|
|
||||||
# Trainer fields
|
# Trainer fields
|
||||||
self.args = args
|
self.args = args
|
||||||
@@ -192,15 +190,28 @@ class BrainToTextDecoder_Trainer:
|
|||||||
# For TPU environments, we need to be more careful about DataLoader configuration
|
# For TPU environments, we need to be more careful about DataLoader configuration
|
||||||
use_tpu = self.args.get('use_tpu', False)
|
use_tpu = self.args.get('use_tpu', False)
|
||||||
|
|
||||||
# TPU doesn't handle batch_size=None well, so use batch_size=1 for TPU
|
if use_tpu:
|
||||||
batch_size_setting = 1 if use_tpu else None
|
# For TPU, create a custom DataLoader that properly handles our batch-returning Dataset
|
||||||
|
# TPU requires specific DataLoader configuration to avoid batch_sampler issues
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
self.train_loader = DataLoader(
|
self.train_loader = DataLoader(
|
||||||
self.train_dataset,
|
self.train_dataset,
|
||||||
batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches, but TPU needs batch_size=1
|
batch_size = None, # None because our Dataset returns batches
|
||||||
|
sampler = None, # Disable sampler to avoid batch_sampler conflicts
|
||||||
|
batch_sampler = None, # Explicitly set to None
|
||||||
|
shuffle = False, # Can't shuffle with custom batching
|
||||||
|
num_workers = num_workers,
|
||||||
|
pin_memory = False, # TPU doesn't support pin_memory
|
||||||
|
collate_fn = lambda x: x[0] # Since Dataset returns batch, just pass it through
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Standard GPU/CPU configuration
|
||||||
|
self.train_loader = DataLoader(
|
||||||
|
self.train_dataset,
|
||||||
|
batch_size = None, # Dataset.__getitem__() already returns batches
|
||||||
shuffle = self.args['dataset']['loader_shuffle'],
|
shuffle = self.args['dataset']['loader_shuffle'],
|
||||||
num_workers = num_workers,
|
num_workers = num_workers,
|
||||||
pin_memory = not use_tpu # TPU doesn't support pin_memory
|
pin_memory = True
|
||||||
)
|
)
|
||||||
|
|
||||||
# val dataset and dataloader
|
# val dataset and dataloader
|
||||||
@@ -214,12 +225,26 @@ class BrainToTextDecoder_Trainer:
|
|||||||
random_seed = self.args['dataset']['seed'],
|
random_seed = self.args['dataset']['seed'],
|
||||||
feature_subset = feature_subset
|
feature_subset = feature_subset
|
||||||
)
|
)
|
||||||
|
if use_tpu:
|
||||||
|
# For TPU, create a custom DataLoader that properly handles our batch-returning Dataset
|
||||||
self.val_loader = DataLoader(
|
self.val_loader = DataLoader(
|
||||||
self.val_dataset,
|
self.val_dataset,
|
||||||
batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches, but TPU needs batch_size=1
|
batch_size = None, # None because our Dataset returns batches
|
||||||
|
sampler = None, # Disable sampler to avoid batch_sampler conflicts
|
||||||
|
batch_sampler = None, # Explicitly set to None
|
||||||
shuffle = False,
|
shuffle = False,
|
||||||
num_workers = 0, # Keep validation dataloader single-threaded for consistency
|
num_workers = 0, # Keep validation dataloader single-threaded for consistency
|
||||||
pin_memory = not use_tpu # TPU doesn't support pin_memory
|
pin_memory = False, # TPU doesn't support pin_memory
|
||||||
|
collate_fn = lambda x: x[0] # Since Dataset returns batch, just pass it through
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Standard GPU/CPU configuration
|
||||||
|
self.val_loader = DataLoader(
|
||||||
|
self.val_dataset,
|
||||||
|
batch_size = None, # Dataset.__getitem__() already returns batches
|
||||||
|
shuffle = False,
|
||||||
|
num_workers = 0, # Keep validation dataloader single-threaded for consistency
|
||||||
|
pin_memory = True
|
||||||
)
|
)
|
||||||
|
|
||||||
self.logger.info("Successfully initialized datasets")
|
self.logger.info("Successfully initialized datasets")
|
||||||
@@ -459,12 +484,7 @@ class BrainToTextDecoder_Trainer:
|
|||||||
Performing augmentations is much faster on GPU than CPU
|
Performing augmentations is much faster on GPU than CPU
|
||||||
'''
|
'''
|
||||||
|
|
||||||
# Handle TPU case where DataLoader with batch_size=1 adds an extra dimension
|
# TPU and GPU should now handle data consistently with our improved DataLoader configuration
|
||||||
use_tpu = self.args.get('use_tpu', False)
|
|
||||||
if use_tpu and features.dim() == 4 and features.size(0) == 1:
|
|
||||||
features = features.squeeze(0) # Remove the extra batch dimension added by DataLoader
|
|
||||||
if isinstance(n_time_steps, torch.Tensor) and n_time_steps.dim() == 2:
|
|
||||||
n_time_steps = n_time_steps.squeeze(0)
|
|
||||||
|
|
||||||
data_shape = features.shape
|
data_shape = features.shape
|
||||||
batch_size = data_shape[0]
|
batch_size = data_shape[0]
|
||||||
|
Reference in New Issue
Block a user