Files
b2txt25/model_training_nnn_tpu/jupyter_xla_setup.py

45 lines
1.5 KiB
Python
Raw Normal View History

2025-10-15 14:26:11 +08:00
# ====================
# 单元格1: 环境设置 (必须第一个运行!)
# ====================
import os
import time
import psutil
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
# ⚠️ 关键: 在导入torch_xla之前设置环境变量
print("🔧 设置XLA环境变量...")
# 获取CPU核心数
cpu_count = os.cpu_count()
print(f"检测到 {cpu_count} 个CPU核心")
# 设置XLA编译优化环境变量
os.environ['XLA_FLAGS'] = (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true '
f'--xla_force_host_platform_device_count={cpu_count}'
)
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
os.environ['XLA_USE_BF16'] = '1'
# 显示设置结果
print("✅ XLA环境变量已设置:")
print(f" CPU核心数: {cpu_count}")
print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
# 系统资源检查
memory_info = psutil.virtual_memory()
print(f"\n💾 系统内存信息:")
print(f" 总内存: {memory_info.total / (1024**3):.1f} GB")
print(f" 可用内存: {memory_info.available / (1024**3):.1f} GB")
print(f" 使用率: {memory_info.percent:.1f}%")
if memory_info.available < 8 * (1024**3): # 小于8GB
print("⚠️ 警告: 可用内存不足8GB可能影响XLA编译速度")
else:
print("✅ 内存充足")
print("\n🎯 环境设置完成! 现在可以运行下一个单元格")