This commit is contained in:
Zchen
2025-10-15 20:45:25 +08:00
parent 3b242b908d
commit e8f0308fef
5 changed files with 409 additions and 19 deletions

View File

@@ -12,7 +12,7 @@ def gradient_reverse(x, lambd=1.0):
Backward: multiply incoming gradient by -lambda
"""
def grad(dy):
return -lambd * dy, None
return -lambd * dy # Only return gradient w.r.t. x, not lambd
return tf.identity(x), grad
@@ -709,17 +709,45 @@ class CTCLoss(keras.losses.Loss):
input_lengths = y_true['input_lengths']
label_lengths = y_true['label_lengths']
# Ensure correct data types
labels = tf.cast(labels, tf.int32)
input_lengths = tf.cast(input_lengths, tf.int32)
label_lengths = tf.cast(label_lengths, tf.int32)
# Convert logits to log probabilities
log_probs = tf.nn.log_softmax(y_pred, axis=-1)
# Transpose for CTC: [time_steps, batch_size, num_classes]
log_probs = tf.transpose(log_probs, [1, 0, 2])
# Convert dense labels to sparse format for CTC using TensorFlow operations
def dense_to_sparse(dense_tensor, sequence_lengths):
"""Convert dense tensor to sparse tensor for CTC"""
batch_size = tf.shape(dense_tensor)[0]
max_len = tf.shape(dense_tensor)[1]
# Create mask for non-zero elements
mask = tf.not_equal(dense_tensor, 0)
# Get indices of non-zero elements
indices = tf.where(mask)
# Get values at those indices
values = tf.gather_nd(dense_tensor, indices)
# Create sparse tensor
dense_shape = tf.cast([batch_size, max_len], tf.int64)
return tf.SparseTensor(indices=indices, values=values, dense_shape=dense_shape)
# Convert labels to sparse format
sparse_labels = dense_to_sparse(labels, label_lengths)
# Compute CTC loss
loss = tf.nn.ctc_loss(
labels=labels,
labels=sparse_labels,
logits=log_probs,
label_length=label_lengths,
label_length=None, # Not needed for sparse format
logit_length=input_lengths,
blank_index=self.blank_index,
logits_time_major=True