#!/usr/bin/env python3 """ 快速TPU测试脚本 - 验证简单模型是否可以在TPU上运行 """ import os import time import torch import torch.nn as nn # 设置环境变量 os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=true' os.environ['XLA_USE_BF16'] = '1' import torch_xla.core.xla_model as xm def quick_test(): """快速测试TPU是否工作正常""" print("🚀 开始快速TPU测试...") try: # 获取TPU设备 device = xm.xla_device() print(f"📱 TPU设备: {device}") # 创建简单模型 model = nn.Sequential( nn.Linear(512, 256), nn.ReLU(), nn.GRU(256, 128, batch_first=True), nn.Linear(128, 41) ).to(device) print(f"📊 模型参数: {sum(p.numel() for p in model.parameters()):,}") # 创建测试数据 x = torch.randn(8, 50, 512, device=device) print(f"📥 输入形状: {x.shape}") # 测试前向传播 print("🔄 测试前向传播...") start_time = time.time() with torch.no_grad(): if hasattr(model, '__getitem__'): # 对于Sequential模型,手动处理GRU层 x_proj = model[1](model[0](x)) # Linear + ReLU gru_out, _ = model[2](x_proj) # GRU output = model[3](gru_out) # Final Linear else: output = model(x) # 同步TPU操作 xm.mark_step() xm.wait_device_ops() forward_time = time.time() - start_time print(f"✅ 前向传播完成! 耗时: {forward_time:.3f}秒") print(f"📤 输出形状: {output.shape}") # 测试反向传播 print("🔄 测试反向传播...") model.train() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) start_time = time.time() # 创建虚拟标签 labels = torch.randint(0, 41, (8, 50), device=device) criterion = nn.CrossEntropyLoss() # 前向传播 if hasattr(model, '__getitem__'): x_proj = model[1](model[0](x)) gru_out, _ = model[2](x_proj) output = model[3](gru_out) else: output = model(x) # 计算损失 loss = criterion(output.view(-1, 41), labels.view(-1)) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() # 同步TPU操作 xm.mark_step() xm.wait_device_ops() backward_time = time.time() - start_time print(f"✅ 反向传播完成! 耗时: {backward_time:.3f}秒") print(f"🎯 损失值: {loss.item():.4f}") # 总结 print(f"\n📈 性能总结:") print(f" 前向传播: {forward_time:.3f}秒") print(f" 反向传播: {backward_time:.3f}秒") print(f" 总计: {forward_time + backward_time:.3f}秒") if (forward_time + backward_time) < 10: # 10秒内完成 print("✅ TPU测试通过! 可以进行完整训练") return True else: print("⚠️ TPU性能较慢,可能需要优化") return False except Exception as e: print(f"❌ TPU测试失败: {e}") import traceback traceback.print_exc() return False if __name__ == "__main__": print("=" * 50) print("⚡ 快速TPU测试") print("=" * 50) success = quick_test() if success: print("\n🎉 测试成功! 现在可以运行:") print(" python simple_tpu_model.py") else: print("\n❌ 测试失败,请检查TPU配置") print("=" * 50)