diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 249ae47..cd20920 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -1,6 +1,5 @@ import os import tensorflow as tf -import tensorflow.keras.backend as K import numpy as np import time import json @@ -18,6 +17,9 @@ except ImportError: print("Warning: editdistance not available, falling back to approximation") editdistance = None +# XLA-compatible CTC loss implementation +from tf_seq2seq_losses import classic_ctc_loss + from rnn_model_tf import ( TripleGRUDecoder, create_tpu_strategy, @@ -33,77 +35,8 @@ from dataset_tf import ( ) -def ctc_loss_for_tpu(y_true, y_pred, input_length, label_length): - """ - TPU-compatible CTC loss function using Keras backend - - This implementation uses K.ctc_batch_cost which is often more robust - for XLA compilation than tf.nn.ctc_loss, especially in complex model graphs. - - Args: - y_true: Dense labels [batch_size, max_label_len] - y_pred: Logits [batch_size, time_steps, num_classes] - input_length: Logit sequence lengths [batch_size] - label_length: True label sequence lengths [batch_size] - - Returns: - Scalar CTC loss value - """ - # K.ctc_batch_cost requires logits to be time-major [time_steps, batch_size, num_classes] - y_pred_time_major = tf.transpose(y_pred, [1, 0, 2]) - - # Ensure correct data types for Keras backend - y_true = tf.cast(y_true, tf.float32) # K.ctc_batch_cost expects float32 labels - input_length = tf.cast(input_length, tf.int32) - label_length = tf.cast(label_length, tf.int32) - - # Calculate CTC loss using Keras backend (more XLA-friendly) - loss = K.ctc_batch_cost(y_true, y_pred_time_major, input_length, label_length) - - return tf.reduce_mean(loss) -def dense_to_sparse(dense_tensor, sequence_lengths): - """ - Convert dense tensor to sparse tensor for CTC loss with dynamic shapes - - Args: - dense_tensor: Dense tensor with shape [batch_size, max_seq_len] - sequence_lengths: Actual sequence lengths [batch_size] - - Returns: - SparseTensor suitable for tf.nn.ctc_loss - """ - # Create mask for valid (non-zero) elements within sequence lengths - batch_size = tf.shape(dense_tensor)[0] - max_seq_len = tf.shape(dense_tensor)[1] - - # Create range indices - batch_indices = tf.range(batch_size) - seq_indices = tf.range(max_seq_len) - - # Create meshgrid for batch and sequence dimensions - batch_mesh, seq_mesh = tf.meshgrid(batch_indices, seq_indices, indexing='ij') - - # Create mask based on sequence lengths and non-zero values - length_mask = seq_mesh < tf.expand_dims(sequence_lengths, 1) - value_mask = tf.not_equal(dense_tensor, 0) - combined_mask = tf.logical_and(length_mask, value_mask) - - # Get indices of valid elements - indices = tf.where(combined_mask) - - # Get values at valid indices - values = tf.gather_nd(dense_tensor, indices) - - # Create sparse tensor - dense_shape = tf.cast(tf.shape(dense_tensor), tf.int64) - - return tf.SparseTensor( - indices=tf.cast(indices, tf.int64), - values=tf.cast(values, tf.int32), - dense_shape=dense_shape - ) class BrainToTextDecoderTrainerTF: @@ -626,20 +559,22 @@ class BrainToTextDecoderTrainerTF: # Calculate losses using TPU-compatible CTC implementation if use_full: - # Clean CTC loss - using Keras backend for XLA compatibility - clean_loss = ctc_loss_for_tpu( - y_true=tf.cast(labels, tf.float32), # Dense labels as float32 - y_pred=clean_logits, - input_length=adjusted_lens, - label_length=phone_seq_lens + # Clean CTC loss - using XLA-compatible classic_ctc_loss + clean_loss = classic_ctc_loss( + labels=tf.cast(labels, tf.int32), # Dense labels as int32 + logits=clean_logits, + label_length=phone_seq_lens, + logit_length=adjusted_lens, + blank_index=0 ) - # Noisy CTC loss - using Keras backend for XLA compatibility - noisy_loss = ctc_loss_for_tpu( - y_true=tf.cast(labels, tf.float32), # Reuse same dense labels - y_pred=noisy_logits, - input_length=adjusted_lens, - label_length=phone_seq_lens + # Noisy CTC loss - using XLA-compatible classic_ctc_loss + noisy_loss = classic_ctc_loss( + labels=tf.cast(labels, tf.int32), # Dense labels as int32 + logits=noisy_logits, + label_length=phone_seq_lens, + logit_length=adjusted_lens, + blank_index=0 ) # Optional noise L2 regularization @@ -649,12 +584,13 @@ class BrainToTextDecoderTrainerTF: loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2 else: - # Standard CTC loss - using Keras backend for XLA compatibility - loss = ctc_loss_for_tpu( - y_true=tf.cast(labels, tf.float32), # Dense labels as float32 - y_pred=clean_logits, - input_length=adjusted_lens, - label_length=phone_seq_lens + # Standard CTC loss - using XLA-compatible classic_ctc_loss + loss = classic_ctc_loss( + labels=tf.cast(labels, tf.int32), # Dense labels as int32 + logits=clean_logits, + label_length=phone_seq_lens, + logit_length=adjusted_lens, + blank_index=0 ) # AdamW handles weight decay automatically - no manual L2 regularization needed @@ -710,12 +646,13 @@ class BrainToTextDecoderTrainerTF: # Forward pass (inference mode only) logits = self.model(features, day_indices, None, False, 'inference', training=False) - # Calculate loss using TPU-compatible CTC implementation - loss = ctc_loss_for_tpu( - y_true=tf.cast(labels, tf.float32), # Dense labels as float32 - y_pred=logits, - input_length=adjusted_lens, - label_length=phone_seq_lens + # Calculate loss using XLA-compatible classic_ctc_loss + loss = classic_ctc_loss( + labels=tf.cast(labels, tf.int32), # Dense labels as int32 + logits=logits, + label_length=phone_seq_lens, + logit_length=adjusted_lens, + blank_index=0 ) # Greedy decoding for PER calculation