This commit is contained in:
Zchen
2025-10-15 14:33:49 +08:00
parent 56fa336af0
commit e7947f310c
8 changed files with 130 additions and 606 deletions

View File

@@ -1,148 +0,0 @@
#!/usr/bin/env python3
"""
XLA Multi-threading Diagnostic Script
检查XLA编译是否正确使用多CPU核心
"""
import os
import psutil
import time
import threading
from concurrent.futures import ThreadPoolExecutor
def set_xla_environment():
"""设置XLA环境变量"""
cpu_count = os.cpu_count()
# 设置XLA环境变量
os.environ['XLA_FLAGS'] = (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true '
f'--xla_force_host_platform_device_count={cpu_count}'
)
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
print(f"🔧 设置XLA环境变量:")
print(f" CPU核心数: {cpu_count}")
print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
print("-" * 60)
def monitor_cpu_usage(duration=30, interval=1):
"""监控CPU使用情况"""
print(f"🔍 监控CPU使用情况 {duration}秒...")
cpu_usage_data = []
start_time = time.time()
while time.time() - start_time < duration:
# 获取每个CPU核心的使用率
cpu_percent_per_core = psutil.cpu_percent(interval=interval, percpu=True)
cpu_usage_data.append(cpu_percent_per_core)
# 实时显示
active_cores = sum(1 for usage in cpu_percent_per_core if usage > 10)
print(f"活跃核心数: {active_cores}/{len(cpu_percent_per_core)}, "
f"平均使用率: {sum(cpu_percent_per_core)/len(cpu_percent_per_core):.1f}%",
end='\r')
print() # 换行
# 分析结果
if cpu_usage_data:
avg_usage_per_core = [
sum(core_data) / len(cpu_usage_data)
for core_data in zip(*cpu_usage_data)
]
active_cores = sum(1 for avg in avg_usage_per_core if avg > 5)
max_usage = max(avg_usage_per_core)
print(f"📊 CPU使用分析:")
print(f" 活跃的CPU核心: {active_cores}/{len(avg_usage_per_core)}")
print(f" 最高平均使用率: {max_usage:.1f}%")
for i, usage in enumerate(avg_usage_per_core):
status = "🟢" if usage > 10 else "🔴" if usage > 5 else ""
print(f" CPU核心 {i}: {usage:.1f}% {status}")
return active_cores > 1
return False
def test_xla_compilation():
"""测试XLA编译"""
print(f"🚀 开始XLA编译测试...")
try:
import torch
import torch_xla.core.xla_model as xm
print(f"✅ PyTorch XLA导入成功")
print(f" XLA设备: {xm.xla_device()}")
print(f" XLA world size: {xm.xrt_world_size()}")
# 创建一个简单的计算图进行编译
device = xm.xla_device()
print(f"🔄 创建测试计算图...")
x = torch.randn(100, 100, device=device)
y = torch.randn(100, 100, device=device)
print(f"🔄 执行矩阵运算 (将触发XLA编译)...")
# 启动CPU监控
monitor_thread = threading.Thread(
target=lambda: monitor_cpu_usage(20, 0.5),
daemon=True
)
monitor_thread.start()
# 执行计算,触发编译
for i in range(10):
z = torch.matmul(x, y)
z = torch.relu(z)
z = torch.matmul(z, x.T)
if i == 0:
print(f"🔄 首次计算完成 (XLA编译应该正在进行)...")
elif i == 5:
print(f"🔄 第6次计算完成...")
# 等待监控完成
monitor_thread.join(timeout=25)
print(f"✅ XLA测试完成")
return True
except ImportError as e:
print(f"❌ PyTorch XLA导入失败: {e}")
return False
except Exception as e:
print(f"❌ XLA测试失败: {e}")
return False
def main():
print("=" * 60)
print("🧪 XLA多线程编译诊断工具")
print("=" * 60)
# 1. 设置环境
set_xla_environment()
# 2. 测试XLA编译并监控CPU
success = test_xla_compilation()
print("=" * 60)
if success:
print("✅ 诊断完成")
print("💡 如果看到多个CPU核心被激活说明XLA多线程工作正常")
print("💡 如果只有1-2个核心活跃可能需要其他优化方法")
else:
print("❌ 诊断失败")
print("💡 请检查PyTorch XLA安装和TPU环境")
print("=" * 60)
if __name__ == "__main__":
main()

View File

