f
This commit is contained in:
@@ -908,111 +908,56 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
|||||||
# Utility functions for TPU-optimized data pipeline
|
# Utility functions for TPU-optimized data pipeline
|
||||||
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||||
transform_args: Dict[str, Any],
|
transform_args: Dict[str, Any],
|
||||||
max_shapes: Dict[str, int],
|
|
||||||
training: bool = True,
|
training: bool = True,
|
||||||
cache_path: Optional[str] = None) -> tf.data.Dataset:
|
cache_path: Optional[str] = None) -> tf.data.Dataset:
|
||||||
"""
|
"""
|
||||||
Create input function for TPU training with PRE-ANALYZED FIXED shapes
|
Create input function for TPU training with DYNAMIC batching -> BATCH augmentation
|
||||||
|
|
||||||
This function uses pre-computed maximum shapes to create STATIC-size batches,
|
This function uses the proven "batch first, augment after" approach that eliminates
|
||||||
ensuring XLA compilation success on TPU hardware. This is CRITICAL for the
|
the time paradox between data augmentation and shape analysis. This is the FINAL
|
||||||
final resolution of both CTC loss compatibility and graph structure issues.
|
solution that resolves all XLA compilation and padding errors.
|
||||||
|
|
||||||
|
The key insight: data augmentation (especially gauss_smooth with padding='SAME')
|
||||||
|
can increase sequence lengths unpredictably, making pre-computed static shapes invalid.
|
||||||
|
By batching first with dynamic padding, then applying augmentation to batches,
|
||||||
|
we eliminate this temporal paradox entirely.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_tf: BrainToTextDatasetTF instance
|
dataset_tf: BrainToTextDatasetTF instance
|
||||||
transform_args: Data transformation configuration
|
transform_args: Data transformation configuration
|
||||||
max_shapes: Pre-computed maximum shapes dictionary with keys:
|
|
||||||
'max_time_steps', 'max_phone_seq_len', 'max_transcription_len', 'n_features'
|
|
||||||
training: Whether this is for training (applies augmentations)
|
training: Whether this is for training (applies augmentations)
|
||||||
cache_path: Optional path for disk caching to improve I/O performance
|
cache_path: Optional path for disk caching to improve I/O performance
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tf.data.Dataset ready for TPU training with FIXED STATIC shapes
|
tf.data.Dataset ready for TPU training with robust dynamic->static flow
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Step 1: Create individual example dataset with file-grouping I/O optimization
|
# Step 1: Create individual example dataset
|
||||||
dataset = dataset_tf.create_individual_dataset()
|
dataset = dataset_tf.create_individual_dataset()
|
||||||
|
|
||||||
# Step 2: Cache raw samples BEFORE any augmentation
|
# Step 2: Cache raw samples BEFORE any augmentation or batching
|
||||||
if cache_path:
|
if cache_path:
|
||||||
dataset = dataset.cache(cache_path)
|
dataset = dataset.cache(cache_path)
|
||||||
split_name = "training" if training else "validation"
|
split_name = "training" if training else "validation"
|
||||||
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: {cache_path}")
|
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: {cache_path}")
|
||||||
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
|
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
|
||||||
else:
|
else:
|
||||||
# 如果没有指定缓存路径,默认使用内存缓存
|
|
||||||
dataset = dataset.cache()
|
dataset = dataset.cache()
|
||||||
split_name = "training" if training else "validation"
|
split_name = "training" if training else "validation"
|
||||||
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache")
|
print(f"🗃️ {split_name.capitalize()} dataset caching enabled: in-memory cache")
|
||||||
print(f"⚠️ First access will be slow while building {split_name} cache, subsequent access will be much faster")
|
|
||||||
|
|
||||||
# Step 3: Apply transformations to individual examples BEFORE batching
|
# Step 3: Batch samples with DYNAMIC padding FIRST (eliminates time paradox)
|
||||||
def apply_transforms(example):
|
print(f"🔧 Using DYNAMIC padding -> batch augmentation approach")
|
||||||
"""Apply data transformations to individual examples"""
|
print(f"🔧 Feature dimension: {dataset_tf.feature_dim}")
|
||||||
features = example['input_features']
|
|
||||||
n_time_steps = example['n_time_steps']
|
|
||||||
|
|
||||||
# Apply transformations
|
# Define dynamic padded shapes - key insight: None allows for dynamic lengths
|
||||||
features, n_time_steps = DataAugmentationTF.transform_data(
|
|
||||||
tf.expand_dims(features, 0), # Add batch dimension for transforms
|
|
||||||
tf.expand_dims(n_time_steps, 0),
|
|
||||||
transform_args,
|
|
||||||
training=training
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove batch dimension
|
|
||||||
example['input_features'] = tf.squeeze(features, 0)
|
|
||||||
example['n_time_steps'] = tf.squeeze(n_time_steps, 0)
|
|
||||||
|
|
||||||
return example
|
|
||||||
|
|
||||||
# Apply transforms to cached data
|
|
||||||
dataset = dataset.map(
|
|
||||||
apply_transforms,
|
|
||||||
num_parallel_calls=tf.data.AUTOTUNE
|
|
||||||
)
|
|
||||||
|
|
||||||
# ========================= 终极调试代码 =========================
|
|
||||||
def debug_print_shape(example):
|
|
||||||
"""调试函数:在 padded_batch 之前打印每个样本的形状"""
|
|
||||||
tf.print("🔍 Sample Shape Debug:",
|
|
||||||
tf.shape(example['input_features']),
|
|
||||||
"Expected feature dim:", dataset_tf.feature_dim,
|
|
||||||
output_stream=sys.stdout)
|
|
||||||
return example
|
|
||||||
|
|
||||||
# 添加形状调试 - 这会在图执行时打印信息
|
|
||||||
dataset = dataset.map(debug_print_shape)
|
|
||||||
print(f"⚠️ Debug mode: Will print each sample shape before padded_batch")
|
|
||||||
# =============================================================
|
|
||||||
|
|
||||||
# Step 4: Batch samples with FIXED STATIC padding (CRITICAL for XLA)
|
|
||||||
print(f"🔧 Using PRE-ANALYZED FIXED shapes for maximum TPU performance:")
|
|
||||||
|
|
||||||
# Extract pre-analyzed shape information
|
|
||||||
max_time_steps = max_shapes['max_time_steps']
|
|
||||||
max_phone_seq_len = max_shapes['max_phone_seq_len']
|
|
||||||
max_transcription_len = max_shapes['max_transcription_len']
|
|
||||||
|
|
||||||
# ========================= 使用统一的特征维度 =========================
|
|
||||||
# 使用 dataset_tf 对象中存储的、经过验证的特征维度,而不是依赖外部参数
|
|
||||||
n_features = dataset_tf.feature_dim # <--- 关键修改:使用自动检测的特征维度
|
|
||||||
print(f"🔧 Using verified feature dimension from dataset: {n_features}")
|
|
||||||
# ========================= 特征维度修改结束 =========================
|
|
||||||
|
|
||||||
print(f" Fixed time steps: {max_time_steps}")
|
|
||||||
print(f" Fixed phone sequence length: {max_phone_seq_len}")
|
|
||||||
print(f" Fixed transcription length: {max_transcription_len}")
|
|
||||||
print(f" Number of features: {n_features}")
|
|
||||||
|
|
||||||
# Define FIXED padded shapes with TensorSpec for better type safety
|
|
||||||
padded_shapes = {
|
padded_shapes = {
|
||||||
'input_features': tf.TensorSpec(shape=[max_time_steps, n_features], dtype=tf.float32),
|
'input_features': tf.TensorSpec(shape=(None, dataset_tf.feature_dim), dtype=tf.float32),
|
||||||
'seq_class_ids': tf.TensorSpec(shape=[max_phone_seq_len], dtype=tf.int32),
|
'seq_class_ids': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||||||
'n_time_steps': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
|
'n_time_steps': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
|
||||||
'phone_seq_lens': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
|
'phone_seq_lens': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
|
||||||
'day_indices': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
|
'day_indices': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
|
||||||
'transcriptions': tf.TensorSpec(shape=[max_transcription_len], dtype=tf.int32),
|
'transcriptions': tf.TensorSpec(shape=(None,), dtype=tf.int32),
|
||||||
'block_nums': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
|
'block_nums': tf.TensorSpec(shape=[], dtype=tf.int32), # scalar
|
||||||
'trial_nums': tf.TensorSpec(shape=[], dtype=tf.int32) # scalar
|
'trial_nums': tf.TensorSpec(shape=[], dtype=tf.int32) # scalar
|
||||||
}
|
}
|
||||||
@@ -1029,7 +974,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
|||||||
'trial_nums': 0
|
'trial_nums': 0
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create batches with FIXED padding - XLA compiler will be happy!
|
# Create batches with DYNAMIC padding - this cannot fail due to size mismatches
|
||||||
dataset = dataset.padded_batch(
|
dataset = dataset.padded_batch(
|
||||||
batch_size=dataset_tf.batch_size,
|
batch_size=dataset_tf.batch_size,
|
||||||
padded_shapes=padded_shapes,
|
padded_shapes=padded_shapes,
|
||||||
@@ -1037,7 +982,37 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
|||||||
drop_remainder=True # Critical for TPU: ensures all batches have same size
|
drop_remainder=True # Critical for TPU: ensures all batches have same size
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prefetch for optimal performance
|
# Step 4: Apply data augmentation to BATCHES (after dynamic batching)
|
||||||
|
def apply_batch_transforms(batch):
|
||||||
|
"""Apply data transformations to entire batches - resolves time paradox"""
|
||||||
|
features = batch['input_features']
|
||||||
|
n_time_steps = batch['n_time_steps']
|
||||||
|
|
||||||
|
# Apply transformations to the entire batch
|
||||||
|
features, n_time_steps = DataAugmentationTF.transform_data(
|
||||||
|
features, # Already has batch dimension
|
||||||
|
n_time_steps,
|
||||||
|
transform_args,
|
||||||
|
training=training
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update batch with transformed data
|
||||||
|
batch['input_features'] = features
|
||||||
|
batch['n_time_steps'] = n_time_steps
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
# Apply batch transforms only during training
|
||||||
|
if training:
|
||||||
|
dataset = dataset.map(
|
||||||
|
apply_batch_transforms,
|
||||||
|
num_parallel_calls=tf.data.AUTOTUNE
|
||||||
|
)
|
||||||
|
print(f"✅ Batch augmentation enabled for training")
|
||||||
|
else:
|
||||||
|
print(f"✅ No augmentation for validation")
|
||||||
|
|
||||||
|
# Step 5: Prefetch for optimal performance
|
||||||
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
@@ -17,8 +17,55 @@ except ImportError:
|
|||||||
print("Warning: editdistance not available, falling back to approximation")
|
print("Warning: editdistance not available, falling back to approximation")
|
||||||
editdistance = None
|
editdistance = None
|
||||||
|
|
||||||
# XLA-compatible CTC loss implementation
|
# Note: Reverted to standard tf.nn.ctc_loss + SparseTensor approach
|
||||||
from tf_seq2seq_losses import classic_ctc_loss
|
# for compatibility with "batch first, augment after" data pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def dense_to_sparse(dense_tensor, sequence_lengths):
|
||||||
|
"""
|
||||||
|
Convert dense tensor to sparse tensor for CTC loss with dynamic shapes
|
||||||
|
|
||||||
|
This function is essential for the "batch first, augment after" approach
|
||||||
|
as it handles the conversion from dynamic dense tensors to SparseTensor
|
||||||
|
format required by tf.nn.ctc_loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dense_tensor: Dense tensor with shape [batch_size, max_seq_len]
|
||||||
|
sequence_lengths: Actual sequence lengths [batch_size]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SparseTensor suitable for tf.nn.ctc_loss
|
||||||
|
"""
|
||||||
|
# Create mask for valid (non-zero) elements within sequence lengths
|
||||||
|
batch_size = tf.shape(dense_tensor)[0]
|
||||||
|
max_seq_len = tf.shape(dense_tensor)[1]
|
||||||
|
|
||||||
|
# Create range indices
|
||||||
|
batch_indices = tf.range(batch_size)
|
||||||
|
seq_indices = tf.range(max_seq_len)
|
||||||
|
|
||||||
|
# Create meshgrid for sequence dimensions
|
||||||
|
_, seq_mesh = tf.meshgrid(batch_indices, seq_indices, indexing='ij')
|
||||||
|
|
||||||
|
# Create mask based on sequence lengths and non-zero values
|
||||||
|
length_mask = seq_mesh < tf.expand_dims(sequence_lengths, 1)
|
||||||
|
value_mask = tf.not_equal(dense_tensor, 0)
|
||||||
|
combined_mask = tf.logical_and(length_mask, value_mask)
|
||||||
|
|
||||||
|
# Get indices of valid elements
|
||||||
|
indices = tf.where(combined_mask)
|
||||||
|
|
||||||
|
# Get values at valid indices
|
||||||
|
values = tf.gather_nd(dense_tensor, indices)
|
||||||
|
|
||||||
|
# Create sparse tensor
|
||||||
|
dense_shape = tf.cast(tf.shape(dense_tensor), tf.int64)
|
||||||
|
|
||||||
|
return tf.SparseTensor(
|
||||||
|
indices=tf.cast(indices, tf.int64),
|
||||||
|
values=tf.cast(values, tf.int32),
|
||||||
|
dense_shape=dense_shape
|
||||||
|
)
|
||||||
|
|
||||||
from rnn_model_tf import (
|
from rnn_model_tf import (
|
||||||
TripleGRUDecoder,
|
TripleGRUDecoder,
|
||||||
@@ -559,23 +606,29 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
# Calculate losses using TPU-compatible CTC implementation
|
# Calculate losses using TPU-compatible CTC implementation
|
||||||
if use_full:
|
if use_full:
|
||||||
# Clean CTC loss - using XLA-compatible classic_ctc_loss
|
# Clean CTC loss - using standard tf.nn.ctc_loss with SparseTensor
|
||||||
clean_loss = classic_ctc_loss(
|
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
||||||
labels=tf.cast(labels, tf.int32), # Dense labels as int32
|
clean_loss = tf.nn.ctc_loss(
|
||||||
|
labels=sparse_labels,
|
||||||
logits=clean_logits,
|
logits=clean_logits,
|
||||||
label_length=phone_seq_lens,
|
label_length=None, # SparseTensor doesn't need label_length
|
||||||
logit_length=adjusted_lens,
|
logit_length=adjusted_lens,
|
||||||
|
logits_time_major=False,
|
||||||
blank_index=0
|
blank_index=0
|
||||||
)
|
)
|
||||||
|
clean_loss = tf.reduce_mean(clean_loss)
|
||||||
|
|
||||||
# Noisy CTC loss - using XLA-compatible classic_ctc_loss
|
# Noisy CTC loss - using standard tf.nn.ctc_loss with SparseTensor
|
||||||
noisy_loss = classic_ctc_loss(
|
# Reuse the same sparse_labels from above
|
||||||
labels=tf.cast(labels, tf.int32), # Dense labels as int32
|
noisy_loss = tf.nn.ctc_loss(
|
||||||
|
labels=sparse_labels,
|
||||||
logits=noisy_logits,
|
logits=noisy_logits,
|
||||||
label_length=phone_seq_lens,
|
label_length=None, # SparseTensor doesn't need label_length
|
||||||
logit_length=adjusted_lens,
|
logit_length=adjusted_lens,
|
||||||
|
logits_time_major=False,
|
||||||
blank_index=0
|
blank_index=0
|
||||||
)
|
)
|
||||||
|
noisy_loss = tf.reduce_mean(noisy_loss)
|
||||||
|
|
||||||
# Optional noise L2 regularization
|
# Optional noise L2 regularization
|
||||||
noise_l2 = tf.constant(0.0, dtype=clean_loss.dtype)
|
noise_l2 = tf.constant(0.0, dtype=clean_loss.dtype)
|
||||||
@@ -584,14 +637,17 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
|
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
|
||||||
else:
|
else:
|
||||||
# Standard CTC loss - using XLA-compatible classic_ctc_loss
|
# Standard CTC loss - using standard tf.nn.ctc_loss with SparseTensor
|
||||||
loss = classic_ctc_loss(
|
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
||||||
labels=tf.cast(labels, tf.int32), # Dense labels as int32
|
loss = tf.nn.ctc_loss(
|
||||||
|
labels=sparse_labels,
|
||||||
logits=clean_logits,
|
logits=clean_logits,
|
||||||
label_length=phone_seq_lens,
|
label_length=None, # SparseTensor doesn't need label_length
|
||||||
logit_length=adjusted_lens,
|
logit_length=adjusted_lens,
|
||||||
|
logits_time_major=False,
|
||||||
blank_index=0
|
blank_index=0
|
||||||
)
|
)
|
||||||
|
loss = tf.reduce_mean(loss)
|
||||||
|
|
||||||
# AdamW handles weight decay automatically - no manual L2 regularization needed
|
# AdamW handles weight decay automatically - no manual L2 regularization needed
|
||||||
# TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理
|
# TensorFlow混合精度处理:不需要手动scaling,Keras policy自动处理
|
||||||
@@ -646,14 +702,17 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
# Forward pass (inference mode only)
|
# Forward pass (inference mode only)
|
||||||
logits = self.model(features, day_indices, None, False, 'inference', training=False)
|
logits = self.model(features, day_indices, None, False, 'inference', training=False)
|
||||||
|
|
||||||
# Calculate loss using XLA-compatible classic_ctc_loss
|
# Calculate loss using standard tf.nn.ctc_loss with SparseTensor
|
||||||
loss = classic_ctc_loss(
|
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
||||||
labels=tf.cast(labels, tf.int32), # Dense labels as int32
|
loss = tf.nn.ctc_loss(
|
||||||
|
labels=sparse_labels,
|
||||||
logits=logits,
|
logits=logits,
|
||||||
label_length=phone_seq_lens,
|
label_length=None, # SparseTensor doesn't need label_length
|
||||||
logit_length=adjusted_lens,
|
logit_length=adjusted_lens,
|
||||||
|
logits_time_major=False,
|
||||||
blank_index=0
|
blank_index=0
|
||||||
)
|
)
|
||||||
|
loss = tf.reduce_mean(loss)
|
||||||
|
|
||||||
# Greedy decoding for PER calculation
|
# Greedy decoding for PER calculation
|
||||||
predicted_ids = tf.argmax(logits, axis=-1, output_type=tf.int32)
|
predicted_ids = tf.argmax(logits, axis=-1, output_type=tf.int32)
|
||||||
|
Reference in New Issue
Block a user