100 lines
2.7 KiB
Python
100 lines
2.7 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
XLA编译进度监控脚本
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
import threading
|
||
import psutil
|
||
from contextlib import contextmanager
|
||
|
||
def monitor_compilation_progress():
|
||
"""监控XLA编译进度"""
|
||
print("🔍 XLA编译进度监控已启动...")
|
||
|
||
start_time = time.time()
|
||
dots = 0
|
||
|
||
while True:
|
||
elapsed = time.time() - start_time
|
||
minutes = int(elapsed // 60)
|
||
seconds = int(elapsed % 60)
|
||
|
||
# 获取CPU使用率
|
||
cpu_percent = psutil.cpu_percent(interval=1)
|
||
memory_percent = psutil.virtual_memory().percent
|
||
|
||
# 动态显示
|
||
dots = (dots + 1) % 4
|
||
dot_str = "." * dots + " " * (3 - dots)
|
||
|
||
print(f"\r🔄 XLA编译中{dot_str} "
|
||
f"⏱️ {minutes:02d}:{seconds:02d} "
|
||
f"🖥️ CPU: {cpu_percent:5.1f}% "
|
||
f"💾 内存: {memory_percent:5.1f}%", end="", flush=True)
|
||
|
||
time.sleep(1)
|
||
|
||
@contextmanager
|
||
def compilation_monitor():
|
||
"""编译监控上下文管理器"""
|
||
print("🚀 开始XLA编译监控...")
|
||
|
||
# 启动监控线程
|
||
monitor_thread = threading.Thread(target=monitor_compilation_progress, daemon=True)
|
||
monitor_thread.start()
|
||
|
||
start_time = time.time()
|
||
|
||
try:
|
||
yield
|
||
finally:
|
||
elapsed = time.time() - start_time
|
||
print(f"\n✅ XLA编译完成! 总耗时: {elapsed:.2f}秒")
|
||
|
||
# 修改trainer中的使用方法
|
||
def add_compilation_monitoring_to_trainer():
|
||
"""给trainer添加编译监控的示例代码"""
|
||
example_code = '''
|
||
# 在rnn_trainer.py的train方法中添加:
|
||
|
||
def train(self):
|
||
from monitor_xla_compilation import compilation_monitor
|
||
|
||
self.model.train()
|
||
train_losses = []
|
||
# ... 其他初始化代码 ...
|
||
|
||
self.logger.info("Starting training loop - XLA compilation monitoring enabled...")
|
||
|
||
# 使用编译监控
|
||
with compilation_monitor():
|
||
for i, batch in enumerate(self.train_loader):
|
||
# 第一个batch会触发XLA编译
|
||
# 监控会显示编译进度
|
||
|
||
# ... 训练代码 ...
|
||
|
||
# 编译完成后会自动停止监控
|
||
break # 只需要第一个batch来触发编译
|
||
|
||
# 继续正常训练循环
|
||
for i, batch in enumerate(self.train_loader):
|
||
# ... 正常训练代码 ...
|
||
'''
|
||
|
||
print("📝 如何在trainer中使用编译监控:")
|
||
print(example_code)
|
||
|
||
if __name__ == "__main__":
|
||
print("🧪 XLA编译监控工具")
|
||
print("=" * 50)
|
||
|
||
# 演示如何使用
|
||
print("📖 使用方法:")
|
||
print("1. 将此文件导入到你的训练脚本中")
|
||
print("2. 在第一次模型调用前使用 compilation_monitor() 上下文管理器")
|
||
print("3. 会实时显示编译进度和系统资源使用情况")
|
||
|
||
add_compilation_monitoring_to_trainer() |