diff --git a/model_training_nnn_tpu/dataset_tf.py b/model_training_nnn_tpu/dataset_tf.py index 2ebb3be..6200535 100644 --- a/model_training_nnn_tpu/dataset_tf.py +++ b/model_training_nnn_tpu/dataset_tf.py @@ -110,21 +110,35 @@ class BrainToTextDatasetTF: self._preload_all_data() print(f"✅ Preloading completed - {len(self.data_cache)} trials cached") - # ========================= 特征维度自动检测 ========================= + # ========================= 特征维度自动检测 (更稳健的版本) ========================= # 明确地计算并存储特征维度,避免 padded_batch 时的维度不匹配 if self.feature_subset: self.feature_dim = len(self.feature_subset) print(f"✅ Using feature subset dimension: {self.feature_dim}") else: - # 不要硬编码!尝试从数据中推断实际特征维度 - try: - first_day = next(iter(self.trial_indices)) - first_trial = self.trial_indices[first_day]['trials'][0] - first_sample = self._load_single_trial_data(first_day, first_trial) - self.feature_dim = first_sample['input_features'].shape[1] - print(f"✅ Auto-detected feature dimension: {self.feature_dim}") - except Exception as e: - print(f"⚠️ Could not auto-detect feature dimension, falling back to 512. Error: {e}") + # 稳健地从数据中推断实际特征维度 + # 遍历所有 day,直到找到第一个包含 trial 的 day + detected_dim = None + for day in self.trial_indices: + # 检查 trial 列表是否非空 + if self.trial_indices[day]['trials']: + try: + # 列表非空,尝试加载第一个 trial + first_valid_trial = self.trial_indices[day]['trials'][0] + first_sample = self._load_single_trial_data(day, first_valid_trial) + detected_dim = first_sample['input_features'].shape[1] + print(f"✅ Auto-detected feature dimension from day {day}: {detected_dim}") + break # 成功检测到维度,跳出循环 + except Exception as e: + # 如果加载这个 trial 失败,则继续尝试下一个 day + print(f"⚠️ Warning: Could not load trial {first_valid_trial} from day {day} for dimension check. Error: {e}") + continue + + if detected_dim is not None: + self.feature_dim = detected_dim + else: + # 如果遍历完所有 day 都没有找到任何有效的 trial,则报错或回退 + print(f"⚠️ CRITICAL: Could not auto-detect feature dimension after checking all days. No valid trials found in the dataset split '{self.split}'. Falling back to 512.") self.feature_dim = 512 # 作为最后的备用方案 # ========================= 特征维度检测结束 =========================