tpu支持

This commit is contained in:
Zchen
2025-10-12 18:41:26 +08:00
parent 1a906d3248
commit 40e4d00576
5 changed files with 456 additions and 226 deletions

View File

@@ -72,11 +72,11 @@ class BrainToTextDecoder_Trainer:
# Create output directory
if args['mode'] == 'train':
os.makedirs(self.args['output_dir'], exist_ok=False)
os.makedirs(self.args['output_dir'], exist_ok=True)
# Create checkpoint directory
if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']:
os.makedirs(self.args['checkpoint_dir'], exist_ok=False)
if args['save_best_checkpoint'] or args['save_all_val_steps'] or args['save_final_model']:
os.makedirs(self.args['checkpoint_dir'], exist_ok=True)
# Set up logging
self.logger = logging.getLogger(__name__)
@@ -188,12 +188,16 @@ class BrainToTextDecoder_Trainer:
# Use TPU-optimized dataloader settings if TPU is enabled
num_workers = self.args['dataset']['dataloader_num_workers'] if self.args.get('use_tpu', False) else self.args['dataset']['num_dataloader_workers']
# For TPU environments or when batch_size=None causes issues, use batch_size=1
# since our dataset already returns complete batches
batch_size_setting = 1 if (self.args.get('use_tpu', False) or self.accelerator.device.type == 'xla') else None
self.train_loader = DataLoader(
self.train_dataset,
batch_size = None, # Dataset.__getitem__() already returns batches
batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches
shuffle = self.args['dataset']['loader_shuffle'],
num_workers = num_workers,
pin_memory = True
pin_memory = True if self.accelerator.device.type != 'xla' else False # TPU doesn't support pin_memory
)
# val dataset and dataloader
@@ -209,10 +213,10 @@ class BrainToTextDecoder_Trainer:
)
self.val_loader = DataLoader(
self.val_dataset,
batch_size = None, # Dataset.__getitem__() already returns batches
batch_size = batch_size_setting, # Dataset.__getitem__() already returns batches
shuffle = False,
num_workers = 0, # Keep validation dataloader single-threaded for consistency
pin_memory = True
pin_memory = True if self.accelerator.device.type != 'xla' else False # TPU doesn't support pin_memory
)
self.logger.info("Successfully initialized datasets")