52 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			52 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | ||
|  | """
 | ||
|  | Quick XLA Model Test | ||
|  | """
 | ||
|  | 
 | ||
|  | import torch | ||
|  | import sys | ||
|  | import os | ||
|  | sys.path.append(os.path.dirname(os.path.abspath(__file__))) | ||
|  | 
 | ||
|  | from rnn_model import TripleGRUDecoder | ||
|  | 
 | ||
|  | def quick_test(): | ||
|  |     print("Quick XLA model test...") | ||
|  | 
 | ||
|  |     # Small model for fast testing | ||
|  |     model = TripleGRUDecoder( | ||
|  |         neural_dim=64,  # Smaller | ||
|  |         n_units=128,    # Smaller | ||
|  |         n_days=3,       # Smaller | ||
|  |         n_classes=10,   # Smaller | ||
|  |         rnn_dropout=0.0, | ||
|  |         input_dropout=0.0, | ||
|  |         patch_size=4,   # Smaller | ||
|  |         patch_stride=1 | ||
|  |     ) | ||
|  | 
 | ||
|  |     model.eval() | ||
|  | 
 | ||
|  |     # Small test data | ||
|  |     batch_size, seq_len = 2, 20 | ||
|  |     features = torch.randn(batch_size, seq_len, 64) | ||
|  |     day_indices = torch.tensor([0, 1]) | ||
|  | 
 | ||
|  |     print(f"Input shape: {features.shape}") | ||
|  |     print(f"Day indices: {day_indices}") | ||
|  | 
 | ||
|  |     # Test inference | ||
|  |     with torch.no_grad(): | ||
|  |         result = model(features, day_indices, mode='inference') | ||
|  |         print(f"Inference result shape: {result.shape}") | ||
|  |         print("✓ Inference mode works") | ||
|  | 
 | ||
|  |         # Test full mode | ||
|  |         clean, noisy, noise = model(features, day_indices, mode='full') | ||
|  |         print(f"Full mode shapes: clean={clean.shape}, noisy={noisy.shape}, noise={noise.shape}") | ||
|  |         print("✓ Full mode works") | ||
|  | 
 | ||
|  |     print("🎉 Quick test passed!") | ||
|  | 
 | ||
|  | if __name__ == "__main__": | ||
|  |     quick_test() |