This commit is contained in:
Zchen
2025-10-19 22:25:21 +08:00
parent 4b373ab317
commit fed5fd8251

View File

@@ -19,7 +19,6 @@ except ImportError:
from rnn_model_tf import ( from rnn_model_tf import (
TripleGRUDecoder, TripleGRUDecoder,
CTCLoss,
create_tpu_strategy, create_tpu_strategy,
build_model_for_tpu, build_model_for_tpu,
configure_mixed_precision configure_mixed_precision
@@ -125,7 +124,7 @@ class BrainToTextDecoderTrainerTF:
print("✅ Optimizer ready for TPU training") print("✅ Optimizer ready for TPU training")
self.lr_scheduler = self._create_lr_scheduler() self.lr_scheduler = self._create_lr_scheduler()
self.ctc_loss = CTCLoss(blank_index=0, reduction='none') # CTC loss is now handled using tf.nn.ctc_loss (TPU-compatible)
# Create unified checkpoint management # Create unified checkpoint management
self.ckpt = tf.train.Checkpoint( self.ckpt = tf.train.Checkpoint(
@@ -551,22 +550,29 @@ class BrainToTextDecoderTrainerTF:
# Calculate losses # Calculate losses
if use_full: if use_full:
# Clean CTC loss # Clean CTC loss - use tf.nn.ctc_loss (TPU-compatible)
clean_loss_input = { # tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
'labels': labels, clean_logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
'input_lengths': adjusted_lens, clean_loss = tf.nn.ctc_loss(
'label_lengths': phone_seq_lens labels=tf.cast(labels, tf.int32),
} logits=clean_logits_time_major,
clean_loss = self.ctc_loss(clean_loss_input, clean_logits) label_length=tf.cast(phone_seq_lens, tf.int32),
logit_length=tf.cast(adjusted_lens, tf.int32),
blank_index=0,
logits_time_major=True
)
clean_loss = tf.reduce_mean(clean_loss) clean_loss = tf.reduce_mean(clean_loss)
# Noisy CTC loss # Noisy CTC loss - use tf.nn.ctc_loss (TPU-compatible)
noisy_loss_input = { noisy_logits_time_major = tf.transpose(noisy_logits, [1, 0, 2])
'labels': labels, noisy_loss = tf.nn.ctc_loss(
'input_lengths': adjusted_lens, labels=tf.cast(labels, tf.int32),
'label_lengths': phone_seq_lens logits=noisy_logits_time_major,
} label_length=tf.cast(phone_seq_lens, tf.int32),
noisy_loss = self.ctc_loss(noisy_loss_input, noisy_logits) logit_length=tf.cast(adjusted_lens, tf.int32),
blank_index=0,
logits_time_major=True
)
noisy_loss = tf.reduce_mean(noisy_loss) noisy_loss = tf.reduce_mean(noisy_loss)
# Optional noise L2 regularization # Optional noise L2 regularization
@@ -576,12 +582,16 @@ class BrainToTextDecoderTrainerTF:
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2 loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
else: else:
loss_input = { # Standard CTC loss - use tf.nn.ctc_loss (TPU-compatible)
'labels': labels, logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
'input_lengths': adjusted_lens, loss = tf.nn.ctc_loss(
'label_lengths': phone_seq_lens labels=tf.cast(labels, tf.int32),
} logits=logits_time_major,
loss = self.ctc_loss(loss_input, clean_logits) label_length=tf.cast(phone_seq_lens, tf.int32),
logit_length=tf.cast(adjusted_lens, tf.int32),
blank_index=0,
logits_time_major=True
)
loss = tf.reduce_mean(loss) loss = tf.reduce_mean(loss)
# AdamW handles weight decay automatically - no manual L2 regularization needed # AdamW handles weight decay automatically - no manual L2 regularization needed
@@ -642,13 +652,17 @@ class BrainToTextDecoderTrainerTF:
# Forward pass (inference mode only) # Forward pass (inference mode only)
logits = self.model(features, day_indices, None, False, 'inference', training=False) logits = self.model(features, day_indices, None, False, 'inference', training=False)
# Calculate loss # Calculate loss - use tf.nn.ctc_loss (TPU-compatible)
loss_input = { # tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
'labels': labels, logits_time_major = tf.transpose(logits, [1, 0, 2])
'input_lengths': adjusted_lens, loss = tf.nn.ctc_loss(
'label_lengths': phone_seq_lens labels=tf.cast(labels, tf.int32),
} logits=logits_time_major,
loss = self.ctc_loss(loss_input, logits) label_length=tf.cast(phone_seq_lens, tf.int32),
logit_length=tf.cast(adjusted_lens, tf.int32),
blank_index=0,
logits_time_major=True
)
loss = tf.reduce_mean(loss) loss = tf.reduce_mean(loss)
# Greedy decoding for PER calculation # Greedy decoding for PER calculation