254 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			254 lines
		
	
	
		
			8.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | |
| """
 | |
| 测试优化后的数据加载管道性能
 | |
| Test script for optimized data loading pipeline performance
 | |
| """
 | |
| 
 | |
| import os
 | |
| import time
 | |
| import psutil
 | |
| import tensorflow as tf
 | |
| from omegaconf import OmegaConf
 | |
| from dataset_tf import BrainToTextDatasetTF, train_test_split_indices, create_input_fn
 | |
| 
 | |
| def get_memory_usage():
 | |
|     """获取当前内存使用情况"""
 | |
|     process = psutil.Process()
 | |
|     memory_info = process.memory_info()
 | |
|     return memory_info.rss / 1024 / 1024  # MB
 | |
| 
 | |
| def test_data_loading_performance():
 | |
|     """测试数据加载性能对比"""
 | |
| 
 | |
|     # 加载配置
 | |
|     config_path = "../rnn_args.yaml"
 | |
|     if not os.path.exists(config_path):
 | |
|         print("❌ Configuration file not found. Creating minimal test config...")
 | |
|         # 创建最小测试配置
 | |
|         args = {
 | |
|             'dataset': {
 | |
|                 'dataset_dir': '../data/hdf5_data_final',
 | |
|                 'sessions': ['t15.2022.03.14', 't15.2022.03.16'],
 | |
|                 'batch_size': 32,
 | |
|                 'days_per_batch': 1,
 | |
|                 'seed': 42,
 | |
|                 'data_transforms': {
 | |
|                     'smooth_data': False,
 | |
|                     'white_noise_std': 0.0,
 | |
|                     'constant_offset_std': 0.0,
 | |
|                     'random_walk_std': 0.0,
 | |
|                     'static_gain_std': 0.0,
 | |
|                     'random_cut': 0
 | |
|                 }
 | |
|             },
 | |
|             'num_training_batches': 10  # 只测试10个batch
 | |
|         }
 | |
|     else:
 | |
|         args = OmegaConf.load(config_path)
 | |
|         args = OmegaConf.to_container(args, resolve=True)
 | |
|         # 限制测试batch数量
 | |
|         args['num_training_batches'] = 10
 | |
| 
 | |
|     print("🔍 Starting data loading performance test...")
 | |
|     print(f"📊 Test configuration: {args['num_training_batches']} batches, batch_size={args['dataset']['batch_size']}")
 | |
| 
 | |
|     # 获取文件路径
 | |
|     train_file_paths = [
 | |
|         os.path.join(args["dataset"]["dataset_dir"], s, 'data_train.hdf5')
 | |
|         for s in args['dataset']['sessions']
 | |
|     ]
 | |
| 
 | |
|     print(f"📁 Testing with files: {train_file_paths}")
 | |
| 
 | |
|     # 检查文件是否存在
 | |
|     missing_files = [f for f in train_file_paths if not os.path.exists(f)]
 | |
|     if missing_files:
 | |
|         print(f"❌ Missing files: {missing_files}")
 | |
|         print("⚠️ Creating dummy test data...")
 | |
|         return test_with_dummy_data(args)
 | |
| 
 | |
|     # 分割数据
 | |
|     print("🔄 Splitting data...")
 | |
|     train_trials, _ = train_test_split_indices(
 | |
|         file_paths=train_file_paths,
 | |
|         test_percentage=0,
 | |
|         seed=args['dataset']['seed']
 | |
|     )
 | |
| 
 | |
|     print(f"📈 Found {sum(len(trials['trials']) for trials in train_trials.values())} training trials")
 | |
| 
 | |
|     # 测试1: 不使用缓存
 | |
|     print("\n" + "="*60)
 | |
|     print("🐌 TEST 1: 标准数据加载 (无缓存)")
 | |
|     print("="*60)
 | |
| 
 | |
|     initial_memory = get_memory_usage()
 | |
|     start_time = time.time()
 | |
| 
 | |
