From ca8c6155053d7affcd1a73c4fee991fedb07a904 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Fri, 17 Oct 2025 02:01:48 +0800 Subject: [PATCH] f --- model_training_nnn_tpu/trainer_tf.py | 53 +++++++++++++--------------- 1 file changed, 24 insertions(+), 29 deletions(-) diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 0a72551..2d0c8d3 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -647,43 +647,38 @@ class BrainToTextDecoderTrainerTF: loss = self.ctc_loss(loss_input, logits) loss = tf.reduce_mean(loss) - # Calculate PER (Phoneme Error Rate) + # Calculate simplified PER approximation (TPU-compatible) + # For TPU training, we use a simplified metric that avoids complex loops + # This gives an approximation of PER but is much faster and TPU-compatible + # Greedy decoding predicted_ids = tf.argmax(logits, axis=-1) - # Remove blanks and consecutive duplicates - batch_edit_distance = 0 - for i in range(tf.shape(logits)[0]): - pred_seq = predicted_ids[i, :adjusted_lens[i]] - # Remove consecutive duplicates - pred_seq = tf.py_function( - func=lambda x: tf.constant([x[0]] + [x[j] for j in range(1, len(x)) if x[j] != x[j-1]]), - inp=[pred_seq], - Tout=tf.int64 - ) - # Remove blanks (assuming blank_index=0) - pred_seq = tf.boolean_mask(pred_seq, pred_seq != 0) + # Simple approximation: count exact matches vs mismatches + # This is less accurate than true edit distance but TPU-compatible + batch_size = tf.shape(logits)[0] + # For each sample, compare predicted vs true sequences + total_mismatches = tf.constant(0, dtype=tf.int32) + + for i in tf.range(batch_size): + # Get sequences for this sample + pred_seq = predicted_ids[i, :adjusted_lens[i]] true_seq = labels[i, :phone_seq_lens[i]] - # Calculate edit distance - edit_dist = tf.edit_distance( - tf.SparseTensor( - indices=tf.expand_dims(tf.range(tf.size(pred_seq)), 1), - values=tf.cast(pred_seq, tf.int64), - dense_shape=[tf.size(pred_seq)] - ), - tf.SparseTensor( - indices=tf.expand_dims(tf.range(tf.size(true_seq)), 1), - values=tf.cast(true_seq, tf.int64), - dense_shape=[tf.size(true_seq)] - ), - normalize=False - ) + # Pad to same length for comparison + max_len = tf.maximum(tf.shape(pred_seq)[0], tf.shape(true_seq)[0]) + pred_padded = tf.pad(pred_seq, [[0, max_len - tf.shape(pred_seq)[0]]], constant_values=0) + true_padded = tf.pad(true_seq, [[0, max_len - tf.shape(true_seq)[0]]], constant_values=0) - batch_edit_distance += edit_dist + # Count mismatches + mismatches = tf.reduce_sum(tf.cast(tf.not_equal(pred_padded, true_padded), tf.int32)) + total_mismatches += mismatches - return loss, batch_edit_distance, tf.reduce_sum(phone_seq_lens) + # Approximate edit distance as number of mismatches + batch_edit_distance = tf.cast(total_mismatches, tf.float32) + + return loss, batch_edit_distance, tf.cast(tf.reduce_sum(phone_seq_lens), tf.float32) def train(self) -> Dict[str, Any]: """Main training loop"""