161 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			161 lines
		
	
	
		
			4.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| """
 | |
| Quick test to verify TensorFlow implementation fixes
 | |
| This tests the core fixes without requiring external dependencies
 | |
| """
 | |
| 
 | |
| try:
 | |
|     import tensorflow as tf
 | |
|     print("✅ TensorFlow imported successfully")
 | |
| except ImportError as e:
 | |
|     print(f"❌ TensorFlow import failed: {e}")
 | |
|     exit(1)
 | |
| 
 | |
| def test_gradient_reversal():
 | |
|     """Test gradient reversal layer fix"""
 | |
|     print("\n=== Testing Gradient Reversal Fix ===")
 | |
|     try:
 | |
|         # Import our fixed gradient reversal function
 | |
|         import sys
 | |
|         import os
 | |
|         sys.path.append(os.path.dirname(os.path.abspath(__file__)))
 | |
| 
 | |
|         from rnn_model_tf import gradient_reverse
 | |
| 
 | |
|         x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
 | |
| 
 | |
|         # Test forward pass (should be identity)
 | |
|         y = gradient_reverse(x, lambd=0.5)
 | |
| 
 | |
|         # Check forward pass
 | |
|         if tf.reduce_all(tf.equal(x, y)):
 | |
|             print("✅ Gradient reversal forward pass works")
 | |
| 
 | |
|             # Test gradient computation
 | |
|             with tf.GradientTape() as tape:
 | |
|                 tape.watch(x)
 | |
|                 y = gradient_reverse(x, lambd=0.5)
 | |
|                 loss = tf.reduce_sum(y)
 | |
| 
 | |
|             grad = tape.gradient(loss, x)
 | |
|             expected_grad = -0.5 * tf.ones_like(x)
 | |
| 
 | |
|             if tf.reduce_all(tf.abs(grad - expected_grad) < 1e-6):
 | |
|                 print("✅ Gradient reversal gradients work correctly")
 | |
|                 return True
 | |
|             else:
 | |
|                 print(f"❌ Gradient reversal gradients incorrect: got {grad}, expected {expected_grad}")
 | |
|                 return False
 | |
|         else:
 | |
|             print("❌ Gradient reversal forward pass failed")
 | |
|             return False
 | |
| 
 | |
|     except Exception as e:
 | |
|         print(f"❌ Gradient reversal test failed: {e}")
 | |
|         return False
 | |
| 
 | |
| def test_ctc_loss():
 | |
|     """Test CTC loss fix"""
 | |
|     print("\n=== Testing CTC Loss Fix ===")
 | |
|     try:
 | |
|         from rnn_model_tf import CTCLoss
 | |
| 
 | |
|         ctc_loss = CTCLoss(blank_index=0, reduction='none')
 | |
| 
 | |
|         # Create simple test data
 | |
|         batch_size = 2
 | |
|         time_steps = 5
 | |
|         n_classes = 4
 | |
| 
 | |
|         logits = tf.random.normal((batch_size, time_steps, n_classes))
 | |
|         labels = tf.constant([[1, 2, 0, 0], [3, 1, 2, 0]], dtype=tf.int32)
 | |
|         input_lengths = tf.constant([time_steps, time_steps], dtype=tf.int32)
 | |
|         label_lengths = tf.constant([2, 3], dtype=tf.int32)
 | |
| 
 | |
|         loss_input = {
 | |
|             'labels': labels,
 | |
|             'input_lengths': input_lengths,
 | |
|             'label_lengths': label_lengths
 | |
|         }
 | |
| 
 | |
|         loss = ctc_loss(loss_input, logits)
 | |
| 
 | |
|         if tf.reduce_all(tf.math.is_finite(loss)) and loss.shape == (batch_size,):
 | |
|             print("✅ CTC loss computation works")
 | |
|             return True
 | |
|         else:
 | |
|             print(f"❌ CTC loss failed: shape {loss.shape}, finite: {tf.reduce_all(tf.math.is_finite(loss))}")
 | |
|             return False
 | |
| 
 | |
|     except Exception as e:
 | |
|         print(f"❌ CTC loss test failed: {e}")
 | |
|         return False
 | |
| 
 | |
| def test_basic_model():
 | |
|     """Test basic model creation"""
 | |
|     print("\n=== Testing Basic Model Creation ===")
 | |
|     try:
 | |
|         from rnn_model_tf import TripleGRUDecoder
 | |
| 
 | |
|         model = TripleGRUDecoder(
 | |
|             neural_dim=64,  # Smaller for testing
 | |
|             n_units=32,
 | |
|             n_days=2,
 | |
|             n_classes=10,
 | |
|             rnn_dropout=0.1,
 | |
|             input_dropout=0.1,
 | |
|             patch_size=2,
 | |
|             patch_stride=1
 | |
|         )
 | |
| 
 | |
|         # Test forward pass
 | |
|         batch_size = 2
 | |
|         time_steps = 10
 | |
|         x = tf.random.normal((batch_size, time_steps, 64))
 | |
|         day_idx = tf.constant([0, 1], dtype=tf.int32)
 | |
| 
 | |
|         # Test inference mode
 | |
|         logits = model(x, day_idx, mode='inference', training=False)
 | |
|         expected_time_steps = (time_steps - 2) // 1 + 1
 | |
| 
 | |
|         if logits.shape == (batch_size, expected_time_steps, 10):
 | |
|             print("✅ Basic model inference works")
 | |
|             return True
 | |
|         else:
 | |
|             print(f"❌ Model output shape incorrect: {logits.shape}")
 | |
|             return False
 | |
| 
 | |
|     except Exception as e:
 | |
|         print(f"❌ Basic model test failed: {e}")
 | |
|         return False
 | |
| 
 | |
| def main():
 | |
|     """Run all tests"""
 | |
|     print("🧪 Testing TensorFlow Implementation Fixes")
 | |
|     print("=" * 50)
 | |
| 
 | |
|     tests = [
 | |
|         test_gradient_reversal,
 | |
|         test_ctc_loss,
 | |
|         test_basic_model
 | |
|     ]
 | |
| 
 | |
|     passed = 0
 | |
|     total = len(tests)
 | |
| 
 | |
|     for test in tests:
 | |
|         if test():
 | |
|             passed += 1
 | |
| 
 | |
|     print("\n" + "=" * 50)
 | |
|     print(f"📊 Test Results: {passed}/{total} tests passed")
 | |
| 
 | |
|     if passed == total:
 | |
|         print("🎉 All fixes working correctly!")
 | |
|         return 0
 | |
|     else:
 | |
|         print("❌ Some fixes still need work")
 | |
|         return 1
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     exit(main()) | 
