106 lines
3.9 KiB
Python
106 lines
3.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
快速测试GA_optimize模块的核心功能
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
sys.path.append('.')
|
|
|
|
import GA_optimize
|
|
import numpy as np
|
|
import time
|
|
|
|
def quick_test():
|
|
print("=== GA_optimize Quick Test ===")
|
|
|
|
try:
|
|
print("1. Testing module import... ✓")
|
|
|
|
print("2. Creating optimizer instance...")
|
|
optimizer = GA_optimize.TTAEGeneticOptimizer()
|
|
print(" Optimizer created successfully ✓")
|
|
|
|
print("3. Testing cache system...")
|
|
cache = GA_optimize.TTAEnsembleCache('./test_cache')
|
|
cache.add_prediction('gru', 'test_session', 0, 'original', np.random.rand(1, 10, 40))
|
|
cached = cache.get_prediction('gru', 'test_session', 0, 'original')
|
|
assert cached is not None, "Cache retrieval failed"
|
|
print(" Cache system working ✓")
|
|
|
|
print("4. Testing parameter bounds...")
|
|
# 测试参数范围
|
|
gru_weight = 0.6
|
|
tta_weights = [2.0, 1.0, 0.5, 0.0, 0.0]
|
|
print(f" Test params: gru_weight={gru_weight}, tta_weights={tta_weights} ✓")
|
|
|
|
print("5. Testing fitness function with dummy data...")
|
|
# 使用少量数据测试适应度函数
|
|
limited_test_data = {}
|
|
sessions = list(optimizer.test_data.keys())[:2] # 只取前2个session
|
|
|
|
for session in sessions:
|
|
data = optimizer.test_data[session]
|
|
limited_data = {}
|
|
for key, value in data.items():
|
|
if isinstance(value, list):
|
|
limited_data[key] = value[:2] # 只取前2个试验
|
|
else:
|
|
limited_data[key] = value
|
|
limited_test_data[session] = limited_data
|
|
|
|
# 暂时替换测试数据
|
|
original_data = optimizer.test_data
|
|
optimizer.test_data = limited_test_data
|
|
optimizer.trials_per_session = {k: len(v['neural_features']) for k, v in limited_test_data.items()}
|
|
|
|
print(" Limited test data prepared ✓")
|
|
print(f" Testing with {sum(optimizer.trials_per_session.values())} trials")
|
|
|
|
# 生成少量预测进行测试
|
|
print("6. Generating small cache for testing...")
|
|
start_time = time.time()
|
|
optimizer.generate_all_predictions()
|
|
cache_time = time.time() - start_time
|
|
print(f" Cache generation completed in {cache_time:.2f}s ✓")
|
|
|
|
print("7. Testing parameter evaluation...")
|
|
test_params = [0.5, 1.0, 1.0, 0.0, 0.0, 0.0]
|
|
start_time = time.time()
|
|
per = optimizer.evaluate_parameters(test_params[0], test_params[1:6])
|
|
eval_time = time.time() - start_time
|
|
print(f" Parameter evaluation completed in {eval_time:.2f}s ✓")
|
|
print(f" Test PER: {per:.2f}%")
|
|
|
|
print("8. Testing fitness function...")
|
|
fitness = optimizer.fitness_function(None, test_params, 0)
|
|
print(f" Fitness value: {fitness:.2f} ✓")
|
|
|
|
# 恢复原始数据
|
|
optimizer.test_data = original_data
|
|
|
|
print("\n=== All Tests Passed! ===")
|
|
print("✓ Module import and initialization")
|
|
print("✓ Cache system functionality")
|
|
print("✓ Model loading and setup")
|
|
print("✓ Data preprocessing")
|
|
print("✓ Prediction generation and caching")
|
|
print("✓ Parameter evaluation")
|
|
print("✓ Fitness function")
|
|
print("\nThe GA_optimize module is ready for full optimization!")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"\n❌ Test failed with error: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
if __name__ == "__main__":
|
|
success = quick_test()
|
|
if success:
|
|
print("\n🎉 Ready to run full genetic algorithm optimization!")
|
|
print("To run full optimization, use: python GA_optimize.py")
|
|
else:
|
|
print("\n⚠️ Please fix the issues before running full optimization.") |