@@ -1,124 +0,0 @@
# ====================
# 单元格4: 逐步调试完整模型编译
# ====================
# 如果单元格3测试通过运行这个单元格
print("🔧 逐步测试完整TripleGRUDecoder模型...")
# 导入完整模型
import sys
sys.path.append('.') # 确保能导入本地模块
try:
from rnn_model import TripleGRUDecoder
print("✅ TripleGRUDecoder导入成功")
except ImportError as e:
print(f"❌ 模型导入失败: {e}")
print("请确保rnn_model.py在当前目录中")
# 分阶段测试
def test_model_compilation_stages():
"""分阶段测试模型编译"""
device = xm.xla_device()
# 阶段1: 测试NoiseModel单独编译
print("\n🔬 阶段1: 测试NoiseModel...")
try:
from rnn_model import NoiseModel
noise_model = NoiseModel(
neural_dim=512,
n_units=384, # 减小参数
n_days=4,
patch_size=8 # 减小patch size
).to(device)
x = torch.randn(2, 20, 512, device=device)
day_idx = torch.tensor([0, 1], device=device)
start_time = time.time()
with torch.no_grad():
output, states = noise_model(x, day_idx)
compile_time = time.time() - start_time
print(f"✅ NoiseModel编译成功! 耗时: {compile_time:.2f}")
print(f" 参数数量: {sum(p.numel() for p in noise_model.parameters()):,}")
return True, compile_time
except Exception as e:
print(f"❌ NoiseModel编译失败: {e}")
return False, 0
# 阶段2: 测试CleanSpeechModel
print("\n🔬 阶段2: 测试CleanSpeechModel...")
try:
from rnn_model import CleanSpeechModel
clean_model = CleanSpeechModel(
neural_dim=512,
n_units=384,
n_days=4,
n_classes=41,
patch_size=8
).to(device)
start_time = time.time()
with torch.no_grad():
output = clean_model(x, day_idx)
compile_time = time.time() - start_time
print(f"✅ CleanSpeechModel编译成功! 耗时: {compile_time:.2f}")
return True, compile_time
except Exception as e:
print(f"❌ CleanSpeechModel编译失败: {e}")
return False, 0
# 阶段3: 测试完整TripleGRUDecoder
print("\n🔬 阶段3: 测试TripleGRUDecoder...")
try:
model = TripleGRUDecoder(
neural_dim=512,
n_units=384, # 比原来的768小
n_days=4, # 减少天数
n_classes=41,
patch_size=8 # 减小patch size
).to(device)
print(f"📊 完整模型参数: {sum(p.numel() for p in model.parameters()):,}")
# 启动编译监控
compilation_monitor.start_monitoring()
start_time = time.time()
with torch.no_grad():
# 测试inference模式
logits = model(x, day_idx, None, False, 'inference')
compile_time = time.time() - start_time
compilation_monitor.complete_monitoring()
print(f"✅ TripleGRUDecoder编译成功! 耗时: {compile_time:.2f}")
print(f"📤 输出形状: {logits.shape}")
return True, compile_time
except Exception as e:
compilation_monitor.complete_monitoring()
print(f"❌ TripleGRUDecoder编译失败: {e}")
return False, 0
# 运行分阶段测试
stage_results = test_model_compilation_stages()
if stage_results:
print(f"\n🎉 所有编译测试完成!")
print("💡 下一步可以尝试:")
print(" 1. 使用简化配置进行训练")
print(" 2. 逐步增加模型复杂度")
print(" 3. 监控TPU资源使用情况")
else:
print(f"\n⚠️ 编译测试发现问题")
print("💡 建议:")
print(" 1. 进一步减小模型参数")
print(" 2. 检查内存使用情况")
print(" 3. 使用CPU模式进行调试")

View File

