ff
This commit is contained in:
@@ -852,20 +852,22 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
|||||||
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||||
transform_args: Dict[str, Any],
|
transform_args: Dict[str, Any],
|
||||||
training: bool = True,
|
training: bool = True,
|
||||||
cache_path: Optional[str] = None,
|
cache_path: Optional[str] = None) -> tf.data.Dataset:
|
||||||
auto_analyze_shapes: bool = True) -> tf.data.Dataset:
|
|
||||||
"""
|
"""
|
||||||
Create input function for TPU training with fixed-shape batching and data augmentation
|
Create input function for TPU training with DYNAMIC padding and data augmentation
|
||||||
|
|
||||||
|
This function uses dynamic shapes to avoid the "pad to a smaller size" error.
|
||||||
|
All variable-length dimensions use tf.TensorShape([None, ...]) to allow
|
||||||
|
TensorFlow to automatically determine the appropriate padding size for each batch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_tf: BrainToTextDatasetTF instance
|
dataset_tf: BrainToTextDatasetTF instance
|
||||||
transform_args: Data transformation configuration
|
transform_args: Data transformation configuration
|
||||||
training: Whether this is for training (applies augmentations)
|
training: Whether this is for training (applies augmentations)
|
||||||
cache_path: Optional path for disk caching to improve I/O performance
|
cache_path: Optional path for disk caching to improve I/O performance
|
||||||
auto_analyze_shapes: Whether to automatically analyze dataset for optimal shapes
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tf.data.Dataset ready for TPU training with fixed shapes
|
tf.data.Dataset ready for TPU training with dynamic shapes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create individual example dataset with file-grouping I/O optimization
|
# Create individual example dataset with file-grouping I/O optimization
|
||||||
@@ -914,58 +916,27 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
|||||||
num_parallel_calls=tf.data.AUTOTUNE
|
num_parallel_calls=tf.data.AUTOTUNE
|
||||||
)
|
)
|
||||||
|
|
||||||
# Determine shapes for TPU compatibility
|
# ========================= DYNAMIC SHAPES SOLUTION =========================
|
||||||
if auto_analyze_shapes:
|
# 使用动态形状避免 "pad to a smaller size" 错误
|
||||||
# Dynamically analyze dataset to determine optimal shapes
|
# 这是最简单、最健壮的解决方案
|
||||||
shape_info = analyze_dataset_shapes(dataset_tf, sample_size=100)
|
print("🔧 Using DYNAMIC shapes for maximum compatibility and robustness.")
|
||||||
max_time_steps = shape_info['max_time_steps']
|
|
||||||
max_phone_seq_len = shape_info['max_phone_seq_len']
|
|
||||||
max_transcription_len = shape_info['max_transcription_len']
|
|
||||||
n_features = shape_info['n_features']
|
|
||||||
print(f"🔧 Using auto-analyzed shapes: time_steps={max_time_steps}, phone_seq={max_phone_seq_len}, transcription={max_transcription_len}")
|
|
||||||
else:
|
|
||||||
# Use dynamic shapes for maximum compatibility - let TensorFlow handle padding automatically
|
|
||||||
# This avoids the "pad to a smaller size" error by allowing dynamic sizing
|
|
||||||
print(f"🔧 Using dynamic shapes for maximum compatibility")
|
|
||||||
# Calculate number of features based on subset
|
|
||||||
n_features = len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
|
|
||||||
|
|
||||||
padded_shapes = {
|
# Calculate number of features based on subset
|
||||||
'input_features': tf.TensorShape([None, n_features]),
|
n_features = len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
|
||||||
'seq_class_ids': tf.TensorShape([None]),
|
|
||||||
'n_time_steps': tf.TensorShape([]), # Scalar
|
|
||||||
'phone_seq_lens': tf.TensorShape([]), # Scalar
|
|
||||||
'day_indices': tf.TensorShape([]), # Scalar
|
|
||||||
'transcriptions': tf.TensorShape([None]),
|
|
||||||
'block_nums': tf.TensorShape([]), # Scalar
|
|
||||||
'trial_nums': tf.TensorShape([]) # Scalar
|
|
||||||
}
|
|
||||||
|
|
||||||
# Create fixed-shape batches with dynamic padding
|
# Define dynamic padded shapes - all variable dimensions use None
|
||||||
dataset = dataset.padded_batch(
|
|
||||||
batch_size=dataset_tf.batch_size,
|
|
||||||
padded_shapes=padded_shapes,
|
|
||||||
padding_values=padding_values,
|
|
||||||
drop_remainder=True # Critical for TPU: ensures all batches have same size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prefetch for optimal performance
|
|
||||||
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
# If using auto-analyzed shapes, create fixed-size padded shapes
|
|
||||||
padded_shapes = {
|
padded_shapes = {
|
||||||
'input_features': [max_time_steps, n_features],
|
'input_features': tf.TensorShape([None, n_features]), # 时间维度动态
|
||||||
'seq_class_ids': [max_phone_seq_len],
|
'seq_class_ids': tf.TensorShape([None]), # 序列长度动态
|
||||||
'n_time_steps': [], # Scalar
|
'n_time_steps': tf.TensorShape([]), # 标量
|
||||||
'phone_seq_lens': [], # Scalar
|
'phone_seq_lens': tf.TensorShape([]), # 标量
|
||||||
'day_indices': [], # Scalar
|
'day_indices': tf.TensorShape([]), # 标量
|
||||||
'transcriptions': [max_transcription_len],
|
'transcriptions': tf.TensorShape([None]), # 转录长度动态
|
||||||
'block_nums': [], # Scalar
|
'block_nums': tf.TensorShape([]), # 标量
|
||||||
'trial_nums': [] # Scalar
|
'trial_nums': tf.TensorShape([]) # 标量
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Define padding values for each field
|
||||||
padding_values = {
|
padding_values = {
|
||||||
'input_features': 0.0,
|
'input_features': 0.0,
|
||||||
'seq_class_ids': 0,
|
'seq_class_ids': 0,
|
||||||
@@ -977,7 +948,8 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
|||||||
'trial_nums': 0
|
'trial_nums': 0
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create fixed-shape batches with padding
|
# Create batches with dynamic padding - TensorFlow will automatically
|
||||||
|
# determine the appropriate padding size for each batch
|
||||||
dataset = dataset.padded_batch(
|
dataset = dataset.padded_batch(
|
||||||
batch_size=dataset_tf.batch_size,
|
batch_size=dataset_tf.batch_size,
|
||||||
padded_shapes=padded_shapes,
|
padded_shapes=padded_shapes,
|
||||||
|
Reference in New Issue
Block a user