修复数据加载器低效问题

This commit is contained in:
Zchen
2025-10-16 17:14:06 +08:00
parent a545cc5648
commit be578f2e1d
3 changed files with 412 additions and 13 deletions

View File

@@ -323,7 +323,13 @@ class BrainToTextDecoderTrainerTF:
with open(os.path.join(self.args['output_dir'], 'train_val_trials.json'), 'w') as f:
json.dump({'train': train_trials, 'val': val_trials}, f)
# Create TensorFlow datasets
# Create TensorFlow datasets with aggressive data preloading for TPU optimization
# Monitor memory usage during data preloading
import psutil
initial_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
print("🔄 Initializing training dataset with full data preloading...")
preload_start_time = time.time()
self.train_dataset_tf = BrainToTextDatasetTF(
trial_indices=train_trials,
n_batches=self.args['num_training_batches'],
@@ -332,9 +338,19 @@ class BrainToTextDecoderTrainerTF:
days_per_batch=self.args['dataset']['days_per_batch'],
random_seed=self.args['dataset']['seed'],
must_include_days=self.args['dataset'].get('must_include_days'),
feature_subset=self.args['dataset'].get('feature_subset')
feature_subset=self.args['dataset'].get('feature_subset'),
cache_data=True, # 启用数据缓存
preload_all_data=True # 一次性加载所有训练数据到内存
)
# Log training data preloading performance
train_preload_time = time.time() - preload_start_time
train_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
train_memory_used = train_memory_mb - initial_memory_mb
print(f"✅ Training data preloaded in {train_preload_time:.2f}s, using {train_memory_used:.1f} MB RAM")
print("🔄 Initializing validation dataset with caching...")
val_preload_start_time = time.time()
self.val_dataset_tf = BrainToTextDatasetTF(
trial_indices=val_trials,
n_batches=None, # Use all validation data
@@ -342,9 +358,19 @@ class BrainToTextDecoderTrainerTF:
batch_size=self.args['dataset']['batch_size'],
days_per_batch=1, # One day per validation batch
random_seed=self.args['dataset']['seed'],
feature_subset=self.args['dataset'].get('feature_subset')
feature_subset=self.args['dataset'].get('feature_subset'),
cache_data=True, # 启用数据缓存
preload_all_data=True # 一次性加载所有验证数据到内存
)
# Log validation data preloading performance
val_preload_time = time.time() - val_preload_start_time
final_memory_mb = psutil.Process().memory_info().rss / 1024 / 1024
total_memory_used = final_memory_mb - initial_memory_mb
val_memory_used = final_memory_mb - train_memory_mb
print(f"✅ Validation data preloaded in {val_preload_time:.2f}s, using {val_memory_used:.1f} MB RAM")
print(f"📊 Total data cache: {total_memory_used:.1f} MB RAM used for all datasets")
self.logger.info("Successfully initialized TensorFlow datasets")
def _build_model(self) -> TripleGRUDecoder: