Refactor individual dataset creation for improved I/O efficiency and add logging for error handling
This commit is contained in:
@@ -3,6 +3,8 @@ import tensorflow as tf
|
|||||||
import h5py
|
import h5py
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
|
import logging
|
||||||
|
from itertools import groupby
|
||||||
from typing import Dict, List, Tuple, Optional, Any
|
from typing import Dict, List, Tuple, Optional, Any
|
||||||
from scipy.ndimage import gaussian_filter1d
|
from scipy.ndimage import gaussian_filter1d
|
||||||
|
|
||||||
@@ -430,34 +432,64 @@ class BrainToTextDatasetTF:
|
|||||||
|
|
||||||
def create_individual_dataset(self) -> tf.data.Dataset:
|
def create_individual_dataset(self) -> tf.data.Dataset:
|
||||||
"""
|
"""
|
||||||
Create tf.data.Dataset that yields individual examples for TPU-optimized batching
|
Create tf.data.Dataset that yields individual examples with I/O optimization.
|
||||||
|
|
||||||
This method creates individual examples instead of pre-batched data,
|
This generator is refactored to group trial loading by session file,
|
||||||
allowing TensorFlow's padded_batch to handle fixed-shape batching for TPU.
|
drastically reducing the number of file open/close operations from
|
||||||
|
N_trials to N_sessions, which is ideal for slow disk I/O.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def individual_example_generator():
|
def individual_example_generator():
|
||||||
"""Generator that yields individual trial examples"""
|
"""Generator that groups reads by file to minimize disk I/O."""
|
||||||
for batch_idx in range(self.n_batches):
|
|
||||||
batch_index = self.batch_indices[batch_idx]
|
|
||||||
|
|
||||||
# Process each trial in the batch individually
|
# 1. 创建一个所有试验的扁平列表: [(day, trial), (day, trial), ...]
|
||||||
|
all_trials_to_load = []
|
||||||
|
# 注意:这里的迭代顺序决定了大致的读取顺序
|
||||||
|
# _create_batch_index_train 已经为我们随机化了批次
|
||||||
|
for batch_idx in sorted(self.batch_indices.keys()):
|
||||||
|
batch_index = self.batch_indices[batch_idx]
|
||||||
for day in batch_index.keys():
|
for day in batch_index.keys():
|
||||||
for trial in batch_index[day]:
|
for trial in batch_index[day]:
|
||||||
trial_data = self._load_trial_data(day, trial)
|
all_trials_to_load.append((day, trial))
|
||||||
|
|
||||||
# Yield individual example with all required fields
|
# 2. 按 'day' (即按文件) 对试验列表进行分组
|
||||||
example = {
|
# key=lambda x: x[0] 表示使用元组的第一个元素 (day) 作为分组键
|
||||||
'input_features': trial_data['input_features'].astype(np.float32),
|
for day, group in groupby(sorted(all_trials_to_load, key=lambda x: x[0]), key=lambda x: x[0]):
|
||||||
'seq_class_ids': trial_data['seq_class_ids'].astype(np.int32),
|
|
||||||
'n_time_steps': np.int32(trial_data['n_time_steps']),
|
session_path = self.trial_indices[day]['session_path']
|
||||||
'phone_seq_lens': np.int32(trial_data['phone_seq_lens']),
|
|
||||||
'day_indices': np.int32(trial_data['day_index']),
|
# 3. 为每个分组(每个文件)只打开一次 HDF5 文件
|
||||||
'transcriptions': trial_data['transcription'].astype(np.int32),
|
try:
|
||||||
'block_nums': np.int32(trial_data['block_num']),
|
with h5py.File(session_path, 'r') as f:
|
||||||
'trial_nums': np.int32(trial_data['trial_num'])
|
# 4. 在文件打开的状态下,读取这个文件中需要的所有试验
|
||||||
}
|
for current_day, current_trial in group:
|
||||||
yield example
|
try:
|
||||||
|
# 直接从打开的文件句柄 'f' 中读取,而不是调用旧的加载函数
|
||||||
|
g = f[f'trial_{current_trial:04d}']
|
||||||
|
|
||||||
|
input_features = g['input_features'][:]
|
||||||
|
if self.feature_subset:
|
||||||
|
input_features = input_features[:, self.feature_subset]
|
||||||
|
|
||||||
|
example = {
|
||||||
|
'input_features': input_features.astype(np.float32),
|
||||||
|
'seq_class_ids': g['seq_class_ids'][:].astype(np.int32),
|
||||||
|
'n_time_steps': np.int32(g.attrs['n_time_steps']),
|
||||||
|
'phone_seq_lens': np.int32(g.attrs['seq_len']),
|
||||||
|
'day_indices': np.int32(current_day),
|
||||||
|
'transcriptions': g['transcription'][:].astype(np.int32),
|
||||||
|
'block_nums': np.int32(g.attrs['block_num']),
|
||||||
|
'trial_nums': np.int32(g.attrs['trial_num'])
|
||||||
|
}
|
||||||
|
yield example
|
||||||
|
|
||||||
|
except KeyError:
|
||||||
|
logging.warning(f"Trial {current_trial} not found in file {session_path}. Skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
except (IOError, FileNotFoundError) as e:
|
||||||
|
logging.error(f"Could not open or read HDF5 file: {session_path}. Error: {e}. Skipping all trials for this day.")
|
||||||
|
continue
|
||||||
|
|
||||||
# Define output signature for individual examples
|
# Define output signature for individual examples
|
||||||
output_signature = {
|
output_signature = {
|
||||||
@@ -479,8 +511,8 @@ class BrainToTextDatasetTF:
|
|||||||
|
|
||||||
# Shuffle individual examples if training (more effective than batch-level shuffle)
|
# Shuffle individual examples if training (more effective than batch-level shuffle)
|
||||||
if self.split == 'train':
|
if self.split == 'train':
|
||||||
# Use a reasonable shuffle buffer - not too large to avoid memory issues
|
# 可以适当增大buffer,因为现在I/O更高效了
|
||||||
shuffle_buffer = min(1000, self.n_trials)
|
shuffle_buffer = min(2048, self.n_trials)
|
||||||
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
|
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
@@ -727,7 +759,8 @@ def train_test_split_indices(file_paths: List[str],
|
|||||||
# Utility functions for TPU-optimized data pipeline
|
# Utility functions for TPU-optimized data pipeline
|
||||||
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||||
transform_args: Dict[str, Any],
|
transform_args: Dict[str, Any],
|
||||||
training: bool = True) -> tf.data.Dataset:
|
training: bool = True,
|
||||||
|
cache_path: Optional[str] = None) -> tf.data.Dataset:
|
||||||
"""
|
"""
|
||||||
Create input function for TPU training with fixed-shape batching and data augmentation
|
Create input function for TPU training with fixed-shape batching and data augmentation
|
||||||
|
|
||||||
@@ -735,14 +768,32 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
|||||||
dataset_tf: BrainToTextDatasetTF instance
|
dataset_tf: BrainToTextDatasetTF instance
|
||||||
transform_args: Data transformation configuration
|
transform_args: Data transformation configuration
|
||||||
training: Whether this is for training (applies augmentations)
|
training: Whether this is for training (applies augmentations)
|
||||||
|
cache_path: Optional path for disk caching to improve I/O performance
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tf.data.Dataset ready for TPU training with fixed shapes
|
tf.data.Dataset ready for TPU training with fixed shapes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create individual example dataset instead of pre-batched dataset
|
# Create individual example dataset with file-grouping I/O optimization
|
||||||
dataset = dataset_tf.create_individual_dataset()
|
dataset = dataset_tf.create_individual_dataset()
|
||||||
|
|
||||||
|
# ========================= I/O OPTIMIZATION SOLUTION =========================
|
||||||
|
# 在数据加载之后、随机操作(如数据增强)之前进行缓存
|
||||||
|
if training:
|
||||||
|
# 对于训练,缓存到磁盘文件或内存
|
||||||
|
if cache_path:
|
||||||
|
dataset = dataset.cache(cache_path)
|
||||||
|
print(f"🗃️ Dataset caching enabled: {cache_path}")
|
||||||
|
print("⚠️ First epoch will be slow while building cache, subsequent epochs will be much faster")
|
||||||
|
else:
|
||||||
|
# 如果内存足够大,可以缓存到内存,速度更快
|
||||||
|
# 但对于大型数据集,推荐使用文件缓存
|
||||||
|
dataset = dataset.cache()
|
||||||
|
print("🗃️ Dataset caching enabled: in-memory cache")
|
||||||
|
print("⚠️ First epoch will be slow while building cache, subsequent epochs will be much faster")
|
||||||
|
# (对于验证集通常不需要缓存,因为它只运行一次)
|
||||||
|
# ================================================================
|
||||||
|
|
||||||
def apply_transforms(example):
|
def apply_transforms(example):
|
||||||
"""Apply data transformations to individual examples"""
|
"""Apply data transformations to individual examples"""
|
||||||
features = example['input_features']
|
features = example['input_features']
|
||||||
@@ -762,7 +813,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
|||||||
|
|
||||||
return example
|
return example
|
||||||
|
|
||||||
# Apply transformations to individual examples
|
# 在缓存之后应用随机的数据增强,确保每个epoch的增强都不同
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
apply_transforms,
|
apply_transforms,
|
||||||
num_parallel_calls=tf.data.AUTOTUNE
|
num_parallel_calls=tf.data.AUTOTUNE
|
||||||
|
Reference in New Issue
Block a user