This commit is contained in:
Zchen
2025-10-22 00:54:20 +08:00
parent 52a9b17375
commit 57f07434ac

View File

@@ -110,21 +110,35 @@ class BrainToTextDatasetTF:
self._preload_all_data()
print(f"✅ Preloading completed - {len(self.data_cache)} trials cached")
# ========================= 特征维度自动检测 =========================
# ========================= 特征维度自动检测 (更稳健的版本) =========================
# 明确地计算并存储特征维度,避免 padded_batch 时的维度不匹配
if self.feature_subset:
self.feature_dim = len(self.feature_subset)
print(f"✅ Using feature subset dimension: {self.feature_dim}")
else:
# 不要硬编码!尝试从数据中推断实际特征维度
try:
first_day = next(iter(self.trial_indices))
first_trial = self.trial_indices[first_day]['trials'][0]
first_sample = self._load_single_trial_data(first_day, first_trial)
self.feature_dim = first_sample['input_features'].shape[1]
print(f"✅ Auto-detected feature dimension: {self.feature_dim}")
except Exception as e:
print(f"⚠️ Could not auto-detect feature dimension, falling back to 512. Error: {e}")
# 稳健地从数据中推断实际特征维度
# 遍历所有 day直到找到第一个包含 trial 的 day
detected_dim = None
for day in self.trial_indices:
# 检查 trial 列表是否非空
if self.trial_indices[day]['trials']:
try:
# 列表非空,尝试加载第一个 trial
first_valid_trial = self.trial_indices[day]['trials'][0]
first_sample = self._load_single_trial_data(day, first_valid_trial)
detected_dim = first_sample['input_features'].shape[1]
print(f"✅ Auto-detected feature dimension from day {day}: {detected_dim}")
break # 成功检测到维度,跳出循环
except Exception as e:
# 如果加载这个 trial 失败,则继续尝试下一个 day
print(f"⚠️ Warning: Could not load trial {first_valid_trial} from day {day} for dimension check. Error: {e}")
continue
if detected_dim is not None:
self.feature_dim = detected_dim
else:
# 如果遍历完所有 day 都没有找到任何有效的 trial则报错或回退
print(f"⚠️ CRITICAL: Could not auto-detect feature dimension after checking all days. No valid trials found in the dataset split '{self.split}'. Falling back to 512.")
self.feature_dim = 512 # 作为最后的备用方案
# ========================= 特征维度检测结束 =========================