final version? maybe
This commit is contained in:
154
model_training_nnn/test_xla_model.py
Normal file
154
model_training_nnn/test_xla_model.py
Normal file
@@ -0,0 +1,154 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
XLA Model Verification Script
|
||||
验证XLA优化后的模型输出与原始模型保持一致
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add the model training directory to path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from rnn_model import TripleGRUDecoder
|
||||
|
||||
def create_test_data(batch_size=4, seq_len=100, neural_dim=512, n_days=10):
|
||||
"""Create synthetic test data matching expected model inputs"""
|
||||
# Create random neural features
|
||||
features = torch.randn(batch_size, seq_len, neural_dim)
|
||||
|
||||
# Create random day indices (should be valid indices < n_days)
|
||||
day_indices = torch.randint(0, n_days, (batch_size,))
|
||||
|
||||
return features, day_indices
|
||||
|
||||
def test_model_consistency():
|
||||
"""Test that XLA-optimized model produces consistent outputs"""
|
||||
|
||||
print("Testing XLA-optimized TripleGRUDecoder consistency...")
|
||||
|
||||
# Model parameters (matching typical configuration)
|
||||
neural_dim = 512
|
||||
n_units = 768
|
||||
n_days = 10
|
||||
n_classes = 40 # Typical phoneme count
|
||||
batch_size = 4
|
||||
seq_len = 100
|
||||
patch_size = 14
|
||||
patch_stride = 1
|
||||
|
||||
# Create model
|
||||
model = TripleGRUDecoder(
|
||||
neural_dim=neural_dim,
|
||||
n_units=n_units,
|
||||
n_days=n_days,
|
||||
n_classes=n_classes,
|
||||
rnn_dropout=0.0, # Disable dropout for consistent testing
|
||||
input_dropout=0.0,
|
||||
patch_size=patch_size,
|
||||
patch_stride=patch_stride
|
||||
)
|
||||
|
||||
# Set to eval mode for consistent results
|
||||
model.eval()
|
||||
|
||||
# Create test data
|
||||
features, day_indices = create_test_data(batch_size, seq_len, neural_dim, n_days)
|
||||
|
||||
print(f"Test data shapes:")
|
||||
print(f" Features: {features.shape}")
|
||||
print(f" Day indices: {day_indices.shape}")
|
||||
print(f" Day indices values: {day_indices.tolist()}")
|
||||
|
||||
# Test inference mode (most commonly used)
|
||||
print("\n=== Testing Inference Mode ===")
|
||||
with torch.no_grad():
|
||||
try:
|
||||
# Run inference mode
|
||||
clean_logits = model(features, day_indices, states=None, return_state=False, mode='inference')
|
||||
|
||||
print(f"Clean logits shape: {clean_logits.shape}")
|
||||
print(f"Clean logits range: [{clean_logits.min().item():.4f}, {clean_logits.max().item():.4f}]")
|
||||
print("✓ Inference mode successful")
|
||||
|
||||
# Test with return_state=True
|
||||
clean_logits_with_state, noise_hidden = model(features, day_indices, states=None, return_state=True, mode='inference')
|
||||
|
||||
# Verify consistency
|
||||
assert torch.allclose(clean_logits, clean_logits_with_state, rtol=1e-5, atol=1e-6), "Inconsistent outputs with/without return_state"
|
||||
print("✓ return_state consistency verified")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Inference mode failed: {e}")
|
||||
raise
|
||||
|
||||
# Test full mode (training)
|
||||
print("\n=== Testing Full Mode ===")
|
||||
with torch.no_grad():
|
||||
try:
|
||||
# Run full mode
|
||||
clean_logits, noisy_logits, noise_output = model(features, day_indices, states=None, return_state=False, mode='full')
|
||||
|
||||
print(f"Clean logits shape: {clean_logits.shape}")
|
||||
print(f"Noisy logits shape: {noisy_logits.shape}")
|
||||
print(f"Noise output shape: {noise_output.shape}")
|
||||
print("✓ Full mode successful")
|
||||
|
||||
# Test with return_state=True
|
||||
(clean_logits_with_state, noisy_logits_with_state, noise_output_with_state), noise_hidden = model(
|
||||
features, day_indices, states=None, return_state=True, mode='full')
|
||||
|
||||
# Verify consistency
|
||||
assert torch.allclose(clean_logits, clean_logits_with_state, rtol=1e-5, atol=1e-6), "Inconsistent clean logits"
|
||||
assert torch.allclose(noisy_logits, noisy_logits_with_state, rtol=1e-5, atol=1e-6), "Inconsistent noisy logits"
|
||||
assert torch.allclose(noise_output, noise_output_with_state, rtol=1e-5, atol=1e-6), "Inconsistent noise output"
|
||||
print("✓ return_state consistency verified")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Full mode failed: {e}")
|
||||
raise
|
||||
|
||||
# Test multiple runs for consistency
|
||||
print("\n=== Testing Multiple Run Consistency ===")
|
||||
with torch.no_grad():
|
||||
try:
|
||||
# Run same input multiple times
|
||||
results = []
|
||||
for i in range(3):
|
||||
result = model(features, day_indices, states=None, return_state=False, mode='inference')
|
||||
results.append(result)
|
||||
|
||||
# Verify all runs produce identical results
|
||||
for i in range(1, len(results)):
|
||||
assert torch.allclose(results[0], results[i], rtol=1e-7, atol=1e-8), f"Inconsistent results between runs 0 and {i}"
|
||||
|
||||
print("✓ Multiple runs produce identical results")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Multiple run consistency failed: {e}")
|
||||
raise
|
||||
|
||||
# Test different batch sizes
|
||||
print("\n=== Testing Different Batch Sizes ===")
|
||||
with torch.no_grad():
|
||||
try:
|
||||
for test_batch_size in [1, 2, 8]:
|
||||
test_features, test_day_indices = create_test_data(test_batch_size, seq_len, neural_dim, n_days)
|
||||
result = model(test_features, test_day_indices, states=None, return_state=False, mode='inference')
|
||||
|
||||
expected_shape = (test_batch_size, (seq_len - patch_size) // patch_stride + 1, n_classes)
|
||||
assert result.shape == expected_shape, f"Unexpected shape for batch_size={test_batch_size}: {result.shape} vs {expected_shape}"
|
||||
|
||||
print(f"✓ Batch size {test_batch_size}: {result.shape}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Batch size testing failed: {e}")
|
||||
raise
|
||||
|
||||
print("\n🎉 All tests passed! XLA-optimized model is working correctly.")
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_model_consistency()
|
Reference in New Issue
Block a user