#!/usr/bin/env python3 """ 简化模型测试脚本 - 验证XLA编译是否正常工作 """ import os import time import torch import torch.nn as nn # 设置XLA环境变量(必须在导入torch_xla之前) os.environ['XLA_FLAGS'] = ( '--xla_cpu_multi_thread_eigen=true ' '--xla_cpu_enable_fast_math=true ' f'--xla_force_host_platform_device_count={os.cpu_count()}' ) os.environ['PYTORCH_XLA_COMPILATION_THREADS'] = str(os.cpu_count()) os.environ['XLA_USE_BF16'] = '1' print(f"🔧 XLA环境变量设置:") print(f" CPU核心数: {os.cpu_count()}") print(f" XLA_FLAGS: {os.environ['XLA_FLAGS']}") print(f" PYTORCH_XLA_COMPILATION_THREADS: {os.environ['PYTORCH_XLA_COMPILATION_THREADS']}") import torch_xla.core.xla_model as xm class SimpleModel(nn.Module): """简化的测试模型""" def __init__(self): super().__init__() self.linear1 = nn.Linear(512, 256) self.gru = nn.GRU(256, 128, batch_first=True) self.linear2 = nn.Linear(128, 41) # 41个音素类别 def forward(self, x): x = torch.relu(self.linear1(x)) x, _ = self.gru(x) x = self.linear2(x) return x def test_xla_compilation(): """测试XLA编译速度""" print("\n🚀 开始简化模型XLA编译测试...") # 检查TPU设备 device = xm.xla_device() print(f"📱 TPU设备: {device}") print(f"🌍 TPU World Size: {xm.xrt_world_size()}") # 创建简化模型 model = SimpleModel().to(device) print(f"📊 模型参数数量: {sum(p.numel() for p in model.parameters()):,}") # 创建测试数据 batch_size = 8 # 小批次 seq_len = 100 # 短序列 x = torch.randn(batch_size, seq_len, 512, device=device) print(f"📥 输入形状: {x.shape}") # 首次前向传播 - 触发XLA编译 print(f"🔄 开始首次前向传播 (XLA编译)...") start_time = time.time() with torch.no_grad(): output = model(x) compile_time = time.time() - start_time print(f"✅ XLA编译完成! 耗时: {compile_time:.2f}秒") print(f"📤 输出形状: {output.shape}") # 再次前向传播 - 使用编译后的图 print(f"🔄 第二次前向传播 (使用编译后的图)...") start_time = time.time() with torch.no_grad(): output2 = model(x) execution_time = time.time() - start_time print(f"⚡ 执行完成! 耗时: {execution_time:.4f}秒") # 性能对比 speedup = compile_time / execution_time if execution_time > 0 else float('inf') print(f"\n📈 性能分析:") print(f" 编译时间: {compile_time:.2f}秒") print(f" 执行时间: {execution_time:.4f}秒") print(f" 加速比: {speedup:.1f}x") if compile_time < 60: # 1分钟内编译完成 print("✅ XLA编译正常!") return True else: print("❌ XLA编译过慢,可能有问题") return False def test_training_step(): """测试训练步骤""" print("\n🎯 测试简化训练步骤...") device = xm.xla_device() model = SimpleModel().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) criterion = nn.CrossEntropyLoss() # 创建训练数据 x = torch.randn(4, 50, 512, device=device) labels = torch.randint(0, 41, (4, 50), device=device) print(f"🔄 开始训练步骤 (包含反向传播)...") start_time = time.time() # 前向传播 outputs = model(x) # 计算损失 loss = criterion(outputs.view(-1, 41), labels.view(-1)) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() step_time = time.time() - start_time print(f"✅ 训练步骤完成! 耗时: {step_time:.2f}秒, 损失: {loss.item():.4f}") return step_time < 120 # 2分钟内完成 def main(): print("=" * 60) print("🧪 XLA编译快速测试") print("=" * 60) try: # 测试1: 简单模型编译 compilation_ok = test_xla_compilation() if compilation_ok: # 测试2: 训练步骤 training_ok = test_training_step() if training_ok: print("\n✅ 所有测试通过! 可以尝试完整模型训练") print("💡 建议:") print(" 1. 确保有足够内存 (32GB+)") print(" 2. 减小batch_size (比如从32改为16)") print(" 3. 使用gradient_accumulation_steps补偿") else: print("\n⚠️ 训练步骤较慢,建议优化") else: print("\n❌ XLA编译有问题,需要检查环境") except Exception as e: print(f"\n💥 测试失败: {e}") print("💡 可能的问题:") print(" - TPU资源不可用") print(" - PyTorch XLA安装问题") print(" - 内存不足") print("=" * 60) if __name__ == "__main__": main()