tf 环境问题
This commit is contained in:
@@ -795,34 +795,31 @@ def create_tpu_strategy():
|
|||||||
print(f" {key}={value}")
|
print(f" {key}={value}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Try different TPU resolver configurations
|
# Use official TPU initialization pattern (simplified and reliable)
|
||||||
resolver = None
|
print("🚀 Using official TensorFlow TPU initialization...")
|
||||||
|
|
||||||
if tpu_address:
|
# Use your tested official TPU initialization code
|
||||||
print(f"🚀 Attempting TPU connection with address: {tpu_address}")
|
|
||||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_address)
|
|
||||||
elif tpu_name:
|
|
||||||
print(f"🚀 Attempting TPU connection with name: {tpu_name}")
|
|
||||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_name)
|
|
||||||
else:
|
|
||||||
# Try auto-detection
|
|
||||||
print("🚀 Attempting TPU auto-detection...")
|
|
||||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
|
resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
|
||||||
|
|
||||||
# Initialize TPU
|
|
||||||
print("⚡ Connecting to TPU cluster...")
|
|
||||||
tf.config.experimental_connect_to_cluster(resolver)
|
tf.config.experimental_connect_to_cluster(resolver)
|
||||||
|
# This is the TPU initialization code that has to be at the beginning.
|
||||||
print("🔧 Initializing TPU system...")
|
|
||||||
tf.tpu.experimental.initialize_tpu_system(resolver)
|
tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||||
|
|
||||||
|
# Verify TPU devices (following official example)
|
||||||
|
tpu_devices = tf.config.list_logical_devices('TPU')
|
||||||
|
print("All devices: ", tpu_devices)
|
||||||
|
|
||||||
|
if not tpu_devices:
|
||||||
|
raise RuntimeError("No TPU devices found!")
|
||||||
|
|
||||||
|
print(f"✅ Found {len(tpu_devices)} TPU devices")
|
||||||
|
|
||||||
# Create TPU strategy
|
# Create TPU strategy
|
||||||
print("🎯 Creating TPU strategy...")
|
print("🎯 Creating TPU strategy...")
|
||||||
strategy = tf.distribute.TPUStrategy(resolver)
|
strategy = tf.distribute.TPUStrategy(resolver)
|
||||||
|
|
||||||
print(f"✅ TPU initialized successfully!")
|
print(f"✅ TPU initialized successfully!")
|
||||||
print(f"🎉 Number of TPU cores: {strategy.num_replicas_in_sync}")
|
print(f"🎉 Number of TPU cores: {strategy.num_replicas_in_sync}")
|
||||||
print(f"🏃 TPU cluster: {resolver.cluster_spec()}")
|
|
||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
Reference in New Issue
Block a user