This commit is contained in:
Zchen
2025-10-20 00:13:39 +08:00
parent 4db3625dc5
commit 6e02894a8a

View File

@@ -767,7 +767,7 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
Returns:
Dictionary with maximum dimensions
"""
print(f"🔍 Analyzing dataset shapes (sampling {sample_size} examples)...")
print(f"🔍 Analyzing dataset shapes (sampling {sample_size if sample_size > 0 else 'ALL'} examples)...")
max_shapes = {
'max_time_steps': 0,
@@ -776,35 +776,29 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
}
# Sample a subset of data to determine max sizes
count = 0
batch_keys = list(dataset_tf.batch_indices.keys())
# 1. Create a flat list of all (day, trial) pairs
all_trials = []
for batch_idx in sorted(dataset_tf.batch_indices.keys()):
batch = dataset_tf.batch_indices[batch_idx]
for day, trials in batch.items():
for trial in trials:
all_trials.append((day, trial))
# If sample_size is -1, analyze all data
if sample_size == -1:
batches_to_check = batch_keys
max_trials_per_batch = float('inf')
# 2. Sample from the list if needed
if sample_size > 0 and len(all_trials) > sample_size:
import random
if dataset_tf.split == 'train': # Random sampling for training
trials_to_check = random.sample(all_trials, sample_size)
else: # Sequential sampling for validation
trials_to_check = all_trials[:sample_size]
else:
# Sample a reasonable number of batches
batches_to_check = batch_keys[:min(max(10, sample_size // 10), len(batch_keys))]
max_trials_per_batch = max(1, sample_size // len(batches_to_check))
trials_to_check = all_trials
for batch_idx in batches_to_check:
if count >= sample_size and sample_size > 0:
break
batch_index = dataset_tf.batch_indices[batch_idx]
for day in batch_index.keys():
if count >= sample_size and sample_size > 0:
break
trials_to_check = batch_index[day][:min(int(max_trials_per_batch), len(batch_index[day]))]
for trial in trials_to_check:
if count >= sample_size and sample_size > 0:
break
print(f"📋 Will analyze {len(trials_to_check)} trials from {len(all_trials)} total trials")
# 3. Iterate through the final list
count = 0
for day, trial in trials_to_check:
try:
session_path = dataset_tf.trial_indices[day]['session_path']
with h5py.File(session_path, 'r') as f:
@@ -823,8 +817,8 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
count += 1
# Show progress for large analyses
if count % 50 == 0:
print(f" Analyzed {count} samples... current max: time={max_shapes['max_time_steps']}, phone={max_shapes['max_phone_seq_len']}, trans={max_shapes['max_transcription_len']}")
if count % 100 == 0:
print(f" Analyzed {count}/{len(trials_to_check)} samples... current max: time={max_shapes['max_time_steps']}, phone={max_shapes['max_phone_seq_len']}, trans={max_shapes['max_transcription_len']}")
except Exception as e:
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")