This commit is contained in:
Zchen
2025-10-15 15:22:13 +08:00
parent 082018cd46
commit b466e97463
4 changed files with 447 additions and 496 deletions

View File

@@ -0,0 +1,194 @@
#!/usr/bin/env python3
"""
最简单的TPU测试 - 完全避开bf16问题
只使用float32最基本的操作
"""
import os
import time
import torch
import torch.nn as nn
# 完全不设置任何bf16相关的环境变量
# 只设置最基本的XLA优化
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true'
# 确保不使用bf16
if 'XLA_USE_BF16' in os.environ:
del os.environ['XLA_USE_BF16']
import torch_xla.core.xla_model as xm
def test_basic_operations():
"""测试最基本的TPU操作"""
print("🚀 测试最基本的TPU操作...")
try:
device = xm.xla_device()
print(f"📱 设备: {device}")
# 测试1: 基本张量操作
print("🔧 测试基本张量操作...")
a = torch.randn(4, 4, device=device, dtype=torch.float32)
b = torch.randn(4, 4, device=device, dtype=torch.float32)
c = a + b
print(f" a.shape: {a.shape}, dtype: {a.dtype}")
print(f" b.shape: {b.shape}, dtype: {b.dtype}")
print(f" c.shape: {c.shape}, dtype: {c.dtype}")
# 同步
xm.mark_step()
xm.wait_device_ops()
print("✅ 基本张量操作成功")
# 测试2: 矩阵乘法
print("🔧 测试矩阵乘法...")
d = torch.mm(a, b)
xm.mark_step()
xm.wait_device_ops()
print(f" 矩阵乘法结果shape: {d.shape}, dtype: {d.dtype}")
print("✅ 矩阵乘法成功")
return True
except Exception as e:
print(f"❌ 基本操作失败: {e}")
return False
def test_simple_model():
"""测试最简单的模型"""
print("\n🧠 测试最简单的模型...")
try:
device = xm.xla_device()
# 超简单的线性模型
model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 2)
).to(device)
print(f"📊 模型参数: {sum(p.numel() for p in model.parameters())}")
# 确保所有参数都是float32
for param in model.parameters():
param.data = param.data.to(torch.float32)
# 创建输入数据 - 明确指定float32
x = torch.randn(8, 10, device=device, dtype=torch.float32)
print(f"📥 输入: shape={x.shape}, dtype={x.dtype}")
# 前向传播
with torch.no_grad():
output = model(x)
xm.mark_step()
xm.wait_device_ops()
print(f"📤 输出: shape={output.shape}, dtype={output.dtype}")
print("✅ 简单模型前向传播成功")
return True
except Exception as e:
print(f"❌ 简单模型失败: {e}")
import traceback
traceback.print_exc()
return False
def test_training_step():
"""测试最简单的训练步骤"""
print("\n🎯 测试最简单的训练步骤...")
try:
device = xm.xla_device()
# 超简单模型
model = nn.Linear(10, 1).to(device)
# 确保权重是float32
for param in model.parameters():
param.data = param.data.to(torch.float32)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()
# 创建数据 - 明确float32
x = torch.randn(4, 10, device=device, dtype=torch.float32)
y = torch.randn(4, 1, device=device, dtype=torch.float32)
print(f"📥 输入: {x.shape}, {x.dtype}")
print(f"📥 标签: {y.shape}, {y.dtype}")
# 一个训练步骤
optimizer.zero_grad()
output = model(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()
# 同步
xm.mark_step()
xm.wait_device_ops()
print(f"🎯 损失: {loss.item():.4f}")
print("✅ 训练步骤成功")
return True
except Exception as e:
print(f"❌ 训练步骤失败: {e}")
import traceback
traceback.print_exc()
return False
def main():
"""主函数"""
print("=" * 50)
print("🔬 最简TPU测试 (仅float32)")
print("=" * 50)
all_passed = True
# 测试1: 基本操作
if test_basic_operations():
print("1⃣ 基本操作 ✅")
else:
print("1⃣ 基本操作 ❌")
all_passed = False
# 测试2: 简单模型
if test_simple_model():
print("2⃣ 简单模型 ✅")
else:
print("2⃣ 简单模型 ❌")
all_passed = False
# 测试3: 训练步骤
if test_training_step():
print("3⃣ 训练步骤 ✅")
else:
print("3⃣ 训练步骤 ❌")
all_passed = False
print("=" * 50)
if all_passed:
print("🎉 所有测试通过! TPU工作正常")
print("💡 现在可以尝试更复杂的模型")
else:
print("❌ 部分测试失败")
print("💡 建议:")
print(" 1. 检查TPU资源是否可用")
print(" 2. 确认torch_xla安装正确")
print(" 3. 重启runtime清理状态")
if __name__ == "__main__":
main()