数据加载器多线程加速
This commit is contained in:
@@ -43,6 +43,9 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
self.args = args
|
self.args = args
|
||||||
self.logger = None
|
self.logger = None
|
||||||
|
|
||||||
|
# Optimize CPU utilization for data pipeline (利用224核心)
|
||||||
|
self._configure_cpu_optimization()
|
||||||
|
|
||||||
# Initialize TPU strategy
|
# Initialize TPU strategy
|
||||||
self.strategy = create_tpu_strategy()
|
self.strategy = create_tpu_strategy()
|
||||||
print(f"Training on {self.strategy.num_replicas_in_sync} TPU cores")
|
print(f"Training on {self.strategy.num_replicas_in_sync} TPU cores")
|
||||||
@@ -123,6 +126,28 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
if self.mixed_precision:
|
if self.mixed_precision:
|
||||||
self.logger.info('Mixed precision (bfloat16) enabled for TPU training')
|
self.logger.info('Mixed precision (bfloat16) enabled for TPU training')
|
||||||
|
|
||||||
|
def _configure_cpu_optimization(self):
|
||||||
|
"""Configure CPU utilization to make use of 224 cores for data pipeline"""
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
|
# Get available CPU cores
|
||||||
|
available_cores = multiprocessing.cpu_count()
|
||||||
|
print(f"💻 Available CPU cores: {available_cores}")
|
||||||
|
|
||||||
|
# Optimize for data pipeline parallelism
|
||||||
|
# Use ~1/4 of cores for inter-op (between operations)
|
||||||
|
# Use ~1/8 of cores for intra-op (within operations)
|
||||||
|
inter_op_threads = min(32, available_cores // 4)
|
||||||
|
intra_op_threads = min(16, available_cores // 8)
|
||||||
|
|
||||||
|
tf.config.threading.set_inter_op_parallelism_threads(inter_op_threads)
|
||||||
|
tf.config.threading.set_intra_op_parallelism_threads(intra_op_threads)
|
||||||
|
|
||||||
|
print(f"🔧 CPU optimization configured:")
|
||||||
|
print(f" Inter-op parallelism: {inter_op_threads} threads")
|
||||||
|
print(f" Intra-op parallelism: {intra_op_threads} threads")
|
||||||
|
print(f" This will accelerate data loading and preprocessing while TPU handles training")
|
||||||
|
|
||||||
def _initialize_datasets(self):
|
def _initialize_datasets(self):
|
||||||
"""Initialize training and validation datasets"""
|
"""Initialize training and validation datasets"""
|
||||||
# Create file paths
|
# Create file paths
|
||||||
|
Reference in New Issue
Block a user