# ==================== # 单元格2: XLA编译进度监控 # ==================== import torch import torch.nn as nn import time import threading from IPython.display import display, HTML, clear_output import ipywidgets as widgets # 导入XLA (环境变量已在单元格1中设置) print("🚀 导入PyTorch XLA...") import torch_xla.core.xla_model as xm print(f"✅ XLA导入成功!") print(f" TPU设备: {xm.xla_device()}") # 兼容新版本PyTorch XLA try: world_size = xm.xrt_world_size() print(f" World Size (旧API): {world_size}") except AttributeError: try: world_size = xm.get_world_size() print(f" World Size (新API): {world_size}") except AttributeError: print(" World Size: 无法获取 (可能在CPU模式)") # 检查XLA版本兼容性 print("🔍 检查XLA API兼容性:") api_available = [] api_deprecated = [] # 检查各种API test_apis = [ ('xrt_world_size', 'xrt_world_size()'), ('get_world_size', 'get_world_size()'), ('mark_step', 'mark_step()'), ('wait_device_ops', 'wait_device_ops()'), ('get_ordinal', 'get_ordinal()'), ('xla_device_count', 'xla_device_count()') ] for api_name, api_desc in test_apis: if hasattr(xm, api_name): api_available.append(api_desc) else: api_deprecated.append(api_desc) if api_available: print(f" ✅ 可用API: {', '.join(api_available)}") if api_deprecated: print(f" ❌ 不可用API: {', '.join(api_deprecated)}") # 创建编译进度监控器 class JupyterCompilationMonitor: def __init__(self): self.start_time = None self.is_monitoring = False # 创建输出widget self.output_widget = widgets.Output() # 创建进度条 self.progress_bar = widgets.IntProgress( value=0, min=0, max=100, description='XLA编译:', bar_style='info', style={'bar_color': '#1f77b4'}, orientation='horizontal' ) # 创建状态标签 self.status_label = widgets.HTML( value="准备开始编译..." ) # 创建CPU使用率显示 self.cpu_label = widgets.HTML( value="CPU: ---%" ) self.memory_label = widgets.HTML( value="内存: ---%" ) # 组合界面 self.monitor_box = widgets.VBox([ widgets.HTML("