From 558be0ad9851e19b5b1edf0b9b6b2818057698ba Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Sun, 19 Oct 2025 10:31:18 +0800 Subject: [PATCH] Refactor individual dataset creation for improved I/O efficiency and add logging for error handling --- model_training_nnn_tpu/dataset_tf.py | 101 ++++++++++++++++++++------- 1 file changed, 76 insertions(+), 25 deletions(-) diff --git a/model_training_nnn_tpu/dataset_tf.py b/model_training_nnn_tpu/dataset_tf.py index ca4ceea..ffd943a 100644 --- a/model_training_nnn_tpu/dataset_tf.py +++ b/model_training_nnn_tpu/dataset_tf.py @@ -3,6 +3,8 @@ import tensorflow as tf import h5py import numpy as np import math +import logging +from itertools import groupby from typing import Dict, List, Tuple, Optional, Any from scipy.ndimage import gaussian_filter1d @@ -430,34 +432,64 @@ class BrainToTextDatasetTF: def create_individual_dataset(self) -> tf.data.Dataset: """ - Create tf.data.Dataset that yields individual examples for TPU-optimized batching + Create tf.data.Dataset that yields individual examples with I/O optimization. - This method creates individual examples instead of pre-batched data, - allowing TensorFlow's padded_batch to handle fixed-shape batching for TPU. + This generator is refactored to group trial loading by session file, + drastically reducing the number of file open/close operations from + N_trials to N_sessions, which is ideal for slow disk I/O. """ def individual_example_generator(): - """Generator that yields individual trial examples""" - for batch_idx in range(self.n_batches): - batch_index = self.batch_indices[batch_idx] + """Generator that groups reads by file to minimize disk I/O.""" - # Process each trial in the batch individually + # 1. 创建一个所有试验的扁平列表: [(day, trial), (day, trial), ...] + all_trials_to_load = [] + # 注意:这里的迭代顺序决定了大致的读取顺序 + # _create_batch_index_train 已经为我们随机化了批次 + for batch_idx in sorted(self.batch_indices.keys()): + batch_index = self.batch_indices[batch_idx] for day in batch_index.keys(): for trial in batch_index[day]: - trial_data = self._load_trial_data(day, trial) + all_trials_to_load.append((day, trial)) - # Yield individual example with all required fields - example = { - 'input_features': trial_data['input_features'].astype(np.float32), - 'seq_class_ids': trial_data['seq_class_ids'].astype(np.int32), - 'n_time_steps': np.int32(trial_data['n_time_steps']), - 'phone_seq_lens': np.int32(trial_data['phone_seq_lens']), - 'day_indices': np.int32(trial_data['day_index']), - 'transcriptions': trial_data['transcription'].astype(np.int32), - 'block_nums': np.int32(trial_data['block_num']), - 'trial_nums': np.int32(trial_data['trial_num']) - } - yield example + # 2. 按 'day' (即按文件) 对试验列表进行分组 + # key=lambda x: x[0] 表示使用元组的第一个元素 (day) 作为分组键 + for day, group in groupby(sorted(all_trials_to_load, key=lambda x: x[0]), key=lambda x: x[0]): + + session_path = self.trial_indices[day]['session_path'] + + # 3. 为每个分组(每个文件)只打开一次 HDF5 文件 + try: + with h5py.File(session_path, 'r') as f: + # 4. 在文件打开的状态下,读取这个文件中需要的所有试验 + for current_day, current_trial in group: + try: + # 直接从打开的文件句柄 'f' 中读取,而不是调用旧的加载函数 + g = f[f'trial_{current_trial:04d}'] + + input_features = g['input_features'][:] + if self.feature_subset: + input_features = input_features[:, self.feature_subset] + + example = { + 'input_features': input_features.astype(np.float32), + 'seq_class_ids': g['seq_class_ids'][:].astype(np.int32), + 'n_time_steps': np.int32(g.attrs['n_time_steps']), + 'phone_seq_lens': np.int32(g.attrs['seq_len']), + 'day_indices': np.int32(current_day), + 'transcriptions': g['transcription'][:].astype(np.int32), + 'block_nums': np.int32(g.attrs['block_num']), + 'trial_nums': np.int32(g.attrs['trial_num']) + } + yield example + + except KeyError: + logging.warning(f"Trial {current_trial} not found in file {session_path}. Skipping.") + continue + + except (IOError, FileNotFoundError) as e: + logging.error(f"Could not open or read HDF5 file: {session_path}. Error: {e}. Skipping all trials for this day.") + continue # Define output signature for individual examples output_signature = { @@ -479,8 +511,8 @@ class BrainToTextDatasetTF: # Shuffle individual examples if training (more effective than batch-level shuffle) if self.split == 'train': - # Use a reasonable shuffle buffer - not too large to avoid memory issues - shuffle_buffer = min(1000, self.n_trials) + # 可以适当增大buffer,因为现在I/O更高效了 + shuffle_buffer = min(2048, self.n_trials) dataset = dataset.shuffle(buffer_size=shuffle_buffer) return dataset @@ -727,7 +759,8 @@ def train_test_split_indices(file_paths: List[str], # Utility functions for TPU-optimized data pipeline def create_input_fn(dataset_tf: BrainToTextDatasetTF, transform_args: Dict[str, Any], - training: bool = True) -> tf.data.Dataset: + training: bool = True, + cache_path: Optional[str] = None) -> tf.data.Dataset: """ Create input function for TPU training with fixed-shape batching and data augmentation @@ -735,14 +768,32 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, dataset_tf: BrainToTextDatasetTF instance transform_args: Data transformation configuration training: Whether this is for training (applies augmentations) + cache_path: Optional path for disk caching to improve I/O performance Returns: tf.data.Dataset ready for TPU training with fixed shapes """ - # Create individual example dataset instead of pre-batched dataset + # Create individual example dataset with file-grouping I/O optimization dataset = dataset_tf.create_individual_dataset() + # ========================= I/O OPTIMIZATION SOLUTION ========================= + # 在数据加载之后、随机操作(如数据增强)之前进行缓存 + if training: + # 对于训练,缓存到磁盘文件或内存 + if cache_path: + dataset = dataset.cache(cache_path) + print(f"🗃️ Dataset caching enabled: {cache_path}") + print("⚠️ First epoch will be slow while building cache, subsequent epochs will be much faster") + else: + # 如果内存足够大,可以缓存到内存,速度更快 + # 但对于大型数据集,推荐使用文件缓存 + dataset = dataset.cache() + print("🗃️ Dataset caching enabled: in-memory cache") + print("⚠️ First epoch will be slow while building cache, subsequent epochs will be much faster") + # (对于验证集通常不需要缓存,因为它只运行一次) + # ================================================================ + def apply_transforms(example): """Apply data transformations to individual examples""" features = example['input_features'] @@ -762,7 +813,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, return example - # Apply transformations to individual examples + # 在缓存之后应用随机的数据增强,确保每个epoch的增强都不同 dataset = dataset.map( apply_transforms, num_parallel_calls=tf.data.AUTOTUNE