139 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			139 lines
		
	
	
		
			4.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | ||
|  | """
 | ||
|  | Test script for the dual-GRU architecture | ||
|  | """
 | ||
|  | 
 | ||
|  | import torch | ||
|  | import sys | ||
|  | import os | ||
|  | 
 | ||
|  | # Add the current directory to the path to import rnn_model | ||
|  | sys.path.append(os.path.dirname(os.path.abspath(__file__))) | ||
|  | 
 | ||
|  | from rnn_model import GRUDecoder | ||
|  | 
 | ||
|  | def test_dual_gru_architecture(): | ||
|  |     """Test the dual-GRU model with synthetic data""" | ||
|  | 
 | ||
|  |     # Model parameters (matching the original configuration) | ||
|  |     neural_dim = 512 | ||
|  |     n_units = 768 | ||
|  |     n_days = 5 | ||
|  |     n_classes = 41 | ||
|  |     batch_size = 4 | ||
|  |     seq_len = 100 | ||
|  | 
 | ||
|  |     print("=== Dual-GRU Architecture Test ===") | ||
|  |     print(f"Neural dim: {neural_dim}") | ||
|  |     print(f"Hidden units: {n_units}") | ||
|  |     print(f"Regression GRU layers: 2") | ||
|  |     print(f"Residual GRU layers: 3") | ||
|  |     print(f"Classes: {n_classes}") | ||
|  |     print() | ||
|  | 
 | ||
|  |     # Create model instance | ||
|  |     model = GRUDecoder( | ||
|  |         neural_dim=neural_dim, | ||
|  |         n_units=n_units, | ||
|  |         n_days=n_days, | ||
|  |         n_classes=n_classes, | ||
|  |         rnn_dropout=0.1, | ||
|  |         input_dropout=0.1, | ||
|  |         n_layers_regression=2, | ||
|  |         n_layers_residual=3, | ||
|  |         patch_size=0,  # Start without patching for simpler test | ||
|  |         patch_stride=0 | ||
|  |     ) | ||
|  | 
 | ||
|  |     print(f"Model created successfully!") | ||
|  |     print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}") | ||
|  |     print() | ||
|  | 
 | ||
|  |     # Create synthetic input data | ||
|  |     x = torch.randn(batch_size, seq_len, neural_dim) | ||
|  |     day_idx = torch.randint(0, n_days, (batch_size,)) | ||
|  | 
 | ||
|  |     print(f"Input shape: {x.shape}") | ||
|  |     print(f"Day indices: {day_idx}") | ||
|  |     print() | ||
|  | 
 | ||
|  |     # Test forward pass without states | ||
|  |     print("Testing forward pass (without states)...") | ||
|  |     with torch.no_grad(): | ||
|  |         logits = model(x, day_idx) | ||
|  |         print(f"Output logits shape: {logits.shape}") | ||
|  |         print(f"Expected shape: [{batch_size}, {seq_len}, {n_classes}]") | ||
|  |         assert logits.shape == (batch_size, seq_len, n_classes), f"Shape mismatch! Got {logits.shape}" | ||
|  |         print("✓ Forward pass successful") | ||
|  |         print() | ||
|  | 
 | ||
|  |     # Test forward pass with state return | ||
|  |     print("Testing forward pass (with state return)...") | ||
|  |     with torch.no_grad(): | ||
|  |         logits, states = model(x, day_idx, return_state=True) | ||
|  |         regression_states, residual_states = states | ||
|  |         print(f"Output logits shape: {logits.shape}") | ||
|  |         print(f"Regression states shape: {regression_states.shape}") | ||
|  |         print(f"Residual states shape: {residual_states.shape}") | ||
|  |         print("✓ Forward pass with states successful") | ||
|  |         print() | ||
|  | 
 | ||
|  |     # Test with patch processing | ||
|  |     print("Testing with patch processing...") | ||
|  |     model_with_patches = GRUDecoder( | ||
|  |         neural_dim=neural_dim, | ||
|  |         n_units=n_units, | ||
|  |         n_days=n_days, | ||
|  |         n_classes=n_classes, | ||
|  |         rnn_dropout=0.1, | ||
|  |         input_dropout=0.1, | ||
|  |         n_layers_regression=2, | ||
|  |         n_layers_residual=3, | ||
|  |         patch_size=14, | ||
|  |         patch_stride=4 | ||
|  |     ) | ||
|  | 
 | ||
|  |     with torch.no_grad(): | ||
|  |         logits_patches = model_with_patches(x, day_idx) | ||
|  |         expected_patches = (seq_len - 14) // 4 + 1  # Number of patches | ||
|  |         print(f"Output logits shape (with patches): {logits_patches.shape}") | ||
|  |         print(f"Expected patches: {expected_patches}") | ||
|  |         print("✓ Patch processing successful") | ||
|  |         print() | ||
|  | 
 | ||
|  |     # Test gradient flow | ||
|  |     print("Testing gradient flow...") | ||
|  |     model.train() | ||
|  |     x_grad = torch.randn(batch_size, seq_len, neural_dim, requires_grad=True) | ||
|  |     logits = model(x_grad, day_idx) | ||
|  |     loss = logits.sum() | ||
|  |     loss.backward() | ||
|  | 
 | ||
|  |     # Check if gradients exist | ||
|  |     regression_grad_exists = any(p.grad is not None for p in model.gru_regression.parameters()) | ||
|  |     residual_grad_exists = any(p.grad is not None for p in model.gru_residual.parameters()) | ||
|  |     day_grad_exists = any(p.grad is not None for p in model.day_weights) | ||
|  | 
 | ||
|  |     print(f"Regression GRU gradients: {'✓' if regression_grad_exists else '✗'}") | ||
|  |     print(f"Residual GRU gradients: {'✓' if residual_grad_exists else '✗'}") | ||
|  |     print(f"Day-specific layer gradients: {'✓' if day_grad_exists else '✗'}") | ||
|  |     print("✓ Gradient flow test successful") | ||
|  |     print() | ||
|  | 
 | ||
|  |     print("=== All tests passed! ===") | ||
|  |     print() | ||
|  | 
 | ||
|  |     # Print architecture summary | ||
|  |     print("=== Architecture Summary ===") | ||
|  |     print("Data Flow:") | ||
|  |     print("1. Input → Day-specific layers (512 → 512)") | ||
|  |     print("2. Day output → Regression GRU (2 layers, 512 hidden)") | ||
|  |     print("3. Residual = Day output - Regression output") | ||
|  |     print("4. Residual → Residual GRU (3 layers, 768 hidden)") | ||
|  |     print("5. Residual GRU output → Linear classifier (768 → 41)") | ||
|  |     print() | ||
|  | 
 | ||
|  |     return True | ||
|  | 
 | ||
|  | if __name__ == "__main__": | ||
|  |     test_dual_gru_architecture() |