Files
b2txt25/model_training_nnn_tpu/dataset_tf.py

1094 lines
47 KiB
Python
Raw Permalink Normal View History

2025-10-15 16:55:52 +08:00
import os
2025-10-21 00:19:05 +08:00
import sys
2025-10-15 16:55:52 +08:00
import tensorflow as tf
import h5py
import numpy as np
import math
import logging
import time
import random
import multiprocessing
from itertools import groupby
2025-10-15 16:55:52 +08:00
from typing import Dict, List, Tuple, Optional, Any
from scipy.ndimage import gaussian_filter1d
from concurrent.futures import ThreadPoolExecutor, as_completed
2025-10-15 16:55:52 +08:00
class BrainToTextDatasetTF:
"""
TensorFlow Dataset for brain-to-text data optimized for TPU v5e-8
This class creates tf.data.Dataset objects that efficiently load and batch
brain-to-text data from HDF5 files with TPU-optimized operations.
"""
def __init__(
self,
trial_indices: Dict[int, Dict[str, Any]],
n_batches: Optional[int],
split: str = 'train',
batch_size: int = 64,
days_per_batch: int = 1,
random_seed: int = -1,
must_include_days: Optional[List[int]] = None,
feature_subset: Optional[List[int]] = None,
prefetch_buffer: int = tf.data.AUTOTUNE,
2025-10-16 17:14:06 +08:00
num_parallel_calls: int = tf.data.AUTOTUNE,
cache_data: bool = True,
preload_all_data: bool = False
2025-10-15 16:55:52 +08:00
):
"""
Initialize TensorFlow dataset for brain-to-text data
Args:
trial_indices: Dictionary with day numbers as keys and trial info as values
n_batches: Number of training batches to create (None for validation)
split: 'train' or 'test'
batch_size: Number of examples per batch
days_per_batch: Number of unique days per batch (for day-specific layers)
random_seed: Random seed for reproducibility
must_include_days: Days that must be included in every batch
feature_subset: Subset of neural features to use
prefetch_buffer: Buffer size for prefetching
num_parallel_calls: Parallel processing threads
2025-10-16 17:14:06 +08:00
cache_data: Whether to cache loaded data in memory
preload_all_data: Whether to preload all data at initialization
2025-10-15 16:55:52 +08:00
"""
# Set random seed for reproducibility
if random_seed != -1:
tf.random.set_seed(random_seed)
np.random.seed(random_seed)
self.split = split
if self.split not in ['train', 'test']:
raise ValueError(f'split must be either "train" or "test". Received {self.split}')
self.days_per_batch = days_per_batch
self.batch_size = batch_size
self.n_batches = n_batches
self.trial_indices = trial_indices
self.n_days = len(trial_indices.keys())
self.feature_subset = feature_subset
self.must_include_days = must_include_days
self.prefetch_buffer = prefetch_buffer
self.num_parallel_calls = num_parallel_calls
2025-10-16 17:14:06 +08:00
self.cache_data = cache_data
self.preload_all_data = preload_all_data
# Initialize data cache
self.data_cache = {} if cache_data else None
2025-10-15 16:55:52 +08:00
# Calculate total number of trials
self.n_trials = 0
for d in trial_indices:
self.n_trials += len(trial_indices[d]['trials'])
# Validation checks
if must_include_days is not None:
if len(must_include_days) > days_per_batch:
raise ValueError(f'must_include_days must be <= days_per_batch')
# Map negative indices
for i, d in enumerate(must_include_days):
if d < 0:
must_include_days[i] = self.n_days + d
if self.split == 'train' and self.days_per_batch > self.n_days:
raise ValueError(f'days_per_batch ({days_per_batch}) > available days ({self.n_days})')
# Create batch indices
if self.split == 'train':
self.batch_indices = self._create_batch_index_train()
else:
self.batch_indices = self._create_batch_index_test()
self.n_batches = len(self.batch_indices)
2025-10-16 17:14:06 +08:00
# Preload data if requested (speeds up first batch significantly)
if self.preload_all_data:
print(f"🔄 Preloading all data for {self.split} split...")
self._preload_all_data()
print(f"✅ Preloading completed - {len(self.data_cache)} trials cached")
2025-10-22 00:54:20 +08:00
# ========================= 特征维度自动检测 (更稳健的版本) =========================
2025-10-21 00:19:05 +08:00
# 明确地计算并存储特征维度,避免 padded_batch 时的维度不匹配
if self.feature_subset:
self.feature_dim = len(self.feature_subset)
print(f"✅ Using feature subset dimension: {self.feature_dim}")
else:
2025-10-22 00:54:20 +08:00
# 稳健地从数据中推断实际特征维度
# 遍历所有 day直到找到第一个包含 trial 的 day
detected_dim = None
for day in self.trial_indices:
# 检查 trial 列表是否非空
if self.trial_indices[day]['trials']:
try:
# 列表非空,尝试加载第一个 trial
first_valid_trial = self.trial_indices[day]['trials'][0]
first_sample = self._load_single_trial_data(day, first_valid_trial)
detected_dim = first_sample['input_features'].shape[1]
print(f"✅ Auto-detected feature dimension from day {day}: {detected_dim}")
break # 成功检测到维度,跳出循环
except Exception as e:
# 如果加载这个 trial 失败,则继续尝试下一个 day
print(f"⚠️ Warning: Could not load trial {first_valid_trial} from day {day} for dimension check. Error: {e}")
continue
if detected_dim is not None:
self.feature_dim = detected_dim
else:
# 如果遍历完所有 day 都没有找到任何有效的 trial则报错或回退
print(f"⚠️ CRITICAL: Could not auto-detect feature dimension after checking all days. No valid trials found in the dataset split '{self.split}'. Falling back to 512.")
2025-10-21 00:19:05 +08:00
self.feature_dim = 512 # 作为最后的备用方案
# ========================= 特征维度检测结束 =========================
2025-10-15 16:55:52 +08:00
def _create_batch_index_train(self) -> Dict[int, Dict[int, List[int]]]:
"""Create training batch indices with random sampling"""
batch_indices = {}
# Precompute non-must-include days
if self.must_include_days is not None:
non_must_include_days = [
d for d in self.trial_indices.keys()
if d not in self.must_include_days
]
for batch_idx in range(self.n_batches):
batch = {}
# Select days for this batch
if self.must_include_days is not None and len(self.must_include_days) > 0:
additional_days = np.random.choice(
non_must_include_days,
size=self.days_per_batch - len(self.must_include_days),
replace=False
)
days = np.concatenate((self.must_include_days, additional_days))
else:
days = np.random.choice(
list(self.trial_indices.keys()),
size=self.days_per_batch,
replace=False
)
# Calculate trials per day
num_trials = math.ceil(self.batch_size / self.days_per_batch)
for d in days:
# Sample trials with replacement
trial_idxs = np.random.choice(
self.trial_indices[d]['trials'],
size=num_trials,
replace=True
)
batch[d] = trial_idxs.tolist()
# Remove extra trials to match exact batch size
extra_trials = (num_trials * len(days)) - self.batch_size
while extra_trials > 0:
d = np.random.choice(days)
if len(batch[d]) > 0:
batch[d] = batch[d][:-1]
extra_trials -= 1
batch_indices[batch_idx] = batch
return batch_indices
def _create_batch_index_test(self) -> Dict[int, Dict[int, List[int]]]:
"""Create test batch indices ensuring all trials are seen once"""
batch_indices = {}
batch_idx = 0
for d in self.trial_indices.keys():
num_trials = len(self.trial_indices[d]['trials'])
num_batches = (num_trials + self.batch_size - 1) // self.batch_size
for i in range(num_batches):
start_idx = i * self.batch_size
end_idx = min((i + 1) * self.batch_size, num_trials)
batch_trials = self.trial_indices[d]['trials'][start_idx:end_idx]
batch_indices[batch_idx] = {d: batch_trials}
batch_idx += 1
return batch_indices
2025-10-16 17:14:06 +08:00
def _preload_all_data(self):
"""Preload all trial data into memory cache (uses available RAM optimally)"""
import multiprocessing
from concurrent.futures import ThreadPoolExecutor, as_completed
# Use CPU cores efficiently for parallel I/O
max_workers = min(multiprocessing.cpu_count(), 32) # Limit to avoid overwhelming I/O
# Collect all trials to load
trials_to_load = []
for day in self.trial_indices:
for trial in self.trial_indices[day]['trials']:
trials_to_load.append((day, trial))
print(f"📊 Preloading {len(trials_to_load)} trials using {max_workers} workers...")
# Parallel loading using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all loading tasks
future_to_trial = {
executor.submit(self._load_single_trial_data, day, trial): (day, trial)
for day, trial in trials_to_load
}
# Process completed tasks and update cache
loaded_count = 0
for future in as_completed(future_to_trial):
day, trial = future_to_trial[future]
try:
trial_data = future.result()
cache_key = f"{day}_{trial}"
self.data_cache[cache_key] = trial_data
loaded_count += 1
# Progress indicator every 100 trials
if loaded_count % 100 == 0:
print(f" Loaded {loaded_count}/{len(trials_to_load)} trials...")
except Exception as e:
print(f" Warning: Failed to load trial {day}_{trial}: {e}")
print(f"✅ Preloading completed: {loaded_count}/{len(trials_to_load)} trials cached")
def _load_single_trial_data(self, day: int, trial: int) -> Dict[str, Any]:
"""Load a single trial's data - optimized version for parallel loading"""
2025-10-15 16:55:52 +08:00
try:
session_path = self.trial_indices[day]['session_path']
with h5py.File(session_path, 'r') as f:
g = f[f'trial_{trial:04d}']
# Load neural features
input_features = g['input_features'][:]
if self.feature_subset:
input_features = input_features[:, self.feature_subset]
2025-10-16 17:14:06 +08:00
# Convert to float32 for TF compatibility
input_features = input_features.astype(np.float32)
2025-10-15 16:55:52 +08:00
trial_data = {
'input_features': input_features,
'seq_class_ids': g['seq_class_ids'][:],
'transcription': g['transcription'][:],
'n_time_steps': g.attrs['n_time_steps'],
'phone_seq_lens': g.attrs['seq_len'],
'day_index': day,
'block_num': g.attrs['block_num'],
'trial_num': g.attrs['trial_num']
}
return trial_data
except Exception as e:
# Log the error and return dummy data with correct feature dimension
logging.warning(f"Failed to load trial {day}_{trial} from {session_path}. Error: {e}. Returning dummy data.")
# Use self.feature_dim to ensure dimension consistency
feature_dim = self.feature_dim
2025-10-15 16:55:52 +08:00
return {
'input_features': np.zeros((100, feature_dim), dtype=np.float32),
2025-10-15 16:55:52 +08:00
'seq_class_ids': np.zeros((10,), dtype=np.int32),
'transcription': np.zeros((50,), dtype=np.int32),
'n_time_steps': 100,
'phone_seq_lens': 10,
'day_index': day,
'block_num': 0,
'trial_num': 0
}
2025-10-16 17:14:06 +08:00
def _load_trial_data(self, day: int, trial: int) -> Dict[str, tf.Tensor]:
"""Load a single trial's data from cache or HDF5 file"""
# Check cache first if caching is enabled
if self.cache_data:
cache_key = f"{day}_{trial}"
if cache_key in self.data_cache:
return self.data_cache[cache_key]
# Load from disk if not in cache
trial_data = self._load_single_trial_data(day, trial)
# Cache the loaded data if caching is enabled
if self.cache_data:
cache_key = f"{day}_{trial}"
self.data_cache[cache_key] = trial_data
return trial_data
2025-10-15 16:55:52 +08:00
def _create_batch_generator(self):
"""
Generator function that yields individual batches with optimized loading
DEPRECATED: This method is deprecated. Use create_input_fn() instead for better performance
and TPU compatibility. This method will be removed in a future version.
"""
import warnings
warnings.warn(
"_create_batch_generator is deprecated. Use create_input_fn() instead for better performance.",
DeprecationWarning,
stacklevel=2
)
2025-10-16 17:14:06 +08:00
import time
from concurrent.futures import ThreadPoolExecutor
2025-10-15 16:55:52 +08:00
for batch_idx in range(self.n_batches):
2025-10-16 17:14:06 +08:00
batch_start_time = time.time()
2025-10-15 16:55:52 +08:00
batch_data = {
'input_features': [],
'seq_class_ids': [],
'n_time_steps': [],
'phone_seq_lens': [],
'day_indices': [],
'transcriptions': [],
'block_nums': [],
'trial_nums': []
}
batch_index = self.batch_indices[batch_idx]
2025-10-16 17:14:06 +08:00
# Collect all trials to load for this batch
trials_to_load = []
2025-10-15 16:55:52 +08:00
for day in batch_index.keys():
for trial in batch_index[day]:
2025-10-16 17:14:06 +08:00
trials_to_load.append((day, trial))
# Use parallel loading if not preloaded and have multiple trials
if not self.preload_all_data and len(trials_to_load) > 4:
# Parallel loading for faster I/O
with ThreadPoolExecutor(max_workers=min(8, len(trials_to_load))) as executor:
future_to_trial = {
executor.submit(self._load_trial_data, day, trial): (day, trial)
for day, trial in trials_to_load
}
# Collect results in order
trial_results = {}
for future in future_to_trial:
day, trial = future_to_trial[future]
trial_results[(day, trial)] = future.result()
# Add data in original order
for day, trial in trials_to_load:
trial_data = trial_results[(day, trial)]
batch_data['input_features'].append(trial_data['input_features'])
batch_data['seq_class_ids'].append(trial_data['seq_class_ids'])
batch_data['transcriptions'].append(trial_data['transcription'])
batch_data['n_time_steps'].append(trial_data['n_time_steps'])
batch_data['phone_seq_lens'].append(trial_data['phone_seq_lens'])
batch_data['day_indices'].append(trial_data['day_index'])
batch_data['block_nums'].append(trial_data['block_num'])
batch_data['trial_nums'].append(trial_data['trial_num'])
else:
# Sequential loading (fast when data is cached or few trials)
for day, trial in trials_to_load:
2025-10-15 16:55:52 +08:00
trial_data = self._load_trial_data(day, trial)
batch_data['input_features'].append(trial_data['input_features'])
batch_data['seq_class_ids'].append(trial_data['seq_class_ids'])
batch_data['transcriptions'].append(trial_data['transcription'])
batch_data['n_time_steps'].append(trial_data['n_time_steps'])
batch_data['phone_seq_lens'].append(trial_data['phone_seq_lens'])
batch_data['day_indices'].append(trial_data['day_index'])
batch_data['block_nums'].append(trial_data['block_num'])
batch_data['trial_nums'].append(trial_data['trial_num'])
2025-10-16 17:14:06 +08:00
data_loading_time = time.time() - batch_start_time
# Add timing diagnostic for first few batches
if batch_idx < 3:
cache_status = "cached" if self.preload_all_data else "disk"
loading_method = "parallel" if (not self.preload_all_data and len(trials_to_load) > 4) else "sequential"
print(f"⏱️ Batch {batch_idx}: {len(trials_to_load)} trials loaded in {data_loading_time:.3f}s ({cache_status}, {loading_method})")
2025-10-15 16:55:52 +08:00
# Pad sequences to create uniform batch
max_time_steps = max(batch_data['n_time_steps'])
max_phone_len = max(len(seq) for seq in batch_data['seq_class_ids'])
max_transcription_len = max(len(trans) for trans in batch_data['transcriptions'])
# Pad input features
padded_features = []
for features in batch_data['input_features']:
if features.shape[0] < max_time_steps:
padding = np.zeros((max_time_steps - features.shape[0], features.shape[1]), dtype=np.float32)
features = np.vstack([features, padding])
padded_features.append(features)
# Pad sequences
padded_seq_ids = []
for seq in batch_data['seq_class_ids']:
if len(seq) < max_phone_len:
padding = np.zeros(max_phone_len - len(seq), dtype=np.int32)
seq = np.concatenate([seq, padding])
padded_seq_ids.append(seq)
# Pad transcriptions
padded_transcriptions = []
for trans in batch_data['transcriptions']:
if len(trans) < max_transcription_len:
padding = np.zeros(max_transcription_len - len(trans), dtype=np.int32)
trans = np.concatenate([trans, padding])
padded_transcriptions.append(trans)
# Create final batch tensors
batch = {
'input_features': np.stack(padded_features),
'seq_class_ids': np.stack(padded_seq_ids),
'n_time_steps': np.array(batch_data['n_time_steps'], dtype=np.int32),
'phone_seq_lens': np.array(batch_data['phone_seq_lens'], dtype=np.int32),
'day_indices': np.array(batch_data['day_indices'], dtype=np.int32),
'transcriptions': np.stack(padded_transcriptions),
'block_nums': np.array(batch_data['block_nums'], dtype=np.int32),
'trial_nums': np.array(batch_data['trial_nums'], dtype=np.int32)
}
yield batch
def create_dataset(self) -> tf.data.Dataset:
"""
Create optimized tf.data.Dataset for TPU training
DEPRECATED: This method is deprecated. Use create_input_fn() instead for better performance
and TPU compatibility. This method will be removed in a future version.
"""
import warnings
warnings.warn(
"create_dataset is deprecated. Use create_input_fn() instead for better performance.",
DeprecationWarning,
stacklevel=2
)
2025-10-15 16:55:52 +08:00
# Define output signature for the dataset
output_signature = {
'input_features': tf.TensorSpec(shape=(None, None, None), dtype=tf.float32),
'seq_class_ids': tf.TensorSpec(shape=(None, None), dtype=tf.int32),
'n_time_steps': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'phone_seq_lens': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'day_indices': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'transcriptions': tf.TensorSpec(shape=(None, None), dtype=tf.int32),
'block_nums': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'trial_nums': tf.TensorSpec(shape=(None,), dtype=tf.int32)
}
# Create dataset from generator
dataset = tf.data.Dataset.from_generator(
self._create_batch_generator,
output_signature=output_signature
)
# Apply TPU-optimized transformations
2025-10-16 20:26:32 +08:00
# 🚨 GPU版本策略不需要在Dataset级别shuffle!
# GPU版本在 _create_batch_index_train() 中已经做了随机采样第107-118行
# 这里再shuffle会导致内存爆炸1000 batch × 256 trials = 256,000 trials同时在内存
# if self.split == 'train':
# dataset = dataset.shuffle(buffer_size=min(1000, self.n_batches)) # ← 注释掉内存杀手
2025-10-15 16:55:52 +08:00
# Prefetch for better performance
dataset = dataset.prefetch(self.prefetch_buffer)
return dataset
2025-10-17 12:20:17 +08:00
def create_individual_dataset(self) -> tf.data.Dataset:
"""
Create tf.data.Dataset that yields individual examples with I/O optimization.
2025-10-17 12:20:17 +08:00
This generator is refactored to group trial loading by session file,
drastically reducing the number of file open/close operations from
N_trials to N_sessions, which is ideal for slow disk I/O.
2025-10-17 12:20:17 +08:00
"""
def individual_example_generator():
"""Generator that groups reads by file to minimize disk I/O."""
2025-10-17 12:20:17 +08:00
# 1. 创建一个所有试验的扁平列表: [(day, trial), (day, trial), ...]
all_trials_to_load = []
# 注意:这里的迭代顺序决定了大致的读取顺序
# _create_batch_index_train 已经为我们随机化了批次
for batch_idx in sorted(self.batch_indices.keys()):
batch_index = self.batch_indices[batch_idx]
2025-10-17 12:20:17 +08:00
for day in batch_index.keys():
for trial in batch_index[day]:
all_trials_to_load.append((day, trial))
# 2. 按 'day' (即按文件) 对试验列表进行分组
# key=lambda x: x[0] 表示使用元组的第一个元素 (day) 作为分组键
for day, group in groupby(sorted(all_trials_to_load, key=lambda x: x[0]), key=lambda x: x[0]):
session_path = self.trial_indices[day]['session_path']
# 3. 为每个分组(每个文件)只打开一次 HDF5 文件
try:
with h5py.File(session_path, 'r') as f:
# 4. 在文件打开的状态下,读取这个文件中需要的所有试验
for current_day, current_trial in group:
try:
# 直接从打开的文件句柄 'f' 中读取,而不是调用旧的加载函数
g = f[f'trial_{current_trial:04d}']
input_features = g['input_features'][:]
if self.feature_subset:
input_features = input_features[:, self.feature_subset]
example = {
'input_features': input_features.astype(np.float32),
'seq_class_ids': g['seq_class_ids'][:].astype(np.int32),
'n_time_steps': np.int32(g.attrs['n_time_steps']),
'phone_seq_lens': np.int32(g.attrs['seq_len']),
'day_indices': np.int32(current_day),
'transcriptions': g['transcription'][:].astype(np.int32),
'block_nums': np.int32(g.attrs['block_num']),
'trial_nums': np.int32(g.attrs['trial_num'])
}
yield example
except KeyError:
logging.warning(f"Trial {current_trial} not found in file {session_path}. Skipping.")
continue
except (IOError, FileNotFoundError) as e:
logging.error(f"Could not open or read HDF5 file: {session_path}. Error: {e}. Skipping all trials for this day.")
continue
2025-10-17 12:20:17 +08:00
# Define output signature for individual examples
output_signature = {
'input_features': tf.TensorSpec(shape=(None, self.feature_dim), dtype=tf.float32),
2025-10-17 12:20:17 +08:00
'seq_class_ids': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'n_time_steps': tf.TensorSpec(shape=(), dtype=tf.int32),
'phone_seq_lens': tf.TensorSpec(shape=(), dtype=tf.int32),
'day_indices': tf.TensorSpec(shape=(), dtype=tf.int32),
'transcriptions': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'block_nums': tf.TensorSpec(shape=(), dtype=tf.int32),
'trial_nums': tf.TensorSpec(shape=(), dtype=tf.int32)
}
# Create dataset from individual examples
dataset = tf.data.Dataset.from_generator(
individual_example_generator,
output_signature=output_signature
)
# Shuffle individual examples if training (more effective than batch-level shuffle)
if self.split == 'train':
# 可以适当增大buffer因为现在I/O更高效了
shuffle_buffer = min(2048, self.n_trials)
2025-10-17 12:20:17 +08:00
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
return dataset
2025-10-15 16:55:52 +08:00
class DataAugmentationTF:
"""
TensorFlow data augmentation functions optimized for TPU v5e-8
"""
@staticmethod
def gauss_smooth(inputs: tf.Tensor,
smooth_kernel_std: float = 2.0,
smooth_kernel_size: int = 100) -> tf.Tensor:
"""
2025-10-17 12:20:17 +08:00
Apply Gaussian smoothing along the time axis using a vectorized TensorFlow operation.
This implementation uses depthwise_conv2d for optimal TPU performance,
replacing the inefficient Python for-loop that created 512 separate conv1d operations.
2025-10-15 16:55:52 +08:00
Args:
inputs: Input tensor [batch_size, time_steps, features]
smooth_kernel_std: Standard deviation of Gaussian kernel
smooth_kernel_size: Size of the Gaussian kernel
Returns:
Smoothed tensor with same shape as input
"""
# Create Gaussian kernel using numpy (computed once)
inp = np.zeros(smooth_kernel_size, dtype=np.float32)
inp[smooth_kernel_size // 2] = 1
gauss_kernel = gaussian_filter1d(inp, smooth_kernel_std)
valid_idx = np.argwhere(gauss_kernel > 0.01)
gauss_kernel = gauss_kernel[valid_idx].flatten()
gauss_kernel = gauss_kernel / np.sum(gauss_kernel)
gauss_kernel = tf.constant(gauss_kernel, dtype=tf.float32)
2025-10-17 12:20:17 +08:00
# ========================= OPTIMIZED SOLUTION =========================
# Get input dimensions
num_features = tf.shape(inputs)[-1]
2025-10-15 20:45:25 +08:00
kernel_size = tf.shape(gauss_kernel)[0]
2025-10-17 12:20:17 +08:00
# Prepare kernel for depthwise_conv2d
# Shape needed: [height, width, in_channels, channel_multiplier]
# Our case: [kernel_size, 1, num_features, 1]
# This means each input channel (num_features) has its own independent, identical 1D Gaussian kernel
kernel = tf.reshape(gauss_kernel, [kernel_size, 1, 1, 1])
kernel = tf.tile(kernel, [1, 1, num_features, 1])
# Prepare input for conv2d
# Shape needed: [batch, height, width, channels]
# Our case: [batch_size, time_steps, 1, num_features]
# Add a dummy width dimension
reshaped_inputs = tf.expand_dims(inputs, axis=2)
# Execute depthwise convolution
# This is a single, efficient operation replacing the original Python for-loop
smoothed = tf.nn.depthwise_conv2d(
reshaped_inputs,
kernel,
strides=[1, 1, 1, 1],
padding='SAME'
)
# Remove the dummy width dimension to restore original shape
smoothed = tf.squeeze(smoothed, axis=2)
# ================================================================
2025-10-15 16:55:52 +08:00
return smoothed
@staticmethod
def transform_data(features: tf.Tensor,
n_time_steps: tf.Tensor,
transform_args: Dict[str, Any],
training: bool = True) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Apply data transformations optimized for TPU
Args:
features: Input features [batch_size, time_steps, channels]
n_time_steps: Number of valid time steps per sample
transform_args: Transformation configuration
training: Whether to apply training-only augmentations
Returns:
Transformed features and updated time steps
"""
batch_size = tf.shape(features)[0]
time_steps = tf.shape(features)[1]
channels = tf.shape(features)[2]
# Training-only augmentations
if training:
# Static gain noise
if transform_args.get('static_gain_std', 0) > 0:
gain_std = transform_args['static_gain_std']
# Create identity matrices for each batch
identity_matrices = tf.eye(channels, batch_shape=[batch_size])
# Add noise to create warp matrices
noise = tf.random.normal([batch_size, channels, channels]) * gain_std
warp_matrices = identity_matrices + noise
# Apply transformation
features = tf.linalg.matmul(features, warp_matrices)
# White noise
if transform_args.get('white_noise_std', 0) > 0:
white_noise = tf.random.normal(tf.shape(features)) * transform_args['white_noise_std']
features = features + white_noise
# Constant offset noise
if transform_args.get('constant_offset_std', 0) > 0:
offset_noise = tf.random.normal([batch_size, 1, channels]) * transform_args['constant_offset_std']
features = features + offset_noise
# Random walk noise
if transform_args.get('random_walk_std', 0) > 0:
random_walk_noise = tf.random.normal(tf.shape(features)) * transform_args['random_walk_std']
axis = transform_args.get('random_walk_axis', 1)
random_walk_noise = tf.cumsum(random_walk_noise, axis=axis)
features = features + random_walk_noise
# Random cutoff (simplified for TPU - apply to all samples in batch)
if transform_args.get('random_cut', 0) > 0:
max_cut = transform_args['random_cut']
cut = tf.random.uniform([], 0, max_cut, dtype=tf.int32)
features = features[:, cut:, :]
n_time_steps = n_time_steps - cut
# Apply Gaussian smoothing (both training and validation)
if transform_args.get('smooth_data', False):
features = DataAugmentationTF.gauss_smooth(
features,
smooth_kernel_std=transform_args.get('smooth_kernel_std', 2.0),
smooth_kernel_size=transform_args.get('smooth_kernel_size', 100)
)
return features, n_time_steps
def train_test_split_indices(file_paths: List[str],
test_percentage: float = 0.1,
seed: int = -1,
bad_trials_dict: Optional[Dict] = None) -> Tuple[Dict, Dict]:
"""
Split data from file_paths into train and test splits
Args:
file_paths: List of HDF5 file paths
test_percentage: Percentage of trials for testing
seed: Random seed for reproducibility
bad_trials_dict: Dictionary of trials to exclude
Returns:
Tuple of (train_trials, test_trials) dictionaries
"""
# Set seed for reproducibility
if seed != -1:
np.random.seed(seed)
# Get trials in each day
trials_per_day = {}
for i, path in enumerate(file_paths):
try:
# Handle both Windows and Unix path separators
path_parts = path.replace('\\', '/').split('/')
session_candidates = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))]
if not session_candidates:
logging.error(f"Could not parse session name from path: {path}. Skipping this file.")
continue
session = session_candidates[0]
except Exception as e:
logging.error(f"Error parsing path {path}: {e}. Skipping this file.")
continue
2025-10-15 16:55:52 +08:00
good_trial_indices = []
if os.path.exists(path):
with h5py.File(path, 'r') as f:
num_trials = len(list(f.keys()))
for t in range(num_trials):
key = f'trial_{t:04d}'
if key not in f:
continue
block_num = f[key].attrs['block_num']
trial_num = f[key].attrs['trial_num']
# Check if trial should be excluded
if (bad_trials_dict is not None
and session in bad_trials_dict
and str(block_num) in bad_trials_dict[session]
and trial_num in bad_trials_dict[session][str(block_num)]):
continue
good_trial_indices.append(t)
trials_per_day[i] = {
'num_trials': len(good_trial_indices),
'trial_indices': good_trial_indices,
'session_path': path
}
# Split trials into train and test
train_trials = {}
test_trials = {}
for day in trials_per_day.keys():
num_trials = trials_per_day[day]['num_trials']
all_trial_indices = trials_per_day[day]['trial_indices']
if test_percentage == 0:
train_trials[day] = {
'trials': all_trial_indices,
'session_path': trials_per_day[day]['session_path']
}
test_trials[day] = {
'trials': [],
'session_path': trials_per_day[day]['session_path']
}
elif test_percentage == 1:
train_trials[day] = {
'trials': [],
'session_path': trials_per_day[day]['session_path']
}
test_trials[day] = {
'trials': all_trial_indices,
'session_path': trials_per_day[day]['session_path']
}
else:
# Calculate number of test trials
num_test = max(1, int(num_trials * test_percentage))
# Randomly select test indices
test_indices = np.random.choice(all_trial_indices, size=num_test, replace=False).tolist()
# Remaining indices for training
train_indices = [idx for idx in all_trial_indices if idx not in test_indices]
train_trials[day] = {
'trials': train_indices,
'session_path': trials_per_day[day]['session_path']
}
test_trials[day] = {
'trials': test_indices,
'session_path': trials_per_day[day]['session_path']
}
return train_trials, test_trials
def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 100) -> Dict[str, int]:
"""
Analyzes dataset shapes in parallel to determine maximum dimensions for padded_batch,
utilizing multiple CPU cores and the dataset's caching mechanism.
Args:
dataset_tf: Dataset instance to analyze
sample_size: Number of samples to analyze (set to -1 for all data)
Returns:
Dictionary with maximum dimensions
"""
print(f"🚀 Starting parallel dataset analysis (sampling: {'ALL' if sample_size == -1 else sample_size})...")
start_time = time.time()
# 1. 收集所有需要分析的 (day, trial) 对,避免重复
2025-10-20 00:13:39 +08:00
all_trials = []
unique_trials = set()
2025-10-20 00:13:39 +08:00
for batch_idx in sorted(dataset_tf.batch_indices.keys()):
batch = dataset_tf.batch_indices[batch_idx]
for day, trials in batch.items():
for trial in trials:
if (day, trial) not in unique_trials:
unique_trials.add((day, trial))
all_trials.append((day, trial))
# 2. 如果需要采样,则对列表进行采样
if 0 < sample_size < len(all_trials):
# 设置种子以确保可重现性(如果需要的话)
random.seed(42)
trials_to_check = random.sample(all_trials, sample_size)
2025-10-19 13:18:20 +08:00
else:
2025-10-20 00:13:39 +08:00
trials_to_check = all_trials
total_trials_to_analyze = len(trials_to_check)
print(f"📊 Total unique trials to analyze: {total_trials_to_analyze}")
2025-10-19 13:18:20 +08:00
# 定义一个辅助函数,供每个线程调用
def analyze_single_trial(day_trial_pair):
"""Loads and analyzes a single trial, returns its shapes."""
day, trial = day_trial_pair
2025-10-20 00:13:39 +08:00
try:
# 复用 dataset_tf 的加载和缓存逻辑
trial_data = dataset_tf._load_trial_data(day, trial)
# 直接从加载的数据中获取信息
time_steps = int(trial_data['n_time_steps'])
phone_seq_len = int(trial_data['phone_seq_lens'])
# 处理 transcription 数据 - 它可能是数组
transcription_data = trial_data['transcription']
if hasattr(transcription_data, '__len__'):
2025-10-20 00:13:39 +08:00
transcription_len = len(transcription_data)
else:
transcription_len = 1 # 如果是标量长度为1
return (time_steps, phone_seq_len, transcription_len)
except Exception as e:
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
return None # 返回 None 表示失败
# 3. 使用 ThreadPoolExecutor 进行并行处理
# Use dynamic calculation based on CPU cores with reasonable upper limit
cpu_count = os.cpu_count() or 4 # Fallback to 4 if cpu_count() returns None
max_workers = min(32, cpu_count, len(trials_to_check))
local_max_shapes = []
print(f"🔧 Using {max_workers} parallel workers for analysis...")
2025-10-19 13:18:20 +08:00
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交所有分析任务
future_to_trial = {
executor.submit(analyze_single_trial, trial_pair): trial_pair
for trial_pair in trials_to_check
}
count = 0
progress_interval = max(1, total_trials_to_analyze // 20) # 显示20次进度更新
for future in as_completed(future_to_trial):
result = future.result()
if result:
local_max_shapes.append(result)
count += 1
if count % progress_interval == 0 or count == total_trials_to_analyze:
elapsed = time.time() - start_time
rate = count / elapsed if elapsed > 0 else 0
print(f" 📈 Analyzed {count}/{total_trials_to_analyze} trials... ({count/total_trials_to_analyze:.1%}) [{rate:.1f} trials/sec]")
# 4. 聚合所有线程的结果
if not local_max_shapes:
raise ValueError("Dataset analysis failed: No trials could be successfully analyzed.")
# 将 [(t1, p1, tr1), (t2, p2, tr2), ...] 转换为 ([t1, t2, ...], [p1, p2, ...], ...)
unzipped_shapes = list(zip(*local_max_shapes))
original_max_shapes = {
'max_time_steps': int(np.max(unzipped_shapes[0])),
'max_phone_seq_len': int(np.max(unzipped_shapes[1])),
'max_transcription_len': int(np.max(unzipped_shapes[2])),
'n_features': dataset_tf.feature_dim
}
# 5. 添加适当的安全边际 - 基于分析范围调整
if sample_size == -1:
# 全数据分析:只需要很小的边际应对可能的舍入误差
safety_margin = 1.02 # 2% buffer for rounding errors
margin_reason = "minimal buffer for full dataset analysis"
else:
# 采样分析:需要更大的边际应对未采样到的极值
safety_margin = 1.3 # 30% buffer for sampling uncertainty
margin_reason = f"larger buffer due to sampling only {sample_size} trials"
final_max_shapes = {
'max_time_steps': int(original_max_shapes['max_time_steps'] * safety_margin),
'max_phone_seq_len': int(original_max_shapes['max_phone_seq_len'] * safety_margin),
'max_transcription_len': int(original_max_shapes['max_transcription_len'] * safety_margin),
'n_features': original_max_shapes['n_features']
}
2025-10-19 13:18:20 +08:00
analysis_time = time.time() - start_time
successful_rate = len(local_max_shapes) / total_trials_to_analyze * 100
print(f"✅ Parallel analysis complete in {analysis_time:.2f} seconds!")
print(f"📊 Successfully analyzed {len(local_max_shapes)}/{total_trials_to_analyze} trials ({successful_rate:.1f}%)")
print(f"📏 Final max shapes (with {int((safety_margin-1)*100)}% safety margin - {margin_reason}):")
print(f" Time steps: {original_max_shapes['max_time_steps']}{final_max_shapes['max_time_steps']}")
print(f" Phone sequence length: {original_max_shapes['max_phone_seq_len']}{final_max_shapes['max_phone_seq_len']}")
print(f" Transcription length: {original_max_shapes['max_transcription_len']}{final_max_shapes['max_transcription_len']}")
print(f" Number of features: {final_max_shapes['n_features']}")
return final_max_shapes
2025-10-15 16:55:52 +08:00
# Utility functions for TPU-optimized data pipeline
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
transform_args: Dict[str, Any],
training: bool = True,
cache_path: Optional[str] = None,
use_static_shapes: bool = True) -> tf.data.Dataset:
2025-10-15 16:55:52 +08:00
"""
Create input function for TPU training with configurable shape handling
2025-10-15 16:55:52 +08:00
Args:
dataset_tf: BrainToTextDatasetTF instance
transform_args: Data transformation configuration
training: Whether this is for training (applies augmentations)
cache_path: Optional path for disk caching to improve I/O performance
use_static_shapes: If True, use pre-computed static shapes for XLA compatibility
2025-10-15 16:55:52 +08:00
Returns:
tf.data.Dataset ready for TPU training
2025-10-15 16:55:52 +08:00
"""
2025-10-21 00:31:59 +08:00
# Step 1: Create individual example dataset
2025-10-17 12:20:17 +08:00
dataset = dataset_tf.create_individual_dataset()
2025-10-21 00:31:59 +08:00
# Step 2: Cache raw samples BEFORE any augmentation or batching
if cache_path:
dataset = dataset.cache(cache_path)
split_name = "training" if training else "validation"
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: {cache_path}")
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
else:
dataset = dataset.cache()
split_name = "training" if training else "validation"
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache")
2025-10-20 13:37:11 +08:00
# Step 3: Batch samples with shape handling optimized for TPU
if use_static_shapes:
print(f"🔧 Using STATIC shapes for XLA compatibility")
# Analyze dataset to get maximum shapes
print("📊 Analyzing dataset for maximum shapes...")
max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1) # Analyze ALL data for maximum accuracy
# Use static shapes based on analysis
padded_shapes = {
'input_features': (max_shapes['max_time_steps'], dataset_tf.feature_dim),
'seq_class_ids': (max_shapes['max_phone_seq_len'],),
'n_time_steps': (),
'phone_seq_lens': (),
'day_indices': (),
'transcriptions': (max_shapes['max_transcription_len'],),
'block_nums': (),
'trial_nums': ()
}
print(f"📏 Using static shapes: time_steps={max_shapes['max_time_steps']}, "
f"phone_len={max_shapes['max_phone_seq_len']}, "
f"transcription_len={max_shapes['max_transcription_len']}")
else:
print(f"🔧 Using DYNAMIC shapes (may cause XLA compilation issues)")
# Use dynamic shapes - may cause XLA compilation issues
padded_shapes = {
'input_features': (None, dataset_tf.feature_dim),
'seq_class_ids': (None,),
'n_time_steps': (),
'phone_seq_lens': (),
'day_indices': (),
'transcriptions': (None,),
'block_nums': (),
'trial_nums': ()
}
2025-10-20 13:37:11 +08:00
print(f"🔧 Feature dimension: {dataset_tf.feature_dim}")
2025-10-17 12:20:17 +08:00
2025-10-19 20:16:23 +08:00
# Define padding values for each field
2025-10-17 12:20:17 +08:00
padding_values = {
'input_features': 0.0,
'seq_class_ids': 0,
'n_time_steps': 0,
'phone_seq_lens': 0,
'day_indices': 0,
'transcriptions': 0,
'block_nums': 0,
'trial_nums': 0
}
2025-10-21 00:31:59 +08:00
# Create batches with DYNAMIC padding - this cannot fail due to size mismatches
2025-10-17 12:20:17 +08:00
dataset = dataset.padded_batch(
batch_size=dataset_tf.batch_size,
padded_shapes=padded_shapes,
padding_values=padding_values,
drop_remainder=True # Critical for TPU: ensures all batches have same size
)
2025-10-21 00:31:59 +08:00
# Step 4: Apply data augmentation to BATCHES (after dynamic batching)
def apply_batch_transforms(batch):
"""Apply data transformations to entire batches - resolves time paradox"""
features = batch['input_features']
n_time_steps = batch['n_time_steps']
# Apply transformations to the entire batch
features, n_time_steps = DataAugmentationTF.transform_data(
features, # Already has batch dimension
n_time_steps,
transform_args,
training=training
)
# Update batch with transformed data
batch['input_features'] = features
batch['n_time_steps'] = n_time_steps
return batch
# Apply batch transforms only during training
if training:
dataset = dataset.map(
apply_batch_transforms,
num_parallel_calls=tf.data.AUTOTUNE
)
print(f"✅ Batch augmentation enabled for training")
else:
print(f"✅ No augmentation for validation")
# Step 5: Prefetch for optimal performance
2025-10-17 12:20:17 +08:00
dataset = dataset.prefetch(tf.data.AUTOTUNE)
2025-10-15 16:55:52 +08:00
return dataset