Files
b2txt25/model_training_nnn_tpu/dataset_tf.py

1094 lines
47 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
import tensorflow as tf
import h5py
import numpy as np
import math
import logging
import time
import random
import multiprocessing
from itertools import groupby
from typing import Dict, List, Tuple, Optional, Any
from scipy.ndimage import gaussian_filter1d
from concurrent.futures import ThreadPoolExecutor, as_completed
class BrainToTextDatasetTF:
"""
TensorFlow Dataset for brain-to-text data optimized for TPU v5e-8
This class creates tf.data.Dataset objects that efficiently load and batch
brain-to-text data from HDF5 files with TPU-optimized operations.
"""
def __init__(
self,
trial_indices: Dict[int, Dict[str, Any]],
n_batches: Optional[int],
split: str = 'train',
batch_size: int = 64,
days_per_batch: int = 1,
random_seed: int = -1,
must_include_days: Optional[List[int]] = None,
feature_subset: Optional[List[int]] = None,
prefetch_buffer: int = tf.data.AUTOTUNE,
num_parallel_calls: int = tf.data.AUTOTUNE,
cache_data: bool = True,
preload_all_data: bool = False
):
"""
Initialize TensorFlow dataset for brain-to-text data
Args:
trial_indices: Dictionary with day numbers as keys and trial info as values
n_batches: Number of training batches to create (None for validation)
split: 'train' or 'test'
batch_size: Number of examples per batch
days_per_batch: Number of unique days per batch (for day-specific layers)
random_seed: Random seed for reproducibility
must_include_days: Days that must be included in every batch
feature_subset: Subset of neural features to use
prefetch_buffer: Buffer size for prefetching
num_parallel_calls: Parallel processing threads
cache_data: Whether to cache loaded data in memory
preload_all_data: Whether to preload all data at initialization
"""
# Set random seed for reproducibility
if random_seed != -1:
tf.random.set_seed(random_seed)
np.random.seed(random_seed)
self.split = split
if self.split not in ['train', 'test']:
raise ValueError(f'split must be either "train" or "test". Received {self.split}')
self.days_per_batch = days_per_batch
self.batch_size = batch_size
self.n_batches = n_batches
self.trial_indices = trial_indices
self.n_days = len(trial_indices.keys())
self.feature_subset = feature_subset
self.must_include_days = must_include_days
self.prefetch_buffer = prefetch_buffer
self.num_parallel_calls = num_parallel_calls
self.cache_data = cache_data
self.preload_all_data = preload_all_data
# Initialize data cache
self.data_cache = {} if cache_data else None
# Calculate total number of trials
self.n_trials = 0
for d in trial_indices:
self.n_trials += len(trial_indices[d]['trials'])
# Validation checks
if must_include_days is not None:
if len(must_include_days) > days_per_batch:
raise ValueError(f'must_include_days must be <= days_per_batch')
# Map negative indices
for i, d in enumerate(must_include_days):
if d < 0:
must_include_days[i] = self.n_days + d
if self.split == 'train' and self.days_per_batch > self.n_days:
raise ValueError(f'days_per_batch ({days_per_batch}) > available days ({self.n_days})')
# Create batch indices
if self.split == 'train':
self.batch_indices = self._create_batch_index_train()
else:
self.batch_indices = self._create_batch_index_test()
self.n_batches = len(self.batch_indices)
# Preload data if requested (speeds up first batch significantly)
if self.preload_all_data:
print(f"🔄 Preloading all data for {self.split} split...")
self._preload_all_data()
print(f"✅ Preloading completed - {len(self.data_cache)} trials cached")
# ========================= 特征维度自动检测 (更稳健的版本) =========================
# 明确地计算并存储特征维度,避免 padded_batch 时的维度不匹配
if self.feature_subset:
self.feature_dim = len(self.feature_subset)
print(f"✅ Using feature subset dimension: {self.feature_dim}")
else:
# 稳健地从数据中推断实际特征维度
# 遍历所有 day直到找到第一个包含 trial 的 day
detected_dim = None
for day in self.trial_indices:
# 检查 trial 列表是否非空
if self.trial_indices[day]['trials']:
try:
# 列表非空,尝试加载第一个 trial
first_valid_trial = self.trial_indices[day]['trials'][0]
first_sample = self._load_single_trial_data(day, first_valid_trial)
detected_dim = first_sample['input_features'].shape[1]
print(f"✅ Auto-detected feature dimension from day {day}: {detected_dim}")
break # 成功检测到维度,跳出循环
except Exception as e:
# 如果加载这个 trial 失败,则继续尝试下一个 day
print(f"⚠️ Warning: Could not load trial {first_valid_trial} from day {day} for dimension check. Error: {e}")
continue
if detected_dim is not None:
self.feature_dim = detected_dim
else:
# 如果遍历完所有 day 都没有找到任何有效的 trial则报错或回退
print(f"⚠️ CRITICAL: Could not auto-detect feature dimension after checking all days. No valid trials found in the dataset split '{self.split}'. Falling back to 512.")
self.feature_dim = 512 # 作为最后的备用方案
# ========================= 特征维度检测结束 =========================
def _create_batch_index_train(self) -> Dict[int, Dict[int, List[int]]]:
"""Create training batch indices with random sampling"""
batch_indices = {}
# Precompute non-must-include days
if self.must_include_days is not None:
non_must_include_days = [
d for d in self.trial_indices.keys()
if d not in self.must_include_days
]
for batch_idx in range(self.n_batches):
batch = {}
# Select days for this batch
if self.must_include_days is not None and len(self.must_include_days) > 0:
additional_days = np.random.choice(
non_must_include_days,
size=self.days_per_batch - len(self.must_include_days),
replace=False
)
days = np.concatenate((self.must_include_days, additional_days))
else:
days = np.random.choice(
list(self.trial_indices.keys()),
size=self.days_per_batch,
replace=False
)
# Calculate trials per day
num_trials = math.ceil(self.batch_size / self.days_per_batch)
for d in days:
# Sample trials with replacement
trial_idxs = np.random.choice(
self.trial_indices[d]['trials'],
size=num_trials,
replace=True
)
batch[d] = trial_idxs.tolist()
# Remove extra trials to match exact batch size
extra_trials = (num_trials * len(days)) - self.batch_size
while extra_trials > 0:
d = np.random.choice(days)
if len(batch[d]) > 0:
batch[d] = batch[d][:-1]
extra_trials -= 1
batch_indices[batch_idx] = batch
return batch_indices
def _create_batch_index_test(self) -> Dict[int, Dict[int, List[int]]]:
"""Create test batch indices ensuring all trials are seen once"""
batch_indices = {}
batch_idx = 0
for d in self.trial_indices.keys():
num_trials = len(self.trial_indices[d]['trials'])
num_batches = (num_trials + self.batch_size - 1) // self.batch_size
for i in range(num_batches):
start_idx = i * self.batch_size
end_idx = min((i + 1) * self.batch_size, num_trials)
batch_trials = self.trial_indices[d]['trials'][start_idx:end_idx]
batch_indices[batch_idx] = {d: batch_trials}
batch_idx += 1
return batch_indices
def _preload_all_data(self):
"""Preload all trial data into memory cache (uses available RAM optimally)"""
import multiprocessing
from concurrent.futures import ThreadPoolExecutor, as_completed
# Use CPU cores efficiently for parallel I/O
max_workers = min(multiprocessing.cpu_count(), 32) # Limit to avoid overwhelming I/O
# Collect all trials to load
trials_to_load = []
for day in self.trial_indices:
for trial in self.trial_indices[day]['trials']:
trials_to_load.append((day, trial))
print(f"📊 Preloading {len(trials_to_load)} trials using {max_workers} workers...")
# Parallel loading using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all loading tasks
future_to_trial = {
executor.submit(self._load_single_trial_data, day, trial): (day, trial)
for day, trial in trials_to_load
}
# Process completed tasks and update cache
loaded_count = 0
for future in as_completed(future_to_trial):
day, trial = future_to_trial[future]
try:
trial_data = future.result()
cache_key = f"{day}_{trial}"
self.data_cache[cache_key] = trial_data
loaded_count += 1
# Progress indicator every 100 trials
if loaded_count % 100 == 0:
print(f" Loaded {loaded_count}/{len(trials_to_load)} trials...")
except Exception as e:
print(f" Warning: Failed to load trial {day}_{trial}: {e}")
print(f"✅ Preloading completed: {loaded_count}/{len(trials_to_load)} trials cached")
def _load_single_trial_data(self, day: int, trial: int) -> Dict[str, Any]:
"""Load a single trial's data - optimized version for parallel loading"""
try:
session_path = self.trial_indices[day]['session_path']
with h5py.File(session_path, 'r') as f:
g = f[f'trial_{trial:04d}']
# Load neural features
input_features = g['input_features'][:]
if self.feature_subset:
input_features = input_features[:, self.feature_subset]
# Convert to float32 for TF compatibility
input_features = input_features.astype(np.float32)
trial_data = {
'input_features': input_features,
'seq_class_ids': g['seq_class_ids'][:],
'transcription': g['transcription'][:],
'n_time_steps': g.attrs['n_time_steps'],
'phone_seq_lens': g.attrs['seq_len'],
'day_index': day,
'block_num': g.attrs['block_num'],
'trial_num': g.attrs['trial_num']
}
return trial_data
except Exception as e:
# Log the error and return dummy data with correct feature dimension
logging.warning(f"Failed to load trial {day}_{trial} from {session_path}. Error: {e}. Returning dummy data.")
# Use self.feature_dim to ensure dimension consistency
feature_dim = self.feature_dim
return {
'input_features': np.zeros((100, feature_dim), dtype=np.float32),
'seq_class_ids': np.zeros((10,), dtype=np.int32),
'transcription': np.zeros((50,), dtype=np.int32),
'n_time_steps': 100,
'phone_seq_lens': 10,
'day_index': day,
'block_num': 0,
'trial_num': 0
}
def _load_trial_data(self, day: int, trial: int) -> Dict[str, tf.Tensor]:
"""Load a single trial's data from cache or HDF5 file"""
# Check cache first if caching is enabled
if self.cache_data:
cache_key = f"{day}_{trial}"
if cache_key in self.data_cache:
return self.data_cache[cache_key]
# Load from disk if not in cache
trial_data = self._load_single_trial_data(day, trial)
# Cache the loaded data if caching is enabled
if self.cache_data:
cache_key = f"{day}_{trial}"
self.data_cache[cache_key] = trial_data
return trial_data
def _create_batch_generator(self):
"""
Generator function that yields individual batches with optimized loading
⚠️ DEPRECATED: This method is deprecated. Use create_input_fn() instead for better performance
and TPU compatibility. This method will be removed in a future version.
"""
import warnings
warnings.warn(
"_create_batch_generator is deprecated. Use create_input_fn() instead for better performance.",
DeprecationWarning,
stacklevel=2
)
import time
from concurrent.futures import ThreadPoolExecutor
for batch_idx in range(self.n_batches):
batch_start_time = time.time()
batch_data = {
'input_features': [],
'seq_class_ids': [],
'n_time_steps': [],
'phone_seq_lens': [],
'day_indices': [],
'transcriptions': [],
'block_nums': [],
'trial_nums': []
}
batch_index = self.batch_indices[batch_idx]
# Collect all trials to load for this batch
trials_to_load = []
for day in batch_index.keys():
for trial in batch_index[day]:
trials_to_load.append((day, trial))
# Use parallel loading if not preloaded and have multiple trials
if not self.preload_all_data and len(trials_to_load) > 4:
# Parallel loading for faster I/O
with ThreadPoolExecutor(max_workers=min(8, len(trials_to_load))) as executor:
future_to_trial = {
executor.submit(self._load_trial_data, day, trial): (day, trial)
for day, trial in trials_to_load
}
# Collect results in order
trial_results = {}
for future in future_to_trial:
day, trial = future_to_trial[future]
trial_results[(day, trial)] = future.result()
# Add data in original order
for day, trial in trials_to_load:
trial_data = trial_results[(day, trial)]
batch_data['input_features'].append(trial_data['input_features'])
batch_data['seq_class_ids'].append(trial_data['seq_class_ids'])
batch_data['transcriptions'].append(trial_data['transcription'])
batch_data['n_time_steps'].append(trial_data['n_time_steps'])
batch_data['phone_seq_lens'].append(trial_data['phone_seq_lens'])
batch_data['day_indices'].append(trial_data['day_index'])
batch_data['block_nums'].append(trial_data['block_num'])
batch_data['trial_nums'].append(trial_data['trial_num'])
else:
# Sequential loading (fast when data is cached or few trials)
for day, trial in trials_to_load:
trial_data = self._load_trial_data(day, trial)
batch_data['input_features'].append(trial_data['input_features'])
batch_data['seq_class_ids'].append(trial_data['seq_class_ids'])
batch_data['transcriptions'].append(trial_data['transcription'])
batch_data['n_time_steps'].append(trial_data['n_time_steps'])
batch_data['phone_seq_lens'].append(trial_data['phone_seq_lens'])
batch_data['day_indices'].append(trial_data['day_index'])
batch_data['block_nums'].append(trial_data['block_num'])
batch_data['trial_nums'].append(trial_data['trial_num'])
data_loading_time = time.time() - batch_start_time
# Add timing diagnostic for first few batches
if batch_idx < 3:
cache_status = "cached" if self.preload_all_data else "disk"
loading_method = "parallel" if (not self.preload_all_data and len(trials_to_load) > 4) else "sequential"
print(f"⏱️ Batch {batch_idx}: {len(trials_to_load)} trials loaded in {data_loading_time:.3f}s ({cache_status}, {loading_method})")
# Pad sequences to create uniform batch
max_time_steps = max(batch_data['n_time_steps'])
max_phone_len = max(len(seq) for seq in batch_data['seq_class_ids'])
max_transcription_len = max(len(trans) for trans in batch_data['transcriptions'])
# Pad input features
padded_features = []
for features in batch_data['input_features']:
if features.shape[0] < max_time_steps:
padding = np.zeros((max_time_steps - features.shape[0], features.shape[1]), dtype=np.float32)
features = np.vstack([features, padding])
padded_features.append(features)
# Pad sequences
padded_seq_ids = []
for seq in batch_data['seq_class_ids']:
if len(seq) < max_phone_len:
padding = np.zeros(max_phone_len - len(seq), dtype=np.int32)
seq = np.concatenate([seq, padding])
padded_seq_ids.append(seq)
# Pad transcriptions
padded_transcriptions = []
for trans in batch_data['transcriptions']:
if len(trans) < max_transcription_len:
padding = np.zeros(max_transcription_len - len(trans), dtype=np.int32)
trans = np.concatenate([trans, padding])
padded_transcriptions.append(trans)
# Create final batch tensors
batch = {
'input_features': np.stack(padded_features),
'seq_class_ids': np.stack(padded_seq_ids),
'n_time_steps': np.array(batch_data['n_time_steps'], dtype=np.int32),
'phone_seq_lens': np.array(batch_data['phone_seq_lens'], dtype=np.int32),
'day_indices': np.array(batch_data['day_indices'], dtype=np.int32),
'transcriptions': np.stack(padded_transcriptions),
'block_nums': np.array(batch_data['block_nums'], dtype=np.int32),
'trial_nums': np.array(batch_data['trial_nums'], dtype=np.int32)
}
yield batch
def create_dataset(self) -> tf.data.Dataset:
"""
Create optimized tf.data.Dataset for TPU training
⚠️ DEPRECATED: This method is deprecated. Use create_input_fn() instead for better performance
and TPU compatibility. This method will be removed in a future version.
"""
import warnings
warnings.warn(
"create_dataset is deprecated. Use create_input_fn() instead for better performance.",
DeprecationWarning,
stacklevel=2
)
# Define output signature for the dataset
output_signature = {
'input_features': tf.TensorSpec(shape=(None, None, None), dtype=tf.float32),
'seq_class_ids': tf.TensorSpec(shape=(None, None), dtype=tf.int32),
'n_time_steps': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'phone_seq_lens': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'day_indices': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'transcriptions': tf.TensorSpec(shape=(None, None), dtype=tf.int32),
'block_nums': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'trial_nums': tf.TensorSpec(shape=(None,), dtype=tf.int32)
}
# Create dataset from generator
dataset = tf.data.Dataset.from_generator(
self._create_batch_generator,
output_signature=output_signature
)
# Apply TPU-optimized transformations
# 🚨 GPU版本策略不需要在Dataset级别shuffle!
# GPU版本在 _create_batch_index_train() 中已经做了随机采样第107-118行
# 这里再shuffle会导致内存爆炸1000 batch × 256 trials = 256,000 trials同时在内存
# if self.split == 'train':
# dataset = dataset.shuffle(buffer_size=min(1000, self.n_batches)) # ← 注释掉内存杀手
# Prefetch for better performance
dataset = dataset.prefetch(self.prefetch_buffer)
return dataset
def create_individual_dataset(self) -> tf.data.Dataset:
"""
Create tf.data.Dataset that yields individual examples with I/O optimization.
This generator is refactored to group trial loading by session file,
drastically reducing the number of file open/close operations from
N_trials to N_sessions, which is ideal for slow disk I/O.
"""
def individual_example_generator():
"""Generator that groups reads by file to minimize disk I/O."""
# 1. 创建一个所有试验的扁平列表: [(day, trial), (day, trial), ...]
all_trials_to_load = []
# 注意:这里的迭代顺序决定了大致的读取顺序
# _create_batch_index_train 已经为我们随机化了批次
for batch_idx in sorted(self.batch_indices.keys()):
batch_index = self.batch_indices[batch_idx]
for day in batch_index.keys():
for trial in batch_index[day]:
all_trials_to_load.append((day, trial))
# 2. 按 'day' (即按文件) 对试验列表进行分组
# key=lambda x: x[0] 表示使用元组的第一个元素 (day) 作为分组键
for day, group in groupby(sorted(all_trials_to_load, key=lambda x: x[0]), key=lambda x: x[0]):
session_path = self.trial_indices[day]['session_path']
# 3. 为每个分组(每个文件)只打开一次 HDF5 文件
try:
with h5py.File(session_path, 'r') as f:
# 4. 在文件打开的状态下,读取这个文件中需要的所有试验
for current_day, current_trial in group:
try:
# 直接从打开的文件句柄 'f' 中读取,而不是调用旧的加载函数
g = f[f'trial_{current_trial:04d}']
input_features = g['input_features'][:]
if self.feature_subset:
input_features = input_features[:, self.feature_subset]
example = {
'input_features': input_features.astype(np.float32),
'seq_class_ids': g['seq_class_ids'][:].astype(np.int32),
'n_time_steps': np.int32(g.attrs['n_time_steps']),
'phone_seq_lens': np.int32(g.attrs['seq_len']),
'day_indices': np.int32(current_day),
'transcriptions': g['transcription'][:].astype(np.int32),
'block_nums': np.int32(g.attrs['block_num']),
'trial_nums': np.int32(g.attrs['trial_num'])
}
yield example
except KeyError:
logging.warning(f"Trial {current_trial} not found in file {session_path}. Skipping.")
continue
except (IOError, FileNotFoundError) as e:
logging.error(f"Could not open or read HDF5 file: {session_path}. Error: {e}. Skipping all trials for this day.")
continue
# Define output signature for individual examples
output_signature = {
'input_features': tf.TensorSpec(shape=(None, self.feature_dim), dtype=tf.float32),
'seq_class_ids': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'n_time_steps': tf.TensorSpec(shape=(), dtype=tf.int32),
'phone_seq_lens': tf.TensorSpec(shape=(), dtype=tf.int32),
'day_indices': tf.TensorSpec(shape=(), dtype=tf.int32),
'transcriptions': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'block_nums': tf.TensorSpec(shape=(), dtype=tf.int32),
'trial_nums': tf.TensorSpec(shape=(), dtype=tf.int32)
}
# Create dataset from individual examples
dataset = tf.data.Dataset.from_generator(
individual_example_generator,
output_signature=output_signature
)
# Shuffle individual examples if training (more effective than batch-level shuffle)
if self.split == 'train':
# 可以适当增大buffer因为现在I/O更高效了
shuffle_buffer = min(2048, self.n_trials)
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
return dataset
class DataAugmentationTF:
"""
TensorFlow data augmentation functions optimized for TPU v5e-8
"""
@staticmethod
def gauss_smooth(inputs: tf.Tensor,
smooth_kernel_std: float = 2.0,
smooth_kernel_size: int = 100) -> tf.Tensor:
"""
Apply Gaussian smoothing along the time axis using a vectorized TensorFlow operation.
This implementation uses depthwise_conv2d for optimal TPU performance,
replacing the inefficient Python for-loop that created 512 separate conv1d operations.
Args:
inputs: Input tensor [batch_size, time_steps, features]
smooth_kernel_std: Standard deviation of Gaussian kernel
smooth_kernel_size: Size of the Gaussian kernel
Returns:
Smoothed tensor with same shape as input
"""
# Create Gaussian kernel using numpy (computed once)
inp = np.zeros(smooth_kernel_size, dtype=np.float32)
inp[smooth_kernel_size // 2] = 1
gauss_kernel = gaussian_filter1d(inp, smooth_kernel_std)
valid_idx = np.argwhere(gauss_kernel > 0.01)
gauss_kernel = gauss_kernel[valid_idx].flatten()
gauss_kernel = gauss_kernel / np.sum(gauss_kernel)
gauss_kernel = tf.constant(gauss_kernel, dtype=tf.float32)
# ========================= OPTIMIZED SOLUTION =========================
# Get input dimensions
num_features = tf.shape(inputs)[-1]
kernel_size = tf.shape(gauss_kernel)[0]
# Prepare kernel for depthwise_conv2d
# Shape needed: [height, width, in_channels, channel_multiplier]
# Our case: [kernel_size, 1, num_features, 1]
# This means each input channel (num_features) has its own independent, identical 1D Gaussian kernel
kernel = tf.reshape(gauss_kernel, [kernel_size, 1, 1, 1])
kernel = tf.tile(kernel, [1, 1, num_features, 1])
# Prepare input for conv2d
# Shape needed: [batch, height, width, channels]
# Our case: [batch_size, time_steps, 1, num_features]
# Add a dummy width dimension
reshaped_inputs = tf.expand_dims(inputs, axis=2)
# Execute depthwise convolution
# This is a single, efficient operation replacing the original Python for-loop
smoothed = tf.nn.depthwise_conv2d(
reshaped_inputs,
kernel,
strides=[1, 1, 1, 1],
padding='SAME'
)
# Remove the dummy width dimension to restore original shape
smoothed = tf.squeeze(smoothed, axis=2)
# ================================================================
return smoothed
@staticmethod
def transform_data(features: tf.Tensor,
n_time_steps: tf.Tensor,
transform_args: Dict[str, Any],
training: bool = True) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Apply data transformations optimized for TPU
Args:
features: Input features [batch_size, time_steps, channels]
n_time_steps: Number of valid time steps per sample
transform_args: Transformation configuration
training: Whether to apply training-only augmentations
Returns:
Transformed features and updated time steps
"""
batch_size = tf.shape(features)[0]
time_steps = tf.shape(features)[1]
channels = tf.shape(features)[2]
# Training-only augmentations
if training:
# Static gain noise
if transform_args.get('static_gain_std', 0) > 0:
gain_std = transform_args['static_gain_std']
# Create identity matrices for each batch
identity_matrices = tf.eye(channels, batch_shape=[batch_size])
# Add noise to create warp matrices
noise = tf.random.normal([batch_size, channels, channels]) * gain_std
warp_matrices = identity_matrices + noise
# Apply transformation
features = tf.linalg.matmul(features, warp_matrices)
# White noise
if transform_args.get('white_noise_std', 0) > 0:
white_noise = tf.random.normal(tf.shape(features)) * transform_args['white_noise_std']
features = features + white_noise
# Constant offset noise
if transform_args.get('constant_offset_std', 0) > 0:
offset_noise = tf.random.normal([batch_size, 1, channels]) * transform_args['constant_offset_std']
features = features + offset_noise
# Random walk noise
if transform_args.get('random_walk_std', 0) > 0:
random_walk_noise = tf.random.normal(tf.shape(features)) * transform_args['random_walk_std']
axis = transform_args.get('random_walk_axis', 1)
random_walk_noise = tf.cumsum(random_walk_noise, axis=axis)
features = features + random_walk_noise
# Random cutoff (simplified for TPU - apply to all samples in batch)
if transform_args.get('random_cut', 0) > 0:
max_cut = transform_args['random_cut']
cut = tf.random.uniform([], 0, max_cut, dtype=tf.int32)
features = features[:, cut:, :]
n_time_steps = n_time_steps - cut
# Apply Gaussian smoothing (both training and validation)
if transform_args.get('smooth_data', False):
features = DataAugmentationTF.gauss_smooth(
features,
smooth_kernel_std=transform_args.get('smooth_kernel_std', 2.0),
smooth_kernel_size=transform_args.get('smooth_kernel_size', 100)
)
return features, n_time_steps
def train_test_split_indices(file_paths: List[str],
test_percentage: float = 0.1,
seed: int = -1,
bad_trials_dict: Optional[Dict] = None) -> Tuple[Dict, Dict]:
"""
Split data from file_paths into train and test splits
Args:
file_paths: List of HDF5 file paths
test_percentage: Percentage of trials for testing
seed: Random seed for reproducibility
bad_trials_dict: Dictionary of trials to exclude
Returns:
Tuple of (train_trials, test_trials) dictionaries
"""
# Set seed for reproducibility
if seed != -1:
np.random.seed(seed)
# Get trials in each day
trials_per_day = {}
for i, path in enumerate(file_paths):
try:
# Handle both Windows and Unix path separators
path_parts = path.replace('\\', '/').split('/')
session_candidates = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))]
if not session_candidates:
logging.error(f"Could not parse session name from path: {path}. Skipping this file.")
continue
session = session_candidates[0]
except Exception as e:
logging.error(f"Error parsing path {path}: {e}. Skipping this file.")
continue
good_trial_indices = []
if os.path.exists(path):
with h5py.File(path, 'r') as f:
num_trials = len(list(f.keys()))
for t in range(num_trials):
key = f'trial_{t:04d}'
if key not in f:
continue
block_num = f[key].attrs['block_num']
trial_num = f[key].attrs['trial_num']
# Check if trial should be excluded
if (bad_trials_dict is not None
and session in bad_trials_dict
and str(block_num) in bad_trials_dict[session]
and trial_num in bad_trials_dict[session][str(block_num)]):
continue
good_trial_indices.append(t)
trials_per_day[i] = {
'num_trials': len(good_trial_indices),
'trial_indices': good_trial_indices,
'session_path': path
}
# Split trials into train and test
train_trials = {}
test_trials = {}
for day in trials_per_day.keys():
num_trials = trials_per_day[day]['num_trials']
all_trial_indices = trials_per_day[day]['trial_indices']
if test_percentage == 0:
train_trials[day] = {
'trials': all_trial_indices,
'session_path': trials_per_day[day]['session_path']
}
test_trials[day] = {
'trials': [],
'session_path': trials_per_day[day]['session_path']
}
elif test_percentage == 1:
train_trials[day] = {
'trials': [],
'session_path': trials_per_day[day]['session_path']
}
test_trials[day] = {
'trials': all_trial_indices,
'session_path': trials_per_day[day]['session_path']
}
else:
# Calculate number of test trials
num_test = max(1, int(num_trials * test_percentage))
# Randomly select test indices
test_indices = np.random.choice(all_trial_indices, size=num_test, replace=False).tolist()
# Remaining indices for training
train_indices = [idx for idx in all_trial_indices if idx not in test_indices]
train_trials[day] = {
'trials': train_indices,
'session_path': trials_per_day[day]['session_path']
}
test_trials[day] = {
'trials': test_indices,
'session_path': trials_per_day[day]['session_path']
}
return train_trials, test_trials
def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 100) -> Dict[str, int]:
"""
Analyzes dataset shapes in parallel to determine maximum dimensions for padded_batch,
utilizing multiple CPU cores and the dataset's caching mechanism.
Args:
dataset_tf: Dataset instance to analyze
sample_size: Number of samples to analyze (set to -1 for all data)
Returns:
Dictionary with maximum dimensions
"""
print(f"🚀 Starting parallel dataset analysis (sampling: {'ALL' if sample_size == -1 else sample_size})...")
start_time = time.time()
# 1. 收集所有需要分析的 (day, trial) 对,避免重复
all_trials = []
unique_trials = set()
for batch_idx in sorted(dataset_tf.batch_indices.keys()):
batch = dataset_tf.batch_indices[batch_idx]
for day, trials in batch.items():
for trial in trials:
if (day, trial) not in unique_trials:
unique_trials.add((day, trial))
all_trials.append((day, trial))
# 2. 如果需要采样,则对列表进行采样
if 0 < sample_size < len(all_trials):
# 设置种子以确保可重现性(如果需要的话)
random.seed(42)
trials_to_check = random.sample(all_trials, sample_size)
else:
trials_to_check = all_trials
total_trials_to_analyze = len(trials_to_check)
print(f"📊 Total unique trials to analyze: {total_trials_to_analyze}")
# 定义一个辅助函数,供每个线程调用
def analyze_single_trial(day_trial_pair):
"""Loads and analyzes a single trial, returns its shapes."""
day, trial = day_trial_pair
try:
# 复用 dataset_tf 的加载和缓存逻辑
trial_data = dataset_tf._load_trial_data(day, trial)
# 直接从加载的数据中获取信息
time_steps = int(trial_data['n_time_steps'])
phone_seq_len = int(trial_data['phone_seq_lens'])
# 处理 transcription 数据 - 它可能是数组
transcription_data = trial_data['transcription']
if hasattr(transcription_data, '__len__'):
transcription_len = len(transcription_data)
else:
transcription_len = 1 # 如果是标量长度为1
return (time_steps, phone_seq_len, transcription_len)
except Exception as e:
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
return None # 返回 None 表示失败
# 3. 使用 ThreadPoolExecutor 进行并行处理
# Use dynamic calculation based on CPU cores with reasonable upper limit
cpu_count = os.cpu_count() or 4 # Fallback to 4 if cpu_count() returns None
max_workers = min(32, cpu_count, len(trials_to_check))
local_max_shapes = []
print(f"🔧 Using {max_workers} parallel workers for analysis...")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交所有分析任务
future_to_trial = {
executor.submit(analyze_single_trial, trial_pair): trial_pair
for trial_pair in trials_to_check
}
count = 0
progress_interval = max(1, total_trials_to_analyze // 20) # 显示20次进度更新
for future in as_completed(future_to_trial):
result = future.result()
if result:
local_max_shapes.append(result)
count += 1
if count % progress_interval == 0 or count == total_trials_to_analyze:
elapsed = time.time() - start_time
rate = count / elapsed if elapsed > 0 else 0
print(f" 📈 Analyzed {count}/{total_trials_to_analyze} trials... ({count/total_trials_to_analyze:.1%}) [{rate:.1f} trials/sec]")
# 4. 聚合所有线程的结果
if not local_max_shapes:
raise ValueError("Dataset analysis failed: No trials could be successfully analyzed.")
# 将 [(t1, p1, tr1), (t2, p2, tr2), ...] 转换为 ([t1, t2, ...], [p1, p2, ...], ...)
unzipped_shapes = list(zip(*local_max_shapes))
original_max_shapes = {
'max_time_steps': int(np.max(unzipped_shapes[0])),
'max_phone_seq_len': int(np.max(unzipped_shapes[1])),
'max_transcription_len': int(np.max(unzipped_shapes[2])),
'n_features': dataset_tf.feature_dim
}
# 5. 添加适当的安全边际 - 基于分析范围调整
if sample_size == -1:
# 全数据分析:只需要很小的边际应对可能的舍入误差
safety_margin = 1.02 # 2% buffer for rounding errors
margin_reason = "minimal buffer for full dataset analysis"
else:
# 采样分析:需要更大的边际应对未采样到的极值
safety_margin = 1.3 # 30% buffer for sampling uncertainty
margin_reason = f"larger buffer due to sampling only {sample_size} trials"
final_max_shapes = {
'max_time_steps': int(original_max_shapes['max_time_steps'] * safety_margin),
'max_phone_seq_len': int(original_max_shapes['max_phone_seq_len'] * safety_margin),
'max_transcription_len': int(original_max_shapes['max_transcription_len'] * safety_margin),
'n_features': original_max_shapes['n_features']
}
analysis_time = time.time() - start_time
successful_rate = len(local_max_shapes) / total_trials_to_analyze * 100
print(f"✅ Parallel analysis complete in {analysis_time:.2f} seconds!")
print(f"📊 Successfully analyzed {len(local_max_shapes)}/{total_trials_to_analyze} trials ({successful_rate:.1f}%)")
print(f"📏 Final max shapes (with {int((safety_margin-1)*100)}% safety margin - {margin_reason}):")
print(f" Time steps: {original_max_shapes['max_time_steps']}{final_max_shapes['max_time_steps']}")
print(f" Phone sequence length: {original_max_shapes['max_phone_seq_len']}{final_max_shapes['max_phone_seq_len']}")
print(f" Transcription length: {original_max_shapes['max_transcription_len']}{final_max_shapes['max_transcription_len']}")
print(f" Number of features: {final_max_shapes['n_features']}")
return final_max_shapes
# Utility functions for TPU-optimized data pipeline
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
transform_args: Dict[str, Any],
training: bool = True,
cache_path: Optional[str] = None,
use_static_shapes: bool = True) -> tf.data.Dataset:
"""
Create input function for TPU training with configurable shape handling
Args:
dataset_tf: BrainToTextDatasetTF instance
transform_args: Data transformation configuration
training: Whether this is for training (applies augmentations)
cache_path: Optional path for disk caching to improve I/O performance
use_static_shapes: If True, use pre-computed static shapes for XLA compatibility
Returns:
tf.data.Dataset ready for TPU training
"""
# Step 1: Create individual example dataset
dataset = dataset_tf.create_individual_dataset()
# Step 2: Cache raw samples BEFORE any augmentation or batching
if cache_path:
dataset = dataset.cache(cache_path)
split_name = "training" if training else "validation"
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: {cache_path}")
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
else:
dataset = dataset.cache()
split_name = "training" if training else "validation"
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache")
# Step 3: Batch samples with shape handling optimized for TPU
if use_static_shapes:
print(f"🔧 Using STATIC shapes for XLA compatibility")
# Analyze dataset to get maximum shapes
print("📊 Analyzing dataset for maximum shapes...")
max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1) # Analyze ALL data for maximum accuracy
# Use static shapes based on analysis
padded_shapes = {
'input_features': (max_shapes['max_time_steps'], dataset_tf.feature_dim),
'seq_class_ids': (max_shapes['max_phone_seq_len'],),
'n_time_steps': (),
'phone_seq_lens': (),
'day_indices': (),
'transcriptions': (max_shapes['max_transcription_len'],),
'block_nums': (),
'trial_nums': ()
}
print(f"📏 Using static shapes: time_steps={max_shapes['max_time_steps']}, "
f"phone_len={max_shapes['max_phone_seq_len']}, "
f"transcription_len={max_shapes['max_transcription_len']}")
else:
print(f"🔧 Using DYNAMIC shapes (may cause XLA compilation issues)")
# Use dynamic shapes - may cause XLA compilation issues
padded_shapes = {
'input_features': (None, dataset_tf.feature_dim),
'seq_class_ids': (None,),
'n_time_steps': (),
'phone_seq_lens': (),
'day_indices': (),
'transcriptions': (None,),
'block_nums': (),
'trial_nums': ()
}
print(f"🔧 Feature dimension: {dataset_tf.feature_dim}")
# Define padding values for each field
padding_values = {
'input_features': 0.0,
'seq_class_ids': 0,
'n_time_steps': 0,
'phone_seq_lens': 0,
'day_indices': 0,
'transcriptions': 0,
'block_nums': 0,
'trial_nums': 0
}
# Create batches with DYNAMIC padding - this cannot fail due to size mismatches
dataset = dataset.padded_batch(
batch_size=dataset_tf.batch_size,
padded_shapes=padded_shapes,
padding_values=padding_values,
drop_remainder=True # Critical for TPU: ensures all batches have same size
)
# Step 4: Apply data augmentation to BATCHES (after dynamic batching)
def apply_batch_transforms(batch):
"""Apply data transformations to entire batches - resolves time paradox"""
features = batch['input_features']
n_time_steps = batch['n_time_steps']
# Apply transformations to the entire batch
features, n_time_steps = DataAugmentationTF.transform_data(
features, # Already has batch dimension
n_time_steps,
transform_args,
training=training
)
# Update batch with transformed data
batch['input_features'] = features
batch['n_time_steps'] = n_time_steps
return batch
# Apply batch transforms only during training
if training:
dataset = dataset.map(
apply_batch_transforms,
num_parallel_calls=tf.data.AUTOTUNE
)
print(f"✅ Batch augmentation enabled for training")
else:
print(f"✅ No augmentation for validation")
# Step 5: Prefetch for optimal performance
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset