#!/usr/bin/env python3 """ Test script for the dual-GRU architecture """ import torch import sys import os # Add the current directory to the path to import rnn_model sys.path.append(os.path.dirname(os.path.abspath(__file__))) from rnn_model import GRUDecoder def test_dual_gru_architecture(): """Test the dual-GRU model with synthetic data""" # Model parameters (matching the original configuration) neural_dim = 512 n_units = 768 n_days = 5 n_classes = 41 batch_size = 4 seq_len = 100 print("=== Dual-GRU Architecture Test ===") print(f"Neural dim: {neural_dim}") print(f"Hidden units: {n_units}") print(f"Regression GRU layers: 2") print(f"Residual GRU layers: 3") print(f"Classes: {n_classes}") print() # Create model instance model = GRUDecoder( neural_dim=neural_dim, n_units=n_units, n_days=n_days, n_classes=n_classes, rnn_dropout=0.1, input_dropout=0.1, n_layers_regression=2, n_layers_residual=3, patch_size=0, # Start without patching for simpler test patch_stride=0 ) print(f"Model created successfully!") print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}") print() # Create synthetic input data x = torch.randn(batch_size, seq_len, neural_dim) day_idx = torch.randint(0, n_days, (batch_size,)) print(f"Input shape: {x.shape}") print(f"Day indices: {day_idx}") print() # Test forward pass without states print("Testing forward pass (without states)...") with torch.no_grad(): logits = model(x, day_idx) print(f"Output logits shape: {logits.shape}") print(f"Expected shape: [{batch_size}, {seq_len}, {n_classes}]") assert logits.shape == (batch_size, seq_len, n_classes), f"Shape mismatch! Got {logits.shape}" print("✓ Forward pass successful") print() # Test forward pass with state return print("Testing forward pass (with state return)...") with torch.no_grad(): logits, states = model(x, day_idx, return_state=True) regression_states, residual_states = states print(f"Output logits shape: {logits.shape}") print(f"Regression states shape: {regression_states.shape}") print(f"Residual states shape: {residual_states.shape}") print("✓ Forward pass with states successful") print() # Test with patch processing print("Testing with patch processing...") model_with_patches = GRUDecoder( neural_dim=neural_dim, n_units=n_units, n_days=n_days, n_classes=n_classes, rnn_dropout=0.1, input_dropout=0.1, n_layers_regression=2, n_layers_residual=3, patch_size=14, patch_stride=4 ) with torch.no_grad(): logits_patches = model_with_patches(x, day_idx) expected_patches = (seq_len - 14) // 4 + 1 # Number of patches print(f"Output logits shape (with patches): {logits_patches.shape}") print(f"Expected patches: {expected_patches}") print("✓ Patch processing successful") print() # Test gradient flow print("Testing gradient flow...") model.train() x_grad = torch.randn(batch_size, seq_len, neural_dim, requires_grad=True) logits = model(x_grad, day_idx) loss = logits.sum() loss.backward() # Check if gradients exist regression_grad_exists = any(p.grad is not None for p in model.gru_regression.parameters()) residual_grad_exists = any(p.grad is not None for p in model.gru_residual.parameters()) day_grad_exists = any(p.grad is not None for p in model.day_weights) print(f"Regression GRU gradients: {'✓' if regression_grad_exists else '✗'}") print(f"Residual GRU gradients: {'✓' if residual_grad_exists else '✗'}") print(f"Day-specific layer gradients: {'✓' if day_grad_exists else '✗'}") print("✓ Gradient flow test successful") print() print("=== All tests passed! ===") print() # Print architecture summary print("=== Architecture Summary ===") print("Data Flow:") print("1. Input → Day-specific layers (512 → 512)") print("2. Day output → Regression GRU (2 layers, 512 hidden)") print("3. Residual = Day output - Regression output") print("4. Residual → Residual GRU (3 layers, 768 hidden)") print("5. Residual GRU output → Linear classifier (768 → 41)") print() return True if __name__ == "__main__": test_dual_gru_architecture()