129 lines
3.6 KiB
Python
129 lines
3.6 KiB
Python
![]() |
#!/usr/bin/env python3
|
|||
|
"""
|
|||
|
快速TPU测试脚本 - 验证简单模型是否可以在TPU上运行
|
|||
|
"""
|
|||
|
|
|||
|
import os
|
|||
|
import time
|
|||
|
import torch
|
|||
|
import torch.nn as nn
|
|||
|
|
|||
|
# 设置环境变量
|
|||
|
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true --xla_cpu_enable_fast_math=true'
|
|||
|
os.environ['XLA_USE_BF16'] = '1'
|
|||
|
|
|||
|
import torch_xla.core.xla_model as xm
|
|||
|
|
|||
|
def quick_test():
|
|||
|
"""快速测试TPU是否工作正常"""
|
|||
|
print("🚀 开始快速TPU测试...")
|
|||
|
|
|||
|
try:
|
|||
|
# 获取TPU设备
|
|||
|
device = xm.xla_device()
|
|||
|
print(f"📱 TPU设备: {device}")
|
|||
|
|
|||
|
# 创建简单模型
|
|||
|
model = nn.Sequential(
|
|||
|
nn.Linear(512, 256),
|
|||
|
nn.ReLU(),
|
|||
|
nn.GRU(256, 128, batch_first=True),
|
|||
|
nn.Linear(128, 41)
|
|||
|
).to(device)
|
|||
|
|
|||
|
print(f"📊 模型参数: {sum(p.numel() for p in model.parameters()):,}")
|
|||
|
|
|||
|
# 创建测试数据
|
|||
|
x = torch.randn(8, 50, 512, device=device)
|
|||
|
print(f"📥 输入形状: {x.shape}")
|
|||
|
|
|||
|
# 测试前向传播
|
|||
|
print("🔄 测试前向传播...")
|
|||
|
start_time = time.time()
|
|||
|
|
|||
|
with torch.no_grad():
|
|||
|
if hasattr(model, '__getitem__'):
|
|||
|
# 对于Sequential模型,手动处理GRU层
|
|||
|
x_proj = model[1](model[0](x)) # Linear + ReLU
|
|||
|
gru_out, _ = model[2](x_proj) # GRU
|
|||
|
output = model[3](gru_out) # Final Linear
|
|||
|
else:
|
|||
|
output = model(x)
|
|||
|
|
|||
|
# 同步TPU操作
|
|||
|
xm.mark_step()
|
|||
|
xm.wait_device_ops()
|
|||
|
|
|||
|
forward_time = time.time() - start_time
|
|||
|
print(f"✅ 前向传播完成! 耗时: {forward_time:.3f}秒")
|
|||
|
print(f"📤 输出形状: {output.shape}")
|
|||
|
|
|||
|
# 测试反向传播
|
|||
|
print("🔄 测试反向传播...")
|
|||
|
model.train()
|
|||
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
|||
|
|
|||
|
start_time = time.time()
|
|||
|
|
|||
|
# 创建虚拟标签
|
|||
|
labels = torch.randint(0, 41, (8, 50), device=device)
|
|||
|
criterion = nn.CrossEntropyLoss()
|
|||
|
|
|||
|
# 前向传播
|
|||
|
if hasattr(model, '__getitem__'):
|
|||
|
x_proj = model[1](model[0](x))
|
|||
|
gru_out, _ = model[2](x_proj)
|
|||
|
output = model[3](gru_out)
|
|||
|
else:
|
|||
|
output = model(x)
|
|||
|
|
|||
|
# 计算损失
|
|||
|
loss = criterion(output.view(-1, 41), labels.view(-1))
|
|||
|
|
|||
|
# 反向传播
|
|||
|
optimizer.zero_grad()
|
|||
|
loss.backward()
|
|||
|
optimizer.step()
|
|||
|
|
|||
|
# 同步TPU操作
|
|||
|
xm.mark_step()
|
|||
|
xm.wait_device_ops()
|
|||
|
|
|||
|
backward_time = time.time() - start_time
|
|||
|
print(f"✅ 反向传播完成! 耗时: {backward_time:.3f}秒")
|
|||
|
print(f"🎯 损失值: {loss.item():.4f}")
|
|||
|
|
|||
|
# 总结
|
|||
|
print(f"\n📈 性能总结:")
|
|||
|
print(f" 前向传播: {forward_time:.3f}秒")
|
|||
|
print(f" 反向传播: {backward_time:.3f}秒")
|
|||
|
print(f" 总计: {forward_time + backward_time:.3f}秒")
|
|||
|
|
|||
|
if (forward_time + backward_time) < 10: # 10秒内完成
|
|||
|
print("✅ TPU测试通过! 可以进行完整训练")
|
|||
|
return True
|
|||
|
else:
|
|||
|
print("⚠️ TPU性能较慢,可能需要优化")
|
|||
|
return False
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ TPU测试失败: {e}")
|
|||
|
import traceback
|
|||
|
traceback.print_exc()
|
|||
|
return False
|
|||
|
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
print("=" * 50)
|
|||
|
print("⚡ 快速TPU测试")
|
|||
|
print("=" * 50)
|
|||
|
|
|||
|
success = quick_test()
|
|||
|
|
|||
|
if success:
|
|||
|
print("\n🎉 测试成功! 现在可以运行:")
|
|||
|
print(" python simple_tpu_model.py")
|
|||
|
else:
|
|||
|
print("\n❌ 测试失败,请检查TPU配置")
|
|||
|
|
|||
|
print("=" * 50)
|