tpu test
This commit is contained in:
194
model_training_nnn_tpu/minimal_tpu_test.py
Normal file
194
model_training_nnn_tpu/minimal_tpu_test.py
Normal file
@@ -0,0 +1,194 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
最简单的TPU测试 - 完全避开bf16问题
|
||||
只使用float32,最基本的操作
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# 完全不设置任何bf16相关的环境变量
|
||||
# 只设置最基本的XLA优化
|
||||
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true'
|
||||
|
||||
# 确保不使用bf16
|
||||
if 'XLA_USE_BF16' in os.environ:
|
||||
del os.environ['XLA_USE_BF16']
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
|
||||
def test_basic_operations():
|
||||
"""测试最基本的TPU操作"""
|
||||
print("🚀 测试最基本的TPU操作...")
|
||||
|
||||
try:
|
||||
device = xm.xla_device()
|
||||
print(f"📱 设备: {device}")
|
||||
|
||||
# 测试1: 基本张量操作
|
||||
print("🔧 测试基本张量操作...")
|
||||
a = torch.randn(4, 4, device=device, dtype=torch.float32)
|
||||
b = torch.randn(4, 4, device=device, dtype=torch.float32)
|
||||
c = a + b
|
||||
|
||||
print(f" a.shape: {a.shape}, dtype: {a.dtype}")
|
||||
print(f" b.shape: {b.shape}, dtype: {b.dtype}")
|
||||
print(f" c.shape: {c.shape}, dtype: {c.dtype}")
|
||||
|
||||
# 同步
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
print("✅ 基本张量操作成功")
|
||||
|
||||
# 测试2: 矩阵乘法
|
||||
print("🔧 测试矩阵乘法...")
|
||||
d = torch.mm(a, b)
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
print(f" 矩阵乘法结果shape: {d.shape}, dtype: {d.dtype}")
|
||||
print("✅ 矩阵乘法成功")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 基本操作失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def test_simple_model():
|
||||
"""测试最简单的模型"""
|
||||
print("\n🧠 测试最简单的模型...")
|
||||
|
||||
try:
|
||||
device = xm.xla_device()
|
||||
|
||||
# 超简单的线性模型
|
||||
model = nn.Sequential(
|
||||
nn.Linear(10, 5),
|
||||
nn.ReLU(),
|
||||
nn.Linear(5, 2)
|
||||
).to(device)
|
||||
|
||||
print(f"📊 模型参数: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
# 确保所有参数都是float32
|
||||
for param in model.parameters():
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
# 创建输入数据 - 明确指定float32
|
||||
x = torch.randn(8, 10, device=device, dtype=torch.float32)
|
||||
|
||||
print(f"📥 输入: shape={x.shape}, dtype={x.dtype}")
|
||||
|
||||
# 前向传播
|
||||
with torch.no_grad():
|
||||
output = model(x)
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
print(f"📤 输出: shape={output.shape}, dtype={output.dtype}")
|
||||
print("✅ 简单模型前向传播成功")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 简单模型失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def test_training_step():
|
||||
"""测试最简单的训练步骤"""
|
||||
print("\n🎯 测试最简单的训练步骤...")
|
||||
|
||||
try:
|
||||
device = xm.xla_device()
|
||||
|
||||
# 超简单模型
|
||||
model = nn.Linear(10, 1).to(device)
|
||||
|
||||
# 确保权重是float32
|
||||
for param in model.parameters():
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||
criterion = nn.MSELoss()
|
||||
|
||||
# 创建数据 - 明确float32
|
||||
x = torch.randn(4, 10, device=device, dtype=torch.float32)
|
||||
y = torch.randn(4, 1, device=device, dtype=torch.float32)
|
||||
|
||||
print(f"📥 输入: {x.shape}, {x.dtype}")
|
||||
print(f"📥 标签: {y.shape}, {y.dtype}")
|
||||
|
||||
# 一个训练步骤
|
||||
optimizer.zero_grad()
|
||||
output = model(x)
|
||||
loss = criterion(output, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 同步
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
|
||||
print(f"🎯 损失: {loss.item():.4f}")
|
||||
print("✅ 训练步骤成功")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 训练步骤失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("=" * 50)
|
||||
print("🔬 最简TPU测试 (仅float32)")
|
||||
print("=" * 50)
|
||||
|
||||
all_passed = True
|
||||
|
||||
# 测试1: 基本操作
|
||||
if test_basic_operations():
|
||||
print("1️⃣ 基本操作 ✅")
|
||||
else:
|
||||
print("1️⃣ 基本操作 ❌")
|
||||
all_passed = False
|
||||
|
||||
# 测试2: 简单模型
|
||||
if test_simple_model():
|
||||
print("2️⃣ 简单模型 ✅")
|
||||
else:
|
||||
print("2️⃣ 简单模型 ❌")
|
||||
all_passed = False
|
||||
|
||||
# 测试3: 训练步骤
|
||||
if test_training_step():
|
||||
print("3️⃣ 训练步骤 ✅")
|
||||
else:
|
||||
print("3️⃣ 训练步骤 ❌")
|
||||
all_passed = False
|
||||
|
||||
print("=" * 50)
|
||||
|
||||
if all_passed:
|
||||
print("🎉 所有测试通过! TPU工作正常")
|
||||
print("💡 现在可以尝试更复杂的模型")
|
||||
else:
|
||||
print("❌ 部分测试失败")
|
||||
print("💡 建议:")
|
||||
print(" 1. 检查TPU资源是否可用")
|
||||
print(" 2. 确认torch_xla安装正确")
|
||||
print(" 3. 重启runtime清理状态")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user