#!/usr/bin/env python3 """ 最简单的TPU测试 - 完全避开bf16问题 只使用float32,最基本的操作 """ import os import time import torch import torch.nn as nn # 完全不设置任何bf16相关的环境变量 # 只设置最基本的XLA优化 os.environ['XLA_FLAGS'] = '--xla_cpu_multi_thread_eigen=true' # 确保不使用bf16 if 'XLA_USE_BF16' in os.environ: del os.environ['XLA_USE_BF16'] import torch_xla.core.xla_model as xm def test_basic_operations(): """测试最基本的TPU操作""" print("🚀 测试最基本的TPU操作...") try: device = xm.xla_device() print(f"📱 设备: {device}") # 测试1: 基本张量操作 print("🔧 测试基本张量操作...") a = torch.randn(4, 4, device=device, dtype=torch.float32) b = torch.randn(4, 4, device=device, dtype=torch.float32) c = a + b print(f" a.shape: {a.shape}, dtype: {a.dtype}") print(f" b.shape: {b.shape}, dtype: {b.dtype}") print(f" c.shape: {c.shape}, dtype: {c.dtype}") # 同步 xm.mark_step() xm.wait_device_ops() print("✅ 基本张量操作成功") # 测试2: 矩阵乘法 print("🔧 测试矩阵乘法...") d = torch.mm(a, b) xm.mark_step() xm.wait_device_ops() print(f" 矩阵乘法结果shape: {d.shape}, dtype: {d.dtype}") print("✅ 矩阵乘法成功") return True except Exception as e: print(f"❌ 基本操作失败: {e}") return False def test_simple_model(): """测试最简单的模型""" print("\n🧠 测试最简单的模型...") try: device = xm.xla_device() # 超简单的线性模型 model = nn.Sequential( nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 2) ).to(device) print(f"📊 模型参数: {sum(p.numel() for p in model.parameters())}") # 确保所有参数都是float32 for param in model.parameters(): param.data = param.data.to(torch.float32) # 创建输入数据 - 明确指定float32 x = torch.randn(8, 10, device=device, dtype=torch.float32) print(f"📥 输入: shape={x.shape}, dtype={x.dtype}") # 前向传播 with torch.no_grad(): output = model(x) xm.mark_step() xm.wait_device_ops() print(f"📤 输出: shape={output.shape}, dtype={output.dtype}") print("✅ 简单模型前向传播成功") return True except Exception as e: print(f"❌ 简单模型失败: {e}") import traceback traceback.print_exc() return False def test_training_step(): """测试最简单的训练步骤""" print("\n🎯 测试最简单的训练步骤...") try: device = xm.xla_device() # 超简单模型 model = nn.Linear(10, 1).to(device) # 确保权重是float32 for param in model.parameters(): param.data = param.data.to(torch.float32) optimizer = torch.optim.SGD(model.parameters(), lr=0.01) criterion = nn.MSELoss() # 创建数据 - 明确float32 x = torch.randn(4, 10, device=device, dtype=torch.float32) y = torch.randn(4, 1, device=device, dtype=torch.float32) print(f"📥 输入: {x.shape}, {x.dtype}") print(f"📥 标签: {y.shape}, {y.dtype}") # 一个训练步骤 optimizer.zero_grad() output = model(x) loss = criterion(output, y) loss.backward() optimizer.step() # 同步 xm.mark_step() xm.wait_device_ops() print(f"🎯 损失: {loss.item():.4f}") print("✅ 训练步骤成功") return True except Exception as e: print(f"❌ 训练步骤失败: {e}") import traceback traceback.print_exc() return False def main(): """主函数""" print("=" * 50) print("🔬 最简TPU测试 (仅float32)") print("=" * 50) all_passed = True # 测试1: 基本操作 if test_basic_operations(): print("1️⃣ 基本操作 ✅") else: print("1️⃣ 基本操作 ❌") all_passed = False # 测试2: 简单模型 if test_simple_model(): print("2️⃣ 简单模型 ✅") else: print("2️⃣ 简单模型 ❌") all_passed = False # 测试3: 训练步骤 if test_training_step(): print("3️⃣ 训练步骤 ✅") else: print("3️⃣ 训练步骤 ❌") all_passed = False print("=" * 50) if all_passed: print("🎉 所有测试通过! TPU工作正常") print("💡 现在可以尝试更复杂的模型") else: print("❌ 部分测试失败") print("💡 建议:") print(" 1. 检查TPU资源是否可用") print(" 2. 确认torch_xla安装正确") print(" 3. 重启runtime清理状态") if __name__ == "__main__": main()