Files
b2txt25/model_training_nnn_tpu/check_tpu_memory.py
2025-10-16 13:39:05 +08:00

403 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
TPU训练内存监控工具 - 专注于训练过程中的实时内存和MXU监控
适用于TPU v5e-8环境
"""
import tensorflow as tf
import time
import numpy as np
def monitor_tpu_during_training():
"""训练过程中的TPU实时内存和MXU监控"""
print("📊 TPU训练实时监控工具")
print("=" * 50)
# 获取TPU设备
try:
tpu_devices = tf.config.list_logical_devices('TPU')
print(f"📍 发现TPU设备: {len(tpu_devices)}")
if not tpu_devices:
print("❌ 未发现TPU设备")
return
except Exception as e:
print(f"❌ 无法检测TPU设备: {e}")
return
def get_detailed_memory_snapshot():
"""获取详细的内存快照,包含所有核心信息"""
snapshot = {}
total_current = 0
total_peak = 0
active_cores = 0
core_details = []
for i, device in enumerate(tpu_devices):
try:
memory_info = tf.config.experimental.get_memory_info(device.name)
if memory_info and 'current' in memory_info:
current_mb = memory_info['current'] // (1024 * 1024)
peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024)
if current_mb > 1: # >1MB算活跃
active_cores += 1
total_current += current_mb
total_peak += peak_mb
core_details.append(f"Core{i}:{current_mb}MB")
snapshot[f'core_{i}'] = {
'current': current_mb,
'peak': peak_mb,
'device': device.name
}
else:
snapshot[f'core_{i}'] = {'current': 0, 'peak': 0, 'device': device.name}
except Exception as e:
snapshot[f'core_{i}'] = {'current': 0, 'peak': 0, 'device': device.name, 'error': str(e)}
snapshot['summary'] = {
'total_current': total_current,
'total_peak': total_peak,
'active_cores': active_cores,
'total_cores': len(tpu_devices),
'core_details': core_details
}
return snapshot
def test_mxu_performance():
"""测试MXU性能和计算能力"""
print("\n🧮 MXU计算性能测试:")
mxu_results = []
try:
with tf.device(tpu_devices[0].name):
# 测试不同规模的矩阵运算
test_configs = [
(2000, "2K×2K", tf.bfloat16),
(4000, "4K×4K", tf.bfloat16),
(6000, "6K×6K", tf.bfloat16),
]
for size, desc, dtype in test_configs:
try:
# 获取测试前内存
pre_mem = get_detailed_memory_snapshot()
start_time = time.time()
# 创建矩阵并执行MXU密集型运算
matrix_a = tf.random.normal([size, size], dtype=dtype)
matrix_b = tf.random.normal([size, size], dtype=dtype)
@tf.function
def mxu_operation():
# 连续矩阵运算充分使用MXU
result = tf.matmul(matrix_a, matrix_b)
result = tf.matmul(result, matrix_a)
return tf.reduce_sum(result)
result = mxu_operation()
# 使用result确保计算被执行
_ = result.numpy()
end_time = time.time()
# 获取测试后内存
post_mem = get_detailed_memory_snapshot()
duration = end_time - start_time
# 计算FLOPS (两次矩阵乘法)
flops = 2 * (2 * size**3)
tflops = flops / duration / 1e12
memory_used = post_mem['summary']['total_current'] - pre_mem['summary']['total_current']
print(f" {desc} ({dtype.name}): {duration:.3f}s, {tflops:.1f}TFLOPS, 内存+{memory_used}MB")
mxu_results.append({
'size': size,
'tflops': tflops,
'duration': duration,
'memory_used': memory_used
})
except Exception as e:
print(f" {desc}: 测试失败 - {str(e)[:50]}")
# MXU性能分析
if mxu_results:
max_tflops = max(r['tflops'] for r in mxu_results)
total_memory = sum(r['memory_used'] for r in mxu_results if r['memory_used'] > 0)
# TPU v5e-8单核理论性能
theoretical_tflops = 275 # bf16峰值性能
efficiency = (max_tflops / theoretical_tflops) * 100
print(f"\n 📊 MXU性能汇总:")
print(f" 峰值性能: {max_tflops:.1f} TFLOPS")
print(f" 理论峰值: {theoretical_tflops} TFLOPS")
print(f" MXU效率: {efficiency:.1f}%")
print(f" 计算内存占用: {total_memory}MB")
if efficiency > 80:
status = "🟢 优秀"
elif efficiency > 50:
status = "🟡 良好"
elif efficiency > 20:
status = "🟠 中等"
else:
status = "🔴 需优化"
print(f" 性能评级: {status}")
except Exception as e:
print(f" MXU测试失败: {e}")
try:
print("🎯 开始TPU训练监控...")
# 1. 获取初始状态
print("\n📸 初始TPU状态:")
baseline_snapshot = get_detailed_memory_snapshot()
print(f" 总内存使用: {baseline_snapshot['summary']['total_current']}MB")
print(f" 活跃核心: {baseline_snapshot['summary']['active_cores']}/{baseline_snapshot['summary']['total_cores']}")
# 显示各核心详细状态
for i in range(len(tpu_devices)):
core = baseline_snapshot[f'core_{i}']
if core['current'] > 0 or core['peak'] > 0:
print(f" Core{i}: 当前{core['current']}MB, 峰值{core['peak']}MB")
# 2. MXU性能基准测试
test_mxu_performance()
# 3. 创建分布式策略 - 使用项目验证的TPU初始化代码
print(f"\n🔄 使用项目标准TPU初始化...")
try:
# 使用项目里验证过的TPU初始化代码
# 禁用GPU避免冲突
try:
tf.config.set_visible_devices([], 'GPU')
print("🚫 GPU已禁用避免CUDA冲突")
except:
pass
# 使用标准的TPU初始化流程
print("🚀 使用官方TensorFlow TPU初始化...")
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
# 验证TPU设备
tpu_devices_check = tf.config.list_logical_devices('TPU')
print(f"✅ TPU设备验证: 发现 {len(tpu_devices_check)} 个设备")
# 创建TPU策略
strategy = tf.distribute.TPUStrategy(resolver)
print(f"✅ 成功创建TPU策略: {strategy.num_replicas_in_sync}个副本")
use_distributed = True
except Exception as e:
print(f"⚠️ 分布式策略失败: {str(e)[:80]}")
print(" 将使用单设备模拟")
use_distributed = False
# 4. 模拟Brain-to-Text训练场景
print(f"\n🧠 模拟Brain-to-Text训练场景...")
if use_distributed:
# 分布式训练模拟
with strategy.scope():
print("📦 创建分布式模型参数...")
# 创建接近真实Brain-to-Text模型的参数 (修复维度匹配)
model_components = {
# GRU层权重第一层接收512维输入后续层接收256维
'gru_layer_0': tf.Variable(tf.random.normal([512, 256]), name='gru_0'),
'gru_layer_1': tf.Variable(tf.random.normal([256, 256]), name='gru_1'),
'gru_layer_2': tf.Variable(tf.random.normal([256, 256]), name='gru_2'),
'output_projection': tf.Variable(tf.random.normal([256, 41]), name='output'),
# 添加day-specific层模拟 (输入512维输出512维)
'day_weights': [tf.Variable(tf.random.normal([512, 512]), name=f'day_{i}') for i in range(8)]
}
# 检查模型加载后内存
after_model = get_detailed_memory_snapshot()
model_memory = after_model['summary']['total_current'] - baseline_snapshot['summary']['total_current']
print(f"🧠 模型加载完成: +{model_memory}MB, {after_model['summary']['active_cores']}个活跃核心")
# 训练循环模拟
print(f"\n🔄 开始训练循环监控...")
for step in range(10):
step_start_time = time.time()
@tf.function
def distributed_training_step():
# 模拟真实训练数据大小
batch_size = 32
seq_length = 1000
features = 512
# 输入数据
neural_data = tf.random.normal([batch_size, seq_length, features])
targets = tf.random.uniform([batch_size, seq_length], maxval=41, dtype=tf.int32)
# 模拟前向传播
x = neural_data
# Day-specific transformation (简化版本避免复杂的维度操作)
# 模拟day-specific变换对每个时间步应用相同变换
day_weight = model_components['day_weights'][0] # 简化使用第一个day权重
# 对最后一个维度进行变换: [batch, seq, 512] @ [512, 512] -> [batch, seq, 512]
x = tf.matmul(x, day_weight)
# 为CTC损失添加目标使用模拟
target_length = tf.reduce_sum(tf.cast(targets > 0, tf.int32), axis=1)
# 简化的CTC相关计算
batch_loss_weight = tf.reduce_mean(tf.cast(target_length, tf.float32))
# GRU layers simulation
for i in range(3):
layer_name = f'gru_layer_{i}'
weight = model_components[layer_name]
# 处理张量维度第一层从3D输入后续层从2D输入
if i == 0:
# 第一层:取最后时间步 [batch, seq, features] -> [batch, features]
if len(x.shape) == 3:
x = x[:, -1, :] # 取最后时间步
x = tf.nn.tanh(tf.matmul(x, weight))
else:
# 后续层直接处理2D张量 [batch, features] -> [batch, features]
x = tf.nn.tanh(tf.matmul(x, weight))
# 输出投影
logits = tf.matmul(x, model_components['output_projection'])
# CTC loss模拟使用batch_loss_weight作为权重
base_loss = tf.reduce_mean(tf.square(logits))
loss = base_loss * batch_loss_weight
return loss
# 执行训练步骤
per_replica_loss = strategy.run(distributed_training_step)
# 聚合分布式结果
loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=None)
step_duration = time.time() - step_start_time
# 获取当前内存状态
current_snapshot = get_detailed_memory_snapshot()
step_memory = current_snapshot['summary']['total_current']
memory_delta = step_memory - baseline_snapshot['summary']['total_current']
# 显示详细训练状态
active_cores_info = f"({', '.join(current_snapshot['summary']['core_details'])})" if current_snapshot['summary']['core_details'] else "(无活跃)"
print(f" Step {step:2d}: loss={float(loss.numpy()):.4f}, "
f"时间={step_duration:.3f}s, "
f"内存={step_memory}MB(+{memory_delta}), "
f"活跃={current_snapshot['summary']['active_cores']}/{current_snapshot['summary']['total_cores']} {active_cores_info}")
# 每5步显示峰值内存
if step % 5 == 0:
peak_info = f"峰值: {current_snapshot['summary']['total_peak']}MB"
print(f" {peak_info}")
time.sleep(0.2) # 短暂暂停观察
else:
# 单设备训练模拟(改进版)
print("🔸 单设备训练模拟...")
with tf.device(tpu_devices[0].name):
# 创建较小的模型参数
simple_weights = tf.Variable(tf.random.normal([512, 256]), name='simple_net')
for step in range(8):
step_start = time.time()
# 创建较大的数据批次
batch_data = tf.random.normal([64, 1000, 512]) # 增大batch size
# 模拟计算密集型操作
@tf.function
def compute_step():
x = tf.reshape(batch_data, [-1, 512])
result = tf.matmul(x, simple_weights)
result = tf.nn.relu(result)
return tf.reduce_mean(result)
result = compute_step()
step_duration = time.time() - step_start
# 获取内存状态
snapshot = get_detailed_memory_snapshot()
memory_change = snapshot['summary']['total_current'] - baseline_snapshot['summary']['total_current']
print(f" Step {step}: result={result.numpy():.4f}, "
f"时间={step_duration:.3f}s, "
f"内存变化=+{memory_change}MB, "
f"峰值={snapshot['summary']['total_peak']}MB")
# 5. 最终分析报告
final_snapshot = get_detailed_memory_snapshot()
total_growth = final_snapshot['summary']['total_current'] - baseline_snapshot['summary']['total_current']
peak_usage = final_snapshot['summary']['total_peak']
print(f"\n📈 训练监控报告:")
print(f" 总内存增长: +{total_growth}MB")
print(f" 峰值内存使用: {peak_usage}MB ({peak_usage/1024:.2f}GB)")
print(f" 最终活跃核心: {final_snapshot['summary']['active_cores']}/{final_snapshot['summary']['total_cores']}")
# 各核心最终状态
print(f" 各核心最终状态:")
has_changes = False
for i in range(len(tpu_devices)):
final_core = final_snapshot[f'core_{i}']
baseline_core = baseline_snapshot[f'core_{i}']
current_change = final_core['current'] - baseline_core['current']
peak_change = final_core['peak'] - baseline_core['peak']
if current_change != 0 or peak_change != 0:
has_changes = True
print(f" Core{i}: 当前{final_core['current']}MB(+{current_change}), 峰值{final_core['peak']}MB(+{peak_change})")
if not has_changes:
print(f" 所有核心内存无明显变化")
# 分布式使用分析
if final_snapshot['summary']['active_cores'] == 1:
print(f"\n⚠️ 分布式问题诊断:")
print(f" 只有1个核心活跃其他7个核心空闲")
print(f" 可能原因: TPU策略配置问题或模型未正确分布")
print(f" 建议: 检查分布式策略和模型分片")
elif final_snapshot['summary']['active_cores'] > 4:
print(f"\n✅ 分布式状态良好:")
print(f" {final_snapshot['summary']['active_cores']}个核心活跃,多核心并行工作正常")
else:
print(f"\n🟡 分布式部分工作:")
print(f" {final_snapshot['summary']['active_cores']}个核心活跃,可能存在负载不均衡")
print("✅ TPU训练监控完成")
except Exception as e:
print(f"❌ 训练监控失败: {e}")
import traceback
print(f"详细错误: {traceback.format_exc()[:300]}")
if __name__ == "__main__":
print("🚀 TPU训练内存监控工具")
print("专注于训练过程中的实时内存和性能监控")
print("适用于TPU v5e-8环境")
print()
monitor_tpu_during_training()
print(f"\n🎯 监控要点总结:")
print(f" 1. 确认所有8个TPU核心是否活跃")
print(f" 2. 监控内存增长模式和峰值使用")
print(f" 3. 检测MXU计算性能和效率")
print(f" 4. 验证分布式策略是否正常工作")
print(f" 5. 识别可能的内存泄漏或性能瓶颈")