tpu
This commit is contained in:
131
model_training_nnn_tpu/jupyter_xla_monitor.py
Normal file
131
model_training_nnn_tpu/jupyter_xla_monitor.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# ====================
|
||||
# 单元格2: XLA编译进度监控
|
||||
# ====================
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import time
|
||||
import threading
|
||||
from IPython.display import display, HTML, clear_output
|
||||
import ipywidgets as widgets
|
||||
|
||||
# 导入XLA (环境变量已在单元格1中设置)
|
||||
print("🚀 导入PyTorch XLA...")
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
print(f"✅ XLA导入成功!")
|
||||
print(f" TPU设备: {xm.xla_device()}")
|
||||
print(f" World Size: {xm.xrt_world_size()}")
|
||||
|
||||
# 创建编译进度监控器
|
||||
class JupyterCompilationMonitor:
|
||||
def __init__(self):
|
||||
self.start_time = None
|
||||
self.is_monitoring = False
|
||||
|
||||
# 创建输出widget
|
||||
self.output_widget = widgets.Output()
|
||||
|
||||
# 创建进度条
|
||||
self.progress_bar = widgets.IntProgress(
|
||||
value=0,
|
||||
min=0,
|
||||
max=100,
|
||||
description='XLA编译:',
|
||||
bar_style='info',
|
||||
style={'bar_color': '#1f77b4'},
|
||||
orientation='horizontal'
|
||||
)
|
||||
|
||||
# 创建状态标签
|
||||
self.status_label = widgets.HTML(
|
||||
value="<b>准备开始编译...</b>"
|
||||
)
|
||||
|
||||
# 创建CPU使用率显示
|
||||
self.cpu_label = widgets.HTML(
|
||||
value="CPU: ---%"
|
||||
)
|
||||
|
||||
self.memory_label = widgets.HTML(
|
||||
value="内存: ---%"
|
||||
)
|
||||
|
||||
# 组合界面
|
||||
self.monitor_box = widgets.VBox([
|
||||
widgets.HTML("<h3>🔄 XLA编译监控</h3>"),
|
||||
self.progress_bar,
|
||||
self.status_label,
|
||||
widgets.HBox([self.cpu_label, self.memory_label]),
|
||||
self.output_widget
|
||||
])
|
||||
|
||||
def start_monitoring(self):
|
||||
"""开始监控"""
|
||||
self.start_time = time.time()
|
||||
self.is_monitoring = True
|
||||
|
||||
display(self.monitor_box)
|
||||
|
||||
# 启动监控线程
|
||||
self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
||||
self.monitor_thread.start()
|
||||
|
||||
def _monitor_loop(self):
|
||||
"""监控循环"""
|
||||
while self.is_monitoring:
|
||||
try:
|
||||
elapsed = time.time() - self.start_time
|
||||
minutes = int(elapsed // 60)
|
||||
seconds = int(elapsed % 60)
|
||||
|
||||
# 更新进度条 (模拟进度)
|
||||
progress = min(int(elapsed / 10 * 100), 95) # 10秒内达到95%
|
||||
self.progress_bar.value = progress
|
||||
|
||||
# 获取系统资源
|
||||
cpu_percent = psutil.cpu_percent(interval=0.1)
|
||||
memory_percent = psutil.virtual_memory().percent
|
||||
|
||||
# 更新显示
|
||||
self.status_label.value = f"<b>编译进行中... ⏱️ {minutes:02d}:{seconds:02d}</b>"
|
||||
self.cpu_label.value = f"<b>🖥️ CPU: {cpu_percent:5.1f}%</b>"
|
||||
self.memory_label.value = f"<b>💾 内存: {memory_percent:5.1f}%</b>"
|
||||
|
||||
# 检测是否编译完成 (CPU使用率突然下降)
|
||||
if elapsed > 10 and cpu_percent < 20: # 编译通常CPU使用率很高
|
||||
self.complete_monitoring()
|
||||
break
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
with self.output_widget:
|
||||
print(f"监控错误: {e}")
|
||||
break
|
||||
|
||||
def complete_monitoring(self):
|
||||
"""完成监控"""
|
||||
if self.is_monitoring:
|
||||
self.is_monitoring = False
|
||||
elapsed = time.time() - self.start_time
|
||||
|
||||
self.progress_bar.value = 100
|
||||
self.progress_bar.bar_style = 'success'
|
||||
self.status_label.value = f"<b style='color: green'>✅ 编译完成! 总耗时: {elapsed:.2f}秒</b>"
|
||||
|
||||
with self.output_widget:
|
||||
print(f"\n🎉 XLA编译成功完成!")
|
||||
print(f"⏱️ 总耗时: {elapsed:.2f}秒")
|
||||
if elapsed < 60:
|
||||
print("✅ 编译速度正常")
|
||||
elif elapsed < 300:
|
||||
print("⚠️ 编译稍慢,但可接受")
|
||||
else:
|
||||
print("❌ 编译过慢,建议检查设置")
|
||||
|
||||
# 创建全局监控器
|
||||
compilation_monitor = JupyterCompilationMonitor()
|
||||
|
||||
print("✅ 编译监控器已准备就绪!")
|
||||
print("💡 运行下一个单元格开始XLA编译测试")
|
Reference in New Issue
Block a user