Refactor individual dataset creation for improved I/O efficiency and add logging for error handling

This commit is contained in:
Zchen
2025-10-19 10:31:18 +08:00
parent d83f990beb
commit 558be0ad98

View File

@@ -3,6 +3,8 @@ import tensorflow as tf
import h5py import h5py
import numpy as np import numpy as np
import math import math
import logging
from itertools import groupby
from typing import Dict, List, Tuple, Optional, Any from typing import Dict, List, Tuple, Optional, Any
from scipy.ndimage import gaussian_filter1d from scipy.ndimage import gaussian_filter1d
@@ -430,34 +432,64 @@ class BrainToTextDatasetTF:
def create_individual_dataset(self) -> tf.data.Dataset: def create_individual_dataset(self) -> tf.data.Dataset:
""" """
Create tf.data.Dataset that yields individual examples for TPU-optimized batching Create tf.data.Dataset that yields individual examples with I/O optimization.
This method creates individual examples instead of pre-batched data, This generator is refactored to group trial loading by session file,
allowing TensorFlow's padded_batch to handle fixed-shape batching for TPU. drastically reducing the number of file open/close operations from
N_trials to N_sessions, which is ideal for slow disk I/O.
""" """
def individual_example_generator(): def individual_example_generator():
"""Generator that yields individual trial examples""" """Generator that groups reads by file to minimize disk I/O."""
for batch_idx in range(self.n_batches):
batch_index = self.batch_indices[batch_idx]
# Process each trial in the batch individually # 1. 创建一个所有试验的扁平列表: [(day, trial), (day, trial), ...]
all_trials_to_load = []
# 注意:这里的迭代顺序决定了大致的读取顺序
# _create_batch_index_train 已经为我们随机化了批次
for batch_idx in sorted(self.batch_indices.keys()):
batch_index = self.batch_indices[batch_idx]
for day in batch_index.keys(): for day in batch_index.keys():
for trial in batch_index[day]: for trial in batch_index[day]:
trial_data = self._load_trial_data(day, trial) all_trials_to_load.append((day, trial))
# Yield individual example with all required fields # 2. 按 'day' (即按文件) 对试验列表进行分组
example = { # key=lambda x: x[0] 表示使用元组的第一个元素 (day) 作为分组键
'input_features': trial_data['input_features'].astype(np.float32), for day, group in groupby(sorted(all_trials_to_load, key=lambda x: x[0]), key=lambda x: x[0]):
'seq_class_ids': trial_data['seq_class_ids'].astype(np.int32),
'n_time_steps': np.int32(trial_data['n_time_steps']), session_path = self.trial_indices[day]['session_path']
'phone_seq_lens': np.int32(trial_data['phone_seq_lens']),
'day_indices': np.int32(trial_data['day_index']), # 3. 为每个分组(每个文件)只打开一次 HDF5 文件
'transcriptions': trial_data['transcription'].astype(np.int32), try:
'block_nums': np.int32(trial_data['block_num']), with h5py.File(session_path, 'r') as f:
'trial_nums': np.int32(trial_data['trial_num']) # 4. 在文件打开的状态下,读取这个文件中需要的所有试验
} for current_day, current_trial in group:
yield example try:
# 直接从打开的文件句柄 'f' 中读取,而不是调用旧的加载函数
g = f[f'trial_{current_trial:04d}']
input_features = g['input_features'][:]
if self.feature_subset:
input_features = input_features[:, self.feature_subset]
example = {
'input_features': input_features.astype(np.float32),
'seq_class_ids': g['seq_class_ids'][:].astype(np.int32),
'n_time_steps': np.int32(g.attrs['n_time_steps']),
'phone_seq_lens': np.int32(g.attrs['seq_len']),
'day_indices': np.int32(current_day),
'transcriptions': g['transcription'][:].astype(np.int32),
'block_nums': np.int32(g.attrs['block_num']),
'trial_nums': np.int32(g.attrs['trial_num'])
}
yield example
except KeyError:
logging.warning(f"Trial {current_trial} not found in file {session_path}. Skipping.")
continue
except (IOError, FileNotFoundError) as e:
logging.error(f"Could not open or read HDF5 file: {session_path}. Error: {e}. Skipping all trials for this day.")
continue
# Define output signature for individual examples # Define output signature for individual examples
output_signature = { output_signature = {
@@ -479,8 +511,8 @@ class BrainToTextDatasetTF:
# Shuffle individual examples if training (more effective than batch-level shuffle) # Shuffle individual examples if training (more effective than batch-level shuffle)
if self.split == 'train': if self.split == 'train':
# Use a reasonable shuffle buffer - not too large to avoid memory issues # 可以适当增大buffer因为现在I/O更高效了
shuffle_buffer = min(1000, self.n_trials) shuffle_buffer = min(2048, self.n_trials)
dataset = dataset.shuffle(buffer_size=shuffle_buffer) dataset = dataset.shuffle(buffer_size=shuffle_buffer)
return dataset return dataset
@@ -727,7 +759,8 @@ def train_test_split_indices(file_paths: List[str],
# Utility functions for TPU-optimized data pipeline # Utility functions for TPU-optimized data pipeline
def create_input_fn(dataset_tf: BrainToTextDatasetTF, def create_input_fn(dataset_tf: BrainToTextDatasetTF,
transform_args: Dict[str, Any], transform_args: Dict[str, Any],
training: bool = True) -> tf.data.Dataset: training: bool = True,
cache_path: Optional[str] = None) -> tf.data.Dataset:
""" """
Create input function for TPU training with fixed-shape batching and data augmentation Create input function for TPU training with fixed-shape batching and data augmentation
@@ -735,14 +768,32 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
dataset_tf: BrainToTextDatasetTF instance dataset_tf: BrainToTextDatasetTF instance
transform_args: Data transformation configuration transform_args: Data transformation configuration
training: Whether this is for training (applies augmentations) training: Whether this is for training (applies augmentations)
cache_path: Optional path for disk caching to improve I/O performance
Returns: Returns:
tf.data.Dataset ready for TPU training with fixed shapes tf.data.Dataset ready for TPU training with fixed shapes
""" """
# Create individual example dataset instead of pre-batched dataset # Create individual example dataset with file-grouping I/O optimization
dataset = dataset_tf.create_individual_dataset() dataset = dataset_tf.create_individual_dataset()
# ========================= I/O OPTIMIZATION SOLUTION =========================
# 在数据加载之后、随机操作(如数据增强)之前进行缓存
if training:
# 对于训练,缓存到磁盘文件或内存
if cache_path:
dataset = dataset.cache(cache_path)
print(f"🗃️ Dataset caching enabled: {cache_path}")
print("⚠️ First epoch will be slow while building cache, subsequent epochs will be much faster")
else:
# 如果内存足够大,可以缓存到内存,速度更快
# 但对于大型数据集,推荐使用文件缓存
dataset = dataset.cache()
print("🗃️ Dataset caching enabled: in-memory cache")
print("⚠️ First epoch will be slow while building cache, subsequent epochs will be much faster")
# (对于验证集通常不需要缓存,因为它只运行一次)
# ================================================================
def apply_transforms(example): def apply_transforms(example):
"""Apply data transformations to individual examples""" """Apply data transformations to individual examples"""
features = example['input_features'] features = example['input_features']
@@ -762,7 +813,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
return example return example
# Apply transformations to individual examples # 在缓存之后应用随机的数据增强确保每个epoch的增强都不同
dataset = dataset.map( dataset = dataset.map(
apply_transforms, apply_transforms,
num_parallel_calls=tf.data.AUTOTUNE num_parallel_calls=tf.data.AUTOTUNE