154 lines
5.8 KiB
Python
154 lines
5.8 KiB
Python
![]() |
#!/usr/bin/env python3
|
||
|
"""
|
||
|
XLA Model Verification Script
|
||
|
验证XLA优化后的模型输出与原始模型保持一致
|
||
|
"""
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import sys
|
||
|
import os
|
||
|
|
||
|
# Add the model training directory to path
|
||
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||
|
|
||
|
from rnn_model import TripleGRUDecoder
|
||
|
|
||
|
def create_test_data(batch_size=4, seq_len=100, neural_dim=512, n_days=10):
|
||
|
"""Create synthetic test data matching expected model inputs"""
|
||
|
# Create random neural features
|
||
|
features = torch.randn(batch_size, seq_len, neural_dim)
|
||
|
|
||
|
# Create random day indices (should be valid indices < n_days)
|
||
|
day_indices = torch.randint(0, n_days, (batch_size,))
|
||
|
|
||
|
return features, day_indices
|
||
|
|
||
|
def test_model_consistency():
|
||
|
"""Test that XLA-optimized model produces consistent outputs"""
|
||
|
|
||
|
print("Testing XLA-optimized TripleGRUDecoder consistency...")
|
||
|
|
||
|
# Model parameters (matching typical configuration)
|
||
|
neural_dim = 512
|
||
|
n_units = 768
|
||
|
n_days = 10
|
||
|
n_classes = 40 # Typical phoneme count
|
||
|
batch_size = 4
|
||
|
seq_len = 100
|
||
|
patch_size = 14
|
||
|
patch_stride = 1
|
||
|
|
||
|
# Create model
|
||
|
model = TripleGRUDecoder(
|
||
|
neural_dim=neural_dim,
|
||
|
n_units=n_units,
|
||
|
n_days=n_days,
|
||
|
n_classes=n_classes,
|
||
|
rnn_dropout=0.0, # Disable dropout for consistent testing
|
||
|
input_dropout=0.0,
|
||
|
patch_size=patch_size,
|
||
|
patch_stride=patch_stride
|
||
|
)
|
||
|
|
||
|
# Set to eval mode for consistent results
|
||
|
model.eval()
|
||
|
|
||
|
# Create test data
|
||
|
features, day_indices = create_test_data(batch_size, seq_len, neural_dim, n_days)
|
||
|
|
||
|
print(f"Test data shapes:")
|
||
|
print(f" Features: {features.shape}")
|
||
|
print(f" Day indices: {day_indices.shape}")
|
||
|
print(f" Day indices values: {day_indices.tolist()}")
|
||
|
|
||
|
# Test inference mode (most commonly used)
|
||
|
print("\n=== Testing Inference Mode ===")
|
||
|
with torch.no_grad():
|
||
|
try:
|
||
|
# Run inference mode
|
||
|
clean_logits = model(features, day_indices, states=None, return_state=False, mode='inference')
|
||
|
|
||
|
print(f"Clean logits shape: {clean_logits.shape}")
|
||
|
print(f"Clean logits range: [{clean_logits.min().item():.4f}, {clean_logits.max().item():.4f}]")
|
||
|
print("✓ Inference mode successful")
|
||
|
|
||
|
# Test with return_state=True
|
||
|
clean_logits_with_state, noise_hidden = model(features, day_indices, states=None, return_state=True, mode='inference')
|
||
|
|
||
|
# Verify consistency
|
||
|
assert torch.allclose(clean_logits, clean_logits_with_state, rtol=1e-5, atol=1e-6), "Inconsistent outputs with/without return_state"
|
||
|
print("✓ return_state consistency verified")
|
||
|
|
||
|
except Exception as e:
|
||
|
print(f"✗ Inference mode failed: {e}")
|
||
|
raise
|
||
|
|
||
|
# Test full mode (training)
|
||
|
print("\n=== Testing Full Mode ===")
|
||
|
with torch.no_grad():
|
||
|
try:
|
||
|
# Run full mode
|
||
|
clean_logits, noisy_logits, noise_output = model(features, day_indices, states=None, return_state=False, mode='full')
|
||
|
|
||
|
print(f"Clean logits shape: {clean_logits.shape}")
|
||
|
print(f"Noisy logits shape: {noisy_logits.shape}")
|
||
|
print(f"Noise output shape: {noise_output.shape}")
|
||
|
print("✓ Full mode successful")
|
||
|
|
||
|
# Test with return_state=True
|
||
|
(clean_logits_with_state, noisy_logits_with_state, noise_output_with_state), noise_hidden = model(
|
||
|
features, day_indices, states=None, return_state=True, mode='full')
|
||
|
|
||
|
# Verify consistency
|
||
|
assert torch.allclose(clean_logits, clean_logits_with_state, rtol=1e-5, atol=1e-6), "Inconsistent clean logits"
|
||
|
assert torch.allclose(noisy_logits, noisy_logits_with_state, rtol=1e-5, atol=1e-6), "Inconsistent noisy logits"
|
||
|
assert torch.allclose(noise_output, noise_output_with_state, rtol=1e-5, atol=1e-6), "Inconsistent noise output"
|
||
|
print("✓ return_state consistency verified")
|
||
|
|
||
|
except Exception as e:
|
||
|
print(f"✗ Full mode failed: {e}")
|
||
|
raise
|
||
|
|
||
|
# Test multiple runs for consistency
|
||
|
print("\n=== Testing Multiple Run Consistency ===")
|
||
|
with torch.no_grad():
|
||
|
try:
|
||
|
# Run same input multiple times
|
||
|
results = []
|
||
|
for i in range(3):
|
||
|
result = model(features, day_indices, states=None, return_state=False, mode='inference')
|
||
|
results.append(result)
|
||
|
|
||
|
# Verify all runs produce identical results
|
||
|
for i in range(1, len(results)):
|
||
|
assert torch.allclose(results[0], results[i], rtol=1e-7, atol=1e-8), f"Inconsistent results between runs 0 and {i}"
|
||
|
|
||
|
print("✓ Multiple runs produce identical results")
|
||
|
|
||
|
except Exception as e:
|
||
|
print(f"✗ Multiple run consistency failed: {e}")
|
||
|
raise
|
||
|
|
||
|
# Test different batch sizes
|
||
|
print("\n=== Testing Different Batch Sizes ===")
|
||
|
with torch.no_grad():
|
||
|
try:
|
||
|
for test_batch_size in [1, 2, 8]:
|
||
|
test_features, test_day_indices = create_test_data(test_batch_size, seq_len, neural_dim, n_days)
|
||
|
result = model(test_features, test_day_indices, states=None, return_state=False, mode='inference')
|
||
|
|
||
|
expected_shape = (test_batch_size, (seq_len - patch_size) // patch_stride + 1, n_classes)
|
||
|
assert result.shape == expected_shape, f"Unexpected shape for batch_size={test_batch_size}: {result.shape} vs {expected_shape}"
|
||
|
|
||
|
print(f"✓ Batch size {test_batch_size}: {result.shape}")
|
||
|
|
||
|
except Exception as e:
|
||
|
print(f"✗ Batch size testing failed: {e}")
|
||
|
raise
|
||
|
|
||
|
print("\n🎉 All tests passed! XLA-optimized model is working correctly.")
|
||
|
return True
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
test_model_consistency()
|