Refactor individual dataset creation for improved I/O efficiency and add logging for error handling

This commit is contained in:
Zchen
2025-10-19 10:31:18 +08:00
parent d83f990beb
commit 558be0ad98

View File

@@ -3,6 +3,8 @@ import tensorflow as tf
import h5py
import numpy as np
import math
import logging
from itertools import groupby
from typing import Dict, List, Tuple, Optional, Any
from scipy.ndimage import gaussian_filter1d
@@ -430,35 +432,65 @@ class BrainToTextDatasetTF:
def create_individual_dataset(self) -> tf.data.Dataset:
"""
Create tf.data.Dataset that yields individual examples for TPU-optimized batching
Create tf.data.Dataset that yields individual examples with I/O optimization.
This method creates individual examples instead of pre-batched data,
allowing TensorFlow's padded_batch to handle fixed-shape batching for TPU.
This generator is refactored to group trial loading by session file,
drastically reducing the number of file open/close operations from
N_trials to N_sessions, which is ideal for slow disk I/O.
"""
def individual_example_generator():
"""Generator that yields individual trial examples"""
for batch_idx in range(self.n_batches):
batch_index = self.batch_indices[batch_idx]
"""Generator that groups reads by file to minimize disk I/O."""
# Process each trial in the batch individually
# 1. 创建一个所有试验的扁平列表: [(day, trial), (day, trial), ...]
all_trials_to_load = []
# 注意:这里的迭代顺序决定了大致的读取顺序
# _create_batch_index_train 已经为我们随机化了批次
for batch_idx in sorted(self.batch_indices.keys()):
batch_index = self.batch_indices[batch_idx]
for day in batch_index.keys():
for trial in batch_index[day]:
trial_data = self._load_trial_data(day, trial)
all_trials_to_load.append((day, trial))
# 2. 按 'day' (即按文件) 对试验列表进行分组
# key=lambda x: x[0] 表示使用元组的第一个元素 (day) 作为分组键
for day, group in groupby(sorted(all_trials_to_load, key=lambda x: x[0]), key=lambda x: x[0]):
session_path = self.trial_indices[day]['session_path']
# 3. 为每个分组(每个文件)只打开一次 HDF5 文件
try:
with h5py.File(session_path, 'r') as f:
# 4. 在文件打开的状态下,读取这个文件中需要的所有试验
for current_day, current_trial in group:
try:
# 直接从打开的文件句柄 'f' 中读取,而不是调用旧的加载函数
g = f[f'trial_{current_trial:04d}']
input_features = g['input_features'][:]
if self.feature_subset:
input_features = input_features[:, self.feature_subset]
# Yield individual example with all required fields
example = {
'input_features': trial_data['input_features'].astype(np.float32),
'seq_class_ids': trial_data['seq_class_ids'].astype(np.int32),
'n_time_steps': np.int32(trial_data['n_time_steps']),
'phone_seq_lens': np.int32(trial_data['phone_seq_lens']),
'day_indices': np.int32(trial_data['day_index']),
'transcriptions': trial_data['transcription'].astype(np.int32),
'block_nums': np.int32(trial_data['block_num']),
'trial_nums': np.int32(trial_data['trial_num'])
'input_features': input_features.astype(np.float32),
'seq_class_ids': g['seq_class_ids'][:].astype(np.int32),
'n_time_steps': np.int32(g.attrs['n_time_steps']),
'phone_seq_lens': np.int32(g.attrs['seq_len']),
'day_indices': np.int32(current_day),
'transcriptions': g['transcription'][:].astype(np.int32),
'block_nums': np.int32(g.attrs['block_num']),
'trial_nums': np.int32(g.attrs['trial_num'])
}
yield example
except KeyError:
logging.warning(f"Trial {current_trial} not found in file {session_path}. Skipping.")
continue
except (IOError, FileNotFoundError) as e:
logging.error(f"Could not open or read HDF5 file: {session_path}. Error: {e}. Skipping all trials for this day.")
continue
# Define output signature for individual examples
output_signature = {
'input_features': tf.TensorSpec(shape=(None, None), dtype=tf.float32),
@@ -479,8 +511,8 @@ class BrainToTextDatasetTF:
# Shuffle individual examples if training (more effective than batch-level shuffle)
if self.split == 'train':
# Use a reasonable shuffle buffer - not too large to avoid memory issues
shuffle_buffer = min(1000, self.n_trials)
# 可以适当增大buffer因为现在I/O更高效了
shuffle_buffer = min(2048, self.n_trials)
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
return dataset
@@ -727,7 +759,8 @@ def train_test_split_indices(file_paths: List[str],
# Utility functions for TPU-optimized data pipeline
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
transform_args: Dict[str, Any],
training: bool = True) -> tf.data.Dataset:
training: bool = True,
cache_path: Optional[str] = None) -> tf.data.Dataset:
"""
Create input function for TPU training with fixed-shape batching and data augmentation
@@ -735,14 +768,32 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
dataset_tf: BrainToTextDatasetTF instance
transform_args: Data transformation configuration
training: Whether this is for training (applies augmentations)
cache_path: Optional path for disk caching to improve I/O performance
Returns:
tf.data.Dataset ready for TPU training with fixed shapes
"""
# Create individual example dataset instead of pre-batched dataset
# Create individual example dataset with file-grouping I/O optimization
dataset = dataset_tf.create_individual_dataset()
# ========================= I/O OPTIMIZATION SOLUTION =========================
# 在数据加载之后、随机操作(如数据增强)之前进行缓存
if training:
# 对于训练,缓存到磁盘文件或内存
if cache_path:
dataset = dataset.cache(cache_path)
print(f"🗃️ Dataset caching enabled: {cache_path}")
print("⚠️ First epoch will be slow while building cache, subsequent epochs will be much faster")
else:
# 如果内存足够大,可以缓存到内存,速度更快
# 但对于大型数据集,推荐使用文件缓存
dataset = dataset.cache()
print("🗃️ Dataset caching enabled: in-memory cache")
print("⚠️ First epoch will be slow while building cache, subsequent epochs will be much faster")
# (对于验证集通常不需要缓存,因为它只运行一次)
# ================================================================
def apply_transforms(example):
"""Apply data transformations to individual examples"""
features = example['input_features']
@@ -762,7 +813,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
return example
# Apply transformations to individual examples
# 在缓存之后应用随机的数据增强确保每个epoch的增强都不同
dataset = dataset.map(
apply_transforms,
num_parallel_calls=tf.data.AUTOTUNE