From e715d9ac790e16ec8fb9ecdf559faf586e84a47b Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Wed, 22 Oct 2025 00:28:10 +0800 Subject: [PATCH] Enhance error handling and deprecate batch generation methods in BrainToTextDatasetTF - Improved error logging when loading trial data fails, ensuring correct feature dimensions in dummy data. - Marked _create_batch_generator and create_dataset methods as deprecated, recommending create_input_fn for better performance. - Adjusted maximum parallel workers in analyze_dataset_shapes based on CPU cores. --- model_training_nnn_tpu/dataset_tf.py | 61 +++++++++++++++++++++++----- 1 file changed, 50 insertions(+), 11 deletions(-) diff --git a/model_training_nnn_tpu/dataset_tf.py b/model_training_nnn_tpu/dataset_tf.py index 74ff95f..2ebb3be 100644 --- a/model_training_nnn_tpu/dataset_tf.py +++ b/model_training_nnn_tpu/dataset_tf.py @@ -273,9 +273,13 @@ class BrainToTextDatasetTF: return trial_data except Exception as e: - # Return dummy data for failed loads + # Log the error and return dummy data with correct feature dimension + logging.warning(f"Failed to load trial {day}_{trial} from {session_path}. Error: {e}. Returning dummy data.") + + # Use self.feature_dim to ensure dimension consistency + feature_dim = self.feature_dim return { - 'input_features': np.zeros((100, 512), dtype=np.float32), + 'input_features': np.zeros((100, feature_dim), dtype=np.float32), 'seq_class_ids': np.zeros((10,), dtype=np.int32), 'transcription': np.zeros((50,), dtype=np.int32), 'n_time_steps': 100, @@ -304,7 +308,19 @@ class BrainToTextDatasetTF: return trial_data def _create_batch_generator(self): - """Generator function that yields individual batches with optimized loading""" + """ + Generator function that yields individual batches with optimized loading + + ⚠️ DEPRECATED: This method is deprecated. Use create_input_fn() instead for better performance + and TPU compatibility. This method will be removed in a future version. + """ + import warnings + warnings.warn( + "_create_batch_generator is deprecated. Use create_input_fn() instead for better performance.", + DeprecationWarning, + stacklevel=2 + ) + import time from concurrent.futures import ThreadPoolExecutor @@ -421,7 +437,18 @@ class BrainToTextDatasetTF: yield batch def create_dataset(self) -> tf.data.Dataset: - """Create optimized tf.data.Dataset for TPU training""" + """ + Create optimized tf.data.Dataset for TPU training + + ⚠️ DEPRECATED: This method is deprecated. Use create_input_fn() instead for better performance + and TPU compatibility. This method will be removed in a future version. + """ + import warnings + warnings.warn( + "create_dataset is deprecated. Use create_input_fn() instead for better performance.", + DeprecationWarning, + stacklevel=2 + ) # Define output signature for the dataset output_signature = { @@ -516,7 +543,7 @@ class BrainToTextDatasetTF: # Define output signature for individual examples output_signature = { - 'input_features': tf.TensorSpec(shape=(None, None), dtype=tf.float32), + 'input_features': tf.TensorSpec(shape=(None, self.feature_dim), dtype=tf.float32), 'seq_class_ids': tf.TensorSpec(shape=(None,), dtype=tf.int32), 'n_time_steps': tf.TensorSpec(shape=(), dtype=tf.int32), 'phone_seq_lens': tf.TensorSpec(shape=(), dtype=tf.int32), @@ -698,9 +725,19 @@ def train_test_split_indices(file_paths: List[str], # Get trials in each day trials_per_day = {} for i, path in enumerate(file_paths): - # Handle both Windows and Unix path separators - path_parts = path.replace('\\', '/').split('/') - session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0] + try: + # Handle both Windows and Unix path separators + path_parts = path.replace('\\', '/').split('/') + session_candidates = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))] + + if not session_candidates: + logging.error(f"Could not parse session name from path: {path}. Skipping this file.") + continue + + session = session_candidates[0] + except Exception as e: + logging.error(f"Error parsing path {path}: {e}. Skipping this file.") + continue good_trial_indices = [] @@ -791,7 +828,7 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = Returns: Dictionary with maximum dimensions """ - print(f"🚀 Starting parallel dataset analysis (sampling: {'ALL' if sample_size == -1 else sample_size}, max workers: 224)...") + print(f"🚀 Starting parallel dataset analysis (sampling: {'ALL' if sample_size == -1 else sample_size})...") start_time = time.time() # 1. 收集所有需要分析的 (day, trial) 对,避免重复 @@ -841,7 +878,9 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = return None # 返回 None 表示失败 # 3. 使用 ThreadPoolExecutor 进行并行处理 - max_workers = min(224, len(trials_to_check)) # 不超过实际任务数 + # Use dynamic calculation based on CPU cores with reasonable upper limit + cpu_count = os.cpu_count() or 4 # Fallback to 4 if cpu_count() returns None + max_workers = min(32, cpu_count, len(trials_to_check)) local_max_shapes = [] print(f"🔧 Using {max_workers} parallel workers for analysis...") @@ -878,7 +917,7 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 'max_time_steps': int(np.max(unzipped_shapes[0])), 'max_phone_seq_len': int(np.max(unzipped_shapes[1])), 'max_transcription_len': int(np.max(unzipped_shapes[2])), - 'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512 + 'n_features': dataset_tf.feature_dim } # 5. 添加安全边际(10% buffer)