# ==================== # 单元格3: 快速XLA编译测试 # ==================== # 简化测试模型 class QuickTestModel(nn.Module): def __init__(self): super().__init__() self.linear1 = nn.Linear(512, 128) self.gru = nn.GRU(128, 64, batch_first=True) self.linear2 = nn.Linear(64, 41) def forward(self, x): x = torch.relu(self.linear1(x)) x, _ = self.gru(x) x = self.linear2(x) return x print("🧪 开始XLA编译快速测试...") # 启动监控 compilation_monitor.start_monitoring() try: # 获取TPU设备 device = xm.xla_device() # 创建小模型 model = QuickTestModel().to(device) param_count = sum(p.numel() for p in model.parameters()) print(f"📊 测试模型参数: {param_count:,}") # 创建测试数据 (很小的batch) x = torch.randn(2, 20, 512, device=device) print(f"📥 输入数据形状: {x.shape}") print("🔄 开始首次前向传播 (触发XLA编译)...") # 首次前向传播 - 这会触发XLA编译 with torch.no_grad(): start_compile = time.time() output = model(x) compile_time = time.time() - start_compile print(f"✅ XLA编译完成!") print(f"📤 输出形状: {output.shape}") # 完成监控 compilation_monitor.complete_monitoring() # 测试编译后的性能 print("\n🚀 测试编译后的执行速度...") with torch.no_grad(): start_exec = time.time() for _ in range(10): output = model(x) avg_exec_time = (time.time() - start_exec) / 10 print(f"⚡ 平均执行时间: {avg_exec_time*1000:.2f}ms") # 性能评估 if compile_time < 30: print("✅ 编译速度优秀! 可以尝试完整模型") test_result = "excellent" elif compile_time < 120: print("✅ 编译速度良好! 建议使用简化配置") test_result = "good" else: print("⚠️ 编译速度较慢,建议进一步优化") test_result = "slow" except Exception as e: compilation_monitor.complete_monitoring() print(f"❌ 测试失败: {e}") test_result = "failed" print(f"\n📋 测试结果: {test_result}") print("💡 如果测试通过,可以运行下一个单元格进行完整训练")