Enable soft device placement for CTC operations and update related comments

This commit is contained in:
Zchen
2025-10-20 11:22:13 +08:00
parent f8fb4d7133
commit 7358ff3d79

View File

@@ -92,6 +92,10 @@ class BrainToTextDecoderTrainerTF:
self.args = args
self.logger = None
# Enable soft device placement for XLA unsupported ops (like CTC)
tf.config.set_soft_device_placement(True)
print("✅ Enabled soft device placement for CTC operations")
# Initialize TPU strategy
self.strategy = create_tpu_strategy()
if self.strategy is None:
@@ -101,8 +105,8 @@ class BrainToTextDecoderTrainerTF:
print(f"Strategy type: {type(self.strategy).__name__}")
print("💡 Using tf.data.AUTOTUNE for optimal data pipeline performance")
print("📝 Ensure create_input_fn uses AUTOTUNE for .map() and .prefetch() operations")
print("⚠️ For best TPU performance, ensure create_input_fn uses padded_batch with fixed shapes")
print(" and drop_remainder=True to avoid dynamic shape warnings")
print("⚠️ CTC operations will automatically fall back to CPU (expected behavior)")
print(" This has minimal performance impact as CTC is a small portion of computation")
# Configure mixed precision for TPU v5e-8
if args.get('use_amp', True):
@@ -593,8 +597,7 @@ class BrainToTextDecoderTrainerTF:
# Convert dense labels to sparse for dynamic shapes
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
# Clean CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes)
# tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
# Clean CTC loss - will auto-fallback to CPU with soft device placement
clean_logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
clean_loss = tf.nn.ctc_loss(
labels=sparse_labels,
@@ -606,7 +609,7 @@ class BrainToTextDecoderTrainerTF:
)
clean_loss = tf.reduce_mean(clean_loss)
# Noisy CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes)
# Noisy CTC loss - will auto-fallback to CPU with soft device placement
noisy_logits_time_major = tf.transpose(noisy_logits, [1, 0, 2])
noisy_loss = tf.nn.ctc_loss(
labels=sparse_labels, # Reuse same sparse labels
@@ -628,7 +631,7 @@ class BrainToTextDecoderTrainerTF:
# Convert dense labels to sparse for dynamic shapes
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
# Standard CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes)
# Standard CTC loss - will auto-fallback to CPU with soft device placement
logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
loss = tf.nn.ctc_loss(
labels=sparse_labels,
@@ -696,8 +699,7 @@ class BrainToTextDecoderTrainerTF:
# Convert dense labels to sparse for dynamic shapes
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
# Calculate loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes)
# tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
# Calculate loss - will auto-fallback to CPU with soft device placement
logits_time_major = tf.transpose(logits, [1, 0, 2])
loss = tf.nn.ctc_loss(
labels=sparse_labels,