tpu维护
This commit is contained in:
@@ -177,26 +177,43 @@ class BrainToTextDecoderTrainerTF:
|
||||
# Get strategy info
|
||||
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
|
||||
|
||||
# Try to get TPU memory info (HBM)
|
||||
# Get TPU memory info using the working /device:TPU:X format
|
||||
try:
|
||||
# Attempt to get TPU memory usage for each device
|
||||
memory_info = tf.config.experimental.get_memory_info('/TPU:0')
|
||||
if memory_info and 'current' in memory_info:
|
||||
current_mb = memory_info['current'] // (1024 * 1024)
|
||||
peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024)
|
||||
hbm_info = f"HBM: {current_mb}MB({peak_mb}MB peak)"
|
||||
# Check all TPU devices for memory usage
|
||||
active_cores = 0
|
||||
total_current_mb = 0
|
||||
max_peak_mb = 0
|
||||
|
||||
for device in tpu_devices:
|
||||
try:
|
||||
memory_info = tf.config.experimental.get_memory_info(device.name)
|
||||
if memory_info and 'current' in memory_info:
|
||||
current_mb = memory_info['current'] // (1024 * 1024)
|
||||
peak_mb = memory_info.get('peak', memory_info['current']) // (1024 * 1024)
|
||||
|
||||
if current_mb > 1: # >1MB considered active
|
||||
active_cores += 1
|
||||
total_current_mb += current_mb
|
||||
max_peak_mb = max(max_peak_mb, peak_mb)
|
||||
except:
|
||||
continue
|
||||
|
||||
if active_cores > 0:
|
||||
if active_cores == 1:
|
||||
hbm_info = f"HBM:{total_current_mb}MB(peak:{max_peak_mb}MB)"
|
||||
else:
|
||||
hbm_info = f"HBM:{total_current_mb}MB/{active_cores}cores(peak:{max_peak_mb}MB)"
|
||||
else:
|
||||
hbm_info = "HBM: unknown"
|
||||
hbm_info = "HBM:idle"
|
||||
|
||||
except Exception:
|
||||
# Fallback: simple TPU activity check
|
||||
try:
|
||||
# Test TPU responsiveness
|
||||
with tf.device('/TPU:0'):
|
||||
test_tensor = tf.constant([1.0, 2.0])
|
||||
_ = tf.reduce_sum(test_tensor)
|
||||
hbm_info = "HBM: active"
|
||||
_ = tf.constant(1.0)
|
||||
hbm_info = "HBM:active"
|
||||
except Exception:
|
||||
hbm_info = "HBM: inactive"
|
||||
hbm_info = "HBM:inactive"
|
||||
|
||||
return (f"TPU: {len(tpu_devices)}dev {num_replicas}cores "
|
||||
f"{hbm_info}")
|
||||
@@ -217,17 +234,37 @@ class BrainToTextDecoderTrainerTF:
|
||||
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
|
||||
strategy_type = type(self.strategy).__name__
|
||||
|
||||
# Get TPU HBM memory info
|
||||
# Get TPU HBM memory info using working device format
|
||||
try:
|
||||
memory_info = tf.config.experimental.get_memory_info('/TPU:0')
|
||||
if memory_info and 'current' in memory_info:
|
||||
current_gb = memory_info['current'] // (1024 * 1024 * 1024)
|
||||
peak_gb = memory_info.get('peak', memory_info['current']) // (1024 * 1024 * 1024)
|
||||
# TPU v5e-8 has ~32GB HBM per chip, 8 chips total = ~256GB
|
||||
estimated_total_gb = 32 * len(tpu_devices)
|
||||
hbm_usage = f"HBM: {current_gb}GB/{estimated_total_gb}GB (peak: {peak_gb}GB)"
|
||||
active_cores = 0
|
||||
total_current_gb = 0
|
||||
max_peak_gb = 0
|
||||
memory_details = []
|
||||
|
||||
for i, device in enumerate(tpu_devices):
|
||||
try:
|
||||
memory_info = tf.config.experimental.get_memory_info(device.name)
|
||||
if memory_info and 'current' in memory_info:
|
||||
current_gb = memory_info['current'] // (1024 * 1024 * 1024)
|
||||
peak_gb = memory_info.get('peak', memory_info['current']) // (1024 * 1024 * 1024)
|
||||
|
||||
if current_gb > 0 or memory_info['current'] > 1024*1024: # >1MB
|
||||
active_cores += 1
|
||||
total_current_gb += current_gb
|
||||
max_peak_gb = max(max_peak_gb, peak_gb)
|
||||
if current_gb > 0:
|
||||
memory_details.append(f"Core{i}:{current_gb}GB")
|
||||
except:
|
||||
continue
|
||||
|
||||
if active_cores > 0:
|
||||
# Based on your test: TPU:0 peaked at 14.5GB, suggesting ~16GB per core
|
||||
estimated_per_core = 16 # Conservative estimate
|
||||
estimated_total_gb = estimated_per_core * len(tpu_devices)
|
||||
hbm_usage = f"HBM: {total_current_gb}GB/{estimated_total_gb}GB (peak: {max_peak_gb}GB) active:{active_cores}cores"
|
||||
else:
|
||||
hbm_usage = "HBM: unknown"
|
||||
hbm_usage = "HBM: 0GB/256GB (idle)"
|
||||
|
||||
except Exception:
|
||||
hbm_usage = "HBM: unavailable"
|
||||
|
||||
@@ -559,39 +596,43 @@ class BrainToTextDecoderTrainerTF:
|
||||
self.args['dataset']['data_transforms'],
|
||||
training=True
|
||||
)
|
||||
|
||||
val_dataset = create_input_fn(
|
||||
self.val_dataset_tf,
|
||||
self.args['dataset']['data_transforms'],
|
||||
training=False
|
||||
)
|
||||
|
||||
# Distribute datasets
|
||||
train_dist_dataset = self.strategy.experimental_distribute_dataset(train_dataset)
|
||||
val_dist_dataset = self.strategy.experimental_distribute_dataset(val_dataset)
|
||||
|
||||
self.logger.info("Created distributed training and validation datasets")
|
||||
# Training metrics
|
||||
train_losses = []
|
||||
val_losses = []
|
||||
val_pers = []
|
||||
val_results = []
|
||||
val_steps_since_improvement = 0
|
||||
|
||||
self.logger.info("Training time count beginning...")
|
||||
train_start_time = time.time()
|
||||
|
||||
# Training loop
|
||||
step = 0
|
||||
for batch in train_dist_dataset:
|
||||
if step >= self.args['num_training_batches']:
|
||||
self.logger.info("Reached maximum training batches, stopping training")
|
||||
break
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Distributed training step
|
||||
self.logger.info("Running distributed training step...")
|
||||
per_replica_losses, per_replica_grad_norms = self.strategy.run(
|
||||
self._train_step, args=(batch, step)
|
||||
)
|
||||
|
||||
# Reduce across replicas
|
||||
self.logger.info("Reducing results across replicas...")
|
||||
loss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
|
||||
grad_norm = self.strategy.reduce(tf.distribute.ReduceOp.MEAN, per_replica_grad_norms, axis=None)
|
||||
|
||||
|
Reference in New Issue
Block a user