From e7947f310cd5b98bb843b45fcabd47df0055a89c Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Wed, 15 Oct 2025 14:33:49 +0800 Subject: [PATCH] tpu --- model_training_nnn_tpu/check_xla_threads.py | 148 ---------------- .../jupyter_debug_full_model.py | 124 -------------- .../jupyter_xla_compatibility.py | 93 ++++++++++ model_training_nnn_tpu/jupyter_xla_monitor.py | 38 ++++- model_training_nnn_tpu/jupyter_xla_setup.py | 45 ----- model_training_nnn_tpu/launch_tpu_training.py | 161 ------------------ .../monitor_xla_compilation.py | 100 ----------- model_training_nnn_tpu/start_tpu_training.sh | 27 --- 8 files changed, 130 insertions(+), 606 deletions(-) delete mode 100644 model_training_nnn_tpu/check_xla_threads.py delete mode 100644 model_training_nnn_tpu/jupyter_debug_full_model.py create mode 100644 model_training_nnn_tpu/jupyter_xla_compatibility.py delete mode 100644 model_training_nnn_tpu/jupyter_xla_setup.py delete mode 100644 model_training_nnn_tpu/launch_tpu_training.py delete mode 100644 model_training_nnn_tpu/monitor_xla_compilation.py delete mode 100644 model_training_nnn_tpu/start_tpu_training.sh diff --git a/model_training_nnn_tpu/check_xla_threads.py b/model_training_nnn_tpu/check_xla_threads.py deleted file mode 100644 index ceaf3ec..0000000 --- a/model_training_nnn_tpu/check_xla_threads.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python3 -""" -XLA Multi-threading Diagnostic Script -检查XLA编译是否正确使用多CPU核心 -""" - -import os -import psutil -import time -import threading -from concurrent.futures import ThreadPoolExecutor - -def set_xla_environment(): - """设置XLA环境变量""" - cpu_count = os.cpu_count() - - # 设置XLA环境变量 - os.environ['XLA_FLAGS'] = ( - '--xla_cpu_multi_thread_eigen=true ' - '--xla_cpu_enable_fast_math=true ' - f'--xla_force_host_platform_device_count={cpu_count}' - ) - os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count) - - print(f"🔧 设置XLA环境变量:") - print(f" CPU核心数: {cpu_count}") - print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}") - print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}") - print("-" * 60) - -def monitor_cpu_usage(duration=30, interval=1): - """监控CPU使用情况""" - print(f"🔍 监控CPU使用情况 {duration}秒...") - - cpu_usage_data = [] - start_time = time.time() - - while time.time() - start_time < duration: - # 获取每个CPU核心的使用率 - cpu_percent_per_core = psutil.cpu_percent(interval=interval, percpu=True) - cpu_usage_data.append(cpu_percent_per_core) - - # 实时显示 - active_cores = sum(1 for usage in cpu_percent_per_core if usage > 10) - print(f"活跃核心数: {active_cores}/{len(cpu_percent_per_core)}, " - f"平均使用率: {sum(cpu_percent_per_core)/len(cpu_percent_per_core):.1f}%", - end='\r') - - print() # 换行 - - # 分析结果 - if cpu_usage_data: - avg_usage_per_core = [ - sum(core_data) / len(cpu_usage_data) - for core_data in zip(*cpu_usage_data) - ] - - active_cores = sum(1 for avg in avg_usage_per_core if avg > 5) - max_usage = max(avg_usage_per_core) - - print(f"📊 CPU使用分析:") - print(f" 活跃的CPU核心: {active_cores}/{len(avg_usage_per_core)}") - print(f" 最高平均使用率: {max_usage:.1f}%") - - for i, usage in enumerate(avg_usage_per_core): - status = "🟢" if usage > 10 else "🔴" if usage > 5 else "⚫" - print(f" CPU核心 {i}: {usage:.1f}% {status}") - - return active_cores > 1 - - return False - -def test_xla_compilation(): - """测试XLA编译""" - print(f"🚀 开始XLA编译测试...") - - try: - import torch - import torch_xla.core.xla_model as xm - - print(f"✅ PyTorch XLA导入成功") - print(f" XLA设备: {xm.xla_device()}") - print(f" XLA world size: {xm.xrt_world_size()}") - - # 创建一个简单的计算图进行编译 - device = xm.xla_device() - - print(f"🔄 创建测试计算图...") - x = torch.randn(100, 100, device=device) - y = torch.randn(100, 100, device=device) - - print(f"🔄 执行矩阵运算 (将触发XLA编译)...") - - # 启动CPU监控 - monitor_thread = threading.Thread( - target=lambda: monitor_cpu_usage(20, 0.5), - daemon=True - ) - monitor_thread.start() - - # 执行计算,触发编译 - for i in range(10): - z = torch.matmul(x, y) - z = torch.relu(z) - z = torch.matmul(z, x.T) - if i == 0: - print(f"🔄 首次计算完成 (XLA编译应该正在进行)...") - elif i == 5: - print(f"🔄 第6次计算完成...") - - # 等待监控完成 - monitor_thread.join(timeout=25) - - print(f"✅ XLA测试完成") - - return True - - except ImportError as e: - print(f"❌ PyTorch XLA导入失败: {e}") - return False - except Exception as e: - print(f"❌ XLA测试失败: {e}") - return False - -def main(): - print("=" * 60) - print("🧪 XLA多线程编译诊断工具") - print("=" * 60) - - # 1. 设置环境 - set_xla_environment() - - # 2. 测试XLA编译并监控CPU - success = test_xla_compilation() - - print("=" * 60) - if success: - print("✅ 诊断完成") - print("💡 如果看到多个CPU核心被激活,说明XLA多线程工作正常") - print("💡 如果只有1-2个核心活跃,可能需要其他优化方法") - else: - print("❌ 诊断失败") - print("💡 请检查PyTorch XLA安装和TPU环境") - - print("=" * 60) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/model_training_nnn_tpu/jupyter_debug_full_model.py b/model_training_nnn_tpu/jupyter_debug_full_model.py deleted file mode 100644 index 1b7deeb..0000000 --- a/model_training_nnn_tpu/jupyter_debug_full_model.py +++ /dev/null @@ -1,124 +0,0 @@ -# ==================== -# 单元格4: 逐步调试完整模型编译 -# ==================== - -# 如果单元格3测试通过,运行这个单元格 -print("🔧 逐步测试完整TripleGRUDecoder模型...") - -# 导入完整模型 -import sys -sys.path.append('.') # 确保能导入本地模块 - -try: - from rnn_model import TripleGRUDecoder - print("✅ TripleGRUDecoder导入成功") -except ImportError as e: - print(f"❌ 模型导入失败: {e}") - print("请确保rnn_model.py在当前目录中") - -# 分阶段测试 -def test_model_compilation_stages(): - """分阶段测试模型编译""" - device = xm.xla_device() - - # 阶段1: 测试NoiseModel单独编译 - print("\n🔬 阶段1: 测试NoiseModel...") - try: - from rnn_model import NoiseModel - noise_model = NoiseModel( - neural_dim=512, - n_units=384, # 减小参数 - n_days=4, - patch_size=8 # 减小patch size - ).to(device) - - x = torch.randn(2, 20, 512, device=device) - day_idx = torch.tensor([0, 1], device=device) - - start_time = time.time() - with torch.no_grad(): - output, states = noise_model(x, day_idx) - compile_time = time.time() - start_time - - print(f"✅ NoiseModel编译成功! 耗时: {compile_time:.2f}秒") - print(f" 参数数量: {sum(p.numel() for p in noise_model.parameters()):,}") - - return True, compile_time - - except Exception as e: - print(f"❌ NoiseModel编译失败: {e}") - return False, 0 - - # 阶段2: 测试CleanSpeechModel - print("\n🔬 阶段2: 测试CleanSpeechModel...") - try: - from rnn_model import CleanSpeechModel - clean_model = CleanSpeechModel( - neural_dim=512, - n_units=384, - n_days=4, - n_classes=41, - patch_size=8 - ).to(device) - - start_time = time.time() - with torch.no_grad(): - output = clean_model(x, day_idx) - compile_time = time.time() - start_time - - print(f"✅ CleanSpeechModel编译成功! 耗时: {compile_time:.2f}秒") - return True, compile_time - - except Exception as e: - print(f"❌ CleanSpeechModel编译失败: {e}") - return False, 0 - - # 阶段3: 测试完整TripleGRUDecoder - print("\n🔬 阶段3: 测试TripleGRUDecoder...") - try: - model = TripleGRUDecoder( - neural_dim=512, - n_units=384, # 比原来的768小 - n_days=4, # 减少天数 - n_classes=41, - patch_size=8 # 减小patch size - ).to(device) - - print(f"📊 完整模型参数: {sum(p.numel() for p in model.parameters()):,}") - - # 启动编译监控 - compilation_monitor.start_monitoring() - - start_time = time.time() - with torch.no_grad(): - # 测试inference模式 - logits = model(x, day_idx, None, False, 'inference') - compile_time = time.time() - start_time - - compilation_monitor.complete_monitoring() - - print(f"✅ TripleGRUDecoder编译成功! 耗时: {compile_time:.2f}秒") - print(f"📤 输出形状: {logits.shape}") - - return True, compile_time - - except Exception as e: - compilation_monitor.complete_monitoring() - print(f"❌ TripleGRUDecoder编译失败: {e}") - return False, 0 - -# 运行分阶段测试 -stage_results = test_model_compilation_stages() - -if stage_results: - print(f"\n🎉 所有编译测试完成!") - print("💡 下一步可以尝试:") - print(" 1. 使用简化配置进行训练") - print(" 2. 逐步增加模型复杂度") - print(" 3. 监控TPU资源使用情况") -else: - print(f"\n⚠️ 编译测试发现问题") - print("💡 建议:") - print(" 1. 进一步减小模型参数") - print(" 2. 检查内存使用情况") - print(" 3. 使用CPU模式进行调试") \ No newline at end of file diff --git a/model_training_nnn_tpu/jupyter_xla_compatibility.py b/model_training_nnn_tpu/jupyter_xla_compatibility.py new file mode 100644 index 0000000..ac275af --- /dev/null +++ b/model_training_nnn_tpu/jupyter_xla_compatibility.py @@ -0,0 +1,93 @@ +# ==================== +# 单元格: XLA版本兼容性检查和修复 +# ==================== + +import torch +import torch.nn as nn +print("🔧 PyTorch XLA版本兼容性检查...") + +# 导入XLA +import torch_xla.core.xla_model as xm + +print("✅ PyTorch XLA导入成功!") + +# 定义兼容性函数 +def get_xla_world_size(): + """获取XLA world size,兼容不同版本""" + try: + return xm.xrt_world_size() + except AttributeError: + try: + return xm.get_world_size() + except AttributeError: + return 1 # 默认返回1 + +def get_xla_ordinal(): + """获取XLA ordinal,兼容不同版本""" + try: + return xm.get_ordinal() + except AttributeError: + return 0 # 默认返回0 + +def xla_mark_step(): + """XLA mark step,兼容不同版本""" + try: + xm.mark_step() + except AttributeError: + try: + xm.wait_device_ops() + except AttributeError: + pass # 如果都不可用,则跳过 + +def check_xla_device(): + """检查XLA设备状态""" + try: + device = xm.xla_device() + print(f"📱 XLA设备: {device}") + + world_size = get_xla_world_size() + ordinal = get_xla_ordinal() + + print(f"🌍 World Size: {world_size}") + print(f"🔢 Ordinal: {ordinal}") + + # 检测设备类型 + device_str = str(device) + if 'xla' in device_str and 'cpu' not in device_str: + print("✅ 检测到TPU设备") + return True, "TPU" + elif 'xla' in device_str and 'cpu' in device_str: + print("⚠️ XLA CPU模拟模式") + return True, "XLA_CPU" + else: + print("❌ 未检测到XLA设备") + return False, "CPU" + + except Exception as e: + print(f"❌ XLA设备检查失败: {e}") + return False, "ERROR" + +# 执行兼容性检查 +device_available, device_type = check_xla_device() + +if device_available: + print(f"✅ XLA环境正常,设备类型: {device_type}") + + # 测试基本XLA操作 + print("🧪 测试基本XLA操作...") + try: + device = xm.xla_device() + x = torch.randn(2, 2, device=device) + y = torch.matmul(x, x) + + # 测试同步函数 + xla_mark_step() + + print("✅ 基本XLA操作测试成功") + + except Exception as e: + print(f"❌ XLA操作测试失败: {e}") +else: + print("❌ XLA环境不可用") + +print("\n💡 兼容性检查完成,可以运行后续单元格") \ No newline at end of file diff --git a/model_training_nnn_tpu/jupyter_xla_monitor.py b/model_training_nnn_tpu/jupyter_xla_monitor.py index e02ece1..e194752 100644 --- a/model_training_nnn_tpu/jupyter_xla_monitor.py +++ b/model_training_nnn_tpu/jupyter_xla_monitor.py @@ -15,7 +15,43 @@ import torch_xla.core.xla_model as xm print(f"✅ XLA导入成功!") print(f" TPU设备: {xm.xla_device()}") -print(f" World Size: {xm.xrt_world_size()}") + +# 兼容新版本PyTorch XLA +try: + world_size = xm.xrt_world_size() + print(f" World Size (旧API): {world_size}") +except AttributeError: + try: + world_size = xm.get_world_size() + print(f" World Size (新API): {world_size}") + except AttributeError: + print(" World Size: 无法获取 (可能在CPU模式)") + +# 检查XLA版本兼容性 +print("🔍 检查XLA API兼容性:") +api_available = [] +api_deprecated = [] + +# 检查各种API +test_apis = [ + ('xrt_world_size', 'xrt_world_size()'), + ('get_world_size', 'get_world_size()'), + ('mark_step', 'mark_step()'), + ('wait_device_ops', 'wait_device_ops()'), + ('get_ordinal', 'get_ordinal()'), + ('xla_device_count', 'xla_device_count()') +] + +for api_name, api_desc in test_apis: + if hasattr(xm, api_name): + api_available.append(api_desc) + else: + api_deprecated.append(api_desc) + +if api_available: + print(f" ✅ 可用API: {', '.join(api_available)}") +if api_deprecated: + print(f" ❌ 不可用API: {', '.join(api_deprecated)}") # 创建编译进度监控器 class JupyterCompilationMonitor: diff --git a/model_training_nnn_tpu/jupyter_xla_setup.py b/model_training_nnn_tpu/jupyter_xla_setup.py deleted file mode 100644 index d37296e..0000000 --- a/model_training_nnn_tpu/jupyter_xla_setup.py +++ /dev/null @@ -1,45 +0,0 @@ -# ==================== -# 单元格1: 环境设置 (必须第一个运行!) -# ==================== - -import os -import time -import psutil -from IPython.display import display, HTML, clear_output -import ipywidgets as widgets - -# ⚠️ 关键: 在导入torch_xla之前设置环境变量 -print("🔧 设置XLA环境变量...") - -# 获取CPU核心数 -cpu_count = os.cpu_count() -print(f"检测到 {cpu_count} 个CPU核心") - -# 设置XLA编译优化环境变量 -os.environ['XLA_FLAGS'] = ( - '--xla_cpu_multi_thread_eigen=true ' - '--xla_cpu_enable_fast_math=true ' - f'--xla_force_host_platform_device_count={cpu_count}' -) -os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count) -os.environ['XLA_USE_BF16'] = '1' - -# 显示设置结果 -print("✅ XLA环境变量已设置:") -print(f" CPU核心数: {cpu_count}") -print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}") -print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}") - -# 系统资源检查 -memory_info = psutil.virtual_memory() -print(f"\n💾 系统内存信息:") -print(f" 总内存: {memory_info.total / (1024**3):.1f} GB") -print(f" 可用内存: {memory_info.available / (1024**3):.1f} GB") -print(f" 使用率: {memory_info.percent:.1f}%") - -if memory_info.available < 8 * (1024**3): # 小于8GB - print("⚠️ 警告: 可用内存不足8GB,可能影响XLA编译速度") -else: - print("✅ 内存充足") - -print("\n🎯 环境设置完成! 现在可以运行下一个单元格") \ No newline at end of file diff --git a/model_training_nnn_tpu/launch_tpu_training.py b/model_training_nnn_tpu/launch_tpu_training.py deleted file mode 100644 index 1beb755..0000000 --- a/model_training_nnn_tpu/launch_tpu_training.py +++ /dev/null @@ -1,161 +0,0 @@ -#!/usr/bin/env python3 -""" -TPU Training Launch Script for Brain-to-Text RNN Model - -This script provides easy TPU training setup using Accelerate library. -Supports both single TPU core and multi-core (8 cores) training. - -Usage: - python launch_tpu_training.py --config rnn_args.yaml --num_cores 8 - -Requirements: - - PyTorch XLA installed - - Accelerate library installed - - TPU runtime available -""" - -import argparse -import yaml -import os -import sys -from pathlib import Path - -def update_config_for_tpu(config_path, num_cores=8): - """ - Update configuration file to enable TPU training - """ - with open(config_path, 'r') as f: - config = yaml.safe_load(f) - - # Enable TPU settings - config['use_tpu'] = True - config['num_tpu_cores'] = num_cores - config['dataloader_num_workers'] = 0 # Required for TPU - config['use_amp'] = True # Enable mixed precision with bfloat16 - - # Adjust batch size and gradient accumulation for multi-core TPU - if num_cores > 1: - # Distribute batch size across cores - original_batch_size = config['dataset']['batch_size'] - config['dataset']['batch_size'] = max(1, original_batch_size // num_cores) - config['gradient_accumulation_steps'] = max(1, config.get('gradient_accumulation_steps', 1)) - - print(f"Adjusted batch size from {original_batch_size} to {config['dataset']['batch_size']} per core") - print(f"Gradient accumulation steps: {config['gradient_accumulation_steps']}") - - # Save updated config - tpu_config_path = config_path.replace('.yaml', '_tpu.yaml') - with open(tpu_config_path, 'w') as f: - yaml.dump(config, f, default_flow_style=False) - - print(f"TPU configuration saved to: {tpu_config_path}") - return tpu_config_path - -def check_tpu_environment(): - """ - Check if TPU environment is properly set up - """ - try: - import torch_xla - import torch_xla.core.xla_model as xm - - # Check if TPUs are available - device = xm.xla_device() - print(f"TPU device available: {device}") - print(f"TPU ordinal: {xm.get_ordinal()}") - print(f"TPU world size: {xm.xrt_world_size()}") - - return True - except ImportError: - print("ERROR: torch_xla not installed. Please install PyTorch XLA for TPU support.") - return False - except Exception as e: - print(f"ERROR: TPU not available - {e}") - return False - -def run_tpu_training(config_path, num_cores=8): - """ - Launch TPU training using accelerate - """ - # Check TPU environment - if not check_tpu_environment(): - sys.exit(1) - - # Update config for TPU - tpu_config_path = update_config_for_tpu(config_path, num_cores) - - # Set TPU environment variables BEFORE launching training - os.environ['TPU_CORES'] = str(num_cores) - os.environ['XLA_USE_BF16'] = '1' # Enable bfloat16 - - # Critical XLA multi-threading settings - must be set before torch_xla import - cpu_count = os.cpu_count() - os.environ['XLA_FLAGS'] = ( - '--xla_cpu_multi_thread_eigen=true ' - '--xla_cpu_enable_fast_math=true ' - f'--xla_force_host_platform_device_count={cpu_count}' - ) - os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count) - - print(f"Set XLA compilation to use {cpu_count} CPU threads") - print(f"XLA_FLAGS: {os.environ['XLA_FLAGS']}") - print(f"PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}") - - # Launch training with accelerate using subprocess to ensure environment variables are passed - cmd = f"accelerate launch --tpu --num_processes {num_cores} train_model.py --config_path {tpu_config_path}" - - print(f"Launching TPU training with command:") - print(f" {cmd}") - print(f"Using {num_cores} TPU cores") - print("-" * 60) - - # Use subprocess to ensure environment variables are properly inherited - import subprocess - - # Create environment with our XLA settings - env = os.environ.copy() - env.update({ - 'TPU_CORES': str(num_cores), - 'XLA_USE_BF16': '1', - 'XLA_FLAGS': ( - '--xla_cpu_multi_thread_eigen=true ' - '--xla_cpu_enable_fast_math=true ' - f'--xla_force_host_platform_device_count={cpu_count}' - ), - 'PYTORCH_XLA_COMPILATION_THREADS': str(cpu_count) - }) - - print(f"Environment variables set for subprocess:") - print(f" XLA_FLAGS: {env['XLA_FLAGS']}") - print(f" PYTORCH_XLA_COMPILATION_THREADS: {env['PYTORCH_XLA_COMPILATION_THREADS']}") - print("-" * 60) - - # Execute training with proper environment - result = subprocess.run(cmd.split(), env=env) - return result.returncode - -def main(): - parser = argparse.ArgumentParser(description='Launch TPU training for Brain-to-Text RNN') - parser.add_argument('--config', default='rnn_args.yaml', - help='Path to configuration file (default: rnn_args.yaml)') - parser.add_argument('--num_cores', type=int, default=8, - help='Number of TPU cores to use (default: 8)') - parser.add_argument('--check_only', action='store_true', - help='Only check TPU environment, do not launch training') - - args = parser.parse_args() - - # Verify config file exists - if not os.path.exists(args.config): - print(f"ERROR: Configuration file {args.config} not found") - sys.exit(1) - - if args.check_only: - check_tpu_environment() - return - - # Run TPU training - run_tpu_training(args.config, args.num_cores) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/model_training_nnn_tpu/monitor_xla_compilation.py b/model_training_nnn_tpu/monitor_xla_compilation.py deleted file mode 100644 index 09d8838..0000000 --- a/model_training_nnn_tpu/monitor_xla_compilation.py +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env python3 -""" -XLA编译进度监控脚本 -""" - -import os -import time -import threading -import psutil -from contextlib import contextmanager - -def monitor_compilation_progress(): - """监控XLA编译进度""" - print("🔍 XLA编译进度监控已启动...") - - start_time = time.time() - dots = 0 - - while True: - elapsed = time.time() - start_time - minutes = int(elapsed // 60) - seconds = int(elapsed % 60) - - # 获取CPU使用率 - cpu_percent = psutil.cpu_percent(interval=1) - memory_percent = psutil.virtual_memory().percent - - # 动态显示 - dots = (dots + 1) % 4 - dot_str = "." * dots + " " * (3 - dots) - - print(f"\r🔄 XLA编译中{dot_str} " - f"⏱️ {minutes:02d}:{seconds:02d} " - f"🖥️ CPU: {cpu_percent:5.1f}% " - f"💾 内存: {memory_percent:5.1f}%", end="", flush=True) - - time.sleep(1) - -@contextmanager -def compilation_monitor(): - """编译监控上下文管理器""" - print("🚀 开始XLA编译监控...") - - # 启动监控线程 - monitor_thread = threading.Thread(target=monitor_compilation_progress, daemon=True) - monitor_thread.start() - - start_time = time.time() - - try: - yield - finally: - elapsed = time.time() - start_time - print(f"\n✅ XLA编译完成! 总耗时: {elapsed:.2f}秒") - -# 修改trainer中的使用方法 -def add_compilation_monitoring_to_trainer(): - """给trainer添加编译监控的示例代码""" - example_code = ''' -# 在rnn_trainer.py的train方法中添加: - -def train(self): - from monitor_xla_compilation import compilation_monitor - - self.model.train() - train_losses = [] - # ... 其他初始化代码 ... - - self.logger.info("Starting training loop - XLA compilation monitoring enabled...") - - # 使用编译监控 - with compilation_monitor(): - for i, batch in enumerate(self.train_loader): - # 第一个batch会触发XLA编译 - # 监控会显示编译进度 - - # ... 训练代码 ... - - # 编译完成后会自动停止监控 - break # 只需要第一个batch来触发编译 - - # 继续正常训练循环 - for i, batch in enumerate(self.train_loader): - # ... 正常训练代码 ... - ''' - - print("📝 如何在trainer中使用编译监控:") - print(example_code) - -if __name__ == "__main__": - print("🧪 XLA编译监控工具") - print("=" * 50) - - # 演示如何使用 - print("📖 使用方法:") - print("1. 将此文件导入到你的训练脚本中") - print("2. 在第一次模型调用前使用 compilation_monitor() 上下文管理器") - print("3. 会实时显示编译进度和系统资源使用情况") - - add_compilation_monitoring_to_trainer() \ No newline at end of file diff --git a/model_training_nnn_tpu/start_tpu_training.sh b/model_training_nnn_tpu/start_tpu_training.sh deleted file mode 100644 index a09022a..0000000 --- a/model_training_nnn_tpu/start_tpu_training.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash - -# TPU XLA Multi-threading Environment Setup -# Set these BEFORE starting Python to ensure they take effect - -echo "Setting up XLA multi-threading environment..." - -# Get CPU core count -CPU_CORES=$(nproc) -echo "Detected $CPU_CORES CPU cores" - -# Set XLA compilation flags -export XLA_FLAGS="--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=true --xla_force_host_platform_device_count=$CPU_CORES" -export PYTORCH_XLA_COMPILATION_THREADS=$CPU_CORES - -# Additional XLA optimizations -export XLA_USE_BF16=1 -export TPU_CORES=8 - -# Print current settings -echo "XLA_FLAGS: $XLA_FLAGS" -echo "PYTORCH_XLA_COMPILATION_THREADS: $PYTORCH_XLA_COMPILATION_THREADS" -echo "XLA_USE_BF16: $XLA_USE_BF16" - -# Start training -echo "Starting TPU training..." -python train_model.py --config_path rnn_args.yaml \ No newline at end of file