93 lines
2.4 KiB
Python
93 lines
2.4 KiB
Python
![]() |
# ====================
|
|||
|
# 单元格: XLA版本兼容性检查和修复
|
|||
|
# ====================
|
|||
|
|
|||
|
import torch
|
|||
|
import torch.nn as nn
|
|||
|
print("🔧 PyTorch XLA版本兼容性检查...")
|
|||
|
|
|||
|
# 导入XLA
|
|||
|
import torch_xla.core.xla_model as xm
|
|||
|
|
|||
|
print("✅ PyTorch XLA导入成功!")
|
|||
|
|
|||
|
# 定义兼容性函数
|
|||
|
def get_xla_world_size():
|
|||
|
"""获取XLA world size,兼容不同版本"""
|
|||
|
try:
|
|||
|
return xm.xrt_world_size()
|
|||
|
except AttributeError:
|
|||
|
try:
|
|||
|
return xm.get_world_size()
|
|||
|
except AttributeError:
|
|||
|
return 1 # 默认返回1
|
|||
|
|
|||
|
def get_xla_ordinal():
|
|||
|
"""获取XLA ordinal,兼容不同版本"""
|
|||
|
try:
|
|||
|
return xm.get_ordinal()
|
|||
|
except AttributeError:
|
|||
|
return 0 # 默认返回0
|
|||
|
|
|||
|
def xla_mark_step():
|
|||
|
"""XLA mark step,兼容不同版本"""
|
|||
|
try:
|
|||
|
xm.mark_step()
|
|||
|
except AttributeError:
|
|||
|
try:
|
|||
|
xm.wait_device_ops()
|
|||
|
except AttributeError:
|
|||
|
pass # 如果都不可用,则跳过
|
|||
|
|
|||
|
def check_xla_device():
|
|||
|
"""检查XLA设备状态"""
|
|||
|
try:
|
|||
|
device = xm.xla_device()
|
|||
|
print(f"📱 XLA设备: {device}")
|
|||
|
|
|||
|
world_size = get_xla_world_size()
|
|||
|
ordinal = get_xla_ordinal()
|
|||
|
|
|||
|
print(f"🌍 World Size: {world_size}")
|
|||
|
print(f"🔢 Ordinal: {ordinal}")
|
|||
|
|
|||
|
# 检测设备类型
|
|||
|
device_str = str(device)
|
|||
|
if 'xla' in device_str and 'cpu' not in device_str:
|
|||
|
print("✅ 检测到TPU设备")
|
|||
|
return True, "TPU"
|
|||
|
elif 'xla' in device_str and 'cpu' in device_str:
|
|||
|
print("⚠️ XLA CPU模拟模式")
|
|||
|
return True, "XLA_CPU"
|
|||
|
else:
|
|||
|
print("❌ 未检测到XLA设备")
|
|||
|
return False, "CPU"
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ XLA设备检查失败: {e}")
|
|||
|
return False, "ERROR"
|
|||
|
|
|||
|
# 执行兼容性检查
|
|||
|
device_available, device_type = check_xla_device()
|
|||
|
|
|||
|
if device_available:
|
|||
|
print(f"✅ XLA环境正常,设备类型: {device_type}")
|
|||
|
|
|||
|
# 测试基本XLA操作
|
|||
|
print("🧪 测试基本XLA操作...")
|
|||
|
try:
|
|||
|
device = xm.xla_device()
|
|||
|
x = torch.randn(2, 2, device=device)
|
|||
|
y = torch.matmul(x, x)
|
|||
|
|
|||
|
# 测试同步函数
|
|||
|
xla_mark_step()
|
|||
|
|
|||
|
print("✅ 基本XLA操作测试成功")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ XLA操作测试失败: {e}")
|
|||
|
else:
|
|||
|
print("❌ XLA环境不可用")
|
|||
|
|
|||
|
print("\n💡 兼容性检查完成,可以运行后续单元格")
|