This commit is contained in:
Zchen
2025-10-12 21:31:07 +08:00
parent 00c94fd48b
commit 580648c058
2 changed files with 93 additions and 1 deletions

View File

@@ -19,7 +19,7 @@ import torchaudio.functional as F # for edit distance
from omegaconf import OmegaConf
# Import Accelerate for TPU support
from accelerate import Accelerator
from accelerate import Accelerator, DataLoaderConfiguration
from accelerate.utils import set_seed
torch.set_float32_matmul_precision('high') # makes float32 matmuls faster on some GPUs
@@ -40,12 +40,18 @@ class BrainToTextDecoder_Trainer:
args : dictionary of training arguments
'''
# Configure DataLoader behavior for TPU compatibility
dataloader_config = DataLoaderConfiguration(
even_batches=False # Required for batch_size=None DataLoaders on TPU
)
# Initialize Accelerator for TPU/multi-device support
self.accelerator = Accelerator(
mixed_precision='bf16' if args.get('use_amp', True) else 'no',
gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1),
log_with=None, # We'll use our own logging
project_dir=args.get('output_dir', './output'),
dataloader_config=dataloader_config,
)