233 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			233 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | ||
|  | """
 | ||
|  | Test script for the triple-GRU adversarial architecture | ||
|  | """
 | ||
|  | 
 | ||
|  | import torch | ||
|  | import sys | ||
|  | import os | ||
|  | 
 | ||
|  | # Add the current directory to the path to import rnn_model | ||
|  | sys.path.append(os.path.dirname(os.path.abspath(__file__))) | ||
|  | 
 | ||
|  | from rnn_model import TripleGRUDecoder, NoiseModel, CleanSpeechModel, NoisySpeechModel | ||
|  | 
 | ||
|  | def test_individual_models(): | ||
|  |     """Test each individual model first""" | ||
|  |     print("=== Testing Individual Models ===") | ||
|  | 
 | ||
|  |     # Model parameters | ||
|  |     neural_dim = 512 | ||
|  |     n_units = 768 | ||
|  |     n_days = 5 | ||
|  |     n_classes = 41 | ||
|  |     batch_size = 4 | ||
|  |     seq_len = 100 | ||
|  | 
 | ||
|  |     # Create synthetic input | ||
|  |     x = torch.randn(batch_size, seq_len, neural_dim) | ||
|  |     day_idx = torch.randint(0, n_days, (batch_size,)) | ||
|  | 
 | ||
|  |     print(f"Input shape: {x.shape}") | ||
|  |     print(f"Day indices: {day_idx}") | ||
|  |     print() | ||
|  | 
 | ||
|  |     # Test NoiseModel | ||
|  |     print("1. Testing NoiseModel...") | ||
|  |     noise_model = NoiseModel(neural_dim, n_units, n_days) | ||
|  |     with torch.no_grad(): | ||
|  |         noise_out, noise_hidden = noise_model(x, day_idx) | ||
|  |         print(f"   Noise output shape: {noise_out.shape}") | ||
|  |         print(f"   Noise hidden shape: {noise_hidden.shape}") | ||
|  |         print("   ✓ NoiseModel working") | ||
|  |     print() | ||
|  | 
 | ||
|  |     # Test CleanSpeechModel | ||
|  |     print("2. Testing CleanSpeechModel...") | ||
|  |     clean_model = CleanSpeechModel(neural_dim, n_units, n_days, n_classes) | ||
|  |     with torch.no_grad(): | ||
|  |         clean_logits = clean_model(x, day_idx) | ||
|  |         print(f"   Clean logits shape: {clean_logits.shape}") | ||
|  |         print("   ✓ CleanSpeechModel working") | ||
|  |     print() | ||
|  | 
 | ||
|  |     # Test NoisySpeechModel | ||
|  |     print("3. Testing NoisySpeechModel...") | ||
|  |     noisy_model = NoisySpeechModel(neural_dim, n_units, n_days, n_classes) | ||
|  |     with torch.no_grad(): | ||
|  |         noisy_logits = noisy_model(noise_out)  # Use noise output as input | ||
|  |         print(f"   Noisy logits shape: {noisy_logits.shape}") | ||
|  |         print("   ✓ NoisySpeechModel working") | ||
|  |     print() | ||
|  | 
 | ||
|  |     return True | ||
|  | 
 | ||
|  | def test_triple_gru_architecture(): | ||
|  |     """Test the complete triple-GRU architecture""" | ||
|  |     print("=== Triple-GRU Architecture Test ===") | ||
|  | 
 | ||
|  |     # Model parameters | ||
|  |     neural_dim = 512 | ||
|  |     n_units = 768 | ||
|  |     n_days = 5 | ||
|  |     n_classes = 41 | ||
|  |     batch_size = 4 | ||
|  |     seq_len = 100 | ||
|  | 
 | ||
|  |     print(f"Neural dim: {neural_dim}") | ||
|  |     print(f"Hidden units: {n_units}") | ||
|  |     print(f"Days: {n_days}") | ||
|  |     print(f"Classes: {n_classes}") | ||
|  |     print() | ||
|  | 
 | ||
