2025-10-15 14:26:11 +08:00
|
|
|
# ====================
|
|
|
|
# 单元格2: XLA编译进度监控
|
|
|
|
# ====================
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import time
|
|
|
|
import threading
|
|
|
|
from IPython.display import display, HTML, clear_output
|
|
|
|
import ipywidgets as widgets
|
|
|
|
|
|
|
|
# 导入XLA (环境变量已在单元格1中设置)
|
|
|
|
print("🚀 导入PyTorch XLA...")
|
|
|
|
import torch_xla.core.xla_model as xm
|
|
|
|
|
|
|
|
print(f"✅ XLA导入成功!")
|
|
|
|
print(f" TPU设备: {xm.xla_device()}")
|
2025-10-15 14:33:49 +08:00
|
|
|
|
|
|
|
# 兼容新版本PyTorch XLA
|
|
|
|
try:
|
|
|
|
world_size = xm.xrt_world_size()
|
|
|
|
print(f" World Size (旧API): {world_size}")
|
|
|
|
except AttributeError:
|
|
|
|
try:
|
|
|
|
world_size = xm.get_world_size()
|
|
|
|
print(f" World Size (新API): {world_size}")
|
|
|
|
except AttributeError:
|
|
|
|
print(" World Size: 无法获取 (可能在CPU模式)")
|
|
|
|
|
|
|
|
# 检查XLA版本兼容性
|
|
|
|
print("🔍 检查XLA API兼容性:")
|
|
|
|
api_available = []
|
|
|
|
api_deprecated = []
|
|
|
|
|
|
|
|
# 检查各种API
|
|
|
|
test_apis = [
|
|
|
|
('xrt_world_size', 'xrt_world_size()'),
|
|
|
|
('get_world_size', 'get_world_size()'),
|
|
|
|
('mark_step', 'mark_step()'),
|
|
|
|
('wait_device_ops', 'wait_device_ops()'),
|
|
|
|
('get_ordinal', 'get_ordinal()'),
|
|
|
|
('xla_device_count', 'xla_device_count()')
|
|
|
|
]
|
|
|
|
|
|
|
|
for api_name, api_desc in test_apis:
|
|
|
|
if hasattr(xm, api_name):
|
|
|
|
api_available.append(api_desc)
|
|
|
|
else:
|
|
|
|
api_deprecated.append(api_desc)
|
|
|
|
|
|
|
|
if api_available:
|
|
|
|
print(f" ✅ 可用API: {', '.join(api_available)}")
|
|
|
|
if api_deprecated:
|
|
|
|
print(f" ❌ 不可用API: {', '.join(api_deprecated)}")
|
2025-10-15 14:26:11 +08:00
|
|
|
|
|
|
|
# 创建编译进度监控器
|
|
|
|
class JupyterCompilationMonitor:
|
|
|
|
def __init__(self):
|
|
|
|
self.start_time = None
|
|
|
|
self.is_monitoring = False
|
|
|
|
|
|
|
|
# 创建输出widget
|
|
|
|
self.output_widget = widgets.Output()
|
|
|
|
|
|
|
|
# 创建进度条
|
|
|
|
self.progress_bar = widgets.IntProgress(
|
|
|
|
value=0,
|
|
|
|
min=0,
|
|
|
|
max=100,
|
|
|
|
description='XLA编译:',
|
|
|
|
bar_style='info',
|
|
|
|
style={'bar_color': '#1f77b4'},
|
|
|
|
orientation='horizontal'
|
|
|
|
)
|
|
|
|
|
|
|
|
# 创建状态标签
|
|
|
|
self.status_label = widgets.HTML(
|
|
|
|
value="<b>准备开始编译...</b>"
|
|
|
|
)
|
|
|
|
|
|
|
|
# 创建CPU使用率显示
|
|
|
|
self.cpu_label = widgets.HTML(
|
|
|
|
value="CPU: ---%"
|
|
|
|
)
|
|
|
|
|
|
|
|
self.memory_label = widgets.HTML(
|
|
|
|
value="内存: ---%"
|
|
|
|
)
|
|
|
|
|
|
|
|
# 组合界面
|
|
|
|
self.monitor_box = widgets.VBox([
|
|
|
|
widgets.HTML("<h3>🔄 XLA编译监控</h3>"),
|
|
|
|
self.progress_bar,
|
|
|
|
self.status_label,
|
|
|
|
widgets.HBox([self.cpu_label, self.memory_label]),
|
|
|
|
self.output_widget
|
|
|
|
])
|
|
|
|
|
|
|
|
def start_monitoring(self):
|
|
|
|
"""开始监控"""
|
|
|
|
self.start_time = time.time()
|
|
|
|
self.is_monitoring = True
|
|
|
|
|
|
|
|
display(self.monitor_box)
|
|
|
|
|
|
|
|
# 启动监控线程
|
|
|
|
self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
|
|
|
self.monitor_thread.start()
|
|
|
|
|
|
|
|
def _monitor_loop(self):
|
|
|
|
"""监控循环"""
|
|
|
|
while self.is_monitoring:
|
|
|
|
try:
|
|
|
|
elapsed = time.time() - self.start_time
|
|
|
|
minutes = int(elapsed // 60)
|
|
|
|
seconds = int(elapsed % 60)
|
|
|
|
|
|
|
|
# 更新进度条 (模拟进度)
|
|
|
|
progress = min(int(elapsed / 10 * 100), 95) # 10秒内达到95%
|
|
|
|
self.progress_bar.value = progress
|
|
|
|
|
|
|
|
# 获取系统资源
|
|
|
|
cpu_percent = psutil.cpu_percent(interval=0.1)
|
|
|
|
memory_percent = psutil.virtual_memory().percent
|
|
|
|
|
|
|
|
# 更新显示
|
|
|
|
self.status_label.value = f"<b>编译进行中... ⏱️ {minutes:02d}:{seconds:02d}</b>"
|
|
|
|
self.cpu_label.value = f"<b>🖥️ CPU: {cpu_percent:5.1f}%</b>"
|
|
|
|
self.memory_label.value = f"<b>💾 内存: {memory_percent:5.1f}%</b>"
|
|
|
|
|
|
|
|
# 检测是否编译完成 (CPU使用率突然下降)
|
|
|
|
if elapsed > 10 and cpu_percent < 20: # 编译通常CPU使用率很高
|
|
|
|
self.complete_monitoring()
|
|
|
|
break
|
|
|
|
|
|
|
|
time.sleep(1)
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
with self.output_widget:
|
|
|
|
print(f"监控错误: {e}")
|
|
|
|
break
|
|
|
|
|
|
|
|
def complete_monitoring(self):
|
|
|
|
"""完成监控"""
|
|
|
|
if self.is_monitoring:
|
|
|
|
self.is_monitoring = False
|
|
|
|
elapsed = time.time() - self.start_time
|
|
|
|
|
|
|
|
self.progress_bar.value = 100
|
|
|
|
self.progress_bar.bar_style = 'success'
|
|
|
|
self.status_label.value = f"<b style='color: green'>✅ 编译完成! 总耗时: {elapsed:.2f}秒</b>"
|
|
|
|
|
|
|
|
with self.output_widget:
|
|
|
|
print(f"\n🎉 XLA编译成功完成!")
|
|
|
|
print(f"⏱️ 总耗时: {elapsed:.2f}秒")
|
|
|
|
if elapsed < 60:
|
|
|
|
print("✅ 编译速度正常")
|
|
|
|
elif elapsed < 300:
|
|
|
|
print("⚠️ 编译稍慢,但可接受")
|
|
|
|
else:
|
|
|
|
print("❌ 编译过慢,建议检查设置")
|
|
|
|
|
|
|
|
# 创建全局监控器
|
|
|
|
compilation_monitor = JupyterCompilationMonitor()
|
|
|
|
|
|
|
|
print("✅ 编译监控器已准备就绪!")
|
|
|
|
print("💡 运行下一个单元格开始XLA编译测试")
|