f
This commit is contained in:
@@ -19,7 +19,6 @@ except ImportError:
|
|||||||
|
|
||||||
from rnn_model_tf import (
|
from rnn_model_tf import (
|
||||||
TripleGRUDecoder,
|
TripleGRUDecoder,
|
||||||
CTCLoss,
|
|
||||||
create_tpu_strategy,
|
create_tpu_strategy,
|
||||||
build_model_for_tpu,
|
build_model_for_tpu,
|
||||||
configure_mixed_precision
|
configure_mixed_precision
|
||||||
@@ -125,7 +124,7 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
print("✅ Optimizer ready for TPU training")
|
print("✅ Optimizer ready for TPU training")
|
||||||
|
|
||||||
self.lr_scheduler = self._create_lr_scheduler()
|
self.lr_scheduler = self._create_lr_scheduler()
|
||||||
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
# CTC loss is now handled using tf.nn.ctc_loss (TPU-compatible)
|
||||||
|
|
||||||
# Create unified checkpoint management
|
# Create unified checkpoint management
|
||||||
self.ckpt = tf.train.Checkpoint(
|
self.ckpt = tf.train.Checkpoint(
|
||||||
@@ -551,22 +550,29 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
# Calculate losses
|
# Calculate losses
|
||||||
if use_full:
|
if use_full:
|
||||||
# Clean CTC loss
|
# Clean CTC loss - use tf.nn.ctc_loss (TPU-compatible)
|
||||||
clean_loss_input = {
|
# tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
|
||||||
'labels': labels,
|
clean_logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
|
||||||
'input_lengths': adjusted_lens,
|
clean_loss = tf.nn.ctc_loss(
|
||||||
'label_lengths': phone_seq_lens
|
labels=tf.cast(labels, tf.int32),
|
||||||
}
|
logits=clean_logits_time_major,
|
||||||
clean_loss = self.ctc_loss(clean_loss_input, clean_logits)
|
label_length=tf.cast(phone_seq_lens, tf.int32),
|
||||||
|
logit_length=tf.cast(adjusted_lens, tf.int32),
|
||||||
|
blank_index=0,
|
||||||
|
logits_time_major=True
|
||||||
|
)
|
||||||
clean_loss = tf.reduce_mean(clean_loss)
|
clean_loss = tf.reduce_mean(clean_loss)
|
||||||
|
|
||||||
# Noisy CTC loss
|
# Noisy CTC loss - use tf.nn.ctc_loss (TPU-compatible)
|
||||||
noisy_loss_input = {
|
noisy_logits_time_major = tf.transpose(noisy_logits, [1, 0, 2])
|
||||||
'labels': labels,
|
noisy_loss = tf.nn.ctc_loss(
|
||||||
'input_lengths': adjusted_lens,
|
labels=tf.cast(labels, tf.int32),
|
||||||
'label_lengths': phone_seq_lens
|
logits=noisy_logits_time_major,
|
||||||
}
|
label_length=tf.cast(phone_seq_lens, tf.int32),
|
||||||
noisy_loss = self.ctc_loss(noisy_loss_input, noisy_logits)
|
logit_length=tf.cast(adjusted_lens, tf.int32),
|
||||||
|
blank_index=0,
|
||||||
|
logits_time_major=True
|
||||||
|
)
|
||||||
noisy_loss = tf.reduce_mean(noisy_loss)
|
noisy_loss = tf.reduce_mean(noisy_loss)
|
||||||
|
|
||||||
# Optional noise L2 regularization
|
# Optional noise L2 regularization
|
||||||
@@ -576,12 +582,16 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
|
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
|
||||||
else:
|
else:
|
||||||
loss_input = {
|
# Standard CTC loss - use tf.nn.ctc_loss (TPU-compatible)
|
||||||
'labels': labels,
|
logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
|
||||||
'input_lengths': adjusted_lens,
|
loss = tf.nn.ctc_loss(
|
||||||
'label_lengths': phone_seq_lens
|
labels=tf.cast(labels, tf.int32),
|
||||||
}
|
logits=logits_time_major,
|
||||||
loss = self.ctc_loss(loss_input, clean_logits)
|
label_length=tf.cast(phone_seq_lens, tf.int32),
|
||||||
|
logit_length=tf.cast(adjusted_lens, tf.int32),
|
||||||
|
blank_index=0,
|
||||||
|
logits_time_major=True
|
||||||
|
)
|
||||||
loss = tf.reduce_mean(loss)
|
loss = tf.reduce_mean(loss)
|
||||||
|
|
||||||
# AdamW handles weight decay automatically - no manual L2 regularization needed
|
# AdamW handles weight decay automatically - no manual L2 regularization needed
|
||||||
@@ -642,13 +652,17 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
# Forward pass (inference mode only)
|
# Forward pass (inference mode only)
|
||||||
logits = self.model(features, day_indices, None, False, 'inference', training=False)
|
logits = self.model(features, day_indices, None, False, 'inference', training=False)
|
||||||
|
|
||||||
# Calculate loss
|
# Calculate loss - use tf.nn.ctc_loss (TPU-compatible)
|
||||||
loss_input = {
|
# tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
|
||||||
'labels': labels,
|
logits_time_major = tf.transpose(logits, [1, 0, 2])
|
||||||
'input_lengths': adjusted_lens,
|
loss = tf.nn.ctc_loss(
|
||||||
'label_lengths': phone_seq_lens
|
labels=tf.cast(labels, tf.int32),
|
||||||
}
|
logits=logits_time_major,
|
||||||
loss = self.ctc_loss(loss_input, logits)
|
label_length=tf.cast(phone_seq_lens, tf.int32),
|
||||||
|
logit_length=tf.cast(adjusted_lens, tf.int32),
|
||||||
|
blank_index=0,
|
||||||
|
logits_time_major=True
|
||||||
|
)
|
||||||
loss = tf.reduce_mean(loss)
|
loss = tf.reduce_mean(loss)
|
||||||
|
|
||||||
# Greedy decoding for PER calculation
|
# Greedy decoding for PER calculation
|
||||||
|
Reference in New Issue
Block a user