diff --git a/model_training_nnn_tpu/rnn_model_tf.py b/model_training_nnn_tpu/rnn_model_tf.py index 7f29924..b7e8342 100644 --- a/model_training_nnn_tpu/rnn_model_tf.py +++ b/model_training_nnn_tpu/rnn_model_tf.py @@ -795,34 +795,31 @@ def create_tpu_strategy(): print(f" {key}={value}") try: - # Try different TPU resolver configurations - resolver = None + # Use official TPU initialization pattern (simplified and reliable) + print("🚀 Using official TensorFlow TPU initialization...") - if tpu_address: - print(f"🚀 Attempting TPU connection with address: {tpu_address}") - resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_address) - elif tpu_name: - print(f"🚀 Attempting TPU connection with name: {tpu_name}") - resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_name) - else: - # Try auto-detection - print("🚀 Attempting TPU auto-detection...") - resolver = tf.distribute.cluster_resolver.TPUClusterResolver() - - # Initialize TPU - print("⚡ Connecting to TPU cluster...") + # Use your tested official TPU initialization code + resolver = tf.distribute.cluster_resolver.TPUClusterResolver() tf.config.experimental_connect_to_cluster(resolver) - - print("🔧 Initializing TPU system...") + # This is the TPU initialization code that has to be at the beginning. tf.tpu.experimental.initialize_tpu_system(resolver) + # Verify TPU devices (following official example) + tpu_devices = tf.config.list_logical_devices('TPU') + print("All devices: ", tpu_devices) + + if not tpu_devices: + raise RuntimeError("No TPU devices found!") + + print(f"✅ Found {len(tpu_devices)} TPU devices") + # Create TPU strategy print("🎯 Creating TPU strategy...") strategy = tf.distribute.TPUStrategy(resolver) print(f"✅ TPU initialized successfully!") print(f"🎉 Number of TPU cores: {strategy.num_replicas_in_sync}") - print(f"🏃 TPU cluster: {resolver.cluster_spec()}") + return strategy except Exception as e: