diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index cc6f7ba..985f06e 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -9,6 +9,7 @@ import pathlib import sys from typing import Dict, Any, Tuple, Optional, List from omegaconf import OmegaConf +from tf_seq2seq_losses import classic_ctc_loss # For accurate PER calculation try: @@ -17,55 +18,7 @@ except ImportError: print("Warning: editdistance not available, falling back to approximation") editdistance = None -# Note: Reverted to standard tf.nn.ctc_loss + SparseTensor approach -# for compatibility with "batch first, augment after" data pipeline - - -def dense_to_sparse(dense_tensor, sequence_lengths): - """ - Convert dense tensor to sparse tensor for CTC loss with dynamic shapes - - This function is essential for the "batch first, augment after" approach - as it handles the conversion from dynamic dense tensors to SparseTensor - format required by tf.nn.ctc_loss. - - Args: - dense_tensor: Dense tensor with shape [batch_size, max_seq_len] - sequence_lengths: Actual sequence lengths [batch_size] - - Returns: - SparseTensor suitable for tf.nn.ctc_loss - """ - # Create mask for valid (non-zero) elements within sequence lengths - batch_size = tf.shape(dense_tensor)[0] - max_seq_len = tf.shape(dense_tensor)[1] - - # Create range indices - batch_indices = tf.range(batch_size) - seq_indices = tf.range(max_seq_len) - - # Create meshgrid for sequence dimensions - _, seq_mesh = tf.meshgrid(batch_indices, seq_indices, indexing='ij') - - # Create mask based on sequence lengths and non-zero values - length_mask = seq_mesh < tf.expand_dims(sequence_lengths, 1) - value_mask = tf.not_equal(dense_tensor, 0) - combined_mask = tf.logical_and(length_mask, value_mask) - - # Get indices of valid elements - indices = tf.where(combined_mask) - - # Get values at valid indices - values = tf.gather_nd(dense_tensor, indices) - - # Create sparse tensor - dense_shape = tf.cast(tf.shape(dense_tensor), tf.int64) - - return tf.SparseTensor( - indices=tf.cast(indices, tf.int64), - values=tf.cast(values, tf.int32), - dense_shape=dense_shape - ) +# Note: Now using classic_ctc_loss from tf_seq2seq_losses for better compatibility from rnn_model_tf import ( TripleGRUDecoder, @@ -603,28 +556,27 @@ class BrainToTextDecoderTrainerTF: features, day_indices, None, False, 'inference', training=True ) - # Calculate losses using TPU-compatible CTC implementation + # Calculate losses using community CTC implementation if use_full: - # Clean CTC loss - using standard tf.nn.ctc_loss with SparseTensor - sparse_labels = dense_to_sparse(labels, phone_seq_lens) - clean_loss = tf.nn.ctc_loss( - labels=sparse_labels, - logits=clean_logits, - label_length=None, # SparseTensor doesn't need label_length + # Clean CTC loss - using new classic_ctc_loss + # Convert to float32 for numerical stability + clean_logits_f32 = tf.cast(clean_logits, tf.float32) + clean_loss = classic_ctc_loss( + labels=labels, + logits=clean_logits_f32, + label_length=phone_seq_lens, # Direct use of length tensor logit_length=adjusted_lens, - logits_time_major=False, blank_index=0 ) clean_loss = tf.reduce_mean(clean_loss) - # Noisy CTC loss - using standard tf.nn.ctc_loss with SparseTensor - # Reuse the same sparse_labels from above - noisy_loss = tf.nn.ctc_loss( - labels=sparse_labels, - logits=noisy_logits, - label_length=None, # SparseTensor doesn't need label_length + # Noisy CTC loss - using new classic_ctc_loss + noisy_logits_f32 = tf.cast(noisy_logits, tf.float32) + noisy_loss = classic_ctc_loss( + labels=labels, + logits=noisy_logits_f32, + label_length=phone_seq_lens, # Direct use of length tensor logit_length=adjusted_lens, - logits_time_major=False, blank_index=0 ) noisy_loss = tf.reduce_mean(noisy_loss) @@ -636,14 +588,14 @@ class BrainToTextDecoderTrainerTF: loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2 else: - # Standard CTC loss - using standard tf.nn.ctc_loss with SparseTensor - sparse_labels = dense_to_sparse(labels, phone_seq_lens) - loss = tf.nn.ctc_loss( - labels=sparse_labels, - logits=clean_logits, - label_length=None, # SparseTensor doesn't need label_length + # Standard CTC loss - using new classic_ctc_loss + # Convert to float32 for numerical stability + clean_logits_f32 = tf.cast(clean_logits, tf.float32) + loss = classic_ctc_loss( + labels=labels, + logits=clean_logits_f32, + label_length=phone_seq_lens, # Direct use of length tensor logit_length=adjusted_lens, - logits_time_major=False, blank_index=0 ) loss = tf.reduce_mean(loss) @@ -701,14 +653,14 @@ class BrainToTextDecoderTrainerTF: # Forward pass (inference mode only) logits = self.model(features, day_indices, None, False, 'inference', training=False) - # Calculate loss using standard tf.nn.ctc_loss with SparseTensor - sparse_labels = dense_to_sparse(labels, phone_seq_lens) - loss = tf.nn.ctc_loss( - labels=sparse_labels, - logits=logits, - label_length=None, # SparseTensor doesn't need label_length + # Calculate loss using community CTC implementation + # Convert to float32 for numerical stability + logits_f32 = tf.cast(logits, tf.float32) + loss = classic_ctc_loss( + labels=labels, + logits=logits_f32, + label_length=phone_seq_lens, # Direct use of length tensor logit_length=adjusted_lens, - logits_time_major=False, blank_index=0 ) loss = tf.reduce_mean(loss)