tpu
This commit is contained in:
@@ -12,7 +12,7 @@ def gradient_reverse(x, lambd=1.0):
|
||||
Backward: multiply incoming gradient by -lambda
|
||||
"""
|
||||
def grad(dy):
|
||||
return -lambd * dy, None
|
||||
return -lambd * dy # Only return gradient w.r.t. x, not lambd
|
||||
|
||||
return tf.identity(x), grad
|
||||
|
||||
@@ -709,17 +709,45 @@ class CTCLoss(keras.losses.Loss):
|
||||
input_lengths = y_true['input_lengths']
|
||||
label_lengths = y_true['label_lengths']
|
||||
|
||||
# Ensure correct data types
|
||||
labels = tf.cast(labels, tf.int32)
|
||||
input_lengths = tf.cast(input_lengths, tf.int32)
|
||||
label_lengths = tf.cast(label_lengths, tf.int32)
|
||||
|
||||
# Convert logits to log probabilities
|
||||
log_probs = tf.nn.log_softmax(y_pred, axis=-1)
|
||||
|
||||
# Transpose for CTC: [time_steps, batch_size, num_classes]
|
||||
log_probs = tf.transpose(log_probs, [1, 0, 2])
|
||||
|
||||
# Convert dense labels to sparse format for CTC using TensorFlow operations
|
||||
def dense_to_sparse(dense_tensor, sequence_lengths):
|
||||
"""Convert dense tensor to sparse tensor for CTC"""
|
||||
batch_size = tf.shape(dense_tensor)[0]
|
||||
max_len = tf.shape(dense_tensor)[1]
|
||||
|
||||
# Create mask for non-zero elements
|
||||
mask = tf.not_equal(dense_tensor, 0)
|
||||
|
||||
# Get indices of non-zero elements
|
||||
indices = tf.where(mask)
|
||||
|
||||
# Get values at those indices
|
||||
values = tf.gather_nd(dense_tensor, indices)
|
||||
|
||||
# Create sparse tensor
|
||||
dense_shape = tf.cast([batch_size, max_len], tf.int64)
|
||||
|
||||
return tf.SparseTensor(indices=indices, values=values, dense_shape=dense_shape)
|
||||
|
||||
# Convert labels to sparse format
|
||||
sparse_labels = dense_to_sparse(labels, label_lengths)
|
||||
|
||||
# Compute CTC loss
|
||||
loss = tf.nn.ctc_loss(
|
||||
labels=labels,
|
||||
labels=sparse_labels,
|
||||
logits=log_probs,
|
||||
label_length=label_lengths,
|
||||
label_length=None, # Not needed for sparse format
|
||||
logit_length=input_lengths,
|
||||
blank_index=self.blank_index,
|
||||
logits_time_major=True
|
||||
|
Reference in New Issue
Block a user