#!/usr/bin/env python3 """ TPU训练内存监控工具 - 专注于训练过程中的实时内存和MXU监控 适用于TPU v5e-8环境 """ import tensorflow as tf import time import numpy as np def monitor_tpu_during_training(): """训练过程中的TPU实时内存和MXU监控""" print("📊 TPU训练实时监控工具") print("=" * 50) # 获取TPU设备 try: tpu_devices = tf.config.list_logical_devices('TPU') print(f"📍 发现TPU设备: {len(tpu_devices)}个") if not tpu_devices: print("❌ 未发现TPU设备") return except Exception as e: print(f"❌ 无法检测TPU设备: {e}") return def get_detailed_memory_snapshot(): """获取详细的内存快照,包含所有核心信息""" snapshot = {} total_current = 0 total_peak = 0 active_cores = 0 core_details = [] for i, device in enumerate(tpu_devices): try: memory_info = tf.config.experimental.get_memory_info(device.name) if memory_info and 'current' in memory_info: current_mb = memory_info['current'] // (1024 * 1024) peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024) if current_mb > 1: # >1MB算活跃 active_cores += 1 total_current += current_mb total_peak += peak_mb core_details.append(f"Core{i}:{current_mb}MB") snapshot[f'core_{i}'] = { 'current': current_mb, 'peak': peak_mb, 'device': device.name } else: snapshot[f'core_{i}'] = {'current': 0, 'peak': 0, 'device': device.name} except Exception as e: snapshot[f'core_{i}'] = {'current': 0, 'peak': 0, 'device': device.name, 'error': str(e)} snapshot['summary'] = { 'total_current': total_current, 'total_peak': total_peak, 'active_cores': active_cores, 'total_cores': len(tpu_devices), 'core_details': core_details } return snapshot def test_mxu_performance(): """测试MXU性能和计算能力""" print("\n🧮 MXU计算性能测试:") mxu_results = [] try: with tf.device(tpu_devices[0].name): # 测试不同规模的矩阵运算 test_configs = [ (2000, "2K×2K", tf.bfloat16), (4000, "4K×4K", tf.bfloat16), (6000, "6K×6K", tf.bfloat16), ] for size, desc, dtype in test_configs: try: # 获取测试前内存 pre_mem = get_detailed_memory_snapshot() start_time = time.time() # 创建矩阵并执行MXU密集型运算 matrix_a = tf.random.normal([size, size], dtype=dtype) matrix_b = tf.random.normal([size, size], dtype=dtype) @tf.function def mxu_operation(): # 连续矩阵运算,充分使用MXU result = tf.matmul(matrix_a, matrix_b) result = tf.matmul(result, matrix_a) return tf.reduce_sum(result) result = mxu_operation() # 使用result确保计算被执行 _ = result.numpy() end_time = time.time() # 获取测试后内存 post_mem = get_detailed_memory_snapshot() duration = end_time - start_time # 计算FLOPS (两次矩阵乘法) flops = 2 * (2 * size**3) tflops = flops / duration / 1e12 memory_used = post_mem['summary']['total_current'] - pre_mem['summary']['total_current'] print(f" {desc} ({dtype.name}): {duration:.3f}s, {tflops:.1f}TFLOPS, 内存+{memory_used}MB") mxu_results.append({ 'size': size, 'tflops': tflops, 'duration': duration, 'memory_used': memory_used }) except Exception as e: print(f" {desc}: 测试失败 - {str(e)[:50]}") # MXU性能分析 if mxu_results: max_tflops = max(r['tflops'] for r in mxu_results) total_memory = sum(r['memory_used'] for r in mxu_results if r['memory_used'] > 0) # TPU v5e-8单核理论性能 theoretical_tflops = 275 # bf16峰值性能 efficiency = (max_tflops / theoretical_tflops) * 100 print(f"\n 📊 MXU性能汇总:") print(f" 峰值性能: {max_tflops:.1f} TFLOPS") print(f" 理论峰值: {theoretical_tflops} TFLOPS") print(f" MXU效率: {efficiency:.1f}%") print(f" 计算内存占用: {total_memory}MB") if efficiency > 80: status = "🟢 优秀" elif efficiency > 50: status = "🟡 良好" elif efficiency > 20: status = "🟠 中等" else: status = "🔴 需优化" print(f" 性能评级: {status}") except Exception as e: print(f" MXU测试失败: {e}") try: print("🎯 开始TPU训练监控...") # 1. 获取初始状态 print("\n📸 初始TPU状态:") baseline_snapshot = get_detailed_memory_snapshot() print(f" 总内存使用: {baseline_snapshot['summary']['total_current']}MB") print(f" 活跃核心: {baseline_snapshot['summary']['active_cores']}/{baseline_snapshot['summary']['total_cores']}") # 显示各核心详细状态 for i in range(len(tpu_devices)): core = baseline_snapshot[f'core_{i}'] if core['current'] > 0 or core['peak'] > 0: print(f" Core{i}: 当前{core['current']}MB, 峰值{core['peak']}MB") # 2. MXU性能基准测试 test_mxu_performance() # 3. 创建分布式策略 - 使用项目验证的TPU初始化代码 print(f"\n🔄 使用项目标准TPU初始化...") try: # 使用项目里验证过的TPU初始化代码 # 禁用GPU避免冲突 try: tf.config.set_visible_devices([], 'GPU') print("🚫 GPU已禁用,避免CUDA冲突") except: pass # 使用标准的TPU初始化流程 print("🚀 使用官方TensorFlow TPU初始化...") resolver = tf.distribute.cluster_resolver.TPUClusterResolver() tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) # 验证TPU设备 tpu_devices_check = tf.config.list_logical_devices('TPU') print(f"✅ TPU设备验证: 发现 {len(tpu_devices_check)} 个设备") # 创建TPU策略 strategy = tf.distribute.TPUStrategy(resolver) print(f"✅ 成功创建TPU策略: {strategy.num_replicas_in_sync}个副本") use_distributed = True except Exception as e: print(f"⚠️ 分布式策略失败: {str(e)[:80]}") print(" 将使用单设备模拟") use_distributed = False # 4. 模拟Brain-to-Text训练场景 print(f"\n🧠 模拟Brain-to-Text训练场景...") if use_distributed: # 分布式训练模拟 with strategy.scope(): print("📦 创建分布式模型参数...") # 创建接近真实Brain-to-Text模型的参数 (修复维度匹配) model_components = { # GRU层权重:第一层接收512维输入,后续层接收256维 'gru_layer_0': tf.Variable(tf.random.normal([512, 256]), name='gru_0'), 'gru_layer_1': tf.Variable(tf.random.normal([256, 256]), name='gru_1'), 'gru_layer_2': tf.Variable(tf.random.normal([256, 256]), name='gru_2'), 'output_projection': tf.Variable(tf.random.normal([256, 41]), name='output'), # 添加day-specific层模拟 (输入512维,输出512维) 'day_weights': [tf.Variable(tf.random.normal([512, 512]), name=f'day_{i}') for i in range(8)] } # 检查模型加载后内存 after_model = get_detailed_memory_snapshot() model_memory = after_model['summary']['total_current'] - baseline_snapshot['summary']['total_current'] print(f"🧠 模型加载完成: +{model_memory}MB, {after_model['summary']['active_cores']}个活跃核心") # 训练循环模拟 print(f"\n🔄 开始训练循环监控...") for step in range(10): step_start_time = time.time() @tf.function def distributed_training_step(): # 模拟真实训练数据大小 batch_size = 32 seq_length = 1000 features = 512 # 输入数据 neural_data = tf.random.normal([batch_size, seq_length, features]) targets = tf.random.uniform([batch_size, seq_length], maxval=41, dtype=tf.int32) # 模拟前向传播 x = neural_data # Day-specific transformation (简化版本避免复杂的维度操作) # 模拟day-specific变换:对每个时间步应用相同变换 day_weight = model_components['day_weights'][0] # 简化:使用第一个day权重 # 对最后一个维度进行变换: [batch, seq, 512] @ [512, 512] -> [batch, seq, 512] x = tf.matmul(x, day_weight) # 为CTC损失添加目标使用(模拟) target_length = tf.reduce_sum(tf.cast(targets > 0, tf.int32), axis=1) # 简化的CTC相关计算 batch_loss_weight = tf.reduce_mean(tf.cast(target_length, tf.float32)) # GRU layers simulation for i in range(3): layer_name = f'gru_layer_{i}' weight = model_components[layer_name] # 处理张量维度:第一层从3D输入,后续层从2D输入 if i == 0: # 第一层:取最后时间步 [batch, seq, features] -> [batch, features] if len(x.shape) == 3: x = x[:, -1, :] # 取最后时间步 x = tf.nn.tanh(tf.matmul(x, weight)) else: # 后续层:直接处理2D张量 [batch, features] -> [batch, features] x = tf.nn.tanh(tf.matmul(x, weight)) # 输出投影 logits = tf.matmul(x, model_components['output_projection']) # CTC loss模拟(使用batch_loss_weight作为权重) base_loss = tf.reduce_mean(tf.square(logits)) loss = base_loss * batch_loss_weight return loss # 执行训练步骤 per_replica_loss = strategy.run(distributed_training_step) # 聚合分布式结果 loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_loss, axis=None) step_duration = time.time() - step_start_time # 获取当前内存状态 current_snapshot = get_detailed_memory_snapshot() step_memory = current_snapshot['summary']['total_current'] memory_delta = step_memory - baseline_snapshot['summary']['total_current'] # 显示详细训练状态 active_cores_info = f"({', '.join(current_snapshot['summary']['core_details'])})" if current_snapshot['summary']['core_details'] else "(无活跃)" print(f" Step {step:2d}: loss={float(loss.numpy()):.4f}, " f"时间={step_duration:.3f}s, " f"内存={step_memory}MB(+{memory_delta}), " f"活跃={current_snapshot['summary']['active_cores']}/{current_snapshot['summary']['total_cores']} {active_cores_info}") # 每5步显示峰值内存 if step % 5 == 0: peak_info = f"峰值: {current_snapshot['summary']['total_peak']}MB" print(f" {peak_info}") time.sleep(0.2) # 短暂暂停观察 else: # 单设备训练模拟(改进版) print("🔸 单设备训练模拟...") with tf.device(tpu_devices[0].name): # 创建较小的模型参数 simple_weights = tf.Variable(tf.random.normal([512, 256]), name='simple_net') for step in range(8): step_start = time.time() # 创建较大的数据批次 batch_data = tf.random.normal([64, 1000, 512]) # 增大batch size # 模拟计算密集型操作 @tf.function def compute_step(): x = tf.reshape(batch_data, [-1, 512]) result = tf.matmul(x, simple_weights) result = tf.nn.relu(result) return tf.reduce_mean(result) result = compute_step() step_duration = time.time() - step_start # 获取内存状态 snapshot = get_detailed_memory_snapshot() memory_change = snapshot['summary']['total_current'] - baseline_snapshot['summary']['total_current'] print(f" Step {step}: result={result.numpy():.4f}, " f"时间={step_duration:.3f}s, " f"内存变化=+{memory_change}MB, " f"峰值={snapshot['summary']['total_peak']}MB") # 5. 最终分析报告 final_snapshot = get_detailed_memory_snapshot() total_growth = final_snapshot['summary']['total_current'] - baseline_snapshot['summary']['total_current'] peak_usage = final_snapshot['summary']['total_peak'] print(f"\n📈 训练监控报告:") print(f" 总内存增长: +{total_growth}MB") print(f" 峰值内存使用: {peak_usage}MB ({peak_usage/1024:.2f}GB)") print(f" 最终活跃核心: {final_snapshot['summary']['active_cores']}/{final_snapshot['summary']['total_cores']}") # 各核心最终状态 print(f" 各核心最终状态:") has_changes = False for i in range(len(tpu_devices)): final_core = final_snapshot[f'core_{i}'] baseline_core = baseline_snapshot[f'core_{i}'] current_change = final_core['current'] - baseline_core['current'] peak_change = final_core['peak'] - baseline_core['peak'] if current_change != 0 or peak_change != 0: has_changes = True print(f" Core{i}: 当前{final_core['current']}MB(+{current_change}), 峰值{final_core['peak']}MB(+{peak_change})") if not has_changes: print(f" 所有核心内存无明显变化") # 分布式使用分析 if final_snapshot['summary']['active_cores'] == 1: print(f"\n⚠️ 分布式问题诊断:") print(f" 只有1个核心活跃,其他7个核心空闲") print(f" 可能原因: TPU策略配置问题或模型未正确分布") print(f" 建议: 检查分布式策略和模型分片") elif final_snapshot['summary']['active_cores'] > 4: print(f"\n✅ 分布式状态良好:") print(f" {final_snapshot['summary']['active_cores']}个核心活跃,多核心并行工作正常") else: print(f"\n🟡 分布式部分工作:") print(f" {final_snapshot['summary']['active_cores']}个核心活跃,可能存在负载不均衡") print("✅ TPU训练监控完成") except Exception as e: print(f"❌ 训练监控失败: {e}") import traceback print(f"详细错误: {traceback.format_exc()[:300]}") if __name__ == "__main__": print("🚀 TPU训练内存监控工具") print("专注于训练过程中的实时内存和性能监控") print("适用于TPU v5e-8环境") print() monitor_tpu_during_training() print(f"\n🎯 监控要点总结:") print(f" 1. 确认所有8个TPU核心是否活跃") print(f" 2. 监控内存增长模式和峰值使用") print(f" 3. 检测MXU计算性能和效率") print(f" 4. 验证分布式策略是否正常工作") print(f" 5. 识别可能的内存泄漏或性能瓶颈")