数据加载器多线程加速

This commit is contained in:
Zchen
2025-10-16 01:17:36 +08:00
parent f84d6254e3
commit 69a7285886

View File

@@ -43,6 +43,9 @@ class BrainToTextDecoderTrainerTF:
self.args = args self.args = args
self.logger = None self.logger = None
# Optimize CPU utilization for data pipeline (利用224核心)
self._configure_cpu_optimization()
# Initialize TPU strategy # Initialize TPU strategy
self.strategy = create_tpu_strategy() self.strategy = create_tpu_strategy()
print(f"Training on {self.strategy.num_replicas_in_sync} TPU cores") print(f"Training on {self.strategy.num_replicas_in_sync} TPU cores")
@@ -123,6 +126,28 @@ class BrainToTextDecoderTrainerTF:
if self.mixed_precision: if self.mixed_precision:
self.logger.info('Mixed precision (bfloat16) enabled for TPU training') self.logger.info('Mixed precision (bfloat16) enabled for TPU training')
def _configure_cpu_optimization(self):
"""Configure CPU utilization to make use of 224 cores for data pipeline"""
import multiprocessing
# Get available CPU cores
available_cores = multiprocessing.cpu_count()
print(f"💻 Available CPU cores: {available_cores}")
# Optimize for data pipeline parallelism
# Use ~1/4 of cores for inter-op (between operations)
# Use ~1/8 of cores for intra-op (within operations)
inter_op_threads = min(32, available_cores // 4)
intra_op_threads = min(16, available_cores // 8)
tf.config.threading.set_inter_op_parallelism_threads(inter_op_threads)
tf.config.threading.set_intra_op_parallelism_threads(intra_op_threads)
print(f"🔧 CPU optimization configured:")
print(f" Inter-op parallelism: {inter_op_threads} threads")
print(f" Intra-op parallelism: {intra_op_threads} threads")
print(f" This will accelerate data loading and preprocessing while TPU handles training")
def _initialize_datasets(self): def _initialize_datasets(self):
"""Initialize training and validation datasets""" """Initialize training and validation datasets"""
# Create file paths # Create file paths