93 lines
2.4 KiB
Python
93 lines
2.4 KiB
Python
# ====================
|
||
# 单元格: XLA版本兼容性检查和修复
|
||
# ====================
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
print("🔧 PyTorch XLA版本兼容性检查...")
|
||
|
||
# 导入XLA
|
||
import torch_xla.core.xla_model as xm
|
||
|
||
print("✅ PyTorch XLA导入成功!")
|
||
|
||
# 定义兼容性函数
|
||
def get_xla_world_size():
|
||
"""获取XLA world size,兼容不同版本"""
|
||
try:
|
||
return xm.xrt_world_size()
|
||
except AttributeError:
|
||
try:
|
||
return xm.get_world_size()
|
||
except AttributeError:
|
||
return 1 # 默认返回1
|
||
|
||
def get_xla_ordinal():
|
||
"""获取XLA ordinal,兼容不同版本"""
|
||
try:
|
||
return xm.get_ordinal()
|
||
except AttributeError:
|
||
return 0 # 默认返回0
|
||
|
||
def xla_mark_step():
|
||
"""XLA mark step,兼容不同版本"""
|
||
try:
|
||
xm.mark_step()
|
||
except AttributeError:
|
||
try:
|
||
xm.wait_device_ops()
|
||
except AttributeError:
|
||
pass # 如果都不可用,则跳过
|
||
|
||
def check_xla_device():
|
||
"""检查XLA设备状态"""
|
||
try:
|
||
device = xm.xla_device()
|
||
print(f"📱 XLA设备: {device}")
|
||
|
||
world_size = get_xla_world_size()
|
||
ordinal = get_xla_ordinal()
|
||
|
||
print(f"🌍 World Size: {world_size}")
|
||
print(f"🔢 Ordinal: {ordinal}")
|
||
|
||
# 检测设备类型
|
||
device_str = str(device)
|
||
if 'xla' in device_str and 'cpu' not in device_str:
|
||
print("✅ 检测到TPU设备")
|
||
return True, "TPU"
|
||
elif 'xla' in device_str and 'cpu' in device_str:
|
||
print("⚠️ XLA CPU模拟模式")
|
||
return True, "XLA_CPU"
|
||
else:
|
||
print("❌ 未检测到XLA设备")
|
||
return False, "CPU"
|
||
|
||
except Exception as e:
|
||
print(f"❌ XLA设备检查失败: {e}")
|
||
return False, "ERROR"
|
||
|
||
# 执行兼容性检查
|
||
device_available, device_type = check_xla_device()
|
||
|
||
if device_available:
|
||
print(f"✅ XLA环境正常,设备类型: {device_type}")
|
||
|
||
# 测试基本XLA操作
|
||
print("🧪 测试基本XLA操作...")
|
||
try:
|
||
device = xm.xla_device()
|
||
x = torch.randn(2, 2, device=device)
|
||
y = torch.matmul(x, x)
|
||
|
||
# 测试同步函数
|
||
xla_mark_step()
|
||
|
||
print("✅ 基本XLA操作测试成功")
|
||
|
||
except Exception as e:
|
||
print(f"❌ XLA操作测试失败: {e}")
|
||
else:
|
||
print("❌ XLA环境不可用")
|
||
|
||
print("\n💡 兼容性检查完成,可以运行后续单元格") |