简单修复

This commit is contained in:
Zchen
2025-10-16 10:53:42 +08:00
parent df4a914bbd
commit 0ff6634192
3 changed files with 106 additions and 10 deletions

View File

@@ -135,10 +135,15 @@ class BrainToTextDecoderTrainerTF:
print(f"💻 Available CPU cores: {available_cores}")
# Optimize for data pipeline parallelism
# Use ~1/4 of cores for inter-op (between operations)
# Use ~1/8 of cores for intra-op (within operations)
inter_op_threads = min(32, available_cores // 4)
intra_op_threads = min(16, available_cores // 8)
# For 224 cores, use more threads for better data loading performance
if available_cores >= 200: # High core count system
inter_op_threads = min(64, available_cores // 3) # More aggressive for 224 cores
intra_op_threads = min(32, available_cores // 6)
else:
# Use ~1/4 of cores for inter-op (between operations)
# Use ~1/8 of cores for intra-op (within operations)
inter_op_threads = min(32, available_cores // 4)
intra_op_threads = min(16, available_cores // 8)
tf.config.threading.set_inter_op_parallelism_threads(inter_op_threads)
tf.config.threading.set_intra_op_parallelism_threads(intra_op_threads)
@@ -148,6 +153,63 @@ class BrainToTextDecoderTrainerTF:
print(f" Intra-op parallelism: {intra_op_threads} threads")
print(f" This will accelerate data loading and preprocessing while TPU handles training")
def _get_tpu_status(self) -> str:
"""Get current TPU status and utilization info"""
try:
# Get TPU devices
tpu_devices = tf.config.list_logical_devices('TPU')
if not tpu_devices:
return "TPU: No devices"
# Get strategy info
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
# Get memory usage (simplified)
import psutil
memory = psutil.virtual_memory()
return (f"TPU: {len(tpu_devices)}dev {num_replicas}cores "
f"RAM: {memory.percent:.1f}%")
except Exception as e:
return f"TPU: status_error({str(e)[:20]})"
def _get_detailed_tpu_status(self) -> str:
"""Get detailed TPU status for training start"""
try:
# Get TPU devices
tpu_devices = tf.config.list_logical_devices('TPU')
if not tpu_devices:
return "❌ No TPU devices detected"
# Get strategy info
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
strategy_type = type(self.strategy).__name__
# Get memory info
import psutil
memory = psutil.virtual_memory()
# Simple TPU test
try:
with tf.device('/TPU:0'):
test_result = tf.constant([1.0, 2.0])
_ = tf.reduce_sum(test_result)
tpu_test = "✅ responsive"
except Exception as e:
tpu_test = f"❌ test_failed({str(e)[:15]})"
return (f"TPU Devices: {len(tpu_devices)} | "
f"Strategy: {strategy_type} | "
f"Cores: {num_replicas} | "
f"RAM: {memory.percent:.1f}% ({memory.used//1024//1024//1024}GB/{memory.total//1024//1024//1024}GB) | "
f"Test: {tpu_test}")
except Exception as e:
return f"❌ TPU status check failed: {str(e)[:50]}"
def _initialize_datasets(self):
"""Initialize training and validation datasets"""
# Create file paths
@@ -448,6 +510,10 @@ class BrainToTextDecoderTrainerTF:
"""Main training loop"""
self.logger.info("Starting training loop...")
# Log initial TPU status
initial_tpu_status = self._get_detailed_tpu_status()
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
# Create distributed datasets
train_dataset = create_input_fn(
self.train_dataset_tf,
@@ -493,12 +559,14 @@ class BrainToTextDecoderTrainerTF:
train_step_duration = time.time() - start_time
train_losses.append(float(loss.numpy()))
# Log training progress
# Log training progress with TPU status
if step % self.args['batches_per_train_log'] == 0:
tpu_status = self._get_tpu_status()
self.logger.info(f'Train batch {step}: '
f'loss: {float(loss.numpy()):.2f} '
f'grad norm: {float(grad_norm.numpy()):.2f} '
f'time: {train_step_duration:.3f}')
f'time: {train_step_duration:.3f}s '
f'| {tpu_status}')
# Validation step
if step % self.args['batches_per_val_step'] == 0 or step == (self.args['num_training_batches'] - 1):
@@ -508,10 +576,12 @@ class BrainToTextDecoderTrainerTF:
val_metrics = self._validate(val_dist_dataset)
val_step_duration = time.time() - val_start_time
tpu_status = self._get_tpu_status()
self.logger.info(f'Val batch {step}: '
f'PER (avg): {val_metrics["avg_per"]:.4f} '
f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} '
f'time: {val_step_duration:.3f}')
f'time: {val_step_duration:.3f}s '
f'| {tpu_status}')
val_pers.append(val_metrics['avg_per'])
val_losses.append(val_metrics['avg_loss'])