#!/usr/bin/env python3 """ Test script for the triple-GRU adversarial architecture """ import torch import sys import os # Add the current directory to the path to import rnn_model sys.path.append(os.path.dirname(os.path.abspath(__file__))) from rnn_model import TripleGRUDecoder, NoiseModel, CleanSpeechModel, NoisySpeechModel def test_individual_models(): """Test each individual model first""" print("=== Testing Individual Models ===") # Model parameters neural_dim = 512 n_units = 768 n_days = 5 n_classes = 41 batch_size = 4 seq_len = 100 # Create synthetic input x = torch.randn(batch_size, seq_len, neural_dim) day_idx = torch.randint(0, n_days, (batch_size,)) print(f"Input shape: {x.shape}") print(f"Day indices: {day_idx}") print() # Test NoiseModel print("1. Testing NoiseModel...") noise_model = NoiseModel(neural_dim, n_units, n_days) with torch.no_grad(): noise_out, noise_hidden = noise_model(x, day_idx) print(f" Noise output shape: {noise_out.shape}") print(f" Noise hidden shape: {noise_hidden.shape}") print(" ✓ NoiseModel working") print() # Test CleanSpeechModel print("2. Testing CleanSpeechModel...") clean_model = CleanSpeechModel(neural_dim, n_units, n_days, n_classes) with torch.no_grad(): clean_logits = clean_model(x, day_idx) print(f" Clean logits shape: {clean_logits.shape}") print(" ✓ CleanSpeechModel working") print() # Test NoisySpeechModel print("3. Testing NoisySpeechModel...") noisy_model = NoisySpeechModel(neural_dim, n_units, n_days, n_classes) with torch.no_grad(): noisy_logits = noisy_model(noise_out) # Use noise output as input print(f" Noisy logits shape: {noisy_logits.shape}") print(" ✓ NoisySpeechModel working") print() return True def test_triple_gru_architecture(): """Test the complete triple-GRU architecture""" print("=== Triple-GRU Architecture Test ===") # Model parameters neural_dim = 512 n_units = 768 n_days = 5 n_classes = 41 batch_size = 4 seq_len = 100 print(f"Neural dim: {neural_dim}") print(f"Hidden units: {n_units}") print(f"Days: {n_days}") print(f"Classes: {n_classes}") print() # Create model instance model = TripleGRUDecoder( neural_dim=neural_dim, n_units=n_units, n_days=n_days, n_classes=n_classes, rnn_dropout=0.1, input_dropout=0.1, patch_size=0, # Start without patching patch_stride=0 ) print(f"Model created successfully!") print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}") print() # Create synthetic input data x = torch.randn(batch_size, seq_len, neural_dim) day_idx = torch.randint(0, n_days, (batch_size,)) print(f"Input shape: {x.shape}") print(f"Day indices: {day_idx}") print() # Test full training mode print("Testing full training mode...") with torch.no_grad(): outputs = model(x, day_idx, mode='full') print(f"Clean logits shape: {outputs['clean_logits'].shape}") print(f"Noisy logits shape: {outputs['noisy_logits'].shape}") print(f"Noise output shape: {outputs['noise_output'].shape}") # Verify shapes assert outputs['clean_logits'].shape == (batch_size, seq_len, n_classes), f"Clean logits shape mismatch" assert outputs['noisy_logits'].shape == (batch_size, seq_len, n_classes), f"Noisy logits shape mismatch" assert outputs['noise_output'].shape == (batch_size, seq_len, neural_dim), f"Noise output shape mismatch" print("✓ Full training mode successful") print() # Test inference mode print("Testing inference mode...") with torch.no_grad(): logits = model(x, day_idx, mode='inference') print(f"Inference logits shape: {logits.shape}") assert logits.shape == (batch_size, seq_len, n_classes), f"Inference logits shape mismatch" print("✓ Inference mode successful") print() # Test with patch processing print("Testing with patch processing...") model_with_patches = TripleGRUDecoder( neural_dim=neural_dim, n_units=n_units, n_days=n_days, n_classes=n_classes, rnn_dropout=0.1, input_dropout=0.1, patch_size=14, patch_stride=4 ) with torch.no_grad(): outputs_patches = model_with_patches(x, day_idx, mode='full') expected_patches = (seq_len - 14) // 4 + 1 print(f"Output shapes with patches:") print(f" Clean logits: {outputs_patches['clean_logits'].shape}") print(f" Expected patches: {expected_patches}") print("✓ Patch processing successful") print() # Test gradient flow print("Testing gradient flow...") model.train() x_grad = torch.randn(batch_size, seq_len, neural_dim, requires_grad=True) # Forward pass in training mode outputs = model(x_grad, day_idx, mode='full') # Calculate losses target = torch.randint(0, n_classes, (batch_size, seq_len)) clean_loss = torch.nn.functional.cross_entropy( outputs['clean_logits'].reshape(-1, n_classes), target.reshape(-1) ) noisy_loss = torch.nn.functional.cross_entropy( outputs['noisy_logits'].reshape(-1, n_classes), target.reshape(-1) ) # Backward passes clean_loss.backward(retain_graph=True) noisy_loss.backward() # Check gradients clean_has_grad = any(p.grad is not None for p in model.clean_speech_model.parameters()) noisy_has_grad = any(p.grad is not None for p in model.noisy_speech_model.parameters()) noise_has_grad = any(p.grad is not None for p in model.noise_model.parameters()) print(f"Clean model gradients: {'✓' if clean_has_grad else '✗'}") print(f"Noisy model gradients: {'✓' if noisy_has_grad else '✗'}") print(f"Noise model gradients: {'✓' if noise_has_grad else '✗'}") print("✓ Gradient flow test successful") print() return True def test_adversarial_training_simulation(): """Simulate adversarial training with gradient combination""" print("=== Adversarial Training Simulation ===") # Simple test to verify gradient combination logic model = TripleGRUDecoder(512, 768, 5, 41, rnn_dropout=0.1) # Create fake gradients fake_clean_grad = torch.randn(41, 768) # Output layer gradients fake_noisy_grad = torch.randn(41, 768) print("Testing gradient combination...") try: model.apply_gradient_combination(fake_clean_grad, fake_noisy_grad, learning_rate=1e-3) print("✓ Gradient combination mechanism working") except Exception as e: print(f"✗ Gradient combination failed: {e}") print() return True if __name__ == "__main__": print("Starting comprehensive tests for Triple-GRU architecture...\n") # Run all tests test_individual_models() test_triple_gru_architecture() test_adversarial_training_simulation() print("=== All tests completed! ===") print() # Print architecture summary print("=== Triple-GRU Architecture Summary ===") print("Training Mode Data Flow:") print("1. Input → NoiseModel → Noise Estimation") print("2. Input - Noise → CleanSpeechModel → Clean Recognition") print("3. Noise → NoisySpeechModel → Noisy Recognition") print("4. Gradient Combination: -Clean_grad + Noisy_grad → NoiseModel") print() print("Inference Mode Data Flow:") print("1. Input → NoiseModel → Noise Estimation") print("2. Input - Noise → CleanSpeechModel → Final Recognition") print()