Files
b2txt25/model_training_nnn_tpu/check_xla_threads.py

148 lines
4.4 KiB
Python
Raw Normal View History

2025-10-15 14:26:11 +08:00
#!/usr/bin/env python3
"""
XLA Multi-threading Diagnostic Script
检查XLA编译是否正确使用多CPU核心
"""
import os
import psutil
import time
import threading
from concurrent.futures import ThreadPoolExecutor
def set_xla_environment():
"""设置XLA环境变量"""
cpu_count = os.cpu_count()
# 设置XLA环境变量
os.environ['XLA_FLAGS'] = (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true '
f'--xla_force_host_platform_device_count={cpu_count}'
)
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
print(f"🔧 设置XLA环境变量:")
print(f" CPU核心数: {cpu_count}")
print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
print("-" * 60)
def monitor_cpu_usage(duration=30, interval=1):
"""监控CPU使用情况"""
print(f"🔍 监控CPU使用情况 {duration}秒...")
cpu_usage_data = []
start_time = time.time()
while time.time() - start_time < duration:
# 获取每个CPU核心的使用率
cpu_percent_per_core = psutil.cpu_percent(interval=interval, percpu=True)
cpu_usage_data.append(cpu_percent_per_core)
# 实时显示
active_cores = sum(1 for usage in cpu_percent_per_core if usage > 10)
print(f"活跃核心数: {active_cores}/{len(cpu_percent_per_core)}, "
f"平均使用率: {sum(cpu_percent_per_core)/len(cpu_percent_per_core):.1f}%",
end='\r')
print() # 换行
# 分析结果
if cpu_usage_data:
avg_usage_per_core = [
sum(core_data) / len(cpu_usage_data)
for core_data in zip(*cpu_usage_data)
]
active_cores = sum(1 for avg in avg_usage_per_core if avg > 5)
max_usage = max(avg_usage_per_core)
print(f"📊 CPU使用分析:")
print(f" 活跃的CPU核心: {active_cores}/{len(avg_usage_per_core)}")
print(f" 最高平均使用率: {max_usage:.1f}%")
for i, usage in enumerate(avg_usage_per_core):
status = "🟢" if usage > 10 else "🔴" if usage > 5 else ""
print(f" CPU核心 {i}: {usage:.1f}% {status}")
return active_cores > 1
return False
def test_xla_compilation():
"""测试XLA编译"""
print(f"🚀 开始XLA编译测试...")
try:
import torch
import torch_xla.core.xla_model as xm
print(f"✅ PyTorch XLA导入成功")
print(f" XLA设备: {xm.xla_device()}")
print(f" XLA world size: {xm.xrt_world_size()}")
# 创建一个简单的计算图进行编译
device = xm.xla_device()
print(f"🔄 创建测试计算图...")
x = torch.randn(100, 100, device=device)
y = torch.randn(100, 100, device=device)
print(f"🔄 执行矩阵运算 (将触发XLA编译)...")
# 启动CPU监控
monitor_thread = threading.Thread(
target=lambda: monitor_cpu_usage(20, 0.5),
daemon=True
)
monitor_thread.start()
# 执行计算,触发编译
for i in range(10):
z = torch.matmul(x, y)
z = torch.relu(z)
z = torch.matmul(z, x.T)
if i == 0:
print(f"🔄 首次计算完成 (XLA编译应该正在进行)...")
elif i == 5:
print(f"🔄 第6次计算完成...")
# 等待监控完成
monitor_thread.join(timeout=25)
print(f"✅ XLA测试完成")
return True
except ImportError as e:
print(f"❌ PyTorch XLA导入失败: {e}")
return False
except Exception as e:
print(f"❌ XLA测试失败: {e}")
return False
def main():
print("=" * 60)
print("🧪 XLA多线程编译诊断工具")
print("=" * 60)
# 1. 设置环境
set_xla_environment()
# 2. 测试XLA编译并监控CPU
success = test_xla_compilation()
print("=" * 60)
if success:
print("✅ 诊断完成")
print("💡 如果看到多个CPU核心被激活说明XLA多线程工作正常")
print("💡 如果只有1-2个核心活跃可能需要其他优化方法")
else:
print("❌ 诊断失败")
print("💡 请检查PyTorch XLA安装和TPU环境")
print("=" * 60)
if __name__ == "__main__":
main()