194 lines
5.0 KiB
Python
194 lines
5.0 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
最简单的TPU测试 - 完全避开bf16问题
|
||
只使用float32,最基本的操作
|
||
"""
|
||
|
||
import os
|
||
import time
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
# 完全不设置任何bf16相关的环境变量
|
||
# 只设置最基本的XLA优化
|
||
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true'
|
||
|
||
# 确保不使用bf16
|
||
if 'XLA_USE_BF16' in os.environ:
|
||
del os.environ['XLA_USE_BF16']
|
||
|
||
import torch_xla.core.xla_model as xm
|
||
|
||
|
||
def test_basic_operations():
|
||
"""测试最基本的TPU操作"""
|
||
print("🚀 测试最基本的TPU操作...")
|
||
|
||
try:
|
||
device = xm.xla_device()
|
||
print(f"📱 设备: {device}")
|
||
|
||
# 测试1: 基本张量操作
|
||
print("🔧 测试基本张量操作...")
|
||
a = torch.randn(4, 4, device=device, dtype=torch.float32)
|
||
b = torch.randn(4, 4, device=device, dtype=torch.float32)
|
||
c = a + b
|
||
|
||
print(f" a.shape: {a.shape}, dtype: {a.dtype}")
|
||
print(f" b.shape: {b.shape}, dtype: {b.dtype}")
|
||
print(f" c.shape: {c.shape}, dtype: {c.dtype}")
|
||
|
||
# 同步
|
||
xm.mark_step()
|
||
xm.wait_device_ops()
|
||
print("✅ 基本张量操作成功")
|
||
|
||
# 测试2: 矩阵乘法
|
||
print("🔧 测试矩阵乘法...")
|
||
d = torch.mm(a, b)
|
||
xm.mark_step()
|
||
xm.wait_device_ops()
|
||
print(f" 矩阵乘法结果shape: {d.shape}, dtype: {d.dtype}")
|
||
print("✅ 矩阵乘法成功")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ 基本操作失败: {e}")
|
||
return False
|
||
|
||
|
||
def test_simple_model():
|
||
"""测试最简单的模型"""
|
||
print("\n🧠 测试最简单的模型...")
|
||
|
||
try:
|
||
device = xm.xla_device()
|
||
|
||
# 超简单的线性模型
|
||
model = nn.Sequential(
|
||
nn.Linear(10, 5),
|
||
nn.ReLU(),
|
||
nn.Linear(5, 2)
|
||
).to(device)
|
||
|
||
print(f"📊 模型参数: {sum(p.numel() for p in model.parameters())}")
|
||
|
||
# 确保所有参数都是float32
|
||
for param in model.parameters():
|
||
param.data = param.data.to(torch.float32)
|
||
|
||
# 创建输入数据 - 明确指定float32
|
||
x = torch.randn(8, 10, device=device, dtype=torch.float32)
|
||
|
||
print(f"📥 输入: shape={x.shape}, dtype={x.dtype}")
|
||
|
||
# 前向传播
|
||
with torch.no_grad():
|
||
output = model(x)
|
||
xm.mark_step()
|
||
xm.wait_device_ops()
|
||
|
||
print(f"📤 输出: shape={output.shape}, dtype={output.dtype}")
|
||
print("✅ 简单模型前向传播成功")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ 简单模型失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
|
||
def test_training_step():
|
||
"""测试最简单的训练步骤"""
|
||
print("\n🎯 测试最简单的训练步骤...")
|
||
|
||
try:
|
||
device = xm.xla_device()
|
||
|
||
# 超简单模型
|
||
model = nn.Linear(10, 1).to(device)
|
||
|
||
# 确保权重是float32
|
||
for param in model.parameters():
|
||
param.data = param.data.to(torch.float32)
|
||
|
||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||
criterion = nn.MSELoss()
|
||
|
||
# 创建数据 - 明确float32
|
||
x = torch.randn(4, 10, device=device, dtype=torch.float32)
|
||
y = torch.randn(4, 1, device=device, dtype=torch.float32)
|
||
|
||
print(f"📥 输入: {x.shape}, {x.dtype}")
|
||
print(f"📥 标签: {y.shape}, {y.dtype}")
|
||
|
||
# 一个训练步骤
|
||
optimizer.zero_grad()
|
||
output = model(x)
|
||
loss = criterion(output, y)
|
||
loss.backward()
|
||
optimizer.step()
|
||
|
||
# 同步
|
||
xm.mark_step()
|
||
xm.wait_device_ops()
|
||
|
||
print(f"🎯 损失: {loss.item():.4f}")
|
||
print("✅ 训练步骤成功")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ 训练步骤失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("=" * 50)
|
||
print("🔬 最简TPU测试 (仅float32)")
|
||
print("=" * 50)
|
||
|
||
all_passed = True
|
||
|
||
# 测试1: 基本操作
|
||
if test_basic_operations():
|
||
print("1️⃣ 基本操作 ✅")
|
||
else:
|
||
print("1️⃣ 基本操作 ❌")
|
||
all_passed = False
|
||
|
||
# 测试2: 简单模型
|
||
if test_simple_model():
|
||
print("2️⃣ 简单模型 ✅")
|
||
else:
|
||
print("2️⃣ 简单模型 ❌")
|
||
all_passed = False
|
||
|
||
# 测试3: 训练步骤
|
||
if test_training_step():
|
||
print("3️⃣ 训练步骤 ✅")
|
||
else:
|
||
print("3️⃣ 训练步骤 ❌")
|
||
all_passed = False
|
||
|
||
print("=" * 50)
|
||
|
||
if all_passed:
|
||
print("🎉 所有测试通过! TPU工作正常")
|
||
print("💡 现在可以尝试更复杂的模型")
|
||
else:
|
||
print("❌ 部分测试失败")
|
||
print("💡 建议:")
|
||
print(" 1. 检查TPU资源是否可用")
|
||
print(" 2. 确认torch_xla安装正确")
|
||
print(" 3. 重启runtime清理状态")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |