Files
b2txt25/model_training_nnn_tpu/jupyter_xla_monitor.py

131 lines
4.2 KiB
Python
Raw Normal View History

2025-10-15 14:26:11 +08:00
# ====================
# 单元格2: XLA编译进度监控
# ====================
import torch
import torch.nn as nn
import time
import threading
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
# 导入XLA (环境变量已在单元格1中设置)
print("🚀 导入PyTorch XLA...")
import torch_xla.core.xla_model as xm
print(f"✅ XLA导入成功!")
print(f" TPU设备: {xm.xla_device()}")
print(f" World Size: {xm.xrt_world_size()}")
# 创建编译进度监控器
class JupyterCompilationMonitor:
def __init__(self):
self.start_time = None
self.is_monitoring = False
# 创建输出widget
self.output_widget = widgets.Output()
# 创建进度条
self.progress_bar = widgets.IntProgress(
value=0,
min=0,
max=100,
description='XLA编译:',
bar_style='info',
style={'bar_color': '#1f77b4'},
orientation='horizontal'
)
# 创建状态标签
self.status_label = widgets.HTML(
value="<b>准备开始编译...</b>"
)
# 创建CPU使用率显示
self.cpu_label = widgets.HTML(
value="CPU: ---%"
)
self.memory_label = widgets.HTML(
value="内存: ---%"
)
# 组合界面
self.monitor_box = widgets.VBox([
widgets.HTML("<h3>🔄 XLA编译监控</h3>"),
self.progress_bar,
self.status_label,
widgets.HBox([self.cpu_label, self.memory_label]),
self.output_widget
])
def start_monitoring(self):
"""开始监控"""
self.start_time = time.time()
self.is_monitoring = True
display(self.monitor_box)
# 启动监控线程
self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
self.monitor_thread.start()
def _monitor_loop(self):
"""监控循环"""
while self.is_monitoring:
try:
elapsed = time.time() - self.start_time
minutes = int(elapsed // 60)
seconds = int(elapsed % 60)
# 更新进度条 (模拟进度)
progress = min(int(elapsed / 10 * 100), 95) # 10秒内达到95%
self.progress_bar.value = progress
# 获取系统资源
cpu_percent = psutil.cpu_percent(interval=0.1)
memory_percent = psutil.virtual_memory().percent
# 更新显示
self.status_label.value = f"<b>编译进行中... ⏱️ {minutes:02d}:{seconds:02d}</b>"
self.cpu_label.value = f"<b>🖥️ CPU: {cpu_percent:5.1f}%</b>"
self.memory_label.value = f"<b>💾 内存: {memory_percent:5.1f}%</b>"
# 检测是否编译完成 (CPU使用率突然下降)
if elapsed > 10 and cpu_percent < 20: # 编译通常CPU使用率很高
self.complete_monitoring()
break
time.sleep(1)
except Exception as e:
with self.output_widget:
print(f"监控错误: {e}")
break
def complete_monitoring(self):
"""完成监控"""
if self.is_monitoring:
self.is_monitoring = False
elapsed = time.time() - self.start_time
self.progress_bar.value = 100
self.progress_bar.bar_style = 'success'
self.status_label.value = f"<b style='color: green'>✅ 编译完成! 总耗时: {elapsed:.2f}秒</b>"
with self.output_widget:
print(f"\n🎉 XLA编译成功完成!")
print(f"⏱️ 总耗时: {elapsed:.2f}")
if elapsed < 60:
print("✅ 编译速度正常")
elif elapsed < 300:
print("⚠️ 编译稍慢,但可接受")
else:
print("❌ 编译过慢,建议检查设置")
# 创建全局监控器
compilation_monitor = JupyterCompilationMonitor()
print("✅ 编译监控器已准备就绪!")
print("💡 运行下一个单元格开始XLA编译测试")