# ==================== # 单元格: XLA版本兼容性检查和修复 # ==================== import torch import torch.nn as nn print("🔧 PyTorch XLA版本兼容性检查...") # 导入XLA import torch_xla.core.xla_model as xm print("✅ PyTorch XLA导入成功!") # 定义兼容性函数 def get_xla_world_size(): """获取XLA world size,兼容不同版本""" try: return xm.xrt_world_size() except AttributeError: try: return xm.get_world_size() except AttributeError: return 1 # 默认返回1 def get_xla_ordinal(): """获取XLA ordinal,兼容不同版本""" try: return xm.get_ordinal() except AttributeError: return 0 # 默认返回0 def xla_mark_step(): """XLA mark step,兼容不同版本""" try: xm.mark_step() except AttributeError: try: xm.wait_device_ops() except AttributeError: pass # 如果都不可用,则跳过 def check_xla_device(): """检查XLA设备状态""" try: device = xm.xla_device() print(f"📱 XLA设备: {device}") world_size = get_xla_world_size() ordinal = get_xla_ordinal() print(f"🌍 World Size: {world_size}") print(f"🔢 Ordinal: {ordinal}") # 检测设备类型 device_str = str(device) if 'xla' in device_str and 'cpu' not in device_str: print("✅ 检测到TPU设备") return True, "TPU" elif 'xla' in device_str and 'cpu' in device_str: print("⚠️ XLA CPU模拟模式") return True, "XLA_CPU" else: print("❌ 未检测到XLA设备") return False, "CPU" except Exception as e: print(f"❌ XLA设备检查失败: {e}") return False, "ERROR" # 执行兼容性检查 device_available, device_type = check_xla_device() if device_available: print(f"✅ XLA环境正常,设备类型: {device_type}") # 测试基本XLA操作 print("🧪 测试基本XLA操作...") try: device = xm.xla_device() x = torch.randn(2, 2, device=device) y = torch.matmul(x, x) # 测试同步函数 xla_mark_step() print("✅ 基本XLA操作测试成功") except Exception as e: print(f"❌ XLA操作测试失败: {e}") else: print("❌ XLA环境不可用") print("\n💡 兼容性检查完成,可以运行后续单元格")