f
This commit is contained in:
@@ -767,7 +767,7 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
||||
Returns:
|
||||
Dictionary with maximum dimensions
|
||||
"""
|
||||
print(f"🔍 Analyzing dataset shapes (sampling {sample_size} examples)...")
|
||||
print(f"🔍 Analyzing dataset shapes (sampling {sample_size if sample_size > 0 else 'ALL'} examples)...")
|
||||
|
||||
max_shapes = {
|
||||
'max_time_steps': 0,
|
||||
@@ -776,35 +776,29 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
||||
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
|
||||
}
|
||||
|
||||
# Sample a subset of data to determine max sizes
|
||||
count = 0
|
||||
batch_keys = list(dataset_tf.batch_indices.keys())
|
||||
# 1. Create a flat list of all (day, trial) pairs
|
||||
all_trials = []
|
||||
for batch_idx in sorted(dataset_tf.batch_indices.keys()):
|
||||
batch = dataset_tf.batch_indices[batch_idx]
|
||||
for day, trials in batch.items():
|
||||
for trial in trials:
|
||||
all_trials.append((day, trial))
|
||||
|
||||
# If sample_size is -1, analyze all data
|
||||
if sample_size == -1:
|
||||
batches_to_check = batch_keys
|
||||
max_trials_per_batch = float('inf')
|
||||
# 2. Sample from the list if needed
|
||||
if sample_size > 0 and len(all_trials) > sample_size:
|
||||
import random
|
||||
if dataset_tf.split == 'train': # Random sampling for training
|
||||
trials_to_check = random.sample(all_trials, sample_size)
|
||||
else: # Sequential sampling for validation
|
||||
trials_to_check = all_trials[:sample_size]
|
||||
else:
|
||||
# Sample a reasonable number of batches
|
||||
batches_to_check = batch_keys[:min(max(10, sample_size // 10), len(batch_keys))]
|
||||
max_trials_per_batch = max(1, sample_size // len(batches_to_check))
|
||||
trials_to_check = all_trials
|
||||
|
||||
for batch_idx in batches_to_check:
|
||||
if count >= sample_size and sample_size > 0:
|
||||
break
|
||||
|
||||
batch_index = dataset_tf.batch_indices[batch_idx]
|
||||
|
||||
for day in batch_index.keys():
|
||||
if count >= sample_size and sample_size > 0:
|
||||
break
|
||||
|
||||
trials_to_check = batch_index[day][:min(int(max_trials_per_batch), len(batch_index[day]))]
|
||||
|
||||
for trial in trials_to_check:
|
||||
if count >= sample_size and sample_size > 0:
|
||||
break
|
||||
print(f"📋 Will analyze {len(trials_to_check)} trials from {len(all_trials)} total trials")
|
||||
|
||||
# 3. Iterate through the final list
|
||||
count = 0
|
||||
for day, trial in trials_to_check:
|
||||
try:
|
||||
session_path = dataset_tf.trial_indices[day]['session_path']
|
||||
with h5py.File(session_path, 'r') as f:
|
||||
@@ -823,8 +817,8 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
||||
count += 1
|
||||
|
||||
# Show progress for large analyses
|
||||
if count % 50 == 0:
|
||||
print(f" Analyzed {count} samples... current max: time={max_shapes['max_time_steps']}, phone={max_shapes['max_phone_seq_len']}, trans={max_shapes['max_transcription_len']}")
|
||||
if count % 100 == 0:
|
||||
print(f" Analyzed {count}/{len(trials_to_check)} samples... current max: time={max_shapes['max_time_steps']}, phone={max_shapes['max_phone_seq_len']}, trans={max_shapes['max_transcription_len']}")
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
|
||||
|
Reference in New Issue
Block a user