tpu
This commit is contained in:
45
model_training_nnn_tpu/jupyter_xla_setup.py
Normal file
45
model_training_nnn_tpu/jupyter_xla_setup.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# ====================
|
||||
# 单元格1: 环境设置 (必须第一个运行!)
|
||||
# ====================
|
||||
|
||||
import os
|
||||
import time
|
||||
import psutil
|
||||
from IPython.display import display, HTML, clear_output
|
||||
import ipywidgets as widgets
|
||||
|
||||
# ⚠️ 关键: 在导入torch_xla之前设置环境变量
|
||||
print("🔧 设置XLA环境变量...")
|
||||
|
||||
# 获取CPU核心数
|
||||
cpu_count = os.cpu_count()
|
||||
print(f"检测到 {cpu_count} 个CPU核心")
|
||||
|
||||
# 设置XLA编译优化环境变量
|
||||
os.environ['XLA_FLAGS'] = (
|
||||
'--xla_cpu_multi_thread_eigen=true '
|
||||
'--xla_cpu_enable_fast_math=true '
|
||||
f'--xla_force_host_platform_device_count={cpu_count}'
|
||||
)
|
||||
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
|
||||
os.environ['XLA_USE_BF16'] = '1'
|
||||
|
||||
# 显示设置结果
|
||||
print("✅ XLA环境变量已设置:")
|
||||
print(f" CPU核心数: {cpu_count}")
|
||||
print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
|
||||
print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
|
||||
|
||||
# 系统资源检查
|
||||
memory_info = psutil.virtual_memory()
|
||||
print(f"\n💾 系统内存信息:")
|
||||
print(f" 总内存: {memory_info.total / (1024**3):.1f} GB")
|
||||
print(f" 可用内存: {memory_info.available / (1024**3):.1f} GB")
|
||||
print(f" 使用率: {memory_info.percent:.1f}%")
|
||||
|
||||
if memory_info.available < 8 * (1024**3): # 小于8GB
|
||||
print("⚠️ 警告: 可用内存不足8GB,可能影响XLA编译速度")
|
||||
else:
|
||||
print("✅ 内存充足")
|
||||
|
||||
print("\n🎯 环境设置完成! 现在可以运行下一个单元格")
|
||||
Reference in New Issue
Block a user