remove quick test script for TensorFlow implementation fixes
This commit is contained in:
@@ -1,161 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Quick test to verify TensorFlow implementation fixes
|
||||
This tests the core fixes without requiring external dependencies
|
||||
"""
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
print("✅ TensorFlow imported successfully")
|
||||
except ImportError as e:
|
||||
print(f"❌ TensorFlow import failed: {e}")
|
||||
exit(1)
|
||||
|
||||
def test_gradient_reversal():
|
||||
"""Test gradient reversal layer fix"""
|
||||
print("\n=== Testing Gradient Reversal Fix ===")
|
||||
try:
|
||||
# Import our fixed gradient reversal function
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from rnn_model_tf import gradient_reverse
|
||||
|
||||
x = tf.constant([[1.0, 2.0], [3.0, 4.0]])
|
||||
|
||||
# Test forward pass (should be identity)
|
||||
y = gradient_reverse(x, lambd=0.5)
|
||||
|
||||
# Check forward pass
|
||||
if tf.reduce_all(tf.equal(x, y)):
|
||||
print("✅ Gradient reversal forward pass works")
|
||||
|
||||
# Test gradient computation
|
||||
with tf.GradientTape() as tape:
|
||||
tape.watch(x)
|
||||
y = gradient_reverse(x, lambd=0.5)
|
||||
loss = tf.reduce_sum(y)
|
||||
|
||||
grad = tape.gradient(loss, x)
|
||||
expected_grad = -0.5 * tf.ones_like(x)
|
||||
|
||||
if tf.reduce_all(tf.abs(grad - expected_grad) < 1e-6):
|
||||
print("✅ Gradient reversal gradients work correctly")
|
||||
return True
|
||||
else:
|
||||
print(f"❌ Gradient reversal gradients incorrect: got {grad}, expected {expected_grad}")
|
||||
return False
|
||||
else:
|
||||
print("❌ Gradient reversal forward pass failed")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Gradient reversal test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_ctc_loss():
|
||||
"""Test CTC loss fix"""
|
||||
print("\n=== Testing CTC Loss Fix ===")
|
||||
try:
|
||||
from rnn_model_tf import CTCLoss
|
||||
|
||||
ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
||||
|
||||
# Create simple test data
|
||||
batch_size = 2
|
||||
time_steps = 5
|
||||
n_classes = 4
|
||||
|
||||
logits = tf.random.normal((batch_size, time_steps, n_classes))
|
||||
labels = tf.constant([[1, 2, 0, 0], [3, 1, 2, 0]], dtype=tf.int32)
|
||||
input_lengths = tf.constant([time_steps, time_steps], dtype=tf.int32)
|
||||
label_lengths = tf.constant([2, 3], dtype=tf.int32)
|
||||
|
||||
loss_input = {
|
||||
'labels': labels,
|
||||
'input_lengths': input_lengths,
|
||||
'label_lengths': label_lengths
|
||||
}
|
||||
|
||||
loss = ctc_loss(loss_input, logits)
|
||||
|
||||
if tf.reduce_all(tf.math.is_finite(loss)) and loss.shape == (batch_size,):
|
||||
print("✅ CTC loss computation works")
|
||||
return True
|
||||
else:
|
||||
print(f"❌ CTC loss failed: shape {loss.shape}, finite: {tf.reduce_all(tf.math.is_finite(loss))}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ CTC loss test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_basic_model():
|
||||
"""Test basic model creation"""
|
||||
print("\n=== Testing Basic Model Creation ===")
|
||||
try:
|
||||
from rnn_model_tf import TripleGRUDecoder
|
||||
|
||||
model = TripleGRUDecoder(
|
||||
neural_dim=64, # Smaller for testing
|
||||
n_units=32,
|
||||
n_days=2,
|
||||
n_classes=10,
|
||||
rnn_dropout=0.1,
|
||||
input_dropout=0.1,
|
||||
patch_size=2,
|
||||
patch_stride=1
|
||||
)
|
||||
|
||||
# Test forward pass
|
||||
batch_size = 2
|
||||
time_steps = 10
|
||||
x = tf.random.normal((batch_size, time_steps, 64))
|
||||
day_idx = tf.constant([0, 1], dtype=tf.int32)
|
||||
|
||||
# Test inference mode
|
||||
logits = model(x, day_idx, mode='inference', training=False)
|
||||
expected_time_steps = (time_steps - 2) // 1 + 1
|
||||
|
||||
if logits.shape == (batch_size, expected_time_steps, 10):
|
||||
print("✅ Basic model inference works")
|
||||
return True
|
||||
else:
|
||||
print(f"❌ Model output shape incorrect: {logits.shape}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Basic model test failed: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("🧪 Testing TensorFlow Implementation Fixes")
|
||||
print("=" * 50)
|
||||
|
||||
tests = [
|
||||
test_gradient_reversal,
|
||||
test_ctc_loss,
|
||||
test_basic_model
|
||||
]
|
||||
|
||||
passed = 0
|
||||
total = len(tests)
|
||||
|
||||
for test in tests:
|
||||
if test():
|
||||
passed += 1
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print(f"📊 Test Results: {passed}/{total} tests passed")
|
||||
|
||||
if passed == total:
|
||||
print("🎉 All fixes working correctly!")
|
||||
return 0
|
||||
else:
|
||||
print("❌ Some fixes still need work")
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
Reference in New Issue
Block a user