78 lines
2.3 KiB
Python
78 lines
2.3 KiB
Python
# ====================
|
|
# 单元格3: 快速XLA编译测试
|
|
# ====================
|
|
|
|
# 简化测试模型
|
|
class QuickTestModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = nn.Linear(512, 128)
|
|
self.gru = nn.GRU(128, 64, batch_first=True)
|
|
self.linear2 = nn.Linear(64, 41)
|
|
|
|
def forward(self, x):
|
|
x = torch.relu(self.linear1(x))
|
|
x, _ = self.gru(x)
|
|
x = self.linear2(x)
|
|
return x
|
|
|
|
print("🧪 开始XLA编译快速测试...")
|
|
|
|
# 启动监控
|
|
compilation_monitor.start_monitoring()
|
|
|
|
try:
|
|
# 获取TPU设备
|
|
device = xm.xla_device()
|
|
|
|
# 创建小模型
|
|
model = QuickTestModel().to(device)
|
|
param_count = sum(p.numel() for p in model.parameters())
|
|
print(f"📊 测试模型参数: {param_count:,}")
|
|
|
|
# 创建测试数据 (很小的batch)
|
|
x = torch.randn(2, 20, 512, device=device)
|
|
print(f"📥 输入数据形状: {x.shape}")
|
|
|
|
print("🔄 开始首次前向传播 (触发XLA编译)...")
|
|
|
|
# 首次前向传播 - 这会触发XLA编译
|
|
with torch.no_grad():
|
|
start_compile = time.time()
|
|
output = model(x)
|
|
compile_time = time.time() - start_compile
|
|
|
|
print(f"✅ XLA编译完成!")
|
|
print(f"📤 输出形状: {output.shape}")
|
|
|
|
# 完成监控
|
|
compilation_monitor.complete_monitoring()
|
|
|
|
# 测试编译后的性能
|
|
print("\n🚀 测试编译后的执行速度...")
|
|
with torch.no_grad():
|
|
start_exec = time.time()
|
|
for _ in range(10):
|
|
output = model(x)
|
|
avg_exec_time = (time.time() - start_exec) / 10
|
|
|
|
print(f"⚡ 平均执行时间: {avg_exec_time*1000:.2f}ms")
|
|
|
|
# 性能评估
|
|
if compile_time < 30:
|
|
print("✅ 编译速度优秀! 可以尝试完整模型")
|
|
test_result = "excellent"
|
|
elif compile_time < 120:
|
|
print("✅ 编译速度良好! 建议使用简化配置")
|
|
test_result = "good"
|
|
else:
|
|
print("⚠️ 编译速度较慢,建议进一步优化")
|
|
test_result = "slow"
|
|
|
|
except Exception as e:
|
|
compilation_monitor.complete_monitoring()
|
|
print(f"❌ 测试失败: {e}")
|
|
test_result = "failed"
|
|
|
|
print(f"\n📋 测试结果: {test_result}")
|
|
print("💡 如果测试通过,可以运行下一个单元格进行完整训练") |