From 0a0e07a1931f69e3801c3d235068b52f24d4d13d Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Mon, 20 Oct 2025 01:16:50 +0800 Subject: [PATCH] Remove custom CTC loss implementation for TPU from the TripleGRUDecoder class --- model_training_nnn_tpu/rnn_model_tf.py | 67 -------------------------- 1 file changed, 67 deletions(-) diff --git a/model_training_nnn_tpu/rnn_model_tf.py b/model_training_nnn_tpu/rnn_model_tf.py index 6a0e7c8..42d0302 100644 --- a/model_training_nnn_tpu/rnn_model_tf.py +++ b/model_training_nnn_tpu/rnn_model_tf.py @@ -689,73 +689,6 @@ class TripleGRUDecoder(keras.Model): self.training_mode = mode -# Custom CTC Loss for TensorFlow TPU -class CTCLoss(keras.losses.Loss): - """ - Custom CTC Loss optimized for TPU v5e-8 - """ - - def __init__(self, blank_index=0, reduction='none', **kwargs): - super(CTCLoss, self).__init__(reduction=reduction, **kwargs) - self.blank_index = blank_index - - def call(self, y_true, y_pred): - """ - Args: - y_true: Dictionary containing 'labels', 'input_lengths', 'label_lengths' - y_pred: Logits tensor [batch_size, time_steps, num_classes] - """ - labels = y_true['labels'] - input_lengths = y_true['input_lengths'] - label_lengths = y_true['label_lengths'] - - # Ensure correct data types - labels = tf.cast(labels, tf.int32) - input_lengths = tf.cast(input_lengths, tf.int32) - label_lengths = tf.cast(label_lengths, tf.int32) - - # Convert logits to log probabilities - log_probs = tf.nn.log_softmax(y_pred, axis=-1) - - # Transpose for CTC: [time_steps, batch_size, num_classes] - log_probs = tf.transpose(log_probs, [1, 0, 2]) - - # Convert dense labels to sparse format for CTC using TensorFlow operations - def dense_to_sparse(dense_tensor, sequence_lengths): - """Convert dense tensor to sparse tensor for CTC""" - batch_size = tf.shape(dense_tensor)[0] - max_len = tf.shape(dense_tensor)[1] - - # Create mask for non-zero elements - mask = tf.not_equal(dense_tensor, 0) - - # Get indices of non-zero elements - indices = tf.where(mask) - - # Get values at those indices - values = tf.gather_nd(dense_tensor, indices) - - # Create sparse tensor - dense_shape = tf.cast([batch_size, max_len], tf.int64) - - return tf.SparseTensor(indices=indices, values=values, dense_shape=dense_shape) - - # Convert labels to sparse format - sparse_labels = dense_to_sparse(labels, label_lengths) - - # Compute CTC loss - loss = tf.nn.ctc_loss( - labels=sparse_labels, - logits=log_probs, - label_length=None, # Not needed for sparse format - logit_length=input_lengths, - blank_index=self.blank_index, - logits_time_major=True - ) - - return loss - - # TPU Strategy Helper Functions def create_tpu_strategy(): """Create TPU strategy for distributed training on TPU v5e-8"""