f
This commit is contained in:
@@ -110,21 +110,35 @@ class BrainToTextDatasetTF:
|
|||||||
self._preload_all_data()
|
self._preload_all_data()
|
||||||
print(f"✅ Preloading completed - {len(self.data_cache)} trials cached")
|
print(f"✅ Preloading completed - {len(self.data_cache)} trials cached")
|
||||||
|
|
||||||
# ========================= 特征维度自动检测 =========================
|
# ========================= 特征维度自动检测 (更稳健的版本) =========================
|
||||||
# 明确地计算并存储特征维度,避免 padded_batch 时的维度不匹配
|
# 明确地计算并存储特征维度,避免 padded_batch 时的维度不匹配
|
||||||
if self.feature_subset:
|
if self.feature_subset:
|
||||||
self.feature_dim = len(self.feature_subset)
|
self.feature_dim = len(self.feature_subset)
|
||||||
print(f"✅ Using feature subset dimension: {self.feature_dim}")
|
print(f"✅ Using feature subset dimension: {self.feature_dim}")
|
||||||
else:
|
else:
|
||||||
# 不要硬编码!尝试从数据中推断实际特征维度
|
# 稳健地从数据中推断实际特征维度
|
||||||
|
# 遍历所有 day,直到找到第一个包含 trial 的 day
|
||||||
|
detected_dim = None
|
||||||
|
for day in self.trial_indices:
|
||||||
|
# 检查 trial 列表是否非空
|
||||||
|
if self.trial_indices[day]['trials']:
|
||||||
try:
|
try:
|
||||||
first_day = next(iter(self.trial_indices))
|
# 列表非空,尝试加载第一个 trial
|
||||||
first_trial = self.trial_indices[first_day]['trials'][0]
|
first_valid_trial = self.trial_indices[day]['trials'][0]
|
||||||
first_sample = self._load_single_trial_data(first_day, first_trial)
|
first_sample = self._load_single_trial_data(day, first_valid_trial)
|
||||||
self.feature_dim = first_sample['input_features'].shape[1]
|
detected_dim = first_sample['input_features'].shape[1]
|
||||||
print(f"✅ Auto-detected feature dimension: {self.feature_dim}")
|
print(f"✅ Auto-detected feature dimension from day {day}: {detected_dim}")
|
||||||
|
break # 成功检测到维度,跳出循环
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"⚠️ Could not auto-detect feature dimension, falling back to 512. Error: {e}")
|
# 如果加载这个 trial 失败,则继续尝试下一个 day
|
||||||
|
print(f"⚠️ Warning: Could not load trial {first_valid_trial} from day {day} for dimension check. Error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if detected_dim is not None:
|
||||||
|
self.feature_dim = detected_dim
|
||||||
|
else:
|
||||||
|
# 如果遍历完所有 day 都没有找到任何有效的 trial,则报错或回退
|
||||||
|
print(f"⚠️ CRITICAL: Could not auto-detect feature dimension after checking all days. No valid trials found in the dataset split '{self.split}'. Falling back to 512.")
|
||||||
self.feature_dim = 512 # 作为最后的备用方案
|
self.feature_dim = 512 # 作为最后的备用方案
|
||||||
# ========================= 特征维度检测结束 =========================
|
# ========================= 特征维度检测结束 =========================
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user