Files
b2txt25/model_training_nnn_tpu/minimal_tpu_test.py
2025-10-15 15:22:13 +08:00

194 lines
5.0 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
最简单的TPU测试 - 完全避开bf16问题
只使用float32最基本的操作
"""
import os
import time
import torch
import torch.nn as nn
# 完全不设置任何bf16相关的环境变量
# 只设置最基本的XLA优化
os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true'
# 确保不使用bf16
if 'XLA_USE_BF16' in os.environ:
del os.environ['XLA_USE_BF16']
import torch_xla.core.xla_model as xm
def test_basic_operations():
"""测试最基本的TPU操作"""
print("🚀 测试最基本的TPU操作...")
try:
device = xm.xla_device()
print(f"📱 设备: {device}")
# 测试1: 基本张量操作
print("🔧 测试基本张量操作...")
a = torch.randn(4, 4, device=device, dtype=torch.float32)
b = torch.randn(4, 4, device=device, dtype=torch.float32)
c = a + b
print(f" a.shape: {a.shape}, dtype: {a.dtype}")
print(f" b.shape: {b.shape}, dtype: {b.dtype}")
print(f" c.shape: {c.shape}, dtype: {c.dtype}")
# 同步
xm.mark_step()
xm.wait_device_ops()
print("✅ 基本张量操作成功")
# 测试2: 矩阵乘法
print("🔧 测试矩阵乘法...")
d = torch.mm(a, b)
xm.mark_step()
xm.wait_device_ops()
print(f" 矩阵乘法结果shape: {d.shape}, dtype: {d.dtype}")
print("✅ 矩阵乘法成功")
return True
except Exception as e:
print(f"❌ 基本操作失败: {e}")
return False
def test_simple_model():
"""测试最简单的模型"""
print("\n🧠 测试最简单的模型...")
try:
device = xm.xla_device()
# 超简单的线性模型
model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 2)
).to(device)
print(f"📊 模型参数: {sum(p.numel() for p in model.parameters())}")
# 确保所有参数都是float32
for param in model.parameters():
param.data = param.data.to(torch.float32)
# 创建输入数据 - 明确指定float32
x = torch.randn(8, 10, device=device, dtype=torch.float32)
print(f"📥 输入: shape={x.shape}, dtype={x.dtype}")
# 前向传播
with torch.no_grad():
output = model(x)
xm.mark_step()
xm.wait_device_ops()
print(f"📤 输出: shape={output.shape}, dtype={output.dtype}")
print("✅ 简单模型前向传播成功")
return True
except Exception as e:
print(f"❌ 简单模型失败: {e}")
import traceback
traceback.print_exc()
return False
def test_training_step():
"""测试最简单的训练步骤"""
print("\n🎯 测试最简单的训练步骤...")
try:
device = xm.xla_device()
# 超简单模型
model = nn.Linear(10, 1).to(device)
# 确保权重是float32
for param in model.parameters():
param.data = param.data.to(torch.float32)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()
# 创建数据 - 明确float32
x = torch.randn(4, 10, device=device, dtype=torch.float32)
y = torch.randn(4, 1, device=device, dtype=torch.float32)
print(f"📥 输入: {x.shape}, {x.dtype}")
print(f"📥 标签: {y.shape}, {y.dtype}")
# 一个训练步骤
optimizer.zero_grad()
output = model(x)
loss = criterion(output, y)
loss.backward()
optimizer.step()
# 同步
xm.mark_step()
xm.wait_device_ops()
print(f"🎯 损失: {loss.item():.4f}")
print("✅ 训练步骤成功")
return True
except Exception as e:
print(f"❌ 训练步骤失败: {e}")
import traceback
traceback.print_exc()
return False
def main():
"""主函数"""
print("=" * 50)
print("🔬 最简TPU测试 (仅float32)")
print("=" * 50)
all_passed = True
# 测试1: 基本操作
if test_basic_operations():
print("1⃣ 基本操作 ✅")
else:
print("1⃣ 基本操作 ❌")
all_passed = False
# 测试2: 简单模型
if test_simple_model():
print("2⃣ 简单模型 ✅")
else:
print("2⃣ 简单模型 ❌")
all_passed = False
# 测试3: 训练步骤
if test_training_step():
print("3⃣ 训练步骤 ✅")
else:
print("3⃣ 训练步骤 ❌")
all_passed = False
print("=" * 50)
if all_passed:
print("🎉 所有测试通过! TPU工作正常")
print("💡 现在可以尝试更复杂的模型")
else:
print("❌ 部分测试失败")
print("💡 建议:")
print(" 1. 检查TPU资源是否可用")
print(" 2. 确认torch_xla安装正确")
print(" 3. 重启runtime清理状态")
if __name__ == "__main__":
main()