Enhance error handling and deprecate batch generation methods in BrainToTextDatasetTF

- Improved error logging when loading trial data fails, ensuring correct feature dimensions in dummy data.
- Marked _create_batch_generator and create_dataset methods as deprecated, recommending create_input_fn for better performance.
- Adjusted maximum parallel workers in analyze_dataset_shapes based on CPU cores.
This commit is contained in:
Zchen
2025-10-22 00:28:10 +08:00
parent a031972ba6
commit e715d9ac79

View File

@@ -273,9 +273,13 @@ class BrainToTextDatasetTF:
return trial_data
except Exception as e:
# Return dummy data for failed loads
# Log the error and return dummy data with correct feature dimension
logging.warning(f"Failed to load trial {day}_{trial} from {session_path}. Error: {e}. Returning dummy data.")
# Use self.feature_dim to ensure dimension consistency
feature_dim = self.feature_dim
return {
'input_features': np.zeros((100, 512), dtype=np.float32),
'input_features': np.zeros((100, feature_dim), dtype=np.float32),
'seq_class_ids': np.zeros((10,), dtype=np.int32),
'transcription': np.zeros((50,), dtype=np.int32),
'n_time_steps': 100,
@@ -304,7 +308,19 @@ class BrainToTextDatasetTF:
return trial_data
def _create_batch_generator(self):
"""Generator function that yields individual batches with optimized loading"""
"""
Generator function that yields individual batches with optimized loading
⚠️ DEPRECATED: This method is deprecated. Use create_input_fn() instead for better performance
and TPU compatibility. This method will be removed in a future version.
"""
import warnings
warnings.warn(
"_create_batch_generator is deprecated. Use create_input_fn() instead for better performance.",
DeprecationWarning,
stacklevel=2
)
import time
from concurrent.futures import ThreadPoolExecutor
@@ -421,7 +437,18 @@ class BrainToTextDatasetTF:
yield batch
def create_dataset(self) -> tf.data.Dataset:
"""Create optimized tf.data.Dataset for TPU training"""
"""
Create optimized tf.data.Dataset for TPU training
⚠️ DEPRECATED: This method is deprecated. Use create_input_fn() instead for better performance
and TPU compatibility. This method will be removed in a future version.
"""
import warnings
warnings.warn(
"create_dataset is deprecated. Use create_input_fn() instead for better performance.",
DeprecationWarning,
stacklevel=2
)
# Define output signature for the dataset
output_signature = {
@@ -516,7 +543,7 @@ class BrainToTextDatasetTF:
# Define output signature for individual examples
output_signature = {
'input_features': tf.TensorSpec(shape=(None, None), dtype=tf.float32),
'input_features': tf.TensorSpec(shape=(None, self.feature_dim), dtype=tf.float32),
'seq_class_ids': tf.TensorSpec(shape=(None,), dtype=tf.int32),
'n_time_steps': tf.TensorSpec(shape=(), dtype=tf.int32),
'phone_seq_lens': tf.TensorSpec(shape=(), dtype=tf.int32),
@@ -698,9 +725,19 @@ def train_test_split_indices(file_paths: List[str],
# Get trials in each day
trials_per_day = {}
for i, path in enumerate(file_paths):
# Handle both Windows and Unix path separators
path_parts = path.replace('\\', '/').split('/')
session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
try:
# Handle both Windows and Unix path separators
path_parts = path.replace('\\', '/').split('/')
session_candidates = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))]
if not session_candidates:
logging.error(f"Could not parse session name from path: {path}. Skipping this file.")
continue
session = session_candidates[0]
except Exception as e:
logging.error(f"Error parsing path {path}: {e}. Skipping this file.")
continue
good_trial_indices = []
@@ -791,7 +828,7 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
Returns:
Dictionary with maximum dimensions
"""
print(f"🚀 Starting parallel dataset analysis (sampling: {'ALL' if sample_size == -1 else sample_size}, max workers: 224)...")
print(f"🚀 Starting parallel dataset analysis (sampling: {'ALL' if sample_size == -1 else sample_size})...")
start_time = time.time()
# 1. 收集所有需要分析的 (day, trial) 对,避免重复
@@ -841,7 +878,9 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
return None # 返回 None 表示失败
# 3. 使用 ThreadPoolExecutor 进行并行处理
max_workers = min(224, len(trials_to_check)) # 不超过实际任务数
# Use dynamic calculation based on CPU cores with reasonable upper limit
cpu_count = os.cpu_count() or 4 # Fallback to 4 if cpu_count() returns None
max_workers = min(32, cpu_count, len(trials_to_check))
local_max_shapes = []
print(f"🔧 Using {max_workers} parallel workers for analysis...")
@@ -878,7 +917,7 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
'max_time_steps': int(np.max(unzipped_shapes[0])),
'max_phone_seq_len': int(np.max(unzipped_shapes[1])),
'max_transcription_len': int(np.max(unzipped_shapes[2])),
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
'n_features': dataset_tf.feature_dim
}
# 5. 添加安全边际10% buffer