f
This commit is contained in:
@@ -9,6 +9,7 @@ import pathlib
|
||||
import sys
|
||||
from typing import Dict, Any, Tuple, Optional, List
|
||||
from omegaconf import OmegaConf
|
||||
from tf_seq2seq_losses import classic_ctc_loss
|
||||
|
||||
# For accurate PER calculation
|
||||
try:
|
||||
@@ -17,55 +18,7 @@ except ImportError:
|
||||
print("Warning: editdistance not available, falling back to approximation")
|
||||
editdistance = None
|
||||
|
||||
# Note: Reverted to standard tf.nn.ctc_loss + SparseTensor approach
|
||||
# for compatibility with "batch first, augment after" data pipeline
|
||||
|
||||
|
||||
def dense_to_sparse(dense_tensor, sequence_lengths):
|
||||
"""
|
||||
Convert dense tensor to sparse tensor for CTC loss with dynamic shapes
|
||||
|
||||
This function is essential for the "batch first, augment after" approach
|
||||
as it handles the conversion from dynamic dense tensors to SparseTensor
|
||||
format required by tf.nn.ctc_loss.
|
||||
|
||||
Args:
|
||||
dense_tensor: Dense tensor with shape [batch_size, max_seq_len]
|
||||
sequence_lengths: Actual sequence lengths [batch_size]
|
||||
|
||||
Returns:
|
||||
SparseTensor suitable for tf.nn.ctc_loss
|
||||
"""
|
||||
# Create mask for valid (non-zero) elements within sequence lengths
|
||||
batch_size = tf.shape(dense_tensor)[0]
|
||||
max_seq_len = tf.shape(dense_tensor)[1]
|
||||
|
||||
# Create range indices
|
||||
batch_indices = tf.range(batch_size)
|
||||
seq_indices = tf.range(max_seq_len)
|
||||
|
||||
# Create meshgrid for sequence dimensions
|
||||
_, seq_mesh = tf.meshgrid(batch_indices, seq_indices, indexing='ij')
|
||||
|
||||
# Create mask based on sequence lengths and non-zero values
|
||||
length_mask = seq_mesh < tf.expand_dims(sequence_lengths, 1)
|
||||
value_mask = tf.not_equal(dense_tensor, 0)
|
||||
combined_mask = tf.logical_and(length_mask, value_mask)
|
||||
|
||||
# Get indices of valid elements
|
||||
indices = tf.where(combined_mask)
|
||||
|
||||
# Get values at valid indices
|
||||
values = tf.gather_nd(dense_tensor, indices)
|
||||
|
||||
# Create sparse tensor
|
||||
dense_shape = tf.cast(tf.shape(dense_tensor), tf.int64)
|
||||
|
||||
return tf.SparseTensor(
|
||||
indices=tf.cast(indices, tf.int64),
|
||||
values=tf.cast(values, tf.int32),
|
||||
dense_shape=dense_shape
|
||||
)
|
||||
# Note: Now using classic_ctc_loss from tf_seq2seq_losses for better compatibility
|
||||
|
||||
from rnn_model_tf import (
|
||||
TripleGRUDecoder,
|
||||
@@ -603,28 +556,27 @@ class BrainToTextDecoderTrainerTF:
|
||||
features, day_indices, None, False, 'inference', training=True
|
||||
)
|
||||
|
||||
# Calculate losses using TPU-compatible CTC implementation
|
||||
# Calculate losses using community CTC implementation
|
||||
if use_full:
|
||||
# Clean CTC loss - using standard tf.nn.ctc_loss with SparseTensor
|
||||
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
||||
clean_loss = tf.nn.ctc_loss(
|
||||
labels=sparse_labels,
|
||||
logits=clean_logits,
|
||||
label_length=None, # SparseTensor doesn't need label_length
|
||||
# Clean CTC loss - using new classic_ctc_loss
|
||||
# Convert to float32 for numerical stability
|
||||
clean_logits_f32 = tf.cast(clean_logits, tf.float32)
|
||||
clean_loss = classic_ctc_loss(
|
||||
labels=labels,
|
||||
logits=clean_logits_f32,
|
||||
label_length=phone_seq_lens, # Direct use of length tensor
|
||||
logit_length=adjusted_lens,
|
||||
logits_time_major=False,
|
||||
blank_index=0
|
||||
)
|
||||
clean_loss = tf.reduce_mean(clean_loss)
|
||||
|
||||
# Noisy CTC loss - using standard tf.nn.ctc_loss with SparseTensor
|
||||
# Reuse the same sparse_labels from above
|
||||
noisy_loss = tf.nn.ctc_loss(
|
||||
labels=sparse_labels,
|
||||
logits=noisy_logits,
|
||||
label_length=None, # SparseTensor doesn't need label_length
|
||||
# Noisy CTC loss - using new classic_ctc_loss
|
||||
noisy_logits_f32 = tf.cast(noisy_logits, tf.float32)
|
||||
noisy_loss = classic_ctc_loss(
|
||||
labels=labels,
|
||||
logits=noisy_logits_f32,
|
||||
label_length=phone_seq_lens, # Direct use of length tensor
|
||||
logit_length=adjusted_lens,
|
||||
logits_time_major=False,
|
||||
blank_index=0
|
||||
)
|
||||
noisy_loss = tf.reduce_mean(noisy_loss)
|
||||
@@ -636,14 +588,14 @@ class BrainToTextDecoderTrainerTF:
|
||||
|
||||
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
|
||||
else:
|
||||
# Standard CTC loss - using standard tf.nn.ctc_loss with SparseTensor
|
||||
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
||||
loss = tf.nn.ctc_loss(
|
||||
labels=sparse_labels,
|
||||
logits=clean_logits,
|
||||
label_length=None, # SparseTensor doesn't need label_length
|
||||
# Standard CTC loss - using new classic_ctc_loss
|
||||
# Convert to float32 for numerical stability
|
||||
clean_logits_f32 = tf.cast(clean_logits, tf.float32)
|
||||
loss = classic_ctc_loss(
|
||||
labels=labels,
|
||||
logits=clean_logits_f32,
|
||||
label_length=phone_seq_lens, # Direct use of length tensor
|
||||
logit_length=adjusted_lens,
|
||||
logits_time_major=False,
|
||||
blank_index=0
|
||||
)
|
||||
loss = tf.reduce_mean(loss)
|
||||
@@ -701,14 +653,14 @@ class BrainToTextDecoderTrainerTF:
|
||||
# Forward pass (inference mode only)
|
||||
logits = self.model(features, day_indices, None, False, 'inference', training=False)
|
||||
|
||||
# Calculate loss using standard tf.nn.ctc_loss with SparseTensor
|
||||
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
||||
loss = tf.nn.ctc_loss(
|
||||
labels=sparse_labels,
|
||||
logits=logits,
|
||||
label_length=None, # SparseTensor doesn't need label_length
|
||||
# Calculate loss using community CTC implementation
|
||||
# Convert to float32 for numerical stability
|
||||
logits_f32 = tf.cast(logits, tf.float32)
|
||||
loss = classic_ctc_loss(
|
||||
labels=labels,
|
||||
logits=logits_f32,
|
||||
label_length=phone_seq_lens, # Direct use of length tensor
|
||||
logit_length=adjusted_lens,
|
||||
logits_time_major=False,
|
||||
blank_index=0
|
||||
)
|
||||
loss = tf.reduce_mean(loss)
|
||||
|
Reference in New Issue
Block a user