154 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			154 lines
		
	
	
		
			5.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | ||
|  | """
 | ||
|  | XLA Model Verification Script | ||
|  | 验证XLA优化后的模型输出与原始模型保持一致 | ||
|  | """
 | ||
|  | 
 | ||
|  | import torch | ||
|  | import torch.nn as nn | ||
|  | import sys | ||
|  | import os | ||
|  | 
 | ||
|  | # Add the model training directory to path | ||
|  | sys.path.append(os.path.dirname(os.path.abspath(__file__))) | ||
|  | 
 | ||
|  | from rnn_model import TripleGRUDecoder | ||
|  | 
 | ||
|  | def create_test_data(batch_size=4, seq_len=100, neural_dim=512, n_days=10): | ||
|  |     """Create synthetic test data matching expected model inputs""" | ||
|  |     # Create random neural features | ||
|  |     features = torch.randn(batch_size, seq_len, neural_dim) | ||
|  | 
 | ||
|  |     # Create random day indices (should be valid indices < n_days) | ||
|  |     day_indices = torch.randint(0, n_days, (batch_size,)) | ||
|  | 
 | ||
|  |     return features, day_indices | ||
|  | 
 | ||
|  | def test_model_consistency(): | ||
|  |     """Test that XLA-optimized model produces consistent outputs""" | ||
|  | 
 | ||
|  |     print("Testing XLA-optimized TripleGRUDecoder consistency...") | ||
|  | 
 | ||
|  |     # Model parameters (matching typical configuration) | ||
|  |     neural_dim = 512 | ||
|  |     n_units = 768 | ||
|  |     n_days = 10 | ||
|  |     n_classes = 40  # Typical phoneme count | ||
|  |     batch_size = 4 | ||
|  |     seq_len = 100 | ||
|  |     patch_size = 14 | ||
|  |     patch_stride = 1 | ||
|  | 
 | ||
|  |     # Create model | ||
|  |     model = TripleGRUDecoder( | ||
|  |         neural_dim=neural_dim, | ||
|  |         n_units=n_units, | ||
|  |         n_days=n_days, | ||
|  |         n_classes=n_classes, | ||
|  |         rnn_dropout=0.0,  # Disable dropout for consistent testing | ||
|  |         input_dropout=0.0, | ||
|  |         patch_size=patch_size, | ||
|  |         patch_stride=patch_stride | ||
|  |     ) | ||
|  | 
 | ||
|  |     # Set to eval mode for consistent results | ||
|  |     model.eval() | ||
|  | 
 | ||
|  |     # Create test data | ||
|  |     features, day_indices = create_test_data(batch_size, seq_len, neural_dim, n_days) | ||
|  | 
 | ||
|  |     print(f"Test data shapes:") | ||
|  |     print(f"  Features: {features.shape}") | ||
|  |     print(f"  Day indices: {day_indices.shape}") | ||
|  |     print(f"  Day indices values: {day_indices.tolist()}") | ||
|  | 
 | ||
|  |     # Test inference mode (most commonly used) | ||
|  |     print("\n=== Testing Inference Mode ===") | ||
|  |     with torch.no_grad(): | ||
|  |         try: | ||
|  |             # Run inference mode | ||
|  |             clean_logits = model(features, day_indices, states=None, return_state=False, mode='inference') | ||
|  | 
 | ||
|  |             print(f"Clean logits shape: {clean_logits.shape}") | ||
|  |             print(f"Clean logits range: [{clean_logits.min().item():.4f}, {clean_logits.max().item():.4f}]") | ||
|  |             print("✓ Inference mode successful") | ||
|  | 
 | ||
|  |             # Test with return_state=True | ||
|  |             clean_logits_with_state, noise_hidden = model(features, day_indices, states=None, return_state=True, mode='inference') | ||
|  | 
 | ||
|  |             # Verify consistency | ||
|  |             assert torch.allclose(clean_logits, clean_logits_with_state, rtol=1e-5, atol=1e-6), "Inconsistent outputs with/without return_state" | ||
|  |             print("✓ return_state consistency verified") | ||
|  | 
 | ||
