This commit is contained in:
Zchen
2025-10-20 00:13:39 +08:00
parent 4db3625dc5
commit 6e02894a8a

View File

@@ -767,7 +767,7 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
Returns: Returns:
Dictionary with maximum dimensions Dictionary with maximum dimensions
""" """
print(f"🔍 Analyzing dataset shapes (sampling {sample_size} examples)...") print(f"🔍 Analyzing dataset shapes (sampling {sample_size if sample_size > 0 else 'ALL'} examples)...")
max_shapes = { max_shapes = {
'max_time_steps': 0, 'max_time_steps': 0,
@@ -776,59 +776,53 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512 'n_features': len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
} }
# Sample a subset of data to determine max sizes # 1. Create a flat list of all (day, trial) pairs
count = 0 all_trials = []
batch_keys = list(dataset_tf.batch_indices.keys()) for batch_idx in sorted(dataset_tf.batch_indices.keys()):
batch = dataset_tf.batch_indices[batch_idx]
for day, trials in batch.items():
for trial in trials:
all_trials.append((day, trial))
# If sample_size is -1, analyze all data # 2. Sample from the list if needed
if sample_size == -1: if sample_size > 0 and len(all_trials) > sample_size:
batches_to_check = batch_keys import random
max_trials_per_batch = float('inf') if dataset_tf.split == 'train': # Random sampling for training
trials_to_check = random.sample(all_trials, sample_size)
else: # Sequential sampling for validation
trials_to_check = all_trials[:sample_size]
else: else:
# Sample a reasonable number of batches trials_to_check = all_trials
batches_to_check = batch_keys[:min(max(10, sample_size // 10), len(batch_keys))]
max_trials_per_batch = max(1, sample_size // len(batches_to_check))
for batch_idx in batches_to_check: print(f"📋 Will analyze {len(trials_to_check)} trials from {len(all_trials)} total trials")
if count >= sample_size and sample_size > 0:
break
batch_index = dataset_tf.batch_indices[batch_idx] # 3. Iterate through the final list
count = 0
for day, trial in trials_to_check:
try:
session_path = dataset_tf.trial_indices[day]['session_path']
with h5py.File(session_path, 'r') as f:
g = f[f'trial_{trial:04d}']
for day in batch_index.keys(): # Check dimensions
if count >= sample_size and sample_size > 0: time_steps = int(g.attrs['n_time_steps'])
break phone_seq_len = int(g.attrs['seq_len'])
transcription_data = g['transcription'][:]
transcription_len = len(transcription_data)
trials_to_check = batch_index[day][:min(int(max_trials_per_batch), len(batch_index[day]))] max_shapes['max_time_steps'] = max(max_shapes['max_time_steps'], time_steps)
max_shapes['max_phone_seq_len'] = max(max_shapes['max_phone_seq_len'], phone_seq_len)
max_shapes['max_transcription_len'] = max(max_shapes['max_transcription_len'], transcription_len)
for trial in trials_to_check: count += 1
if count >= sample_size and sample_size > 0:
break
try: # Show progress for large analyses
session_path = dataset_tf.trial_indices[day]['session_path'] if count % 100 == 0:
with h5py.File(session_path, 'r') as f: print(f" Analyzed {count}/{len(trials_to_check)} samples... current max: time={max_shapes['max_time_steps']}, phone={max_shapes['max_phone_seq_len']}, trans={max_shapes['max_transcription_len']}")
g = f[f'trial_{trial:04d}']
# Check dimensions except Exception as e:
time_steps = int(g.attrs['n_time_steps']) logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
phone_seq_len = int(g.attrs['seq_len']) continue
transcription_data = g['transcription'][:]
transcription_len = len(transcription_data)
max_shapes['max_time_steps'] = max(max_shapes['max_time_steps'], time_steps)
max_shapes['max_phone_seq_len'] = max(max_shapes['max_phone_seq_len'], phone_seq_len)
max_shapes['max_transcription_len'] = max(max_shapes['max_transcription_len'], transcription_len)
count += 1
# Show progress for large analyses
if count % 50 == 0:
print(f" Analyzed {count} samples... current max: time={max_shapes['max_time_steps']}, phone={max_shapes['max_phone_seq_len']}, trans={max_shapes['max_transcription_len']}")
except Exception as e:
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
continue
# Add safety margins (20% buffer) to handle edge cases # Add safety margins (20% buffer) to handle edge cases
original_time_steps = max_shapes['max_time_steps'] original_time_steps = max_shapes['max_time_steps']