tpu支持
This commit is contained in:
@@ -72,11 +72,11 @@ class BrainToTextDecoder_Trainer:
|
||||
|
||||
# Create output directory
|
||||
if args['mode'] == 'train':
|
||||
os.makedirs(self.args['output_dir'], exist_ok=False)
|
||||
os.makedirs(self.args['output_dir'], exist_ok=True)
|
||||
|
||||
# Create checkpoint directory
|
||||
if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']:
|
||||
os.makedirs(self.args['checkpoint_dir'], exist_ok=False)
|
||||
if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']:
|
||||
os.makedirs(self.args['checkpoint_dir'], exist_ok=True)
|
||||
|
||||
# Set up logging
|
||||
self.logger = logging.getLogger(__name__)
|
||||
@@ -188,12 +188,16 @@ class BrainToTextDecoder_Trainer:
|
||||
# Use TPU-optimized dataloader settings if TPU is enabled
|
||||
num_workers = self.args['dataset']['dataloader_num_workers'] if self.args.get('use_tpu', False) else self.args['dataset']['num_dataloader_workers']
|
||||
|
||||
# For TPU environments or when batch_size=None causes issues, use batch_size=1
|
||||
# since our dataset already returns complete batches
|
||||
batch_size_setting = 1 if (self.args.get('use_tpu', False) or self.accelerator.device.type == 'xla') else None
|
||||
|
||||
self.train_loader = DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size = None, # Dataset.__getitem__() already returns batches
|
||||
batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches
|
||||
shuffle = self.args['dataset']['loader_shuffle'],
|
||||
num_workers = num_workers,
|
||||
pin_memory = True
|
||||
pin_memory = True if self.accelerator.device.type != 'xla' else False # TPU doesn't support pin_memory
|
||||
)
|
||||
|
||||
# val dataset and dataloader
|
||||
@@ -209,10 +213,10 @@ class BrainToTextDecoder_Trainer:
|
||||
)
|
||||
self.val_loader = DataLoader(
|
||||
self.val_dataset,
|
||||
batch_size = None, # Dataset.__getitem__() already returns batches
|
||||
batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches
|
||||
shuffle = False,
|
||||
num_workers = 0, # Keep validation dataloader single-threaded for consistency
|
||||
pin_memory = True
|
||||
pin_memory = True if self.accelerator.device.type != 'xla' else False # TPU doesn't support pin_memory
|
||||
)
|
||||
|
||||
self.logger.info("Successfully initialized datasets")
|
||||
|
Reference in New Issue
Block a user