Files
b2txt25/model_training_nnn_tpu/dataset_tf.py

595 lines
24 KiB
Python
Raw Normal View History

2025-10-15 16:55:52 +08:00
import os
import tensorflow as tf
import h5py
import numpy as np
import math
from typing import Dict, List, Tuple, Optional, Any
from scipy.ndimage import gaussian_filter1d
class BrainToTextDatasetTF:
"""
TensorFlow Dataset for brain-to-text data optimized for TPU v5e-8
This class creates tf.data.Dataset objects that efficiently load and batch
brain-to-text data from HDF5 files with TPU-optimized operations.
"""
def __init__(
self,
trial_indices: Dict[int, Dict[str, Any]],
n_batches: Optional[int],
split: str = 'train',
batch_size: int = 64,
days_per_batch: int = 1,
random_seed: int = -1,
must_include_days: Optional[List[int]] = None,
feature_subset: Optional[List[int]] = None,
prefetch_buffer: int = tf.data.AUTOTUNE,
num_parallel_calls: int = tf.data.AUTOTUNE
):
"""
Initialize TensorFlow dataset for brain-to-text data
Args:
trial_indices: Dictionary with day numbers as keys and trial info as values
n_batches: Number of training batches to create (None for validation)
split: 'train' or 'test'
batch_size: Number of examples per batch
days_per_batch: Number of unique days per batch (for day-specific layers)
random_seed: Random seed for reproducibility
must_include_days: Days that must be included in every batch
feature_subset: Subset of neural features to use
prefetch_buffer: Buffer size for prefetching
num_parallel_calls: Parallel processing threads
"""
# Set random seed for reproducibility
if random_seed != -1:
tf.random.set_seed(random_seed)
np.random.seed(random_seed)
self.split = split
if self.split not in ['train', 'test']:
raise ValueError(f'split must be either "train" or "test". Received {self.split}')
self.days_per_batch = days_per_batch
self.batch_size = batch_size
self.n_batches = n_batches
self.trial_indices = trial_indices
self.n_days = len(trial_indices.keys())
self.feature_subset = feature_subset
self.must_include_days = must_include_days
self.prefetch_buffer = prefetch_buffer
self.num_parallel_calls = num_parallel_calls
# Calculate total number of trials
self.n_trials = 0
for d in trial_indices:
self.n_trials += len(trial_indices[d]['trials'])
# Validation checks
if must_include_days is not None:
if len(must_include_days) > days_per_batch:
raise ValueError(f'must_include_days must be <= days_per_batch')
# Map negative indices
for i, d in enumerate(must_include_days):
if d < 0:
must_include_days[i] = self.n_days + d
if self.split == 'train' and self.days_per_batch > self.n_days:
raise ValueError(f'days_per_batch ({days_per_batch}) > available days ({self.n_days})')
# Create batch indices
if self.split == 'train':
self.batch_indices = self._create_batch_index_train()
else:
self.batch_indices = self._create_batch_index_test()
self.n_batches = len(self.batch_indices)
def _create_batch_index_train(self) -> Dict[int, Dict[int, List[int]]]:
"""Create training batch indices with random sampling"""
batch_indices = {}
# Precompute non-must-include days
if self.must_include_days is not None:
non_must_include_days = [
d for d in self.trial_indices.keys()
if d not in self.must_include_days
]
for batch_idx in range(self.n_batches):
batch = {}
# Select days for this batch
if self.must_include_days is not None and len(self.must_include_days) > 0:
additional_days = np.random.choice(
non_must_include_days,
size=self.days_per_batch - len(self.must_include_days),
replace=False
)
days = np.concatenate((self.must_include_days, additional_days))
else:
days = np.random.choice(
list(self.trial_indices.keys()),
size=self.days_per_batch,
replace=False
)
# Calculate trials per day
num_trials = math.ceil(self.batch_size / self.days_per_batch)
for d in days:
# Sample trials with replacement
trial_idxs = np.random.choice(
self.trial_indices[d]['trials'],
size=num_trials,
replace=True
)
batch[d] = trial_idxs.tolist()
# Remove extra trials to match exact batch size
extra_trials = (num_trials * len(days)) - self.batch_size
while extra_trials > 0:
d = np.random.choice(days)
if len(batch[d]) > 0:
batch[d] = batch[d][:-1]
extra_trials -= 1
batch_indices[batch_idx] = batch
return batch_indices
def _create_batch_index_test(self) -> Dict[int, Dict[int, List[int]]]:
"""Create test batch indices ensuring all trials are seen once"""
batch_indices = {}
batch_idx = 0
for d in self.trial_indices.keys():
num_trials = len(self.trial_indices[d]['trials'])
num_batches = (num_trials + self.batch_size - 1) // self.batch_size
for i in range(num_batches):
start_idx = i * self.batch_size
end_idx = min((i + 1) * self.batch_size, num_trials)
batch_trials = self.trial_indices[d]['trials'][start_idx:end_idx]
batch_indices[batch_idx] = {d: batch_trials}
batch_idx += 1
return batch_indices
def _load_trial_data(self, day: int, trial: int) -> Dict[str, tf.Tensor]:
"""Load a single trial's data from HDF5 file"""
try:
session_path = self.trial_indices[day]['session_path']
with h5py.File(session_path, 'r') as f:
g = f[f'trial_{trial:04d}']
# Load neural features
input_features = g['input_features'][:]
if self.feature_subset:
input_features = input_features[:, self.feature_subset]
# Convert to bfloat16 for TPU efficiency
input_features = input_features.astype(np.float32) # TF will handle bfloat16 conversion
trial_data = {
'input_features': input_features,
'seq_class_ids': g['seq_class_ids'][:],
'transcription': g['transcription'][:],
'n_time_steps': g.attrs['n_time_steps'],
'phone_seq_lens': g.attrs['seq_len'],
'day_index': day,
'block_num': g.attrs['block_num'],
'trial_num': g.attrs['trial_num']
}
return trial_data
except Exception as e:
print(f'Error loading trial {trial} from day {day}: {e}')
# Return dummy data to maintain batch structure
return {
'input_features': np.zeros((100, 512), dtype=np.float32),
'seq_class_ids': np.zeros((10,), dtype=np.int32),
'transcription': np.zeros((50,), dtype=np.int32),
'n_time_steps': 100,
'phone_seq_lens': 10,
'day_index': day,
'block_num': 0,
'trial_num': 0
}
def _create_batch_generator(self):
"""Generator function that yields individual batches"""
for batch_idx in range(self.n_batches):
batch_data = {
'input_features': [],
'seq_class_ids': [],
'n_time_steps': [],
'phone_seq_lens': [],
'day_indices': [],
'transcriptions': [],
'block_nums': [],
'trial_nums': []
}
batch_index = self.batch_indices[batch_idx]
# Load data for each day in the batch
for day in batch_index.keys():
for trial in batch_index[day]:
trial_data = self._load_trial_data(day, trial)
batch_data['input_features'].append(trial_data['input_features'])
batch_data['seq_class_ids'].append(trial_data['seq_class_ids'])
batch_data['transcriptions'].append(trial_data['transcription'])
batch_data['n_time_steps'].append(trial_data['n_time_steps'])
batch_data['phone_seq_lens'].append(trial_data['phone_seq_lens'])
batch_data['day_indices'].append(trial_data['day_index'])
batch_data['block_nums'].append(trial_data['block_num'])
batch_data['trial_nums'].append(trial_data['trial_num'])
# Pad sequences to create uniform batch
max_time_steps = max(batch_data['n_time_steps'])
max_phone_len = max(len(seq) for seq in batch_data['seq_class_ids'])
max_transcription_len = max(len(trans) for trans in batch_data['transcriptions'])
# Pad input features
padded_features = []
for features in batch_data['input_features']:
if features.shape[0] < max_time_steps:
padding = np.zeros((max_time_steps - features.shape[0], features.shape[1]), dtype=np.float32)
features = np.vstack([features, padding])
padded_features.append(features)
# Pad sequences
padded_seq_ids = []
for seq in batch_data['seq_class_ids']:
if len(seq) < max_phone_len:
padding = np.zeros(max_phone_len - len(seq), dtype=np.int32)
seq = np.concatenate([seq, padding])
padded_seq_ids.append(seq)
# Pad transcriptions
padded_transcriptions = []
for trans in batch_data['transcriptions']:
if len(trans) < max_transcription_len:
padding = np.zeros(max_transcription_len - len(trans), dtype=np.int32)
trans = np.concatenate([trans, padding])
padded_transcriptions.append(trans)
# Create final batch tensors
batch = {
'input_features': np.stack(padded_features),
'seq_class_ids': np.stack(padded_seq_ids),
'n_time_steps': np.array(batch_data['n_time_steps'], dtype=np.int32),
'phone_seq_lens': np.array(batch_data['phone_seq_lens'], dtype=np.int32),
'day_indices': np.array(batch_data['day_indices'], dtype=np.int32),
'transcriptions': np.stack(padded_transcriptions),
'block_nums': np.array(batch_data['block_nums'], dtype=np.int32),
'trial_nums': np.array(batch_data['trial_nums'], dtype=np.int32)
}
yield batch
def create_dataset(self) -> tf.data.Dataset:
"""Create optimized tf.data.Dataset for TPU training"""
# Define output signature for the dataset
output_signature = {
'input_features': tf.TensorSpec(shape=(None, None, None), dtype=tf.float32),
'seq_class_ids': tf.TensorSpec(shape=(None, None), dtype=tf.int32),
'n_time_steps': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'phone_seq_lens': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'day_indices': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'transcriptions': tf.TensorSpec(shape=(None, None), dtype=tf.int32),
'block_nums': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'trial_nums': tf.TensorSpec(shape=(None,), dtype=tf.int32)
}
# Create dataset from generator
dataset = tf.data.Dataset.from_generator(
self._create_batch_generator,
output_signature=output_signature
)
# Apply TPU-optimized transformations
if self.split == 'train':
# For training, add shuffling
dataset = dataset.shuffle(buffer_size=min(1000, self.n_batches))
# Prefetch for better performance
dataset = dataset.prefetch(self.prefetch_buffer)
return dataset
class DataAugmentationTF:
"""
TensorFlow data augmentation functions optimized for TPU v5e-8
"""
@staticmethod
def gauss_smooth(inputs: tf.Tensor,
smooth_kernel_std: float = 2.0,
smooth_kernel_size: int = 100) -> tf.Tensor:
"""
Apply Gaussian smoothing along the time axis using TensorFlow operations
Args:
inputs: Input tensor [batch_size, time_steps, features]
smooth_kernel_std: Standard deviation of Gaussian kernel
smooth_kernel_size: Size of the Gaussian kernel
Returns:
Smoothed tensor with same shape as input
"""
# Create Gaussian kernel using numpy (computed once)
inp = np.zeros(smooth_kernel_size, dtype=np.float32)
inp[smooth_kernel_size // 2] = 1
gauss_kernel = gaussian_filter1d(inp, smooth_kernel_std)
valid_idx = np.argwhere(gauss_kernel > 0.01)
gauss_kernel = gauss_kernel[valid_idx].flatten()
gauss_kernel = gauss_kernel / np.sum(gauss_kernel)
2025-10-15 20:45:25 +08:00
# Convert to TensorFlow tensor and reshape for conv1d
2025-10-15 16:55:52 +08:00
gauss_kernel = tf.constant(gauss_kernel, dtype=tf.float32)
2025-10-15 20:45:25 +08:00
kernel_size = tf.shape(gauss_kernel)[0]
gauss_kernel = tf.reshape(gauss_kernel, [kernel_size, 1, 1]) # [kernel_size, in_channels, out_channels]
2025-10-15 16:55:52 +08:00
2025-10-15 20:45:25 +08:00
# Get tensor dimensions
2025-10-15 16:55:52 +08:00
batch_size = tf.shape(inputs)[0]
time_steps = tf.shape(inputs)[1]
num_features = tf.shape(inputs)[2]
2025-10-15 20:45:25 +08:00
# Apply convolution to each feature channel separately
smoothed_features = []
# Convert num_features to Python int for loop
num_features_py = inputs.shape[-1] if inputs.shape[-1] is not None else tf.shape(inputs)[-1]
if isinstance(num_features_py, tf.Tensor):
# If dynamic, use tf.map_fn for dynamic number of features
def smooth_single_feature(i):
# Extract single feature channel: [batch_size, time_steps, 1]
feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1)
# Apply 1D convolution
return tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
# Use tf.map_fn for dynamic features
indices = tf.range(num_features)
smoothed_features_tensor = tf.map_fn(smooth_single_feature, indices, dtype=tf.float32)
# Transpose to get [batch_size, time_steps, features]
smoothed = tf.transpose(smoothed_features_tensor, [1, 2, 0, 3])
smoothed = tf.squeeze(smoothed, axis=-1)
else:
# Static number of features - use loop
for i in range(num_features_py):
# Extract single feature channel: [batch_size, time_steps, 1]
feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1)
# Apply 1D convolution
smoothed_channel = tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
smoothed_features.append(smoothed_channel)
# Concatenate all smoothed features
smoothed = tf.concat(smoothed_features, axis=-1) # [batch_size, time_steps, features]
2025-10-15 16:55:52 +08:00
return smoothed
@staticmethod
def transform_data(features: tf.Tensor,
n_time_steps: tf.Tensor,
transform_args: Dict[str, Any],
training: bool = True) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Apply data transformations optimized for TPU
Args:
features: Input features [batch_size, time_steps, channels]
n_time_steps: Number of valid time steps per sample
transform_args: Transformation configuration
training: Whether to apply training-only augmentations
Returns:
Transformed features and updated time steps
"""
batch_size = tf.shape(features)[0]
time_steps = tf.shape(features)[1]
channels = tf.shape(features)[2]
# Training-only augmentations
if training:
# Static gain noise
if transform_args.get('static_gain_std', 0) > 0:
gain_std = transform_args['static_gain_std']
# Create identity matrices for each batch
identity_matrices = tf.eye(channels, batch_shape=[batch_size])
# Add noise to create warp matrices
noise = tf.random.normal([batch_size, channels, channels]) * gain_std
warp_matrices = identity_matrices + noise
# Apply transformation
features = tf.linalg.matmul(features, warp_matrices)
# White noise
if transform_args.get('white_noise_std', 0) > 0:
white_noise = tf.random.normal(tf.shape(features)) * transform_args['white_noise_std']
features = features + white_noise
# Constant offset noise
if transform_args.get('constant_offset_std', 0) > 0:
offset_noise = tf.random.normal([batch_size, 1, channels]) * transform_args['constant_offset_std']
features = features + offset_noise
# Random walk noise
if transform_args.get('random_walk_std', 0) > 0:
random_walk_noise = tf.random.normal(tf.shape(features)) * transform_args['random_walk_std']
axis = transform_args.get('random_walk_axis', 1)
random_walk_noise = tf.cumsum(random_walk_noise, axis=axis)
features = features + random_walk_noise
# Random cutoff (simplified for TPU - apply to all samples in batch)
if transform_args.get('random_cut', 0) > 0:
max_cut = transform_args['random_cut']
cut = tf.random.uniform([], 0, max_cut, dtype=tf.int32)
features = features[:, cut:, :]
n_time_steps = n_time_steps - cut
# Apply Gaussian smoothing (both training and validation)
if transform_args.get('smooth_data', False):
features = DataAugmentationTF.gauss_smooth(
features,
smooth_kernel_std=transform_args.get('smooth_kernel_std', 2.0),
smooth_kernel_size=transform_args.get('smooth_kernel_size', 100)
)
return features, n_time_steps
def train_test_split_indices(file_paths: List[str],
test_percentage: float = 0.1,
seed: int = -1,
bad_trials_dict: Optional[Dict] = None) -> Tuple[Dict, Dict]:
"""
Split data from file_paths into train and test splits
Args:
file_paths: List of HDF5 file paths
test_percentage: Percentage of trials for testing
seed: Random seed for reproducibility
bad_trials_dict: Dictionary of trials to exclude
Returns:
Tuple of (train_trials, test_trials) dictionaries
"""
# Set seed for reproducibility
if seed != -1:
np.random.seed(seed)
# Get trials in each day
trials_per_day = {}
for i, path in enumerate(file_paths):
# Handle both Windows and Unix path separators
path_parts = path.replace('\\', '/').split('/')
session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
good_trial_indices = []
if os.path.exists(path):
with h5py.File(path, 'r') as f:
num_trials = len(list(f.keys()))
for t in range(num_trials):
key = f'trial_{t:04d}'
if key not in f:
continue
block_num = f[key].attrs['block_num']
trial_num = f[key].attrs['trial_num']
# Check if trial should be excluded
if (bad_trials_dict is not None
and session in bad_trials_dict
and str(block_num) in bad_trials_dict[session]
and trial_num in bad_trials_dict[session][str(block_num)]):
continue
good_trial_indices.append(t)
trials_per_day[i] = {
'num_trials': len(good_trial_indices),
'trial_indices': good_trial_indices,
'session_path': path
}
# Split trials into train and test
train_trials = {}
test_trials = {}
for day in trials_per_day.keys():
num_trials = trials_per_day[day]['num_trials']
all_trial_indices = trials_per_day[day]['trial_indices']
if test_percentage == 0:
train_trials[day] = {
'trials': all_trial_indices,
'session_path': trials_per_day[day]['session_path']
}
test_trials[day] = {
'trials': [],
'session_path': trials_per_day[day]['session_path']
}
elif test_percentage == 1:
train_trials[day] = {
'trials': [],
'session_path': trials_per_day[day]['session_path']
}
test_trials[day] = {
'trials': all_trial_indices,
'session_path': trials_per_day[day]['session_path']
}
else:
# Calculate number of test trials
num_test = max(1, int(num_trials * test_percentage))
# Randomly select test indices
test_indices = np.random.choice(all_trial_indices, size=num_test, replace=False).tolist()
# Remaining indices for training
train_indices = [idx for idx in all_trial_indices if idx not in test_indices]
train_trials[day] = {
'trials': train_indices,
'session_path': trials_per_day[day]['session_path']
}
test_trials[day] = {
'trials': test_indices,
'session_path': trials_per_day[day]['session_path']
}
return train_trials, test_trials
# Utility functions for TPU-optimized data pipeline
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
transform_args: Dict[str, Any],
training: bool = True) -> tf.data.Dataset:
"""
Create input function for TPU training with data augmentation
Args:
dataset_tf: BrainToTextDatasetTF instance
transform_args: Data transformation configuration
training: Whether this is for training (applies augmentations)
Returns:
tf.data.Dataset ready for TPU training
"""
dataset = dataset_tf.create_dataset()
def apply_transforms(batch):
"""Apply data transformations to a batch"""
features = batch['input_features']
n_time_steps = batch['n_time_steps']
# Apply transformations
features, n_time_steps = DataAugmentationTF.transform_data(
features, n_time_steps, transform_args, training=training
)
# Update batch with transformed data
batch['input_features'] = features
batch['n_time_steps'] = n_time_steps
return batch
# Apply transformations
dataset = dataset.map(
apply_transforms,
num_parallel_calls=tf.data.AUTOTUNE
)
return dataset