tpu
This commit is contained in:
@@ -15,7 +15,43 @@ import torch_xla.core.xla_model as xm
|
||||
|
||||
print(f"✅ XLA导入成功!")
|
||||
print(f" TPU设备: {xm.xla_device()}")
|
||||
print(f" World Size: {xm.xrt_world_size()}")
|
||||
|
||||
# 兼容新版本PyTorch XLA
|
||||
try:
|
||||
world_size = xm.xrt_world_size()
|
||||
print(f" World Size (旧API): {world_size}")
|
||||
except AttributeError:
|
||||
try:
|
||||
world_size = xm.get_world_size()
|
||||
print(f" World Size (新API): {world_size}")
|
||||
except AttributeError:
|
||||
print(" World Size: 无法获取 (可能在CPU模式)")
|
||||
|
||||
# 检查XLA版本兼容性
|
||||
print("🔍 检查XLA API兼容性:")
|
||||
api_available = []
|
||||
api_deprecated = []
|
||||
|
||||
# 检查各种API
|
||||
test_apis = [
|
||||
('xrt_world_size', 'xrt_world_size()'),
|
||||
('get_world_size', 'get_world_size()'),
|
||||
('mark_step', 'mark_step()'),
|
||||
('wait_device_ops', 'wait_device_ops()'),
|
||||
('get_ordinal', 'get_ordinal()'),
|
||||
('xla_device_count', 'xla_device_count()')
|
||||
]
|
||||
|
||||
for api_name, api_desc in test_apis:
|
||||
if hasattr(xm, api_name):
|
||||
api_available.append(api_desc)
|
||||
else:
|
||||
api_deprecated.append(api_desc)
|
||||
|
||||
if api_available:
|
||||
print(f" ✅ 可用API: {', '.join(api_available)}")
|
||||
if api_deprecated:
|
||||
print(f" ❌ 不可用API: {', '.join(api_deprecated)}")
|
||||
|
||||
# 创建编译进度监控器
|
||||
class JupyterCompilationMonitor:
|
||||
|
Reference in New Issue
Block a user