45 lines
1.5 KiB
Python
45 lines
1.5 KiB
Python
# ====================
|
||
# 单元格1: 环境设置 (必须第一个运行!)
|
||
# ====================
|
||
|
||
import os
|
||
import time
|
||
import psutil
|
||
from IPython.display import display, HTML, clear_output
|
||
import ipywidgets as widgets
|
||
|
||
# ⚠️ 关键: 在导入torch_xla之前设置环境变量
|
||
print("🔧 设置XLA环境变量...")
|
||
|
||
# 获取CPU核心数
|
||
cpu_count = os.cpu_count()
|
||
print(f"检测到 {cpu_count} 个CPU核心")
|
||
|
||
# 设置XLA编译优化环境变量
|
||
os.environ['XLA_FLAGS'] = (
|
||
'--xla_cpu_multi_thread_eigen=true '
|
||
'--xla_cpu_enable_fast_math=true '
|
||
f'--xla_force_host_platform_device_count={cpu_count}'
|
||
)
|
||
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
|
||
os.environ['XLA_USE_BF16'] = '1'
|
||
|
||
# 显示设置结果
|
||
print("✅ XLA环境变量已设置:")
|
||
print(f" CPU核心数: {cpu_count}")
|
||
print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
|
||
print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
|
||
|
||
# 系统资源检查
|
||
memory_info = psutil.virtual_memory()
|
||
print(f"\n💾 系统内存信息:")
|
||
print(f" 总内存: {memory_info.total / (1024**3):.1f} GB")
|
||
print(f" 可用内存: {memory_info.available / (1024**3):.1f} GB")
|
||
print(f" 使用率: {memory_info.percent:.1f}%")
|
||
|
||
if memory_info.available < 8 * (1024**3): # 小于8GB
|
||
print("⚠️ 警告: 可用内存不足8GB,可能影响XLA编译速度")
|
||
else:
|
||
print("✅ 内存充足")
|
||
|
||
print("\n🎯 环境设置完成! 现在可以运行下一个单元格") |