#!/usr/bin/env python3 """ XLA编译进度监控脚本 """ import os import time import threading import psutil from contextlib import contextmanager def monitor_compilation_progress(): """监控XLA编译进度""" print("🔍 XLA编译进度监控已启动...") start_time = time.time() dots = 0 while True: elapsed = time.time() - start_time minutes = int(elapsed // 60) seconds = int(elapsed % 60) # 获取CPU使用率 cpu_percent = psutil.cpu_percent(interval=1) memory_percent = psutil.virtual_memory().percent # 动态显示 dots = (dots + 1) % 4 dot_str = "." * dots + " " * (3 - dots) print(f"\r🔄 XLA编译中{dot_str} " f"⏱️ {minutes:02d}:{seconds:02d} " f"🖥️ CPU: {cpu_percent:5.1f}% " f"💾 内存: {memory_percent:5.1f}%", end="", flush=True) time.sleep(1) @contextmanager def compilation_monitor(): """编译监控上下文管理器""" print("🚀 开始XLA编译监控...") # 启动监控线程 monitor_thread = threading.Thread(target=monitor_compilation_progress, daemon=True) monitor_thread.start() start_time = time.time() try: yield finally: elapsed = time.time() - start_time print(f"\n✅ XLA编译完成! 总耗时: {elapsed:.2f}秒") # 修改trainer中的使用方法 def add_compilation_monitoring_to_trainer(): """给trainer添加编译监控的示例代码""" example_code = ''' # 在rnn_trainer.py的train方法中添加: def train(self): from monitor_xla_compilation import compilation_monitor self.model.train() train_losses = [] # ... 其他初始化代码 ... self.logger.info("Starting training loop - XLA compilation monitoring enabled...") # 使用编译监控 with compilation_monitor(): for i, batch in enumerate(self.train_loader): # 第一个batch会触发XLA编译 # 监控会显示编译进度 # ... 训练代码 ... # 编译完成后会自动停止监控 break # 只需要第一个batch来触发编译 # 继续正常训练循环 for i, batch in enumerate(self.train_loader): # ... 正常训练代码 ... ''' print("📝 如何在trainer中使用编译监控:") print(example_code) if __name__ == "__main__": print("🧪 XLA编译监控工具") print("=" * 50) # 演示如何使用 print("📖 使用方法:") print("1. 将此文件导入到你的训练脚本中") print("2. 在第一次模型调用前使用 compilation_monitor() 上下文管理器") print("3. 会实时显示编译进度和系统资源使用情况") add_compilation_monitoring_to_trainer()