From f84d6254e372142769311be66b5cfca9ebb6d2b8 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Thu, 16 Oct 2025 00:53:42 +0800 Subject: [PATCH] =?UTF-8?q?tf=20=E7=8E=AF=E5=A2=83=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model_training_nnn_tpu/rnn_model_tf.py | 33 ++++++++++++-------------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/model_training_nnn_tpu/rnn_model_tf.py b/model_training_nnn_tpu/rnn_model_tf.py index 7f29924..b7e8342 100644 --- a/model_training_nnn_tpu/rnn_model_tf.py +++ b/model_training_nnn_tpu/rnn_model_tf.py @@ -795,34 +795,31 @@ def create_tpu_strategy(): print(f" {key}={value}") try: - # Try different TPU resolver configurations - resolver = None + # Use official TPU initialization pattern (simplified and reliable) + print("🚀 Using official TensorFlow TPU initialization...") - if tpu_address: - print(f"🚀 Attempting TPU connection with address: {tpu_address}") - resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_address) - elif tpu_name: - print(f"🚀 Attempting TPU connection with name: {tpu_name}") - resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_name) - else: - # Try auto-detection - print("🚀 Attempting TPU auto-detection...") - resolver = tf.distribute.cluster_resolver.TPUClusterResolver() - - # Initialize TPU - print("⚡ Connecting to TPU cluster...") + # Use your tested official TPU initialization code + resolver = tf.distribute.cluster_resolver.TPUClusterResolver() tf.config.experimental_connect_to_cluster(resolver) - - print("🔧 Initializing TPU system...") + # This is the TPU initialization code that has to be at the beginning. tf.tpu.experimental.initialize_tpu_system(resolver) + # Verify TPU devices (following official example) + tpu_devices = tf.config.list_logical_devices('TPU') + print("All devices: ", tpu_devices) + + if not tpu_devices: + raise RuntimeError("No TPU devices found!") + + print(f"✅ Found {len(tpu_devices)} TPU devices") + # Create TPU strategy print("🎯 Creating TPU strategy...") strategy = tf.distribute.TPUStrategy(resolver) print(f"✅ TPU initialized successfully!") print(f"🎉 Number of TPU cores: {strategy.num_replicas_in_sync}") - print(f"🏃 TPU cluster: {resolver.cluster_spec()}") + return strategy except Exception as e: