tpu维护

This commit is contained in:
Zchen
2025-10-16 13:39:05 +08:00
parent 5a1e446219
commit a545cc5648
4 changed files with 708 additions and 26 deletions

View File

@@ -177,26 +177,43 @@ class BrainToTextDecoderTrainerTF:
# Get strategy info
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
# Try to get TPU memory info (HBM)
# Get TPU memory info using the working /device:TPU:X format
try:
# Attempt to get TPU memory usage for each device
memory_info = tf.config.experimental.get_memory_info('/TPU:0')
if memory_info and 'current' in memory_info:
current_mb = memory_info['current'] // (1024 * 1024)
peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024)
hbm_info = f"HBM: {current_mb}MB({peak_mb}MB peak)"
# Check all TPU devices for memory usage
active_cores = 0
total_current_mb = 0
max_peak_mb = 0
for device in tpu_devices:
try:
memory_info = tf.config.experimental.get_memory_info(device.name)
if memory_info and 'current' in memory_info:
current_mb = memory_info['current'] // (1024 * 1024)
peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024)
if current_mb > 1: # >1MB considered active
active_cores += 1
total_current_mb += current_mb
max_peak_mb = max(max_peak_mb, peak_mb)
except:
continue
if active_cores > 0:
if active_cores == 1:
hbm_info = f"HBM:{total_current_mb}MB(peak:{max_peak_mb}MB)"
else:
hbm_info = f"HBM:{total_current_mb}MB/{active_cores}cores(peak:{max_peak_mb}MB)"
else:
hbm_info = "HBM: unknown"
hbm_info = "HBM:idle"
except Exception:
# Fallback: simple TPU activity check
try:
# Test TPU responsiveness
with tf.device('/TPU:0'):
test_tensor = tf.constant([1.0, 2.0])
_ = tf.reduce_sum(test_tensor)
hbm_info = "HBM: active"
_ = tf.constant(1.0)
hbm_info = "HBM:active"
except Exception:
hbm_info = "HBM: inactive"
hbm_info = "HBM:inactive"
return (f"TPU: {len(tpu_devices)}dev {num_replicas}cores "
f"{hbm_info}")
@@ -217,17 +234,37 @@ class BrainToTextDecoderTrainerTF:
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
strategy_type = type(self.strategy).__name__
# Get TPU HBM memory info
# Get TPU HBM memory info using working device format
try:
memory_info = tf.config.experimental.get_memory_info('/TPU:0')
if memory_info and 'current' in memory_info:
current_gb = memory_info['current'] // (1024 * 1024 * 1024)
peak_gb = memory_info.get('peak', memory_info['current']) // (1024 * 1024 * 1024)
# TPU v5e-8 has ~32GB HBM per chip, 8 chips total = ~256GB
estimated_total_gb = 32 * len(tpu_devices)
hbm_usage = f"HBM: {current_gb}GB/{estimated_total_gb}GB (peak: {peak_gb}GB)"
active_cores = 0
total_current_gb = 0
max_peak_gb = 0
memory_details = []
for i, device in enumerate(tpu_devices):
try:
memory_info = tf.config.experimental.get_memory_info(device.name)
if memory_info and 'current' in memory_info:
current_gb = memory_info['current'] // (1024 * 1024 * 1024)
peak_gb = memory_info.get('peak', memory_info['current']) // (1024 * 1024 * 1024)
if current_gb > 0 or memory_info['current'] > 1024*1024: # >1MB
active_cores += 1
total_current_gb += current_gb
max_peak_gb = max(max_peak_gb, peak_gb)
if current_gb > 0:
memory_details.append(f"Core{i}:{current_gb}GB")
except:
continue
if active_cores > 0:
# Based on your test: TPU:0 peaked at 14.5GB, suggesting ~16GB per core
estimated_per_core = 16 # Conservative estimate
estimated_total_gb = estimated_per_core * len(tpu_devices)
hbm_usage = f"HBM: {total_current_gb}GB/{estimated_total_gb}GB (peak: {max_peak_gb}GB) active:{active_cores}cores"
else:
hbm_usage = "HBM: unknown"
hbm_usage = "HBM: 0GB/256GB (idle)"
except Exception:
hbm_usage = "HBM: unavailable"
@@ -559,39 +596,43 @@ class BrainToTextDecoderTrainerTF:
self.args['dataset']['data_transforms'],
training=True
)
val_dataset = create_input_fn(
self.val_dataset_tf,
self.args['dataset']['data_transforms'],
training=False
)
# Distribute datasets
train_dist_dataset = self.strategy.experimental_distribute_dataset(train_dataset)
val_dist_dataset = self.strategy.experimental_distribute_dataset(val_dataset)
self.logger.info("Created distributed training and validation datasets")
# Training metrics
train_losses = []
val_losses = []
val_pers = []
val_results = []
val_steps_since_improvement = 0
self.logger.info("Training time count beginning...")
train_start_time = time.time()
# Training loop
step = 0
for batch in train_dist_dataset:
if step >= self.args['num_training_batches']:
self.logger.info("Reached maximum training batches, stopping training")
break
start_time = time.time()
# Distributed training step
self.logger.info("Running distributed training step...")
per_replica_losses, per_replica_grad_norms = self.strategy.run(
self._train_step, args=(batch, step)
)
# Reduce across replicas
self.logger.info("Reducing results across replicas...")
loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
grad_norm = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_grad_norms, axis=None)