f
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
import tensorflow as tf
|
||||
import h5py
|
||||
import numpy as np
|
||||
@@ -109,6 +110,24 @@ class BrainToTextDatasetTF:
|
||||
self._preload_all_data()
|
||||
print(f"✅ Preloading completed - {len(self.data_cache)} trials cached")
|
||||
|
||||
# ========================= 特征维度自动检测 =========================
|
||||
# 明确地计算并存储特征维度,避免 padded_batch 时的维度不匹配
|
||||
if self.feature_subset:
|
||||
self.feature_dim = len(self.feature_subset)
|
||||
print(f"✅ Using feature subset dimension: {self.feature_dim}")
|
||||
else:
|
||||
# 不要硬编码!尝试从数据中推断实际特征维度
|
||||
try:
|
||||
first_day = next(iter(self.trial_indices))
|
||||
first_trial = self.trial_indices[first_day]['trials'][0]
|
||||
first_sample = self._load_single_trial_data(first_day, first_trial)
|
||||
self.feature_dim = first_sample['input_features'].shape[1]
|
||||
print(f"✅ Auto-detected feature dimension: {self.feature_dim}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Could not auto-detect feature dimension, falling back to 512. Error: {e}")
|
||||
self.feature_dim = 512 # 作为最后的备用方案
|
||||
# ========================= 特征维度检测结束 =========================
|
||||
|
||||
def _create_batch_index_train(self) -> Dict[int, Dict[int, List[int]]]:
|
||||
"""Create training batch indices with random sampling"""
|
||||
batch_indices = {}
|
||||
@@ -953,6 +972,20 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
num_parallel_calls=tf.data.AUTOTUNE
|
||||
)
|
||||
|
||||
# ========================= 终极调试代码 =========================
|
||||
def debug_print_shape(example):
|
||||
"""调试函数:在 padded_batch 之前打印每个样本的形状"""
|
||||
tf.print("🔍 Sample Shape Debug:",
|
||||
tf.shape(example['input_features']),
|
||||
"Expected feature dim:", dataset_tf.feature_dim,
|
||||
output_stream=sys.stdout)
|
||||
return example
|
||||
|
||||
# 添加形状调试 - 这会在图执行时打印信息
|
||||
dataset = dataset.map(debug_print_shape)
|
||||
print(f"⚠️ Debug mode: Will print each sample shape before padded_batch")
|
||||
# =============================================================
|
||||
|
||||
# Step 4: Batch samples with FIXED STATIC padding (CRITICAL for XLA)
|
||||
print(f"🔧 Using PRE-ANALYZED FIXED shapes for maximum TPU performance:")
|
||||
|
||||
@@ -960,23 +993,28 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
max_time_steps = max_shapes['max_time_steps']
|
||||
max_phone_seq_len = max_shapes['max_phone_seq_len']
|
||||
max_transcription_len = max_shapes['max_transcription_len']
|
||||
n_features = max_shapes['n_features']
|
||||
|
||||
# ========================= 使用统一的特征维度 =========================
|
||||
# 使用 dataset_tf 对象中存储的、经过验证的特征维度,而不是依赖外部参数
|
||||
n_features = dataset_tf.feature_dim # <--- 关键修改:使用自动检测的特征维度
|
||||
print(f"🔧 Using verified feature dimension from dataset: {n_features}")
|
||||
# ========================= 特征维度修改结束 =========================
|
||||
|
||||
print(f" Fixed time steps: {max_time_steps}")
|
||||
print(f" Fixed phone sequence length: {max_phone_seq_len}")
|
||||
print(f" Fixed transcription length: {max_transcription_len}")
|
||||
print(f" Number of features: {n_features}")
|
||||
|
||||
# Define FIXED padded shapes - NO None values for XLA compatibility
|
||||
# Define FIXED padded shapes with TensorSpec for better type safety
|
||||
padded_shapes = {
|
||||
'input_features': tf.TensorShape([max_time_steps, n_features]),
|
||||
'seq_class_ids': tf.TensorShape([max_phone_seq_len]),
|
||||
'n_time_steps': tf.TensorShape([]), # scalar
|
||||
'phone_seq_lens': tf.TensorShape([]), # scalar
|
||||
'day_indices': tf.TensorShape([]), # scalar
|
||||
'transcriptions': tf.TensorShape([max_transcription_len]),
|
||||
'block_nums': tf.TensorShape([]), # scalar
|
||||
'trial_nums': tf.TensorShape([]) # scalar
|
||||
'input_features': tf.TensorSpec(shape=[max_time_steps, n_features], dtype=tf.float32),
|
||||
'seq_class_ids': tf.TensorSpec(shape=[max_phone_seq_len], dtype=tf.int32),
|
||||
'n_time_steps': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
|
||||
'phone_seq_lens': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
|
||||
'day_indices': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
|
||||
'transcriptions': tf.TensorSpec(shape=[max_transcription_len], dtype=tf.int32),
|
||||
'block_nums': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
|
||||
'trial_nums': tf.TensorSpec(shape=[], dtype=tf.int32) # scalar
|
||||
}
|
||||
|
||||
# Define padding values for each field
|
||||
|
Reference in New Issue
Block a user