Files
b2txt25/model_training_nnn_tpu/jupyter_xla_setup.py
Zchen 56fa336af0 tpu
2025-10-15 14:26:11 +08:00

45 lines
1.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ====================
# 单元格1: 环境设置 (必须第一个运行!)
# ====================
import os
import time
import psutil
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
# ⚠️ 关键: 在导入torch_xla之前设置环境变量
print("🔧 设置XLA环境变量...")
# 获取CPU核心数
cpu_count = os.cpu_count()
print(f"检测到 {cpu_count} 个CPU核心")
# 设置XLA编译优化环境变量
os.environ['XLA_FLAGS'] = (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true '
f'--xla_force_host_platform_device_count={cpu_count}'
)
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
os.environ['XLA_USE_BF16'] = '1'
# 显示设置结果
print("✅ XLA环境变量已设置:")
print(f" CPU核心数: {cpu_count}")
print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
# 系统资源检查
memory_info = psutil.virtual_memory()
print(f"\n💾 系统内存信息:")
print(f" 总内存: {memory_info.total / (1024**3):.1f} GB")
print(f" 可用内存: {memory_info.available / (1024**3):.1f} GB")
print(f" 使用率: {memory_info.percent:.1f}%")
if memory_info.available < 8 * (1024**3): # 小于8GB
print("⚠️ 警告: 可用内存不足8GB可能影响XLA编译速度")
else:
print("✅ 内存充足")
print("\n🎯 环境设置完成! 现在可以运行下一个单元格")