@@ -0,0 +1,93 @@
# ====================
# 单元格: XLA版本兼容性检查和修复
# ====================
import torch
import torch.nn as nn
print("🔧 PyTorch XLA版本兼容性检查...")
# 导入XLA
import torch_xla.core.xla_model as xm
print("✅ PyTorch XLA导入成功!")
# 定义兼容性函数
def get_xla_world_size():
"""获取XLA world size兼容不同版本"""
try:
return xm.xrt_world_size()
except AttributeError:
try:
return xm.get_world_size()
except AttributeError:
return 1 # 默认返回1
def get_xla_ordinal():
"""获取XLA ordinal兼容不同版本"""
try:
return xm.get_ordinal()
except AttributeError:
return 0 # 默认返回0
def xla_mark_step():
"""XLA mark step兼容不同版本"""
try:
xm.mark_step()
except AttributeError:
try:
xm.wait_device_ops()
except AttributeError:
pass # 如果都不可用,则跳过
def check_xla_device():
"""检查XLA设备状态"""
try:
device = xm.xla_device()
print(f"📱 XLA设备: {device}")
world_size = get_xla_world_size()
ordinal = get_xla_ordinal()
print(f"🌍 World Size: {world_size}")
print(f"🔢 Ordinal: {ordinal}")
# 检测设备类型
device_str = str(device)
if 'xla' in device_str and 'cpu' not in device_str:
print("✅ 检测到TPU设备")
return True, "TPU"
elif 'xla' in device_str and 'cpu' in device_str:
print("⚠️ XLA CPU模拟模式")
return True, "XLA_CPU"
else:
print("❌ 未检测到XLA设备")
return False, "CPU"
except Exception as e:
print(f"❌ XLA设备检查失败: {e}")
return False, "ERROR"
# 执行兼容性检查
device_available, device_type = check_xla_device()
if device_available:
print(f"✅ XLA环境正常设备类型: {device_type}")
# 测试基本XLA操作
print("🧪 测试基本XLA操作...")
try:
device = xm.xla_device()
x = torch.randn(2, 2, device=device)
y = torch.matmul(x, x)
# 测试同步函数
xla_mark_step()
print("✅ 基本XLA操作测试成功")
except Exception as e:
print(f"❌ XLA操作测试失败: {e}")
else:
print("❌ XLA环境不可用")
print("\n💡 兼容性检查完成,可以运行后续单元格")

View File

@@ -15,7 +15,43 @@ import torch_xla.core.xla_model as xm
print(f"✅ XLA导入成功!")
print(f" TPU设备: {xm.xla_device()}")
print(f" World Size: {xm.xrt_world_size()}")
# 兼容新版本PyTorch XLA
try:
world_size = xm.xrt_world_size()
print(f" World Size (旧API): {world_size}")
except AttributeError:
try:
world_size = xm.get_world_size()
print(f" World Size (新API): {world_size}")
except AttributeError:
print(" World Size: 无法获取 (可能在CPU模式)")
# 检查XLA版本兼容性
print("🔍 检查XLA API兼容性:")
api_available = []
api_deprecated = []
# 检查各种API
test_apis = [
('xrt_world_size', 'xrt_world_size()'),
('get_world_size', 'get_world_size()'),
('mark_step', 'mark_step()'),
('wait_device_ops', 'wait_device_ops()'),
('get_ordinal', 'get_ordinal()'),
('xla_device_count', 'xla_device_count()')
]
for api_name, api_desc in test_apis:
if hasattr(xm, api_name):
api_available.append(api_desc)
else:
api_deprecated.append(api_desc)
if api_available:
print(f" ✅ 可用API: {', '.join(api_available)}")
if api_deprecated:
print(f" ❌ 不可用API: {', '.join(api_deprecated)}")
# 创建编译进度监控器
class JupyterCompilationMonitor:

View File

@@ -1,45 +0,0 @@
# ====================
# 单元格1: 环境设置 (必须第一个运行!)
# ====================
import os
import time
import psutil
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
# ⚠️ 关键: 在导入torch_xla之前设置环境变量
print("🔧 设置XLA环境变量...")
# 获取CPU核心数
cpu_count = os.cpu_count()
print(f"检测到 {cpu_count} 个CPU核心")
# 设置XLA编译优化环境变量
os.environ['XLA_FLAGS'] = (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true '
f'--xla_force_host_platform_device_count={cpu_count}'
)
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
os.environ['XLA_USE_BF16'] = '1'
# 显示设置结果
print("✅ XLA环境变量已设置:")
print(f" CPU核心数: {cpu_count}")
print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
# 系统资源检查
memory_info = psutil.virtual_memory()
print(f"\n💾 系统内存信息:")
print(f" 总内存: {memory_info.total / (1024**3):.1f} GB")
print(f" 可用内存: {memory_info.available / (1024**3):.1f} GB")
print(f" 使用率: {memory_info.percent:.1f}%")
if memory_info.available < 8 * (1024**3): # 小于8GB
print("⚠️ 警告: 可用内存不足8GB可能影响XLA编译速度")
else:
print("✅ 内存充足")
print("\n🎯 环境设置完成! 现在可以运行下一个单元格")

