Remove custom CTC loss implementation for TPU from the TripleGRUDecoder class
This commit is contained in:
@@ -689,73 +689,6 @@ class TripleGRUDecoder(keras.Model):
|
||||
self.training_mode = mode
|
||||
|
||||
|
||||
# Custom CTC Loss for TensorFlow TPU
|
||||
class CTCLoss(keras.losses.Loss):
|
||||
"""
|
||||
Custom CTC Loss optimized for TPU v5e-8
|
||||
"""
|
||||
|
||||
def __init__(self, blank_index=0, reduction='none', **kwargs):
|
||||
super(CTCLoss, self).__init__(reduction=reduction, **kwargs)
|
||||
self.blank_index = blank_index
|
||||
|
||||
def call(self, y_true, y_pred):
|
||||
"""
|
||||
Args:
|
||||
y_true: Dictionary containing 'labels', 'input_lengths', 'label_lengths'
|
||||
y_pred: Logits tensor [batch_size, time_steps, num_classes]
|
||||
"""
|
||||
labels = y_true['labels']
|
||||
input_lengths = y_true['input_lengths']
|
||||
label_lengths = y_true['label_lengths']
|
||||
|
||||
# Ensure correct data types
|
||||
labels = tf.cast(labels, tf.int32)
|
||||
input_lengths = tf.cast(input_lengths, tf.int32)
|
||||
label_lengths = tf.cast(label_lengths, tf.int32)
|
||||
|
||||
# Convert logits to log probabilities
|
||||
log_probs = tf.nn.log_softmax(y_pred, axis=-1)
|
||||
|
||||
# Transpose for CTC: [time_steps, batch_size, num_classes]
|
||||
log_probs = tf.transpose(log_probs, [1, 0, 2])
|
||||
|
||||
# Convert dense labels to sparse format for CTC using TensorFlow operations
|
||||
def dense_to_sparse(dense_tensor, sequence_lengths):
|
||||
"""Convert dense tensor to sparse tensor for CTC"""
|
||||
batch_size = tf.shape(dense_tensor)[0]
|
||||
max_len = tf.shape(dense_tensor)[1]
|
||||
|
||||
# Create mask for non-zero elements
|
||||
mask = tf.not_equal(dense_tensor, 0)
|
||||
|
||||
# Get indices of non-zero elements
|
||||
indices = tf.where(mask)
|
||||
|
||||
# Get values at those indices
|
||||
values = tf.gather_nd(dense_tensor, indices)
|
||||
|
||||
# Create sparse tensor
|
||||
dense_shape = tf.cast([batch_size, max_len], tf.int64)
|
||||
|
||||
return tf.SparseTensor(indices=indices, values=values, dense_shape=dense_shape)
|
||||
|
||||
# Convert labels to sparse format
|
||||
sparse_labels = dense_to_sparse(labels, label_lengths)
|
||||
|
||||
# Compute CTC loss
|
||||
loss = tf.nn.ctc_loss(
|
||||
labels=sparse_labels,
|
||||
logits=log_probs,
|
||||
label_length=None, # Not needed for sparse format
|
||||
logit_length=input_lengths,
|
||||
blank_index=self.blank_index,
|
||||
logits_time_major=True
|
||||
)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
# TPU Strategy Helper Functions
|
||||
def create_tpu_strategy():
|
||||
"""Create TPU strategy for distributed training on TPU v5e-8"""
|
||||
|
Reference in New Issue
Block a user