106 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			106 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | #!/usr/bin/env python3 | ||
|  | """
 | ||
|  | 快速测试GA_optimize模块的核心功能 | ||
|  | """
 | ||
|  | 
 | ||
|  | import sys | ||
|  | import os | ||
|  | sys.path.append('.') | ||
|  | 
 | ||
|  | import GA_optimize | ||
|  | import numpy as np | ||
|  | import time | ||
|  | 
 | ||
|  | def quick_test(): | ||
|  |     print("=== GA_optimize Quick Test ===") | ||
|  |      | ||
|  |     try: | ||
|  |         print("1. Testing module import... ✓") | ||
|  |          | ||
|  |         print("2. Creating optimizer instance...") | ||
|  |         optimizer = GA_optimize.TTAEGeneticOptimizer() | ||
|  |         print("   Optimizer created successfully ✓") | ||
|  |          | ||
|  |         print("3. Testing cache system...") | ||
|  |         cache = GA_optimize.TTAEnsembleCache('./test_cache') | ||
|  |         cache.add_prediction('gru', 'test_session', 0, 'original', np.random.rand(1, 10, 40)) | ||
|  |         cached = cache.get_prediction('gru', 'test_session', 0, 'original') | ||
|  |         assert cached is not None, "Cache retrieval failed" | ||
|  |         print("   Cache system working ✓") | ||
|  |          | ||
|  |         print("4. Testing parameter bounds...") | ||
|  |         # 测试参数范围 | ||
|  |         gru_weight = 0.6 | ||
|  |         tta_weights = [2.0, 1.0, 0.5, 0.0, 0.0] | ||
|  |         print(f"   Test params: gru_weight={gru_weight}, tta_weights={tta_weights} ✓") | ||
|  |          | ||
|  |         print("5. Testing fitness function with dummy data...") | ||
|  |         # 使用少量数据测试适应度函数 | ||
|  |         limited_test_data = {} | ||
|  |         sessions = list(optimizer.test_data.keys())[:2]  # 只取前2个session | ||
|  |          | ||
|  |         for session in sessions: | ||
|  |             data = optimizer.test_data[session] | ||
|  |             limited_data = {} | ||
|  |             for key, value in data.items(): | ||
|  |                 if isinstance(value, list): | ||
|  |                     limited_data[key] = value[:2]  # 只取前2个试验 | ||
|  |                 else: | ||
|  |                     limited_data[key] = value | ||
|  |             limited_test_data[session] = limited_data | ||
|  |          | ||
|  |         # 暂时替换测试数据 | ||
|  |         original_data = optimizer.test_data | ||
|  |         optimizer.test_data = limited_test_data | ||
|  |         optimizer.trials_per_session = {k: len(v['neural_features']) for k, v in limited_test_data.items()} | ||
|  |          | ||
|  |         print("   Limited test data prepared ✓") | ||
|  |         print(f"   Testing with {sum(optimizer.trials_per_session.values())} trials") | ||
|  |          | ||
|  |         # 生成少量预测进行测试 | ||
|  |         print("6. Generating small cache for testing...") | ||
|  |         start_time = time.time() | ||
|  |         optimizer.generate_all_predictions() | ||
|  |         cache_time = time.time() - start_time | ||
|  |         print(f"   Cache generation completed in {cache_time:.2f}s ✓") | ||
|  |          | ||
|  |         print("7. Testing parameter evaluation...") | ||
|  |         test_params = [0.5, 1.0, 1.0, 0.0, 0.0, 0.0] | ||
|  |         start_time = time.time() | ||
|  |         per = optimizer.evaluate_parameters(test_params[0], test_params[1:6]) | ||
|  |         eval_time = time.time() - start_time | ||
|  |         print(f"   Parameter evaluation completed in {eval_time:.2f}s ✓") | ||
|  |         print(f"   Test PER: {per:.2f}%") | ||
|  |          | ||
|  |         print("8. Testing fitness function...") | ||
|  |         fitness = optimizer.fitness_function(None, test_params, 0) | ||
|  |         print(f"   Fitness value: {fitness:.2f} ✓") | ||
|  |          | ||
|  |         # 恢复原始数据 | ||
|  |         optimizer.test_data = original_data | ||
|  |          | ||
|  |         print("\n=== All Tests Passed! ===") | ||
|  |         print("✓ Module import and initialization") | ||
|  |         print("✓ Cache system functionality")  | ||
|  |         print("✓ Model loading and setup") | ||
|  |         print("✓ Data preprocessing") | ||
|  |         print("✓ Prediction generation and caching") | ||
|  |         print("✓ Parameter evaluation") | ||
|  |         print("✓ Fitness function") | ||
|  |         print("\nThe GA_optimize module is ready for full optimization!") | ||
|  |          | ||
|  |         return True | ||
|  |          | ||
|  |     except Exception as e: | ||
|  |         print(f"\n❌ Test failed with error: {e}") | ||
|  |         import traceback | ||
|  |         traceback.print_exc() | ||
|  |         return False | ||
|  | 
 | ||
|  | if __name__ == "__main__": | ||
|  |     success = quick_test() | ||
|  |     if success: | ||
|  |         print("\n🎉 Ready to run full genetic algorithm optimization!") | ||
|  |         print("To run full optimization, use: python GA_optimize.py") | ||
|  |     else: | ||
|  |         print("\n⚠️  Please fix the issues before running full optimization.") |