Enhance dataset shape analysis by implementing parallel processing and improving sampling logic

This commit is contained in:
Zchen
2025-10-20 00:35:17 +08:00
parent e1669b5a4c
commit fabf70cfa9

View File

@@ -4,9 +4,13 @@ import h5py
import numpy as np
import math
import logging
import time
import random
import multiprocessing
from itertools import groupby
from typing import Dict, List, Tuple, Optional, Any
from scipy.ndimage import gaussian_filter1d
from concurrent.futures import ThreadPoolExecutor, as_completed
class BrainToTextDatasetTF:
@@ -758,7 +762,8 @@ def train_test_split_indices(file_paths: List[str],
def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 100) -> Dict[str, int]:
"""
Analyze dataset to determine maximum shapes for padded_batch
Analyzes dataset shapes in parallel to determine maximum dimensions for padded_batch,
utilizing multiple CPU cores and the dataset's caching mechanism.
Args:
dataset_tf: Dataset instance to analyze
@@ -767,79 +772,118 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
Returns:
Dictionary with maximum dimensions
"""
print(f"🔍 Analyzing dataset shapes (sampling {sample_size if sample_size > 0 else 'ALL'} examples)...")
print(f"🚀 Starting parallel dataset analysis (sampling: {'ALL' if sample_size == -1 else sample_size}, max workers: 224)...")
start_time = time.time()
max_shapes = {
'max_time_steps': 0,
'max_phone_seq_len': 0,
'max_transcription_len': 0,
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
}
# 1. Create a flat list of all (day, trial) pairs
# 1. 收集所有需要分析的 (day, trial) 对,避免重复
all_trials = []
unique_trials = set()
for batch_idx in sorted(dataset_tf.batch_indices.keys()):
batch = dataset_tf.batch_indices[batch_idx]
for day, trials in batch.items():
for trial in trials:
if (day, trial) not in unique_trials:
unique_trials.add((day, trial))
all_trials.append((day, trial))
# 2. Sample from the list if needed
if sample_size > 0 and len(all_trials) > sample_size:
import random
if dataset_tf.split == 'train': # Random sampling for training
# 2. 如果需要采样,则对列表进行采样
if 0 < sample_size < len(all_trials):
# 设置种子以确保可重现性(如果需要的话)
random.seed(42)
trials_to_check = random.sample(all_trials, sample_size)
else: # Sequential sampling for validation
trials_to_check = all_trials[:sample_size]
else:
trials_to_check = all_trials
print(f"📋 Will analyze {len(trials_to_check)} trials from {len(all_trials)} total trials")
total_trials_to_analyze = len(trials_to_check)
print(f"📊 Total unique trials to analyze: {total_trials_to_analyze}")
# 3. Iterate through the final list
count = 0
for day, trial in trials_to_check:
# 定义一个辅助函数,供每个线程调用
def analyze_single_trial(day_trial_pair):
"""Loads and analyzes a single trial, returns its shapes."""
day, trial = day_trial_pair
try:
session_path = dataset_tf.trial_indices[day]['session_path']
with h5py.File(session_path, 'r') as f:
g = f[f'trial_{trial:04d}']
# 复用 dataset_tf 的加载和缓存逻辑
trial_data = dataset_tf._load_trial_data(day, trial)
# Check dimensions
time_steps = int(g.attrs['n_time_steps'])
phone_seq_len = int(g.attrs['seq_len'])
transcription_data = g['transcription'][:]
# 直接从加载的数据中获取信息
time_steps = int(trial_data['n_time_steps'])
phone_seq_len = int(trial_data['phone_seq_lens'])
# 处理 transcription 数据 - 它可能是数组
transcription_data = trial_data['transcription']
if hasattr(transcription_data, '__len__'):
transcription_len = len(transcription_data)
else:
transcription_len = 1 # 如果是标量长度为1
max_shapes['max_time_steps'] = max(max_shapes['max_time_steps'], time_steps)
max_shapes['max_phone_seq_len'] = max(max_shapes['max_phone_seq_len'], phone_seq_len)
max_shapes['max_transcription_len'] = max(max_shapes['max_transcription_len'], transcription_len)
count += 1
# Show progress for large analyses
if count % 100 == 0:
print(f" Analyzed {count}/{len(trials_to_check)} samples... current max: time={max_shapes['max_time_steps']}, phone={max_shapes['max_phone_seq_len']}, trans={max_shapes['max_transcription_len']}")
return (time_steps, phone_seq_len, transcription_len)
except Exception as e:
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
continue
return None # 返回 None 表示失败
# Add safety margins (20% buffer) to handle edge cases
original_time_steps = max_shapes['max_time_steps']
original_phone_seq_len = max_shapes['max_phone_seq_len']
original_transcription_len = max_shapes['max_transcription_len']
# 3. 使用 ThreadPoolExecutor 进行并行处理
max_workers = min(224, len(trials_to_check)) # 不超过实际任务数
local_max_shapes = []
max_shapes['max_time_steps'] = int(max_shapes['max_time_steps'] * 1.2)
max_shapes['max_phone_seq_len'] = int(max_shapes['max_phone_seq_len'] * 1.2)
max_shapes['max_transcription_len'] = int(max_shapes['max_transcription_len'] * 1.2)
print(f"🔧 Using {max_workers} parallel workers for analysis...")
print(f"📊 Dataset analysis complete (analyzed {count} samples):")
print(f" Original max time steps: {original_time_steps} → Padded: {max_shapes['max_time_steps']}")
print(f" Original max phone sequence length: {original_phone_seq_len} → Padded: {max_shapes['max_phone_seq_len']}")
print(f" Original max transcription length: {original_transcription_len} → Padded: {max_shapes['max_transcription_len']}")
print(f" Number of features: {max_shapes['n_features']}")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交所有分析任务
future_to_trial = {
executor.submit(analyze_single_trial, trial_pair): trial_pair
for trial_pair in trials_to_check
}
return max_shapes
count = 0
progress_interval = max(1, total_trials_to_analyze // 20) # 显示20次进度更新
for future in as_completed(future_to_trial):
result = future.result()
if result:
local_max_shapes.append(result)
count += 1
if count % progress_interval == 0 or count == total_trials_to_analyze:
elapsed = time.time() - start_time
rate = count / elapsed if elapsed > 0 else 0
print(f" 📈 Analyzed {count}/{total_trials_to_analyze} trials... ({count/total_trials_to_analyze:.1%}) [{rate:.1f} trials/sec]")
# 4. 聚合所有线程的结果
if not local_max_shapes:
raise ValueError("Dataset analysis failed: No trials could be successfully analyzed.")
# 将 [(t1, p1, tr1), (t2, p2, tr2), ...] 转换为 ([t1, t2, ...], [p1, p2, ...], ...)
unzipped_shapes = list(zip(*local_max_shapes))
original_max_shapes = {
'max_time_steps': int(np.max(unzipped_shapes[0])),
'max_phone_seq_len': int(np.max(unzipped_shapes[1])),
'max_transcription_len': int(np.max(unzipped_shapes[2])),
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
}
# 5. 添加安全边际10% buffer
safety_margin = 1.1
final_max_shapes = {
'max_time_steps': int(original_max_shapes['max_time_steps'] * safety_margin),
'max_phone_seq_len': int(original_max_shapes['max_phone_seq_len'] * safety_margin),
'max_transcription_len': int(original_max_shapes['max_transcription_len'] * safety_margin),
'n_features': original_max_shapes['n_features']
}
analysis_time = time.time() - start_time
successful_rate = len(local_max_shapes) / total_trials_to_analyze * 100
print(f"✅ Parallel analysis complete in {analysis_time:.2f} seconds!")
print(f"📊 Successfully analyzed {len(local_max_shapes)}/{total_trials_to_analyze} trials ({successful_rate:.1f}%)")
print(f"📏 Final max shapes (with {int((safety_margin-1)*100)}% safety margin):")
print(f" Time steps: {original_max_shapes['max_time_steps']}{final_max_shapes['max_time_steps']}")
print(f" Phone sequence length: {original_max_shapes['max_phone_seq_len']}{final_max_shapes['max_phone_seq_len']}")
print(f" Transcription length: {original_max_shapes['max_transcription_len']}{final_max_shapes['max_transcription_len']}")
print(f" Number of features: {final_max_shapes['n_features']}")
return final_max_shapes
# Utility functions for TPU-optimized data pipeline