f
This commit is contained in:
@@ -110,21 +110,35 @@ class BrainToTextDatasetTF:
|
||||
self._preload_all_data()
|
||||
print(f"✅ Preloading completed - {len(self.data_cache)} trials cached")
|
||||
|
||||
# ========================= 特征维度自动检测 =========================
|
||||
# ========================= 特征维度自动检测 (更稳健的版本) =========================
|
||||
# 明确地计算并存储特征维度,避免 padded_batch 时的维度不匹配
|
||||
if self.feature_subset:
|
||||
self.feature_dim = len(self.feature_subset)
|
||||
print(f"✅ Using feature subset dimension: {self.feature_dim}")
|
||||
else:
|
||||
# 不要硬编码!尝试从数据中推断实际特征维度
|
||||
try:
|
||||
first_day = next(iter(self.trial_indices))
|
||||
first_trial = self.trial_indices[first_day]['trials'][0]
|
||||
first_sample = self._load_single_trial_data(first_day, first_trial)
|
||||
self.feature_dim = first_sample['input_features'].shape[1]
|
||||
print(f"✅ Auto-detected feature dimension: {self.feature_dim}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Could not auto-detect feature dimension, falling back to 512. Error: {e}")
|
||||
# 稳健地从数据中推断实际特征维度
|
||||
# 遍历所有 day,直到找到第一个包含 trial 的 day
|
||||
detected_dim = None
|
||||
for day in self.trial_indices:
|
||||
# 检查 trial 列表是否非空
|
||||
if self.trial_indices[day]['trials']:
|
||||
try:
|
||||
# 列表非空,尝试加载第一个 trial
|
||||
first_valid_trial = self.trial_indices[day]['trials'][0]
|
||||
first_sample = self._load_single_trial_data(day, first_valid_trial)
|
||||
detected_dim = first_sample['input_features'].shape[1]
|
||||
print(f"✅ Auto-detected feature dimension from day {day}: {detected_dim}")
|
||||
break # 成功检测到维度,跳出循环
|
||||
except Exception as e:
|
||||
# 如果加载这个 trial 失败,则继续尝试下一个 day
|
||||
print(f"⚠️ Warning: Could not load trial {first_valid_trial} from day {day} for dimension check. Error: {e}")
|
||||
continue
|
||||
|
||||
if detected_dim is not None:
|
||||
self.feature_dim = detected_dim
|
||||
else:
|
||||
# 如果遍历完所有 day 都没有找到任何有效的 trial,则报错或回退
|
||||
print(f"⚠️ CRITICAL: Could not auto-detect feature dimension after checking all days. No valid trials found in the dataset split '{self.split}'. Falling back to 512.")
|
||||
self.feature_dim = 512 # 作为最后的备用方案
|
||||
# ========================= 特征维度检测结束 =========================
|
||||
|
||||
|
Reference in New Issue
Block a user