124 lines
3.8 KiB
Python
124 lines
3.8 KiB
Python
# ====================
|
||
# 单元格4: 逐步调试完整模型编译
|
||
# ====================
|
||
|
||
# 如果单元格3测试通过,运行这个单元格
|
||
print("🔧 逐步测试完整TripleGRUDecoder模型...")
|
||
|
||
# 导入完整模型
|
||
import sys
|
||
sys.path.append('.') # 确保能导入本地模块
|
||
|
||
try:
|
||
from rnn_model import TripleGRUDecoder
|
||
print("✅ TripleGRUDecoder导入成功")
|
||
except ImportError as e:
|
||
print(f"❌ 模型导入失败: {e}")
|
||
print("请确保rnn_model.py在当前目录中")
|
||
|
||
# 分阶段测试
|
||
def test_model_compilation_stages():
|
||
"""分阶段测试模型编译"""
|
||
device = xm.xla_device()
|
||
|
||
# 阶段1: 测试NoiseModel单独编译
|
||
print("\n🔬 阶段1: 测试NoiseModel...")
|
||
try:
|
||
from rnn_model import NoiseModel
|
||
noise_model = NoiseModel(
|
||
neural_dim=512,
|
||
n_units=384, # 减小参数
|
||
n_days=4,
|
||
patch_size=8 # 减小patch size
|
||
).to(device)
|
||
|
||
x = torch.randn(2, 20, 512, device=device)
|
||
day_idx = torch.tensor([0, 1], device=device)
|
||
|
||
start_time = time.time()
|
||
with torch.no_grad():
|
||
output, states = noise_model(x, day_idx)
|
||
compile_time = time.time() - start_time
|
||
|
||
print(f"✅ NoiseModel编译成功! 耗时: {compile_time:.2f}秒")
|
||
print(f" 参数数量: {sum(p.numel() for p in noise_model.parameters()):,}")
|
||
|
||
return True, compile_time
|
||
|
||
except Exception as e:
|
||
print(f"❌ NoiseModel编译失败: {e}")
|
||
return False, 0
|
||
|
||
# 阶段2: 测试CleanSpeechModel
|
||
print("\n🔬 阶段2: 测试CleanSpeechModel...")
|
||
try:
|
||
from rnn_model import CleanSpeechModel
|
||
clean_model = CleanSpeechModel(
|
||
neural_dim=512,
|
||
n_units=384,
|
||
n_days=4,
|
||
n_classes=41,
|
||
patch_size=8
|
||
).to(device)
|
||
|
||
start_time = time.time()
|
||
with torch.no_grad():
|
||
output = clean_model(x, day_idx)
|
||
compile_time = time.time() - start_time
|
||
|
||
print(f"✅ CleanSpeechModel编译成功! 耗时: {compile_time:.2f}秒")
|
||
return True, compile_time
|
||
|
||
except Exception as e:
|
||
print(f"❌ CleanSpeechModel编译失败: {e}")
|
||
return False, 0
|
||
|
||
# 阶段3: 测试完整TripleGRUDecoder
|
||
print("\n🔬 阶段3: 测试TripleGRUDecoder...")
|
||
try:
|
||
model = TripleGRUDecoder(
|
||
neural_dim=512,
|
||
n_units=384, # 比原来的768小
|
||
n_days=4, # 减少天数
|
||
n_classes=41,
|
||
patch_size=8 # 减小patch size
|
||
).to(device)
|
||
|
||
print(f"📊 完整模型参数: {sum(p.numel() for p in model.parameters()):,}")
|
||
|
||
# 启动编译监控
|
||
compilation_monitor.start_monitoring()
|
||
|
||
start_time = time.time()
|
||
with torch.no_grad():
|
||
# 测试inference模式
|
||
logits = model(x, day_idx, None, False, 'inference')
|
||
compile_time = time.time() - start_time
|
||
|
||
compilation_monitor.complete_monitoring()
|
||
|
||
print(f"✅ TripleGRUDecoder编译成功! 耗时: {compile_time:.2f}秒")
|
||
print(f"📤 输出形状: {logits.shape}")
|
||
|
||
return True, compile_time
|
||
|
||
except Exception as e:
|
||
compilation_monitor.complete_monitoring()
|
||
print(f"❌ TripleGRUDecoder编译失败: {e}")
|
||
return False, 0
|
||
|
||
# 运行分阶段测试
|
||
stage_results = test_model_compilation_stages()
|
||
|
||
if stage_results:
|
||
print(f"\n🎉 所有编译测试完成!")
|
||
print("💡 下一步可以尝试:")
|
||
print(" 1. 使用简化配置进行训练")
|
||
print(" 2. 逐步增加模型复杂度")
|
||
print(" 3. 监控TPU资源使用情况")
|
||
else:
|
||
print(f"\n⚠️ 编译测试发现问题")
|
||
print("💡 建议:")
|
||
print(" 1. 进一步减小模型参数")
|
||
print(" 2. 检查内存使用情况")
|
||
print(" 3. 使用CPU模式进行调试") |