Files
b2txt25/model_training_nnn_tpu/jupyter_xla_test.py
Zchen 56fa336af0 tpu
2025-10-15 14:26:11 +08:00

78 lines
2.3 KiB
Python

# ====================
# 单元格3: 快速XLA编译测试
# ====================
# 简化测试模型
class QuickTestModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(512, 128)
self.gru = nn.GRU(128, 64, batch_first=True)
self.linear2 = nn.Linear(64, 41)
def forward(self, x):
x = torch.relu(self.linear1(x))
x, _ = self.gru(x)
x = self.linear2(x)
return x
print("🧪 开始XLA编译快速测试...")
# 启动监控
compilation_monitor.start_monitoring()
try:
# 获取TPU设备
device = xm.xla_device()
# 创建小模型
model = QuickTestModel().to(device)
param_count = sum(p.numel() for p in model.parameters())
print(f"📊 测试模型参数: {param_count:,}")
# 创建测试数据 (很小的batch)
x = torch.randn(2, 20, 512, device=device)
print(f"📥 输入数据形状: {x.shape}")
print("🔄 开始首次前向传播 (触发XLA编译)...")
# 首次前向传播 - 这会触发XLA编译
with torch.no_grad():
start_compile = time.time()
output = model(x)
compile_time = time.time() - start_compile
print(f"✅ XLA编译完成!")
print(f"📤 输出形状: {output.shape}")
# 完成监控
compilation_monitor.complete_monitoring()
# 测试编译后的性能
print("\n🚀 测试编译后的执行速度...")
with torch.no_grad():
start_exec = time.time()
for _ in range(10):
output = model(x)
avg_exec_time = (time.time() - start_exec) / 10
print(f"⚡ 平均执行时间: {avg_exec_time*1000:.2f}ms")
# 性能评估
if compile_time < 30:
print("✅ 编译速度优秀! 可以尝试完整模型")
test_result = "excellent"
elif compile_time < 120:
print("✅ 编译速度良好! 建议使用简化配置")
test_result = "good"
else:
print("⚠️ 编译速度较慢,建议进一步优化")
test_result = "slow"
except Exception as e:
compilation_monitor.complete_monitoring()
print(f"❌ 测试失败: {e}")
test_result = "failed"
print(f"\n📋 测试结果: {test_result}")
print("💡 如果测试通过,可以运行下一个单元格进行完整训练")