1043 lines
45 KiB
Python
1043 lines
45 KiB
Python
import os
|
||
import sys
|
||
import tensorflow as tf
|
||
import h5py
|
||
import numpy as np
|
||
import math
|
||
import logging
|
||
import time
|
||
import random
|
||
import multiprocessing
|
||
from itertools import groupby
|
||
from typing import Dict, List, Tuple, Optional, Any
|
||
from scipy.ndimage import gaussian_filter1d
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
|
||
|
||
class BrainToTextDatasetTF:
|
||
"""
|
||
TensorFlow Dataset for brain-to-text data optimized for TPU v5e-8
|
||
|
||
This class creates tf.data.Dataset objects that efficiently load and batch
|
||
brain-to-text data from HDF5 files with TPU-optimized operations.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
trial_indices: Dict[int, Dict[str, Any]],
|
||
n_batches: Optional[int],
|
||
split: str = 'train',
|
||
batch_size: int = 64,
|
||
days_per_batch: int = 1,
|
||
random_seed: int = -1,
|
||
must_include_days: Optional[List[int]] = None,
|
||
feature_subset: Optional[List[int]] = None,
|
||
prefetch_buffer: int = tf.data.AUTOTUNE,
|
||
num_parallel_calls: int = tf.data.AUTOTUNE,
|
||
cache_data: bool = True,
|
||
preload_all_data: bool = False
|
||
):
|
||
"""
|
||
Initialize TensorFlow dataset for brain-to-text data
|
||
|
||
Args:
|
||
trial_indices: Dictionary with day numbers as keys and trial info as values
|
||
n_batches: Number of training batches to create (None for validation)
|
||
split: 'train' or 'test'
|
||
batch_size: Number of examples per batch
|
||
days_per_batch: Number of unique days per batch (for day-specific layers)
|
||
random_seed: Random seed for reproducibility
|
||
must_include_days: Days that must be included in every batch
|
||
feature_subset: Subset of neural features to use
|
||
prefetch_buffer: Buffer size for prefetching
|
||
num_parallel_calls: Parallel processing threads
|
||
cache_data: Whether to cache loaded data in memory
|
||
preload_all_data: Whether to preload all data at initialization
|
||
"""
|
||
|
||
# Set random seed for reproducibility
|
||
if random_seed != -1:
|
||
tf.random.set_seed(random_seed)
|
||
np.random.seed(random_seed)
|
||
|
||
self.split = split
|
||
if self.split not in ['train', 'test']:
|
||
raise ValueError(f'split must be either "train" or "test". Received {self.split}')
|
||
|
||
self.days_per_batch = days_per_batch
|
||
self.batch_size = batch_size
|
||
self.n_batches = n_batches
|
||
self.trial_indices = trial_indices
|
||
self.n_days = len(trial_indices.keys())
|
||
self.feature_subset = feature_subset
|
||
self.must_include_days = must_include_days
|
||
self.prefetch_buffer = prefetch_buffer
|
||
self.num_parallel_calls = num_parallel_calls
|
||
self.cache_data = cache_data
|
||
self.preload_all_data = preload_all_data
|
||
|
||
# Initialize data cache
|
||
self.data_cache = {} if cache_data else None
|
||
|
||
# Calculate total number of trials
|
||
self.n_trials = 0
|
||
for d in trial_indices:
|
||
self.n_trials += len(trial_indices[d]['trials'])
|
||
|
||
# Validation checks
|
||
if must_include_days is not None:
|
||
if len(must_include_days) > days_per_batch:
|
||
raise ValueError(f'must_include_days must be <= days_per_batch')
|
||
|
||
# Map negative indices
|
||
for i, d in enumerate(must_include_days):
|
||
if d < 0:
|
||
must_include_days[i] = self.n_days + d
|
||
|
||
if self.split == 'train' and self.days_per_batch > self.n_days:
|
||
raise ValueError(f'days_per_batch ({days_per_batch}) > available days ({self.n_days})')
|
||
|
||
# Create batch indices
|
||
if self.split == 'train':
|
||
self.batch_indices = self._create_batch_index_train()
|
||
else:
|
||
self.batch_indices = self._create_batch_index_test()
|
||
self.n_batches = len(self.batch_indices)
|
||
|
||
# Preload data if requested (speeds up first batch significantly)
|
||
if self.preload_all_data:
|
||
print(f"🔄 Preloading all data for {self.split} split...")
|
||
self._preload_all_data()
|
||
print(f"✅ Preloading completed - {len(self.data_cache)} trials cached")
|
||
|
||
# ========================= 特征维度自动检测 =========================
|
||
# 明确地计算并存储特征维度,避免 padded_batch 时的维度不匹配
|
||
if self.feature_subset:
|
||
self.feature_dim = len(self.feature_subset)
|
||
print(f"✅ Using feature subset dimension: {self.feature_dim}")
|
||
else:
|
||
# 不要硬编码!尝试从数据中推断实际特征维度
|
||
try:
|
||
first_day = next(iter(self.trial_indices))
|
||
first_trial = self.trial_indices[first_day]['trials'][0]
|
||
first_sample = self._load_single_trial_data(first_day, first_trial)
|
||
self.feature_dim = first_sample['input_features'].shape[1]
|
||
print(f"✅ Auto-detected feature dimension: {self.feature_dim}")
|
||
except Exception as e:
|
||
print(f"⚠️ Could not auto-detect feature dimension, falling back to 512. Error: {e}")
|
||
self.feature_dim = 512 # 作为最后的备用方案
|
||
# ========================= 特征维度检测结束 =========================
|
||
|
||
def _create_batch_index_train(self) -> Dict[int, Dict[int, List[int]]]:
|
||
"""Create training batch indices with random sampling"""
|
||
batch_indices = {}
|
||
|
||
# Precompute non-must-include days
|
||
if self.must_include_days is not None:
|
||
non_must_include_days = [
|
||
d for d in self.trial_indices.keys()
|
||
if d not in self.must_include_days
|
||
]
|
||
|
||
for batch_idx in range(self.n_batches):
|
||
batch = {}
|
||
|
||
# Select days for this batch
|
||
if self.must_include_days is not None and len(self.must_include_days) > 0:
|
||
additional_days = np.random.choice(
|
||
non_must_include_days,
|
||
size=self.days_per_batch - len(self.must_include_days),
|
||
replace=False
|
||
)
|
||
days = np.concatenate((self.must_include_days, additional_days))
|
||
else:
|
||
days = np.random.choice(
|
||
list(self.trial_indices.keys()),
|
||
size=self.days_per_batch,
|
||
replace=False
|
||
)
|
||
|
||
# Calculate trials per day
|
||
num_trials = math.ceil(self.batch_size / self.days_per_batch)
|
||
|
||
for d in days:
|
||
# Sample trials with replacement
|
||
trial_idxs = np.random.choice(
|
||
self.trial_indices[d]['trials'],
|
||
size=num_trials,
|
||
replace=True
|
||
)
|
||
batch[d] = trial_idxs.tolist()
|
||
|
||
# Remove extra trials to match exact batch size
|
||
extra_trials = (num_trials * len(days)) - self.batch_size
|
||
while extra_trials > 0:
|
||
d = np.random.choice(days)
|
||
if len(batch[d]) > 0:
|
||
batch[d] = batch[d][:-1]
|
||
extra_trials -= 1
|
||
|
||
batch_indices[batch_idx] = batch
|
||
|
||
return batch_indices
|
||
|
||
def _create_batch_index_test(self) -> Dict[int, Dict[int, List[int]]]:
|
||
"""Create test batch indices ensuring all trials are seen once"""
|
||
batch_indices = {}
|
||
batch_idx = 0
|
||
|
||
for d in self.trial_indices.keys():
|
||
num_trials = len(self.trial_indices[d]['trials'])
|
||
num_batches = (num_trials + self.batch_size - 1) // self.batch_size
|
||
|
||
for i in range(num_batches):
|
||
start_idx = i * self.batch_size
|
||
end_idx = min((i + 1) * self.batch_size, num_trials)
|
||
|
||
batch_trials = self.trial_indices[d]['trials'][start_idx:end_idx]
|
||
batch_indices[batch_idx] = {d: batch_trials}
|
||
batch_idx += 1
|
||
|
||
return batch_indices
|
||
|
||
def _preload_all_data(self):
|
||
"""Preload all trial data into memory cache (uses available RAM optimally)"""
|
||
import multiprocessing
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
|
||
# Use CPU cores efficiently for parallel I/O
|
||
max_workers = min(multiprocessing.cpu_count(), 32) # Limit to avoid overwhelming I/O
|
||
|
||
# Collect all trials to load
|
||
trials_to_load = []
|
||
for day in self.trial_indices:
|
||
for trial in self.trial_indices[day]['trials']:
|
||
trials_to_load.append((day, trial))
|
||
|
||
print(f"📊 Preloading {len(trials_to_load)} trials using {max_workers} workers...")
|
||
|
||
# Parallel loading using ThreadPoolExecutor
|
||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||
# Submit all loading tasks
|
||
future_to_trial = {
|
||
executor.submit(self._load_single_trial_data, day, trial): (day, trial)
|
||
for day, trial in trials_to_load
|
||
}
|
||
|
||
# Process completed tasks and update cache
|
||
loaded_count = 0
|
||
for future in as_completed(future_to_trial):
|
||
day, trial = future_to_trial[future]
|
||
try:
|
||
trial_data = future.result()
|
||
cache_key = f"{day}_{trial}"
|
||
self.data_cache[cache_key] = trial_data
|
||
loaded_count += 1
|
||
|
||
# Progress indicator every 100 trials
|
||
if loaded_count % 100 == 0:
|
||
print(f" Loaded {loaded_count}/{len(trials_to_load)} trials...")
|
||
|
||
except Exception as e:
|
||
print(f" Warning: Failed to load trial {day}_{trial}: {e}")
|
||
|
||
print(f"✅ Preloading completed: {loaded_count}/{len(trials_to_load)} trials cached")
|
||
|
||
def _load_single_trial_data(self, day: int, trial: int) -> Dict[str, Any]:
|
||
"""Load a single trial's data - optimized version for parallel loading"""
|
||
try:
|
||
session_path = self.trial_indices[day]['session_path']
|
||
|
||
with h5py.File(session_path, 'r') as f:
|
||
g = f[f'trial_{trial:04d}']
|
||
|
||
# Load neural features
|
||
input_features = g['input_features'][:]
|
||
if self.feature_subset:
|
||
input_features = input_features[:, self.feature_subset]
|
||
|
||
# Convert to float32 for TF compatibility
|
||
input_features = input_features.astype(np.float32)
|
||
|
||
trial_data = {
|
||
'input_features': input_features,
|
||
'seq_class_ids': g['seq_class_ids'][:],
|
||
'transcription': g['transcription'][:],
|
||
'n_time_steps': g.attrs['n_time_steps'],
|
||
'phone_seq_lens': g.attrs['seq_len'],
|
||
'day_index': day,
|
||
'block_num': g.attrs['block_num'],
|
||
'trial_num': g.attrs['trial_num']
|
||
}
|
||
|
||
return trial_data
|
||
|
||
except Exception as e:
|
||
# Return dummy data for failed loads
|
||
return {
|
||
'input_features': np.zeros((100, 512), dtype=np.float32),
|
||
'seq_class_ids': np.zeros((10,), dtype=np.int32),
|
||
'transcription': np.zeros((50,), dtype=np.int32),
|
||
'n_time_steps': 100,
|
||
'phone_seq_lens': 10,
|
||
'day_index': day,
|
||
'block_num': 0,
|
||
'trial_num': 0
|
||
}
|
||
|
||
def _load_trial_data(self, day: int, trial: int) -> Dict[str, tf.Tensor]:
|
||
"""Load a single trial's data from cache or HDF5 file"""
|
||
# Check cache first if caching is enabled
|
||
if self.cache_data:
|
||
cache_key = f"{day}_{trial}"
|
||
if cache_key in self.data_cache:
|
||
return self.data_cache[cache_key]
|
||
|
||
# Load from disk if not in cache
|
||
trial_data = self._load_single_trial_data(day, trial)
|
||
|
||
# Cache the loaded data if caching is enabled
|
||
if self.cache_data:
|
||
cache_key = f"{day}_{trial}"
|
||
self.data_cache[cache_key] = trial_data
|
||
|
||
return trial_data
|
||
|
||
def _create_batch_generator(self):
|
||
"""Generator function that yields individual batches with optimized loading"""
|
||
import time
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
|
||
for batch_idx in range(self.n_batches):
|
||
batch_start_time = time.time()
|
||
|
||
batch_data = {
|
||
'input_features': [],
|
||
'seq_class_ids': [],
|
||
'n_time_steps': [],
|
||
'phone_seq_lens': [],
|
||
'day_indices': [],
|
||
'transcriptions': [],
|
||
'block_nums': [],
|
||
'trial_nums': []
|
||
}
|
||
|
||
batch_index = self.batch_indices[batch_idx]
|
||
|
||
# Collect all trials to load for this batch
|
||
trials_to_load = []
|
||
for day in batch_index.keys():
|
||
for trial in batch_index[day]:
|
||
trials_to_load.append((day, trial))
|
||
|
||
# Use parallel loading if not preloaded and have multiple trials
|
||
if not self.preload_all_data and len(trials_to_load) > 4:
|
||
# Parallel loading for faster I/O
|
||
with ThreadPoolExecutor(max_workers=min(8, len(trials_to_load))) as executor:
|
||
future_to_trial = {
|
||
executor.submit(self._load_trial_data, day, trial): (day, trial)
|
||
for day, trial in trials_to_load
|
||
}
|
||
|
||
# Collect results in order
|
||
trial_results = {}
|
||
for future in future_to_trial:
|
||
day, trial = future_to_trial[future]
|
||
trial_results[(day, trial)] = future.result()
|
||
|
||
# Add data in original order
|
||
for day, trial in trials_to_load:
|
||
trial_data = trial_results[(day, trial)]
|
||
batch_data['input_features'].append(trial_data['input_features'])
|
||
batch_data['seq_class_ids'].append(trial_data['seq_class_ids'])
|
||
batch_data['transcriptions'].append(trial_data['transcription'])
|
||
batch_data['n_time_steps'].append(trial_data['n_time_steps'])
|
||
batch_data['phone_seq_lens'].append(trial_data['phone_seq_lens'])
|
||
batch_data['day_indices'].append(trial_data['day_index'])
|
||
batch_data['block_nums'].append(trial_data['block_num'])
|
||
batch_data['trial_nums'].append(trial_data['trial_num'])
|
||
else:
|
||
# Sequential loading (fast when data is cached or few trials)
|
||
for day, trial in trials_to_load:
|
||
trial_data = self._load_trial_data(day, trial)
|
||
batch_data['input_features'].append(trial_data['input_features'])
|
||
batch_data['seq_class_ids'].append(trial_data['seq_class_ids'])
|
||
batch_data['transcriptions'].append(trial_data['transcription'])
|
||
batch_data['n_time_steps'].append(trial_data['n_time_steps'])
|
||
batch_data['phone_seq_lens'].append(trial_data['phone_seq_lens'])
|
||
batch_data['day_indices'].append(trial_data['day_index'])
|
||
batch_data['block_nums'].append(trial_data['block_num'])
|
||
batch_data['trial_nums'].append(trial_data['trial_num'])
|
||
|
||
data_loading_time = time.time() - batch_start_time
|
||
|
||
# Add timing diagnostic for first few batches
|
||
if batch_idx < 3:
|
||
cache_status = "cached" if self.preload_all_data else "disk"
|
||
loading_method = "parallel" if (not self.preload_all_data and len(trials_to_load) > 4) else "sequential"
|
||
print(f"⏱️ Batch {batch_idx}: {len(trials_to_load)} trials loaded in {data_loading_time:.3f}s ({cache_status}, {loading_method})")
|
||
|
||
# Pad sequences to create uniform batch
|
||
max_time_steps = max(batch_data['n_time_steps'])
|
||
max_phone_len = max(len(seq) for seq in batch_data['seq_class_ids'])
|
||
max_transcription_len = max(len(trans) for trans in batch_data['transcriptions'])
|
||
|
||
# Pad input features
|
||
padded_features = []
|
||
for features in batch_data['input_features']:
|
||
if features.shape[0] < max_time_steps:
|
||
padding = np.zeros((max_time_steps - features.shape[0], features.shape[1]), dtype=np.float32)
|
||
features = np.vstack([features, padding])
|
||
padded_features.append(features)
|
||
|
||
# Pad sequences
|
||
padded_seq_ids = []
|
||
for seq in batch_data['seq_class_ids']:
|
||
if len(seq) < max_phone_len:
|
||
padding = np.zeros(max_phone_len - len(seq), dtype=np.int32)
|
||
seq = np.concatenate([seq, padding])
|
||
padded_seq_ids.append(seq)
|
||
|
||
# Pad transcriptions
|
||
padded_transcriptions = []
|
||
for trans in batch_data['transcriptions']:
|
||
if len(trans) < max_transcription_len:
|
||
padding = np.zeros(max_transcription_len - len(trans), dtype=np.int32)
|
||
trans = np.concatenate([trans, padding])
|
||
padded_transcriptions.append(trans)
|
||
|
||
# Create final batch tensors
|
||
batch = {
|
||
'input_features': np.stack(padded_features),
|
||
'seq_class_ids': np.stack(padded_seq_ids),
|
||
'n_time_steps': np.array(batch_data['n_time_steps'], dtype=np.int32),
|
||
'phone_seq_lens': np.array(batch_data['phone_seq_lens'], dtype=np.int32),
|
||
'day_indices': np.array(batch_data['day_indices'], dtype=np.int32),
|
||
'transcriptions': np.stack(padded_transcriptions),
|
||
'block_nums': np.array(batch_data['block_nums'], dtype=np.int32),
|
||
'trial_nums': np.array(batch_data['trial_nums'], dtype=np.int32)
|
||
}
|
||
|
||
yield batch
|
||
|
||
def create_dataset(self) -> tf.data.Dataset:
|
||
"""Create optimized tf.data.Dataset for TPU training"""
|
||
|
||
# Define output signature for the dataset
|
||
output_signature = {
|
||
'input_features': tf.TensorSpec(shape=(None, None, None), dtype=tf.float32),
|
||
'seq_class_ids': tf.TensorSpec(shape=(None, None), dtype=tf.int32),
|
||
'n_time_steps': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||
'phone_seq_lens': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||
'day_indices': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||
'transcriptions': tf.TensorSpec(shape=(None, None), dtype=tf.int32),
|
||
'block_nums': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||
'trial_nums': tf.TensorSpec(shape=(None,), dtype=tf.int32)
|
||
}
|
||
|
||
# Create dataset from generator
|
||
dataset = tf.data.Dataset.from_generator(
|
||
self._create_batch_generator,
|
||
output_signature=output_signature
|
||
)
|
||
|
||
# Apply TPU-optimized transformations
|
||
# 🚨 GPU版本策略:不需要在Dataset级别shuffle!
|
||
# GPU版本在 _create_batch_index_train() 中已经做了随机采样(第107-118行)
|
||
# 这里再shuffle会导致内存爆炸(1000 batch × 256 trials = 256,000 trials同时在内存)
|
||
# if self.split == 'train':
|
||
# dataset = dataset.shuffle(buffer_size=min(1000, self.n_batches)) # ← 注释掉内存杀手
|
||
|
||
# Prefetch for better performance
|
||
dataset = dataset.prefetch(self.prefetch_buffer)
|
||
|
||
return dataset
|
||
|
||
def create_individual_dataset(self) -> tf.data.Dataset:
|
||
"""
|
||
Create tf.data.Dataset that yields individual examples with I/O optimization.
|
||
|
||
This generator is refactored to group trial loading by session file,
|
||
drastically reducing the number of file open/close operations from
|
||
N_trials to N_sessions, which is ideal for slow disk I/O.
|
||
"""
|
||
|
||
def individual_example_generator():
|
||
"""Generator that groups reads by file to minimize disk I/O."""
|
||
|
||
# 1. 创建一个所有试验的扁平列表: [(day, trial), (day, trial), ...]
|
||
all_trials_to_load = []
|
||
# 注意:这里的迭代顺序决定了大致的读取顺序
|
||
# _create_batch_index_train 已经为我们随机化了批次
|
||
for batch_idx in sorted(self.batch_indices.keys()):
|
||
batch_index = self.batch_indices[batch_idx]
|
||
for day in batch_index.keys():
|
||
for trial in batch_index[day]:
|
||
all_trials_to_load.append((day, trial))
|
||
|
||
# 2. 按 'day' (即按文件) 对试验列表进行分组
|
||
# key=lambda x: x[0] 表示使用元组的第一个元素 (day) 作为分组键
|
||
for day, group in groupby(sorted(all_trials_to_load, key=lambda x: x[0]), key=lambda x: x[0]):
|
||
|
||
session_path = self.trial_indices[day]['session_path']
|
||
|
||
# 3. 为每个分组(每个文件)只打开一次 HDF5 文件
|
||
try:
|
||
with h5py.File(session_path, 'r') as f:
|
||
# 4. 在文件打开的状态下,读取这个文件中需要的所有试验
|
||
for current_day, current_trial in group:
|
||
try:
|
||
# 直接从打开的文件句柄 'f' 中读取,而不是调用旧的加载函数
|
||
g = f[f'trial_{current_trial:04d}']
|
||
|
||
input_features = g['input_features'][:]
|
||
if self.feature_subset:
|
||
input_features = input_features[:, self.feature_subset]
|
||
|
||
example = {
|
||
'input_features': input_features.astype(np.float32),
|
||
'seq_class_ids': g['seq_class_ids'][:].astype(np.int32),
|
||
'n_time_steps': np.int32(g.attrs['n_time_steps']),
|
||
'phone_seq_lens': np.int32(g.attrs['seq_len']),
|
||
'day_indices': np.int32(current_day),
|
||
'transcriptions': g['transcription'][:].astype(np.int32),
|
||
'block_nums': np.int32(g.attrs['block_num']),
|
||
'trial_nums': np.int32(g.attrs['trial_num'])
|
||
}
|
||
yield example
|
||
|
||
except KeyError:
|
||
logging.warning(f"Trial {current_trial} not found in file {session_path}. Skipping.")
|
||
continue
|
||
|
||
except (IOError, FileNotFoundError) as e:
|
||
logging.error(f"Could not open or read HDF5 file: {session_path}. Error: {e}. Skipping all trials for this day.")
|
||
continue
|
||
|
||
# Define output signature for individual examples
|
||
output_signature = {
|
||
'input_features': tf.TensorSpec(shape=(None, None), dtype=tf.float32),
|
||
'seq_class_ids': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||
'n_time_steps': tf.TensorSpec(shape=(), dtype=tf.int32),
|
||
'phone_seq_lens': tf.TensorSpec(shape=(), dtype=tf.int32),
|
||
'day_indices': tf.TensorSpec(shape=(), dtype=tf.int32),
|
||
'transcriptions': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||
'block_nums': tf.TensorSpec(shape=(), dtype=tf.int32),
|
||
'trial_nums': tf.TensorSpec(shape=(), dtype=tf.int32)
|
||
}
|
||
|
||
# Create dataset from individual examples
|
||
dataset = tf.data.Dataset.from_generator(
|
||
individual_example_generator,
|
||
output_signature=output_signature
|
||
)
|
||
|
||
# Shuffle individual examples if training (more effective than batch-level shuffle)
|
||
if self.split == 'train':
|
||
# 可以适当增大buffer,因为现在I/O更高效了
|
||
shuffle_buffer = min(2048, self.n_trials)
|
||
dataset = dataset.shuffle(buffer_size=shuffle_buffer)
|
||
|
||
return dataset
|
||
|
||
|
||
class DataAugmentationTF:
|
||
"""
|
||
TensorFlow data augmentation functions optimized for TPU v5e-8
|
||
"""
|
||
|
||
@staticmethod
|
||
def gauss_smooth(inputs: tf.Tensor,
|
||
smooth_kernel_std: float = 2.0,
|
||
smooth_kernel_size: int = 100) -> tf.Tensor:
|
||
"""
|
||
Apply Gaussian smoothing along the time axis using a vectorized TensorFlow operation.
|
||
|
||
This implementation uses depthwise_conv2d for optimal TPU performance,
|
||
replacing the inefficient Python for-loop that created 512 separate conv1d operations.
|
||
|
||
Args:
|
||
inputs: Input tensor [batch_size, time_steps, features]
|
||
smooth_kernel_std: Standard deviation of Gaussian kernel
|
||
smooth_kernel_size: Size of the Gaussian kernel
|
||
|
||
Returns:
|
||
Smoothed tensor with same shape as input
|
||
"""
|
||
# Create Gaussian kernel using numpy (computed once)
|
||
inp = np.zeros(smooth_kernel_size, dtype=np.float32)
|
||
inp[smooth_kernel_size // 2] = 1
|
||
gauss_kernel = gaussian_filter1d(inp, smooth_kernel_std)
|
||
valid_idx = np.argwhere(gauss_kernel > 0.01)
|
||
gauss_kernel = gauss_kernel[valid_idx].flatten()
|
||
gauss_kernel = gauss_kernel / np.sum(gauss_kernel)
|
||
gauss_kernel = tf.constant(gauss_kernel, dtype=tf.float32)
|
||
|
||
# ========================= OPTIMIZED SOLUTION =========================
|
||
# Get input dimensions
|
||
num_features = tf.shape(inputs)[-1]
|
||
kernel_size = tf.shape(gauss_kernel)[0]
|
||
|
||
# Prepare kernel for depthwise_conv2d
|
||
# Shape needed: [height, width, in_channels, channel_multiplier]
|
||
# Our case: [kernel_size, 1, num_features, 1]
|
||
# This means each input channel (num_features) has its own independent, identical 1D Gaussian kernel
|
||
kernel = tf.reshape(gauss_kernel, [kernel_size, 1, 1, 1])
|
||
kernel = tf.tile(kernel, [1, 1, num_features, 1])
|
||
|
||
# Prepare input for conv2d
|
||
# Shape needed: [batch, height, width, channels]
|
||
# Our case: [batch_size, time_steps, 1, num_features]
|
||
# Add a dummy width dimension
|
||
reshaped_inputs = tf.expand_dims(inputs, axis=2)
|
||
|
||
# Execute depthwise convolution
|
||
# This is a single, efficient operation replacing the original Python for-loop
|
||
smoothed = tf.nn.depthwise_conv2d(
|
||
reshaped_inputs,
|
||
kernel,
|
||
strides=[1, 1, 1, 1],
|
||
padding='SAME'
|
||
)
|
||
|
||
# Remove the dummy width dimension to restore original shape
|
||
smoothed = tf.squeeze(smoothed, axis=2)
|
||
# ================================================================
|
||
|
||
return smoothed
|
||
|
||
@staticmethod
|
||
def transform_data(features: tf.Tensor,
|
||
n_time_steps: tf.Tensor,
|
||
transform_args: Dict[str, Any],
|
||
training: bool = True) -> Tuple[tf.Tensor, tf.Tensor]:
|
||
"""
|
||
Apply data transformations optimized for TPU
|
||
|
||
Args:
|
||
features: Input features [batch_size, time_steps, channels]
|
||
n_time_steps: Number of valid time steps per sample
|
||
transform_args: Transformation configuration
|
||
training: Whether to apply training-only augmentations
|
||
|
||
Returns:
|
||
Transformed features and updated time steps
|
||
"""
|
||
batch_size = tf.shape(features)[0]
|
||
time_steps = tf.shape(features)[1]
|
||
channels = tf.shape(features)[2]
|
||
|
||
# Training-only augmentations
|
||
if training:
|
||
# Static gain noise
|
||
if transform_args.get('static_gain_std', 0) > 0:
|
||
gain_std = transform_args['static_gain_std']
|
||
# Create identity matrices for each batch
|
||
identity_matrices = tf.eye(channels, batch_shape=[batch_size])
|
||
# Add noise to create warp matrices
|
||
noise = tf.random.normal([batch_size, channels, channels]) * gain_std
|
||
warp_matrices = identity_matrices + noise
|
||
# Apply transformation
|
||
features = tf.linalg.matmul(features, warp_matrices)
|
||
|
||
# White noise
|
||
if transform_args.get('white_noise_std', 0) > 0:
|
||
white_noise = tf.random.normal(tf.shape(features)) * transform_args['white_noise_std']
|
||
features = features + white_noise
|
||
|
||
# Constant offset noise
|
||
if transform_args.get('constant_offset_std', 0) > 0:
|
||
offset_noise = tf.random.normal([batch_size, 1, channels]) * transform_args['constant_offset_std']
|
||
features = features + offset_noise
|
||
|
||
# Random walk noise
|
||
if transform_args.get('random_walk_std', 0) > 0:
|
||
random_walk_noise = tf.random.normal(tf.shape(features)) * transform_args['random_walk_std']
|
||
axis = transform_args.get('random_walk_axis', 1)
|
||
random_walk_noise = tf.cumsum(random_walk_noise, axis=axis)
|
||
features = features + random_walk_noise
|
||
|
||
# Random cutoff (simplified for TPU - apply to all samples in batch)
|
||
if transform_args.get('random_cut', 0) > 0:
|
||
max_cut = transform_args['random_cut']
|
||
cut = tf.random.uniform([], 0, max_cut, dtype=tf.int32)
|
||
features = features[:, cut:, :]
|
||
n_time_steps = n_time_steps - cut
|
||
|
||
# Apply Gaussian smoothing (both training and validation)
|
||
if transform_args.get('smooth_data', False):
|
||
features = DataAugmentationTF.gauss_smooth(
|
||
features,
|
||
smooth_kernel_std=transform_args.get('smooth_kernel_std', 2.0),
|
||
smooth_kernel_size=transform_args.get('smooth_kernel_size', 100)
|
||
)
|
||
|
||
return features, n_time_steps
|
||
|
||
|
||
def train_test_split_indices(file_paths: List[str],
|
||
test_percentage: float = 0.1,
|
||
seed: int = -1,
|
||
bad_trials_dict: Optional[Dict] = None) -> Tuple[Dict, Dict]:
|
||
"""
|
||
Split data from file_paths into train and test splits
|
||
|
||
Args:
|
||
file_paths: List of HDF5 file paths
|
||
test_percentage: Percentage of trials for testing
|
||
seed: Random seed for reproducibility
|
||
bad_trials_dict: Dictionary of trials to exclude
|
||
|
||
Returns:
|
||
Tuple of (train_trials, test_trials) dictionaries
|
||
"""
|
||
# Set seed for reproducibility
|
||
if seed != -1:
|
||
np.random.seed(seed)
|
||
|
||
# Get trials in each day
|
||
trials_per_day = {}
|
||
for i, path in enumerate(file_paths):
|
||
# Handle both Windows and Unix path separators
|
||
path_parts = path.replace('\\', '/').split('/')
|
||
session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
|
||
|
||
good_trial_indices = []
|
||
|
||
if os.path.exists(path):
|
||
with h5py.File(path, 'r') as f:
|
||
num_trials = len(list(f.keys()))
|
||
for t in range(num_trials):
|
||
key = f'trial_{t:04d}'
|
||
|
||
if key not in f:
|
||
continue
|
||
|
||
block_num = f[key].attrs['block_num']
|
||
trial_num = f[key].attrs['trial_num']
|
||
|
||
# Check if trial should be excluded
|
||
if (bad_trials_dict is not None
|
||
and session in bad_trials_dict
|
||
and str(block_num) in bad_trials_dict[session]
|
||
and trial_num in bad_trials_dict[session][str(block_num)]):
|
||
continue
|
||
|
||
good_trial_indices.append(t)
|
||
|
||
trials_per_day[i] = {
|
||
'num_trials': len(good_trial_indices),
|
||
'trial_indices': good_trial_indices,
|
||
'session_path': path
|
||
}
|
||
|
||
# Split trials into train and test
|
||
train_trials = {}
|
||
test_trials = {}
|
||
|
||
for day in trials_per_day.keys():
|
||
num_trials = trials_per_day[day]['num_trials']
|
||
all_trial_indices = trials_per_day[day]['trial_indices']
|
||
|
||
if test_percentage == 0:
|
||
train_trials[day] = {
|
||
'trials': all_trial_indices,
|
||
'session_path': trials_per_day[day]['session_path']
|
||
}
|
||
test_trials[day] = {
|
||
'trials': [],
|
||
'session_path': trials_per_day[day]['session_path']
|
||
}
|
||
elif test_percentage == 1:
|
||
train_trials[day] = {
|
||
'trials': [],
|
||
'session_path': trials_per_day[day]['session_path']
|
||
}
|
||
test_trials[day] = {
|
||
'trials': all_trial_indices,
|
||
'session_path': trials_per_day[day]['session_path']
|
||
}
|
||
else:
|
||
# Calculate number of test trials
|
||
num_test = max(1, int(num_trials * test_percentage))
|
||
|
||
# Randomly select test indices
|
||
test_indices = np.random.choice(all_trial_indices, size=num_test, replace=False).tolist()
|
||
|
||
# Remaining indices for training
|
||
train_indices = [idx for idx in all_trial_indices if idx not in test_indices]
|
||
|
||
train_trials[day] = {
|
||
'trials': train_indices,
|
||
'session_path': trials_per_day[day]['session_path']
|
||
}
|
||
test_trials[day] = {
|
||
'trials': test_indices,
|
||
'session_path': trials_per_day[day]['session_path']
|
||
}
|
||
|
||
return train_trials, test_trials
|
||
|
||
|
||
def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 100) -> Dict[str, int]:
|
||
"""
|
||
Analyzes dataset shapes in parallel to determine maximum dimensions for padded_batch,
|
||
utilizing multiple CPU cores and the dataset's caching mechanism.
|
||
|
||
Args:
|
||
dataset_tf: Dataset instance to analyze
|
||
sample_size: Number of samples to analyze (set to -1 for all data)
|
||
|
||
Returns:
|
||
Dictionary with maximum dimensions
|
||
"""
|
||
print(f"🚀 Starting parallel dataset analysis (sampling: {'ALL' if sample_size == -1 else sample_size}, max workers: 224)...")
|
||
start_time = time.time()
|
||
|
||
# 1. 收集所有需要分析的 (day, trial) 对,避免重复
|
||
all_trials = []
|
||
unique_trials = set()
|
||
for batch_idx in sorted(dataset_tf.batch_indices.keys()):
|
||
batch = dataset_tf.batch_indices[batch_idx]
|
||
for day, trials in batch.items():
|
||
for trial in trials:
|
||
if (day, trial) not in unique_trials:
|
||
unique_trials.add((day, trial))
|
||
all_trials.append((day, trial))
|
||
|
||
# 2. 如果需要采样,则对列表进行采样
|
||
if 0 < sample_size < len(all_trials):
|
||
# 设置种子以确保可重现性(如果需要的话)
|
||
random.seed(42)
|
||
trials_to_check = random.sample(all_trials, sample_size)
|
||
else:
|
||
trials_to_check = all_trials
|
||
|
||
total_trials_to_analyze = len(trials_to_check)
|
||
print(f"📊 Total unique trials to analyze: {total_trials_to_analyze}")
|
||
|
||
# 定义一个辅助函数,供每个线程调用
|
||
def analyze_single_trial(day_trial_pair):
|
||
"""Loads and analyzes a single trial, returns its shapes."""
|
||
day, trial = day_trial_pair
|
||
try:
|
||
# 复用 dataset_tf 的加载和缓存逻辑
|
||
trial_data = dataset_tf._load_trial_data(day, trial)
|
||
|
||
# 直接从加载的数据中获取信息
|
||
time_steps = int(trial_data['n_time_steps'])
|
||
phone_seq_len = int(trial_data['phone_seq_lens'])
|
||
|
||
# 处理 transcription 数据 - 它可能是数组
|
||
transcription_data = trial_data['transcription']
|
||
if hasattr(transcription_data, '__len__'):
|
||
transcription_len = len(transcription_data)
|
||
else:
|
||
transcription_len = 1 # 如果是标量,长度为1
|
||
|
||
return (time_steps, phone_seq_len, transcription_len)
|
||
except Exception as e:
|
||
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
|
||
return None # 返回 None 表示失败
|
||
|
||
# 3. 使用 ThreadPoolExecutor 进行并行处理
|
||
max_workers = min(224, len(trials_to_check)) # 不超过实际任务数
|
||
local_max_shapes = []
|
||
|
||
print(f"🔧 Using {max_workers} parallel workers for analysis...")
|
||
|
||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||
# 提交所有分析任务
|
||
future_to_trial = {
|
||
executor.submit(analyze_single_trial, trial_pair): trial_pair
|
||
for trial_pair in trials_to_check
|
||
}
|
||
|
||
count = 0
|
||
progress_interval = max(1, total_trials_to_analyze // 20) # 显示20次进度更新
|
||
|
||
for future in as_completed(future_to_trial):
|
||
result = future.result()
|
||
if result:
|
||
local_max_shapes.append(result)
|
||
|
||
count += 1
|
||
if count % progress_interval == 0 or count == total_trials_to_analyze:
|
||
elapsed = time.time() - start_time
|
||
rate = count / elapsed if elapsed > 0 else 0
|
||
print(f" 📈 Analyzed {count}/{total_trials_to_analyze} trials... ({count/total_trials_to_analyze:.1%}) [{rate:.1f} trials/sec]")
|
||
|
||
# 4. 聚合所有线程的结果
|
||
if not local_max_shapes:
|
||
raise ValueError("Dataset analysis failed: No trials could be successfully analyzed.")
|
||
|
||
# 将 [(t1, p1, tr1), (t2, p2, tr2), ...] 转换为 ([t1, t2, ...], [p1, p2, ...], ...)
|
||
unzipped_shapes = list(zip(*local_max_shapes))
|
||
|
||
original_max_shapes = {
|
||
'max_time_steps': int(np.max(unzipped_shapes[0])),
|
||
'max_phone_seq_len': int(np.max(unzipped_shapes[1])),
|
||
'max_transcription_len': int(np.max(unzipped_shapes[2])),
|
||
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
|
||
}
|
||
|
||
# 5. 添加安全边际(10% buffer)
|
||
safety_margin = 1.1
|
||
|
||
final_max_shapes = {
|
||
'max_time_steps': int(original_max_shapes['max_time_steps'] * safety_margin),
|
||
'max_phone_seq_len': int(original_max_shapes['max_phone_seq_len'] * safety_margin),
|
||
'max_transcription_len': int(original_max_shapes['max_transcription_len'] * safety_margin),
|
||
'n_features': original_max_shapes['n_features']
|
||
}
|
||
|
||
analysis_time = time.time() - start_time
|
||
successful_rate = len(local_max_shapes) / total_trials_to_analyze * 100
|
||
|
||
print(f"✅ Parallel analysis complete in {analysis_time:.2f} seconds!")
|
||
print(f"📊 Successfully analyzed {len(local_max_shapes)}/{total_trials_to_analyze} trials ({successful_rate:.1f}%)")
|
||
print(f"📏 Final max shapes (with {int((safety_margin-1)*100)}% safety margin):")
|
||
print(f" Time steps: {original_max_shapes['max_time_steps']} → {final_max_shapes['max_time_steps']}")
|
||
print(f" Phone sequence length: {original_max_shapes['max_phone_seq_len']} → {final_max_shapes['max_phone_seq_len']}")
|
||
print(f" Transcription length: {original_max_shapes['max_transcription_len']} → {final_max_shapes['max_transcription_len']}")
|
||
print(f" Number of features: {final_max_shapes['n_features']}")
|
||
|
||
return final_max_shapes
|
||
|
||
|
||
# Utility functions for TPU-optimized data pipeline
|
||
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||
transform_args: Dict[str, Any],
|
||
max_shapes: Dict[str, int],
|
||
training: bool = True,
|
||
cache_path: Optional[str] = None) -> tf.data.Dataset:
|
||
"""
|
||
Create input function for TPU training with PRE-ANALYZED FIXED shapes
|
||
|
||
This function uses pre-computed maximum shapes to create STATIC-size batches,
|
||
ensuring XLA compilation success on TPU hardware. This is CRITICAL for the
|
||
final resolution of both CTC loss compatibility and graph structure issues.
|
||
|
||
Args:
|
||
dataset_tf: BrainToTextDatasetTF instance
|
||
transform_args: Data transformation configuration
|
||
max_shapes: Pre-computed maximum shapes dictionary with keys:
|
||
'max_time_steps', 'max_phone_seq_len', 'max_transcription_len', 'n_features'
|
||
training: Whether this is for training (applies augmentations)
|
||
cache_path: Optional path for disk caching to improve I/O performance
|
||
|
||
Returns:
|
||
tf.data.Dataset ready for TPU training with FIXED STATIC shapes
|
||
"""
|
||
|
||
# Step 1: Create individual example dataset with file-grouping I/O optimization
|
||
dataset = dataset_tf.create_individual_dataset()
|
||
|
||
# Step 2: Cache raw samples BEFORE any augmentation
|
||
if cache_path:
|
||
dataset = dataset.cache(cache_path)
|
||
split_name = "training" if training else "validation"
|
||
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: {cache_path}")
|
||
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
|
||
else:
|
||
# 如果没有指定缓存路径,默认使用内存缓存
|
||
dataset = dataset.cache()
|
||
split_name = "training" if training else "validation"
|
||
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache")
|
||
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
|
||
|
||
# Step 3: Apply transformations to individual examples BEFORE batching
|
||
def apply_transforms(example):
|
||
"""Apply data transformations to individual examples"""
|
||
features = example['input_features']
|
||
n_time_steps = example['n_time_steps']
|
||
|
||
# Apply transformations
|
||
features, n_time_steps = DataAugmentationTF.transform_data(
|
||
tf.expand_dims(features, 0), # Add batch dimension for transforms
|
||
tf.expand_dims(n_time_steps, 0),
|
||
transform_args,
|
||
training=training
|
||
)
|
||
|
||
# Remove batch dimension
|
||
example['input_features'] = tf.squeeze(features, 0)
|
||
example['n_time_steps'] = tf.squeeze(n_time_steps, 0)
|
||
|
||
return example
|
||
|
||
# Apply transforms to cached data
|
||
dataset = dataset.map(
|
||
apply_transforms,
|
||
num_parallel_calls=tf.data.AUTOTUNE
|
||
)
|
||
|
||
# ========================= 终极调试代码 =========================
|
||
def debug_print_shape(example):
|
||
"""调试函数:在 padded_batch 之前打印每个样本的形状"""
|
||
tf.print("🔍 Sample Shape Debug:",
|
||
tf.shape(example['input_features']),
|
||
"Expected feature dim:", dataset_tf.feature_dim,
|
||
output_stream=sys.stdout)
|
||
return example
|
||
|
||
# 添加形状调试 - 这会在图执行时打印信息
|
||
dataset = dataset.map(debug_print_shape)
|
||
print(f"⚠️ Debug mode: Will print each sample shape before padded_batch")
|
||
# =============================================================
|
||
|
||
# Step 4: Batch samples with FIXED STATIC padding (CRITICAL for XLA)
|
||
print(f"🔧 Using PRE-ANALYZED FIXED shapes for maximum TPU performance:")
|
||
|
||
# Extract pre-analyzed shape information
|
||
max_time_steps = max_shapes['max_time_steps']
|
||
max_phone_seq_len = max_shapes['max_phone_seq_len']
|
||
max_transcription_len = max_shapes['max_transcription_len']
|
||
|
||
# ========================= 使用统一的特征维度 =========================
|
||
# 使用 dataset_tf 对象中存储的、经过验证的特征维度,而不是依赖外部参数
|
||
n_features = dataset_tf.feature_dim # <--- 关键修改:使用自动检测的特征维度
|
||
print(f"🔧 Using verified feature dimension from dataset: {n_features}")
|
||
# ========================= 特征维度修改结束 =========================
|
||
|
||
print(f" Fixed time steps: {max_time_steps}")
|
||
print(f" Fixed phone sequence length: {max_phone_seq_len}")
|
||
print(f" Fixed transcription length: {max_transcription_len}")
|
||
print(f" Number of features: {n_features}")
|
||
|
||
# Define FIXED padded shapes with TensorSpec for better type safety
|
||
padded_shapes = {
|
||
'input_features': tf.TensorSpec(shape=[max_time_steps, n_features], dtype=tf.float32),
|
||
'seq_class_ids': tf.TensorSpec(shape=[max_phone_seq_len], dtype=tf.int32),
|
||
'n_time_steps': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
|
||
'phone_seq_lens': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
|
||
'day_indices': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
|
||
'transcriptions': tf.TensorSpec(shape=[max_transcription_len], dtype=tf.int32),
|
||
'block_nums': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
|
||
'trial_nums': tf.TensorSpec(shape=[], dtype=tf.int32) # scalar
|
||
}
|
||
|
||
# Define padding values for each field
|
||
padding_values = {
|
||
'input_features': 0.0,
|
||
'seq_class_ids': 0,
|
||
'n_time_steps': 0,
|
||
'phone_seq_lens': 0,
|
||
'day_indices': 0,
|
||
'transcriptions': 0,
|
||
'block_nums': 0,
|
||
'trial_nums': 0
|
||
}
|
||
|
||
# Create batches with FIXED padding - XLA compiler will be happy!
|
||
dataset = dataset.padded_batch(
|
||
batch_size=dataset_tf.batch_size,
|
||
padded_shapes=padded_shapes,
|
||
padding_values=padding_values,
|
||
drop_remainder=True # Critical for TPU: ensures all batches have same size
|
||
)
|
||
|
||
# Prefetch for optimal performance
|
||
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
||
|
||
return dataset |