diff --git a/model_training_nnn_tpu/quick_test_fixes.py b/model_training_nnn_tpu/quick_test_fixes.py deleted file mode 100644 index 00405b7..0000000 --- a/model_training_nnn_tpu/quick_test_fixes.py +++ /dev/null @@ -1,161 +0,0 @@ -#!/usr/bin/env python3 -""" -Quick test to verify TensorFlow implementation fixes -This tests the core fixes without requiring external dependencies -""" - -try: - import tensorflow as tf - print("✅ TensorFlow imported successfully") -except ImportError as e: - print(f"❌ TensorFlow import failed: {e}") - exit(1) - -def test_gradient_reversal(): - """Test gradient reversal layer fix""" - print("\n=== Testing Gradient Reversal Fix ===") - try: - # Import our fixed gradient reversal function - import sys - import os - sys.path.append(os.path.dirname(os.path.abspath(__file__))) - - from rnn_model_tf import gradient_reverse - - x = tf.constant([[1.0, 2.0], [3.0, 4.0]]) - - # Test forward pass (should be identity) - y = gradient_reverse(x, lambd=0.5) - - # Check forward pass - if tf.reduce_all(tf.equal(x, y)): - print("✅ Gradient reversal forward pass works") - - # Test gradient computation - with tf.GradientTape() as tape: - tape.watch(x) - y = gradient_reverse(x, lambd=0.5) - loss = tf.reduce_sum(y) - - grad = tape.gradient(loss, x) - expected_grad = -0.5 * tf.ones_like(x) - - if tf.reduce_all(tf.abs(grad - expected_grad) < 1e-6): - print("✅ Gradient reversal gradients work correctly") - return True - else: - print(f"❌ Gradient reversal gradients incorrect: got {grad}, expected {expected_grad}") - return False - else: - print("❌ Gradient reversal forward pass failed") - return False - - except Exception as e: - print(f"❌ Gradient reversal test failed: {e}") - return False - -def test_ctc_loss(): - """Test CTC loss fix""" - print("\n=== Testing CTC Loss Fix ===") - try: - from rnn_model_tf import CTCLoss - - ctc_loss = CTCLoss(blank_index=0, reduction='none') - - # Create simple test data - batch_size = 2 - time_steps = 5 - n_classes = 4 - - logits = tf.random.normal((batch_size, time_steps, n_classes)) - labels = tf.constant([[1, 2, 0, 0], [3, 1, 2, 0]], dtype=tf.int32) - input_lengths = tf.constant([time_steps, time_steps], dtype=tf.int32) - label_lengths = tf.constant([2, 3], dtype=tf.int32) - - loss_input = { - 'labels': labels, - 'input_lengths': input_lengths, - 'label_lengths': label_lengths - } - - loss = ctc_loss(loss_input, logits) - - if tf.reduce_all(tf.math.is_finite(loss)) and loss.shape == (batch_size,): - print("✅ CTC loss computation works") - return True - else: - print(f"❌ CTC loss failed: shape {loss.shape}, finite: {tf.reduce_all(tf.math.is_finite(loss))}") - return False - - except Exception as e: - print(f"❌ CTC loss test failed: {e}") - return False - -def test_basic_model(): - """Test basic model creation""" - print("\n=== Testing Basic Model Creation ===") - try: - from rnn_model_tf import TripleGRUDecoder - - model = TripleGRUDecoder( - neural_dim=64, # Smaller for testing - n_units=32, - n_days=2, - n_classes=10, - rnn_dropout=0.1, - input_dropout=0.1, - patch_size=2, - patch_stride=1 - ) - - # Test forward pass - batch_size = 2 - time_steps = 10 - x = tf.random.normal((batch_size, time_steps, 64)) - day_idx = tf.constant([0, 1], dtype=tf.int32) - - # Test inference mode - logits = model(x, day_idx, mode='inference', training=False) - expected_time_steps = (time_steps - 2) // 1 + 1 - - if logits.shape == (batch_size, expected_time_steps, 10): - print("✅ Basic model inference works") - return True - else: - print(f"❌ Model output shape incorrect: {logits.shape}") - return False - - except Exception as e: - print(f"❌ Basic model test failed: {e}") - return False - -def main(): - """Run all tests""" - print("🧪 Testing TensorFlow Implementation Fixes") - print("=" * 50) - - tests = [ - test_gradient_reversal, - test_ctc_loss, - test_basic_model - ] - - passed = 0 - total = len(tests) - - for test in tests: - if test(): - passed += 1 - - print("\n" + "=" * 50) - print(f"📊 Test Results: {passed}/{total} tests passed") - - if passed == total: - print("🎉 All fixes working correctly!") - return 0 - else: - print("❌ Some fixes still need work") - return 1 - -if __name__ == "__main__": - exit(main()) \ No newline at end of file