f
This commit is contained in:
@@ -767,7 +767,7 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
|||||||
Returns:
|
Returns:
|
||||||
Dictionary with maximum dimensions
|
Dictionary with maximum dimensions
|
||||||
"""
|
"""
|
||||||
print(f"🔍 Analyzing dataset shapes (sampling {sample_size} examples)...")
|
print(f"🔍 Analyzing dataset shapes (sampling {sample_size if sample_size > 0 else 'ALL'} examples)...")
|
||||||
|
|
||||||
max_shapes = {
|
max_shapes = {
|
||||||
'max_time_steps': 0,
|
'max_time_steps': 0,
|
||||||
@@ -776,59 +776,53 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
|||||||
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
|
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
|
||||||
}
|
}
|
||||||
|
|
||||||
# Sample a subset of data to determine max sizes
|
# 1. Create a flat list of all (day, trial) pairs
|
||||||
count = 0
|
all_trials = []
|
||||||
batch_keys = list(dataset_tf.batch_indices.keys())
|
for batch_idx in sorted(dataset_tf.batch_indices.keys()):
|
||||||
|
batch = dataset_tf.batch_indices[batch_idx]
|
||||||
|
for day, trials in batch.items():
|
||||||
|
for trial in trials:
|
||||||
|
all_trials.append((day, trial))
|
||||||
|
|
||||||
# If sample_size is -1, analyze all data
|
# 2. Sample from the list if needed
|
||||||
if sample_size == -1:
|
if sample_size > 0 and len(all_trials) > sample_size:
|
||||||
batches_to_check = batch_keys
|
import random
|
||||||
max_trials_per_batch = float('inf')
|
if dataset_tf.split == 'train': # Random sampling for training
|
||||||
|
trials_to_check = random.sample(all_trials, sample_size)
|
||||||
|
else: # Sequential sampling for validation
|
||||||
|
trials_to_check = all_trials[:sample_size]
|
||||||
else:
|
else:
|
||||||
# Sample a reasonable number of batches
|
trials_to_check = all_trials
|
||||||
batches_to_check = batch_keys[:min(max(10, sample_size // 10), len(batch_keys))]
|
|
||||||
max_trials_per_batch = max(1, sample_size // len(batches_to_check))
|
|
||||||
|
|
||||||
for batch_idx in batches_to_check:
|
print(f"📋 Will analyze {len(trials_to_check)} trials from {len(all_trials)} total trials")
|
||||||
if count >= sample_size and sample_size > 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
batch_index = dataset_tf.batch_indices[batch_idx]
|
# 3. Iterate through the final list
|
||||||
|
count = 0
|
||||||
|
for day, trial in trials_to_check:
|
||||||
|
try:
|
||||||
|
session_path = dataset_tf.trial_indices[day]['session_path']
|
||||||
|
with h5py.File(session_path, 'r') as f:
|
||||||
|
g = f[f'trial_{trial:04d}']
|
||||||
|
|
||||||
for day in batch_index.keys():
|
# Check dimensions
|
||||||
if count >= sample_size and sample_size > 0:
|
time_steps = int(g.attrs['n_time_steps'])
|
||||||
break
|
phone_seq_len = int(g.attrs['seq_len'])
|
||||||
|
transcription_data = g['transcription'][:]
|
||||||
|
transcription_len = len(transcription_data)
|
||||||
|
|
||||||
trials_to_check = batch_index[day][:min(int(max_trials_per_batch), len(batch_index[day]))]
|
max_shapes['max_time_steps'] = max(max_shapes['max_time_steps'], time_steps)
|
||||||
|
max_shapes['max_phone_seq_len'] = max(max_shapes['max_phone_seq_len'], phone_seq_len)
|
||||||
|
max_shapes['max_transcription_len'] = max(max_shapes['max_transcription_len'], transcription_len)
|
||||||
|
|
||||||
for trial in trials_to_check:
|
count += 1
|
||||||
if count >= sample_size and sample_size > 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
try:
|
# Show progress for large analyses
|
||||||
session_path = dataset_tf.trial_indices[day]['session_path']
|
if count % 100 == 0:
|
||||||
with h5py.File(session_path, 'r') as f:
|
print(f" Analyzed {count}/{len(trials_to_check)} samples... current max: time={max_shapes['max_time_steps']}, phone={max_shapes['max_phone_seq_len']}, trans={max_shapes['max_transcription_len']}")
|
||||||
g = f[f'trial_{trial:04d}']
|
|
||||||
|
|
||||||
# Check dimensions
|
except Exception as e:
|
||||||
time_steps = int(g.attrs['n_time_steps'])
|
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
|
||||||
phone_seq_len = int(g.attrs['seq_len'])
|
continue
|
||||||
transcription_data = g['transcription'][:]
|
|
||||||
transcription_len = len(transcription_data)
|
|
||||||
|
|
||||||
max_shapes['max_time_steps'] = max(max_shapes['max_time_steps'], time_steps)
|
|
||||||
max_shapes['max_phone_seq_len'] = max(max_shapes['max_phone_seq_len'], phone_seq_len)
|
|
||||||
max_shapes['max_transcription_len'] = max(max_shapes['max_transcription_len'], transcription_len)
|
|
||||||
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
# Show progress for large analyses
|
|
||||||
if count % 50 == 0:
|
|
||||||
print(f" Analyzed {count} samples... current max: time={max_shapes['max_time_steps']}, phone={max_shapes['max_phone_seq_len']}, trans={max_shapes['max_transcription_len']}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Add safety margins (20% buffer) to handle edge cases
|
# Add safety margins (20% buffer) to handle edge cases
|
||||||
original_time_steps = max_shapes['max_time_steps']
|
original_time_steps = max_shapes['max_time_steps']
|
||||||
|
Reference in New Issue
Block a user