Files
b2txt25/model_training_nnn_tpu/quick_tpu_test.py

129 lines
3.6 KiB
Python
Raw Normal View History

2025-10-15 15:14:01 +08:00
#!/usr/bin/env python3
"""
快速TPU测试脚本 - 验证简单模型是否可以在TPU上运行
"""
import os
import time
import torch
import torch.nn as nn
# 设置环境变量
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=true'
os.environ['XLA_USE_BF16'] = '1'
import torch_xla.core.xla_model as xm
def quick_test():
"""快速测试TPU是否工作正常"""
print("🚀 开始快速TPU测试...")
try:
# 获取TPU设备
device = xm.xla_device()
print(f"📱 TPU设备: {device}")
# 创建简单模型
model = nn.Sequential(
nn.Linear(512, 256),
nn.ReLU(),
nn.GRU(256, 128, batch_first=True),
nn.Linear(128, 41)
).to(device)
print(f"📊 模型参数: {sum(p.numel() for p in model.parameters()):,}")
# 创建测试数据
x = torch.randn(8, 50, 512, device=device)
print(f"📥 输入形状: {x.shape}")
# 测试前向传播
print("🔄 测试前向传播...")
start_time = time.time()
with torch.no_grad():
if hasattr(model, '__getitem__'):
# 对于Sequential模型手动处理GRU层
x_proj = model[1](model[0](x)) # Linear + ReLU
gru_out, _ = model[2](x_proj) # GRU
output = model[3](gru_out) # Final Linear
else:
output = model(x)
# 同步TPU操作
xm.mark_step()
xm.wait_device_ops()
forward_time = time.time() - start_time
print(f"✅ 前向传播完成! 耗时: {forward_time:.3f}")
print(f"📤 输出形状: {output.shape}")
# 测试反向传播
print("🔄 测试反向传播...")
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
start_time = time.time()
# 创建虚拟标签
labels = torch.randint(0, 41, (8, 50), device=device)
criterion = nn.CrossEntropyLoss()
# 前向传播
if hasattr(model, '__getitem__'):
x_proj = model[1](model[0](x))
gru_out, _ = model[2](x_proj)
output = model[3](gru_out)
else:
output = model(x)
# 计算损失
loss = criterion(output.view(-1, 41), labels.view(-1))
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 同步TPU操作
xm.mark_step()
xm.wait_device_ops()
backward_time = time.time() - start_time
print(f"✅ 反向传播完成! 耗时: {backward_time:.3f}")
print(f"🎯 损失值: {loss.item():.4f}")
# 总结
print(f"\n📈 性能总结:")
print(f" 前向传播: {forward_time:.3f}")
print(f" 反向传播: {backward_time:.3f}")
print(f" 总计: {forward_time + backward_time:.3f}")
if (forward_time + backward_time) < 10: # 10秒内完成
print("✅ TPU测试通过! 可以进行完整训练")
return True
else:
print("⚠️ TPU性能较慢可能需要优化")
return False
except Exception as e:
print(f"❌ TPU测试失败: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
print("=" * 50)
print("⚡ 快速TPU测试")
print("=" * 50)
success = quick_test()
if success:
print("\n🎉 测试成功! 现在可以运行:")
print(" python simple_tpu_model.py")
else:
print("\n❌ 测试失败请检查TPU配置")
print("=" * 50)