#!/usr/bin/env python3 """ XLA Multi-threading Diagnostic Script 检查XLA编译是否正确使用多CPU核心 """ import os import psutil import time import threading from concurrent.futures import ThreadPoolExecutor def set_xla_environment(): """设置XLA环境变量""" cpu_count = os.cpu_count() # 设置XLA环境变量 os.environ['XLA_FLAGS'] = ( '--xla_cpu_multi_thread_eigen=true ' '--xla_cpu_enable_fast_math=true ' f'--xla_force_host_platform_device_count={cpu_count}' ) os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count) print(f"🔧 设置XLA环境变量:") print(f" CPU核心数: {cpu_count}") print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}") print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}") print("-" * 60) def monitor_cpu_usage(duration=30, interval=1): """监控CPU使用情况""" print(f"🔍 监控CPU使用情况 {duration}秒...") cpu_usage_data = [] start_time = time.time() while time.time() - start_time < duration: # 获取每个CPU核心的使用率 cpu_percent_per_core = psutil.cpu_percent(interval=interval, percpu=True) cpu_usage_data.append(cpu_percent_per_core) # 实时显示 active_cores = sum(1 for usage in cpu_percent_per_core if usage > 10) print(f"活跃核心数: {active_cores}/{len(cpu_percent_per_core)}, " f"平均使用率: {sum(cpu_percent_per_core)/len(cpu_percent_per_core):.1f}%", end='\r') print() # 换行 # 分析结果 if cpu_usage_data: avg_usage_per_core = [ sum(core_data) / len(cpu_usage_data) for core_data in zip(*cpu_usage_data) ] active_cores = sum(1 for avg in avg_usage_per_core if avg > 5) max_usage = max(avg_usage_per_core) print(f"📊 CPU使用分析:") print(f" 活跃的CPU核心: {active_cores}/{len(avg_usage_per_core)}") print(f" 最高平均使用率: {max_usage:.1f}%") for i, usage in enumerate(avg_usage_per_core): status = "🟢" if usage > 10 else "🔴" if usage > 5 else "⚫" print(f" CPU核心 {i}: {usage:.1f}% {status}") return active_cores > 1 return False def test_xla_compilation(): """测试XLA编译""" print(f"🚀 开始XLA编译测试...") try: import torch import torch_xla.core.xla_model as xm print(f"✅ PyTorch XLA导入成功") print(f" XLA设备: {xm.xla_device()}") print(f" XLA world size: {xm.xrt_world_size()}") # 创建一个简单的计算图进行编译 device = xm.xla_device() print(f"🔄 创建测试计算图...") x = torch.randn(100, 100, device=device) y = torch.randn(100, 100, device=device) print(f"🔄 执行矩阵运算 (将触发XLA编译)...") # 启动CPU监控 monitor_thread = threading.Thread( target=lambda: monitor_cpu_usage(20, 0.5), daemon=True ) monitor_thread.start() # 执行计算,触发编译 for i in range(10): z = torch.matmul(x, y) z = torch.relu(z) z = torch.matmul(z, x.T) if i == 0: print(f"🔄 首次计算完成 (XLA编译应该正在进行)...") elif i == 5: print(f"🔄 第6次计算完成...") # 等待监控完成 monitor_thread.join(timeout=25) print(f"✅ XLA测试完成") return True except ImportError as e: print(f"❌ PyTorch XLA导入失败: {e}") return False except Exception as e: print(f"❌ XLA测试失败: {e}") return False def main(): print("=" * 60) print("🧪 XLA多线程编译诊断工具") print("=" * 60) # 1. 设置环境 set_xla_environment() # 2. 测试XLA编译并监控CPU success = test_xla_compilation() print("=" * 60) if success: print("✅ 诊断完成") print("💡 如果看到多个CPU核心被激活,说明XLA多线程工作正常") print("💡 如果只有1-2个核心活跃,可能需要其他优化方法") else: print("❌ 诊断失败") print("💡 请检查PyTorch XLA安装和TPU环境") print("=" * 60) if __name__ == "__main__": main()