Remove custom CTC loss implementation for TPU from the TripleGRUDecoder class
This commit is contained in:
@@ -689,73 +689,6 @@ class TripleGRUDecoder(keras.Model):
|
|||||||
self.training_mode = mode
|
self.training_mode = mode
|
||||||
|
|
||||||
|
|
||||||
# Custom CTC Loss for TensorFlow TPU
|
|
||||||
class CTCLoss(keras.losses.Loss):
|
|
||||||
"""
|
|
||||||
Custom CTC Loss optimized for TPU v5e-8
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, blank_index=0, reduction='none', **kwargs):
|
|
||||||
super(CTCLoss, self).__init__(reduction=reduction, **kwargs)
|
|
||||||
self.blank_index = blank_index
|
|
||||||
|
|
||||||
def call(self, y_true, y_pred):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
y_true: Dictionary containing 'labels', 'input_lengths', 'label_lengths'
|
|
||||||
y_pred: Logits tensor [batch_size, time_steps, num_classes]
|
|
||||||
"""
|
|
||||||
labels = y_true['labels']
|
|
||||||
input_lengths = y_true['input_lengths']
|
|
||||||
label_lengths = y_true['label_lengths']
|
|
||||||
|
|
||||||
# Ensure correct data types
|
|
||||||
labels = tf.cast(labels, tf.int32)
|
|
||||||
input_lengths = tf.cast(input_lengths, tf.int32)
|
|
||||||
label_lengths = tf.cast(label_lengths, tf.int32)
|
|
||||||
|
|
||||||
# Convert logits to log probabilities
|
|
||||||
log_probs = tf.nn.log_softmax(y_pred, axis=-1)
|
|
||||||
|
|
||||||
# Transpose for CTC: [time_steps, batch_size, num_classes]
|
|
||||||
log_probs = tf.transpose(log_probs, [1, 0, 2])
|
|
||||||
|
|
||||||
# Convert dense labels to sparse format for CTC using TensorFlow operations
|
|
||||||
def dense_to_sparse(dense_tensor, sequence_lengths):
|
|
||||||
"""Convert dense tensor to sparse tensor for CTC"""
|
|
||||||
batch_size = tf.shape(dense_tensor)[0]
|
|
||||||
max_len = tf.shape(dense_tensor)[1]
|
|
||||||
|
|
||||||
# Create mask for non-zero elements
|
|
||||||
mask = tf.not_equal(dense_tensor, 0)
|
|
||||||
|
|
||||||
# Get indices of non-zero elements
|
|
||||||
indices = tf.where(mask)
|
|
||||||
|
|
||||||
# Get values at those indices
|
|
||||||
values = tf.gather_nd(dense_tensor, indices)
|
|
||||||
|
|
||||||
# Create sparse tensor
|
|
||||||
dense_shape = tf.cast([batch_size, max_len], tf.int64)
|
|
||||||
|
|
||||||
return tf.SparseTensor(indices=indices, values=values, dense_shape=dense_shape)
|
|
||||||
|
|
||||||
# Convert labels to sparse format
|
|
||||||
sparse_labels = dense_to_sparse(labels, label_lengths)
|
|
||||||
|
|
||||||
# Compute CTC loss
|
|
||||||
loss = tf.nn.ctc_loss(
|
|
||||||
labels=sparse_labels,
|
|
||||||
logits=log_probs,
|
|
||||||
label_length=None, # Not needed for sparse format
|
|
||||||
logit_length=input_lengths,
|
|
||||||
blank_index=self.blank_index,
|
|
||||||
logits_time_major=True
|
|
||||||
)
|
|
||||||
|
|
||||||
return loss
|
|
||||||
|
|
||||||
|
|
||||||
# TPU Strategy Helper Functions
|
# TPU Strategy Helper Functions
|
||||||
def create_tpu_strategy():
|
def create_tpu_strategy():
|
||||||
"""Create TPU strategy for distributed training on TPU v5e-8"""
|
"""Create TPU strategy for distributed training on TPU v5e-8"""
|
||||||
|
|||||||
Reference in New Issue
Block a user