Enhance dataset shape analysis by implementing parallel processing and improving sampling logic
This commit is contained in:
@@ -4,9 +4,13 @@ import h5py
|
||||
import numpy as np
|
||||
import math
|
||||
import logging
|
||||
import time
|
||||
import random
|
||||
import multiprocessing
|
||||
from itertools import groupby
|
||||
from typing import Dict, List, Tuple, Optional, Any
|
||||
from scipy.ndimage import gaussian_filter1d
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
|
||||
class BrainToTextDatasetTF:
|
||||
@@ -758,7 +762,8 @@ def train_test_split_indices(file_paths: List[str],
|
||||
|
||||
def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 100) -> Dict[str, int]:
|
||||
"""
|
||||
Analyze dataset to determine maximum shapes for padded_batch
|
||||
Analyzes dataset shapes in parallel to determine maximum dimensions for padded_batch,
|
||||
utilizing multiple CPU cores and the dataset's caching mechanism.
|
||||
|
||||
Args:
|
||||
dataset_tf: Dataset instance to analyze
|
||||
@@ -767,79 +772,118 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
||||
Returns:
|
||||
Dictionary with maximum dimensions
|
||||
"""
|
||||
print(f"🔍 Analyzing dataset shapes (sampling {sample_size if sample_size > 0 else 'ALL'} examples)...")
|
||||
print(f"🚀 Starting parallel dataset analysis (sampling: {'ALL' if sample_size == -1 else sample_size}, max workers: 224)...")
|
||||
start_time = time.time()
|
||||
|
||||
max_shapes = {
|
||||
'max_time_steps': 0,
|
||||
'max_phone_seq_len': 0,
|
||||
'max_transcription_len': 0,
|
||||
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
|
||||
}
|
||||
|
||||
# 1. Create a flat list of all (day, trial) pairs
|
||||
# 1. 收集所有需要分析的 (day, trial) 对,避免重复
|
||||
all_trials = []
|
||||
unique_trials = set()
|
||||
for batch_idx in sorted(dataset_tf.batch_indices.keys()):
|
||||
batch = dataset_tf.batch_indices[batch_idx]
|
||||
for day, trials in batch.items():
|
||||
for trial in trials:
|
||||
if (day, trial) not in unique_trials:
|
||||
unique_trials.add((day, trial))
|
||||
all_trials.append((day, trial))
|
||||
|
||||
# 2. Sample from the list if needed
|
||||
if sample_size > 0 and len(all_trials) > sample_size:
|
||||
import random
|
||||
if dataset_tf.split == 'train': # Random sampling for training
|
||||
# 2. 如果需要采样,则对列表进行采样
|
||||
if 0 < sample_size < len(all_trials):
|
||||
# 设置种子以确保可重现性(如果需要的话)
|
||||
random.seed(42)
|
||||
trials_to_check = random.sample(all_trials, sample_size)
|
||||
else: # Sequential sampling for validation
|
||||
trials_to_check = all_trials[:sample_size]
|
||||
else:
|
||||
trials_to_check = all_trials
|
||||
|
||||
print(f"📋 Will analyze {len(trials_to_check)} trials from {len(all_trials)} total trials")
|
||||
total_trials_to_analyze = len(trials_to_check)
|
||||
print(f"📊 Total unique trials to analyze: {total_trials_to_analyze}")
|
||||
|
||||
# 3. Iterate through the final list
|
||||
count = 0
|
||||
for day, trial in trials_to_check:
|
||||
# 定义一个辅助函数,供每个线程调用
|
||||
def analyze_single_trial(day_trial_pair):
|
||||
"""Loads and analyzes a single trial, returns its shapes."""
|
||||
day, trial = day_trial_pair
|
||||
try:
|
||||
session_path = dataset_tf.trial_indices[day]['session_path']
|
||||
with h5py.File(session_path, 'r') as f:
|
||||
g = f[f'trial_{trial:04d}']
|
||||
# 复用 dataset_tf 的加载和缓存逻辑
|
||||
trial_data = dataset_tf._load_trial_data(day, trial)
|
||||
|
||||
# Check dimensions
|
||||
time_steps = int(g.attrs['n_time_steps'])
|
||||
phone_seq_len = int(g.attrs['seq_len'])
|
||||
transcription_data = g['transcription'][:]
|
||||
# 直接从加载的数据中获取信息
|
||||
time_steps = int(trial_data['n_time_steps'])
|
||||
phone_seq_len = int(trial_data['phone_seq_lens'])
|
||||
|
||||
# 处理 transcription 数据 - 它可能是数组
|
||||
transcription_data = trial_data['transcription']
|
||||
if hasattr(transcription_data, '__len__'):
|
||||
transcription_len = len(transcription_data)
|
||||
else:
|
||||
transcription_len = 1 # 如果是标量,长度为1
|
||||
|
||||
max_shapes['max_time_steps'] = max(max_shapes['max_time_steps'], time_steps)
|
||||
max_shapes['max_phone_seq_len'] = max(max_shapes['max_phone_seq_len'], phone_seq_len)
|
||||
max_shapes['max_transcription_len'] = max(max_shapes['max_transcription_len'], transcription_len)
|
||||
|
||||
count += 1
|
||||
|
||||
# Show progress for large analyses
|
||||
if count % 100 == 0:
|
||||
print(f" Analyzed {count}/{len(trials_to_check)} samples... current max: time={max_shapes['max_time_steps']}, phone={max_shapes['max_phone_seq_len']}, trans={max_shapes['max_transcription_len']}")
|
||||
|
||||
return (time_steps, phone_seq_len, transcription_len)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
|
||||
continue
|
||||
return None # 返回 None 表示失败
|
||||
|
||||
# Add safety margins (20% buffer) to handle edge cases
|
||||
original_time_steps = max_shapes['max_time_steps']
|
||||
original_phone_seq_len = max_shapes['max_phone_seq_len']
|
||||
original_transcription_len = max_shapes['max_transcription_len']
|
||||
# 3. 使用 ThreadPoolExecutor 进行并行处理
|
||||
max_workers = min(224, len(trials_to_check)) # 不超过实际任务数
|
||||
local_max_shapes = []
|
||||
|
||||
max_shapes['max_time_steps'] = int(max_shapes['max_time_steps'] * 1.2)
|
||||
max_shapes['max_phone_seq_len'] = int(max_shapes['max_phone_seq_len'] * 1.2)
|
||||
max_shapes['max_transcription_len'] = int(max_shapes['max_transcription_len'] * 1.2)
|
||||
print(f"🔧 Using {max_workers} parallel workers for analysis...")
|
||||
|
||||
print(f"📊 Dataset analysis complete (analyzed {count} samples):")
|
||||
print(f" Original max time steps: {original_time_steps} → Padded: {max_shapes['max_time_steps']}")
|
||||
print(f" Original max phone sequence length: {original_phone_seq_len} → Padded: {max_shapes['max_phone_seq_len']}")
|
||||
print(f" Original max transcription length: {original_transcription_len} → Padded: {max_shapes['max_transcription_len']}")
|
||||
print(f" Number of features: {max_shapes['n_features']}")
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交所有分析任务
|
||||
future_to_trial = {
|
||||
executor.submit(analyze_single_trial, trial_pair): trial_pair
|
||||
for trial_pair in trials_to_check
|
||||
}
|
||||
|
||||
return max_shapes
|
||||
count = 0
|
||||
progress_interval = max(1, total_trials_to_analyze // 20) # 显示20次进度更新
|
||||
|
||||
for future in as_completed(future_to_trial):
|
||||
result = future.result()
|
||||
if result:
|
||||
local_max_shapes.append(result)
|
||||
|
||||
count += 1
|
||||
if count % progress_interval == 0 or count == total_trials_to_analyze:
|
||||
elapsed = time.time() - start_time
|
||||
rate = count / elapsed if elapsed > 0 else 0
|
||||
print(f" 📈 Analyzed {count}/{total_trials_to_analyze} trials... ({count/total_trials_to_analyze:.1%}) [{rate:.1f} trials/sec]")
|
||||
|
||||
# 4. 聚合所有线程的结果
|
||||
if not local_max_shapes:
|
||||
raise ValueError("Dataset analysis failed: No trials could be successfully analyzed.")
|
||||
|
||||
# 将 [(t1, p1, tr1), (t2, p2, tr2), ...] 转换为 ([t1, t2, ...], [p1, p2, ...], ...)
|
||||
unzipped_shapes = list(zip(*local_max_shapes))
|
||||
|
||||
original_max_shapes = {
|
||||
'max_time_steps': int(np.max(unzipped_shapes[0])),
|
||||
'max_phone_seq_len': int(np.max(unzipped_shapes[1])),
|
||||
'max_transcription_len': int(np.max(unzipped_shapes[2])),
|
||||
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
|
||||
}
|
||||
|
||||
# 5. 添加安全边际(10% buffer)
|
||||
safety_margin = 1.1
|
||||
|
||||
final_max_shapes = {
|
||||
'max_time_steps': int(original_max_shapes['max_time_steps'] * safety_margin),
|
||||
'max_phone_seq_len': int(original_max_shapes['max_phone_seq_len'] * safety_margin),
|
||||
'max_transcription_len': int(original_max_shapes['max_transcription_len'] * safety_margin),
|
||||
'n_features': original_max_shapes['n_features']
|
||||
}
|
||||
|
||||
analysis_time = time.time() - start_time
|
||||
successful_rate = len(local_max_shapes) / total_trials_to_analyze * 100
|
||||
|
||||
print(f"✅ Parallel analysis complete in {analysis_time:.2f} seconds!")
|
||||
print(f"📊 Successfully analyzed {len(local_max_shapes)}/{total_trials_to_analyze} trials ({successful_rate:.1f}%)")
|
||||
print(f"📏 Final max shapes (with {int((safety_margin-1)*100)}% safety margin):")
|
||||
print(f" Time steps: {original_max_shapes['max_time_steps']} → {final_max_shapes['max_time_steps']}")
|
||||
print(f" Phone sequence length: {original_max_shapes['max_phone_seq_len']} → {final_max_shapes['max_phone_seq_len']}")
|
||||
print(f" Transcription length: {original_max_shapes['max_transcription_len']} → {final_max_shapes['max_transcription_len']}")
|
||||
print(f" Number of features: {final_max_shapes['n_features']}")
|
||||
|
||||
return final_max_shapes
|
||||
|
||||
|
||||
# Utility functions for TPU-optimized data pipeline
|
||||
|
Reference in New Issue
Block a user