Files
b2txt25/model_training_nnn_tpu/jupyter_debug_full_model.py
Zchen 56fa336af0 tpu
2025-10-15 14:26:11 +08:00

124 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# ====================
# 单元格4: 逐步调试完整模型编译
# ====================
# 如果单元格3测试通过运行这个单元格
print("🔧 逐步测试完整TripleGRUDecoder模型...")
# 导入完整模型
import sys
sys.path.append('.') # 确保能导入本地模块
try:
from rnn_model import TripleGRUDecoder
print("✅ TripleGRUDecoder导入成功")
except ImportError as e:
print(f"❌ 模型导入失败: {e}")
print("请确保rnn_model.py在当前目录中")
# 分阶段测试
def test_model_compilation_stages():
"""分阶段测试模型编译"""
device = xm.xla_device()
# 阶段1: 测试NoiseModel单独编译
print("\n🔬 阶段1: 测试NoiseModel...")
try:
from rnn_model import NoiseModel
noise_model = NoiseModel(
neural_dim=512,
n_units=384, # 减小参数
n_days=4,
patch_size=8 # 减小patch size
).to(device)
x = torch.randn(2, 20, 512, device=device)
day_idx = torch.tensor([0, 1], device=device)
start_time = time.time()
with torch.no_grad():
output, states = noise_model(x, day_idx)
compile_time = time.time() - start_time
print(f"✅ NoiseModel编译成功! 耗时: {compile_time:.2f}")
print(f" 参数数量: {sum(p.numel() for p in noise_model.parameters()):,}")
return True, compile_time
except Exception as e:
print(f"❌ NoiseModel编译失败: {e}")
return False, 0
# 阶段2: 测试CleanSpeechModel
print("\n🔬 阶段2: 测试CleanSpeechModel...")
try:
from rnn_model import CleanSpeechModel
clean_model = CleanSpeechModel(
neural_dim=512,
n_units=384,
n_days=4,
n_classes=41,
patch_size=8
).to(device)
start_time = time.time()
with torch.no_grad():
output = clean_model(x, day_idx)
compile_time = time.time() - start_time
print(f"✅ CleanSpeechModel编译成功! 耗时: {compile_time:.2f}")
return True, compile_time
except Exception as e:
print(f"❌ CleanSpeechModel编译失败: {e}")
return False, 0
# 阶段3: 测试完整TripleGRUDecoder
print("\n🔬 阶段3: 测试TripleGRUDecoder...")
try:
model = TripleGRUDecoder(
neural_dim=512,
n_units=384, # 比原来的768小
n_days=4, # 减少天数
n_classes=41,
patch_size=8 # 减小patch size
).to(device)
print(f"📊 完整模型参数: {sum(p.numel() for p in model.parameters()):,}")
# 启动编译监控
compilation_monitor.start_monitoring()
start_time = time.time()
with torch.no_grad():
# 测试inference模式
logits = model(x, day_idx, None, False, 'inference')
compile_time = time.time() - start_time
compilation_monitor.complete_monitoring()
print(f"✅ TripleGRUDecoder编译成功! 耗时: {compile_time:.2f}")
print(f"📤 输出形状: {logits.shape}")
return True, compile_time
except Exception as e:
compilation_monitor.complete_monitoring()
print(f"❌ TripleGRUDecoder编译失败: {e}")
return False, 0
# 运行分阶段测试
stage_results = test_model_compilation_stages()
if stage_results:
print(f"\n🎉 所有编译测试完成!")
print("💡 下一步可以尝试:")
print(" 1. 使用简化配置进行训练")
print(" 2. 逐步增加模型复杂度")
print(" 3. 监控TPU资源使用情况")
else:
print(f"\n⚠️ 编译测试发现问题")
print("💡 建议:")
print(" 1. 进一步减小模型参数")
print(" 2. 检查内存使用情况")
print(" 3. 使用CPU模式进行调试")