This commit is contained in:
Zchen
2025-10-19 20:16:23 +08:00
parent 40d0fc50de
commit 4b373ab317

View File

@@ -852,20 +852,22 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
transform_args: Dict[str, Any],
training: bool = True,
cache_path: Optional[str] = None,
auto_analyze_shapes: bool = True) -> tf.data.Dataset:
cache_path: Optional[str] = None) -> tf.data.Dataset:
"""
Create input function for TPU training with fixed-shape batching and data augmentation
Create input function for TPU training with DYNAMIC padding and data augmentation
This function uses dynamic shapes to avoid the "pad to a smaller size" error.
All variable-length dimensions use tf.TensorShape([None, ...]) to allow
TensorFlow to automatically determine the appropriate padding size for each batch.
Args:
dataset_tf: BrainToTextDatasetTF instance
transform_args: Data transformation configuration
training: Whether this is for training (applies augmentations)
cache_path: Optional path for disk caching to improve I/O performance
auto_analyze_shapes: Whether to automatically analyze dataset for optimal shapes
Returns:
tf.data.Dataset ready for TPU training with fixed shapes
tf.data.Dataset ready for TPU training with dynamic shapes
"""
# Create individual example dataset with file-grouping I/O optimization
@@ -914,58 +916,27 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
num_parallel_calls=tf.data.AUTOTUNE
)
# Determine shapes for TPU compatibility
if auto_analyze_shapes:
# Dynamically analyze dataset to determine optimal shapes
shape_info = analyze_dataset_shapes(dataset_tf, sample_size=100)
max_time_steps = shape_info['max_time_steps']
max_phone_seq_len = shape_info['max_phone_seq_len']
max_transcription_len = shape_info['max_transcription_len']
n_features = shape_info['n_features']
print(f"🔧 Using auto-analyzed shapes: time_steps={max_time_steps}, phone_seq={max_phone_seq_len}, transcription={max_transcription_len}")
else:
# Use dynamic shapes for maximum compatibility - let TensorFlow handle padding automatically
# This avoids the "pad to a smaller size" error by allowing dynamic sizing
print(f"🔧 Using dynamic shapes for maximum compatibility")
# ========================= DYNAMIC SHAPES SOLUTION =========================
# 使用动态形状避免 "pad to a smaller size" 错误
# 这是最简单、最健壮的解决方案
print("🔧 Using DYNAMIC shapes for maximum compatibility and robustness.")
# Calculate number of features based on subset
n_features = len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
# Define dynamic padded shapes - all variable dimensions use None
padded_shapes = {
'input_features': tf.TensorShape([None, n_features]),
'seq_class_ids': tf.TensorShape([None]),
'n_time_steps': tf.TensorShape([]), # Scalar
'phone_seq_lens': tf.TensorShape([]), # Scalar
'day_indices': tf.TensorShape([]), # Scalar
'transcriptions': tf.TensorShape([None]),
'block_nums': tf.TensorShape([]), # Scalar
'trial_nums': tf.TensorShape([]) # Scalar
}
# Create fixed-shape batches with dynamic padding
dataset = dataset.padded_batch(
batch_size=dataset_tf.batch_size,
padded_shapes=padded_shapes,
padding_values=padding_values,
drop_remainder=True # Critical for TPU: ensures all batches have same size
)
# Prefetch for optimal performance
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset
# If using auto-analyzed shapes, create fixed-size padded shapes
padded_shapes = {
'input_features': [max_time_steps, n_features],
'seq_class_ids': [max_phone_seq_len],
'n_time_steps': [], # Scalar
'phone_seq_lens': [], # Scalar
'day_indices': [], # Scalar
'transcriptions': [max_transcription_len],
'block_nums': [], # Scalar
'trial_nums': [] # Scalar
'input_features': tf.TensorShape([None, n_features]), # 时间维度动态
'seq_class_ids': tf.TensorShape([None]), # 序列长度动态
'n_time_steps': tf.TensorShape([]), # 标量
'phone_seq_lens': tf.TensorShape([]), # 标量
'day_indices': tf.TensorShape([]), # 标量
'transcriptions': tf.TensorShape([None]), # 转录长度动态
'block_nums': tf.TensorShape([]), # 标量
'trial_nums': tf.TensorShape([]) # 标量
}
# Define padding values for each field
padding_values = {
'input_features': 0.0,
'seq_class_ids': 0,
@@ -977,7 +948,8 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
'trial_nums': 0
}
# Create fixed-shape batches with padding
# Create batches with dynamic padding - TensorFlow will automatically
# determine the appropriate padding size for each batch
dataset = dataset.padded_batch(
batch_size=dataset_tf.batch_size,
padded_shapes=padded_shapes,