#!/usr/bin/env python3 """ XLA Model Verification Script 验证XLA优化后的模型输出与原始模型保持一致 """ import torch import torch.nn as nn import sys import os # Add the model training directory to path sys.path.append(os.path.dirname(os.path.abspath(__file__))) from rnn_model import TripleGRUDecoder def create_test_data(batch_size=4, seq_len=100, neural_dim=512, n_days=10): """Create synthetic test data matching expected model inputs""" # Create random neural features features = torch.randn(batch_size, seq_len, neural_dim) # Create random day indices (should be valid indices < n_days) day_indices = torch.randint(0, n_days, (batch_size,)) return features, day_indices def test_model_consistency(): """Test that XLA-optimized model produces consistent outputs""" print("Testing XLA-optimized TripleGRUDecoder consistency...") # Model parameters (matching typical configuration) neural_dim = 512 n_units = 768 n_days = 10 n_classes = 40 # Typical phoneme count batch_size = 4 seq_len = 100 patch_size = 14 patch_stride = 1 # Create model model = TripleGRUDecoder( neural_dim=neural_dim, n_units=n_units, n_days=n_days, n_classes=n_classes, rnn_dropout=0.0, # Disable dropout for consistent testing input_dropout=0.0, patch_size=patch_size, patch_stride=patch_stride ) # Set to eval mode for consistent results model.eval() # Create test data features, day_indices = create_test_data(batch_size, seq_len, neural_dim, n_days) print(f"Test data shapes:") print(f" Features: {features.shape}") print(f" Day indices: {day_indices.shape}") print(f" Day indices values: {day_indices.tolist()}") # Test inference mode (most commonly used) print("\n=== Testing Inference Mode ===") with torch.no_grad(): try: # Run inference mode clean_logits = model(features, day_indices, states=None, return_state=False, mode='inference') print(f"Clean logits shape: {clean_logits.shape}") print(f"Clean logits range: [{clean_logits.min().item():.4f}, {clean_logits.max().item():.4f}]") print("✓ Inference mode successful") # Test with return_state=True clean_logits_with_state, noise_hidden = model(features, day_indices, states=None, return_state=True, mode='inference') # Verify consistency assert torch.allclose(clean_logits, clean_logits_with_state, rtol=1e-5, atol=1e-6), "Inconsistent outputs with/without return_state" print("✓ return_state consistency verified") except Exception as e: print(f"✗ Inference mode failed: {e}") raise # Test full mode (training) print("\n=== Testing Full Mode ===") with torch.no_grad(): try: # Run full mode clean_logits, noisy_logits, noise_output = model(features, day_indices, states=None, return_state=False, mode='full') print(f"Clean logits shape: {clean_logits.shape}") print(f"Noisy logits shape: {noisy_logits.shape}") print(f"Noise output shape: {noise_output.shape}") print("✓ Full mode successful") # Test with return_state=True (clean_logits_with_state, noisy_logits_with_state, noise_output_with_state), noise_hidden = model( features, day_indices, states=None, return_state=True, mode='full') # Verify consistency assert torch.allclose(clean_logits, clean_logits_with_state, rtol=1e-5, atol=1e-6), "Inconsistent clean logits" assert torch.allclose(noisy_logits, noisy_logits_with_state, rtol=1e-5, atol=1e-6), "Inconsistent noisy logits" assert torch.allclose(noise_output, noise_output_with_state, rtol=1e-5, atol=1e-6), "Inconsistent noise output" print("✓ return_state consistency verified") except Exception as e: print(f"✗ Full mode failed: {e}") raise # Test multiple runs for consistency print("\n=== Testing Multiple Run Consistency ===") with torch.no_grad(): try: # Run same input multiple times results = [] for i in range(3): result = model(features, day_indices, states=None, return_state=False, mode='inference') results.append(result) # Verify all runs produce identical results for i in range(1, len(results)): assert torch.allclose(results[0], results[i], rtol=1e-7, atol=1e-8), f"Inconsistent results between runs 0 and {i}" print("✓ Multiple runs produce identical results") except Exception as e: print(f"✗ Multiple run consistency failed: {e}") raise # Test different batch sizes print("\n=== Testing Different Batch Sizes ===") with torch.no_grad(): try: for test_batch_size in [1, 2, 8]: test_features, test_day_indices = create_test_data(test_batch_size, seq_len, neural_dim, n_days) result = model(test_features, test_day_indices, states=None, return_state=False, mode='inference') expected_shape = (test_batch_size, (seq_len - patch_size) // patch_stride + 1, n_classes) assert result.shape == expected_shape, f"Unexpected shape for batch_size={test_batch_size}: {result.shape} vs {expected_shape}" print(f"✓ Batch size {test_batch_size}: {result.shape}") except Exception as e: print(f"✗ Batch size testing failed: {e}") raise print("\n🎉 All tests passed! XLA-optimized model is working correctly.") return True if __name__ == "__main__": test_model_consistency()