tpu
This commit is contained in:
@@ -19,7 +19,7 @@ mode: train
|
|||||||
use_amp: true # whether to use automatic mixed precision (AMP) for training
|
use_amp: true # whether to use automatic mixed precision (AMP) for training
|
||||||
|
|
||||||
# TPU and distributed training settings
|
# TPU and distributed training settings
|
||||||
use_tpu: false # whether to use TPU for training (set to true for TPU)
|
use_tpu: true # whether to use TPU for training (set to true for TPU)
|
||||||
num_tpu_cores: 8 # number of TPU cores to use (typically 8 for v3-8 or v4-8)
|
num_tpu_cores: 8 # number of TPU cores to use (typically 8 for v3-8 or v4-8)
|
||||||
gradient_accumulation_steps: 1 # number of gradient accumulation steps for distributed training
|
gradient_accumulation_steps: 1 # number of gradient accumulation steps for distributed training
|
||||||
dataloader_num_workers: 0 # set to 0 for TPU to avoid multiprocessing issues
|
dataloader_num_workers: 0 # set to 0 for TPU to avoid multiprocessing issues
|
||||||
|
@@ -46,7 +46,6 @@ class BrainToTextDecoder_Trainer:
|
|||||||
gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1),
|
gradient_accumulation_steps=args.get('gradient_accumulation_steps', 1),
|
||||||
log_with=None, # We'll use our own logging
|
log_with=None, # We'll use our own logging
|
||||||
project_dir=args.get('output_dir', './output'),
|
project_dir=args.get('output_dir', './output'),
|
||||||
even_batches=False, # Required for batch_size=None DataLoaders
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user