From fabf70cfa982c55c3da0ac230339c68233339154 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Mon, 20 Oct 2025 00:35:17 +0800 Subject: [PATCH] Enhance dataset shape analysis by implementing parallel processing and improving sampling logic --- model_training_nnn_tpu/dataset_tf.py | 150 +++++++++++++++++---------- 1 file changed, 97 insertions(+), 53 deletions(-) diff --git a/model_training_nnn_tpu/dataset_tf.py b/model_training_nnn_tpu/dataset_tf.py index 47f9cea..2a42f24 100644 --- a/model_training_nnn_tpu/dataset_tf.py +++ b/model_training_nnn_tpu/dataset_tf.py @@ -4,9 +4,13 @@ import h5py import numpy as np import math import logging +import time +import random +import multiprocessing from itertools import groupby from typing import Dict, List, Tuple, Optional, Any from scipy.ndimage import gaussian_filter1d +from concurrent.futures import ThreadPoolExecutor, as_completed class BrainToTextDatasetTF: @@ -758,7 +762,8 @@ def train_test_split_indices(file_paths: List[str], def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 100) -> Dict[str, int]: """ - Analyze dataset to determine maximum shapes for padded_batch + Analyzes dataset shapes in parallel to determine maximum dimensions for padded_batch, + utilizing multiple CPU cores and the dataset's caching mechanism. Args: dataset_tf: Dataset instance to analyze @@ -767,79 +772,118 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = Returns: Dictionary with maximum dimensions """ - print(f"🔍 Analyzing dataset shapes (sampling {sample_size if sample_size > 0 else 'ALL'} examples)...") + print(f"🚀 Starting parallel dataset analysis (sampling: {'ALL' if sample_size == -1 else sample_size}, max workers: 224)...") + start_time = time.time() - max_shapes = { - 'max_time_steps': 0, - 'max_phone_seq_len': 0, - 'max_transcription_len': 0, - 'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512 - } - - # 1. Create a flat list of all (day, trial) pairs + # 1. 收集所有需要分析的 (day, trial) 对,避免重复 all_trials = [] + unique_trials = set() for batch_idx in sorted(dataset_tf.batch_indices.keys()): batch = dataset_tf.batch_indices[batch_idx] for day, trials in batch.items(): for trial in trials: - all_trials.append((day, trial)) + if (day, trial) not in unique_trials: + unique_trials.add((day, trial)) + all_trials.append((day, trial)) - # 2. Sample from the list if needed - if sample_size > 0 and len(all_trials) > sample_size: - import random - if dataset_tf.split == 'train': # Random sampling for training - trials_to_check = random.sample(all_trials, sample_size) - else: # Sequential sampling for validation - trials_to_check = all_trials[:sample_size] + # 2. 如果需要采样,则对列表进行采样 + if 0 < sample_size < len(all_trials): + # 设置种子以确保可重现性(如果需要的话) + random.seed(42) + trials_to_check = random.sample(all_trials, sample_size) else: trials_to_check = all_trials - print(f"📋 Will analyze {len(trials_to_check)} trials from {len(all_trials)} total trials") + total_trials_to_analyze = len(trials_to_check) + print(f"📊 Total unique trials to analyze: {total_trials_to_analyze}") - # 3. Iterate through the final list - count = 0 - for day, trial in trials_to_check: + # 定义一个辅助函数,供每个线程调用 + def analyze_single_trial(day_trial_pair): + """Loads and analyzes a single trial, returns its shapes.""" + day, trial = day_trial_pair try: - session_path = dataset_tf.trial_indices[day]['session_path'] - with h5py.File(session_path, 'r') as f: - g = f[f'trial_{trial:04d}'] + # 复用 dataset_tf 的加载和缓存逻辑 + trial_data = dataset_tf._load_trial_data(day, trial) - # Check dimensions - time_steps = int(g.attrs['n_time_steps']) - phone_seq_len = int(g.attrs['seq_len']) - transcription_data = g['transcription'][:] + # 直接从加载的数据中获取信息 + time_steps = int(trial_data['n_time_steps']) + phone_seq_len = int(trial_data['phone_seq_lens']) + + # 处理 transcription 数据 - 它可能是数组 + transcription_data = trial_data['transcription'] + if hasattr(transcription_data, '__len__'): transcription_len = len(transcription_data) + else: + transcription_len = 1 # 如果是标量,长度为1 - max_shapes['max_time_steps'] = max(max_shapes['max_time_steps'], time_steps) - max_shapes['max_phone_seq_len'] = max(max_shapes['max_phone_seq_len'], phone_seq_len) - max_shapes['max_transcription_len'] = max(max_shapes['max_transcription_len'], transcription_len) - - count += 1 - - # Show progress for large analyses - if count % 100 == 0: - print(f" Analyzed {count}/{len(trials_to_check)} samples... current max: time={max_shapes['max_time_steps']}, phone={max_shapes['max_phone_seq_len']}, trans={max_shapes['max_transcription_len']}") - + return (time_steps, phone_seq_len, transcription_len) except Exception as e: logging.warning(f"Failed to analyze trial {day}_{trial}: {e}") - continue + return None # 返回 None 表示失败 - # Add safety margins (20% buffer) to handle edge cases - original_time_steps = max_shapes['max_time_steps'] - original_phone_seq_len = max_shapes['max_phone_seq_len'] - original_transcription_len = max_shapes['max_transcription_len'] + # 3. 使用 ThreadPoolExecutor 进行并行处理 + max_workers = min(224, len(trials_to_check)) # 不超过实际任务数 + local_max_shapes = [] - max_shapes['max_time_steps'] = int(max_shapes['max_time_steps'] * 1.2) - max_shapes['max_phone_seq_len'] = int(max_shapes['max_phone_seq_len'] * 1.2) - max_shapes['max_transcription_len'] = int(max_shapes['max_transcription_len'] * 1.2) + print(f"🔧 Using {max_workers} parallel workers for analysis...") - print(f"📊 Dataset analysis complete (analyzed {count} samples):") - print(f" Original max time steps: {original_time_steps} → Padded: {max_shapes['max_time_steps']}") - print(f" Original max phone sequence length: {original_phone_seq_len} → Padded: {max_shapes['max_phone_seq_len']}") - print(f" Original max transcription length: {original_transcription_len} → Padded: {max_shapes['max_transcription_len']}") - print(f" Number of features: {max_shapes['n_features']}") + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # 提交所有分析任务 + future_to_trial = { + executor.submit(analyze_single_trial, trial_pair): trial_pair + for trial_pair in trials_to_check + } - return max_shapes + count = 0 + progress_interval = max(1, total_trials_to_analyze // 20) # 显示20次进度更新 + + for future in as_completed(future_to_trial): + result = future.result() + if result: + local_max_shapes.append(result) + + count += 1 + if count % progress_interval == 0 or count == total_trials_to_analyze: + elapsed = time.time() - start_time + rate = count / elapsed if elapsed > 0 else 0 + print(f" 📈 Analyzed {count}/{total_trials_to_analyze} trials... ({count/total_trials_to_analyze:.1%}) [{rate:.1f} trials/sec]") + + # 4. 聚合所有线程的结果 + if not local_max_shapes: + raise ValueError("Dataset analysis failed: No trials could be successfully analyzed.") + + # 将 [(t1, p1, tr1), (t2, p2, tr2), ...] 转换为 ([t1, t2, ...], [p1, p2, ...], ...) + unzipped_shapes = list(zip(*local_max_shapes)) + + original_max_shapes = { + 'max_time_steps': int(np.max(unzipped_shapes[0])), + 'max_phone_seq_len': int(np.max(unzipped_shapes[1])), + 'max_transcription_len': int(np.max(unzipped_shapes[2])), + 'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512 + } + + # 5. 添加安全边际(10% buffer) + safety_margin = 1.1 + + final_max_shapes = { + 'max_time_steps': int(original_max_shapes['max_time_steps'] * safety_margin), + 'max_phone_seq_len': int(original_max_shapes['max_phone_seq_len'] * safety_margin), + 'max_transcription_len': int(original_max_shapes['max_transcription_len'] * safety_margin), + 'n_features': original_max_shapes['n_features'] + } + + analysis_time = time.time() - start_time + successful_rate = len(local_max_shapes) / total_trials_to_analyze * 100 + + print(f"✅ Parallel analysis complete in {analysis_time:.2f} seconds!") + print(f"📊 Successfully analyzed {len(local_max_shapes)}/{total_trials_to_analyze} trials ({successful_rate:.1f}%)") + print(f"📏 Final max shapes (with {int((safety_margin-1)*100)}% safety margin):") + print(f" Time steps: {original_max_shapes['max_time_steps']} → {final_max_shapes['max_time_steps']}") + print(f" Phone sequence length: {original_max_shapes['max_phone_seq_len']} → {final_max_shapes['max_phone_seq_len']}") + print(f" Transcription length: {original_max_shapes['max_transcription_len']} → {final_max_shapes['max_transcription_len']}") + print(f" Number of features: {final_max_shapes['n_features']}") + + return final_max_shapes # Utility functions for TPU-optimized data pipeline