131 lines
4.2 KiB
Python
131 lines
4.2 KiB
Python
![]() |
# ====================
|
||
|
# 单元格2: XLA编译进度监控
|
||
|
# ====================
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import time
|
||
|
import threading
|
||
|
from IPython.display import display, HTML, clear_output
|
||
|
import ipywidgets as widgets
|
||
|
|
||
|
# 导入XLA (环境变量已在单元格1中设置)
|
||
|
print("🚀 导入PyTorch XLA...")
|
||
|
import torch_xla.core.xla_model as xm
|
||
|
|
||
|
print(f"✅ XLA导入成功!")
|
||
|
print(f" TPU设备: {xm.xla_device()}")
|
||
|
print(f" World Size: {xm.xrt_world_size()}")
|
||
|
|
||
|
# 创建编译进度监控器
|
||
|
class JupyterCompilationMonitor:
|
||
|
def __init__(self):
|
||
|
self.start_time = None
|
||
|
self.is_monitoring = False
|
||
|
|
||
|
# 创建输出widget
|
||
|
self.output_widget = widgets.Output()
|
||
|
|
||
|
# 创建进度条
|
||
|
self.progress_bar = widgets.IntProgress(
|
||
|
value=0,
|
||
|
min=0,
|
||
|
max=100,
|
||
|
description='XLA编译:',
|
||
|
bar_style='info',
|
||
|
style={'bar_color': '#1f77b4'},
|
||
|
orientation='horizontal'
|
||
|
)
|
||
|
|
||
|
# 创建状态标签
|
||
|
self.status_label = widgets.HTML(
|
||
|
value="<b>准备开始编译...</b>"
|
||
|
)
|
||
|
|
||
|
# 创建CPU使用率显示
|
||
|
self.cpu_label = widgets.HTML(
|
||
|
value="CPU: ---%"
|
||
|
)
|
||
|
|
||
|
self.memory_label = widgets.HTML(
|
||
|
value="内存: ---%"
|
||
|
)
|
||
|
|
||
|
# 组合界面
|
||
|
self.monitor_box = widgets.VBox([
|
||
|
widgets.HTML("<h3>🔄 XLA编译监控</h3>"),
|
||
|
self.progress_bar,
|
||
|
self.status_label,
|
||
|
widgets.HBox([self.cpu_label, self.memory_label]),
|
||
|
self.output_widget
|
||
|
])
|
||
|
|
||
|
def start_monitoring(self):
|
||
|
"""开始监控"""
|
||
|
self.start_time = time.time()
|
||
|
self.is_monitoring = True
|
||
|
|
||
|
display(self.monitor_box)
|
||
|
|
||
|
# 启动监控线程
|
||
|
self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
||
|
self.monitor_thread.start()
|
||
|
|
||
|
def _monitor_loop(self):
|
||
|
"""监控循环"""
|
||
|
while self.is_monitoring:
|
||
|
try:
|
||
|
elapsed = time.time() - self.start_time
|
||
|
minutes = int(elapsed // 60)
|
||
|
seconds = int(elapsed % 60)
|
||
|
|
||
|
# 更新进度条 (模拟进度)
|
||
|
progress = min(int(elapsed / 10 * 100), 95) # 10秒内达到95%
|
||
|
self.progress_bar.value = progress
|
||
|
|
||
|
# 获取系统资源
|
||
|
cpu_percent = psutil.cpu_percent(interval=0.1)
|
||
|
memory_percent = psutil.virtual_memory().percent
|
||
|
|
||
|
# 更新显示
|
||
|
self.status_label.value = f"<b>编译进行中... ⏱️ {minutes:02d}:{seconds:02d}</b>"
|
||
|
self.cpu_label.value = f"<b>🖥️ CPU: {cpu_percent:5.1f}%</b>"
|
||
|
self.memory_label.value = f"<b>💾 内存: {memory_percent:5.1f}%</b>"
|
||
|
|
||
|
# 检测是否编译完成 (CPU使用率突然下降)
|
||
|
if elapsed > 10 and cpu_percent < 20: # 编译通常CPU使用率很高
|
||
|
self.complete_monitoring()
|
||
|
break
|
||
|
|
||
|
time.sleep(1)
|
||
|
|
||
|
except Exception as e:
|
||
|
with self.output_widget:
|
||
|
print(f"监控错误: {e}")
|
||
|
break
|
||
|
|
||
|
def complete_monitoring(self):
|
||
|
"""完成监控"""
|
||
|
if self.is_monitoring:
|
||
|
self.is_monitoring = False
|
||
|
elapsed = time.time() - self.start_time
|
||
|
|
||
|
self.progress_bar.value = 100
|
||
|
self.progress_bar.bar_style = 'success'
|
||
|
self.status_label.value = f"<b style='color: green'>✅ 编译完成! 总耗时: {elapsed:.2f}秒</b>"
|
||
|
|
||
|
with self.output_widget:
|
||
|
print(f"\n🎉 XLA编译成功完成!")
|
||
|
print(f"⏱️ 总耗时: {elapsed:.2f}秒")
|
||
|
if elapsed < 60:
|
||
|
print("✅ 编译速度正常")
|
||
|
elif elapsed < 300:
|
||
|
print("⚠️ 编译稍慢,但可接受")
|
||
|
else:
|
||
|
print("❌ 编译过慢,建议检查设置")
|
||
|
|
||
|
# 创建全局监控器
|
||
|
compilation_monitor = JupyterCompilationMonitor()
|
||
|
|
||
|
print("✅ 编译监控器已准备就绪!")
|
||
|
print("💡 运行下一个单元格开始XLA编译测试")
|