148 lines
4.4 KiB
Python
148 lines
4.4 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
XLA Multi-threading Diagnostic Script
|
||
检查XLA编译是否正确使用多CPU核心
|
||
"""
|
||
|
||
import os
|
||
import psutil
|
||
import time
|
||
import threading
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
|
||
def set_xla_environment():
|
||
"""设置XLA环境变量"""
|
||
cpu_count = os.cpu_count()
|
||
|
||
# 设置XLA环境变量
|
||
os.environ['XLA_FLAGS'] = (
|
||
'--xla_cpu_multi_thread_eigen=true '
|
||
'--xla_cpu_enable_fast_math=true '
|
||
f'--xla_force_host_platform_device_count={cpu_count}'
|
||
)
|
||
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
|
||
|
||
print(f"🔧 设置XLA环境变量:")
|
||
print(f" CPU核心数: {cpu_count}")
|
||
print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
|
||
print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
|
||
print("-" * 60)
|
||
|
||
def monitor_cpu_usage(duration=30, interval=1):
|
||
"""监控CPU使用情况"""
|
||
print(f"🔍 监控CPU使用情况 {duration}秒...")
|
||
|
||
cpu_usage_data = []
|
||
start_time = time.time()
|
||
|
||
while time.time() - start_time < duration:
|
||
# 获取每个CPU核心的使用率
|
||
cpu_percent_per_core = psutil.cpu_percent(interval=interval, percpu=True)
|
||
cpu_usage_data.append(cpu_percent_per_core)
|
||
|
||
# 实时显示
|
||
active_cores = sum(1 for usage in cpu_percent_per_core if usage > 10)
|
||
print(f"活跃核心数: {active_cores}/{len(cpu_percent_per_core)}, "
|
||
f"平均使用率: {sum(cpu_percent_per_core)/len(cpu_percent_per_core):.1f}%",
|
||
end='\r')
|
||
|
||
print() # 换行
|
||
|
||
# 分析结果
|
||
if cpu_usage_data:
|
||
avg_usage_per_core = [
|
||
sum(core_data) / len(cpu_usage_data)
|
||
for core_data in zip(*cpu_usage_data)
|
||
]
|
||
|
||
active_cores = sum(1 for avg in avg_usage_per_core if avg > 5)
|
||
max_usage = max(avg_usage_per_core)
|
||
|
||
print(f"📊 CPU使用分析:")
|
||
print(f" 活跃的CPU核心: {active_cores}/{len(avg_usage_per_core)}")
|
||
print(f" 最高平均使用率: {max_usage:.1f}%")
|
||
|
||
for i, usage in enumerate(avg_usage_per_core):
|
||
status = "🟢" if usage > 10 else "🔴" if usage > 5 else "⚫"
|
||
print(f" CPU核心 {i}: {usage:.1f}% {status}")
|
||
|
||
return active_cores > 1
|
||
|
||
return False
|
||
|
||
def test_xla_compilation():
|
||
"""测试XLA编译"""
|
||
print(f"🚀 开始XLA编译测试...")
|
||
|
||
try:
|
||
import torch
|
||
import torch_xla.core.xla_model as xm
|
||
|
||
print(f"✅ PyTorch XLA导入成功")
|
||
print(f" XLA设备: {xm.xla_device()}")
|
||
print(f" XLA world size: {xm.xrt_world_size()}")
|
||
|
||
# 创建一个简单的计算图进行编译
|
||
device = xm.xla_device()
|
||
|
||
print(f"🔄 创建测试计算图...")
|
||
x = torch.randn(100, 100, device=device)
|
||
y = torch.randn(100, 100, device=device)
|
||
|
||
print(f"🔄 执行矩阵运算 (将触发XLA编译)...")
|
||
|
||
# 启动CPU监控
|
||
monitor_thread = threading.Thread(
|
||
target=lambda: monitor_cpu_usage(20, 0.5),
|
||
daemon=True
|
||
)
|
||
monitor_thread.start()
|
||
|
||
# 执行计算,触发编译
|
||
for i in range(10):
|
||
z = torch.matmul(x, y)
|
||
z = torch.relu(z)
|
||
z = torch.matmul(z, x.T)
|
||
if i == 0:
|
||
print(f"🔄 首次计算完成 (XLA编译应该正在进行)...")
|
||
elif i == 5:
|
||
print(f"🔄 第6次计算完成...")
|
||
|
||
# 等待监控完成
|
||
monitor_thread.join(timeout=25)
|
||
|
||
print(f"✅ XLA测试完成")
|
||
|
||
return True
|
||
|
||
except ImportError as e:
|
||
print(f"❌ PyTorch XLA导入失败: {e}")
|
||
return False
|
||
except Exception as e:
|
||
print(f"❌ XLA测试失败: {e}")
|
||
return False
|
||
|
||
def main():
|
||
print("=" * 60)
|
||
print("🧪 XLA多线程编译诊断工具")
|
||
print("=" * 60)
|
||
|
||
# 1. 设置环境
|
||
set_xla_environment()
|
||
|
||
# 2. 测试XLA编译并监控CPU
|
||
success = test_xla_compilation()
|
||
|
||
print("=" * 60)
|
||
if success:
|
||
print("✅ 诊断完成")
|
||
print("💡 如果看到多个CPU核心被激活,说明XLA多线程工作正常")
|
||
print("💡 如果只有1-2个核心活跃,可能需要其他优化方法")
|
||
else:
|
||
print("❌ 诊断失败")
|
||
print("💡 请检查PyTorch XLA安装和TPU环境")
|
||
|
||
print("=" * 60)
|
||
|
||
if __name__ == "__main__":
|
||
main() |