diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 253a0db..e14e47f 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -43,6 +43,9 @@ class BrainToTextDecoderTrainerTF: self.args = args self.logger = None + # Optimize CPU utilization for data pipeline (利用224核心) + self._configure_cpu_optimization() + # Initialize TPU strategy self.strategy = create_tpu_strategy() print(f"Training on {self.strategy.num_replicas_in_sync} TPU cores") @@ -123,6 +126,28 @@ class BrainToTextDecoderTrainerTF: if self.mixed_precision: self.logger.info('Mixed precision (bfloat16) enabled for TPU training') + def _configure_cpu_optimization(self): + """Configure CPU utilization to make use of 224 cores for data pipeline""" + import multiprocessing + + # Get available CPU cores + available_cores = multiprocessing.cpu_count() + print(f"💻 Available CPU cores: {available_cores}") + + # Optimize for data pipeline parallelism + # Use ~1/4 of cores for inter-op (between operations) + # Use ~1/8 of cores for intra-op (within operations) + inter_op_threads = min(32, available_cores // 4) + intra_op_threads = min(16, available_cores // 8) + + tf.config.threading.set_inter_op_parallelism_threads(inter_op_threads) + tf.config.threading.set_intra_op_parallelism_threads(intra_op_threads) + + print(f"🔧 CPU optimization configured:") + print(f" Inter-op parallelism: {inter_op_threads} threads") + print(f" Intra-op parallelism: {intra_op_threads} threads") + print(f" This will accelerate data loading and preprocessing while TPU handles training") + def _initialize_datasets(self): """Initialize training and validation datasets""" # Create file paths