Enable soft device placement for CTC operations and update related comments
This commit is contained in:
@@ -92,6 +92,10 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
self.args = args
|
self.args = args
|
||||||
self.logger = None
|
self.logger = None
|
||||||
|
|
||||||
|
# Enable soft device placement for XLA unsupported ops (like CTC)
|
||||||
|
tf.config.set_soft_device_placement(True)
|
||||||
|
print("✅ Enabled soft device placement for CTC operations")
|
||||||
|
|
||||||
# Initialize TPU strategy
|
# Initialize TPU strategy
|
||||||
self.strategy = create_tpu_strategy()
|
self.strategy = create_tpu_strategy()
|
||||||
if self.strategy is None:
|
if self.strategy is None:
|
||||||
@@ -101,8 +105,8 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
print(f"Strategy type: {type(self.strategy).__name__}")
|
print(f"Strategy type: {type(self.strategy).__name__}")
|
||||||
print("💡 Using tf.data.AUTOTUNE for optimal data pipeline performance")
|
print("💡 Using tf.data.AUTOTUNE for optimal data pipeline performance")
|
||||||
print("📝 Ensure create_input_fn uses AUTOTUNE for .map() and .prefetch() operations")
|
print("📝 Ensure create_input_fn uses AUTOTUNE for .map() and .prefetch() operations")
|
||||||
print("⚠️ For best TPU performance, ensure create_input_fn uses padded_batch with fixed shapes")
|
print("⚠️ CTC operations will automatically fall back to CPU (expected behavior)")
|
||||||
print(" and drop_remainder=True to avoid dynamic shape warnings")
|
print(" This has minimal performance impact as CTC is a small portion of computation")
|
||||||
|
|
||||||
# Configure mixed precision for TPU v5e-8
|
# Configure mixed precision for TPU v5e-8
|
||||||
if args.get('use_amp', True):
|
if args.get('use_amp', True):
|
||||||
@@ -593,8 +597,7 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
# Convert dense labels to sparse for dynamic shapes
|
# Convert dense labels to sparse for dynamic shapes
|
||||||
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
||||||
|
|
||||||
# Clean CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes)
|
# Clean CTC loss - will auto-fallback to CPU with soft device placement
|
||||||
# tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
|
|
||||||
clean_logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
|
clean_logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
|
||||||
clean_loss = tf.nn.ctc_loss(
|
clean_loss = tf.nn.ctc_loss(
|
||||||
labels=sparse_labels,
|
labels=sparse_labels,
|
||||||
@@ -606,7 +609,7 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
)
|
)
|
||||||
clean_loss = tf.reduce_mean(clean_loss)
|
clean_loss = tf.reduce_mean(clean_loss)
|
||||||
|
|
||||||
# Noisy CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes)
|
# Noisy CTC loss - will auto-fallback to CPU with soft device placement
|
||||||
noisy_logits_time_major = tf.transpose(noisy_logits, [1, 0, 2])
|
noisy_logits_time_major = tf.transpose(noisy_logits, [1, 0, 2])
|
||||||
noisy_loss = tf.nn.ctc_loss(
|
noisy_loss = tf.nn.ctc_loss(
|
||||||
labels=sparse_labels, # Reuse same sparse labels
|
labels=sparse_labels, # Reuse same sparse labels
|
||||||
@@ -628,7 +631,7 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
# Convert dense labels to sparse for dynamic shapes
|
# Convert dense labels to sparse for dynamic shapes
|
||||||
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
||||||
|
|
||||||
# Standard CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes)
|
# Standard CTC loss - will auto-fallback to CPU with soft device placement
|
||||||
logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
|
logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
|
||||||
loss = tf.nn.ctc_loss(
|
loss = tf.nn.ctc_loss(
|
||||||
labels=sparse_labels,
|
labels=sparse_labels,
|
||||||
@@ -696,8 +699,7 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
# Convert dense labels to sparse for dynamic shapes
|
# Convert dense labels to sparse for dynamic shapes
|
||||||
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
sparse_labels = dense_to_sparse(labels, phone_seq_lens)
|
||||||
|
|
||||||
# Calculate loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes)
|
# Calculate loss - will auto-fallback to CPU with soft device placement
|
||||||
# tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
|
|
||||||
logits_time_major = tf.transpose(logits, [1, 0, 2])
|
logits_time_major = tf.transpose(logits, [1, 0, 2])
|
||||||
loss = tf.nn.ctc_loss(
|
loss = tf.nn.ctc_loss(
|
||||||
labels=sparse_labels,
|
labels=sparse_labels,
|
||||||
|
Reference in New Issue
Block a user