Enhance dataset shape analysis by implementing parallel processing and improving sampling logic
This commit is contained in:
@@ -4,9 +4,13 @@ import h5py
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import multiprocessing
|
||||||
from itertools import groupby
|
from itertools import groupby
|
||||||
from typing import Dict, List, Tuple, Optional, Any
|
from typing import Dict, List, Tuple, Optional, Any
|
||||||
from scipy.ndimage import gaussian_filter1d
|
from scipy.ndimage import gaussian_filter1d
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
|
|
||||||
class BrainToTextDatasetTF:
|
class BrainToTextDatasetTF:
|
||||||
@@ -758,7 +762,8 @@ def train_test_split_indices(file_paths: List[str],
|
|||||||
|
|
||||||
def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 100) -> Dict[str, int]:
|
def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 100) -> Dict[str, int]:
|
||||||
"""
|
"""
|
||||||
Analyze dataset to determine maximum shapes for padded_batch
|
Analyzes dataset shapes in parallel to determine maximum dimensions for padded_batch,
|
||||||
|
utilizing multiple CPU cores and the dataset's caching mechanism.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_tf: Dataset instance to analyze
|
dataset_tf: Dataset instance to analyze
|
||||||
@@ -767,79 +772,118 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary with maximum dimensions
|
Dictionary with maximum dimensions
|
||||||
"""
|
"""
|
||||||
print(f"🔍 Analyzing dataset shapes (sampling {sample_size if sample_size > 0 else 'ALL'} examples)...")
|
print(f"🚀 Starting parallel dataset analysis (sampling: {'ALL' if sample_size == -1 else sample_size}, max workers: 224)...")
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
max_shapes = {
|
# 1. 收集所有需要分析的 (day, trial) 对,避免重复
|
||||||
'max_time_steps': 0,
|
|
||||||
'max_phone_seq_len': 0,
|
|
||||||
'max_transcription_len': 0,
|
|
||||||
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
|
|
||||||
}
|
|
||||||
|
|
||||||
# 1. Create a flat list of all (day, trial) pairs
|
|
||||||
all_trials = []
|
all_trials = []
|
||||||
|
unique_trials = set()
|
||||||
for batch_idx in sorted(dataset_tf.batch_indices.keys()):
|
for batch_idx in sorted(dataset_tf.batch_indices.keys()):
|
||||||
batch = dataset_tf.batch_indices[batch_idx]
|
batch = dataset_tf.batch_indices[batch_idx]
|
||||||
for day, trials in batch.items():
|
for day, trials in batch.items():
|
||||||
for trial in trials:
|
for trial in trials:
|
||||||
|
if (day, trial) not in unique_trials:
|
||||||
|
unique_trials.add((day, trial))
|
||||||
all_trials.append((day, trial))
|
all_trials.append((day, trial))
|
||||||
|
|
||||||
# 2. Sample from the list if needed
|
# 2. 如果需要采样,则对列表进行采样
|
||||||
if sample_size > 0 and len(all_trials) > sample_size:
|
if 0 < sample_size < len(all_trials):
|
||||||
import random
|
# 设置种子以确保可重现性(如果需要的话)
|
||||||
if dataset_tf.split == 'train': # Random sampling for training
|
random.seed(42)
|
||||||
trials_to_check = random.sample(all_trials, sample_size)
|
trials_to_check = random.sample(all_trials, sample_size)
|
||||||
else: # Sequential sampling for validation
|
|
||||||
trials_to_check = all_trials[:sample_size]
|
|
||||||
else:
|
else:
|
||||||
trials_to_check = all_trials
|
trials_to_check = all_trials
|
||||||
|
|
||||||
print(f"📋 Will analyze {len(trials_to_check)} trials from {len(all_trials)} total trials")
|
total_trials_to_analyze = len(trials_to_check)
|
||||||
|
print(f"📊 Total unique trials to analyze: {total_trials_to_analyze}")
|
||||||
|
|
||||||
# 3. Iterate through the final list
|
# 定义一个辅助函数,供每个线程调用
|
||||||
count = 0
|
def analyze_single_trial(day_trial_pair):
|
||||||
for day, trial in trials_to_check:
|
"""Loads and analyzes a single trial, returns its shapes."""
|
||||||
|
day, trial = day_trial_pair
|
||||||
try:
|
try:
|
||||||
session_path = dataset_tf.trial_indices[day]['session_path']
|
# 复用 dataset_tf 的加载和缓存逻辑
|
||||||
with h5py.File(session_path, 'r') as f:
|
trial_data = dataset_tf._load_trial_data(day, trial)
|
||||||
g = f[f'trial_{trial:04d}']
|
|
||||||
|
|
||||||
# Check dimensions
|
# 直接从加载的数据中获取信息
|
||||||
time_steps = int(g.attrs['n_time_steps'])
|
time_steps = int(trial_data['n_time_steps'])
|
||||||
phone_seq_len = int(g.attrs['seq_len'])
|
phone_seq_len = int(trial_data['phone_seq_lens'])
|
||||||
transcription_data = g['transcription'][:]
|
|
||||||
|
# 处理 transcription 数据 - 它可能是数组
|
||||||
|
transcription_data = trial_data['transcription']
|
||||||
|
if hasattr(transcription_data, '__len__'):
|
||||||
transcription_len = len(transcription_data)
|
transcription_len = len(transcription_data)
|
||||||
|
else:
|
||||||
|
transcription_len = 1 # 如果是标量,长度为1
|
||||||
|
|
||||||
max_shapes['max_time_steps'] = max(max_shapes['max_time_steps'], time_steps)
|
return (time_steps, phone_seq_len, transcription_len)
|
||||||
max_shapes['max_phone_seq_len'] = max(max_shapes['max_phone_seq_len'], phone_seq_len)
|
|
||||||
max_shapes['max_transcription_len'] = max(max_shapes['max_transcription_len'], transcription_len)
|
|
||||||
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
# Show progress for large analyses
|
|
||||||
if count % 100 == 0:
|
|
||||||
print(f" Analyzed {count}/{len(trials_to_check)} samples... current max: time={max_shapes['max_time_steps']}, phone={max_shapes['max_phone_seq_len']}, trans={max_shapes['max_transcription_len']}")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
|
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
|
||||||
continue
|
return None # 返回 None 表示失败
|
||||||
|
|
||||||
# Add safety margins (20% buffer) to handle edge cases
|
# 3. 使用 ThreadPoolExecutor 进行并行处理
|
||||||
original_time_steps = max_shapes['max_time_steps']
|
max_workers = min(224, len(trials_to_check)) # 不超过实际任务数
|
||||||
original_phone_seq_len = max_shapes['max_phone_seq_len']
|
local_max_shapes = []
|
||||||
original_transcription_len = max_shapes['max_transcription_len']
|
|
||||||
|
|
||||||
max_shapes['max_time_steps'] = int(max_shapes['max_time_steps'] * 1.2)
|
print(f"🔧 Using {max_workers} parallel workers for analysis...")
|
||||||
max_shapes['max_phone_seq_len'] = int(max_shapes['max_phone_seq_len'] * 1.2)
|
|
||||||
max_shapes['max_transcription_len'] = int(max_shapes['max_transcription_len'] * 1.2)
|
|
||||||
|
|
||||||
print(f"📊 Dataset analysis complete (analyzed {count} samples):")
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
print(f" Original max time steps: {original_time_steps} → Padded: {max_shapes['max_time_steps']}")
|
# 提交所有分析任务
|
||||||
print(f" Original max phone sequence length: {original_phone_seq_len} → Padded: {max_shapes['max_phone_seq_len']}")
|
future_to_trial = {
|
||||||
print(f" Original max transcription length: {original_transcription_len} → Padded: {max_shapes['max_transcription_len']}")
|
executor.submit(analyze_single_trial, trial_pair): trial_pair
|
||||||
print(f" Number of features: {max_shapes['n_features']}")
|
for trial_pair in trials_to_check
|
||||||
|
}
|
||||||
|
|
||||||
return max_shapes
|
count = 0
|
||||||
|
progress_interval = max(1, total_trials_to_analyze // 20) # 显示20次进度更新
|
||||||
|
|
||||||
|
for future in as_completed(future_to_trial):
|
||||||
|
result = future.result()
|
||||||
|
if result:
|
||||||
|
local_max_shapes.append(result)
|
||||||
|
|
||||||
|
count += 1
|
||||||
|
if count % progress_interval == 0 or count == total_trials_to_analyze:
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
rate = count / elapsed if elapsed > 0 else 0
|
||||||
|
print(f" 📈 Analyzed {count}/{total_trials_to_analyze} trials... ({count/total_trials_to_analyze:.1%}) [{rate:.1f} trials/sec]")
|
||||||
|
|
||||||
|
# 4. 聚合所有线程的结果
|
||||||
|
if not local_max_shapes:
|
||||||
|
raise ValueError("Dataset analysis failed: No trials could be successfully analyzed.")
|
||||||
|
|
||||||
|
# 将 [(t1, p1, tr1), (t2, p2, tr2), ...] 转换为 ([t1, t2, ...], [p1, p2, ...], ...)
|
||||||
|
unzipped_shapes = list(zip(*local_max_shapes))
|
||||||
|
|
||||||
|
original_max_shapes = {
|
||||||
|
'max_time_steps': int(np.max(unzipped_shapes[0])),
|
||||||
|
'max_phone_seq_len': int(np.max(unzipped_shapes[1])),
|
||||||
|
'max_transcription_len': int(np.max(unzipped_shapes[2])),
|
||||||
|
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
|
||||||
|
}
|
||||||
|
|
||||||
|
# 5. 添加安全边际(10% buffer)
|
||||||
|
safety_margin = 1.1
|
||||||
|
|
||||||
|
final_max_shapes = {
|
||||||
|
'max_time_steps': int(original_max_shapes['max_time_steps'] * safety_margin),
|
||||||
|
'max_phone_seq_len': int(original_max_shapes['max_phone_seq_len'] * safety_margin),
|
||||||
|
'max_transcription_len': int(original_max_shapes['max_transcription_len'] * safety_margin),
|
||||||
|
'n_features': original_max_shapes['n_features']
|
||||||
|
}
|
||||||
|
|
||||||
|
analysis_time = time.time() - start_time
|
||||||
|
successful_rate = len(local_max_shapes) / total_trials_to_analyze * 100
|
||||||
|
|
||||||
|
print(f"✅ Parallel analysis complete in {analysis_time:.2f} seconds!")
|
||||||
|
print(f"📊 Successfully analyzed {len(local_max_shapes)}/{total_trials_to_analyze} trials ({successful_rate:.1f}%)")
|
||||||
|
print(f"📏 Final max shapes (with {int((safety_margin-1)*100)}% safety margin):")
|
||||||
|
print(f" Time steps: {original_max_shapes['max_time_steps']} → {final_max_shapes['max_time_steps']}")
|
||||||
|
print(f" Phone sequence length: {original_max_shapes['max_phone_seq_len']} → {final_max_shapes['max_phone_seq_len']}")
|
||||||
|
print(f" Transcription length: {original_max_shapes['max_transcription_len']} → {final_max_shapes['max_transcription_len']}")
|
||||||
|
print(f" Number of features: {final_max_shapes['n_features']}")
|
||||||
|
|
||||||
|
return final_max_shapes
|
||||||
|
|
||||||
|
|
||||||
# Utility functions for TPU-optimized data pipeline
|
# Utility functions for TPU-optimized data pipeline
|
||||||
|
Reference in New Issue
Block a user