From fed5fd8251aee7cf5e9488c1aedf96aa8390a0f2 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Sun, 19 Oct 2025 22:25:21 +0800 Subject: [PATCH] f --- model_training_nnn_tpu/trainer_tf.py | 72 +++++++++++++++++----------- 1 file changed, 43 insertions(+), 29 deletions(-) diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 8e65b54..9e7dee7 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -19,7 +19,6 @@ except ImportError: from rnn_model_tf import ( TripleGRUDecoder, - CTCLoss, create_tpu_strategy, build_model_for_tpu, configure_mixed_precision @@ -125,7 +124,7 @@ class BrainToTextDecoderTrainerTF: print("✅ Optimizer ready for TPU training") self.lr_scheduler = self._create_lr_scheduler() - self.ctc_loss = CTCLoss(blank_index=0, reduction='none') + # CTC loss is now handled using tf.nn.ctc_loss (TPU-compatible) # Create unified checkpoint management self.ckpt = tf.train.Checkpoint( @@ -551,22 +550,29 @@ class BrainToTextDecoderTrainerTF: # Calculate losses if use_full: - # Clean CTC loss - clean_loss_input = { - 'labels': labels, - 'input_lengths': adjusted_lens, - 'label_lengths': phone_seq_lens - } - clean_loss = self.ctc_loss(clean_loss_input, clean_logits) + # Clean CTC loss - use tf.nn.ctc_loss (TPU-compatible) + # tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes] + clean_logits_time_major = tf.transpose(clean_logits, [1, 0, 2]) + clean_loss = tf.nn.ctc_loss( + labels=tf.cast(labels, tf.int32), + logits=clean_logits_time_major, + label_length=tf.cast(phone_seq_lens, tf.int32), + logit_length=tf.cast(adjusted_lens, tf.int32), + blank_index=0, + logits_time_major=True + ) clean_loss = tf.reduce_mean(clean_loss) - # Noisy CTC loss - noisy_loss_input = { - 'labels': labels, - 'input_lengths': adjusted_lens, - 'label_lengths': phone_seq_lens - } - noisy_loss = self.ctc_loss(noisy_loss_input, noisy_logits) + # Noisy CTC loss - use tf.nn.ctc_loss (TPU-compatible) + noisy_logits_time_major = tf.transpose(noisy_logits, [1, 0, 2]) + noisy_loss = tf.nn.ctc_loss( + labels=tf.cast(labels, tf.int32), + logits=noisy_logits_time_major, + label_length=tf.cast(phone_seq_lens, tf.int32), + logit_length=tf.cast(adjusted_lens, tf.int32), + blank_index=0, + logits_time_major=True + ) noisy_loss = tf.reduce_mean(noisy_loss) # Optional noise L2 regularization @@ -576,12 +582,16 @@ class BrainToTextDecoderTrainerTF: loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2 else: - loss_input = { - 'labels': labels, - 'input_lengths': adjusted_lens, - 'label_lengths': phone_seq_lens - } - loss = self.ctc_loss(loss_input, clean_logits) + # Standard CTC loss - use tf.nn.ctc_loss (TPU-compatible) + logits_time_major = tf.transpose(clean_logits, [1, 0, 2]) + loss = tf.nn.ctc_loss( + labels=tf.cast(labels, tf.int32), + logits=logits_time_major, + label_length=tf.cast(phone_seq_lens, tf.int32), + logit_length=tf.cast(adjusted_lens, tf.int32), + blank_index=0, + logits_time_major=True + ) loss = tf.reduce_mean(loss) # AdamW handles weight decay automatically - no manual L2 regularization needed @@ -642,13 +652,17 @@ class BrainToTextDecoderTrainerTF: # Forward pass (inference mode only) logits = self.model(features, day_indices, None, False, 'inference', training=False) - # Calculate loss - loss_input = { - 'labels': labels, - 'input_lengths': adjusted_lens, - 'label_lengths': phone_seq_lens - } - loss = self.ctc_loss(loss_input, logits) + # Calculate loss - use tf.nn.ctc_loss (TPU-compatible) + # tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes] + logits_time_major = tf.transpose(logits, [1, 0, 2]) + loss = tf.nn.ctc_loss( + labels=tf.cast(labels, tf.int32), + logits=logits_time_major, + label_length=tf.cast(phone_seq_lens, tf.int32), + logit_length=tf.cast(adjusted_lens, tf.int32), + blank_index=0, + logits_time_major=True + ) loss = tf.reduce_mean(loss) # Greedy decoding for PER calculation