tpu
This commit is contained in:
@@ -1,148 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
XLA Multi-threading Diagnostic Script
|
||||
检查XLA编译是否正确使用多CPU核心
|
||||
"""
|
||||
|
||||
import os
|
||||
import psutil
|
||||
import time
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
def set_xla_environment():
|
||||
"""设置XLA环境变量"""
|
||||
cpu_count = os.cpu_count()
|
||||
|
||||
# 设置XLA环境变量
|
||||
os.environ['XLA_FLAGS'] = (
|
||||
'--xla_cpu_multi_thread_eigen=true '
|
||||
'--xla_cpu_enable_fast_math=true '
|
||||
f'--xla_force_host_platform_device_count={cpu_count}'
|
||||
)
|
||||
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
|
||||
|
||||
print(f"🔧 设置XLA环境变量:")
|
||||
print(f" CPU核心数: {cpu_count}")
|
||||
print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
|
||||
print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
|
||||
print("-" * 60)
|
||||
|
||||
def monitor_cpu_usage(duration=30, interval=1):
|
||||
"""监控CPU使用情况"""
|
||||
print(f"🔍 监控CPU使用情况 {duration}秒...")
|
||||
|
||||
cpu_usage_data = []
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < duration:
|
||||
# 获取每个CPU核心的使用率
|
||||
cpu_percent_per_core = psutil.cpu_percent(interval=interval, percpu=True)
|
||||
cpu_usage_data.append(cpu_percent_per_core)
|
||||
|
||||
# 实时显示
|
||||
active_cores = sum(1 for usage in cpu_percent_per_core if usage > 10)
|
||||
print(f"活跃核心数: {active_cores}/{len(cpu_percent_per_core)}, "
|
||||
f"平均使用率: {sum(cpu_percent_per_core)/len(cpu_percent_per_core):.1f}%",
|
||||
end='\r')
|
||||
|
||||
print() # 换行
|
||||
|
||||
# 分析结果
|
||||
if cpu_usage_data:
|
||||
avg_usage_per_core = [
|
||||
sum(core_data) / len(cpu_usage_data)
|
||||
for core_data in zip(*cpu_usage_data)
|
||||
]
|
||||
|
||||
active_cores = sum(1 for avg in avg_usage_per_core if avg > 5)
|
||||
max_usage = max(avg_usage_per_core)
|
||||
|
||||
print(f"📊 CPU使用分析:")
|
||||
print(f" 活跃的CPU核心: {active_cores}/{len(avg_usage_per_core)}")
|
||||
print(f" 最高平均使用率: {max_usage:.1f}%")
|
||||
|
||||
for i, usage in enumerate(avg_usage_per_core):
|
||||
status = "🟢" if usage > 10 else "🔴" if usage > 5 else "⚫"
|
||||
print(f" CPU核心 {i}: {usage:.1f}% {status}")
|
||||
|
||||
return active_cores > 1
|
||||
|
||||
return False
|
||||
|
||||
def test_xla_compilation():
|
||||
"""测试XLA编译"""
|
||||
print(f"🚀 开始XLA编译测试...")
|
||||
|
||||
try:
|
||||
import torch
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
print(f"✅ PyTorch XLA导入成功")
|
||||
print(f" XLA设备: {xm.xla_device()}")
|
||||
print(f" XLA world size: {xm.xrt_world_size()}")
|
||||
|
||||
# 创建一个简单的计算图进行编译
|
||||
device = xm.xla_device()
|
||||
|
||||
print(f"🔄 创建测试计算图...")
|
||||
x = torch.randn(100, 100, device=device)
|
||||
y = torch.randn(100, 100, device=device)
|
||||
|
||||
print(f"🔄 执行矩阵运算 (将触发XLA编译)...")
|
||||
|
||||
# 启动CPU监控
|
||||
monitor_thread = threading.Thread(
|
||||
target=lambda: monitor_cpu_usage(20, 0.5),
|
||||
daemon=True
|
||||
)
|
||||
monitor_thread.start()
|
||||
|
||||
# 执行计算,触发编译
|
||||
for i in range(10):
|
||||
z = torch.matmul(x, y)
|
||||
z = torch.relu(z)
|
||||
z = torch.matmul(z, x.T)
|
||||
if i == 0:
|
||||
print(f"🔄 首次计算完成 (XLA编译应该正在进行)...")
|
||||
elif i == 5:
|
||||
print(f"🔄 第6次计算完成...")
|
||||
|
||||
# 等待监控完成
|
||||
monitor_thread.join(timeout=25)
|
||||
|
||||
print(f"✅ XLA测试完成")
|
||||
|
||||
return True
|
||||
|
||||
except ImportError as e:
|
||||
print(f"❌ PyTorch XLA导入失败: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ XLA测试失败: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("🧪 XLA多线程编译诊断工具")
|
||||
print("=" * 60)
|
||||
|
||||
# 1. 设置环境
|
||||
set_xla_environment()
|
||||
|
||||
# 2. 测试XLA编译并监控CPU
|
||||
success = test_xla_compilation()
|
||||
|
||||
print("=" * 60)
|
||||
if success:
|
||||
print("✅ 诊断完成")
|
||||
print("💡 如果看到多个CPU核心被激活,说明XLA多线程工作正常")
|
||||
print("💡 如果只有1-2个核心活跃,可能需要其他优化方法")
|
||||
else:
|
||||
print("❌ 诊断失败")
|
||||
print("💡 请检查PyTorch XLA安装和TPU环境")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -1,124 +0,0 @@
|
||||
# ====================
|
||||
# 单元格4: 逐步调试完整模型编译
|
||||
# ====================
|
||||
|
||||
# 如果单元格3测试通过,运行这个单元格
|
||||
print("🔧 逐步测试完整TripleGRUDecoder模型...")
|
||||
|
||||
# 导入完整模型
|
||||
import sys
|
||||
sys.path.append('.') # 确保能导入本地模块
|
||||
|
||||
try:
|
||||
from rnn_model import TripleGRUDecoder
|
||||
print("✅ TripleGRUDecoder导入成功")
|
||||
except ImportError as e:
|
||||
print(f"❌ 模型导入失败: {e}")
|
||||
print("请确保rnn_model.py在当前目录中")
|
||||
|
||||
# 分阶段测试
|
||||
def test_model_compilation_stages():
|
||||
"""分阶段测试模型编译"""
|
||||
device = xm.xla_device()
|
||||
|
||||
# 阶段1: 测试NoiseModel单独编译
|
||||
print("\n🔬 阶段1: 测试NoiseModel...")
|
||||
try:
|
||||
from rnn_model import NoiseModel
|
||||
noise_model = NoiseModel(
|
||||
neural_dim=512,
|
||||
n_units=384, # 减小参数
|
||||
n_days=4,
|
||||
patch_size=8 # 减小patch size
|
||||
).to(device)
|
||||
|
||||
x = torch.randn(2, 20, 512, device=device)
|
||||
day_idx = torch.tensor([0, 1], device=device)
|
||||
|
||||
start_time = time.time()
|
||||
with torch.no_grad():
|
||||
output, states = noise_model(x, day_idx)
|
||||
compile_time = time.time() - start_time
|
||||
|
||||
print(f"✅ NoiseModel编译成功! 耗时: {compile_time:.2f}秒")
|
||||
print(f" 参数数量: {sum(p.numel() for p in noise_model.parameters()):,}")
|
||||
|
||||
return True, compile_time
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ NoiseModel编译失败: {e}")
|
||||
return False, 0
|
||||
|
||||
# 阶段2: 测试CleanSpeechModel
|
||||
print("\n🔬 阶段2: 测试CleanSpeechModel...")
|
||||
try:
|
||||
from rnn_model import CleanSpeechModel
|
||||
clean_model = CleanSpeechModel(
|
||||
neural_dim=512,
|
||||
n_units=384,
|
||||
n_days=4,
|
||||
n_classes=41,
|
||||
patch_size=8
|
||||
).to(device)
|
||||
|
||||
start_time = time.time()
|
||||
with torch.no_grad():
|
||||
output = clean_model(x, day_idx)
|
||||
compile_time = time.time() - start_time
|
||||
|
||||
print(f"✅ CleanSpeechModel编译成功! 耗时: {compile_time:.2f}秒")
|
||||
return True, compile_time
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ CleanSpeechModel编译失败: {e}")
|
||||
return False, 0
|
||||
|
||||
# 阶段3: 测试完整TripleGRUDecoder
|
||||
print("\n🔬 阶段3: 测试TripleGRUDecoder...")
|
||||
try:
|
||||
model = TripleGRUDecoder(
|
||||
neural_dim=512,
|
||||
n_units=384, # 比原来的768小
|
||||
n_days=4, # 减少天数
|
||||
n_classes=41,
|
||||
patch_size=8 # 减小patch size
|
||||
).to(device)
|
||||
|
||||
print(f"📊 完整模型参数: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
# 启动编译监控
|
||||
compilation_monitor.start_monitoring()
|
||||
|
||||
start_time = time.time()
|
||||
with torch.no_grad():
|
||||
# 测试inference模式
|
||||
logits = model(x, day_idx, None, False, 'inference')
|
||||
compile_time = time.time() - start_time
|
||||
|
||||
compilation_monitor.complete_monitoring()
|
||||
|
||||
print(f"✅ TripleGRUDecoder编译成功! 耗时: {compile_time:.2f}秒")
|
||||
print(f"📤 输出形状: {logits.shape}")
|
||||
|
||||
return True, compile_time
|
||||
|
||||
except Exception as e:
|
||||
compilation_monitor.complete_monitoring()
|
||||
print(f"❌ TripleGRUDecoder编译失败: {e}")
|
||||
return False, 0
|
||||
|
||||
# 运行分阶段测试
|
||||
stage_results = test_model_compilation_stages()
|
||||
|
||||
if stage_results:
|
||||
print(f"\n🎉 所有编译测试完成!")
|
||||
print("💡 下一步可以尝试:")
|
||||
print(" 1. 使用简化配置进行训练")
|
||||
print(" 2. 逐步增加模型复杂度")
|
||||
print(" 3. 监控TPU资源使用情况")
|
||||
else:
|
||||
print(f"\n⚠️ 编译测试发现问题")
|
||||
print("💡 建议:")
|
||||
print(" 1. 进一步减小模型参数")
|
||||
print(" 2. 检查内存使用情况")
|
||||
print(" 3. 使用CPU模式进行调试")
|
93
model_training_nnn_tpu/jupyter_xla_compatibility.py
Normal file
93
model_training_nnn_tpu/jupyter_xla_compatibility.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# ====================
|
||||
# 单元格: XLA版本兼容性检查和修复
|
||||
# ====================
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
print("🔧 PyTorch XLA版本兼容性检查...")
|
||||
|
||||
# 导入XLA
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
print("✅ PyTorch XLA导入成功!")
|
||||
|
||||
# 定义兼容性函数
|
||||
def get_xla_world_size():
|
||||
"""获取XLA world size,兼容不同版本"""
|
||||
try:
|
||||
return xm.xrt_world_size()
|
||||
except AttributeError:
|
||||
try:
|
||||
return xm.get_world_size()
|
||||
except AttributeError:
|
||||
return 1 # 默认返回1
|
||||
|
||||
def get_xla_ordinal():
|
||||
"""获取XLA ordinal,兼容不同版本"""
|
||||
try:
|
||||
return xm.get_ordinal()
|
||||
except AttributeError:
|
||||
return 0 # 默认返回0
|
||||
|
||||
def xla_mark_step():
|
||||
"""XLA mark step,兼容不同版本"""
|
||||
try:
|
||||
xm.mark_step()
|
||||
except AttributeError:
|
||||
try:
|
||||
xm.wait_device_ops()
|
||||
except AttributeError:
|
||||
pass # 如果都不可用,则跳过
|
||||
|
||||
def check_xla_device():
|
||||
"""检查XLA设备状态"""
|
||||
try:
|
||||
device = xm.xla_device()
|
||||
print(f"📱 XLA设备: {device}")
|
||||
|
||||
world_size = get_xla_world_size()
|
||||
ordinal = get_xla_ordinal()
|
||||
|
||||
print(f"🌍 World Size: {world_size}")
|
||||
print(f"🔢 Ordinal: {ordinal}")
|
||||
|
||||
# 检测设备类型
|
||||
device_str = str(device)
|
||||
if 'xla' in device_str and 'cpu' not in device_str:
|
||||
print("✅ 检测到TPU设备")
|
||||
return True, "TPU"
|
||||
elif 'xla' in device_str and 'cpu' in device_str:
|
||||
print("⚠️ XLA CPU模拟模式")
|
||||
return True, "XLA_CPU"
|
||||
else:
|
||||
print("❌ 未检测到XLA设备")
|
||||
return False, "CPU"
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ XLA设备检查失败: {e}")
|
||||
return False, "ERROR"
|
||||
|
||||
# 执行兼容性检查
|
||||
device_available, device_type = check_xla_device()
|
||||
|
||||
if device_available:
|
||||
print(f"✅ XLA环境正常,设备类型: {device_type}")
|
||||
|
||||
# 测试基本XLA操作
|
||||
print("🧪 测试基本XLA操作...")
|
||||
try:
|
||||
device = xm.xla_device()
|
||||
x = torch.randn(2, 2, device=device)
|
||||
y = torch.matmul(x, x)
|
||||
|
||||
# 测试同步函数
|
||||
xla_mark_step()
|
||||
|
||||
print("✅ 基本XLA操作测试成功")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ XLA操作测试失败: {e}")
|
||||
else:
|
||||
print("❌ XLA环境不可用")
|
||||
|
||||
print("\n💡 兼容性检查完成,可以运行后续单元格")
|
@@ -15,7 +15,43 @@ import torch_xla.core.xla_model as xm
|
||||
|
||||
print(f"✅ XLA导入成功!")
|
||||
print(f" TPU设备: {xm.xla_device()}")
|
||||
print(f" World Size: {xm.xrt_world_size()}")
|
||||
|
||||
# 兼容新版本PyTorch XLA
|
||||
try:
|
||||
world_size = xm.xrt_world_size()
|
||||
print(f" World Size (旧API): {world_size}")
|
||||
except AttributeError:
|
||||
try:
|
||||
world_size = xm.get_world_size()
|
||||
print(f" World Size (新API): {world_size}")
|
||||
except AttributeError:
|
||||
print(" World Size: 无法获取 (可能在CPU模式)")
|
||||
|
||||
# 检查XLA版本兼容性
|
||||
print("🔍 检查XLA API兼容性:")
|
||||
api_available = []
|
||||
api_deprecated = []
|
||||
|
||||
# 检查各种API
|
||||
test_apis = [
|
||||
('xrt_world_size', 'xrt_world_size()'),
|
||||
('get_world_size', 'get_world_size()'),
|
||||
('mark_step', 'mark_step()'),
|
||||
('wait_device_ops', 'wait_device_ops()'),
|
||||
('get_ordinal', 'get_ordinal()'),
|
||||
('xla_device_count', 'xla_device_count()')
|
||||
]
|
||||
|
||||
for api_name, api_desc in test_apis:
|
||||
if hasattr(xm, api_name):
|
||||
api_available.append(api_desc)
|
||||
else:
|
||||
api_deprecated.append(api_desc)
|
||||
|
||||
if api_available:
|
||||
print(f" ✅ 可用API: {', '.join(api_available)}")
|
||||
if api_deprecated:
|
||||
print(f" ❌ 不可用API: {', '.join(api_deprecated)}")
|
||||
|
||||
# 创建编译进度监控器
|
||||
class JupyterCompilationMonitor:
|
||||
|
@@ -1,45 +0,0 @@
|
||||
# ====================
|
||||
# 单元格1: 环境设置 (必须第一个运行!)
|
||||
# ====================
|
||||
|
||||
import os
|
||||
import time
|
||||
import psutil
|
||||
from IPython.display import display, HTML, clear_output
|
||||
import ipywidgets as widgets
|
||||
|
||||
# ⚠️ 关键: 在导入torch_xla之前设置环境变量
|
||||
print("🔧 设置XLA环境变量...")
|
||||
|
||||
# 获取CPU核心数
|
||||
cpu_count = os.cpu_count()
|
||||
print(f"检测到 {cpu_count} 个CPU核心")
|
||||
|
||||
# 设置XLA编译优化环境变量
|
||||
os.environ['XLA_FLAGS'] = (
|
||||
'--xla_cpu_multi_thread_eigen=true '
|
||||
'--xla_cpu_enable_fast_math=true '
|
||||
f'--xla_force_host_platform_device_count={cpu_count}'
|
||||
)
|
||||
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
|
||||
os.environ['XLA_USE_BF16'] = '1'
|
||||
|
||||
# 显示设置结果
|
||||
print("✅ XLA环境变量已设置:")
|
||||
print(f" CPU核心数: {cpu_count}")
|
||||
print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}")
|
||||
print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
|
||||
|
||||
# 系统资源检查
|
||||
memory_info = psutil.virtual_memory()
|
||||
print(f"\n💾 系统内存信息:")
|
||||
print(f" 总内存: {memory_info.total / (1024**3):.1f} GB")
|
||||
print(f" 可用内存: {memory_info.available / (1024**3):.1f} GB")
|
||||
print(f" 使用率: {memory_info.percent:.1f}%")
|
||||
|
||||
if memory_info.available < 8 * (1024**3): # 小于8GB
|
||||
print("⚠️ 警告: 可用内存不足8GB,可能影响XLA编译速度")
|
||||
else:
|
||||
print("✅ 内存充足")
|
||||
|
||||
print("\n🎯 环境设置完成! 现在可以运行下一个单元格")
|
@@ -1,161 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TPU Training Launch Script for Brain-to-Text RNN Model
|
||||
|
||||
This script provides easy TPU training setup using Accelerate library.
|
||||
Supports both single TPU core and multi-core (8 cores) training.
|
||||
|
||||
Usage:
|
||||
python launch_tpu_training.py --config rnn_args.yaml --num_cores 8
|
||||
|
||||
Requirements:
|
||||
- PyTorch XLA installed
|
||||
- Accelerate library installed
|
||||
- TPU runtime available
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import yaml
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
def update_config_for_tpu(config_path, num_cores=8):
|
||||
"""
|
||||
Update configuration file to enable TPU training
|
||||
"""
|
||||
with open(config_path, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# Enable TPU settings
|
||||
config['use_tpu'] = True
|
||||
config['num_tpu_cores'] = num_cores
|
||||
config['dataloader_num_workers'] = 0 # Required for TPU
|
||||
config['use_amp'] = True # Enable mixed precision with bfloat16
|
||||
|
||||
# Adjust batch size and gradient accumulation for multi-core TPU
|
||||
if num_cores > 1:
|
||||
# Distribute batch size across cores
|
||||
original_batch_size = config['dataset']['batch_size']
|
||||
config['dataset']['batch_size'] = max(1, original_batch_size // num_cores)
|
||||
config['gradient_accumulation_steps'] = max(1, config.get('gradient_accumulation_steps', 1))
|
||||
|
||||
print(f"Adjusted batch size from {original_batch_size} to {config['dataset']['batch_size']} per core")
|
||||
print(f"Gradient accumulation steps: {config['gradient_accumulation_steps']}")
|
||||
|
||||
# Save updated config
|
||||
tpu_config_path = config_path.replace('.yaml', '_tpu.yaml')
|
||||
with open(tpu_config_path, 'w') as f:
|
||||
yaml.dump(config, f, default_flow_style=False)
|
||||
|
||||
print(f"TPU configuration saved to: {tpu_config_path}")
|
||||
return tpu_config_path
|
||||
|
||||
def check_tpu_environment():
|
||||
"""
|
||||
Check if TPU environment is properly set up
|
||||
"""
|
||||
try:
|
||||
import torch_xla
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
# Check if TPUs are available
|
||||
device = xm.xla_device()
|
||||
print(f"TPU device available: {device}")
|
||||
print(f"TPU ordinal: {xm.get_ordinal()}")
|
||||
print(f"TPU world size: {xm.xrt_world_size()}")
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
print("ERROR: torch_xla not installed. Please install PyTorch XLA for TPU support.")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"ERROR: TPU not available - {e}")
|
||||
return False
|
||||
|
||||
def run_tpu_training(config_path, num_cores=8):
|
||||
"""
|
||||
Launch TPU training using accelerate
|
||||
"""
|
||||
# Check TPU environment
|
||||
if not check_tpu_environment():
|
||||
sys.exit(1)
|
||||
|
||||
# Update config for TPU
|
||||
tpu_config_path = update_config_for_tpu(config_path, num_cores)
|
||||
|
||||
# Set TPU environment variables BEFORE launching training
|
||||
os.environ['TPU_CORES'] = str(num_cores)
|
||||
os.environ['XLA_USE_BF16'] = '1' # Enable bfloat16
|
||||
|
||||
# Critical XLA multi-threading settings - must be set before torch_xla import
|
||||
cpu_count = os.cpu_count()
|
||||
os.environ['XLA_FLAGS'] = (
|
||||
'--xla_cpu_multi_thread_eigen=true '
|
||||
'--xla_cpu_enable_fast_math=true '
|
||||
f'--xla_force_host_platform_device_count={cpu_count}'
|
||||
)
|
||||
os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(cpu_count)
|
||||
|
||||
print(f"Set XLA compilation to use {cpu_count} CPU threads")
|
||||
print(f"XLA_FLAGS: {os.environ['XLA_FLAGS']}")
|
||||
print(f"PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}")
|
||||
|
||||
# Launch training with accelerate using subprocess to ensure environment variables are passed
|
||||
cmd = f"accelerate launch --tpu --num_processes {num_cores} train_model.py --config_path {tpu_config_path}"
|
||||
|
||||
print(f"Launching TPU training with command:")
|
||||
print(f" {cmd}")
|
||||
print(f"Using {num_cores} TPU cores")
|
||||
print("-" * 60)
|
||||
|
||||
# Use subprocess to ensure environment variables are properly inherited
|
||||
import subprocess
|
||||
|
||||
# Create environment with our XLA settings
|
||||
env = os.environ.copy()
|
||||
env.update({
|
||||
'TPU_CORES': str(num_cores),
|
||||
'XLA_USE_BF16': '1',
|
||||
'XLA_FLAGS': (
|
||||
'--xla_cpu_multi_thread_eigen=true '
|
||||
'--xla_cpu_enable_fast_math=true '
|
||||
f'--xla_force_host_platform_device_count={cpu_count}'
|
||||
),
|
||||
'PYTORCH_XLA_COMPILATION_THREADS': str(cpu_count)
|
||||
})
|
||||
|
||||
print(f"Environment variables set for subprocess:")
|
||||
print(f" XLA_FLAGS: {env['XLA_FLAGS']}")
|
||||
print(f" PYTORCH_XLA_COMPILATION_THREADS: {env['PYTORCH_XLA_COMPILATION_THREADS']}")
|
||||
print("-" * 60)
|
||||
|
||||
# Execute training with proper environment
|
||||
result = subprocess.run(cmd.split(), env=env)
|
||||
return result.returncode
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Launch TPU training for Brain-to-Text RNN')
|
||||
parser.add_argument('--config', default='rnn_args.yaml',
|
||||
help='Path to configuration file (default: rnn_args.yaml)')
|
||||
parser.add_argument('--num_cores', type=int, default=8,
|
||||
help='Number of TPU cores to use (default: 8)')
|
||||
parser.add_argument('--check_only', action='store_true',
|
||||
help='Only check TPU environment, do not launch training')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Verify config file exists
|
||||
if not os.path.exists(args.config):
|
||||
print(f"ERROR: Configuration file {args.config} not found")
|
||||
sys.exit(1)
|
||||
|
||||
if args.check_only:
|
||||
check_tpu_environment()
|
||||
return
|
||||
|
||||
# Run TPU training
|
||||
run_tpu_training(args.config, args.num_cores)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -1,100 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
XLA编译进度监控脚本
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import threading
|
||||
import psutil
|
||||
from contextlib import contextmanager
|
||||
|
||||
def monitor_compilation_progress():
|
||||
"""监控XLA编译进度"""
|
||||
print("🔍 XLA编译进度监控已启动...")
|
||||
|
||||
start_time = time.time()
|
||||
dots = 0
|
||||
|
||||
while True:
|
||||
elapsed = time.time() - start_time
|
||||
minutes = int(elapsed // 60)
|
||||
seconds = int(elapsed % 60)
|
||||
|
||||
# 获取CPU使用率
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
memory_percent = psutil.virtual_memory().percent
|
||||
|
||||
# 动态显示
|
||||
dots = (dots + 1) % 4
|
||||
dot_str = "." * dots + " " * (3 - dots)
|
||||
|
||||
print(f"\r🔄 XLA编译中{dot_str} "
|
||||
f"⏱️ {minutes:02d}:{seconds:02d} "
|
||||
f"🖥️ CPU: {cpu_percent:5.1f}% "
|
||||
f"💾 内存: {memory_percent:5.1f}%", end="", flush=True)
|
||||
|
||||
time.sleep(1)
|
||||
|
||||
@contextmanager
|
||||
def compilation_monitor():
|
||||
"""编译监控上下文管理器"""
|
||||
print("🚀 开始XLA编译监控...")
|
||||
|
||||
# 启动监控线程
|
||||
monitor_thread = threading.Thread(target=monitor_compilation_progress, daemon=True)
|
||||
monitor_thread.start()
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
elapsed = time.time() - start_time
|
||||
print(f"\n✅ XLA编译完成! 总耗时: {elapsed:.2f}秒")
|
||||
|
||||
# 修改trainer中的使用方法
|
||||
def add_compilation_monitoring_to_trainer():
|
||||
"""给trainer添加编译监控的示例代码"""
|
||||
example_code = '''
|
||||
# 在rnn_trainer.py的train方法中添加:
|
||||
|
||||
def train(self):
|
||||
from monitor_xla_compilation import compilation_monitor
|
||||
|
||||
self.model.train()
|
||||
train_losses = []
|
||||
# ... 其他初始化代码 ...
|
||||
|
||||
self.logger.info("Starting training loop - XLA compilation monitoring enabled...")
|
||||
|
||||
# 使用编译监控
|
||||
with compilation_monitor():
|
||||
for i, batch in enumerate(self.train_loader):
|
||||
# 第一个batch会触发XLA编译
|
||||
# 监控会显示编译进度
|
||||
|
||||
# ... 训练代码 ...
|
||||
|
||||
# 编译完成后会自动停止监控
|
||||
break # 只需要第一个batch来触发编译
|
||||
|
||||
# 继续正常训练循环
|
||||
for i, batch in enumerate(self.train_loader):
|
||||
# ... 正常训练代码 ...
|
||||
'''
|
||||
|
||||
print("📝 如何在trainer中使用编译监控:")
|
||||
print(example_code)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🧪 XLA编译监控工具")
|
||||
print("=" * 50)
|
||||
|
||||
# 演示如何使用
|
||||
print("📖 使用方法:")
|
||||
print("1. 将此文件导入到你的训练脚本中")
|
||||
print("2. 在第一次模型调用前使用 compilation_monitor() 上下文管理器")
|
||||
print("3. 会实时显示编译进度和系统资源使用情况")
|
||||
|
||||
add_compilation_monitoring_to_trainer()
|
@@ -1,27 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# TPU XLA Multi-threading Environment Setup
|
||||
# Set these BEFORE starting Python to ensure they take effect
|
||||
|
||||
echo "Setting up XLA multi-threading environment..."
|
||||
|
||||
# Get CPU core count
|
||||
CPU_CORES=$(nproc)
|
||||
echo "Detected $CPU_CORES CPU cores"
|
||||
|
||||
# Set XLA compilation flags
|
||||
export XLA_FLAGS="--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=true --xla_force_host_platform_device_count=$CPU_CORES"
|
||||
export PYTORCH_XLA_COMPILATION_THREADS=$CPU_CORES
|
||||
|
||||
# Additional XLA optimizations
|
||||
export XLA_USE_BF16=1
|
||||
export TPU_CORES=8
|
||||
|
||||
# Print current settings
|
||||
echo "XLA_FLAGS: $XLA_FLAGS"
|
||||
echo "PYTORCH_XLA_COMPILATION_THREADS: $PYTORCH_XLA_COMPILATION_THREADS"
|
||||
echo "XLA_USE_BF16: $XLA_USE_BF16"
|
||||
|
||||
# Start training
|
||||
echo "Starting TPU training..."
|
||||
python train_model.py --config_path rnn_args.yaml
|
Reference in New Issue
Block a user