Files
b2txt25/model_training_nnn/test_dual_gru.py

139 lines
4.5 KiB
Python
Raw Permalink Normal View History

2025-10-12 09:11:32 +08:00
#!/usr/bin/env python3
"""
Test script for the dual-GRU architecture
"""
import torch
import sys
import os
# Add the current directory to the path to import rnn_model
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from rnn_model import GRUDecoder
def test_dual_gru_architecture():
"""Test the dual-GRU model with synthetic data"""
# Model parameters (matching the original configuration)
neural_dim = 512
n_units = 768
n_days = 5
n_classes = 41
batch_size = 4
seq_len = 100
print("=== Dual-GRU Architecture Test ===")
print(f"Neural dim: {neural_dim}")
print(f"Hidden units: {n_units}")
print(f"Regression GRU layers: 2")
print(f"Residual GRU layers: 3")
print(f"Classes: {n_classes}")
print()
# Create model instance
model = GRUDecoder(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
n_classes=n_classes,
rnn_dropout=0.1,
input_dropout=0.1,
n_layers_regression=2,
n_layers_residual=3,
patch_size=0, # Start without patching for simpler test
patch_stride=0
)
print(f"Model created successfully!")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print()
# Create synthetic input data
x = torch.randn(batch_size, seq_len, neural_dim)
day_idx = torch.randint(0, n_days, (batch_size,))
print(f"Input shape: {x.shape}")
print(f"Day indices: {day_idx}")
print()
# Test forward pass without states
print("Testing forward pass (without states)...")
with torch.no_grad():
logits = model(x, day_idx)
print(f"Output logits shape: {logits.shape}")
print(f"Expected shape: [{batch_size}, {seq_len}, {n_classes}]")
assert logits.shape == (batch_size, seq_len, n_classes), f"Shape mismatch! Got {logits.shape}"
print("✓ Forward pass successful")
print()
# Test forward pass with state return
print("Testing forward pass (with state return)...")
with torch.no_grad():
logits, states = model(x, day_idx, return_state=True)
regression_states, residual_states = states
print(f"Output logits shape: {logits.shape}")
print(f"Regression states shape: {regression_states.shape}")
print(f"Residual states shape: {residual_states.shape}")
print("✓ Forward pass with states successful")
print()
# Test with patch processing
print("Testing with patch processing...")
model_with_patches = GRUDecoder(
neural_dim=neural_dim,
n_units=n_units,
n_days=n_days,
n_classes=n_classes,
rnn_dropout=0.1,
input_dropout=0.1,
n_layers_regression=2,
n_layers_residual=3,
patch_size=14,
patch_stride=4
)
with torch.no_grad():
logits_patches = model_with_patches(x, day_idx)
expected_patches = (seq_len - 14) // 4 + 1 # Number of patches
print(f"Output logits shape (with patches): {logits_patches.shape}")
print(f"Expected patches: {expected_patches}")
print("✓ Patch processing successful")
print()
# Test gradient flow
print("Testing gradient flow...")
model.train()
x_grad = torch.randn(batch_size, seq_len, neural_dim, requires_grad=True)
logits = model(x_grad, day_idx)
loss = logits.sum()
loss.backward()
# Check if gradients exist
regression_grad_exists = any(p.grad is not None for p in model.gru_regression.parameters())
residual_grad_exists = any(p.grad is not None for p in model.gru_residual.parameters())
day_grad_exists = any(p.grad is not None for p in model.day_weights)
print(f"Regression GRU gradients: {'' if regression_grad_exists else ''}")
print(f"Residual GRU gradients: {'' if residual_grad_exists else ''}")
print(f"Day-specific layer gradients: {'' if day_grad_exists else ''}")
print("✓ Gradient flow test successful")
print()
print("=== All tests passed! ===")
print()
# Print architecture summary
print("=== Architecture Summary ===")
print("Data Flow:")
print("1. Input → Day-specific layers (512 → 512)")
print("2. Day output → Regression GRU (2 layers, 512 hidden)")
print("3. Residual = Day output - Regression output")
print("4. Residual → Residual GRU (3 layers, 768 hidden)")
print("5. Residual GRU output → Linear classifier (768 → 41)")
print()
return True
if __name__ == "__main__":
test_dual_gru_architecture()