#!/usr/bin/env python3 """ 测试优化后的数据加载管道性能 Test script for optimized data loading pipeline performance """ import os import time import psutil import tensorflow as tf from omegaconf import OmegaConf from dataset_tf import BrainToTextDatasetTF, train_test_split_indices, create_input_fn def get_memory_usage(): """获取当前内存使用情况""" process = psutil.Process() memory_info = process.memory_info() return memory_info.rss / 1024 / 1024 # MB def test_data_loading_performance(): """测试数据加载性能对比""" # 加载配置 config_path = "../rnn_args.yaml" if not os.path.exists(config_path): print("❌ Configuration file not found. Creating minimal test config...") # 创建最小测试配置 args = { 'dataset': { 'dataset_dir': '../data/hdf5_data_final', 'sessions': ['t15.2022.03.14', 't15.2022.03.16'], 'batch_size': 32, 'days_per_batch': 1, 'seed': 42, 'data_transforms': { 'smooth_data': False, 'white_noise_std': 0.0, 'constant_offset_std': 0.0, 'random_walk_std': 0.0, 'static_gain_std': 0.0, 'random_cut': 0 } }, 'num_training_batches': 10 # 只测试10个batch } else: args = OmegaConf.load(config_path) args = OmegaConf.to_container(args, resolve=True) # 限制测试batch数量 args['num_training_batches'] = 10 print("🔍 Starting data loading performance test...") print(f"📊 Test configuration: {args['num_training_batches']} batches, batch_size={args['dataset']['batch_size']}") # 获取文件路径 train_file_paths = [ os.path.join(args["dataset"]["dataset_dir"], s, 'data_train.hdf5') for s in args['dataset']['sessions'] ] print(f"📁 Testing with files: {train_file_paths}") # 检查文件是否存在 missing_files = [f for f in train_file_paths if not os.path.exists(f)] if missing_files: print(f"❌ Missing files: {missing_files}") print("⚠️ Creating dummy test data...") return test_with_dummy_data(args) # 分割数据 print("🔄 Splitting data...") train_trials, _ = train_test_split_indices( file_paths=train_file_paths, test_percentage=0, seed=args['dataset']['seed'] ) print(f"📈 Found {sum(len(trials['trials']) for trials in train_trials.values())} training trials") # 测试1: 不使用缓存 print("\n" + "="*60) print("🐌 TEST 1: 标准数据加载 (无缓存)") print("="*60) initial_memory = get_memory_usage() start_time = time.time() dataset_no_cache = BrainToTextDatasetTF( trial_indices=train_trials, n_batches=args['num_training_batches'], split='train', batch_size=args['dataset']['batch_size'], days_per_batch=args['dataset']['days_per_batch'], random_seed=args['dataset']['seed'], cache_data=False, # 禁用缓存 preload_all_data=False # 禁用预加载 ) tf_dataset_no_cache = create_input_fn( dataset_no_cache, args['dataset']['data_transforms'], training=True ) # 测试前3个batch的加载时间 batch_times = [] for i, batch in enumerate(tf_dataset_no_cache.take(3)): batch_start = time.time() # 触发实际数据加载 _ = batch['input_features'].numpy() batch_time = time.time() - batch_start batch_times.append(batch_time) print(f" Batch {i}: {batch_time:.3f}s") no_cache_time = time.time() - start_time no_cache_memory = get_memory_usage() - initial_memory print(f"💾 Memory usage: +{no_cache_memory:.1f} MB") print(f"⏱️ Total time: {no_cache_time:.3f}s") print(f"📊 Avg batch time: {sum(batch_times)/len(batch_times):.3f}s") # 测试2: 使用预加载缓存 print("\n" + "="*60) print("🚀 TEST 2: 优化数据加载 (全缓存预加载)") print("="*60) initial_memory = get_memory_usage() start_time = time.time() dataset_with_cache = BrainToTextDatasetTF( trial_indices=train_trials, n_batches=args['num_training_batches'], split='train', batch_size=args['dataset']['batch_size'], days_per_batch=args['dataset']['days_per_batch'], random_seed=args['dataset']['seed'], cache_data=True, # 启用缓存 preload_all_data=True # 启用预加载 ) preload_time = time.time() - start_time preload_memory = get_memory_usage() - initial_memory print(f"📝 Preloading completed in {preload_time:.3f}s") print(f"💾 Preloading memory: +{preload_memory:.1f} MB") tf_dataset_with_cache = create_input_fn( dataset_with_cache, args['dataset']['data_transforms'], training=True ) # 测试前3个batch的加载时间 batch_start_time = time.time() batch_times_cached = [] for i, batch in enumerate(tf_dataset_with_cache.take(3)): batch_start = time.time() # 触发实际数据加载 _ = batch['input_features'].numpy() batch_time = time.time() - batch_start batch_times_cached.append(batch_time) print(f" Batch {i}: {batch_time:.3f}s") cached_batch_time = time.time() - batch_start_time cached_memory = get_memory_usage() - initial_memory print(f"💾 Total memory usage: +{cached_memory:.1f} MB") print(f"⏱️ Batch loading time: {cached_batch_time:.3f}s") print(f"📊 Avg batch time: {sum(batch_times_cached)/len(batch_times_cached):.3f}s") # 性能对比 print("\n" + "="*60) print("📈 PERFORMANCE COMPARISON") print("="*60) speedup = (sum(batch_times)/len(batch_times)) / (sum(batch_times_cached)/len(batch_times_cached)) memory_cost = cached_memory - no_cache_memory print(f"🚀 Speed improvement: {speedup:.1f}x faster") print(f"💾 Memory cost: +{memory_cost:.1f} MB for caching") print(f"⚡ First batch time: {batch_times[0]:.3f}s → {batch_times_cached[0]:.3f}s") if speedup > 2: print("✅ Excellent! 缓存优化显著提升了数据加载速度") elif speedup > 1.5: print("✅ Good! 缓存优化有效提升了数据加载速度") else: print("⚠️ Warning: 缓存优化效果不明显,可能数据量太小") return True def test_with_dummy_data(args): """使用模拟数据进行测试""" print("🔧 Creating dummy data for testing...") # 创建模拟试验索引 dummy_trials = { 0: { 'trials': list(range(100)), # 100个模拟试验 'session_path': 'dummy_path' } } print("📊 Testing with dummy data (100 trials)...") # 测试缓存vs非缓存的初始化时间差异 print("\n🐌 Testing without cache...") start_time = time.time() dataset_no_cache = BrainToTextDatasetTF( trial_indices=dummy_trials, n_batches=5, split='train', batch_size=32, days_per_batch=1, random_seed=42, cache_data=False, preload_all_data=False ) no_cache_time = time.time() - start_time print(f" Initialization time: {no_cache_time:.3f}s") print("\n🚀 Testing with cache...") start_time = time.time() dataset_with_cache = BrainToTextDatasetTF( trial_indices=dummy_trials, n_batches=5, split='train', batch_size=32, days_per_batch=1, random_seed=42, cache_data=True, preload_all_data=True ) cache_time = time.time() - start_time print(f" Initialization time: {cache_time:.3f}s") print(f"\n✅ 缓存机制已成功集成到数据加载管道中") print(f"📝 实际性能需要用真实的HDF5数据进行测试") return True if __name__ == "__main__": print("🧪 Data Loading Performance Test") print("="*60) try: success = test_data_loading_performance() if success: print("\n🎉 Data loading optimization test completed successfully!") print("💡 你现在可以运行 train_model_tf.py 来享受快速的数据加载了") except Exception as e: print(f"\n❌ Test failed with error: {e}") import traceback traceback.print_exc()