This commit is contained in:
Zchen
2025-10-22 00:38:55 +08:00
parent e715d9ac79
commit 52a9b17375

View File

@@ -9,6 +9,7 @@ import pathlib
import sys
from typing import Dict, Any, Tuple, Optional, List
from omegaconf import OmegaConf
from tf_seq2seq_losses import classic_ctc_loss
# For accurate PER calculation
try:
@@ -17,55 +18,7 @@ except ImportError:
print("Warning: editdistance not available, falling back to approximation")
editdistance = None
# Note: Reverted to standard tf.nn.ctc_loss + SparseTensor approach
# for compatibility with "batch first, augment after" data pipeline
def dense_to_sparse(dense_tensor, sequence_lengths):
"""
Convert dense tensor to sparse tensor for CTC loss with dynamic shapes
This function is essential for the "batch first, augment after" approach
as it handles the conversion from dynamic dense tensors to SparseTensor
format required by tf.nn.ctc_loss.
Args:
dense_tensor: Dense tensor with shape [batch_size, max_seq_len]
sequence_lengths: Actual sequence lengths [batch_size]
Returns:
SparseTensor suitable for tf.nn.ctc_loss
"""
# Create mask for valid (non-zero) elements within sequence lengths
batch_size = tf.shape(dense_tensor)[0]
max_seq_len = tf.shape(dense_tensor)[1]
# Create range indices
batch_indices = tf.range(batch_size)
seq_indices = tf.range(max_seq_len)
# Create meshgrid for sequence dimensions
_, seq_mesh = tf.meshgrid(batch_indices, seq_indices, indexing='ij')
# Create mask based on sequence lengths and non-zero values
length_mask = seq_mesh < tf.expand_dims(sequence_lengths, 1)
value_mask = tf.not_equal(dense_tensor, 0)
combined_mask = tf.logical_and(length_mask, value_mask)
# Get indices of valid elements
indices = tf.where(combined_mask)
# Get values at valid indices
values = tf.gather_nd(dense_tensor, indices)
# Create sparse tensor
dense_shape = tf.cast(tf.shape(dense_tensor), tf.int64)
return tf.SparseTensor(
indices=tf.cast(indices, tf.int64),
values=tf.cast(values, tf.int32),
dense_shape=dense_shape
)
# Note: Now using classic_ctc_loss from tf_seq2seq_losses for better compatibility
from rnn_model_tf import (
TripleGRUDecoder,
@@ -603,28 +556,27 @@ class BrainToTextDecoderTrainerTF:
features, day_indices, None, False, 'inference', training=True
)
# Calculate losses using TPU-compatible CTC implementation
# Calculate losses using community CTC implementation
if use_full:
# Clean CTC loss - using standard tf.nn.ctc_loss with SparseTensor
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
clean_loss = tf.nn.ctc_loss(
labels=sparse_labels,
logits=clean_logits,
label_length=None, # SparseTensor doesn't need label_length
# Clean CTC loss - using new classic_ctc_loss
# Convert to float32 for numerical stability
clean_logits_f32 = tf.cast(clean_logits, tf.float32)
clean_loss = classic_ctc_loss(
labels=labels,
logits=clean_logits_f32,
label_length=phone_seq_lens, # Direct use of length tensor
logit_length=adjusted_lens,
logits_time_major=False,
blank_index=0
)
clean_loss = tf.reduce_mean(clean_loss)
# Noisy CTC loss - using standard tf.nn.ctc_loss with SparseTensor
# Reuse the same sparse_labels from above
noisy_loss = tf.nn.ctc_loss(
labels=sparse_labels,
logits=noisy_logits,
label_length=None, # SparseTensor doesn't need label_length
# Noisy CTC loss - using new classic_ctc_loss
noisy_logits_f32 = tf.cast(noisy_logits, tf.float32)
noisy_loss = classic_ctc_loss(
labels=labels,
logits=noisy_logits_f32,
label_length=phone_seq_lens, # Direct use of length tensor
logit_length=adjusted_lens,
logits_time_major=False,
blank_index=0
)
noisy_loss = tf.reduce_mean(noisy_loss)
@@ -636,14 +588,14 @@ class BrainToTextDecoderTrainerTF:
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
else:
# Standard CTC loss - using standard tf.nn.ctc_loss with SparseTensor
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
loss = tf.nn.ctc_loss(
labels=sparse_labels,
logits=clean_logits,
label_length=None, # SparseTensor doesn't need label_length
# Standard CTC loss - using new classic_ctc_loss
# Convert to float32 for numerical stability
clean_logits_f32 = tf.cast(clean_logits, tf.float32)
loss = classic_ctc_loss(
labels=labels,
logits=clean_logits_f32,
label_length=phone_seq_lens, # Direct use of length tensor
logit_length=adjusted_lens,
logits_time_major=False,
blank_index=0
)
loss = tf.reduce_mean(loss)
@@ -701,14 +653,14 @@ class BrainToTextDecoderTrainerTF:
# Forward pass (inference mode only)
logits = self.model(features, day_indices, None, False, 'inference', training=False)
# Calculate loss using standard tf.nn.ctc_loss with SparseTensor
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
loss = tf.nn.ctc_loss(
labels=sparse_labels,
logits=logits,
label_length=None, # SparseTensor doesn't need label_length
# Calculate loss using community CTC implementation
# Convert to float32 for numerical stability
logits_f32 = tf.cast(logits, tf.float32)
loss = classic_ctc_loss(
labels=labels,
logits=logits_f32,
label_length=phone_seq_lens, # Direct use of length tensor
logit_length=adjusted_lens,
logits_time_major=False,
blank_index=0
)
loss = tf.reduce_mean(loss)