简单修复
This commit is contained in:
@@ -135,10 +135,15 @@ class BrainToTextDecoderTrainerTF:
|
||||
print(f"💻 Available CPU cores: {available_cores}")
|
||||
|
||||
# Optimize for data pipeline parallelism
|
||||
# Use ~1/4 of cores for inter-op (between operations)
|
||||
# Use ~1/8 of cores for intra-op (within operations)
|
||||
inter_op_threads = min(32, available_cores // 4)
|
||||
intra_op_threads = min(16, available_cores // 8)
|
||||
# For 224 cores, use more threads for better data loading performance
|
||||
if available_cores >= 200: # High core count system
|
||||
inter_op_threads = min(64, available_cores // 3) # More aggressive for 224 cores
|
||||
intra_op_threads = min(32, available_cores // 6)
|
||||
else:
|
||||
# Use ~1/4 of cores for inter-op (between operations)
|
||||
# Use ~1/8 of cores for intra-op (within operations)
|
||||
inter_op_threads = min(32, available_cores // 4)
|
||||
intra_op_threads = min(16, available_cores // 8)
|
||||
|
||||
tf.config.threading.set_inter_op_parallelism_threads(inter_op_threads)
|
||||
tf.config.threading.set_intra_op_parallelism_threads(intra_op_threads)
|
||||
@@ -148,6 +153,63 @@ class BrainToTextDecoderTrainerTF:
|
||||
print(f" Intra-op parallelism: {intra_op_threads} threads")
|
||||
print(f" This will accelerate data loading and preprocessing while TPU handles training")
|
||||
|
||||
def _get_tpu_status(self) -> str:
|
||||
"""Get current TPU status and utilization info"""
|
||||
try:
|
||||
# Get TPU devices
|
||||
tpu_devices = tf.config.list_logical_devices('TPU')
|
||||
|
||||
if not tpu_devices:
|
||||
return "TPU: No devices"
|
||||
|
||||
# Get strategy info
|
||||
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
|
||||
|
||||
# Get memory usage (simplified)
|
||||
import psutil
|
||||
memory = psutil.virtual_memory()
|
||||
|
||||
return (f"TPU: {len(tpu_devices)}dev {num_replicas}cores "
|
||||
f"RAM: {memory.percent:.1f}%")
|
||||
|
||||
except Exception as e:
|
||||
return f"TPU: status_error({str(e)[:20]})"
|
||||
|
||||
def _get_detailed_tpu_status(self) -> str:
|
||||
"""Get detailed TPU status for training start"""
|
||||
try:
|
||||
# Get TPU devices
|
||||
tpu_devices = tf.config.list_logical_devices('TPU')
|
||||
|
||||
if not tpu_devices:
|
||||
return "❌ No TPU devices detected"
|
||||
|
||||
# Get strategy info
|
||||
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
|
||||
strategy_type = type(self.strategy).__name__
|
||||
|
||||
# Get memory info
|
||||
import psutil
|
||||
memory = psutil.virtual_memory()
|
||||
|
||||
# Simple TPU test
|
||||
try:
|
||||
with tf.device('/TPU:0'):
|
||||
test_result = tf.constant([1.0, 2.0])
|
||||
_ = tf.reduce_sum(test_result)
|
||||
tpu_test = "✅ responsive"
|
||||
except Exception as e:
|
||||
tpu_test = f"❌ test_failed({str(e)[:15]})"
|
||||
|
||||
return (f"TPU Devices: {len(tpu_devices)} | "
|
||||
f"Strategy: {strategy_type} | "
|
||||
f"Cores: {num_replicas} | "
|
||||
f"RAM: {memory.percent:.1f}% ({memory.used//1024//1024//1024}GB/{memory.total//1024//1024//1024}GB) | "
|
||||
f"Test: {tpu_test}")
|
||||
|
||||
except Exception as e:
|
||||
return f"❌ TPU status check failed: {str(e)[:50]}"
|
||||
|
||||
def _initialize_datasets(self):
|
||||
"""Initialize training and validation datasets"""
|
||||
# Create file paths
|
||||
@@ -448,6 +510,10 @@ class BrainToTextDecoderTrainerTF:
|
||||
"""Main training loop"""
|
||||
self.logger.info("Starting training loop...")
|
||||
|
||||
# Log initial TPU status
|
||||
initial_tpu_status = self._get_detailed_tpu_status()
|
||||
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
|
||||
|
||||
# Create distributed datasets
|
||||
train_dataset = create_input_fn(
|
||||
self.train_dataset_tf,
|
||||
@@ -493,12 +559,14 @@ class BrainToTextDecoderTrainerTF:
|
||||
train_step_duration = time.time() - start_time
|
||||
train_losses.append(float(loss.numpy()))
|
||||
|
||||
# Log training progress
|
||||
# Log training progress with TPU status
|
||||
if step % self.args['batches_per_train_log'] == 0:
|
||||
tpu_status = self._get_tpu_status()
|
||||
self.logger.info(f'Train batch {step}: '
|
||||
f'loss: {float(loss.numpy()):.2f} '
|
||||
f'grad norm: {float(grad_norm.numpy()):.2f} '
|
||||
f'time: {train_step_duration:.3f}')
|
||||
f'time: {train_step_duration:.3f}s '
|
||||
f'| {tpu_status}')
|
||||
|
||||
# Validation step
|
||||
if step % self.args['batches_per_val_step'] == 0 or step == (self.args['num_training_batches'] - 1):
|
||||
@@ -508,10 +576,12 @@ class BrainToTextDecoderTrainerTF:
|
||||
val_metrics = self._validate(val_dist_dataset)
|
||||
val_step_duration = time.time() - val_start_time
|
||||
|
||||
tpu_status = self._get_tpu_status()
|
||||
self.logger.info(f'Val batch {step}: '
|
||||
f'PER (avg): {val_metrics["avg_per"]:.4f} '
|
||||
f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} '
|
||||
f'time: {val_step_duration:.3f}')
|
||||
f'time: {val_step_duration:.3f}s '
|
||||
f'| {tpu_status}')
|
||||
|
||||
val_pers.append(val_metrics['avg_per'])
|
||||
val_losses.append(val_metrics['avg_loss'])
|
||||
|
Reference in New Issue
Block a user