View File

@@ -1,161 +0,0 @@
#!/usr/bin/env python3
"""
TPU Training Launch Script for Brain-to-Text RNN Model
This script provides easy TPU training setup using Accelerate library.
Supports both single TPU core and multi-core (8 cores) training.
Usage:
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
Requirements:
- PyTorch XLA installed
- Accelerate library installed
- TPU runtime available
"""
import argparse
import yaml
import os
import sys
from pathlib import Path
def update_config_for_tpu(config_path, num_cores=8):
"""
Update configuration file to enable TPU training
"""
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
# Enable TPU settings
config['use_tpu'] = True
config['num_tpu_cores'] = num_cores
config['dataloader_num_workers'] = 0 # Required for TPU
config['use_amp'] = True # Enable mixed precision with bfloat16
# Adjust batch size and gradient accumulation for multi-core TPU
if num_cores > 1:
# Distribute batch size across cores
original_batch_size = config['dataset']['batch_size']
config['dataset']['batch_size'] = max(1, original_batch_size // num_cores)
config['gradient_accumulation_steps'] = max(1, config.get('gradient_accumulation_steps', 1))
print(f"Adjusted batch size from {original_batch_size} to {config['dataset']['batch_size']} per core")
print(f"Gradient accumulation steps: {config['gradient_accumulation_steps']}")
# Save updated config
tpu_config_path = config_path.replace('.yaml', '_tpu.yaml')
with open(tpu_config_path, 'w') as f:
yaml.dump(config, f, default_flow_style=False)
print(f"TPU configuration saved to: {tpu_config_path}")
return tpu_config_path
def check_tpu_environment():
"""
Check if TPU environment is properly set up
"""
try:
import torch_xla
import torch_xla.core.xla_model as xm
# Check if TPUs are available
device = xm.xla_device()
print(f"TPU device available: {device}")
print(f"TPU ordinal: {xm.get_ordinal()}")
print(f"TPU world size: {xm.xrt_world_size()}")
return True
except ImportError:
print("ERROR: torch_xla not installed. Please install PyTorch XLA for TPU support.")
return False
except Exception as e:
print(f"ERROR: TPU not available - {e}")
return False
def run_tpu_training(config_path, num_cores=8):
"""
Launch TPU training using accelerate
"""
# Check TPU environment
if not check_tpu_environment():
sys.exit(1)
# Update config for TPU
tpu_config_path = update_config_for_tpu(config_path, num_cores)
# Set TPU environment variables BEFORE launching training
os.environ['TPU_CORES'] = str(num_cores)
os.environ['XLA_USE_BF16'] = '1' # Enable bfloat16
# Critical XLA multi-threading settings - must be set before torch_xla import
cpu_count = os.cpu_count()
os.environ['XLA_FLAGS'] = (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true '
f'--xla_force_host_platform_device_count={cpu_count}'
)
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
print(f"Set XLA compilation to use {cpu_count} CPU threads")
print(f"XLA_FLAGS: {os.environ['XLA_FLAGS']}")
print(f"PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
# Launch training with accelerate using subprocess to ensure environment variables are passed
cmd = f"accelerate launch --tpu --num_processes {num_cores} train_model.py --config_path {tpu_config_path}"
print(f"Launching TPU training with command:")
print(f" {cmd}")
print(f"Using {num_cores} TPU cores")
print("-" * 60)
# Use subprocess to ensure environment variables are properly inherited
import subprocess
# Create environment with our XLA settings
env = os.environ.copy()
env.update({
'TPU_CORES': str(num_cores),
'XLA_USE_BF16': '1',
'XLA_FLAGS': (
'--xla_cpu_multi_thread_eigen=true '
'--xla_cpu_enable_fast_math=true '
f'--xla_force_host_platform_device_count={cpu_count}'
),
'PYTORCH_XLA_COMPILATION_THREADS': str(cpu_count)
})
print(f"Environment variables set for subprocess:")
print(f" XLA_FLAGS: {env['XLA_FLAGS']}")
print(f" PYTORCH_XLA_COMPILATION_THREADS: {env['PYTORCH_XLA_COMPILATION_THREADS']}")
print("-" * 60)
# Execute training with proper environment
result = subprocess.run(cmd.split(), env=env)
return result.returncode
def main():
parser = argparse.ArgumentParser(description='Launch TPU training for Brain-to-Text RNN')
parser.add_argument('--config', default='rnn_args.yaml',
help='Path to configuration file (default: rnn_args.yaml)')
parser.add_argument('--num_cores', type=int, default=8,
help='Number of TPU cores to use (default: 8)')
parser.add_argument('--check_only', action='store_true',
help='Only check TPU environment, do not launch training')
args = parser.parse_args()
# Verify config file exists
if not os.path.exists(args.config):
print(f"ERROR: Configuration file {args.config} not found")
sys.exit(1)
if args.check_only:
check_tpu_environment()
return
# Run TPU training
run_tpu_training(args.config, args.num_cores)
if __name__ == "__main__":
main()

View File

@@ -1,100 +0,0 @@
#!/usr/bin/env python3
"""
XLA编译进度监控脚本
"""
import os
import time
import threading
import psutil
from contextlib import contextmanager
def monitor_compilation_progress():
"""监控XLA编译进度"""
print("🔍 XLA编译进度监控已启动...")
start_time = time.time()
dots = 0
while True:
elapsed = time.time() - start_time
minutes = int(elapsed // 60)
seconds = int(elapsed % 60)
# 获取CPU使用率
cpu_percent = psutil.cpu_percent(interval=1)
memory_percent = psutil.virtual_memory().percent
# 动态显示
dots = (dots + 1) % 4
dot_str = "." * dots + " " * (3 - dots)
print(f"\r🔄 XLA编译中{dot_str} "
f"⏱️ {minutes:02d}:{seconds:02d} "
f"🖥️ CPU: {cpu_percent:5.1f}% "
f"💾 内存: {memory_percent:5.1f}%", end="", flush=True)
time.sleep(1)
@contextmanager
def compilation_monitor():
"""编译监控上下文管理器"""
print("🚀 开始XLA编译监控...")
# 启动监控线程
monitor_thread = threading.Thread(target=monitor_compilation_progress, daemon=True)
monitor_thread.start()
start_time = time.time()
try:
yield
finally:
elapsed = time.time() - start_time
print(f"\n✅ XLA编译完成! 总耗时: {elapsed:.2f}")
# 修改trainer中的使用方法
def add_compilation_monitoring_to_trainer():
"""给trainer添加编译监控的示例代码"""
example_code = '''
# 在rnn_trainer.py的train方法中添加
def train(self):
from monitor_xla_compilation import compilation_monitor
self.model.train()
train_losses = []
# ... 其他初始化代码 ...
self.logger.info("Starting training loop - XLA compilation monitoring enabled...")
# 使用编译监控
with compilation_monitor():
for i, batch in enumerate(self.train_loader):
# 第一个batch会触发XLA编译
# 监控会显示编译进度
# ... 训练代码 ...
# 编译完成后会自动停止监控
break # 只需要第一个batch来触发编译
# 继续正常训练循环
for i, batch in enumerate(self.train_loader):
# ... 正常训练代码 ...
'''
print("📝 如何在trainer中使用编译监控:")
print(example_code)
if __name__ == "__main__":
print("🧪 XLA编译监控工具")
print("=" * 50)
# 演示如何使用
print("📖 使用方法:")
print("1. 将此文件导入到你的训练脚本中")
print("2. 在第一次模型调用前使用 compilation_monitor() 上下文管理器")
print("3. 会实时显示编译进度和系统资源使用情况")
add_compilation_monitoring_to_trainer()

View File

@@ -1,27 +0,0 @@
#!/bin/bash
# TPU XLA Multi-threading Environment Setup
# Set these BEFORE starting Python to ensure they take effect
echo "Setting up XLA multi-threading environment..."
# Get CPU core count
CPU_CORES=$(nproc)
echo "Detected $CPU_CORES CPU cores"
# Set XLA compilation flags
export XLA_FLAGS="--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=true --xla_force_host_platform_device_count=$CPU_CORES"
export PYTORCH_XLA_COMPILATION_THREADS=$CPU_CORES
# Additional XLA optimizations
export XLA_USE_BF16=1
export TPU_CORES=8
# Print current settings
echo "XLA_FLAGS: $XLA_FLAGS"
echo "PYTORCH_XLA_COMPILATION_THREADS: $PYTORCH_XLA_COMPILATION_THREADS"
echo "XLA_USE_BF16: $XLA_USE_BF16"
# Start training
echo "Starting TPU training..."
python train_model.py --config_path rnn_args.yaml