Refactor individual dataset creation for improved I/O efficiency and add logging for error handling
This commit is contained in:
@@ -3,6 +3,8 @@ import tensorflow as tf
|
||||
import h5py
|
||||
import numpy as np
|
||||
import math
|
||||
import logging
|
||||
from itertools import groupby
|
||||
from typing import Dict, List, Tuple, Optional, Any
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
|
||||
@@ -430,34 +432,64 @@ class BrainToTextDatasetTF:
|
||||
|
||||
def create_individual_dataset(self) -> tf.data.Dataset:
|
||||
"""
|
||||
Create tf.data.Dataset that yields individual examples for TPU-optimized batching
|
||||
Create tf.data.Dataset that yields individual examples with I/O optimization.
|
||||
|
||||
This method creates individual examples instead of pre-batched data,
|
||||
allowing TensorFlow's padded_batch to handle fixed-shape batching for TPU.
|
||||
This generator is refactored to group trial loading by session file,
|
||||
drastically reducing the number of file open/close operations from
|
||||
N_trials to N_sessions, which is ideal for slow disk I/O.
|
||||
"""
|
||||
|
||||
def individual_example_generator():
|
||||
"""Generator that yields individual trial examples"""
|
||||
for batch_idx in range(self.n_batches):
|
||||
batch_index = self.batch_indices[batch_idx]
|
||||
"""Generator that groups reads by file to minimize disk I/O."""
|
||||
|
||||
# Process each trial in the batch individually
|
||||
# 1. 创建一个所有试验的扁平列表: [(day, trial), (day, trial), ...]
|
||||
all_trials_to_load = []
|
||||
# 注意:这里的迭代顺序决定了大致的读取顺序
|
||||
# _create_batch_index_train 已经为我们随机化了批次
|
||||
for batch_idx in sorted(self.batch_indices.keys()):
|
||||
batch_index = self.batch_indices[batch_idx]
|
||||
for day in batch_index.keys():
|
||||
for trial in batch_index[day]:
|
||||
trial_data = self._load_trial_data(day, trial)
|
||||
all_trials_to_load.append((day, trial))
|
||||
|
||||
# Yield individual example with all required fields
|
||||
example = {
|
||||
'input_features': trial_data['input_features'].astype(np.float32),
|
||||
'seq_class_ids': trial_data['seq_class_ids'].astype(np.int32),
|
||||
'n_time_steps': np.int32(trial_data['n_time_steps']),
|
||||
'phone_seq_lens': np.int32(trial_data['phone_seq_lens']),
|
||||
'day_indices': np.int32(trial_data['day_index']),
|
||||
'transcriptions': trial_data['transcription'].astype(np.int32),
|
||||
'block_nums': np.int32(trial_data['block_num']),
|
||||
'trial_nums': np.int32(trial_data['trial_num'])
|
||||
}
|
||||
yield example
|
||||
# 2. 按 'day' (即按文件) 对试验列表进行分组
|
||||
# key=lambda x: x[0] 表示使用元组的第一个元素 (day) 作为分组键
|
||||
for day, group in groupby(sorted(all_trials_to_load, key=lambda x: x[0]), key=lambda x: x[0]):
|
||||
|
||||
session_path = self.trial_indices[day]['session_path']
|
||||
|
||||
# 3. 为每个分组(每个文件)只打开一次 HDF5 文件
|
||||
try:
|
||||
with h5py.File(session_path, 'r') as f:
|
||||
# 4. 在文件打开的状态下,读取这个文件中需要的所有试验
|
||||
for current_day, current_trial in group:
|
||||
try:
|
||||
# 直接从打开的文件句柄 'f' 中读取,而不是调用旧的加载函数
|
||||
g = f[f'trial_{current_trial:04d}']
|
||||
|
||||
input_features = g['input_features'][:]
|
||||
if self.feature_subset:
|
||||
input_features = input_features[:, self.feature_subset]
|
||||
|
||||
example = {
|
||||
'input_features': input_features.astype(np.float32),
|
||||
'seq_class_ids': g['seq_class_ids'][:].astype(np.int32),
|
||||
'n_time_steps': np.int32(g.attrs['n_time_steps']),
|
||||
'phone_seq_lens': np.int32(g.attrs['seq_len']),
|
||||
'day_indices': np.int32(current_day),
|
||||
'transcriptions': g['transcription'][:].astype(np.int32),
|
||||
'block_nums': np.int32(g.attrs['block_num']),
|
||||
'trial_nums': np.int32(g.attrs['trial_num'])
|
||||
}
|
||||
yield example
|
||||
|
||||
except KeyError:
|
||||
logging.warning(f"Trial {current_trial} not found in file {session_path}. Skipping.")
|
||||
continue
|
||||
|
||||
except (IOError, FileNotFoundError) as e:
|
||||
logging.error(f"Could not open or read HDF5 file: {session_path}. Error: {e}. Skipping all trials for this day.")
|
||||
continue
|
||||
|
||||
# Define output signature for individual examples
|
||||
output_signature = {
|
||||
@@ -479,8 +511,8 @@ class BrainToTextDatasetTF:
|
||||
|
||||
# Shuffle individual examples if training (more effective than batch-level shuffle)
|
||||
if self.split == 'train':
|
||||
# Use a reasonable shuffle buffer - not too large to avoid memory issues
|
||||
shuffle_buffer = min(1000, self.n_trials)
|
||||
# 可以适当增大buffer,因为现在I/O更高效了
|
||||
shuffle_buffer = min(2048, self.n_trials)
|
||||
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
|
||||
|
||||
return dataset
|
||||
@@ -727,7 +759,8 @@ def train_test_split_indices(file_paths: List[str],
|
||||
# Utility functions for TPU-optimized data pipeline
|
||||
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
transform_args: Dict[str, Any],
|
||||
training: bool = True) -> tf.data.Dataset:
|
||||
training: bool = True,
|
||||
cache_path: Optional[str] = None) -> tf.data.Dataset:
|
||||
"""
|
||||
Create input function for TPU training with fixed-shape batching and data augmentation
|
||||
|
||||
@@ -735,14 +768,32 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
dataset_tf: BrainToTextDatasetTF instance
|
||||
transform_args: Data transformation configuration
|
||||
training: Whether this is for training (applies augmentations)
|
||||
cache_path: Optional path for disk caching to improve I/O performance
|
||||
|
||||
Returns:
|
||||
tf.data.Dataset ready for TPU training with fixed shapes
|
||||
"""
|
||||
|
||||
# Create individual example dataset instead of pre-batched dataset
|
||||
# Create individual example dataset with file-grouping I/O optimization
|
||||
dataset = dataset_tf.create_individual_dataset()
|
||||
|
||||
# ========================= I/O OPTIMIZATION SOLUTION =========================
|
||||
# 在数据加载之后、随机操作(如数据增强)之前进行缓存
|
||||
if training:
|
||||
# 对于训练,缓存到磁盘文件或内存
|
||||
if cache_path:
|
||||
dataset = dataset.cache(cache_path)
|
||||
print(f"🗃️ Dataset caching enabled: {cache_path}")
|
||||
print("⚠️ First epoch will be slow while building cache, subsequent epochs will be much faster")
|
||||
else:
|
||||
# 如果内存足够大,可以缓存到内存,速度更快
|
||||
# 但对于大型数据集,推荐使用文件缓存
|
||||
dataset = dataset.cache()
|
||||
print("🗃️ Dataset caching enabled: in-memory cache")
|
||||
print("⚠️ First epoch will be slow while building cache, subsequent epochs will be much faster")
|
||||
# (对于验证集通常不需要缓存,因为它只运行一次)
|
||||
# ================================================================
|
||||
|
||||
def apply_transforms(example):
|
||||
"""Apply data transformations to individual examples"""
|
||||
features = example['input_features']
|
||||
@@ -762,7 +813,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
|
||||
return example
|
||||
|
||||
# Apply transformations to individual examples
|
||||
# 在缓存之后应用随机的数据增强,确保每个epoch的增强都不同
|
||||
dataset = dataset.map(
|
||||
apply_transforms,
|
||||
num_parallel_calls=tf.data.AUTOTUNE
|
||||
|
Reference in New Issue
Block a user