f
This commit is contained in:
@@ -1,6 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import tensorflow.keras.backend as K
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
@@ -18,6 +17,9 @@ except ImportError:
|
|||||||
print("Warning: editdistance not available, falling back to approximation")
|
print("Warning: editdistance not available, falling back to approximation")
|
||||||
editdistance = None
|
editdistance = None
|
||||||
|
|
||||||
|
# XLA-compatible CTC loss implementation
|
||||||
|
from tf_seq2seq_losses import classic_ctc_loss
|
||||||
|
|
||||||
from rnn_model_tf import (
|
from rnn_model_tf import (
|
||||||
TripleGRUDecoder,
|
TripleGRUDecoder,
|
||||||
create_tpu_strategy,
|
create_tpu_strategy,
|
||||||
@@ -33,77 +35,8 @@ from dataset_tf import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def ctc_loss_for_tpu(y_true, y_pred, input_length, label_length):
|
|
||||||
"""
|
|
||||||
TPU-compatible CTC loss function using Keras backend
|
|
||||||
|
|
||||||
This implementation uses K.ctc_batch_cost which is often more robust
|
|
||||||
for XLA compilation than tf.nn.ctc_loss, especially in complex model graphs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
y_true: Dense labels [batch_size, max_label_len]
|
|
||||||
y_pred: Logits [batch_size, time_steps, num_classes]
|
|
||||||
input_length: Logit sequence lengths [batch_size]
|
|
||||||
label_length: True label sequence lengths [batch_size]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Scalar CTC loss value
|
|
||||||
"""
|
|
||||||
# K.ctc_batch_cost requires logits to be time-major [time_steps, batch_size, num_classes]
|
|
||||||
y_pred_time_major = tf.transpose(y_pred, [1, 0, 2])
|
|
||||||
|
|
||||||
# Ensure correct data types for Keras backend
|
|
||||||
y_true = tf.cast(y_true, tf.float32) # K.ctc_batch_cost expects float32 labels
|
|
||||||
input_length = tf.cast(input_length, tf.int32)
|
|
||||||
label_length = tf.cast(label_length, tf.int32)
|
|
||||||
|
|
||||||
# Calculate CTC loss using Keras backend (more XLA-friendly)
|
|
||||||
loss = K.ctc_batch_cost(y_true, y_pred_time_major, input_length, label_length)
|
|
||||||
|
|
||||||
return tf.reduce_mean(loss)
|
|
||||||
|
|
||||||
|
|
||||||
def dense_to_sparse(dense_tensor, sequence_lengths):
|
|
||||||
"""
|
|
||||||
Convert dense tensor to sparse tensor for CTC loss with dynamic shapes
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dense_tensor: Dense tensor with shape [batch_size, max_seq_len]
|
|
||||||
sequence_lengths: Actual sequence lengths [batch_size]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SparseTensor suitable for tf.nn.ctc_loss
|
|
||||||
"""
|
|
||||||
# Create mask for valid (non-zero) elements within sequence lengths
|
|
||||||
batch_size = tf.shape(dense_tensor)[0]
|
|
||||||
max_seq_len = tf.shape(dense_tensor)[1]
|
|
||||||
|
|
||||||
# Create range indices
|
|
||||||
batch_indices = tf.range(batch_size)
|
|
||||||
seq_indices = tf.range(max_seq_len)
|
|
||||||
|
|
||||||
# Create meshgrid for batch and sequence dimensions
|
|
||||||
batch_mesh, seq_mesh = tf.meshgrid(batch_indices, seq_indices, indexing='ij')
|
|
||||||
|
|
||||||
# Create mask based on sequence lengths and non-zero values
|
|
||||||
length_mask = seq_mesh < tf.expand_dims(sequence_lengths, 1)
|
|
||||||
value_mask = tf.not_equal(dense_tensor, 0)
|
|
||||||
combined_mask = tf.logical_and(length_mask, value_mask)
|
|
||||||
|
|
||||||
# Get indices of valid elements
|
|
||||||
indices = tf.where(combined_mask)
|
|
||||||
|
|
||||||
# Get values at valid indices
|
|
||||||
values = tf.gather_nd(dense_tensor, indices)
|
|
||||||
|
|
||||||
# Create sparse tensor
|
|
||||||
dense_shape = tf.cast(tf.shape(dense_tensor), tf.int64)
|
|
||||||
|
|
||||||
return tf.SparseTensor(
|
|
||||||
indices=tf.cast(indices, tf.int64),
|
|
||||||
values=tf.cast(values, tf.int32),
|
|
||||||
dense_shape=dense_shape
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BrainToTextDecoderTrainerTF:
|
class BrainToTextDecoderTrainerTF:
|
||||||
@@ -626,20 +559,22 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
# Calculate losses using TPU-compatible CTC implementation
|
# Calculate losses using TPU-compatible CTC implementation
|
||||||
if use_full:
|
if use_full:
|
||||||
# Clean CTC loss - using Keras backend for XLA compatibility
|
# Clean CTC loss - using XLA-compatible classic_ctc_loss
|
||||||
clean_loss = ctc_loss_for_tpu(
|
clean_loss = classic_ctc_loss(
|
||||||
y_true=tf.cast(labels, tf.float32), # Dense labels as float32
|
labels=tf.cast(labels, tf.int32), # Dense labels as int32
|
||||||
y_pred=clean_logits,
|
logits=clean_logits,
|
||||||
input_length=adjusted_lens,
|
label_length=phone_seq_lens,
|
||||||
label_length=phone_seq_lens
|
logit_length=adjusted_lens,
|
||||||
|
blank_index=0
|
||||||
)
|
)
|
||||||
|
|
||||||
# Noisy CTC loss - using Keras backend for XLA compatibility
|
# Noisy CTC loss - using XLA-compatible classic_ctc_loss
|
||||||
noisy_loss = ctc_loss_for_tpu(
|
noisy_loss = classic_ctc_loss(
|
||||||
y_true=tf.cast(labels, tf.float32), # Reuse same dense labels
|
labels=tf.cast(labels, tf.int32), # Dense labels as int32
|
||||||
y_pred=noisy_logits,
|
logits=noisy_logits,
|
||||||
input_length=adjusted_lens,
|
label_length=phone_seq_lens,
|
||||||
label_length=phone_seq_lens
|
logit_length=adjusted_lens,
|
||||||
|
blank_index=0
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optional noise L2 regularization
|
# Optional noise L2 regularization
|
||||||
@@ -649,12 +584,13 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
|
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
|
||||||
else:
|
else:
|
||||||
# Standard CTC loss - using Keras backend for XLA compatibility
|
# Standard CTC loss - using XLA-compatible classic_ctc_loss
|
||||||
loss = ctc_loss_for_tpu(
|
loss = classic_ctc_loss(
|
||||||
y_true=tf.cast(labels, tf.float32), # Dense labels as float32
|
labels=tf.cast(labels, tf.int32), # Dense labels as int32
|
||||||
y_pred=clean_logits,
|
logits=clean_logits,
|
||||||
input_length=adjusted_lens,
|
label_length=phone_seq_lens,
|
||||||
label_length=phone_seq_lens
|
logit_length=adjusted_lens,
|
||||||
|
blank_index=0
|
||||||
)
|
)
|
||||||
|
|
||||||
# AdamW handles weight decay automatically - no manual L2 regularization needed
|
# AdamW handles weight decay automatically - no manual L2 regularization needed
|
||||||
@@ -710,12 +646,13 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
# Forward pass (inference mode only)
|
# Forward pass (inference mode only)
|
||||||
logits = self.model(features, day_indices, None, False, 'inference', training=False)
|
logits = self.model(features, day_indices, None, False, 'inference', training=False)
|
||||||
|
|
||||||
# Calculate loss using TPU-compatible CTC implementation
|
# Calculate loss using XLA-compatible classic_ctc_loss
|
||||||
loss = ctc_loss_for_tpu(
|
loss = classic_ctc_loss(
|
||||||
y_true=tf.cast(labels, tf.float32), # Dense labels as float32
|
labels=tf.cast(labels, tf.int32), # Dense labels as int32
|
||||||
y_pred=logits,
|
logits=logits,
|
||||||
input_length=adjusted_lens,
|
label_length=phone_seq_lens,
|
||||||
label_length=phone_seq_lens
|
logit_length=adjusted_lens,
|
||||||
|
blank_index=0
|
||||||
)
|
)
|
||||||
|
|
||||||
# Greedy decoding for PER calculation
|
# Greedy decoding for PER calculation
|
||||||
|
Reference in New Issue
Block a user