720 lines
29 KiB
Python
720 lines
29 KiB
Python
import os
|
||
import tensorflow as tf
|
||
import h5py
|
||
import numpy as np
|
||
import math
|
||
from typing import Dict, List, Tuple, Optional, Any
|
||
from scipy.ndimage import gaussian_filter1d
|
||
|
||
|
||
class BrainToTextDatasetTF:
|
||
"""
|
||
TensorFlow Dataset for brain-to-text data optimized for TPU v5e-8
|
||
|
||
This class creates tf.data.Dataset objects that efficiently load and batch
|
||
brain-to-text data from HDF5 files with TPU-optimized operations.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
trial_indices: Dict[int, Dict[str, Any]],
|
||
n_batches: Optional[int],
|
||
split: str = 'train',
|
||
batch_size: int = 64,
|
||
days_per_batch: int = 1,
|
||
random_seed: int = -1,
|
||
must_include_days: Optional[List[int]] = None,
|
||
feature_subset: Optional[List[int]] = None,
|
||
prefetch_buffer: int = tf.data.AUTOTUNE,
|
||
num_parallel_calls: int = tf.data.AUTOTUNE,
|
||
cache_data: bool = True,
|
||
preload_all_data: bool = False
|
||
):
|
||
"""
|
||
Initialize TensorFlow dataset for brain-to-text data
|
||
|
||
Args:
|
||
trial_indices: Dictionary with day numbers as keys and trial info as values
|
||
n_batches: Number of training batches to create (None for validation)
|
||
split: 'train' or 'test'
|
||
batch_size: Number of examples per batch
|
||
days_per_batch: Number of unique days per batch (for day-specific layers)
|
||
random_seed: Random seed for reproducibility
|
||
must_include_days: Days that must be included in every batch
|
||
feature_subset: Subset of neural features to use
|
||
prefetch_buffer: Buffer size for prefetching
|
||
num_parallel_calls: Parallel processing threads
|
||
cache_data: Whether to cache loaded data in memory
|
||
preload_all_data: Whether to preload all data at initialization
|
||
"""
|
||
|
||
# Set random seed for reproducibility
|
||
if random_seed != -1:
|
||
tf.random.set_seed(random_seed)
|
||
np.random.seed(random_seed)
|
||
|
||
self.split = split
|
||
if self.split not in ['train', 'test']:
|
||
raise ValueError(f'split must be either "train" or "test". Received {self.split}')
|
||
|
||
self.days_per_batch = days_per_batch
|
||
self.batch_size = batch_size
|
||
self.n_batches = n_batches
|
||
self.trial_indices = trial_indices
|
||
self.n_days = len(trial_indices.keys())
|
||
self.feature_subset = feature_subset
|
||
self.must_include_days = must_include_days
|
||
self.prefetch_buffer = prefetch_buffer
|
||
self.num_parallel_calls = num_parallel_calls
|
||
self.cache_data = cache_data
|
||
self.preload_all_data = preload_all_data
|
||
|
||
# Initialize data cache
|
||
self.data_cache = {} if cache_data else None
|
||
|
||
# Calculate total number of trials
|
||
self.n_trials = 0
|
||
for d in trial_indices:
|
||
self.n_trials += len(trial_indices[d]['trials'])
|
||
|
||
# Validation checks
|
||
if must_include_days is not None:
|
||
if len(must_include_days) > days_per_batch:
|
||
raise ValueError(f'must_include_days must be <= days_per_batch')
|
||
|
||
# Map negative indices
|
||
for i, d in enumerate(must_include_days):
|
||
if d < 0:
|
||
must_include_days[i] = self.n_days + d
|
||
|
||
if self.split == 'train' and self.days_per_batch > self.n_days:
|
||
raise ValueError(f'days_per_batch ({days_per_batch}) > available days ({self.n_days})')
|
||
|
||
# Create batch indices
|
||
if self.split == 'train':
|
||
self.batch_indices = self._create_batch_index_train()
|
||
else:
|
||
self.batch_indices = self._create_batch_index_test()
|
||
self.n_batches = len(self.batch_indices)
|
||
|
||
# Preload data if requested (speeds up first batch significantly)
|
||
if self.preload_all_data:
|
||
print(f"🔄 Preloading all data for {self.split} split...")
|
||
self._preload_all_data()
|
||
print(f"✅ Preloading completed - {len(self.data_cache)} trials cached")
|
||
|
||
def _create_batch_index_train(self) -> Dict[int, Dict[int, List[int]]]:
|
||
"""Create training batch indices with random sampling"""
|
||
batch_indices = {}
|
||
|
||
# Precompute non-must-include days
|
||
if self.must_include_days is not None:
|
||
non_must_include_days = [
|
||
d for d in self.trial_indices.keys()
|
||
if d not in self.must_include_days
|
||
]
|
||
|
||
for batch_idx in range(self.n_batches):
|
||
batch = {}
|
||
|
||
# Select days for this batch
|
||
if self.must_include_days is not None and len(self.must_include_days) > 0:
|
||
additional_days = np.random.choice(
|
||
non_must_include_days,
|
||
size=self.days_per_batch - len(self.must_include_days),
|
||
replace=False
|
||
)
|
||
days = np.concatenate((self.must_include_days, additional_days))
|
||
else:
|
||
days = np.random.choice(
|
||
list(self.trial_indices.keys()),
|
||
size=self.days_per_batch,
|
||
replace=False
|
||
)
|
||
|
||
# Calculate trials per day
|
||
num_trials = math.ceil(self.batch_size / self.days_per_batch)
|
||
|
||
for d in days:
|
||
# Sample trials with replacement
|
||
trial_idxs = np.random.choice(
|
||
self.trial_indices[d]['trials'],
|
||
size=num_trials,
|
||
replace=True
|
||
)
|
||
batch[d] = trial_idxs.tolist()
|
||
|
||
# Remove extra trials to match exact batch size
|
||
extra_trials = (num_trials * len(days)) - self.batch_size
|
||
while extra_trials > 0:
|
||
d = np.random.choice(days)
|
||
if len(batch[d]) > 0:
|
||
batch[d] = batch[d][:-1]
|
||
extra_trials -= 1
|
||
|
||
batch_indices[batch_idx] = batch
|
||
|
||
return batch_indices
|
||
|
||
def _create_batch_index_test(self) -> Dict[int, Dict[int, List[int]]]:
|
||
"""Create test batch indices ensuring all trials are seen once"""
|
||
batch_indices = {}
|
||
batch_idx = 0
|
||
|
||
for d in self.trial_indices.keys():
|
||
num_trials = len(self.trial_indices[d]['trials'])
|
||
num_batches = (num_trials + self.batch_size - 1) // self.batch_size
|
||
|
||
for i in range(num_batches):
|
||
start_idx = i * self.batch_size
|
||
end_idx = min((i + 1) * self.batch_size, num_trials)
|
||
|
||
batch_trials = self.trial_indices[d]['trials'][start_idx:end_idx]
|
||
batch_indices[batch_idx] = {d: batch_trials}
|
||
batch_idx += 1
|
||
|
||
return batch_indices
|
||
|
||
def _preload_all_data(self):
|
||
"""Preload all trial data into memory cache (uses available RAM optimally)"""
|
||
import multiprocessing
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
|
||
# Use CPU cores efficiently for parallel I/O
|
||
max_workers = min(multiprocessing.cpu_count(), 32) # Limit to avoid overwhelming I/O
|
||
|
||
# Collect all trials to load
|
||
trials_to_load = []
|
||
for day in self.trial_indices:
|
||
for trial in self.trial_indices[day]['trials']:
|
||
trials_to_load.append((day, trial))
|
||
|
||
print(f"📊 Preloading {len(trials_to_load)} trials using {max_workers} workers...")
|
||
|
||
# Parallel loading using ThreadPoolExecutor
|
||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||
# Submit all loading tasks
|
||
future_to_trial = {
|
||
executor.submit(self._load_single_trial_data, day, trial): (day, trial)
|
||
for day, trial in trials_to_load
|
||
}
|
||
|
||
# Process completed tasks and update cache
|
||
loaded_count = 0
|
||
for future in as_completed(future_to_trial):
|
||
day, trial = future_to_trial[future]
|
||
try:
|
||
trial_data = future.result()
|
||
cache_key = f"{day}_{trial}"
|
||
self.data_cache[cache_key] = trial_data
|
||
loaded_count += 1
|
||
|
||
# Progress indicator every 100 trials
|
||
if loaded_count % 100 == 0:
|
||
print(f" Loaded {loaded_count}/{len(trials_to_load)} trials...")
|
||
|
||
except Exception as e:
|
||
print(f" Warning: Failed to load trial {day}_{trial}: {e}")
|
||
|
||
print(f"✅ Preloading completed: {loaded_count}/{len(trials_to_load)} trials cached")
|
||
|
||
def _load_single_trial_data(self, day: int, trial: int) -> Dict[str, Any]:
|
||
"""Load a single trial's data - optimized version for parallel loading"""
|
||
try:
|
||
session_path = self.trial_indices[day]['session_path']
|
||
|
||
with h5py.File(session_path, 'r') as f:
|
||
g = f[f'trial_{trial:04d}']
|
||
|
||
# Load neural features
|
||
input_features = g['input_features'][:]
|
||
if self.feature_subset:
|
||
input_features = input_features[:, self.feature_subset]
|
||
|
||
# Convert to float32 for TF compatibility
|
||
input_features = input_features.astype(np.float32)
|
||
|
||
trial_data = {
|
||
'input_features': input_features,
|
||
'seq_class_ids': g['seq_class_ids'][:],
|
||
'transcription': g['transcription'][:],
|
||
'n_time_steps': g.attrs['n_time_steps'],
|
||
'phone_seq_lens': g.attrs['seq_len'],
|
||
'day_index': day,
|
||
'block_num': g.attrs['block_num'],
|
||
'trial_num': g.attrs['trial_num']
|
||
}
|
||
|
||
return trial_data
|
||
|
||
except Exception as e:
|
||
# Return dummy data for failed loads
|
||
return {
|
||
'input_features': np.zeros((100, 512), dtype=np.float32),
|
||
'seq_class_ids': np.zeros((10,), dtype=np.int32),
|
||
'transcription': np.zeros((50,), dtype=np.int32),
|
||
'n_time_steps': 100,
|
||
'phone_seq_lens': 10,
|
||
'day_index': day,
|
||
'block_num': 0,
|
||
'trial_num': 0
|
||
}
|
||
|
||
def _load_trial_data(self, day: int, trial: int) -> Dict[str, tf.Tensor]:
|
||
"""Load a single trial's data from cache or HDF5 file"""
|
||
# Check cache first if caching is enabled
|
||
if self.cache_data:
|
||
cache_key = f"{day}_{trial}"
|
||
if cache_key in self.data_cache:
|
||
return self.data_cache[cache_key]
|
||
|
||
# Load from disk if not in cache
|
||
trial_data = self._load_single_trial_data(day, trial)
|
||
|
||
# Cache the loaded data if caching is enabled
|
||
if self.cache_data:
|
||
cache_key = f"{day}_{trial}"
|
||
self.data_cache[cache_key] = trial_data
|
||
|
||
return trial_data
|
||
|
||
def _create_batch_generator(self):
|
||
"""Generator function that yields individual batches with optimized loading"""
|
||
import time
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
|
||
for batch_idx in range(self.n_batches):
|
||
batch_start_time = time.time()
|
||
|
||
batch_data = {
|
||
'input_features': [],
|
||
'seq_class_ids': [],
|
||
'n_time_steps': [],
|
||
'phone_seq_lens': [],
|
||
'day_indices': [],
|
||
'transcriptions': [],
|
||
'block_nums': [],
|
||
'trial_nums': []
|
||
}
|
||
|
||
batch_index = self.batch_indices[batch_idx]
|
||
|
||
# Collect all trials to load for this batch
|
||
trials_to_load = []
|
||
for day in batch_index.keys():
|
||
for trial in batch_index[day]:
|
||
trials_to_load.append((day, trial))
|
||
|
||
# Use parallel loading if not preloaded and have multiple trials
|
||
if not self.preload_all_data and len(trials_to_load) > 4:
|
||
# Parallel loading for faster I/O
|
||
with ThreadPoolExecutor(max_workers=min(8, len(trials_to_load))) as executor:
|
||
future_to_trial = {
|
||
executor.submit(self._load_trial_data, day, trial): (day, trial)
|
||
for day, trial in trials_to_load
|
||
}
|
||
|
||
# Collect results in order
|
||
trial_results = {}
|
||
for future in future_to_trial:
|
||
day, trial = future_to_trial[future]
|
||
trial_results[(day, trial)] = future.result()
|
||
|
||
# Add data in original order
|
||
for day, trial in trials_to_load:
|
||
trial_data = trial_results[(day, trial)]
|
||
batch_data['input_features'].append(trial_data['input_features'])
|
||
batch_data['seq_class_ids'].append(trial_data['seq_class_ids'])
|
||
batch_data['transcriptions'].append(trial_data['transcription'])
|
||
batch_data['n_time_steps'].append(trial_data['n_time_steps'])
|
||
batch_data['phone_seq_lens'].append(trial_data['phone_seq_lens'])
|
||
batch_data['day_indices'].append(trial_data['day_index'])
|
||
batch_data['block_nums'].append(trial_data['block_num'])
|
||
batch_data['trial_nums'].append(trial_data['trial_num'])
|
||
else:
|
||
# Sequential loading (fast when data is cached or few trials)
|
||
for day, trial in trials_to_load:
|
||
trial_data = self._load_trial_data(day, trial)
|
||
batch_data['input_features'].append(trial_data['input_features'])
|
||
batch_data['seq_class_ids'].append(trial_data['seq_class_ids'])
|
||
batch_data['transcriptions'].append(trial_data['transcription'])
|
||
batch_data['n_time_steps'].append(trial_data['n_time_steps'])
|
||
batch_data['phone_seq_lens'].append(trial_data['phone_seq_lens'])
|
||
batch_data['day_indices'].append(trial_data['day_index'])
|
||
batch_data['block_nums'].append(trial_data['block_num'])
|
||
batch_data['trial_nums'].append(trial_data['trial_num'])
|
||
|
||
data_loading_time = time.time() - batch_start_time
|
||
|
||
# Add timing diagnostic for first few batches
|
||
if batch_idx < 3:
|
||
cache_status = "cached" if self.preload_all_data else "disk"
|
||
loading_method = "parallel" if (not self.preload_all_data and len(trials_to_load) > 4) else "sequential"
|
||
print(f"⏱️ Batch {batch_idx}: {len(trials_to_load)} trials loaded in {data_loading_time:.3f}s ({cache_status}, {loading_method})")
|
||
|
||
# Pad sequences to create uniform batch
|
||
max_time_steps = max(batch_data['n_time_steps'])
|
||
max_phone_len = max(len(seq) for seq in batch_data['seq_class_ids'])
|
||
max_transcription_len = max(len(trans) for trans in batch_data['transcriptions'])
|
||
|
||
# Pad input features
|
||
padded_features = []
|
||
for features in batch_data['input_features']:
|
||
if features.shape[0] < max_time_steps:
|
||
padding = np.zeros((max_time_steps - features.shape[0], features.shape[1]), dtype=np.float32)
|
||
features = np.vstack([features, padding])
|
||
padded_features.append(features)
|
||
|
||
# Pad sequences
|
||
padded_seq_ids = []
|
||
for seq in batch_data['seq_class_ids']:
|
||
if len(seq) < max_phone_len:
|
||
padding = np.zeros(max_phone_len - len(seq), dtype=np.int32)
|
||
seq = np.concatenate([seq, padding])
|
||
padded_seq_ids.append(seq)
|
||
|
||
# Pad transcriptions
|
||
padded_transcriptions = []
|
||
for trans in batch_data['transcriptions']:
|
||
if len(trans) < max_transcription_len:
|
||
padding = np.zeros(max_transcription_len - len(trans), dtype=np.int32)
|
||
trans = np.concatenate([trans, padding])
|
||
padded_transcriptions.append(trans)
|
||
|
||
# Create final batch tensors
|
||
batch = {
|
||
'input_features': np.stack(padded_features),
|
||
'seq_class_ids': np.stack(padded_seq_ids),
|
||
'n_time_steps': np.array(batch_data['n_time_steps'], dtype=np.int32),
|
||
'phone_seq_lens': np.array(batch_data['phone_seq_lens'], dtype=np.int32),
|
||
'day_indices': np.array(batch_data['day_indices'], dtype=np.int32),
|
||
'transcriptions': np.stack(padded_transcriptions),
|
||
'block_nums': np.array(batch_data['block_nums'], dtype=np.int32),
|
||
'trial_nums': np.array(batch_data['trial_nums'], dtype=np.int32)
|
||
}
|
||
|
||
yield batch
|
||
|
||
def create_dataset(self) -> tf.data.Dataset:
|
||
"""Create optimized tf.data.Dataset for TPU training"""
|
||
|
||
# Define output signature for the dataset
|
||
output_signature = {
|
||
'input_features': tf.TensorSpec(shape=(None, None, None), dtype=tf.float32),
|
||
'seq_class_ids': tf.TensorSpec(shape=(None, None), dtype=tf.int32),
|
||
'n_time_steps': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||
'phone_seq_lens': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||
'day_indices': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||
'transcriptions': tf.TensorSpec(shape=(None, None), dtype=tf.int32),
|
||
'block_nums': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||
'trial_nums': tf.TensorSpec(shape=(None,), dtype=tf.int32)
|
||
}
|
||
|
||
# Create dataset from generator
|
||
dataset = tf.data.Dataset.from_generator(
|
||
self._create_batch_generator,
|
||
output_signature=output_signature
|
||
)
|
||
|
||
# Apply TPU-optimized transformations
|
||
# 🚨 GPU版本策略:不需要在Dataset级别shuffle!
|
||
# GPU版本在 _create_batch_index_train() 中已经做了随机采样(第107-118行)
|
||
# 这里再shuffle会导致内存爆炸(1000 batch × 256 trials = 256,000 trials同时在内存)
|
||
# if self.split == 'train':
|
||
# dataset = dataset.shuffle(buffer_size=min(1000, self.n_batches)) # ← 注释掉内存杀手
|
||
|
||
# Prefetch for better performance
|
||
dataset = dataset.prefetch(self.prefetch_buffer)
|
||
|
||
return dataset
|
||
|
||
|
||
class DataAugmentationTF:
|
||
"""
|
||
TensorFlow data augmentation functions optimized for TPU v5e-8
|
||
"""
|
||
|
||
@staticmethod
|
||
def gauss_smooth(inputs: tf.Tensor,
|
||
smooth_kernel_std: float = 2.0,
|
||
smooth_kernel_size: int = 100) -> tf.Tensor:
|
||
"""
|
||
Apply Gaussian smoothing along the time axis using TensorFlow operations
|
||
|
||
Args:
|
||
inputs: Input tensor [batch_size, time_steps, features]
|
||
smooth_kernel_std: Standard deviation of Gaussian kernel
|
||
smooth_kernel_size: Size of the Gaussian kernel
|
||
|
||
Returns:
|
||
Smoothed tensor with same shape as input
|
||
"""
|
||
# Create Gaussian kernel using numpy (computed once)
|
||
inp = np.zeros(smooth_kernel_size, dtype=np.float32)
|
||
inp[smooth_kernel_size // 2] = 1
|
||
gauss_kernel = gaussian_filter1d(inp, smooth_kernel_std)
|
||
valid_idx = np.argwhere(gauss_kernel > 0.01)
|
||
gauss_kernel = gauss_kernel[valid_idx].flatten()
|
||
gauss_kernel = gauss_kernel / np.sum(gauss_kernel)
|
||
|
||
# Convert to TensorFlow tensor and reshape for conv1d
|
||
gauss_kernel = tf.constant(gauss_kernel, dtype=tf.float32)
|
||
kernel_size = tf.shape(gauss_kernel)[0]
|
||
gauss_kernel = tf.reshape(gauss_kernel, [kernel_size, 1, 1]) # [kernel_size, in_channels, out_channels]
|
||
|
||
# Get tensor dimensions
|
||
batch_size = tf.shape(inputs)[0]
|
||
time_steps = tf.shape(inputs)[1]
|
||
num_features = tf.shape(inputs)[2]
|
||
|
||
# Apply convolution to each feature channel separately
|
||
smoothed_features = []
|
||
|
||
# Convert num_features to Python int for loop
|
||
num_features_py = inputs.shape[-1] if inputs.shape[-1] is not None else tf.shape(inputs)[-1]
|
||
|
||
if isinstance(num_features_py, tf.Tensor):
|
||
# If dynamic, use tf.map_fn for dynamic number of features
|
||
def smooth_single_feature(i):
|
||
# Extract single feature channel: [batch_size, time_steps, 1]
|
||
feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1)
|
||
# Apply 1D convolution
|
||
return tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
|
||
|
||
# Use tf.map_fn for dynamic features
|
||
indices = tf.range(num_features)
|
||
smoothed_features_tensor = tf.map_fn(
|
||
smooth_single_feature,
|
||
indices,
|
||
fn_output_signature=tf.TensorSpec(shape=[None, None, 1], dtype=tf.float32)
|
||
)
|
||
# Transpose to get [batch_size, time_steps, features]
|
||
smoothed = tf.transpose(smoothed_features_tensor, [1, 2, 0, 3])
|
||
smoothed = tf.squeeze(smoothed, axis=-1)
|
||
else:
|
||
# Static number of features - use loop
|
||
for i in range(num_features_py):
|
||
# Extract single feature channel: [batch_size, time_steps, 1]
|
||
feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1)
|
||
# Apply 1D convolution
|
||
smoothed_channel = tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
|
||
smoothed_features.append(smoothed_channel)
|
||
|
||
# Concatenate all smoothed features
|
||
smoothed = tf.concat(smoothed_features, axis=-1) # [batch_size, time_steps, features]
|
||
|
||
return smoothed
|
||
|
||
@staticmethod
|
||
def transform_data(features: tf.Tensor,
|
||
n_time_steps: tf.Tensor,
|
||
transform_args: Dict[str, Any],
|
||
training: bool = True) -> Tuple[tf.Tensor, tf.Tensor]:
|
||
"""
|
||
Apply data transformations optimized for TPU
|
||
|
||
Args:
|
||
features: Input features [batch_size, time_steps, channels]
|
||
n_time_steps: Number of valid time steps per sample
|
||
transform_args: Transformation configuration
|
||
training: Whether to apply training-only augmentations
|
||
|
||
Returns:
|
||
Transformed features and updated time steps
|
||
"""
|
||
batch_size = tf.shape(features)[0]
|
||
time_steps = tf.shape(features)[1]
|
||
channels = tf.shape(features)[2]
|
||
|
||
# Training-only augmentations
|
||
if training:
|
||
# Static gain noise
|
||
if transform_args.get('static_gain_std', 0) > 0:
|
||
gain_std = transform_args['static_gain_std']
|
||
# Create identity matrices for each batch
|
||
identity_matrices = tf.eye(channels, batch_shape=[batch_size])
|
||
# Add noise to create warp matrices
|
||
noise = tf.random.normal([batch_size, channels, channels]) * gain_std
|
||
warp_matrices = identity_matrices + noise
|
||
# Apply transformation
|
||
features = tf.linalg.matmul(features, warp_matrices)
|
||
|
||
# White noise
|
||
if transform_args.get('white_noise_std', 0) > 0:
|
||
white_noise = tf.random.normal(tf.shape(features)) * transform_args['white_noise_std']
|
||
features = features + white_noise
|
||
|
||
# Constant offset noise
|
||
if transform_args.get('constant_offset_std', 0) > 0:
|
||
offset_noise = tf.random.normal([batch_size, 1, channels]) * transform_args['constant_offset_std']
|
||
features = features + offset_noise
|
||
|
||
# Random walk noise
|
||
if transform_args.get('random_walk_std', 0) > 0:
|
||
random_walk_noise = tf.random.normal(tf.shape(features)) * transform_args['random_walk_std']
|
||
axis = transform_args.get('random_walk_axis', 1)
|
||
random_walk_noise = tf.cumsum(random_walk_noise, axis=axis)
|
||
features = features + random_walk_noise
|
||
|
||
# Random cutoff (simplified for TPU - apply to all samples in batch)
|
||
if transform_args.get('random_cut', 0) > 0:
|
||
max_cut = transform_args['random_cut']
|
||
cut = tf.random.uniform([], 0, max_cut, dtype=tf.int32)
|
||
features = features[:, cut:, :]
|
||
n_time_steps = n_time_steps - cut
|
||
|
||
# Apply Gaussian smoothing (both training and validation)
|
||
if transform_args.get('smooth_data', False):
|
||
features = DataAugmentationTF.gauss_smooth(
|
||
features,
|
||
smooth_kernel_std=transform_args.get('smooth_kernel_std', 2.0),
|
||
smooth_kernel_size=transform_args.get('smooth_kernel_size', 100)
|
||
)
|
||
|
||
return features, n_time_steps
|
||
|
||
|
||
def train_test_split_indices(file_paths: List[str],
|
||
test_percentage: float = 0.1,
|
||
seed: int = -1,
|
||
bad_trials_dict: Optional[Dict] = None) -> Tuple[Dict, Dict]:
|
||
"""
|
||
Split data from file_paths into train and test splits
|
||
|
||
Args:
|
||
file_paths: List of HDF5 file paths
|
||
test_percentage: Percentage of trials for testing
|
||
seed: Random seed for reproducibility
|
||
bad_trials_dict: Dictionary of trials to exclude
|
||
|
||
Returns:
|
||
Tuple of (train_trials, test_trials) dictionaries
|
||
"""
|
||
# Set seed for reproducibility
|
||
if seed != -1:
|
||
np.random.seed(seed)
|
||
|
||
# Get trials in each day
|
||
trials_per_day = {}
|
||
for i, path in enumerate(file_paths):
|
||
# Handle both Windows and Unix path separators
|
||
path_parts = path.replace('\\', '/').split('/')
|
||
session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
|
||
|
||
good_trial_indices = []
|
||
|
||
if os.path.exists(path):
|
||
with h5py.File(path, 'r') as f:
|
||
num_trials = len(list(f.keys()))
|
||
for t in range(num_trials):
|
||
key = f'trial_{t:04d}'
|
||
|
||
if key not in f:
|
||
continue
|
||
|
||
block_num = f[key].attrs['block_num']
|
||
trial_num = f[key].attrs['trial_num']
|
||
|
||
# Check if trial should be excluded
|
||
if (bad_trials_dict is not None
|
||
and session in bad_trials_dict
|
||
and str(block_num) in bad_trials_dict[session]
|
||
and trial_num in bad_trials_dict[session][str(block_num)]):
|
||
continue
|
||
|
||
good_trial_indices.append(t)
|
||
|
||
trials_per_day[i] = {
|
||
'num_trials': len(good_trial_indices),
|
||
'trial_indices': good_trial_indices,
|
||
'session_path': path
|
||
}
|
||
|
||
# Split trials into train and test
|
||
train_trials = {}
|
||
test_trials = {}
|
||
|
||
for day in trials_per_day.keys():
|
||
num_trials = trials_per_day[day]['num_trials']
|
||
all_trial_indices = trials_per_day[day]['trial_indices']
|
||
|
||
if test_percentage == 0:
|
||
train_trials[day] = {
|
||
'trials': all_trial_indices,
|
||
'session_path': trials_per_day[day]['session_path']
|
||
}
|
||
test_trials[day] = {
|
||
'trials': [],
|
||
'session_path': trials_per_day[day]['session_path']
|
||
}
|
||
elif test_percentage == 1:
|
||
train_trials[day] = {
|
||
'trials': [],
|
||
'session_path': trials_per_day[day]['session_path']
|
||
}
|
||
test_trials[day] = {
|
||
'trials': all_trial_indices,
|
||
'session_path': trials_per_day[day]['session_path']
|
||
}
|
||
else:
|
||
# Calculate number of test trials
|
||
num_test = max(1, int(num_trials * test_percentage))
|
||
|
||
# Randomly select test indices
|
||
test_indices = np.random.choice(all_trial_indices, size=num_test, replace=False).tolist()
|
||
|
||
# Remaining indices for training
|
||
train_indices = [idx for idx in all_trial_indices if idx not in test_indices]
|
||
|
||
train_trials[day] = {
|
||
'trials': train_indices,
|
||
'session_path': trials_per_day[day]['session_path']
|
||
}
|
||
test_trials[day] = {
|
||
'trials': test_indices,
|
||
'session_path': trials_per_day[day]['session_path']
|
||
}
|
||
|
||
return train_trials, test_trials
|
||
|
||
|
||
# Utility functions for TPU-optimized data pipeline
|
||
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||
transform_args: Dict[str, Any],
|
||
training: bool = True) -> tf.data.Dataset:
|
||
"""
|
||
Create input function for TPU training with data augmentation
|
||
|
||
Args:
|
||
dataset_tf: BrainToTextDatasetTF instance
|
||
transform_args: Data transformation configuration
|
||
training: Whether this is for training (applies augmentations)
|
||
|
||
Returns:
|
||
tf.data.Dataset ready for TPU training
|
||
"""
|
||
dataset = dataset_tf.create_dataset()
|
||
|
||
def apply_transforms(batch):
|
||
"""Apply data transformations to a batch"""
|
||
features = batch['input_features']
|
||
n_time_steps = batch['n_time_steps']
|
||
|
||
# Apply transformations
|
||
features, n_time_steps = DataAugmentationTF.transform_data(
|
||
features, n_time_steps, transform_args, training=training
|
||
)
|
||
|
||
# Update batch with transformed data
|
||
batch['input_features'] = features
|
||
batch['n_time_steps'] = n_time_steps
|
||
|
||
return batch
|
||
|
||
# Apply transformations
|
||
dataset = dataset.map(
|
||
apply_transforms,
|
||
num_parallel_calls=tf.data.AUTOTUNE
|
||
)
|
||
|
||
return dataset |