This commit is contained in:
Zchen
2025-10-17 02:01:48 +08:00
parent 49700456b8
commit ca8c615505

View File

@@ -647,43 +647,38 @@ class BrainToTextDecoderTrainerTF:
loss = self.ctc_loss(loss_input, logits)
loss = tf.reduce_mean(loss)
# Calculate PER (Phoneme Error Rate)
# Calculate simplified PER approximation (TPU-compatible)
# For TPU training, we use a simplified metric that avoids complex loops
# This gives an approximation of PER but is much faster and TPU-compatible
# Greedy decoding
predicted_ids = tf.argmax(logits, axis=-1)
# Remove blanks and consecutive duplicates
batch_edit_distance = 0
for i in range(tf.shape(logits)[0]):
pred_seq = predicted_ids[i, :adjusted_lens[i]]
# Remove consecutive duplicates
pred_seq = tf.py_function(
func=lambda x: tf.constant([x[0]] + [x[j] for j in range(1, len(x)) if x[j] != x[j-1]]),
inp=[pred_seq],
Tout=tf.int64
)
# Remove blanks (assuming blank_index=0)
pred_seq = tf.boolean_mask(pred_seq, pred_seq != 0)
# Simple approximation: count exact matches vs mismatches
# This is less accurate than true edit distance but TPU-compatible
batch_size = tf.shape(logits)[0]
# For each sample, compare predicted vs true sequences
total_mismatches = tf.constant(0, dtype=tf.int32)
for i in tf.range(batch_size):
# Get sequences for this sample
pred_seq = predicted_ids[i, :adjusted_lens[i]]
true_seq = labels[i, :phone_seq_lens[i]]
# Calculate edit distance
edit_dist = tf.edit_distance(
tf.SparseTensor(
indices=tf.expand_dims(tf.range(tf.size(pred_seq)), 1),
values=tf.cast(pred_seq, tf.int64),
dense_shape=[tf.size(pred_seq)]
),
tf.SparseTensor(
indices=tf.expand_dims(tf.range(tf.size(true_seq)), 1),
values=tf.cast(true_seq, tf.int64),
dense_shape=[tf.size(true_seq)]
),
normalize=False
)
# Pad to same length for comparison
max_len = tf.maximum(tf.shape(pred_seq)[0], tf.shape(true_seq)[0])
pred_padded = tf.pad(pred_seq, [[0, max_len - tf.shape(pred_seq)[0]]], constant_values=0)
true_padded = tf.pad(true_seq, [[0, max_len - tf.shape(true_seq)[0]]], constant_values=0)
batch_edit_distance += edit_dist
# Count mismatches
mismatches = tf.reduce_sum(tf.cast(tf.not_equal(pred_padded, true_padded), tf.int32))
total_mismatches += mismatches
return loss, batch_edit_distance, tf.reduce_sum(phone_seq_lens)
# Approximate edit distance as number of mismatches
batch_edit_distance = tf.cast(total_mismatches, tf.float32)
return loss, batch_edit_distance, tf.cast(tf.reduce_sum(phone_seq_lens), tf.float32)
def train(self) -> Dict[str, Any]:
"""Main training loop"""