# ==================== # 单元格2: XLA编译进度监控 # ==================== import torch import torch.nn as nn import time import threading from IPython.display import display, HTML, clear_output import ipywidgets as widgets # 导入XLA (环境变量已在单元格1中设置) print("🚀 导入PyTorch XLA...") import torch_xla.core.xla_model as xm print(f"✅ XLA导入成功!") print(f" TPU设备: {xm.xla_device()}") # 兼容新版本PyTorch XLA try: world_size = xm.xrt_world_size() print(f" World Size (旧API): {world_size}") except AttributeError: try: world_size = xm.get_world_size() print(f" World Size (新API): {world_size}") except AttributeError: print(" World Size: 无法获取 (可能在CPU模式)") # 检查XLA版本兼容性 print("🔍 检查XLA API兼容性:") api_available = [] api_deprecated = [] # 检查各种API test_apis = [ ('xrt_world_size', 'xrt_world_size()'), ('get_world_size', 'get_world_size()'), ('mark_step', 'mark_step()'), ('wait_device_ops', 'wait_device_ops()'), ('get_ordinal', 'get_ordinal()'), ('xla_device_count', 'xla_device_count()') ] for api_name, api_desc in test_apis: if hasattr(xm, api_name): api_available.append(api_desc) else: api_deprecated.append(api_desc) if api_available: print(f" ✅ 可用API: {', '.join(api_available)}") if api_deprecated: print(f" ❌ 不可用API: {', '.join(api_deprecated)}") # 创建编译进度监控器 class JupyterCompilationMonitor: def __init__(self): self.start_time = None self.is_monitoring = False # 创建输出widget self.output_widget = widgets.Output() # 创建进度条 self.progress_bar = widgets.IntProgress( value=0, min=0, max=100, description='XLA编译:', bar_style='info', style={'bar_color': '#1f77b4'}, orientation='horizontal' ) # 创建状态标签 self.status_label = widgets.HTML( value="准备开始编译..." ) # 创建CPU使用率显示 self.cpu_label = widgets.HTML( value="CPU: ---%" ) self.memory_label = widgets.HTML( value="内存: ---%" ) # 组合界面 self.monitor_box = widgets.VBox([ widgets.HTML("

🔄 XLA编译监控

"), self.progress_bar, self.status_label, widgets.HBox([self.cpu_label, self.memory_label]), self.output_widget ]) def start_monitoring(self): """开始监控""" self.start_time = time.time() self.is_monitoring = True display(self.monitor_box) # 启动监控线程 self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True) self.monitor_thread.start() def _monitor_loop(self): """监控循环""" while self.is_monitoring: try: elapsed = time.time() - self.start_time minutes = int(elapsed // 60) seconds = int(elapsed % 60) # 更新进度条 (模拟进度) progress = min(int(elapsed / 10 * 100), 95) # 10秒内达到95% self.progress_bar.value = progress # 获取系统资源 cpu_percent = psutil.cpu_percent(interval=0.1) memory_percent = psutil.virtual_memory().percent # 更新显示 self.status_label.value = f"编译进行中... ⏱️ {minutes:02d}:{seconds:02d}" self.cpu_label.value = f"🖥️ CPU: {cpu_percent:5.1f}%" self.memory_label.value = f"💾 内存: {memory_percent:5.1f}%" # 检测是否编译完成 (CPU使用率突然下降) if elapsed > 10 and cpu_percent < 20: # 编译通常CPU使用率很高 self.complete_monitoring() break time.sleep(1) except Exception as e: with self.output_widget: print(f"监控错误: {e}") break def complete_monitoring(self): """完成监控""" if self.is_monitoring: self.is_monitoring = False elapsed = time.time() - self.start_time self.progress_bar.value = 100 self.progress_bar.bar_style = 'success' self.status_label.value = f"✅ 编译完成! 总耗时: {elapsed:.2f}秒" with self.output_widget: print(f"\n🎉 XLA编译成功完成!") print(f"⏱️ 总耗时: {elapsed:.2f}秒") if elapsed < 60: print("✅ 编译速度正常") elif elapsed < 300: print("⚠️ 编译稍慢,但可接受") else: print("❌ 编译过慢,建议检查设置") # 创建全局监控器 compilation_monitor = JupyterCompilationMonitor() print("✅ 编译监控器已准备就绪!") print("💡 运行下一个单元格开始XLA编译测试")