|  |         except Exception as e: | ||
|  |             print(f"✗ Inference mode failed: {e}") | ||
|  |             raise | ||
|  | 
 | ||
|  |     # Test full mode (training) | ||
|  |     print("\n=== Testing Full Mode ===") | ||
|  |     with torch.no_grad(): | ||
|  |         try: | ||
|  |             # Run full mode | ||
|  |             clean_logits, noisy_logits, noise_output = model(features, day_indices, states=None, return_state=False, mode='full') | ||
|  | 
 | ||
|  |             print(f"Clean logits shape: {clean_logits.shape}") | ||
|  |             print(f"Noisy logits shape: {noisy_logits.shape}") | ||
|  |             print(f"Noise output shape: {noise_output.shape}") | ||
|  |             print("✓ Full mode successful") | ||
|  | 
 | ||
|  |             # Test with return_state=True | ||
|  |             (clean_logits_with_state, noisy_logits_with_state, noise_output_with_state), noise_hidden = model( | ||
|  |                 features, day_indices, states=None, return_state=True, mode='full') | ||
|  | 
 | ||
|  |             # Verify consistency | ||
|  |             assert torch.allclose(clean_logits, clean_logits_with_state, rtol=1e-5, atol=1e-6), "Inconsistent clean logits" | ||
|  |             assert torch.allclose(noisy_logits, noisy_logits_with_state, rtol=1e-5, atol=1e-6), "Inconsistent noisy logits" | ||
|  |             assert torch.allclose(noise_output, noise_output_with_state, rtol=1e-5, atol=1e-6), "Inconsistent noise output" | ||
|  |             print("✓ return_state consistency verified") | ||
|  | 
 | ||
|  |         except Exception as e: | ||
|  |             print(f"✗ Full mode failed: {e}") | ||
|  |             raise | ||
|  | 
 | ||
|  |     # Test multiple runs for consistency | ||
|  |     print("\n=== Testing Multiple Run Consistency ===") | ||
|  |     with torch.no_grad(): | ||
|  |         try: | ||
|  |             # Run same input multiple times | ||
|  |             results = [] | ||
|  |             for i in range(3): | ||
|  |                 result = model(features, day_indices, states=None, return_state=False, mode='inference') | ||
|  |                 results.append(result) | ||
|  | 
 | ||
|  |             # Verify all runs produce identical results | ||
|  |             for i in range(1, len(results)): | ||
|  |                 assert torch.allclose(results[0], results[i], rtol=1e-7, atol=1e-8), f"Inconsistent results between runs 0 and {i}" | ||
|  | 
 | ||
|  |             print("✓ Multiple runs produce identical results") | ||
|  | 
 | ||
|  |         except Exception as e: | ||
|  |             print(f"✗ Multiple run consistency failed: {e}") | ||
|  |             raise | ||
|  | 
 | ||
|  |     # Test different batch sizes | ||
|  |     print("\n=== Testing Different Batch Sizes ===") | ||
|  |     with torch.no_grad(): | ||
|  |         try: | ||
|  |             for test_batch_size in [1, 2, 8]: | ||
|  |                 test_features, test_day_indices = create_test_data(test_batch_size, seq_len, neural_dim, n_days) | ||
|  |                 result = model(test_features, test_day_indices, states=None, return_state=False, mode='inference') | ||
|  | 
 | ||
|  |                 expected_shape = (test_batch_size, (seq_len - patch_size) // patch_stride + 1, n_classes) | ||
|  |                 assert result.shape == expected_shape, f"Unexpected shape for batch_size={test_batch_size}: {result.shape} vs {expected_shape}" | ||
|  | 
 | ||
|  |                 print(f"✓ Batch size {test_batch_size}: {result.shape}") | ||
|  | 
 | ||
|  |         except Exception as e: | ||
|  |             print(f"✗ Batch size testing failed: {e}") | ||
|  |             raise | ||
|  | 
 | ||
|  |     print("\n🎉 All tests passed! XLA-optimized model is working correctly.") | ||
|  |     return True | ||
|  | 
 | ||
|  | if __name__ == "__main__": | ||
|  |     test_model_consistency() |