Add dataset shape analysis function and integrate into input function for TPU optimization

This commit is contained in:
Zchen
2025-10-19 11:04:36 +08:00
parent cfd9653da9
commit 4328114ed6

View File

@@ -756,11 +756,81 @@ def train_test_split_indices(file_paths: List[str],
return train_trials, test_trials
def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 100) -> Dict[str, int]:
"""
Analyze dataset to determine maximum shapes for padded_batch
Args:
dataset_tf: Dataset instance to analyze
sample_size: Number of samples to analyze (set to -1 for all data)
Returns:
Dictionary with maximum dimensions
"""
print(f"🔍 Analyzing dataset shapes (sampling {sample_size} examples)...")
max_shapes = {
'max_time_steps': 0,
'max_phone_seq_len': 0,
'max_transcription_len': 0,
'n_features': 512 # Fixed for neural features
}
# Sample a subset of data to determine max sizes
count = 0
for batch_idx in list(dataset_tf.batch_indices.keys())[:min(10, len(dataset_tf.batch_indices))]:
batch_index = dataset_tf.batch_indices[batch_idx]
for day in batch_index.keys():
for trial in batch_index[day][:min(10, len(batch_index[day]))]:
if count >= sample_size and sample_size > 0:
break
try:
session_path = dataset_tf.trial_indices[day]['session_path']
with h5py.File(session_path, 'r') as f:
g = f[f'trial_{trial:04d}']
# Check dimensions
time_steps = int(g.attrs['n_time_steps'])
phone_seq_len = int(g.attrs['seq_len'])
transcription_data = g['transcription'][:]
transcription_len = len(transcription_data)
max_shapes['max_time_steps'] = max(max_shapes['max_time_steps'], time_steps)
max_shapes['max_phone_seq_len'] = max(max_shapes['max_phone_seq_len'], phone_seq_len)
max_shapes['max_transcription_len'] = max(max_shapes['max_transcription_len'], transcription_len)
count += 1
except Exception as e:
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
continue
if count >= sample_size and sample_size > 0:
break
if count >= sample_size and sample_size > 0:
break
# Add safety margins (20% buffer) to handle edge cases
max_shapes['max_time_steps'] = int(max_shapes['max_time_steps'] * 1.2)
max_shapes['max_phone_seq_len'] = int(max_shapes['max_phone_seq_len'] * 1.2)
max_shapes['max_transcription_len'] = int(max_shapes['max_transcription_len'] * 1.2)
print(f"📊 Dataset analysis complete (analyzed {count} samples):")
print(f" Max time steps: {max_shapes['max_time_steps']}")
print(f" Max phone sequence length: {max_shapes['max_phone_seq_len']}")
print(f" Max transcription length: {max_shapes['max_transcription_len']}")
return max_shapes
# Utility functions for TPU-optimized data pipeline
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
transform_args: Dict[str, Any],
training: bool = True,
cache_path: Optional[str] = None) -> tf.data.Dataset:
cache_path: Optional[str] = None,
auto_analyze_shapes: bool = True) -> tf.data.Dataset:
"""
Create input function for TPU training with fixed-shape batching and data augmentation
@@ -769,6 +839,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
transform_args: Data transformation configuration
training: Whether this is for training (applies augmentations)
cache_path: Optional path for disk caching to improve I/O performance
auto_analyze_shapes: Whether to automatically analyze dataset for optimal shapes
Returns:
tf.data.Dataset ready for TPU training with fixed shapes
@@ -820,12 +891,21 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
num_parallel_calls=tf.data.AUTOTUNE
)
# Define fixed shapes for TPU compatibility
# These should match the maximum expected sizes in your dataset
max_time_steps = 4096 # Adjust based on your data
max_phone_seq_len = 256 # Adjust based on your data
max_transcription_len = 512 # Adjust based on your data
n_features = 512 # Number of neural features
# Determine shapes for TPU compatibility
if auto_analyze_shapes:
# Dynamically analyze dataset to determine optimal shapes
shape_info = analyze_dataset_shapes(dataset_tf, sample_size=100)
max_time_steps = shape_info['max_time_steps']
max_phone_seq_len = shape_info['max_phone_seq_len']
max_transcription_len = shape_info['max_transcription_len']
n_features = shape_info['n_features']
else:
# Use conservative fixed shapes for TPU compatibility
# Increased sizes to handle larger data - adjust based on your actual dataset
max_time_steps = 8192 # Increased from 4096 - adjust based on your data
max_phone_seq_len = 512 # Increased from 256 - adjust based on your data
max_transcription_len = 1024 # Increased from 512 - adjust based on your data
n_features = 512 # Number of neural features
padded_shapes = {
'input_features': [max_time_steps, n_features],