tpu not find
This commit is contained in:
@@ -759,20 +759,74 @@ class CTCLoss(keras.losses.Loss):
|
|||||||
# TPU Strategy Helper Functions
|
# TPU Strategy Helper Functions
|
||||||
def create_tpu_strategy():
|
def create_tpu_strategy():
|
||||||
"""Create TPU strategy for distributed training on TPU v5e-8"""
|
"""Create TPU strategy for distributed training on TPU v5e-8"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
print("🔍 Detecting TPU environment...")
|
||||||
|
|
||||||
|
# Check for various TPU environment variables
|
||||||
|
tpu_address = None
|
||||||
|
tpu_name = None
|
||||||
|
|
||||||
|
# Check common TPU environment variables
|
||||||
|
if 'COLAB_TPU_ADDR' in os.environ:
|
||||||
|
tpu_address = os.environ['COLAB_TPU_ADDR']
|
||||||
|
print(f"📍 Found Colab TPU address: {tpu_address}")
|
||||||
|
elif 'TPU_NAME' in os.environ:
|
||||||
|
tpu_name = os.environ['TPU_NAME']
|
||||||
|
print(f"📍 Found TPU name: {tpu_name}")
|
||||||
|
elif 'TPU_WORKER_ID' in os.environ:
|
||||||
|
# Kaggle TPU environment
|
||||||
|
worker_id = os.environ.get('TPU_WORKER_ID', '0')
|
||||||
|
tpu_address = f'grpc://10.0.0.2:8470' # Default Kaggle TPU address
|
||||||
|
print(f"📍 Kaggle TPU detected, worker ID: {worker_id}, address: {tpu_address}")
|
||||||
|
|
||||||
|
# Print all TPU-related environment variables for debugging
|
||||||
|
print("🔧 TPU environment variables:")
|
||||||
|
tpu_vars = {k: v for k, v in os.environ.items() if 'TPU' in k or 'COLAB' in k}
|
||||||
|
for key, value in tpu_vars.items():
|
||||||
|
print(f" {key}={value}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Try different TPU resolver configurations
|
||||||
|
resolver = None
|
||||||
|
|
||||||
|
if tpu_address:
|
||||||
|
print(f"🚀 Attempting TPU connection with address: {tpu_address}")
|
||||||
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_address)
|
||||||
|
elif tpu_name:
|
||||||
|
print(f"🚀 Attempting TPU connection with name: {tpu_name}")
|
||||||
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_name)
|
||||||
|
else:
|
||||||
|
# Try auto-detection
|
||||||
|
print("🚀 Attempting TPU auto-detection...")
|
||||||
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
|
||||||
|
|
||||||
# Initialize TPU
|
# Initialize TPU
|
||||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
|
print("⚡ Connecting to TPU cluster...")
|
||||||
tf.config.experimental_connect_to_cluster(resolver)
|
tf.config.experimental_connect_to_cluster(resolver)
|
||||||
|
|
||||||
|
print("🔧 Initializing TPU system...")
|
||||||
tf.tpu.experimental.initialize_tpu_system(resolver)
|
tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||||
|
|
||||||
# Create TPU strategy
|
# Create TPU strategy
|
||||||
|
print("🎯 Creating TPU strategy...")
|
||||||
strategy = tf.distribute.TPUStrategy(resolver)
|
strategy = tf.distribute.TPUStrategy(resolver)
|
||||||
print(f"TPU initialized successfully. Number of replicas: {strategy.num_replicas_in_sync}")
|
|
||||||
|
print(f"✅ TPU initialized successfully!")
|
||||||
|
print(f"🎉 Number of TPU cores: {strategy.num_replicas_in_sync}")
|
||||||
|
print(f"🏃 TPU cluster: {resolver.cluster_spec()}")
|
||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
except ValueError as e:
|
except Exception as e:
|
||||||
print(f"Failed to initialize TPU: {e}")
|
print(f"❌ Failed to initialize TPU: {e}")
|
||||||
print("Falling back to default strategy")
|
print(f"🔍 Error type: {type(e).__name__}")
|
||||||
|
|
||||||
|
# Enhanced error reporting
|
||||||
|
if "Please provide a TPU Name" in str(e):
|
||||||
|
print("💡 Hint: TPU name/address not found in environment variables")
|
||||||
|
print(" Common variables: COLAB_TPU_ADDR, TPU_NAME, TPU_WORKER_ID")
|
||||||
|
|
||||||
|
print("🔄 Falling back to default strategy (CPU/GPU)")
|
||||||
return tf.distribute.get_strategy()
|
return tf.distribute.get_strategy()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user