Files
b2txt25/model_training_nnn_tpu/monitor_xla_compilation.py
Zchen 56fa336af0 tpu
2025-10-15 14:26:11 +08:00

100 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
XLA编译进度监控脚本
"""
import os
import time
import threading
import psutil
from contextlib import contextmanager
def monitor_compilation_progress():
"""监控XLA编译进度"""
print("🔍 XLA编译进度监控已启动...")
start_time = time.time()
dots = 0
while True:
elapsed = time.time() - start_time
minutes = int(elapsed // 60)
seconds = int(elapsed % 60)
# 获取CPU使用率
cpu_percent = psutil.cpu_percent(interval=1)
memory_percent = psutil.virtual_memory().percent
# 动态显示
dots = (dots + 1) % 4
dot_str = "." * dots + " " * (3 - dots)
print(f"\r🔄 XLA编译中{dot_str} "
f"⏱️ {minutes:02d}:{seconds:02d} "
f"🖥️ CPU: {cpu_percent:5.1f}% "
f"💾 内存: {memory_percent:5.1f}%", end="", flush=True)
time.sleep(1)
@contextmanager
def compilation_monitor():
"""编译监控上下文管理器"""
print("🚀 开始XLA编译监控...")
# 启动监控线程
monitor_thread = threading.Thread(target=monitor_compilation_progress, daemon=True)
monitor_thread.start()
start_time = time.time()
try:
yield
finally:
elapsed = time.time() - start_time
print(f"\n✅ XLA编译完成! 总耗时: {elapsed:.2f}")
# 修改trainer中的使用方法
def add_compilation_monitoring_to_trainer():
"""给trainer添加编译监控的示例代码"""
example_code = '''
# 在rnn_trainer.py的train方法中添加
def train(self):
from monitor_xla_compilation import compilation_monitor
self.model.train()
train_losses = []
# ... 其他初始化代码 ...
self.logger.info("Starting training loop - XLA compilation monitoring enabled...")
# 使用编译监控
with compilation_monitor():
for i, batch in enumerate(self.train_loader):
# 第一个batch会触发XLA编译
# 监控会显示编译进度
# ... 训练代码 ...
# 编译完成后会自动停止监控
break # 只需要第一个batch来触发编译
# 继续正常训练循环
for i, batch in enumerate(self.train_loader):
# ... 正常训练代码 ...
'''
print("📝 如何在trainer中使用编译监控:")
print(example_code)
if __name__ == "__main__":
print("🧪 XLA编译监控工具")
print("=" * 50)
# 演示如何使用
print("📖 使用方法:")
print("1. 将此文件导入到你的训练脚本中")
print("2. 在第一次模型调用前使用 compilation_monitor() 上下文管理器")
print("3. 会实时显示编译进度和系统资源使用情况")
add_compilation_monitoring_to_trainer()