52 lines
1.3 KiB
Python
52 lines
1.3 KiB
Python
![]() |
#!/usr/bin/env python3
|
||
|
"""
|
||
|
Quick XLA Model Test
|
||
|
"""
|
||
|
|
||
|
import torch
|
||
|
import sys
|
||
|
import os
|
||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||
|
|
||
|
from rnn_model import TripleGRUDecoder
|
||
|
|
||
|
def quick_test():
|
||
|
print("Quick XLA model test...")
|
||
|
|
||
|
# Small model for fast testing
|
||
|
model = TripleGRUDecoder(
|
||
|
neural_dim=64, # Smaller
|
||
|
n_units=128, # Smaller
|
||
|
n_days=3, # Smaller
|
||
|
n_classes=10, # Smaller
|
||
|
rnn_dropout=0.0,
|
||
|
input_dropout=0.0,
|
||
|
patch_size=4, # Smaller
|
||
|
patch_stride=1
|
||
|
)
|
||
|
|
||
|
model.eval()
|
||
|
|
||
|
# Small test data
|
||
|
batch_size, seq_len = 2, 20
|
||
|
features = torch.randn(batch_size, seq_len, 64)
|
||
|
day_indices = torch.tensor([0, 1])
|
||
|
|
||
|
print(f"Input shape: {features.shape}")
|
||
|
print(f"Day indices: {day_indices}")
|
||
|
|
||
|
# Test inference
|
||
|
with torch.no_grad():
|
||
|
result = model(features, day_indices, mode='inference')
|
||
|
print(f"Inference result shape: {result.shape}")
|
||
|
print("✓ Inference mode works")
|
||
|
|
||
|
# Test full mode
|
||
|
clean, noisy, noise = model(features, day_indices, mode='full')
|
||
|
print(f"Full mode shapes: clean={clean.shape}, noisy={noisy.shape}, noise={noise.shape}")
|
||
|
print("✓ Full mode works")
|
||
|
|
||
|
print("🎉 Quick test passed!")
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
quick_test()
|