Enable soft device placement for CTC operations and update related comments
This commit is contained in:
@@ -92,6 +92,10 @@ class BrainToTextDecoderTrainerTF:
|
||||
self.args = args
|
||||
self.logger = None
|
||||
|
||||
# Enable soft device placement for XLA unsupported ops (like CTC)
|
||||
tf.config.set_soft_device_placement(True)
|
||||
print("✅ Enabled soft device placement for CTC operations")
|
||||
|
||||
# Initialize TPU strategy
|
||||
self.strategy = create_tpu_strategy()
|
||||
if self.strategy is None:
|
||||
@@ -101,8 +105,8 @@ class BrainToTextDecoderTrainerTF:
|
||||
print(f"Strategy type: {type(self.strategy).__name__}")
|
||||
print("💡 Using tf.data.AUTOTUNE for optimal data pipeline performance")
|
||||
print("📝 Ensure create_input_fn uses AUTOTUNE for .map() and .prefetch() operations")
|
||||
print("⚠️ For best TPU performance, ensure create_input_fn uses padded_batch with fixed shapes")
|
||||
print(" and drop_remainder=True to avoid dynamic shape warnings")
|
||||
print("⚠️ CTC operations will automatically fall back to CPU (expected behavior)")
|
||||
print(" This has minimal performance impact as CTC is a small portion of computation")
|
||||
|
||||
# Configure mixed precision for TPU v5e-8
|
||||
if args.get('use_amp', True):
|
||||
@@ -593,8 +597,7 @@ class BrainToTextDecoderTrainerTF:
|
||||
# Convert dense labels to sparse for dynamic shapes
|
||||
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
||||
|
||||
# Clean CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes)
|
||||
# tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
|
||||
# Clean CTC loss - will auto-fallback to CPU with soft device placement
|
||||
clean_logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
|
||||
clean_loss = tf.nn.ctc_loss(
|
||||
labels=sparse_labels,
|
||||
@@ -606,7 +609,7 @@ class BrainToTextDecoderTrainerTF:
|
||||
)
|
||||
clean_loss = tf.reduce_mean(clean_loss)
|
||||
|
||||
# Noisy CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes)
|
||||
# Noisy CTC loss - will auto-fallback to CPU with soft device placement
|
||||
noisy_logits_time_major = tf.transpose(noisy_logits, [1, 0, 2])
|
||||
noisy_loss = tf.nn.ctc_loss(
|
||||
labels=sparse_labels, # Reuse same sparse labels
|
||||
@@ -628,7 +631,7 @@ class BrainToTextDecoderTrainerTF:
|
||||
# Convert dense labels to sparse for dynamic shapes
|
||||
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
||||
|
||||
# Standard CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes)
|
||||
# Standard CTC loss - will auto-fallback to CPU with soft device placement
|
||||
logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
|
||||
loss = tf.nn.ctc_loss(
|
||||
labels=sparse_labels,
|
||||
@@ -696,8 +699,7 @@ class BrainToTextDecoderTrainerTF:
|
||||
# Convert dense labels to sparse for dynamic shapes
|
||||
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
||||
|
||||
# Calculate loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes)
|
||||
# tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
|
||||
# Calculate loss - will auto-fallback to CPU with soft device placement
|
||||
logits_time_major = tf.transpose(logits, [1, 0, 2])
|
||||
loss = tf.nn.ctc_loss(
|
||||
labels=sparse_labels,
|
||||
|
Reference in New Issue
Block a user