ff
This commit is contained in:
@@ -852,20 +852,22 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
||||
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
transform_args: Dict[str, Any],
|
||||
training: bool = True,
|
||||
cache_path: Optional[str] = None,
|
||||
auto_analyze_shapes: bool = True) -> tf.data.Dataset:
|
||||
cache_path: Optional[str] = None) -> tf.data.Dataset:
|
||||
"""
|
||||
Create input function for TPU training with fixed-shape batching and data augmentation
|
||||
Create input function for TPU training with DYNAMIC padding and data augmentation
|
||||
|
||||
This function uses dynamic shapes to avoid the "pad to a smaller size" error.
|
||||
All variable-length dimensions use tf.TensorShape([None, ...]) to allow
|
||||
TensorFlow to automatically determine the appropriate padding size for each batch.
|
||||
|
||||
Args:
|
||||
dataset_tf: BrainToTextDatasetTF instance
|
||||
transform_args: Data transformation configuration
|
||||
training: Whether this is for training (applies augmentations)
|
||||
cache_path: Optional path for disk caching to improve I/O performance
|
||||
auto_analyze_shapes: Whether to automatically analyze dataset for optimal shapes
|
||||
|
||||
Returns:
|
||||
tf.data.Dataset ready for TPU training with fixed shapes
|
||||
tf.data.Dataset ready for TPU training with dynamic shapes
|
||||
"""
|
||||
|
||||
# Create individual example dataset with file-grouping I/O optimization
|
||||
@@ -914,58 +916,27 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
num_parallel_calls=tf.data.AUTOTUNE
|
||||
)
|
||||
|
||||
# Determine shapes for TPU compatibility
|
||||
if auto_analyze_shapes:
|
||||
# Dynamically analyze dataset to determine optimal shapes
|
||||
shape_info = analyze_dataset_shapes(dataset_tf, sample_size=100)
|
||||
max_time_steps = shape_info['max_time_steps']
|
||||
max_phone_seq_len = shape_info['max_phone_seq_len']
|
||||
max_transcription_len = shape_info['max_transcription_len']
|
||||
n_features = shape_info['n_features']
|
||||
print(f"🔧 Using auto-analyzed shapes: time_steps={max_time_steps}, phone_seq={max_phone_seq_len}, transcription={max_transcription_len}")
|
||||
else:
|
||||
# Use dynamic shapes for maximum compatibility - let TensorFlow handle padding automatically
|
||||
# This avoids the "pad to a smaller size" error by allowing dynamic sizing
|
||||
print(f"🔧 Using dynamic shapes for maximum compatibility")
|
||||
# Calculate number of features based on subset
|
||||
n_features = len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
|
||||
# ========================= DYNAMIC SHAPES SOLUTION =========================
|
||||
# 使用动态形状避免 "pad to a smaller size" 错误
|
||||
# 这是最简单、最健壮的解决方案
|
||||
print("🔧 Using DYNAMIC shapes for maximum compatibility and robustness.")
|
||||
|
||||
padded_shapes = {
|
||||
'input_features': tf.TensorShape([None, n_features]),
|
||||
'seq_class_ids': tf.TensorShape([None]),
|
||||
'n_time_steps': tf.TensorShape([]), # Scalar
|
||||
'phone_seq_lens': tf.TensorShape([]), # Scalar
|
||||
'day_indices': tf.TensorShape([]), # Scalar
|
||||
'transcriptions': tf.TensorShape([None]),
|
||||
'block_nums': tf.TensorShape([]), # Scalar
|
||||
'trial_nums': tf.TensorShape([]) # Scalar
|
||||
}
|
||||
# Calculate number of features based on subset
|
||||
n_features = len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
|
||||
|
||||
# Create fixed-shape batches with dynamic padding
|
||||
dataset = dataset.padded_batch(
|
||||
batch_size=dataset_tf.batch_size,
|
||||
padded_shapes=padded_shapes,
|
||||
padding_values=padding_values,
|
||||
drop_remainder=True # Critical for TPU: ensures all batches have same size
|
||||
)
|
||||
|
||||
# Prefetch for optimal performance
|
||||
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
||||
|
||||
return dataset
|
||||
|
||||
# If using auto-analyzed shapes, create fixed-size padded shapes
|
||||
# Define dynamic padded shapes - all variable dimensions use None
|
||||
padded_shapes = {
|
||||
'input_features': [max_time_steps, n_features],
|
||||
'seq_class_ids': [max_phone_seq_len],
|
||||
'n_time_steps': [], # Scalar
|
||||
'phone_seq_lens': [], # Scalar
|
||||
'day_indices': [], # Scalar
|
||||
'transcriptions': [max_transcription_len],
|
||||
'block_nums': [], # Scalar
|
||||
'trial_nums': [] # Scalar
|
||||
'input_features': tf.TensorShape([None, n_features]), # 时间维度动态
|
||||
'seq_class_ids': tf.TensorShape([None]), # 序列长度动态
|
||||
'n_time_steps': tf.TensorShape([]), # 标量
|
||||
'phone_seq_lens': tf.TensorShape([]), # 标量
|
||||
'day_indices': tf.TensorShape([]), # 标量
|
||||
'transcriptions': tf.TensorShape([None]), # 转录长度动态
|
||||
'block_nums': tf.TensorShape([]), # 标量
|
||||
'trial_nums': tf.TensorShape([]) # 标量
|
||||
}
|
||||
|
||||
# Define padding values for each field
|
||||
padding_values = {
|
||||
'input_features': 0.0,
|
||||
'seq_class_ids': 0,
|
||||
@@ -977,7 +948,8 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||
'trial_nums': 0
|
||||
}
|
||||
|
||||
# Create fixed-shape batches with padding
|
||||
# Create batches with dynamic padding - TensorFlow will automatically
|
||||
# determine the appropriate padding size for each batch
|
||||
dataset = dataset.padded_batch(
|
||||
batch_size=dataset_tf.batch_size,
|
||||
padded_shapes=padded_shapes,
|
||||
|
Reference in New Issue
Block a user