Enhance dataset shape analysis by implementing parallel processing and improving sampling logic

This commit is contained in:
Zchen
2025-10-20 00:35:17 +08:00
parent e1669b5a4c
commit fabf70cfa9

View File

@@ -4,9 +4,13 @@ import h5py
import numpy as np import numpy as np
import math import math
import logging import logging
import time
import random
import multiprocessing
from itertools import groupby from itertools import groupby
from typing import Dict, List, Tuple, Optional, Any from typing import Dict, List, Tuple, Optional, Any
from scipy.ndimage import gaussian_filter1d from scipy.ndimage import gaussian_filter1d
from concurrent.futures import ThreadPoolExecutor, as_completed
class BrainToTextDatasetTF: class BrainToTextDatasetTF:
@@ -758,7 +762,8 @@ def train_test_split_indices(file_paths: List[str],
def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 100) -> Dict[str, int]: def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 100) -> Dict[str, int]:
""" """
Analyze dataset to determine maximum shapes for padded_batch Analyzes dataset shapes in parallel to determine maximum dimensions for padded_batch,
utilizing multiple CPU cores and the dataset's caching mechanism.
Args: Args:
dataset_tf: Dataset instance to analyze dataset_tf: Dataset instance to analyze
@@ -767,79 +772,118 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
Returns: Returns:
Dictionary with maximum dimensions Dictionary with maximum dimensions
""" """
print(f"🔍 Analyzing dataset shapes (sampling {sample_size if sample_size > 0 else 'ALL'} examples)...") print(f"🚀 Starting parallel dataset analysis (sampling: {'ALL' if sample_size == -1 else sample_size}, max workers: 224)...")
start_time = time.time()
max_shapes = { # 1. 收集所有需要分析的 (day, trial) 对,避免重复
'max_time_steps': 0,
'max_phone_seq_len': 0,
'max_transcription_len': 0,
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
}
# 1. Create a flat list of all (day, trial) pairs
all_trials = [] all_trials = []
unique_trials = set()
for batch_idx in sorted(dataset_tf.batch_indices.keys()): for batch_idx in sorted(dataset_tf.batch_indices.keys()):
batch = dataset_tf.batch_indices[batch_idx] batch = dataset_tf.batch_indices[batch_idx]
for day, trials in batch.items(): for day, trials in batch.items():
for trial in trials: for trial in trials:
if (day, trial) not in unique_trials:
unique_trials.add((day, trial))
all_trials.append((day, trial)) all_trials.append((day, trial))
# 2. Sample from the list if needed # 2. 如果需要采样,则对列表进行采样
if sample_size > 0 and len(all_trials) > sample_size: if 0 < sample_size < len(all_trials):
import random # 设置种子以确保可重现性(如果需要的话)
if dataset_tf.split == 'train': # Random sampling for training random.seed(42)
trials_to_check = random.sample(all_trials, sample_size) trials_to_check = random.sample(all_trials, sample_size)
else: # Sequential sampling for validation
trials_to_check = all_trials[:sample_size]
else: else:
trials_to_check = all_trials trials_to_check = all_trials
print(f"📋 Will analyze {len(trials_to_check)} trials from {len(all_trials)} total trials") total_trials_to_analyze = len(trials_to_check)
print(f"📊 Total unique trials to analyze: {total_trials_to_analyze}")
# 3. Iterate through the final list # 定义一个辅助函数,供每个线程调用
count = 0 def analyze_single_trial(day_trial_pair):
for day, trial in trials_to_check: """Loads and analyzes a single trial, returns its shapes."""
day, trial = day_trial_pair
try: try:
session_path = dataset_tf.trial_indices[day]['session_path'] # 复用 dataset_tf 的加载和缓存逻辑
with h5py.File(session_path, 'r') as f: trial_data = dataset_tf._load_trial_data(day, trial)
g = f[f'trial_{trial:04d}']
# Check dimensions # 直接从加载的数据中获取信息
time_steps = int(g.attrs['n_time_steps']) time_steps = int(trial_data['n_time_steps'])
phone_seq_len = int(g.attrs['seq_len']) phone_seq_len = int(trial_data['phone_seq_lens'])
transcription_data = g['transcription'][:]
# 处理 transcription 数据 - 它可能是数组
transcription_data = trial_data['transcription']
if hasattr(transcription_data, '__len__'):
transcription_len = len(transcription_data) transcription_len = len(transcription_data)
else:
transcription_len = 1 # 如果是标量长度为1
max_shapes['max_time_steps'] = max(max_shapes['max_time_steps'], time_steps) return (time_steps, phone_seq_len, transcription_len)
max_shapes['max_phone_seq_len'] = max(max_shapes['max_phone_seq_len'], phone_seq_len)
max_shapes['max_transcription_len'] = max(max_shapes['max_transcription_len'], transcription_len)
count += 1
# Show progress for large analyses
if count % 100 == 0:
print(f" Analyzed {count}/{len(trials_to_check)} samples... current max: time={max_shapes['max_time_steps']}, phone={max_shapes['max_phone_seq_len']}, trans={max_shapes['max_transcription_len']}")
except Exception as e: except Exception as e:
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}") logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
continue return None # 返回 None 表示失败
# Add safety margins (20% buffer) to handle edge cases # 3. 使用 ThreadPoolExecutor 进行并行处理
original_time_steps = max_shapes['max_time_steps'] max_workers = min(224, len(trials_to_check)) # 不超过实际任务数
original_phone_seq_len = max_shapes['max_phone_seq_len'] local_max_shapes = []
original_transcription_len = max_shapes['max_transcription_len']
max_shapes['max_time_steps'] = int(max_shapes['max_time_steps'] * 1.2) print(f"🔧 Using {max_workers} parallel workers for analysis...")
max_shapes['max_phone_seq_len'] = int(max_shapes['max_phone_seq_len'] * 1.2)
max_shapes['max_transcription_len'] = int(max_shapes['max_transcription_len'] * 1.2)
print(f"📊 Dataset analysis complete (analyzed {count} samples):") with ThreadPoolExecutor(max_workers=max_workers) as executor:
print(f" Original max time steps: {original_time_steps} → Padded: {max_shapes['max_time_steps']}") # 提交所有分析任务
print(f" Original max phone sequence length: {original_phone_seq_len} → Padded: {max_shapes['max_phone_seq_len']}") future_to_trial = {
print(f" Original max transcription length: {original_transcription_len} → Padded: {max_shapes['max_transcription_len']}") executor.submit(analyze_single_trial, trial_pair): trial_pair
print(f" Number of features: {max_shapes['n_features']}") for trial_pair in trials_to_check
}
return max_shapes count = 0
progress_interval = max(1, total_trials_to_analyze // 20) # 显示20次进度更新
for future in as_completed(future_to_trial):
result = future.result()
if result:
local_max_shapes.append(result)
count += 1
if count % progress_interval == 0 or count == total_trials_to_analyze:
elapsed = time.time() - start_time
rate = count / elapsed if elapsed > 0 else 0
print(f" 📈 Analyzed {count}/{total_trials_to_analyze} trials... ({count/total_trials_to_analyze:.1%}) [{rate:.1f} trials/sec]")
# 4. 聚合所有线程的结果
if not local_max_shapes:
raise ValueError("Dataset analysis failed: No trials could be successfully analyzed.")
# 将 [(t1, p1, tr1), (t2, p2, tr2), ...] 转换为 ([t1, t2, ...], [p1, p2, ...], ...)
unzipped_shapes = list(zip(*local_max_shapes))
original_max_shapes = {
'max_time_steps': int(np.max(unzipped_shapes[0])),
'max_phone_seq_len': int(np.max(unzipped_shapes[1])),
'max_transcription_len': int(np.max(unzipped_shapes[2])),
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
}
# 5. 添加安全边际10% buffer
safety_margin = 1.1
final_max_shapes = {
'max_time_steps': int(original_max_shapes['max_time_steps'] * safety_margin),
'max_phone_seq_len': int(original_max_shapes['max_phone_seq_len'] * safety_margin),
'max_transcription_len': int(original_max_shapes['max_transcription_len'] * safety_margin),
'n_features': original_max_shapes['n_features']
}
analysis_time = time.time() - start_time
successful_rate = len(local_max_shapes) / total_trials_to_analyze * 100
print(f"✅ Parallel analysis complete in {analysis_time:.2f} seconds!")
print(f"📊 Successfully analyzed {len(local_max_shapes)}/{total_trials_to_analyze} trials ({successful_rate:.1f}%)")
print(f"📏 Final max shapes (with {int((safety_margin-1)*100)}% safety margin):")
print(f" Time steps: {original_max_shapes['max_time_steps']}{final_max_shapes['max_time_steps']}")
print(f" Phone sequence length: {original_max_shapes['max_phone_seq_len']}{final_max_shapes['max_phone_seq_len']}")
print(f" Transcription length: {original_max_shapes['max_transcription_len']}{final_max_shapes['max_transcription_len']}")
print(f" Number of features: {final_max_shapes['n_features']}")
return final_max_shapes
# Utility functions for TPU-optimized data pipeline # Utility functions for TPU-optimized data pipeline