This commit is contained in:
Zchen
2025-10-21 00:31:59 +08:00
parent e7c9b95b00
commit ab12d0b7ee
2 changed files with 128 additions and 94 deletions

View File

@@ -17,8 +17,55 @@ except ImportError:
print("Warning: editdistance not available, falling back to approximation")
editdistance = None
# XLA-compatible CTC loss implementation
from tf_seq2seq_losses import classic_ctc_loss
# Note: Reverted to standard tf.nn.ctc_loss + SparseTensor approach
# for compatibility with "batch first, augment after" data pipeline
def dense_to_sparse(dense_tensor, sequence_lengths):
"""
Convert dense tensor to sparse tensor for CTC loss with dynamic shapes
This function is essential for the "batch first, augment after" approach
as it handles the conversion from dynamic dense tensors to SparseTensor
format required by tf.nn.ctc_loss.
Args:
dense_tensor: Dense tensor with shape [batch_size, max_seq_len]
sequence_lengths: Actual sequence lengths [batch_size]
Returns:
SparseTensor suitable for tf.nn.ctc_loss
"""
# Create mask for valid (non-zero) elements within sequence lengths
batch_size = tf.shape(dense_tensor)[0]
max_seq_len = tf.shape(dense_tensor)[1]
# Create range indices
batch_indices = tf.range(batch_size)
seq_indices = tf.range(max_seq_len)
# Create meshgrid for sequence dimensions
_, seq_mesh = tf.meshgrid(batch_indices, seq_indices, indexing='ij')
# Create mask based on sequence lengths and non-zero values
length_mask = seq_mesh < tf.expand_dims(sequence_lengths, 1)
value_mask = tf.not_equal(dense_tensor, 0)
combined_mask = tf.logical_and(length_mask, value_mask)
# Get indices of valid elements
indices = tf.where(combined_mask)
# Get values at valid indices
values = tf.gather_nd(dense_tensor, indices)
# Create sparse tensor
dense_shape = tf.cast(tf.shape(dense_tensor), tf.int64)
return tf.SparseTensor(
indices=tf.cast(indices, tf.int64),
values=tf.cast(values, tf.int32),
dense_shape=dense_shape
)
from rnn_model_tf import (
TripleGRUDecoder,
@@ -559,23 +606,29 @@ class BrainToTextDecoderTrainerTF:
# Calculate losses using TPU-compatible CTC implementation
if use_full:
# Clean CTC loss - using XLA-compatible classic_ctc_loss
clean_loss = classic_ctc_loss(
labels=tf.cast(labels, tf.int32), # Dense labels as int32
# Clean CTC loss - using standard tf.nn.ctc_loss with SparseTensor
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
clean_loss = tf.nn.ctc_loss(
labels=sparse_labels,
logits=clean_logits,
label_length=phone_seq_lens,
label_length=None, # SparseTensor doesn't need label_length
logit_length=adjusted_lens,
logits_time_major=False,
blank_index=0
)
clean_loss = tf.reduce_mean(clean_loss)
# Noisy CTC loss - using XLA-compatible classic_ctc_loss
noisy_loss = classic_ctc_loss(
labels=tf.cast(labels, tf.int32), # Dense labels as int32
# Noisy CTC loss - using standard tf.nn.ctc_loss with SparseTensor
# Reuse the same sparse_labels from above
noisy_loss = tf.nn.ctc_loss(
labels=sparse_labels,
logits=noisy_logits,
label_length=phone_seq_lens,
label_length=None, # SparseTensor doesn't need label_length
logit_length=adjusted_lens,
logits_time_major=False,
blank_index=0
)
noisy_loss = tf.reduce_mean(noisy_loss)
# Optional noise L2 regularization
noise_l2 = tf.constant(0.0, dtype=clean_loss.dtype)
@@ -584,14 +637,17 @@ class BrainToTextDecoderTrainerTF:
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
else:
# Standard CTC loss - using XLA-compatible classic_ctc_loss
loss = classic_ctc_loss(
labels=tf.cast(labels, tf.int32), # Dense labels as int32
# Standard CTC loss - using standard tf.nn.ctc_loss with SparseTensor
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
loss = tf.nn.ctc_loss(
labels=sparse_labels,
logits=clean_logits,
label_length=phone_seq_lens,
label_length=None, # SparseTensor doesn't need label_length
logit_length=adjusted_lens,
logits_time_major=False,
blank_index=0
)
loss = tf.reduce_mean(loss)
# AdamW handles weight decay automatically - no manual L2 regularization needed
# TensorFlow混合精度处理不需要手动scalingKeras policy自动处理
@@ -646,14 +702,17 @@ class BrainToTextDecoderTrainerTF:
# Forward pass (inference mode only)
logits = self.model(features, day_indices, None, False, 'inference', training=False)
# Calculate loss using XLA-compatible classic_ctc_loss
loss = classic_ctc_loss(
labels=tf.cast(labels, tf.int32), # Dense labels as int32
# Calculate loss using standard tf.nn.ctc_loss with SparseTensor
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
loss = tf.nn.ctc_loss(
labels=sparse_labels,
logits=logits,
label_length=phone_seq_lens,
label_length=None, # SparseTensor doesn't need label_length
logit_length=adjusted_lens,
logits_time_major=False,
blank_index=0
)
loss = tf.reduce_mean(loss)
# Greedy decoding for PER calculation
predicted_ids = tf.argmax(logits, axis=-1, output_type=tf.int32)