f
This commit is contained in:
@@ -19,7 +19,6 @@ except ImportError:
|
||||
|
||||
from rnn_model_tf import (
|
||||
TripleGRUDecoder,
|
||||
CTCLoss,
|
||||
create_tpu_strategy,
|
||||
build_model_for_tpu,
|
||||
configure_mixed_precision
|
||||
@@ -125,7 +124,7 @@ class BrainToTextDecoderTrainerTF:
|
||||
print("✅ Optimizer ready for TPU training")
|
||||
|
||||
self.lr_scheduler = self._create_lr_scheduler()
|
||||
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
||||
# CTC loss is now handled using tf.nn.ctc_loss (TPU-compatible)
|
||||
|
||||
# Create unified checkpoint management
|
||||
self.ckpt = tf.train.Checkpoint(
|
||||
@@ -551,22 +550,29 @@ class BrainToTextDecoderTrainerTF:
|
||||
|
||||
# Calculate losses
|
||||
if use_full:
|
||||
# Clean CTC loss
|
||||
clean_loss_input = {
|
||||
'labels': labels,
|
||||
'input_lengths': adjusted_lens,
|
||||
'label_lengths': phone_seq_lens
|
||||
}
|
||||
clean_loss = self.ctc_loss(clean_loss_input, clean_logits)
|
||||
# Clean CTC loss - use tf.nn.ctc_loss (TPU-compatible)
|
||||
# tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
|
||||
clean_logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
|
||||
clean_loss = tf.nn.ctc_loss(
|
||||
labels=tf.cast(labels, tf.int32),
|
||||
logits=clean_logits_time_major,
|
||||
label_length=tf.cast(phone_seq_lens, tf.int32),
|
||||
logit_length=tf.cast(adjusted_lens, tf.int32),
|
||||
blank_index=0,
|
||||
logits_time_major=True
|
||||
)
|
||||
clean_loss = tf.reduce_mean(clean_loss)
|
||||
|
||||
# Noisy CTC loss
|
||||
noisy_loss_input = {
|
||||
'labels': labels,
|
||||
'input_lengths': adjusted_lens,
|
||||
'label_lengths': phone_seq_lens
|
||||
}
|
||||
noisy_loss = self.ctc_loss(noisy_loss_input, noisy_logits)
|
||||
# Noisy CTC loss - use tf.nn.ctc_loss (TPU-compatible)
|
||||
noisy_logits_time_major = tf.transpose(noisy_logits, [1, 0, 2])
|
||||
noisy_loss = tf.nn.ctc_loss(
|
||||
labels=tf.cast(labels, tf.int32),
|
||||
logits=noisy_logits_time_major,
|
||||
label_length=tf.cast(phone_seq_lens, tf.int32),
|
||||
logit_length=tf.cast(adjusted_lens, tf.int32),
|
||||
blank_index=0,
|
||||
logits_time_major=True
|
||||
)
|
||||
noisy_loss = tf.reduce_mean(noisy_loss)
|
||||
|
||||
# Optional noise L2 regularization
|
||||
@@ -576,12 +582,16 @@ class BrainToTextDecoderTrainerTF:
|
||||
|
||||
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
|
||||
else:
|
||||
loss_input = {
|
||||
'labels': labels,
|
||||
'input_lengths': adjusted_lens,
|
||||
'label_lengths': phone_seq_lens
|
||||
}
|
||||
loss = self.ctc_loss(loss_input, clean_logits)
|
||||
# Standard CTC loss - use tf.nn.ctc_loss (TPU-compatible)
|
||||
logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
|
||||
loss = tf.nn.ctc_loss(
|
||||
labels=tf.cast(labels, tf.int32),
|
||||
logits=logits_time_major,
|
||||
label_length=tf.cast(phone_seq_lens, tf.int32),
|
||||
logit_length=tf.cast(adjusted_lens, tf.int32),
|
||||
blank_index=0,
|
||||
logits_time_major=True
|
||||
)
|
||||
loss = tf.reduce_mean(loss)
|
||||
|
||||
# AdamW handles weight decay automatically - no manual L2 regularization needed
|
||||
@@ -642,13 +652,17 @@ class BrainToTextDecoderTrainerTF:
|
||||
# Forward pass (inference mode only)
|
||||
logits = self.model(features, day_indices, None, False, 'inference', training=False)
|
||||
|
||||
# Calculate loss
|
||||
loss_input = {
|
||||
'labels': labels,
|
||||
'input_lengths': adjusted_lens,
|
||||
'label_lengths': phone_seq_lens
|
||||
}
|
||||
loss = self.ctc_loss(loss_input, logits)
|
||||
# Calculate loss - use tf.nn.ctc_loss (TPU-compatible)
|
||||
# tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
|
||||
logits_time_major = tf.transpose(logits, [1, 0, 2])
|
||||
loss = tf.nn.ctc_loss(
|
||||
labels=tf.cast(labels, tf.int32),
|
||||
logits=logits_time_major,
|
||||
label_length=tf.cast(phone_seq_lens, tf.int32),
|
||||
logit_length=tf.cast(adjusted_lens, tf.int32),
|
||||
blank_index=0,
|
||||
logits_time_major=True
|
||||
)
|
||||
loss = tf.reduce_mean(loss)
|
||||
|
||||
# Greedy decoding for PER calculation
|
||||
|
Reference in New Issue
Block a user