|  |     # Create model instance | ||
|  |     model = TripleGRUDecoder( | ||
|  |         neural_dim=neural_dim, | ||
|  |         n_units=n_units, | ||
|  |         n_days=n_days, | ||
|  |         n_classes=n_classes, | ||
|  |         rnn_dropout=0.1, | ||
|  |         input_dropout=0.1, | ||
|  |         patch_size=0,  # Start without patching | ||
|  |         patch_stride=0 | ||
|  |     ) | ||
|  | 
 | ||
|  |     print(f"Model created successfully!") | ||
|  |     print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}") | ||
|  |     print() | ||
|  | 
 | ||
|  |     # Create synthetic input data | ||
|  |     x = torch.randn(batch_size, seq_len, neural_dim) | ||
|  |     day_idx = torch.randint(0, n_days, (batch_size,)) | ||
|  | 
 | ||
|  |     print(f"Input shape: {x.shape}") | ||
|  |     print(f"Day indices: {day_idx}") | ||
|  |     print() | ||
|  | 
 | ||
|  |     # Test full training mode | ||
|  |     print("Testing full training mode...") | ||
|  |     with torch.no_grad(): | ||
|  |         outputs = model(x, day_idx, mode='full') | ||
|  | 
 | ||
|  |         print(f"Clean logits shape: {outputs['clean_logits'].shape}") | ||
|  |         print(f"Noisy logits shape: {outputs['noisy_logits'].shape}") | ||
|  |         print(f"Noise output shape: {outputs['noise_output'].shape}") | ||
|  | 
 | ||
|  |         # Verify shapes | ||
|  |         assert outputs['clean_logits'].shape == (batch_size, seq_len, n_classes), f"Clean logits shape mismatch" | ||
|  |         assert outputs['noisy_logits'].shape == (batch_size, seq_len, n_classes), f"Noisy logits shape mismatch" | ||
|  |         assert outputs['noise_output'].shape == (batch_size, seq_len, neural_dim), f"Noise output shape mismatch" | ||
|  |         print("✓ Full training mode successful") | ||
|  |     print() | ||
|  | 
 | ||
|  |     # Test inference mode | ||
|  |     print("Testing inference mode...") | ||
|  |     with torch.no_grad(): | ||
|  |         logits = model(x, day_idx, mode='inference') | ||
|  |         print(f"Inference logits shape: {logits.shape}") | ||
|  |         assert logits.shape == (batch_size, seq_len, n_classes), f"Inference logits shape mismatch" | ||
|  |         print("✓ Inference mode successful") | ||
|  |     print() | ||
|  | 
 | ||
|  |     # Test with patch processing | ||
|  |     print("Testing with patch processing...") | ||
|  |     model_with_patches = TripleGRUDecoder( | ||
|  |         neural_dim=neural_dim, | ||
|  |         n_units=n_units, | ||
|  |         n_days=n_days, | ||
|  |         n_classes=n_classes, | ||
|  |         rnn_dropout=0.1, | ||
|  |         input_dropout=0.1, | ||
|  |         patch_size=14, | ||
|  |         patch_stride=4 | ||
|  |     ) | ||
|  | 
 | ||
|  |     with torch.no_grad(): | ||
|  |         outputs_patches = model_with_patches(x, day_idx, mode='full') | ||
|  |         expected_patches = (seq_len - 14) // 4 + 1 | ||
|  |         print(f"Output shapes with patches:") | ||
|  |         print(f"  Clean logits: {outputs_patches['clean_logits'].shape}") | ||
|  |         print(f"  Expected patches: {expected_patches}") | ||
|  |         print("✓ Patch processing successful") | ||
|  |     print() | ||
|  | 
 | ||
|  |     # Test gradient flow | ||
|  |     print("Testing gradient flow...") | ||
|  |     model.train() | ||
|  |     x_grad = torch.randn(batch_size, seq_len, neural_dim, requires_grad=True) | ||
|  | 
 | ||
|  |     # Forward pass in training mode | ||
|  |     outputs = model(x_grad, day_idx, mode='full') | ||
|  | 
 | ||
