Files
b2txt25/model_training_nnn_tpu/jupyter_xla_compatibility.py
Zchen e7947f310c tpu
2025-10-15 14:33:49 +08:00

93 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ====================
# 单元格: XLA版本兼容性检查和修复
# ====================
import torch
import torch.nn as nn
print("🔧 PyTorch XLA版本兼容性检查...")
# 导入XLA
import torch_xla.core.xla_model as xm
print("✅ PyTorch XLA导入成功!")
# 定义兼容性函数
def get_xla_world_size():
"""获取XLA world size兼容不同版本"""
try:
return xm.xrt_world_size()
except AttributeError:
try:
return xm.get_world_size()
except AttributeError:
return 1 # 默认返回1
def get_xla_ordinal():
"""获取XLA ordinal兼容不同版本"""
try:
return xm.get_ordinal()
except AttributeError:
return 0 # 默认返回0
def xla_mark_step():
"""XLA mark step兼容不同版本"""
try:
xm.mark_step()
except AttributeError:
try:
xm.wait_device_ops()
except AttributeError:
pass # 如果都不可用,则跳过
def check_xla_device():
"""检查XLA设备状态"""
try:
device = xm.xla_device()
print(f"📱 XLA设备: {device}")
world_size = get_xla_world_size()
ordinal = get_xla_ordinal()
print(f"🌍 World Size: {world_size}")
print(f"🔢 Ordinal: {ordinal}")
# 检测设备类型
device_str = str(device)
if 'xla' in device_str and 'cpu' not in device_str:
print("✅ 检测到TPU设备")
return True, "TPU"
elif 'xla' in device_str and 'cpu' in device_str:
print("⚠️ XLA CPU模拟模式")
return True, "XLA_CPU"
else:
print("❌ 未检测到XLA设备")
return False, "CPU"
except Exception as e:
print(f"❌ XLA设备检查失败: {e}")
return False, "ERROR"
# 执行兼容性检查
device_available, device_type = check_xla_device()
if device_available:
print(f"✅ XLA环境正常设备类型: {device_type}")
# 测试基本XLA操作
print("🧪 测试基本XLA操作...")
try:
device = xm.xla_device()
x = torch.randn(2, 2, device=device)
y = torch.matmul(x, x)
# 测试同步函数
xla_mark_step()
print("✅ 基本XLA操作测试成功")
except Exception as e:
print(f"❌ XLA操作测试失败: {e}")
else:
print("❌ XLA环境不可用")
print("\n💡 兼容性检查完成,可以运行后续单元格")