|     dataset_no_cache = BrainToTextDatasetTF(
 | |
|         trial_indices=train_trials,
 | |
|         n_batches=args['num_training_batches'],
 | |
|         split='train',
 | |
|         batch_size=args['dataset']['batch_size'],
 | |
|         days_per_batch=args['dataset']['days_per_batch'],
 | |
|         random_seed=args['dataset']['seed'],
 | |
|         cache_data=False,           # 禁用缓存
 | |
|         preload_all_data=False      # 禁用预加载
 | |
|     )
 | |
| 
 | |
|     tf_dataset_no_cache = create_input_fn(
 | |
|         dataset_no_cache,
 | |
|         args['dataset']['data_transforms'],
 | |
|         training=True
 | |
|     )
 | |
| 
 | |
|     # 测试前3个batch的加载时间
 | |
|     batch_times = []
 | |
|     for i, batch in enumerate(tf_dataset_no_cache.take(3)):
 | |
|         batch_start = time.time()
 | |
|         # 触发实际数据加载
 | |
|         _ = batch['input_features'].numpy()
 | |
|         batch_time = time.time() - batch_start
 | |
|         batch_times.append(batch_time)
 | |
|         print(f"   Batch {i}: {batch_time:.3f}s")
 | |
| 
 | |
|     no_cache_time = time.time() - start_time
 | |
|     no_cache_memory = get_memory_usage() - initial_memory
 | |
| 
 | |
|     print(f"💾 Memory usage: +{no_cache_memory:.1f} MB")
 | |
|     print(f"⏱️ Total time: {no_cache_time:.3f}s")
 | |
|     print(f"📊 Avg batch time: {sum(batch_times)/len(batch_times):.3f}s")
 | |
| 
 | |
|     # 测试2: 使用预加载缓存
 | |
|     print("\n" + "="*60)
 | |
|     print("🚀 TEST 2: 优化数据加载 (全缓存预加载)")
 | |
|     print("="*60)
 | |
| 
 | |
|     initial_memory = get_memory_usage()
 | |
|     start_time = time.time()
 | |
| 
 | |
|     dataset_with_cache = BrainToTextDatasetTF(
 | |
|         trial_indices=train_trials,
 | |
|         n_batches=args['num_training_batches'],
 | |
|         split='train',
 | |
|         batch_size=args['dataset']['batch_size'],
 | |
|         days_per_batch=args['dataset']['days_per_batch'],
 | |
|         random_seed=args['dataset']['seed'],
 | |
|         cache_data=True,            # 启用缓存
 | |
|         preload_all_data=True       # 启用预加载
 | |
|     )
 | |
| 
 | |
|     preload_time = time.time() - start_time
 | |
|     preload_memory = get_memory_usage() - initial_memory
 | |
| 
 | |
|     print(f"📝 Preloading completed in {preload_time:.3f}s")
 | |
|     print(f"💾 Preloading memory: +{preload_memory:.1f} MB")
 | |
| 
 | |
|     tf_dataset_with_cache = create_input_fn(
 | |
|         dataset_with_cache,
 | |
|         args['dataset']['data_transforms'],
 | |
|         training=True
 | |
|     )
 | |
| 
 | |
|     # 测试前3个batch的加载时间
 | |
|     batch_start_time = time.time()
 | |
|     batch_times_cached = []
 | |
|     for i, batch in enumerate(tf_dataset_with_cache.take(3)):
 | |
|         batch_start = time.time()
 | |
|         # 触发实际数据加载
 | |
|         _ = batch['input_features'].numpy()
 | |
|         batch_time = time.time() - batch_start
 | |
|         batch_times_cached.append(batch_time)
 | |
|         print(f"   Batch {i}: {batch_time:.3f}s")
 | |
| 
 | |
|     cached_batch_time = time.time() - batch_start_time
 | |
|     cached_memory = get_memory_usage() - initial_memory
 | |
| 
 | |
|     print(f"💾 Total memory usage: +{cached_memory:.1f} MB")
 | |
