f
This commit is contained in:
@@ -767,7 +767,7 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary with maximum dimensions
|
Dictionary with maximum dimensions
|
||||||
"""
|
"""
|
||||||
print(f"🔍 Analyzing dataset shapes (sampling {sample_size} examples)...")
|
print(f"🔍 Analyzing dataset shapes (sampling {sample_size if sample_size > 0 else 'ALL'} examples)...")
|
||||||
|
|
||||||
max_shapes = {
|
max_shapes = {
|
||||||
'max_time_steps': 0,
|
'max_time_steps': 0,
|
||||||
@@ -776,35 +776,29 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
|||||||
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
|
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
|
||||||
}
|
}
|
||||||
|
|
||||||
# Sample a subset of data to determine max sizes
|
# 1. Create a flat list of all (day, trial) pairs
|
||||||
count = 0
|
all_trials = []
|
||||||
batch_keys = list(dataset_tf.batch_indices.keys())
|
for batch_idx in sorted(dataset_tf.batch_indices.keys()):
|
||||||
|
batch = dataset_tf.batch_indices[batch_idx]
|
||||||
|
for day, trials in batch.items():
|
||||||
|
for trial in trials:
|
||||||
|
all_trials.append((day, trial))
|
||||||
|
|
||||||
# If sample_size is -1, analyze all data
|
# 2. Sample from the list if needed
|
||||||
if sample_size == -1:
|
if sample_size > 0 and len(all_trials) > sample_size:
|
||||||
batches_to_check = batch_keys
|
import random
|
||||||
max_trials_per_batch = float('inf')
|
if dataset_tf.split == 'train': # Random sampling for training
|
||||||
|
trials_to_check = random.sample(all_trials, sample_size)
|
||||||
|
else: # Sequential sampling for validation
|
||||||
|
trials_to_check = all_trials[:sample_size]
|
||||||
else:
|
else:
|
||||||
# Sample a reasonable number of batches
|
trials_to_check = all_trials
|
||||||
batches_to_check = batch_keys[:min(max(10, sample_size // 10), len(batch_keys))]
|
|
||||||
max_trials_per_batch = max(1, sample_size // len(batches_to_check))
|
|
||||||
|
|
||||||
for batch_idx in batches_to_check:
|
print(f"📋 Will analyze {len(trials_to_check)} trials from {len(all_trials)} total trials")
|
||||||
if count >= sample_size and sample_size > 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
batch_index = dataset_tf.batch_indices[batch_idx]
|
|
||||||
|
|
||||||
for day in batch_index.keys():
|
|
||||||
if count >= sample_size and sample_size > 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
trials_to_check = batch_index[day][:min(int(max_trials_per_batch), len(batch_index[day]))]
|
|
||||||
|
|
||||||
for trial in trials_to_check:
|
|
||||||
if count >= sample_size and sample_size > 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
|
# 3. Iterate through the final list
|
||||||
|
count = 0
|
||||||
|
for day, trial in trials_to_check:
|
||||||
try:
|
try:
|
||||||
session_path = dataset_tf.trial_indices[day]['session_path']
|
session_path = dataset_tf.trial_indices[day]['session_path']
|
||||||
with h5py.File(session_path, 'r') as f:
|
with h5py.File(session_path, 'r') as f:
|
||||||
@@ -823,8 +817,8 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
|||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
# Show progress for large analyses
|
# Show progress for large analyses
|
||||||
if count % 50 == 0:
|
if count % 100 == 0:
|
||||||
print(f" Analyzed {count} samples... current max: time={max_shapes['max_time_steps']}, phone={max_shapes['max_phone_seq_len']}, trans={max_shapes['max_transcription_len']}")
|
print(f" Analyzed {count}/{len(trials_to_check)} samples... current max: time={max_shapes['max_time_steps']}, phone={max_shapes['max_phone_seq_len']}, trans={max_shapes['max_transcription_len']}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
|
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
|
||||||
|
Reference in New Issue
Block a user