From 4b373ab317a82d7af57b63317469e8790a1605a7 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Sun, 19 Oct 2025 20:16:23 +0800 Subject: [PATCH] ff --- model_training_nnn_tpu/dataset_tf.py | 78 +++++++++------------------- 1 file changed, 25 insertions(+), 53 deletions(-) diff --git a/model_training_nnn_tpu/dataset_tf.py b/model_training_nnn_tpu/dataset_tf.py index 2b2fd53..bb44c53 100644 --- a/model_training_nnn_tpu/dataset_tf.py +++ b/model_training_nnn_tpu/dataset_tf.py @@ -852,20 +852,22 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = def create_input_fn(dataset_tf: BrainToTextDatasetTF, transform_args: Dict[str, Any], training: bool = True, - cache_path: Optional[str] = None, - auto_analyze_shapes: bool = True) -> tf.data.Dataset: + cache_path: Optional[str] = None) -> tf.data.Dataset: """ - Create input function for TPU training with fixed-shape batching and data augmentation + Create input function for TPU training with DYNAMIC padding and data augmentation + + This function uses dynamic shapes to avoid the "pad to a smaller size" error. + All variable-length dimensions use tf.TensorShape([None, ...]) to allow + TensorFlow to automatically determine the appropriate padding size for each batch. Args: dataset_tf: BrainToTextDatasetTF instance transform_args: Data transformation configuration training: Whether this is for training (applies augmentations) cache_path: Optional path for disk caching to improve I/O performance - auto_analyze_shapes: Whether to automatically analyze dataset for optimal shapes Returns: - tf.data.Dataset ready for TPU training with fixed shapes + tf.data.Dataset ready for TPU training with dynamic shapes """ # Create individual example dataset with file-grouping I/O optimization @@ -914,58 +916,27 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, num_parallel_calls=tf.data.AUTOTUNE ) - # Determine shapes for TPU compatibility - if auto_analyze_shapes: - # Dynamically analyze dataset to determine optimal shapes - shape_info = analyze_dataset_shapes(dataset_tf, sample_size=100) - max_time_steps = shape_info['max_time_steps'] - max_phone_seq_len = shape_info['max_phone_seq_len'] - max_transcription_len = shape_info['max_transcription_len'] - n_features = shape_info['n_features'] - print(f"🔧 Using auto-analyzed shapes: time_steps={max_time_steps}, phone_seq={max_phone_seq_len}, transcription={max_transcription_len}") - else: - # Use dynamic shapes for maximum compatibility - let TensorFlow handle padding automatically - # This avoids the "pad to a smaller size" error by allowing dynamic sizing - print(f"🔧 Using dynamic shapes for maximum compatibility") - # Calculate number of features based on subset - n_features = len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512 + # ========================= DYNAMIC SHAPES SOLUTION ========================= + # 使用动态形状避免 "pad to a smaller size" 错误 + # 这是最简单、最健壮的解决方案 + print("🔧 Using DYNAMIC shapes for maximum compatibility and robustness.") - padded_shapes = { - 'input_features': tf.TensorShape([None, n_features]), - 'seq_class_ids': tf.TensorShape([None]), - 'n_time_steps': tf.TensorShape([]), # Scalar - 'phone_seq_lens': tf.TensorShape([]), # Scalar - 'day_indices': tf.TensorShape([]), # Scalar - 'transcriptions': tf.TensorShape([None]), - 'block_nums': tf.TensorShape([]), # Scalar - 'trial_nums': tf.TensorShape([]) # Scalar - } + # Calculate number of features based on subset + n_features = len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512 - # Create fixed-shape batches with dynamic padding - dataset = dataset.padded_batch( - batch_size=dataset_tf.batch_size, - padded_shapes=padded_shapes, - padding_values=padding_values, - drop_remainder=True # Critical for TPU: ensures all batches have same size - ) - - # Prefetch for optimal performance - dataset = dataset.prefetch(tf.data.AUTOTUNE) - - return dataset - - # If using auto-analyzed shapes, create fixed-size padded shapes + # Define dynamic padded shapes - all variable dimensions use None padded_shapes = { - 'input_features': [max_time_steps, n_features], - 'seq_class_ids': [max_phone_seq_len], - 'n_time_steps': [], # Scalar - 'phone_seq_lens': [], # Scalar - 'day_indices': [], # Scalar - 'transcriptions': [max_transcription_len], - 'block_nums': [], # Scalar - 'trial_nums': [] # Scalar + 'input_features': tf.TensorShape([None, n_features]), # 时间维度动态 + 'seq_class_ids': tf.TensorShape([None]), # 序列长度动态 + 'n_time_steps': tf.TensorShape([]), # 标量 + 'phone_seq_lens': tf.TensorShape([]), # 标量 + 'day_indices': tf.TensorShape([]), # 标量 + 'transcriptions': tf.TensorShape([None]), # 转录长度动态 + 'block_nums': tf.TensorShape([]), # 标量 + 'trial_nums': tf.TensorShape([]) # 标量 } + # Define padding values for each field padding_values = { 'input_features': 0.0, 'seq_class_ids': 0, @@ -977,7 +948,8 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, 'trial_nums': 0 } - # Create fixed-shape batches with padding + # Create batches with dynamic padding - TensorFlow will automatically + # determine the appropriate padding size for each batch dataset = dataset.padded_batch( batch_size=dataset_tf.batch_size, padded_shapes=padded_shapes,