|  |     # Calculate losses | ||
|  |     target = torch.randint(0, n_classes, (batch_size, seq_len)) | ||
|  |     clean_loss = torch.nn.functional.cross_entropy( | ||
|  |         outputs['clean_logits'].reshape(-1, n_classes), | ||
|  |         target.reshape(-1) | ||
|  |     ) | ||
|  |     noisy_loss = torch.nn.functional.cross_entropy( | ||
|  |         outputs['noisy_logits'].reshape(-1, n_classes), | ||
|  |         target.reshape(-1) | ||
|  |     ) | ||
|  | 
 | ||
|  |     # Backward passes | ||
|  |     clean_loss.backward(retain_graph=True) | ||
|  |     noisy_loss.backward() | ||
|  | 
 | ||
|  |     # Check gradients | ||
|  |     clean_has_grad = any(p.grad is not None for p in model.clean_speech_model.parameters()) | ||
|  |     noisy_has_grad = any(p.grad is not None for p in model.noisy_speech_model.parameters()) | ||
|  |     noise_has_grad = any(p.grad is not None for p in model.noise_model.parameters()) | ||
|  | 
 | ||
|  |     print(f"Clean model gradients: {'✓' if clean_has_grad else '✗'}") | ||
|  |     print(f"Noisy model gradients: {'✓' if noisy_has_grad else '✗'}") | ||
|  |     print(f"Noise model gradients: {'✓' if noise_has_grad else '✗'}") | ||
|  |     print("✓ Gradient flow test successful") | ||
|  |     print() | ||
|  | 
 | ||
|  |     return True | ||
|  | 
 | ||
|  | def test_adversarial_training_simulation(): | ||
|  |     """Simulate adversarial training with gradient combination""" | ||
|  |     print("=== Adversarial Training Simulation ===") | ||
|  | 
 | ||
|  |     # Simple test to verify gradient combination logic | ||
|  |     model = TripleGRUDecoder(512, 768, 5, 41, rnn_dropout=0.1) | ||
|  | 
 | ||
|  |     # Create fake gradients | ||
|  |     fake_clean_grad = torch.randn(41, 768)  # Output layer gradients | ||
|  |     fake_noisy_grad = torch.randn(41, 768) | ||
|  | 
 | ||
|  |     print("Testing gradient combination...") | ||
|  |     try: | ||
|  |         model.apply_gradient_combination(fake_clean_grad, fake_noisy_grad, learning_rate=1e-3) | ||
|  |         print("✓ Gradient combination mechanism working") | ||
|  |     except Exception as e: | ||
|  |         print(f"✗ Gradient combination failed: {e}") | ||
|  | 
 | ||
|  |     print() | ||
|  |     return True | ||
|  | 
 | ||
|  | if __name__ == "__main__": | ||
|  |     print("Starting comprehensive tests for Triple-GRU architecture...\n") | ||
|  | 
 | ||
|  |     # Run all tests | ||
|  |     test_individual_models() | ||
|  |     test_triple_gru_architecture() | ||
|  |     test_adversarial_training_simulation() | ||
|  | 
 | ||
|  |     print("=== All tests completed! ===") | ||
|  |     print() | ||
|  | 
 | ||
|  |     # Print architecture summary | ||
|  |     print("=== Triple-GRU Architecture Summary ===") | ||
|  |     print("Training Mode Data Flow:") | ||
|  |     print("1. Input → NoiseModel → Noise Estimation") | ||
|  |     print("2. Input - Noise → CleanSpeechModel → Clean Recognition") | ||
|  |     print("3. Noise → NoisySpeechModel → Noisy Recognition") | ||
|  |     print("4. Gradient Combination: -Clean_grad + Noisy_grad → NoiseModel") | ||
|  |     print() | ||
|  |     print("Inference Mode Data Flow:") | ||
|  |     print("1. Input → NoiseModel → Noise Estimation") | ||
|  |     print("2. Input - Noise → CleanSpeechModel → Final Recognition") | ||
|  |     print() |