Remove custom CTC loss implementation for TPU from the TripleGRUDecoder class

This commit is contained in:
Zchen
2025-10-20 01:16:50 +08:00
parent 06ddbc6ac2
commit 0a0e07a193

View File

@@ -689,73 +689,6 @@ class TripleGRUDecoder(keras.Model):
self.training_mode = mode
# Custom CTC Loss for TensorFlow TPU
class CTCLoss(keras.losses.Loss):
"""
Custom CTC Loss optimized for TPU v5e-8
"""
def __init__(self, blank_index=0, reduction='none', **kwargs):
super(CTCLoss, self).__init__(reduction=reduction, **kwargs)
self.blank_index = blank_index
def call(self, y_true, y_pred):
"""
Args:
y_true: Dictionary containing 'labels', 'input_lengths', 'label_lengths'
y_pred: Logits tensor [batch_size, time_steps, num_classes]
"""
labels = y_true['labels']
input_lengths = y_true['input_lengths']
label_lengths = y_true['label_lengths']
# Ensure correct data types
labels = tf.cast(labels, tf.int32)
input_lengths = tf.cast(input_lengths, tf.int32)
label_lengths = tf.cast(label_lengths, tf.int32)
# Convert logits to log probabilities
log_probs = tf.nn.log_softmax(y_pred, axis=-1)
# Transpose for CTC: [time_steps, batch_size, num_classes]
log_probs = tf.transpose(log_probs, [1, 0, 2])
# Convert dense labels to sparse format for CTC using TensorFlow operations
def dense_to_sparse(dense_tensor, sequence_lengths):
"""Convert dense tensor to sparse tensor for CTC"""
batch_size = tf.shape(dense_tensor)[0]
max_len = tf.shape(dense_tensor)[1]
# Create mask for non-zero elements
mask = tf.not_equal(dense_tensor, 0)
# Get indices of non-zero elements
indices = tf.where(mask)
# Get values at those indices
values = tf.gather_nd(dense_tensor, indices)
# Create sparse tensor
dense_shape = tf.cast([batch_size, max_len], tf.int64)
return tf.SparseTensor(indices=indices, values=values, dense_shape=dense_shape)
# Convert labels to sparse format
sparse_labels = dense_to_sparse(labels, label_lengths)
# Compute CTC loss
loss = tf.nn.ctc_loss(
labels=sparse_labels,
logits=log_probs,
label_length=None, # Not needed for sparse format
logit_length=input_lengths,
blank_index=self.blank_index,
logits_time_major=True
)
return loss
# TPU Strategy Helper Functions
def create_tpu_strategy():
"""Create TPU strategy for distributed training on TPU v5e-8"""