From 7358ff3d796ad95be34c2c41aeb58ffd030705ed Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Mon, 20 Oct 2025 11:22:13 +0800 Subject: [PATCH] Enable soft device placement for CTC operations and update related comments --- model_training_nnn_tpu/trainer_tf.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/model_training_nnn_tpu/trainer_tf.py b/model_training_nnn_tpu/trainer_tf.py index 77243ca..154b557 100644 --- a/model_training_nnn_tpu/trainer_tf.py +++ b/model_training_nnn_tpu/trainer_tf.py @@ -92,6 +92,10 @@ class BrainToTextDecoderTrainerTF: self.args = args self.logger = None + # Enable soft device placement for XLA unsupported ops (like CTC) + tf.config.set_soft_device_placement(True) + print("✅ Enabled soft device placement for CTC operations") + # Initialize TPU strategy self.strategy = create_tpu_strategy() if self.strategy is None: @@ -101,8 +105,8 @@ class BrainToTextDecoderTrainerTF: print(f"Strategy type: {type(self.strategy).__name__}") print("💡 Using tf.data.AUTOTUNE for optimal data pipeline performance") print("📝 Ensure create_input_fn uses AUTOTUNE for .map() and .prefetch() operations") - print("⚠️ For best TPU performance, ensure create_input_fn uses padded_batch with fixed shapes") - print(" and drop_remainder=True to avoid dynamic shape warnings") + print("⚠️ CTC operations will automatically fall back to CPU (expected behavior)") + print(" This has minimal performance impact as CTC is a small portion of computation") # Configure mixed precision for TPU v5e-8 if args.get('use_amp', True): @@ -593,8 +597,7 @@ class BrainToTextDecoderTrainerTF: # Convert dense labels to sparse for dynamic shapes sparse_labels = dense_to_sparse(labels, phone_seq_lens) - # Clean CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes) - # tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes] + # Clean CTC loss - will auto-fallback to CPU with soft device placement clean_logits_time_major = tf.transpose(clean_logits, [1, 0, 2]) clean_loss = tf.nn.ctc_loss( labels=sparse_labels, @@ -606,7 +609,7 @@ class BrainToTextDecoderTrainerTF: ) clean_loss = tf.reduce_mean(clean_loss) - # Noisy CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes) + # Noisy CTC loss - will auto-fallback to CPU with soft device placement noisy_logits_time_major = tf.transpose(noisy_logits, [1, 0, 2]) noisy_loss = tf.nn.ctc_loss( labels=sparse_labels, # Reuse same sparse labels @@ -628,7 +631,7 @@ class BrainToTextDecoderTrainerTF: # Convert dense labels to sparse for dynamic shapes sparse_labels = dense_to_sparse(labels, phone_seq_lens) - # Standard CTC loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes) + # Standard CTC loss - will auto-fallback to CPU with soft device placement logits_time_major = tf.transpose(clean_logits, [1, 0, 2]) loss = tf.nn.ctc_loss( labels=sparse_labels, @@ -696,8 +699,7 @@ class BrainToTextDecoderTrainerTF: # Convert dense labels to sparse for dynamic shapes sparse_labels = dense_to_sparse(labels, phone_seq_lens) - # Calculate loss - use tf.nn.ctc_loss with sparse labels (dynamic shapes) - # tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes] + # Calculate loss - will auto-fallback to CPU with soft device placement logits_time_major = tf.transpose(logits, [1, 0, 2]) loss = tf.nn.ctc_loss( labels=sparse_labels,