Files
b2txt25/model_training_nnn_tpu/check_xla_threads.py
Zchen 56fa336af0 tpu
2025-10-15 14:26:11 +08:00

148 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
XLA Multi-threading Diagnostic Script
检查XLA编译是否正确使用多CPU核心
"""
import os
import psutil
import time
import threading
from concurrent.futures import ThreadPoolExecutor
def set_xla_environment():
"""设置XLA环境变量"""
cpu_count = os.cpu_count()
# 设置XLA环境变量
os.environ['XLA_FLAGS'] = (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true '
f'--xla_force_host_platform_device_count={cpu_count}'
)
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
print(f"🔧 设置XLA环境变量:")
print(f" CPU核心数: {cpu_count}")
print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
print("-" * 60)
def monitor_cpu_usage(duration=30, interval=1):
"""监控CPU使用情况"""
print(f"🔍 监控CPU使用情况 {duration}秒...")
cpu_usage_data = []
start_time = time.time()
while time.time() - start_time < duration:
# 获取每个CPU核心的使用率
cpu_percent_per_core = psutil.cpu_percent(interval=interval, percpu=True)
cpu_usage_data.append(cpu_percent_per_core)
# 实时显示
active_cores = sum(1 for usage in cpu_percent_per_core if usage > 10)
print(f"活跃核心数: {active_cores}/{len(cpu_percent_per_core)}, "
f"平均使用率: {sum(cpu_percent_per_core)/len(cpu_percent_per_core):.1f}%",
end='\r')
print() # 换行
# 分析结果
if cpu_usage_data:
avg_usage_per_core = [
sum(core_data) / len(cpu_usage_data)
for core_data in zip(*cpu_usage_data)
]
active_cores = sum(1 for avg in avg_usage_per_core if avg > 5)
max_usage = max(avg_usage_per_core)
print(f"📊 CPU使用分析:")
print(f" 活跃的CPU核心: {active_cores}/{len(avg_usage_per_core)}")
print(f" 最高平均使用率: {max_usage:.1f}%")
for i, usage in enumerate(avg_usage_per_core):
status = "🟢" if usage > 10 else "🔴" if usage > 5 else ""
print(f" CPU核心 {i}: {usage:.1f}% {status}")
return active_cores > 1
return False
def test_xla_compilation():
"""测试XLA编译"""
print(f"🚀 开始XLA编译测试...")
try:
import torch
import torch_xla.core.xla_model as xm
print(f"✅ PyTorch XLA导入成功")
print(f" XLA设备: {xm.xla_device()}")
print(f" XLA world size: {xm.xrt_world_size()}")
# 创建一个简单的计算图进行编译
device = xm.xla_device()
print(f"🔄 创建测试计算图...")
x = torch.randn(100, 100, device=device)
y = torch.randn(100, 100, device=device)
print(f"🔄 执行矩阵运算 (将触发XLA编译)...")
# 启动CPU监控
monitor_thread = threading.Thread(
target=lambda: monitor_cpu_usage(20, 0.5),
daemon=True
)
monitor_thread.start()
# 执行计算,触发编译
for i in range(10):
z = torch.matmul(x, y)
z = torch.relu(z)
z = torch.matmul(z, x.T)
if i == 0:
print(f"🔄 首次计算完成 (XLA编译应该正在进行)...")
elif i == 5:
print(f"🔄 第6次计算完成...")
# 等待监控完成
monitor_thread.join(timeout=25)
print(f"✅ XLA测试完成")
return True
except ImportError as e:
print(f"❌ PyTorch XLA导入失败: {e}")
return False
except Exception as e:
print(f"❌ XLA测试失败: {e}")
return False
def main():
print("=" * 60)
print("🧪 XLA多线程编译诊断工具")
print("=" * 60)
# 1. 设置环境
set_xla_environment()
# 2. 测试XLA编译并监控CPU
success = test_xla_compilation()
print("=" * 60)
if success:
print("✅ 诊断完成")
print("💡 如果看到多个CPU核心被激活说明XLA多线程工作正常")
print("💡 如果只有1-2个核心活跃可能需要其他优化方法")
else:
print("❌ 诊断失败")
print("💡 请检查PyTorch XLA安装和TPU环境")
print("=" * 60)
if __name__ == "__main__":
main()