Enhance error handling and deprecate batch generation methods in BrainToTextDatasetTF
- Improved error logging when loading trial data fails, ensuring correct feature dimensions in dummy data. - Marked _create_batch_generator and create_dataset methods as deprecated, recommending create_input_fn for better performance. - Adjusted maximum parallel workers in analyze_dataset_shapes based on CPU cores.
This commit is contained in:
@@ -273,9 +273,13 @@ class BrainToTextDatasetTF:
|
||||
return trial_data
|
||||
|
||||
except Exception as e:
|
||||
# Return dummy data for failed loads
|
||||
# Log the error and return dummy data with correct feature dimension
|
||||
logging.warning(f"Failed to load trial {day}_{trial} from {session_path}. Error: {e}. Returning dummy data.")
|
||||
|
||||
# Use self.feature_dim to ensure dimension consistency
|
||||
feature_dim = self.feature_dim
|
||||
return {
|
||||
'input_features': np.zeros((100, 512), dtype=np.float32),
|
||||
'input_features': np.zeros((100, feature_dim), dtype=np.float32),
|
||||
'seq_class_ids': np.zeros((10,), dtype=np.int32),
|
||||
'transcription': np.zeros((50,), dtype=np.int32),
|
||||
'n_time_steps': 100,
|
||||
@@ -304,7 +308,19 @@ class BrainToTextDatasetTF:
|
||||
return trial_data
|
||||
|
||||
def _create_batch_generator(self):
|
||||
"""Generator function that yields individual batches with optimized loading"""
|
||||
"""
|
||||
Generator function that yields individual batches with optimized loading
|
||||
|
||||
⚠️ DEPRECATED: This method is deprecated. Use create_input_fn() instead for better performance
|
||||
and TPU compatibility. This method will be removed in a future version.
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"_create_batch_generator is deprecated. Use create_input_fn() instead for better performance.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
@@ -421,7 +437,18 @@ class BrainToTextDatasetTF:
|
||||
yield batch
|
||||
|
||||
def create_dataset(self) -> tf.data.Dataset:
|
||||
"""Create optimized tf.data.Dataset for TPU training"""
|
||||
"""
|
||||
Create optimized tf.data.Dataset for TPU training
|
||||
|
||||
⚠️ DEPRECATED: This method is deprecated. Use create_input_fn() instead for better performance
|
||||
and TPU compatibility. This method will be removed in a future version.
|
||||
"""
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"create_dataset is deprecated. Use create_input_fn() instead for better performance.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
|
||||
# Define output signature for the dataset
|
||||
output_signature = {
|
||||
@@ -516,7 +543,7 @@ class BrainToTextDatasetTF:
|
||||
|
||||
# Define output signature for individual examples
|
||||
output_signature = {
|
||||
'input_features': tf.TensorSpec(shape=(None, None), dtype=tf.float32),
|
||||
'input_features': tf.TensorSpec(shape=(None, self.feature_dim), dtype=tf.float32),
|
||||
'seq_class_ids': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||||
'n_time_steps': tf.TensorSpec(shape=(), dtype=tf.int32),
|
||||
'phone_seq_lens': tf.TensorSpec(shape=(), dtype=tf.int32),
|
||||
@@ -698,9 +725,19 @@ def train_test_split_indices(file_paths: List[str],
|
||||
# Get trials in each day
|
||||
trials_per_day = {}
|
||||
for i, path in enumerate(file_paths):
|
||||
# Handle both Windows and Unix path separators
|
||||
path_parts = path.replace('\\', '/').split('/')
|
||||
session = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))][0]
|
||||
try:
|
||||
# Handle both Windows and Unix path separators
|
||||
path_parts = path.replace('\\', '/').split('/')
|
||||
session_candidates = [s for s in path_parts if (s.startswith('t15.20') or s.startswith('t12.20'))]
|
||||
|
||||
if not session_candidates:
|
||||
logging.error(f"Could not parse session name from path: {path}. Skipping this file.")
|
||||
continue
|
||||
|
||||
session = session_candidates[0]
|
||||
except Exception as e:
|
||||
logging.error(f"Error parsing path {path}: {e}. Skipping this file.")
|
||||
continue
|
||||
|
||||
good_trial_indices = []
|
||||
|
||||
@@ -791,7 +828,7 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
||||
Returns:
|
||||
Dictionary with maximum dimensions
|
||||
"""
|
||||
print(f"🚀 Starting parallel dataset analysis (sampling: {'ALL' if sample_size == -1 else sample_size}, max workers: 224)...")
|
||||
print(f"🚀 Starting parallel dataset analysis (sampling: {'ALL' if sample_size == -1 else sample_size})...")
|
||||
start_time = time.time()
|
||||
|
||||
# 1. 收集所有需要分析的 (day, trial) 对,避免重复
|
||||
@@ -841,7 +878,9 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
||||
return None # 返回 None 表示失败
|
||||
|
||||
# 3. 使用 ThreadPoolExecutor 进行并行处理
|
||||
max_workers = min(224, len(trials_to_check)) # 不超过实际任务数
|
||||
# Use dynamic calculation based on CPU cores with reasonable upper limit
|
||||
cpu_count = os.cpu_count() or 4 # Fallback to 4 if cpu_count() returns None
|
||||
max_workers = min(32, cpu_count, len(trials_to_check))
|
||||
local_max_shapes = []
|
||||
|
||||
print(f"🔧 Using {max_workers} parallel workers for analysis...")
|
||||
@@ -878,7 +917,7 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
||||
'max_time_steps': int(np.max(unzipped_shapes[0])),
|
||||
'max_phone_seq_len': int(np.max(unzipped_shapes[1])),
|
||||
'max_transcription_len': int(np.max(unzipped_shapes[2])),
|
||||
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
|
||||
'n_features': dataset_tf.feature_dim
|
||||
}
|
||||
|
||||
# 5. 添加安全边际(10% buffer)
|
||||
|
Reference in New Issue
Block a user