修复数据加载器低效问题
This commit is contained in:
254
model_training_nnn_tpu/test_data_loading.py
Normal file
254
model_training_nnn_tpu/test_data_loading.py
Normal file
@@ -0,0 +1,254 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
测试优化后的数据加载管道性能
|
||||
Test script for optimized data loading pipeline performance
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import psutil
|
||||
import tensorflow as tf
|
||||
from omegaconf import OmegaConf
|
||||
from dataset_tf import BrainToTextDatasetTF, train_test_split_indices, create_input_fn
|
||||
|
||||
def get_memory_usage():
|
||||
"""获取当前内存使用情况"""
|
||||
process = psutil.Process()
|
||||
memory_info = process.memory_info()
|
||||
return memory_info.rss / 1024 / 1024 # MB
|
||||
|
||||
def test_data_loading_performance():
|
||||
"""测试数据加载性能对比"""
|
||||
|
||||
# 加载配置
|
||||
config_path = "../rnn_args.yaml"
|
||||
if not os.path.exists(config_path):
|
||||
print("❌ Configuration file not found. Creating minimal test config...")
|
||||
# 创建最小测试配置
|
||||
args = {
|
||||
'dataset': {
|
||||
'dataset_dir': '../data/hdf5_data_final',
|
||||
'sessions': ['t15.2022.03.14', 't15.2022.03.16'],
|
||||
'batch_size': 32,
|
||||
'days_per_batch': 1,
|
||||
'seed': 42,
|
||||
'data_transforms': {
|
||||
'smooth_data': False,
|
||||
'white_noise_std': 0.0,
|
||||
'constant_offset_std': 0.0,
|
||||
'random_walk_std': 0.0,
|
||||
'static_gain_std': 0.0,
|
||||
'random_cut': 0
|
||||
}
|
||||
},
|
||||
'num_training_batches': 10 # 只测试10个batch
|
||||
}
|
||||
else:
|
||||
args = OmegaConf.load(config_path)
|
||||
args = OmegaConf.to_container(args, resolve=True)
|
||||
# 限制测试batch数量
|
||||
args['num_training_batches'] = 10
|
||||
|
||||
print("🔍 Starting data loading performance test...")
|
||||
print(f"📊 Test configuration: {args['num_training_batches']} batches, batch_size={args['dataset']['batch_size']}")
|
||||
|
||||
# 获取文件路径
|
||||
train_file_paths = [
|
||||
os.path.join(args["dataset"]["dataset_dir"], s, 'data_train.hdf5')
|
||||
for s in args['dataset']['sessions']
|
||||
]
|
||||
|
||||
print(f"📁 Testing with files: {train_file_paths}")
|
||||
|
||||
# 检查文件是否存在
|
||||
missing_files = [f for f in train_file_paths if not os.path.exists(f)]
|
||||
if missing_files:
|
||||
print(f"❌ Missing files: {missing_files}")
|
||||
print("⚠️ Creating dummy test data...")
|
||||
return test_with_dummy_data(args)
|
||||
|
||||
# 分割数据
|
||||
print("🔄 Splitting data...")
|
||||
train_trials, _ = train_test_split_indices(
|
||||
file_paths=train_file_paths,
|
||||
test_percentage=0,
|
||||
seed=args['dataset']['seed']
|
||||
)
|
||||
|
||||
print(f"📈 Found {sum(len(trials['trials']) for trials in train_trials.values())} training trials")
|
||||
|
||||
# 测试1: 不使用缓存
|
||||
print("\n" + "="*60)
|
||||
print("🐌 TEST 1: 标准数据加载 (无缓存)")
|
||||
print("="*60)
|
||||
|
||||
initial_memory = get_memory_usage()
|
||||
start_time = time.time()
|
||||
|
||||
dataset_no_cache = BrainToTextDatasetTF(
|
||||
trial_indices=train_trials,
|
||||
n_batches=args['num_training_batches'],
|
||||
split='train',
|
||||
batch_size=args['dataset']['batch_size'],
|
||||
days_per_batch=args['dataset']['days_per_batch'],
|
||||
random_seed=args['dataset']['seed'],
|
||||
cache_data=False, # 禁用缓存
|
||||
preload_all_data=False # 禁用预加载
|
||||
)
|
||||
|
||||
tf_dataset_no_cache = create_input_fn(
|
||||
dataset_no_cache,
|
||||
args['dataset']['data_transforms'],
|
||||
training=True
|
||||
)
|
||||
|
||||
# 测试前3个batch的加载时间
|
||||
batch_times = []
|
||||
for i, batch in enumerate(tf_dataset_no_cache.take(3)):
|
||||
batch_start = time.time()
|
||||
# 触发实际数据加载
|
||||
_ = batch['input_features'].numpy()
|
||||
batch_time = time.time() - batch_start
|
||||
batch_times.append(batch_time)
|
||||
print(f" Batch {i}: {batch_time:.3f}s")
|
||||
|
||||
no_cache_time = time.time() - start_time
|
||||
no_cache_memory = get_memory_usage() - initial_memory
|
||||
|
||||
print(f"💾 Memory usage: +{no_cache_memory:.1f} MB")
|
||||
print(f"⏱️ Total time: {no_cache_time:.3f}s")
|
||||
print(f"📊 Avg batch time: {sum(batch_times)/len(batch_times):.3f}s")
|
||||
|
||||
# 测试2: 使用预加载缓存
|
||||
print("\n" + "="*60)
|
||||
print("🚀 TEST 2: 优化数据加载 (全缓存预加载)")
|
||||
print("="*60)
|
||||
|
||||
initial_memory = get_memory_usage()
|
||||
start_time = time.time()
|
||||
|
||||
dataset_with_cache = BrainToTextDatasetTF(
|
||||
trial_indices=train_trials,
|
||||
n_batches=args['num_training_batches'],
|
||||
split='train',
|
||||
batch_size=args['dataset']['batch_size'],
|
||||
days_per_batch=args['dataset']['days_per_batch'],
|
||||
random_seed=args['dataset']['seed'],
|
||||
cache_data=True, # 启用缓存
|
||||
preload_all_data=True # 启用预加载
|
||||
)
|
||||
|
||||
preload_time = time.time() - start_time
|
||||
preload_memory = get_memory_usage() - initial_memory
|
||||
|
||||
print(f"📝 Preloading completed in {preload_time:.3f}s")
|
||||
print(f"💾 Preloading memory: +{preload_memory:.1f} MB")
|
||||
|
||||
tf_dataset_with_cache = create_input_fn(
|
||||
dataset_with_cache,
|
||||
args['dataset']['data_transforms'],
|
||||
training=True
|
||||
)
|
||||
|
||||
# 测试前3个batch的加载时间
|
||||
batch_start_time = time.time()
|
||||
batch_times_cached = []
|
||||
for i, batch in enumerate(tf_dataset_with_cache.take(3)):
|
||||
batch_start = time.time()
|
||||
# 触发实际数据加载
|
||||
_ = batch['input_features'].numpy()
|
||||
batch_time = time.time() - batch_start
|
||||
batch_times_cached.append(batch_time)
|
||||
print(f" Batch {i}: {batch_time:.3f}s")
|
||||
|
||||
cached_batch_time = time.time() - batch_start_time
|
||||
cached_memory = get_memory_usage() - initial_memory
|
||||
|
||||
print(f"💾 Total memory usage: +{cached_memory:.1f} MB")
|
||||
print(f"⏱️ Batch loading time: {cached_batch_time:.3f}s")
|
||||
print(f"📊 Avg batch time: {sum(batch_times_cached)/len(batch_times_cached):.3f}s")
|
||||
|
||||
# 性能对比
|
||||
print("\n" + "="*60)
|
||||
print("📈 PERFORMANCE COMPARISON")
|
||||
print("="*60)
|
||||
|
||||
speedup = (sum(batch_times)/len(batch_times)) / (sum(batch_times_cached)/len(batch_times_cached))
|
||||
memory_cost = cached_memory - no_cache_memory
|
||||
|
||||
print(f"🚀 Speed improvement: {speedup:.1f}x faster")
|
||||
print(f"💾 Memory cost: +{memory_cost:.1f} MB for caching")
|
||||
print(f"⚡ First batch time: {batch_times[0]:.3f}s → {batch_times_cached[0]:.3f}s")
|
||||
|
||||
if speedup > 2:
|
||||
print("✅ Excellent! 缓存优化显著提升了数据加载速度")
|
||||
elif speedup > 1.5:
|
||||
print("✅ Good! 缓存优化有效提升了数据加载速度")
|
||||
else:
|
||||
print("⚠️ Warning: 缓存优化效果不明显,可能数据量太小")
|
||||
|
||||
return True
|
||||
|
||||
def test_with_dummy_data(args):
|
||||
"""使用模拟数据进行测试"""
|
||||
print("🔧 Creating dummy data for testing...")
|
||||
|
||||
# 创建模拟试验索引
|
||||
dummy_trials = {
|
||||
0: {
|
||||
'trials': list(range(100)), # 100个模拟试验
|
||||
'session_path': 'dummy_path'
|
||||
}
|
||||
}
|
||||
|
||||
print("📊 Testing with dummy data (100 trials)...")
|
||||
|
||||
# 测试缓存vs非缓存的初始化时间差异
|
||||
print("\n🐌 Testing without cache...")
|
||||
start_time = time.time()
|
||||
dataset_no_cache = BrainToTextDatasetTF(
|
||||
trial_indices=dummy_trials,
|
||||
n_batches=5,
|
||||
split='train',
|
||||
batch_size=32,
|
||||
days_per_batch=1,
|
||||
random_seed=42,
|
||||
cache_data=False,
|
||||
preload_all_data=False
|
||||
)
|
||||
no_cache_time = time.time() - start_time
|
||||
print(f" Initialization time: {no_cache_time:.3f}s")
|
||||
|
||||
print("\n🚀 Testing with cache...")
|
||||
start_time = time.time()
|
||||
dataset_with_cache = BrainToTextDatasetTF(
|
||||
trial_indices=dummy_trials,
|
||||
n_batches=5,
|
||||
split='train',
|
||||
batch_size=32,
|
||||
days_per_batch=1,
|
||||
random_seed=42,
|
||||
cache_data=True,
|
||||
preload_all_data=True
|
||||
)
|
||||
cache_time = time.time() - start_time
|
||||
print(f" Initialization time: {cache_time:.3f}s")
|
||||
|
||||
print(f"\n✅ 缓存机制已成功集成到数据加载管道中")
|
||||
print(f"📝 实际性能需要用真实的HDF5数据进行测试")
|
||||
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🧪 Data Loading Performance Test")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
success = test_data_loading_performance()
|
||||
if success:
|
||||
print("\n🎉 Data loading optimization test completed successfully!")
|
||||
print("💡 你现在可以运行 train_model_tf.py 来享受快速的数据加载了")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Test failed with error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
Reference in New Issue
Block a user