# ==================== # 单元格4: 逐步调试完整模型编译 # ==================== # 如果单元格3测试通过,运行这个单元格 print("🔧 逐步测试完整TripleGRUDecoder模型...") # 导入完整模型 import sys sys.path.append('.') # 确保能导入本地模块 try: from rnn_model import TripleGRUDecoder print("✅ TripleGRUDecoder导入成功") except ImportError as e: print(f"❌ 模型导入失败: {e}") print("请确保rnn_model.py在当前目录中") # 分阶段测试 def test_model_compilation_stages(): """分阶段测试模型编译""" device = xm.xla_device() # 阶段1: 测试NoiseModel单独编译 print("\n🔬 阶段1: 测试NoiseModel...") try: from rnn_model import NoiseModel noise_model = NoiseModel( neural_dim=512, n_units=384, # 减小参数 n_days=4, patch_size=8 # 减小patch size ).to(device) x = torch.randn(2, 20, 512, device=device) day_idx = torch.tensor([0, 1], device=device) start_time = time.time() with torch.no_grad(): output, states = noise_model(x, day_idx) compile_time = time.time() - start_time print(f"✅ NoiseModel编译成功! 耗时: {compile_time:.2f}秒") print(f" 参数数量: {sum(p.numel() for p in noise_model.parameters()):,}") return True, compile_time except Exception as e: print(f"❌ NoiseModel编译失败: {e}") return False, 0 # 阶段2: 测试CleanSpeechModel print("\n🔬 阶段2: 测试CleanSpeechModel...") try: from rnn_model import CleanSpeechModel clean_model = CleanSpeechModel( neural_dim=512, n_units=384, n_days=4, n_classes=41, patch_size=8 ).to(device) start_time = time.time() with torch.no_grad(): output = clean_model(x, day_idx) compile_time = time.time() - start_time print(f"✅ CleanSpeechModel编译成功! 耗时: {compile_time:.2f}秒") return True, compile_time except Exception as e: print(f"❌ CleanSpeechModel编译失败: {e}") return False, 0 # 阶段3: 测试完整TripleGRUDecoder print("\n🔬 阶段3: 测试TripleGRUDecoder...") try: model = TripleGRUDecoder( neural_dim=512, n_units=384, # 比原来的768小 n_days=4, # 减少天数 n_classes=41, patch_size=8 # 减小patch size ).to(device) print(f"📊 完整模型参数: {sum(p.numel() for p in model.parameters()):,}") # 启动编译监控 compilation_monitor.start_monitoring() start_time = time.time() with torch.no_grad(): # 测试inference模式 logits = model(x, day_idx, None, False, 'inference') compile_time = time.time() - start_time compilation_monitor.complete_monitoring() print(f"✅ TripleGRUDecoder编译成功! 耗时: {compile_time:.2f}秒") print(f"📤 输出形状: {logits.shape}") return True, compile_time except Exception as e: compilation_monitor.complete_monitoring() print(f"❌ TripleGRUDecoder编译失败: {e}") return False, 0 # 运行分阶段测试 stage_results = test_model_compilation_stages() if stage_results: print(f"\n🎉 所有编译测试完成!") print("💡 下一步可以尝试:") print(" 1. 使用简化配置进行训练") print(" 2. 逐步增加模型复杂度") print(" 3. 监控TPU资源使用情况") else: print(f"\n⚠️ 编译测试发现问题") print("💡 建议:") print(" 1. 进一步减小模型参数") print(" 2. 检查内存使用情况") print(" 3. 使用CPU模式进行调试")