This commit is contained in:
Zchen
2025-10-22 00:38:55 +08:00
parent e715d9ac79
commit 52a9b17375

View File

@@ -9,6 +9,7 @@ import pathlib
import sys import sys
from typing import Dict, Any, Tuple, Optional, List from typing import Dict, Any, Tuple, Optional, List
from omegaconf import OmegaConf from omegaconf import OmegaConf
from tf_seq2seq_losses import classic_ctc_loss
# For accurate PER calculation # For accurate PER calculation
try: try:
@@ -17,55 +18,7 @@ except ImportError:
print("Warning: editdistance not available, falling back to approximation") print("Warning: editdistance not available, falling back to approximation")
editdistance = None editdistance = None
# Note: Reverted to standard tf.nn.ctc_loss + SparseTensor approach # Note: Now using classic_ctc_loss from tf_seq2seq_losses for better compatibility
# for compatibility with "batch first, augment after" data pipeline
def dense_to_sparse(dense_tensor, sequence_lengths):
"""
Convert dense tensor to sparse tensor for CTC loss with dynamic shapes
This function is essential for the "batch first, augment after" approach
as it handles the conversion from dynamic dense tensors to SparseTensor
format required by tf.nn.ctc_loss.
Args:
dense_tensor: Dense tensor with shape [batch_size, max_seq_len]
sequence_lengths: Actual sequence lengths [batch_size]
Returns:
SparseTensor suitable for tf.nn.ctc_loss
"""
# Create mask for valid (non-zero) elements within sequence lengths
batch_size = tf.shape(dense_tensor)[0]
max_seq_len = tf.shape(dense_tensor)[1]
# Create range indices
batch_indices = tf.range(batch_size)
seq_indices = tf.range(max_seq_len)
# Create meshgrid for sequence dimensions
_, seq_mesh = tf.meshgrid(batch_indices, seq_indices, indexing='ij')
# Create mask based on sequence lengths and non-zero values
length_mask = seq_mesh < tf.expand_dims(sequence_lengths, 1)
value_mask = tf.not_equal(dense_tensor, 0)
combined_mask = tf.logical_and(length_mask, value_mask)
# Get indices of valid elements
indices = tf.where(combined_mask)
# Get values at valid indices
values = tf.gather_nd(dense_tensor, indices)
# Create sparse tensor
dense_shape = tf.cast(tf.shape(dense_tensor), tf.int64)
return tf.SparseTensor(
indices=tf.cast(indices, tf.int64),
values=tf.cast(values, tf.int32),
dense_shape=dense_shape
)
from rnn_model_tf import ( from rnn_model_tf import (
TripleGRUDecoder, TripleGRUDecoder,
@@ -603,28 +556,27 @@ class BrainToTextDecoderTrainerTF:
features, day_indices, None, False, 'inference', training=True features, day_indices, None, False, 'inference', training=True
) )
# Calculate losses using TPU-compatible CTC implementation # Calculate losses using community CTC implementation
if use_full: if use_full:
# Clean CTC loss - using standard tf.nn.ctc_loss with SparseTensor # Clean CTC loss - using new classic_ctc_loss
sparse_labels = dense_to_sparse(labels, phone_seq_lens) # Convert to float32 for numerical stability
clean_loss = tf.nn.ctc_loss( clean_logits_f32 = tf.cast(clean_logits, tf.float32)
labels=sparse_labels, clean_loss = classic_ctc_loss(
logits=clean_logits, labels=labels,
label_length=None, # SparseTensor doesn't need label_length logits=clean_logits_f32,
label_length=phone_seq_lens, # Direct use of length tensor
logit_length=adjusted_lens, logit_length=adjusted_lens,
logits_time_major=False,
blank_index=0 blank_index=0
) )
clean_loss = tf.reduce_mean(clean_loss) clean_loss = tf.reduce_mean(clean_loss)
# Noisy CTC loss - using standard tf.nn.ctc_loss with SparseTensor # Noisy CTC loss - using new classic_ctc_loss
# Reuse the same sparse_labels from above noisy_logits_f32 = tf.cast(noisy_logits, tf.float32)
noisy_loss = tf.nn.ctc_loss( noisy_loss = classic_ctc_loss(
labels=sparse_labels, labels=labels,
logits=noisy_logits, logits=noisy_logits_f32,
label_length=None, # SparseTensor doesn't need label_length label_length=phone_seq_lens, # Direct use of length tensor
logit_length=adjusted_lens, logit_length=adjusted_lens,
logits_time_major=False,
blank_index=0 blank_index=0
) )
noisy_loss = tf.reduce_mean(noisy_loss) noisy_loss = tf.reduce_mean(noisy_loss)
@@ -636,14 +588,14 @@ class BrainToTextDecoderTrainerTF:
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2 loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
else: else:
# Standard CTC loss - using standard tf.nn.ctc_loss with SparseTensor # Standard CTC loss - using new classic_ctc_loss
sparse_labels = dense_to_sparse(labels, phone_seq_lens) # Convert to float32 for numerical stability
loss = tf.nn.ctc_loss( clean_logits_f32 = tf.cast(clean_logits, tf.float32)
labels=sparse_labels, loss = classic_ctc_loss(
logits=clean_logits, labels=labels,
label_length=None, # SparseTensor doesn't need label_length logits=clean_logits_f32,
label_length=phone_seq_lens, # Direct use of length tensor
logit_length=adjusted_lens, logit_length=adjusted_lens,
logits_time_major=False,
blank_index=0 blank_index=0
) )
loss = tf.reduce_mean(loss) loss = tf.reduce_mean(loss)
@@ -701,14 +653,14 @@ class BrainToTextDecoderTrainerTF:
# Forward pass (inference mode only) # Forward pass (inference mode only)
logits = self.model(features, day_indices, None, False, 'inference', training=False) logits = self.model(features, day_indices, None, False, 'inference', training=False)
# Calculate loss using standard tf.nn.ctc_loss with SparseTensor # Calculate loss using community CTC implementation
sparse_labels = dense_to_sparse(labels, phone_seq_lens) # Convert to float32 for numerical stability
loss = tf.nn.ctc_loss( logits_f32 = tf.cast(logits, tf.float32)
labels=sparse_labels, loss = classic_ctc_loss(
logits=logits, labels=labels,
label_length=None, # SparseTensor doesn't need label_length logits=logits_f32,
label_length=phone_seq_lens, # Direct use of length tensor
logit_length=adjusted_lens, logit_length=adjusted_lens,
logits_time_major=False,
blank_index=0 blank_index=0
) )
loss = tf.reduce_mean(loss) loss = tf.reduce_mean(loss)