This commit is contained in:
Zchen
2025-10-19 20:16:23 +08:00
parent 40d0fc50de
commit 4b373ab317

View File

@@ -852,20 +852,22 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
def create_input_fn(dataset_tf: BrainToTextDatasetTF, def create_input_fn(dataset_tf: BrainToTextDatasetTF,
transform_args: Dict[str, Any], transform_args: Dict[str, Any],
training: bool = True, training: bool = True,
cache_path: Optional[str] = None, cache_path: Optional[str] = None) -> tf.data.Dataset:
auto_analyze_shapes: bool = True) -> tf.data.Dataset:
""" """
Create input function for TPU training with fixed-shape batching and data augmentation Create input function for TPU training with DYNAMIC padding and data augmentation
This function uses dynamic shapes to avoid the "pad to a smaller size" error.
All variable-length dimensions use tf.TensorShape([None, ...]) to allow
TensorFlow to automatically determine the appropriate padding size for each batch.
Args: Args:
dataset_tf: BrainToTextDatasetTF instance dataset_tf: BrainToTextDatasetTF instance
transform_args: Data transformation configuration transform_args: Data transformation configuration
training: Whether this is for training (applies augmentations) training: Whether this is for training (applies augmentations)
cache_path: Optional path for disk caching to improve I/O performance cache_path: Optional path for disk caching to improve I/O performance
auto_analyze_shapes: Whether to automatically analyze dataset for optimal shapes
Returns: Returns:
tf.data.Dataset ready for TPU training with fixed shapes tf.data.Dataset ready for TPU training with dynamic shapes
""" """
# Create individual example dataset with file-grouping I/O optimization # Create individual example dataset with file-grouping I/O optimization
@@ -914,58 +916,27 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
num_parallel_calls=tf.data.AUTOTUNE num_parallel_calls=tf.data.AUTOTUNE
) )
# Determine shapes for TPU compatibility # ========================= DYNAMIC SHAPES SOLUTION =========================
if auto_analyze_shapes: # 使用动态形状避免 "pad to a smaller size" 错误
# Dynamically analyze dataset to determine optimal shapes # 这是最简单、最健壮的解决方案
shape_info = analyze_dataset_shapes(dataset_tf, sample_size=100) print("🔧 Using DYNAMIC shapes for maximum compatibility and robustness.")
max_time_steps = shape_info['max_time_steps']
max_phone_seq_len = shape_info['max_phone_seq_len']
max_transcription_len = shape_info['max_transcription_len']
n_features = shape_info['n_features']
print(f"🔧 Using auto-analyzed shapes: time_steps={max_time_steps}, phone_seq={max_phone_seq_len}, transcription={max_transcription_len}")
else:
# Use dynamic shapes for maximum compatibility - let TensorFlow handle padding automatically
# This avoids the "pad to a smaller size" error by allowing dynamic sizing
print(f"🔧 Using dynamic shapes for maximum compatibility")
# Calculate number of features based on subset
n_features = len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
padded_shapes = { # Calculate number of features based on subset
'input_features': tf.TensorShape([None, n_features]), n_features = len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
'seq_class_ids': tf.TensorShape([None]),
'n_time_steps': tf.TensorShape([]), # Scalar
'phone_seq_lens': tf.TensorShape([]), # Scalar
'day_indices': tf.TensorShape([]), # Scalar
'transcriptions': tf.TensorShape([None]),
'block_nums': tf.TensorShape([]), # Scalar
'trial_nums': tf.TensorShape([]) # Scalar
}
# Create fixed-shape batches with dynamic padding # Define dynamic padded shapes - all variable dimensions use None
dataset = dataset.padded_batch(
batch_size=dataset_tf.batch_size,
padded_shapes=padded_shapes,
padding_values=padding_values,
drop_remainder=True # Critical for TPU: ensures all batches have same size
)
# Prefetch for optimal performance
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset
# If using auto-analyzed shapes, create fixed-size padded shapes
padded_shapes = { padded_shapes = {
'input_features': [max_time_steps, n_features], 'input_features': tf.TensorShape([None, n_features]), # 时间维度动态
'seq_class_ids': [max_phone_seq_len], 'seq_class_ids': tf.TensorShape([None]), # 序列长度动态
'n_time_steps': [], # Scalar 'n_time_steps': tf.TensorShape([]), # 标量
'phone_seq_lens': [], # Scalar 'phone_seq_lens': tf.TensorShape([]), # 标量
'day_indices': [], # Scalar 'day_indices': tf.TensorShape([]), # 标量
'transcriptions': [max_transcription_len], 'transcriptions': tf.TensorShape([None]), # 转录长度动态
'block_nums': [], # Scalar 'block_nums': tf.TensorShape([]), # 标量
'trial_nums': [] # Scalar 'trial_nums': tf.TensorShape([]) # 标量
} }
# Define padding values for each field
padding_values = { padding_values = {
'input_features': 0.0, 'input_features': 0.0,
'seq_class_ids': 0, 'seq_class_ids': 0,
@@ -977,7 +948,8 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
'trial_nums': 0 'trial_nums': 0
} }
# Create fixed-shape batches with padding # Create batches with dynamic padding - TensorFlow will automatically
# determine the appropriate padding size for each batch
dataset = dataset.padded_batch( dataset = dataset.padded_batch(
batch_size=dataset_tf.batch_size, batch_size=dataset_tf.batch_size,
padded_shapes=padded_shapes, padded_shapes=padded_shapes,