Add dataset shape analysis function and integrate into input function for TPU optimization
This commit is contained in:
@@ -756,11 +756,81 @@ def train_test_split_indices(file_paths: List[str],
|
||||
return train_trials, test_trials
|
||||
|
||||
|
||||
def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 100) -> Dict[str, int]:
|
||||
"""
|
||||
Analyze dataset to determine maximum shapes for padded_batch
|
||||
|
||||
Args:
|
||||
dataset_tf: Dataset instance to analyze
|
||||
sample_size: Number of samples to analyze (set to -1 for all data)
|
||||
|
||||
Returns:
|
||||
Dictionary with maximum dimensions
|
||||
"""
|
||||
print(f"🔍 Analyzing dataset shapes (sampling {sample_size} examples)...")
|
||||
|
||||
max_shapes = {
|
||||
'max_time_steps': 0,
|
||||
'max_phone_seq_len': 0,
|
||||
'max_transcription_len': 0,
|
||||
'n_features': 512 # Fixed for neural features
|
||||
}
|
||||
|
||||
# Sample a subset of data to determine max sizes
|
||||
count = 0
|
||||
for batch_idx in list(dataset_tf.batch_indices.keys())[:min(10, len(dataset_tf.batch_indices))]:
|
||||
batch_index = dataset_tf.batch_indices[batch_idx]
|
||||
|
||||
for day in batch_index.keys():
|
||||
for trial in batch_index[day][:min(10, len(batch_index[day]))]:
|
||||
if count >= sample_size and sample_size > 0:
|
||||
break
|
||||
|
||||
try:
|
||||
session_path = dataset_tf.trial_indices[day]['session_path']
|
||||
with h5py.File(session_path, 'r') as f:
|
||||
g = f[f'trial_{trial:04d}']
|
||||
|
||||
# Check dimensions
|
||||
time_steps = int(g.attrs['n_time_steps'])
|
||||
phone_seq_len = int(g.attrs['seq_len'])
|
||||
transcription_data = g['transcription'][:]
|
||||
transcription_len = len(transcription_data)
|
||||
|
||||
max_shapes['max_time_steps'] = max(max_shapes['max_time_steps'], time_steps)
|
||||
max_shapes['max_phone_seq_len'] = max(max_shapes['max_phone_seq_len'], phone_seq_len)
|
||||
max_shapes['max_transcription_len'] = max(max_shapes['max_transcription_len'], transcription_len)
|
||||
|
||||
count += 1
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to analyze trial {day}_{trial}: {e}")
|
||||
continue
|
||||
|
||||
if count >= sample_size and sample_size > 0:
|
||||
break
|
||||
if count >= sample_size and sample_size > 0:
|
||||
break
|
||||
|
||||
# Add safety margins (20% buffer) to handle edge cases
|
||||
max_shapes['max_time_steps'] = int(max_shapes['max_time_steps'] * 1.2)
|
||||
max_shapes['max_phone_seq_len'] = int(max_shapes['max_phone_seq_len'] * 1.2)
|
||||
max_shapes['max_transcription_len'] = int(max_shapes['max_transcription_len'] * 1.2)
|
||||
|
||||
print(f"📊 Dataset analysis complete (analyzed {count} samples):")
|
||||
print(f" Max time steps: {max_shapes['max_time_steps']}")
|
||||
print(f" Max phone sequence length: {max_shapes['max_phone_seq_len']}")
|
||||
print(f" Max transcription length: {max_shapes['max_transcription_len']}")
|
||||
|
||||
return max_shapes
|
||||
|
||||
|
||||
# Utility functions for TPU-optimized data pipeline
|
||||
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
transform_args: Dict[str, Any],
|
||||
training: bool = True,
|
||||
cache_path: Optional[str] = None) -> tf.data.Dataset:
|
||||
cache_path: Optional[str] = None,
|
||||
auto_analyze_shapes: bool = True) -> tf.data.Dataset:
|
||||
"""
|
||||
Create input function for TPU training with fixed-shape batching and data augmentation
|
||||
|
||||
@@ -769,6 +839,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
transform_args: Data transformation configuration
|
||||
training: Whether this is for training (applies augmentations)
|
||||
cache_path: Optional path for disk caching to improve I/O performance
|
||||
auto_analyze_shapes: Whether to automatically analyze dataset for optimal shapes
|
||||
|
||||
Returns:
|
||||
tf.data.Dataset ready for TPU training with fixed shapes
|
||||
@@ -820,12 +891,21 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
num_parallel_calls=tf.data.AUTOTUNE
|
||||
)
|
||||
|
||||
# Define fixed shapes for TPU compatibility
|
||||
# These should match the maximum expected sizes in your dataset
|
||||
max_time_steps = 4096 # Adjust based on your data
|
||||
max_phone_seq_len = 256 # Adjust based on your data
|
||||
max_transcription_len = 512 # Adjust based on your data
|
||||
n_features = 512 # Number of neural features
|
||||
# Determine shapes for TPU compatibility
|
||||
if auto_analyze_shapes:
|
||||
# Dynamically analyze dataset to determine optimal shapes
|
||||
shape_info = analyze_dataset_shapes(dataset_tf, sample_size=100)
|
||||
max_time_steps = shape_info['max_time_steps']
|
||||
max_phone_seq_len = shape_info['max_phone_seq_len']
|
||||
max_transcription_len = shape_info['max_transcription_len']
|
||||
n_features = shape_info['n_features']
|
||||
else:
|
||||
# Use conservative fixed shapes for TPU compatibility
|
||||
# Increased sizes to handle larger data - adjust based on your actual dataset
|
||||
max_time_steps = 8192 # Increased from 4096 - adjust based on your data
|
||||
max_phone_seq_len = 512 # Increased from 256 - adjust based on your data
|
||||
max_transcription_len = 1024 # Increased from 512 - adjust based on your data
|
||||
n_features = 512 # Number of neural features
|
||||
|
||||
padded_shapes = {
|
||||
'input_features': [max_time_steps, n_features],
|
||||
|
Reference in New Issue
Block a user