|     print(f"⏱️ Batch loading time: {cached_batch_time:.3f}s")
 | |
|     print(f"📊 Avg batch time: {sum(batch_times_cached)/len(batch_times_cached):.3f}s")
 | |
| 
 | |
|     # 性能对比
 | |
|     print("\n" + "="*60)
 | |
|     print("📈 PERFORMANCE COMPARISON")
 | |
|     print("="*60)
 | |
| 
 | |
|     speedup = (sum(batch_times)/len(batch_times)) / (sum(batch_times_cached)/len(batch_times_cached))
 | |
|     memory_cost = cached_memory - no_cache_memory
 | |
| 
 | |
|     print(f"🚀 Speed improvement: {speedup:.1f}x faster")
 | |
|     print(f"💾 Memory cost: +{memory_cost:.1f} MB for caching")
 | |
|     print(f"⚡ First batch time: {batch_times[0]:.3f}s → {batch_times_cached[0]:.3f}s")
 | |
| 
 | |
|     if speedup > 2:
 | |
|         print("✅ Excellent! 缓存优化显著提升了数据加载速度")
 | |
|     elif speedup > 1.5:
 | |
|         print("✅ Good! 缓存优化有效提升了数据加载速度")
 | |
|     else:
 | |
|         print("⚠️ Warning: 缓存优化效果不明显,可能数据量太小")
 | |
| 
 | |
|     return True
 | |
| 
 | |
| def test_with_dummy_data(args):
 | |
|     """使用模拟数据进行测试"""
 | |
|     print("🔧 Creating dummy data for testing...")
 | |
| 
 | |
|     # 创建模拟试验索引
 | |
|     dummy_trials = {
 | |
|         0: {
 | |
|             'trials': list(range(100)),  # 100个模拟试验
 | |
|             'session_path': 'dummy_path'
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     print("📊 Testing with dummy data (100 trials)...")
 | |
| 
 | |
|     # 测试缓存vs非缓存的初始化时间差异
 | |
|     print("\n🐌 Testing without cache...")
 | |
|     start_time = time.time()
 | |
|     dataset_no_cache = BrainToTextDatasetTF(
 | |
|         trial_indices=dummy_trials,
 | |
|         n_batches=5,
 | |
|         split='train',
 | |
|         batch_size=32,
 | |
|         days_per_batch=1,
 | |
|         random_seed=42,
 | |
|         cache_data=False,
 | |
|         preload_all_data=False
 | |
|     )
 | |
|     no_cache_time = time.time() - start_time
 | |
|     print(f"   Initialization time: {no_cache_time:.3f}s")
 | |
| 
 | |
|     print("\n🚀 Testing with cache...")
 | |
|     start_time = time.time()
 | |
|     dataset_with_cache = BrainToTextDatasetTF(
 | |
|         trial_indices=dummy_trials,
 | |
|         n_batches=5,
 | |
|         split='train',
 | |
|         batch_size=32,
 | |
|         days_per_batch=1,
 | |
|         random_seed=42,
 | |
|         cache_data=True,
 | |
|         preload_all_data=True
 | |
|     )
 | |
|     cache_time = time.time() - start_time
 | |
|     print(f"   Initialization time: {cache_time:.3f}s")
 | |
| 
 | |
|     print(f"\n✅ 缓存机制已成功集成到数据加载管道中")
 | |
|     print(f"📝 实际性能需要用真实的HDF5数据进行测试")
 | |
| 
 | |
|     return True
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     print("🧪 Data Loading Performance Test")
 | |
|     print("="*60)
 | |
| 
 | |
|     try:
 | |
|         success = test_data_loading_performance()
 | |
|         if success:
 | |
|             print("\n🎉 Data loading optimization test completed successfully!")
 | |
|             print("💡 你现在可以运行 train_model_tf.py 来享受快速的数据加载了")
 | |
|     except Exception as e:
 | |
|         print(f"\n❌ Test failed with error: {e}")
 | |
|         import traceback
 | |
|         traceback.print_exc() | 
