diff --git a/model_training_nnn_tpu/dataset_tf.py b/model_training_nnn_tpu/dataset_tf.py index e16f13f..7448657 100644 --- a/model_training_nnn_tpu/dataset_tf.py +++ b/model_training_nnn_tpu/dataset_tf.py @@ -1,4 +1,5 @@ import os +import sys import tensorflow as tf import h5py import numpy as np @@ -109,6 +110,24 @@ class BrainToTextDatasetTF: self._preload_all_data() print(f"✅ Preloading completed - {len(self.data_cache)} trials cached") + # ========================= 特征维度自动检测 ========================= + # 明确地计算并存储特征维度,避免 padded_batch 时的维度不匹配 + if self.feature_subset: + self.feature_dim = len(self.feature_subset) + print(f"✅ Using feature subset dimension: {self.feature_dim}") + else: + # 不要硬编码!尝试从数据中推断实际特征维度 + try: + first_day = next(iter(self.trial_indices)) + first_trial = self.trial_indices[first_day]['trials'][0] + first_sample = self._load_single_trial_data(first_day, first_trial) + self.feature_dim = first_sample['input_features'].shape[1] + print(f"✅ Auto-detected feature dimension: {self.feature_dim}") + except Exception as e: + print(f"⚠️ Could not auto-detect feature dimension, falling back to 512. Error: {e}") + self.feature_dim = 512 # 作为最后的备用方案 + # ========================= 特征维度检测结束 ========================= + def _create_batch_index_train(self) -> Dict[int, Dict[int, List[int]]]: """Create training batch indices with random sampling""" batch_indices = {} @@ -953,6 +972,20 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, num_parallel_calls=tf.data.AUTOTUNE ) + # ========================= 终极调试代码 ========================= + def debug_print_shape(example): + """调试函数:在 padded_batch 之前打印每个样本的形状""" + tf.print("🔍 Sample Shape Debug:", + tf.shape(example['input_features']), + "Expected feature dim:", dataset_tf.feature_dim, + output_stream=sys.stdout) + return example + + # 添加形状调试 - 这会在图执行时打印信息 + dataset = dataset.map(debug_print_shape) + print(f"⚠️ Debug mode: Will print each sample shape before padded_batch") + # ============================================================= + # Step 4: Batch samples with FIXED STATIC padding (CRITICAL for XLA) print(f"🔧 Using PRE-ANALYZED FIXED shapes for maximum TPU performance:") @@ -960,23 +993,28 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, max_time_steps = max_shapes['max_time_steps'] max_phone_seq_len = max_shapes['max_phone_seq_len'] max_transcription_len = max_shapes['max_transcription_len'] - n_features = max_shapes['n_features'] + + # ========================= 使用统一的特征维度 ========================= + # 使用 dataset_tf 对象中存储的、经过验证的特征维度,而不是依赖外部参数 + n_features = dataset_tf.feature_dim # <--- 关键修改:使用自动检测的特征维度 + print(f"🔧 Using verified feature dimension from dataset: {n_features}") + # ========================= 特征维度修改结束 ========================= print(f" Fixed time steps: {max_time_steps}") print(f" Fixed phone sequence length: {max_phone_seq_len}") print(f" Fixed transcription length: {max_transcription_len}") print(f" Number of features: {n_features}") - # Define FIXED padded shapes - NO None values for XLA compatibility + # Define FIXED padded shapes with TensorSpec for better type safety padded_shapes = { - 'input_features': tf.TensorShape([max_time_steps, n_features]), - 'seq_class_ids': tf.TensorShape([max_phone_seq_len]), - 'n_time_steps': tf.TensorShape([]), # scalar - 'phone_seq_lens': tf.TensorShape([]), # scalar - 'day_indices': tf.TensorShape([]), # scalar - 'transcriptions': tf.TensorShape([max_transcription_len]), - 'block_nums': tf.TensorShape([]), # scalar - 'trial_nums': tf.TensorShape([]) # scalar + 'input_features': tf.TensorSpec(shape=[max_time_steps, n_features], dtype=tf.float32), + 'seq_class_ids': tf.TensorSpec(shape=[max_phone_seq_len], dtype=tf.int32), + 'n_time_steps': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar + 'phone_seq_lens': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar + 'day_indices': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar + 'transcriptions': tf.TensorSpec(shape=[max_transcription_len], dtype=tf.int32), + 'block_nums': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar + 'trial_nums': tf.TensorSpec(shape=[], dtype=tf.int32) # scalar } # Define padding values for each field