import os import tensorflow as tf import h5py import numpy as np import math import logging from itertools import groupby from typing import Dict, List, Tuple, Optional, Any from scipy.ndimage import gaussian_filter1d class BrainToTextDatasetTF: """ TensorFlow Dataset for brain-to-text data optimized for TPU v5e-8 This class creates tf.data.Dataset objects that efficiently load and batch brain-to-text data from HDF5 files with TPU-optimized operations. """ def __init__( self, trial_indices: Dict[int, Dict[str, Any]], n_batches: Optional[int], split: str = 'train', batch_size: int = 64, days_per_batch: int = 1, random_seed: int = -1, must_include_days: Optional[List[int]] = None, feature_subset: Optional[List[int]] = None, prefetch_buffer: int = tf.data.AUTOTUNE, num_parallel_calls: int = tf.data.AUTOTUNE, cache_data: bool = True, preload_all_data: bool = False ): """ Initialize TensorFlow dataset for brain-to-text data Args: trial_indices: Dictionary with day numbers as keys and trial info as values n_batches: Number of training batches to create (None for validation) split: 'train' or 'test' batch_size: Number of examples per batch days_per_batch: Number of unique days per batch (for day-specific layers) random_seed: Random seed for reproducibility must_include_days: Days that must be included in every batch feature_subset: Subset of neural features to use prefetch_buffer: Buffer size for prefetching num_parallel_calls: Parallel processing threads cache_data: Whether to cache loaded data in memory preload_all_data: Whether to preload all data at initialization """ # Set random seed for reproducibility if random_seed != -1: tf.random.set_seed(random_seed) np.random.seed(random_seed) self.split = split if self.split not in ['train', 'test']: raise ValueError(f'split must be either "train" or "test". Received {self.split}') self.days_per_batch = days_per_batch self.batch_size = batch_size self.n_batches = n_batches self.trial_indices = trial_indices self.n_days = len(trial_indices.keys()) self.feature_subset = feature_subset self.must_include_days = must_include_days self.prefetch_buffer = prefetch_buffer self.num_parallel_calls = num_parallel_calls self.cache_data = cache_data self.preload_all_data = preload_all_data # Initialize data cache self.data_cache = {} if cache_data else None # Calculate total number of trials self.n_trials = 0 for d in trial_indices: self.n_trials += len(trial_indices[d]['trials']) # Validation checks if must_include_days is not None: if len(must_include_days) > days_per_batch: raise ValueError(f'must_include_days must be <= days_per_batch') # Map negative indices for i, d in enumerate(must_include_days): if d < 0: must_include_days[i] = self.n_days + d if self.split == 'train' and self.days_per_batch > self.n_days: raise ValueError(f'days_per_batch ({days_per_batch}) > available days ({self.n_days})') # Create batch indices if self.split == 'train': self.batch_indices = self._create_batch_index_train() else: self.batch_indices = self._create_batch_index_test() self.n_batches = len(self.batch_indices) # Preload data if requested (speeds up first batch significantly) if self.preload_all_data: print(f"🔄 Preloading all data for {self.split} split...") self._preload_all_data() print(f"✅ Preloading completed - {len(self.data_cache)} trials cached") def _create_batch_index_train(self) -> Dict[int, Dict[int, List[int]]]: """Create training batch indices with random sampling""" batch_indices = {} # Precompute non-must-include days if self.must_include_days is not None: non_must_include_days = [ d for d in self.trial_indices.keys() if d not in self.must_include_days ] for batch_idx in range(self.n_batches): batch = {} # Select days for this batch if self.must_include_days is not None and len(self.must_include_days) > 0: additional_days = np.random.choice( non_must_include_days, size=self.days_per_batch - len(self.must_include_days), replace=False ) days = np.concatenate((self.must_include_days, additional_days)) else: days = np.random.choice( list(self.trial_indices.keys()), size=self.days_per_batch, replace=False ) # Calculate trials per day num_trials = math.ceil(self.batch_size / self.days_per_batch) for d in days: # Sample trials with replacement trial_idxs = np.random.choice( self.trial_indices[d]['trials'], size=num_trials, replace=True ) batch[d] = trial_idxs.tolist() # Remove extra trials to match exact batch size extra_trials = (num_trials * len(days)) - self.batch_size while extra_trials > 0: d = np.random.choice(days) if len(batch[d]) > 0: batch[d] = batch[d][:-1] extra_trials -= 1 batch_indices[batch_idx] = batch return batch_indices def _create_batch_index_test(self) -> Dict[int, Dict[int, List[int]]]: """Create test batch indices ensuring all trials are seen once""" batch_indices = {} batch_idx = 0 for d in self.trial_indices.keys(): num_trials = len(self.trial_indices[d]['trials']) num_batches = (num_trials + self.batch_size - 1) // self.batch_size for i in range(num_batches): start_idx = i * self.batch_size end_idx = min((i + 1) * self.batch_size, num_trials) batch_trials = self.trial_indices[d]['trials'][start_idx:end_idx] batch_indices[batch_idx] = {d: batch_trials} batch_idx += 1 return batch_indices def _preload_all_data(self): """Preload all trial data into memory cache (uses available RAM optimally)""" import multiprocessing from concurrent.futures import ThreadPoolExecutor, as_completed # Use CPU cores efficiently for parallel I/O max_workers = min(multiprocessing.cpu_count(), 32) # Limit to avoid overwhelming I/O # Collect all trials to load trials_to_load = [] for day in self.trial_indices: for trial in self.trial_indices[day]['trials']: trials_to_load.append((day, trial)) print(f"📊 Preloading {len(trials_to_load)} trials using {max_workers} workers...") # Parallel loading using ThreadPoolExecutor with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all loading tasks future_to_trial = { executor.submit(self._load_single_trial_data, day, trial): (day, trial) for day, trial in trials_to_load } # Process completed tasks and update cache loaded_count = 0 for future in as_completed(future_to_trial): day, trial = future_to_trial[future] try: trial_data = future.result() cache_key = f"{day}_{trial}" self.data_cache[cache_key] = trial_data loaded_count += 1 # Progress indicator every 100 trials if loaded_count % 100 == 0: print(f" Loaded {loaded_count}/{len(trials_to_load)} trials...") except Exception as e: print(f" Warning: Failed to load trial {day}_{trial}: {e}") print(f"✅ Preloading completed: {loaded_count}/{len(trials_to_load)} trials cached") def _load_single_trial_data(self, day: int, trial: int) -> Dict[str, Any]: """Load a single trial's data - optimized version for parallel loading""" try: session_path = self.trial_indices[day]['session_path'] with h5py.File(session_path, 'r') as f: g = f[f'trial_{trial:04d}'] # Load neural features input_features = g['input_features'][:] if self.feature_subset: input_features = input_features[:, self.feature_subset] # Convert to float32 for TF compatibility input_features = input_features.astype(np.float32) trial_data = { 'input_features': input_features, 'seq_class_ids': g['seq_class_ids'][:], 'transcription': g['transcription'][:], 'n_time_steps': g.attrs['n_time_steps'], 'phone_seq_lens': g.attrs['seq_len'], 'day_index': day, 'block_num': g.attrs['block_num'], 'trial_num': g.attrs['trial_num'] } return trial_data except Exception as e: # Return dummy data for failed loads return { 'input_features': np.zeros((100, 512), dtype=np.float32), 'seq_class_ids': np.zeros((10,), dtype=np.int32), 'transcription': np.zeros((50,), dtype=np.int32), 'n_time_steps': 100, 'phone_seq_lens': 10, 'day_index': day, 'block_num': 0, 'trial_num': 0 } def _load_trial_data(self, day: int, trial: int) -> Dict[str, tf.Tensor]: """Load a single trial's data from cache or HDF5 file""" # Check cache first if caching is enabled if self.cache_data: cache_key = f"{day}_{trial}" if cache_key in self.data_cache: return self.data_cache[cache_key] # Load from disk if not in cache trial_data = self._load_single_trial_data(day, trial) # Cache the loaded data if caching is enabled if self.cache_data: cache_key = f"{day}_{trial}" self.data_cache[cache_key] = trial_data return trial_data def _create_batch_generator(self): """Generator function that yields individual batches with optimized loading""" import time from concurrent.futures import ThreadPoolExecutor for batch_idx in range(self.n_batches): batch_start_time = time.time() batch_data = { 'input_features': [], 'seq_class_ids': [], 'n_time_steps': [], 'phone_seq_lens': [], 'day_indices': [], 'transcriptions': [], 'block_nums': [], 'trial_nums': [] } batch_index = self.batch_indices[batch_idx] # Collect all trials to load for this batch trials_to_load = [] for day in batch_index.keys(): for trial in batch_index[day]: trials_to_load.append((day, trial)) # Use parallel loading if not preloaded and have multiple trials if not self.preload_all_data and len(trials_to_load) > 4: # Parallel loading for faster I/O with ThreadPoolExecutor(max_workers=min(8, len(trials_to_load))) as executor: future_to_trial = { executor.submit(self._load_trial_data, day, trial): (day, trial) for day, trial in trials_to_load } # Collect results in order trial_results = {} for future in future_to_trial: day, trial = future_to_trial[future] trial_results[(day, trial)] = future.result() # Add data in original order for day, trial in trials_to_load: trial_data = trial_results[(day, trial)] batch_data['input_features'].append(trial_data['input_features']) batch_data['seq_class_ids'].append(trial_data['seq_class_ids']) batch_data['transcriptions'].append(trial_data['transcription']) batch_data['n_time_steps'].append(trial_data['n_time_steps']) batch_data['phone_seq_lens'].append(trial_data['phone_seq_lens']) batch_data['day_indices'].append(trial_data['day_index']) batch_data['block_nums'].append(trial_data['block_num']) batch_data['trial_nums'].append(trial_data['trial_num']) else: # Sequential loading (fast when data is cached or few trials) for day, trial in trials_to_load: trial_data = self._load_trial_data(day, trial) batch_data['input_features'].append(trial_data['input_features']) batch_data['seq_class_ids'].append(trial_data['seq_class_ids']) batch_data['transcriptions'].append(trial_data['transcription']) batch_data['n_time_steps'].append(trial_data['n_time_steps']) batch_data['phone_seq_lens'].append(trial_data['phone_seq_lens']) batch_data['day_indices'].append(trial_data['day_index']) batch_data['block_nums'].append(trial_data['block_num']) batch_data['trial_nums'].append(trial_data['trial_num']) data_loading_time = time.time() - batch_start_time # Add timing diagnostic for first few batches if batch_idx < 3: cache_status = "cached" if self.preload_all_data else "disk" loading_method = "parallel" if (not self.preload_all_data and len(trials_to_load) > 4) else "sequential" print(f"⏱️ Batch {batch_idx}: {len(trials_to_load)} trials loaded in {data_loading_time:.3f}s ({cache_status}, {loading_method})") # Pad sequences to create uniform batch max_time_steps = max(batch_data['n_time_steps']) max_phone_len = max(len(seq) for seq in batch_data['seq_class_ids']) max_transcription_len = max(len(trans) for trans in batch_data['transcriptions']) # Pad input features padded_features = [] for features in batch_data['input_features']: if features.shape[0] < max_time_steps: padding = np.zeros((max_time_steps - features.shape[0], features.shape[1]), dtype=np.float32) features = np.vstack([features, padding]) padded_features.append(features) # Pad sequences padded_seq_ids = [] for seq in batch_data['seq_class_ids']: if len(seq) < max_phone_len: padding = np.zeros(max_phone_len - len(seq), dtype=np.int32) seq = np.concatenate([seq, padding]) padded_seq_ids.append(seq) # Pad transcriptions padded_transcriptions = [] for trans in batch_data['transcriptions']: if len(trans) < max_transcription_len: padding = np.zeros(max_transcription_len - len(trans), dtype=np.int32) trans = np.concatenate([trans, padding]) padded_transcriptions.append(trans) # Create final batch tensors batch = { 'input_features': np.stack(padded_features), 'seq_class_ids': np.stack(padded_seq_ids), 'n_time_steps': np.array(batch_data['n_time_steps'], dtype=np.int32), 'phone_seq_lens': np.array(batch_data['phone_seq_lens'], dtype=np.int32), 'day_indices': np.array(batch_data['day_indices'], dtype=np.int32), 'transcriptions': np.stack(padded_transcriptions), 'block_nums': np.array(batch_data['block_nums'], dtype=np.int32), 'trial_nums': np.array(batch_data['trial_nums'], dtype=np.int32) } yield batch def create_dataset(self) -> tf.data.Dataset: """Create optimized tf.data.Dataset for TPU training""" # Define output signature for the dataset output_signature = { 'input_features': tf.TensorSpec(shape=(None, None, None), dtype=tf.float32), 'seq_class_ids': tf.TensorSpec(shape=(None, None), dtype=tf.int32), 'n_time_steps': tf.TensorSpec(shape=(None,), dtype=tf.int32), 'phone_seq_lens': tf.TensorSpec(shape=(None,), dtype=tf.int32), 'day_indices': tf.TensorSpec(shape=(None,), dtype=tf.int32), 'transcriptions': tf.TensorSpec(shape=(None, None), dtype=tf.int32), 'block_nums': tf.TensorSpec(shape=(None,), dtype=tf.int32), 'trial_nums': tf.TensorSpec(shape=(None,), dtype=tf.int32) } # Create dataset from generator dataset = tf.data.Dataset.from_generator( self._create_batch_generator, output_signature=output_signature ) # Apply TPU-optimized transformations # 🚨 GPU版本策略:不需要在Dataset级别shuffle! # GPU版本在 _create_batch_index_train() 中已经做了随机采样(第107-118行) # 这里再shuffle会导致内存爆炸(1000 batch × 256 trials = 256,000 trials同时在内存) # if self.split == 'train': # dataset = dataset.shuffle(buffer_size=min(1000, self.n_batches)) # ← 注释掉内存杀手 # Prefetch for better performance dataset = dataset.prefetch(self.prefetch_buffer) return dataset def create_individual_dataset(self) -> tf.data.Dataset: """ Create tf.data.Dataset that yields individual examples with I/O optimization. This generator is refactored to group trial loading by session file, drastically reducing the number of file open/close operations from N_trials to N_sessions, which is ideal for slow disk I/O. """ def individual_example_generator(): """Generator that groups reads by file to minimize disk I/O.""" # 1. 创建一个所有试验的扁平列表: [(day, trial), (day, trial), ...] all_trials_to_load = [] # 注意:这里的迭代顺序决定了大致的读取顺序 # _create_batch_index_train 已经为我们随机化了批次 for batch_idx in sorted(self.batch_indices.keys()): batch_index = self.batch_indices[batch_idx] for day in batch_index.keys(): for trial in batch_index[day]: all_trials_to_load.append((day, trial)) # 2. 按 'day' (即按文件) 对试验列表进行分组 # key=lambda x: x[0] 表示使用元组的第一个元素 (day) 作为分组键 for day, group in groupby(sorted(all_trials_to_load, key=lambda x: x[0]), key=lambda x: x[0]): session_path = self.trial_indices[day]['session_path'] # 3. 为每个分组(每个文件)只打开一次 HDF5 文件 try: with h5py.File(session_path, 'r') as f: # 4. 在文件打开的状态下,读取这个文件中需要的所有试验 for current_day, current_trial in group: try: # 直接从打开的文件句柄 'f' 中读取,而不是调用旧的加载函数 g = f[f'trial_{current_trial:04d}'] input_features = g['input_features'][:] if self.feature_subset: input_features = input_features[:, self.feature_subset] example = { 'input_features': input_features.astype(np.float32), 'seq_class_ids': g['seq_class_ids'][:].astype(np.int32), 'n_time_steps': np.int32(g.attrs['n_time_steps']), 'phone_seq_lens': np.int32(g.attrs['seq_len']), 'day_indices': np.int32(current_day), 'transcriptions': g['transcription'][:].astype(np.int32), 'block_nums': np.int32(g.attrs['block_num']), 'trial_nums': np.int32(g.attrs['trial_num']) } yield example except KeyError: logging.warning(f"Trial {current_trial} not found in file {session_path}. Skipping.") continue except (IOError, FileNotFoundError) as e: logging.error(f"Could not open or read HDF5 file: {session_path}. Error: {e}. Skipping all trials for this day.") continue # Define output signature for individual examples output_signature = { 'input_features': tf.TensorSpec(shape=(None, None), dtype=tf.float32), 'seq_class_ids': tf.TensorSpec(shape=(None,), dtype=tf.int32), 'n_time_steps': tf.TensorSpec(shape=(), dtype=tf.int32), 'phone_seq_lens': tf.TensorSpec(shape=(), dtype=tf.int32), 'day_indices': tf.TensorSpec(shape=(), dtype=tf.int32), 'transcriptions': tf.TensorSpec(shape=(None,), dtype=tf.int32), 'block_nums': tf.TensorSpec(shape=(), dtype=tf.int32), 'trial_nums': tf.TensorSpec(shape=(), dtype=tf.int32) } # Create dataset from individual examples dataset = tf.data.Dataset.from_generator( individual_example_generator, output_signature=output_signature ) # Shuffle individual examples if training (more effective than batch-level shuffle) if self.split == 'train': # 可以适当增大buffer,因为现在I/O更高效了 shuffle_buffer = min(2048, self.n_trials) dataset = dataset.shuffle(buffer_size=shuffle_buffer) return dataset class DataAugmentationTF: """ TensorFlow data augmentation functions optimized for TPU v5e-8 """ @staticmethod def gauss_smooth(inputs: tf.Tensor, smooth_kernel_std: float = 2.0, smooth_kernel_size: int = 100) -> tf.Tensor: """ Apply Gaussian smoothing along the time axis using a vectorized TensorFlow operation. This implementation uses depthwise_conv2d for optimal TPU performance, replacing the inefficient Python for-loop that created 512 separate conv1d operations. Args: inputs: Input tensor [batch_size, time_steps, features] smooth_kernel_std: Standard deviation of Gaussian kernel smooth_kernel_size: Size of the Gaussian kernel Returns: Smoothed tensor with same shape as input """ # Create Gaussian kernel using numpy (computed once) inp = np.zeros(smooth_kernel_size, dtype=np.float32) inp[smooth_kernel_size // 2] = 1 gauss_kernel = gaussian_filter1d(inp, smooth_kernel_std) valid_idx = np.argwhere(gauss_kernel > 0.01) gauss_kernel = gauss_kernel[valid_idx].flatten() gauss_kernel = gauss_kernel / np.sum(gauss_kernel) gauss_kernel = tf.constant(gauss_kernel, dtype=tf.float32) # ========================= OPTIMIZED SOLUTION ========================= # Get input dimensions num_features = tf.shape(inputs)[-1] kernel_size = tf.shape(gauss_kernel)[0] # Prepare kernel for depthwise_conv2d # Shape needed: [height, width, in_channels, channel_multiplier] # Our case: [kernel_size, 1, num_features, 1] # This means each input channel (num_features) has its own independent, identical 1D Gaussian kernel kernel = tf.reshape(gauss_kernel, [kernel_size, 1, 1, 1]) kernel = tf.tile(kernel, [1, 1, num_features, 1]) # Prepare input for conv2d # Shape needed: [batch, height, width, channels] # Our case: [batch_size, time_steps, 1, num_features] # Add a dummy width dimension reshaped_inputs = tf.expand_dims(inputs, axis=2) # Execute depthwise convolution # This is a single, efficient operation replacing the original Python for-loop smoothed = tf.nn.depthwise_conv2d( reshaped_inputs, kernel, strides=[1, 1, 1, 1], padding='SAME' ) # Remove the dummy width dimension to restore original shape smoothed = tf.squeeze(smoothed, axis=2) # ================================================================ return smoothed @staticmethod def transform_data(features: tf.Tensor, n_time_steps: tf.Tensor, transform_args: Dict[str, Any], training: bool = True) -> Tuple[tf.Tensor, tf.Tensor]: """ Apply data transformations optimized for TPU Args: features: Input features [batch_size, time_steps, channels] n_time_steps: Number of valid time steps per sample transform_args: Transformation configuration training: Whether to apply training-only augmentations Returns: Transformed features and updated time steps """ batch_size = tf.shape(features)[0] time_steps = tf.shape(features)[1] channels = tf.shape(features)[2] # Training-only augmentations if training: # Static gain noise if transform_args.get('static_gain_std', 0) > 0: gain_std = transform_args['static_gain_std'] # Create identity matrices for each batch identity_matrices = tf.eye(channels, batch_shape=[batch_size]) # Add noise to create warp matrices noise = tf.random.normal([batch_size, channels, channels]) * gain_std warp_matrices = identity_matrices + noise # Apply transformation features = tf.linalg.matmul(features, warp_matrices) # White noise if transform_args.get('white_noise_std', 0) > 0: white_noise = tf.random.normal(tf.shape(features)) * transform_args['white_noise_std'] features = features + white_noise # Constant offset noise if transform_args.get('constant_offset_std', 0) > 0: offset_noise = tf.random.normal([batch_size, 1, channels]) * transform_args['constant_offset_std'] features = features + offset_noise # Random walk noise if transform_args.get('random_walk_std', 0) > 0: random_walk_noise = tf.random.normal(tf.shape(features)) * transform_args['random_walk_std'] axis = transform_args.get('random_walk_axis', 1) random_walk_noise = tf.cumsum(random_walk_noise, axis=axis) features = features + random_walk_noise # Random cutoff (simplified for TPU - apply to all samples in batch) if transform_args.get('random_cut', 0) > 0: max_cut = transform_args['random_cut'] cut = tf.random.uniform([], 0, max_cut, dtype=tf.int32) features = features[:, cut:, :] n_time_steps = n_time_steps - cut # Apply Gaussian smoothing (both training and validation) if transform_args.get('smooth_data', False): features = DataAugmentationTF.gauss_smooth( features, smooth_kernel_std=transform_args.get('smooth_kernel_std', 2.0), smooth_kernel_size=transform_args.get('smooth_kernel_size', 100) ) return features, n_time_steps def train_test_split_indices(file_paths: List[str], test_percentage: float = 0.1, seed: int = -1, bad_trials_dict: Optional[Dict] = None) -> Tuple[Dict, Dict]: """ Split data from file_paths into train and test splits Args: file_paths: List of HDF5 file paths test_percentage: Percentage of trials for testing seed: Random seed for reproducibility bad_trials_dict: Dictionary of trials to exclude Returns: Tuple of (train_trials, test_trials) dictionaries """ # Set seed for reproducibility if seed != -1: np.random.seed(seed) # Get trials in each day trials_per_day = {} for i, path in enumerate(file_paths): # Handle both Windows and Unix path separators path_parts = path.replace('\\', '/').split('/') session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0] good_trial_indices = [] if os.path.exists(path): with h5py.File(path, 'r') as f: num_trials = len(list(f.keys())) for t in range(num_trials): key = f'trial_{t:04d}' if key not in f: continue block_num = f[key].attrs['block_num'] trial_num = f[key].attrs['trial_num'] # Check if trial should be excluded if (bad_trials_dict is not None and session in bad_trials_dict and str(block_num) in bad_trials_dict[session] and trial_num in bad_trials_dict[session][str(block_num)]): continue good_trial_indices.append(t) trials_per_day[i] = { 'num_trials': len(good_trial_indices), 'trial_indices': good_trial_indices, 'session_path': path } # Split trials into train and test train_trials = {} test_trials = {} for day in trials_per_day.keys(): num_trials = trials_per_day[day]['num_trials'] all_trial_indices = trials_per_day[day]['trial_indices'] if test_percentage == 0: train_trials[day] = { 'trials': all_trial_indices, 'session_path': trials_per_day[day]['session_path'] } test_trials[day] = { 'trials': [], 'session_path': trials_per_day[day]['session_path'] } elif test_percentage == 1: train_trials[day] = { 'trials': [], 'session_path': trials_per_day[day]['session_path'] } test_trials[day] = { 'trials': all_trial_indices, 'session_path': trials_per_day[day]['session_path'] } else: # Calculate number of test trials num_test = max(1, int(num_trials * test_percentage)) # Randomly select test indices test_indices = np.random.choice(all_trial_indices, size=num_test, replace=False).tolist() # Remaining indices for training train_indices = [idx for idx in all_trial_indices if idx not in test_indices] train_trials[day] = { 'trials': train_indices, 'session_path': trials_per_day[day]['session_path'] } test_trials[day] = { 'trials': test_indices, 'session_path': trials_per_day[day]['session_path'] } return train_trials, test_trials def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 100) -> Dict[str, int]: """ Analyze dataset to determine maximum shapes for padded_batch Args: dataset_tf: Dataset instance to analyze sample_size: Number of samples to analyze (set to -1 for all data) Returns: Dictionary with maximum dimensions """ print(f"🔍 Analyzing dataset shapes (sampling {sample_size} examples)...") max_shapes = { 'max_time_steps': 0, 'max_phone_seq_len': 0, 'max_transcription_len': 0, 'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512 } # Sample a subset of data to determine max sizes count = 0 batch_keys = list(dataset_tf.batch_indices.keys()) # If sample_size is -1, analyze all data if sample_size == -1: batches_to_check = batch_keys max_trials_per_batch = float('inf') else: # Sample a reasonable number of batches batches_to_check = batch_keys[:min(max(10, sample_size // 10), len(batch_keys))] max_trials_per_batch = max(1, sample_size // len(batches_to_check)) for batch_idx in batches_to_check: if count >= sample_size and sample_size > 0: break batch_index = dataset_tf.batch_indices[batch_idx] for day in batch_index.keys(): if count >= sample_size and sample_size > 0: break trials_to_check = batch_index[day][:min(int(max_trials_per_batch), len(batch_index[day]))] for trial in trials_to_check: if count >= sample_size and sample_size > 0: break try: session_path = dataset_tf.trial_indices[day]['session_path'] with h5py.File(session_path, 'r') as f: g = f[f'trial_{trial:04d}'] # Check dimensions time_steps = int(g.attrs['n_time_steps']) phone_seq_len = int(g.attrs['seq_len']) transcription_data = g['transcription'][:] transcription_len = len(transcription_data) max_shapes['max_time_steps'] = max(max_shapes['max_time_steps'], time_steps) max_shapes['max_phone_seq_len'] = max(max_shapes['max_phone_seq_len'], phone_seq_len) max_shapes['max_transcription_len'] = max(max_shapes['max_transcription_len'], transcription_len) count += 1 # Show progress for large analyses if count % 50 == 0: print(f" Analyzed {count} samples... current max: time={max_shapes['max_time_steps']}, phone={max_shapes['max_phone_seq_len']}, trans={max_shapes['max_transcription_len']}") except Exception as e: logging.warning(f"Failed to analyze trial {day}_{trial}: {e}") continue # Add safety margins (20% buffer) to handle edge cases original_time_steps = max_shapes['max_time_steps'] original_phone_seq_len = max_shapes['max_phone_seq_len'] original_transcription_len = max_shapes['max_transcription_len'] max_shapes['max_time_steps'] = int(max_shapes['max_time_steps'] * 1.2) max_shapes['max_phone_seq_len'] = int(max_shapes['max_phone_seq_len'] * 1.2) max_shapes['max_transcription_len'] = int(max_shapes['max_transcription_len'] * 1.2) print(f"📊 Dataset analysis complete (analyzed {count} samples):") print(f" Original max time steps: {original_time_steps} → Padded: {max_shapes['max_time_steps']}") print(f" Original max phone sequence length: {original_phone_seq_len} → Padded: {max_shapes['max_phone_seq_len']}") print(f" Original max transcription length: {original_transcription_len} → Padded: {max_shapes['max_transcription_len']}") print(f" Number of features: {max_shapes['n_features']}") return max_shapes # Utility functions for TPU-optimized data pipeline def create_input_fn(dataset_tf: BrainToTextDatasetTF, transform_args: Dict[str, Any], training: bool = True, cache_path: Optional[str] = None, auto_analyze_shapes: bool = True) -> tf.data.Dataset: """ Create input function for TPU training with fixed-shape batching and data augmentation Args: dataset_tf: BrainToTextDatasetTF instance transform_args: Data transformation configuration training: Whether this is for training (applies augmentations) cache_path: Optional path for disk caching to improve I/O performance auto_analyze_shapes: Whether to automatically analyze dataset for optimal shapes Returns: tf.data.Dataset ready for TPU training with fixed shapes """ # Create individual example dataset with file-grouping I/O optimization dataset = dataset_tf.create_individual_dataset() # ========================= I/O OPTIMIZATION SOLUTION ========================= # 对训练集和验证集都进行缓存,因为: # 1. 训练集:每个epoch都要完整遍历 # 2. 验证集:每200轮验证一次 + 早停检查,会被频繁使用 if cache_path: dataset = dataset.cache(cache_path) split_name = "training" if training else "validation" print(f"🗃️ {split_name.capitalize()} dataset caching enabled: {cache_path}") print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster") else: # 如果没有指定缓存路径,默认使用内存缓存 # 对于大型数据集,建议在调用时显式指定磁盘缓存路径 dataset = dataset.cache() split_name = "training" if training else "validation" print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache") print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster") # ================================================================ def apply_transforms(example): """Apply data transformations to individual examples""" features = example['input_features'] n_time_steps = example['n_time_steps'] # Apply transformations features, n_time_steps = DataAugmentationTF.transform_data( tf.expand_dims(features, 0), # Add batch dimension for transforms tf.expand_dims(n_time_steps, 0), transform_args, training=training ) # Remove batch dimension example['input_features'] = tf.squeeze(features, 0) example['n_time_steps'] = tf.squeeze(n_time_steps, 0) return example # 在缓存之后应用随机的数据增强,确保每个epoch的增强都不同 dataset = dataset.map( apply_transforms, num_parallel_calls=tf.data.AUTOTUNE ) # Determine shapes for TPU compatibility if auto_analyze_shapes: # Dynamically analyze dataset to determine optimal shapes shape_info = analyze_dataset_shapes(dataset_tf, sample_size=100) max_time_steps = shape_info['max_time_steps'] max_phone_seq_len = shape_info['max_phone_seq_len'] max_transcription_len = shape_info['max_transcription_len'] n_features = shape_info['n_features'] print(f"🔧 Using auto-analyzed shapes: time_steps={max_time_steps}, phone_seq={max_phone_seq_len}, transcription={max_transcription_len}") else: # Use dynamic shapes for maximum compatibility - let TensorFlow handle padding automatically # This avoids the "pad to a smaller size" error by allowing dynamic sizing print(f"🔧 Using dynamic shapes for maximum compatibility") # Calculate number of features based on subset n_features = len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512 padded_shapes = { 'input_features': tf.TensorShape([None, n_features]), 'seq_class_ids': tf.TensorShape([None]), 'n_time_steps': tf.TensorShape([]), # Scalar 'phone_seq_lens': tf.TensorShape([]), # Scalar 'day_indices': tf.TensorShape([]), # Scalar 'transcriptions': tf.TensorShape([None]), 'block_nums': tf.TensorShape([]), # Scalar 'trial_nums': tf.TensorShape([]) # Scalar } # Create fixed-shape batches with dynamic padding dataset = dataset.padded_batch( batch_size=dataset_tf.batch_size, padded_shapes=padded_shapes, padding_values=padding_values, drop_remainder=True # Critical for TPU: ensures all batches have same size ) # Prefetch for optimal performance dataset = dataset.prefetch(tf.data.AUTOTUNE) return dataset # If using auto-analyzed shapes, create fixed-size padded shapes padded_shapes = { 'input_features': [max_time_steps, n_features], 'seq_class_ids': [max_phone_seq_len], 'n_time_steps': [], # Scalar 'phone_seq_lens': [], # Scalar 'day_indices': [], # Scalar 'transcriptions': [max_transcription_len], 'block_nums': [], # Scalar 'trial_nums': [] # Scalar } padding_values = { 'input_features': 0.0, 'seq_class_ids': 0, 'n_time_steps': 0, 'phone_seq_lens': 0, 'day_indices': 0, 'transcriptions': 0, 'block_nums': 0, 'trial_nums': 0 } # Create fixed-shape batches with padding dataset = dataset.padded_batch( batch_size=dataset_tf.batch_size, padded_shapes=padded_shapes, padding_values=padding_values, drop_remainder=True # Critical for TPU: ensures all batches have same size ) # Prefetch for optimal performance dataset = dataset.prefetch(tf.data.AUTOTUNE) return dataset