From 4328114ed621fe0850b311e5bb723fecb6665891 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Sun, 19 Oct 2025 11:04:36 +0800 Subject: [PATCH] Add dataset shape analysis function and integrate into input function for TPU optimization --- model_training_nnn_tpu/dataset_tf.py | 94 +++++++++++++++++++++++++--- 1 file changed, 87 insertions(+), 7 deletions(-) diff --git a/model_training_nnn_tpu/dataset_tf.py b/model_training_nnn_tpu/dataset_tf.py index 2898764..5fbc977 100644 --- a/model_training_nnn_tpu/dataset_tf.py +++ b/model_training_nnn_tpu/dataset_tf.py @@ -756,11 +756,81 @@ def train_test_split_indices(file_paths: List[str], return train_trials, test_trials +def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 100) -> Dict[str, int]: + """ + Analyze dataset to determine maximum shapes for padded_batch + + Args: + dataset_tf: Dataset instance to analyze + sample_size: Number of samples to analyze (set to -1 for all data) + + Returns: + Dictionary with maximum dimensions + """ + print(f"🔍 Analyzing dataset shapes (sampling {sample_size} examples)...") + + max_shapes = { + 'max_time_steps': 0, + 'max_phone_seq_len': 0, + 'max_transcription_len': 0, + 'n_features': 512 # Fixed for neural features + } + + # Sample a subset of data to determine max sizes + count = 0 + for batch_idx in list(dataset_tf.batch_indices.keys())[:min(10, len(dataset_tf.batch_indices))]: + batch_index = dataset_tf.batch_indices[batch_idx] + + for day in batch_index.keys(): + for trial in batch_index[day][:min(10, len(batch_index[day]))]: + if count >= sample_size and sample_size > 0: + break + + try: + session_path = dataset_tf.trial_indices[day]['session_path'] + with h5py.File(session_path, 'r') as f: + g = f[f'trial_{trial:04d}'] + + # Check dimensions + time_steps = int(g.attrs['n_time_steps']) + phone_seq_len = int(g.attrs['seq_len']) + transcription_data = g['transcription'][:] + transcription_len = len(transcription_data) + + max_shapes['max_time_steps'] = max(max_shapes['max_time_steps'], time_steps) + max_shapes['max_phone_seq_len'] = max(max_shapes['max_phone_seq_len'], phone_seq_len) + max_shapes['max_transcription_len'] = max(max_shapes['max_transcription_len'], transcription_len) + + count += 1 + + except Exception as e: + logging.warning(f"Failed to analyze trial {day}_{trial}: {e}") + continue + + if count >= sample_size and sample_size > 0: + break + if count >= sample_size and sample_size > 0: + break + + # Add safety margins (20% buffer) to handle edge cases + max_shapes['max_time_steps'] = int(max_shapes['max_time_steps'] * 1.2) + max_shapes['max_phone_seq_len'] = int(max_shapes['max_phone_seq_len'] * 1.2) + max_shapes['max_transcription_len'] = int(max_shapes['max_transcription_len'] * 1.2) + + print(f"📊 Dataset analysis complete (analyzed {count} samples):") + print(f" Max time steps: {max_shapes['max_time_steps']}") + print(f" Max phone sequence length: {max_shapes['max_phone_seq_len']}") + print(f" Max transcription length: {max_shapes['max_transcription_len']}") + + return max_shapes + + # Utility functions for TPU-optimized data pipeline def create_input_fn(dataset_tf: BrainToTextDatasetTF, transform_args: Dict[str, Any], training: bool = True, - cache_path: Optional[str] = None) -> tf.data.Dataset: + cache_path: Optional[str] = None, + auto_analyze_shapes: bool = True) -> tf.data.Dataset: """ Create input function for TPU training with fixed-shape batching and data augmentation @@ -769,6 +839,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, transform_args: Data transformation configuration training: Whether this is for training (applies augmentations) cache_path: Optional path for disk caching to improve I/O performance + auto_analyze_shapes: Whether to automatically analyze dataset for optimal shapes Returns: tf.data.Dataset ready for TPU training with fixed shapes @@ -820,12 +891,21 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, num_parallel_calls=tf.data.AUTOTUNE ) - # Define fixed shapes for TPU compatibility - # These should match the maximum expected sizes in your dataset - max_time_steps = 4096 # Adjust based on your data - max_phone_seq_len = 256 # Adjust based on your data - max_transcription_len = 512 # Adjust based on your data - n_features = 512 # Number of neural features + # Determine shapes for TPU compatibility + if auto_analyze_shapes: + # Dynamically analyze dataset to determine optimal shapes + shape_info = analyze_dataset_shapes(dataset_tf, sample_size=100) + max_time_steps = shape_info['max_time_steps'] + max_phone_seq_len = shape_info['max_phone_seq_len'] + max_transcription_len = shape_info['max_transcription_len'] + n_features = shape_info['n_features'] + else: + # Use conservative fixed shapes for TPU compatibility + # Increased sizes to handle larger data - adjust based on your actual dataset + max_time_steps = 8192 # Increased from 4096 - adjust based on your data + max_phone_seq_len = 512 # Increased from 256 - adjust based on your data + max_transcription_len = 1024 # Increased from 512 - adjust based on your data + n_features = 512 # Number of neural features padded_shapes = { 'input_features': [max_time_steps, n_features],