简单修复

This commit is contained in:
Zchen
2025-10-16 10:53:42 +08:00
parent df4a914bbd
commit 0ff6634192
3 changed files with 106 additions and 10 deletions

View File

@@ -94,7 +94,11 @@ def gauss_smooth(inputs: tf.Tensor, smooth_kernel_std: float = 2.0, smooth_kerne
return tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
indices = tf.range(tf.shape(inputs)[-1])
smoothed_features_tensor = tf.map_fn(smooth_single_feature, indices, dtype=tf.float32)
smoothed_features_tensor = tf.map_fn(
smooth_single_feature,
indices,
fn_output_signature=tf.TensorSpec(shape=[None, None, 1], dtype=tf.float32)
)
smoothed = tf.transpose(smoothed_features_tensor, [1, 2, 0, 3])
smoothed = tf.squeeze(smoothed, axis=-1)
else:
@@ -109,7 +113,25 @@ def gauss_smooth(inputs: tf.Tensor, smooth_kernel_std: float = 2.0, smooth_kerne
return smoothed
```
## 4. ✅ Test Script Fix (`test_tensorflow_implementation.py`)
## 4. ✅ TensorFlow Deprecation Warning Fix (`dataset_tf.py`)
**Problem**: `calling map_fn_v2 with dtype is deprecated and will be removed in a future version`
**Solution**: Replaced deprecated `dtype` parameter with `fn_output_signature` in `tf.map_fn`:
```python
# Before (deprecated):
smoothed_features_tensor = tf.map_fn(smooth_single_feature, indices, dtype=tf.float32)
# After (modern API):
smoothed_features_tensor = tf.map_fn(
smooth_single_feature,
indices,
fn_output_signature=tf.TensorSpec(shape=[None, None, 1], dtype=tf.float32)
)
```
## 5. ✅ Test Script Fix (`test_tensorflow_implementation.py`)
**Problem**: `cannot access local variable 'expected_features' where it is not associated with a value`

View File

@@ -362,7 +362,11 @@ class DataAugmentationTF:
# Use tf.map_fn for dynamic features
indices = tf.range(num_features)
smoothed_features_tensor = tf.map_fn(smooth_single_feature, indices, dtype=tf.float32)
smoothed_features_tensor = tf.map_fn(
smooth_single_feature,
indices,
fn_output_signature=tf.TensorSpec(shape=[None, None, 1], dtype=tf.float32)
)
# Transpose to get [batch_size, time_steps, features]
smoothed = tf.transpose(smoothed_features_tensor, [1, 2, 0, 3])
smoothed = tf.squeeze(smoothed, axis=-1)

View File

@@ -135,10 +135,15 @@ class BrainToTextDecoderTrainerTF:
print(f"💻 Available CPU cores: {available_cores}")
# Optimize for data pipeline parallelism
# Use ~1/4 of cores for inter-op (between operations)
# Use ~1/8 of cores for intra-op (within operations)
inter_op_threads = min(32, available_cores // 4)
intra_op_threads = min(16, available_cores // 8)
# For 224 cores, use more threads for better data loading performance
if available_cores >= 200: # High core count system
inter_op_threads = min(64, available_cores // 3) # More aggressive for 224 cores
intra_op_threads = min(32, available_cores // 6)
else:
# Use ~1/4 of cores for inter-op (between operations)
# Use ~1/8 of cores for intra-op (within operations)
inter_op_threads = min(32, available_cores // 4)
intra_op_threads = min(16, available_cores // 8)
tf.config.threading.set_inter_op_parallelism_threads(inter_op_threads)
tf.config.threading.set_intra_op_parallelism_threads(intra_op_threads)
@@ -148,6 +153,63 @@ class BrainToTextDecoderTrainerTF:
print(f" Intra-op parallelism: {intra_op_threads} threads")
print(f" This will accelerate data loading and preprocessing while TPU handles training")
def _get_tpu_status(self) -> str:
"""Get current TPU status and utilization info"""
try:
# Get TPU devices
tpu_devices = tf.config.list_logical_devices('TPU')
if not tpu_devices:
return "TPU: No devices"
# Get strategy info
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
# Get memory usage (simplified)
import psutil
memory = psutil.virtual_memory()
return (f"TPU: {len(tpu_devices)}dev {num_replicas}cores "
f"RAM: {memory.percent:.1f}%")
except Exception as e:
return f"TPU: status_error({str(e)[:20]})"
def _get_detailed_tpu_status(self) -> str:
"""Get detailed TPU status for training start"""
try:
# Get TPU devices
tpu_devices = tf.config.list_logical_devices('TPU')
if not tpu_devices:
return "❌ No TPU devices detected"
# Get strategy info
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
strategy_type = type(self.strategy).__name__
# Get memory info
import psutil
memory = psutil.virtual_memory()
# Simple TPU test
try:
with tf.device('/TPU:0'):
test_result = tf.constant([1.0, 2.0])
_ = tf.reduce_sum(test_result)
tpu_test = "✅ responsive"
except Exception as e:
tpu_test = f"❌ test_failed({str(e)[:15]})"
return (f"TPU Devices: {len(tpu_devices)} | "
f"Strategy: {strategy_type} | "
f"Cores: {num_replicas} | "
f"RAM: {memory.percent:.1f}% ({memory.used//1024//1024//1024}GB/{memory.total//1024//1024//1024}GB) | "
f"Test: {tpu_test}")
except Exception as e:
return f"❌ TPU status check failed: {str(e)[:50]}"
def _initialize_datasets(self):
"""Initialize training and validation datasets"""
# Create file paths
@@ -448,6 +510,10 @@ class BrainToTextDecoderTrainerTF:
"""Main training loop"""
self.logger.info("Starting training loop...")
# Log initial TPU status
initial_tpu_status = self._get_detailed_tpu_status()
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
# Create distributed datasets
train_dataset = create_input_fn(
self.train_dataset_tf,
@@ -493,12 +559,14 @@ class BrainToTextDecoderTrainerTF:
train_step_duration = time.time() - start_time
train_losses.append(float(loss.numpy()))
# Log training progress
# Log training progress with TPU status
if step % self.args['batches_per_train_log'] == 0:
tpu_status = self._get_tpu_status()
self.logger.info(f'Train batch {step}: '
f'loss: {float(loss.numpy()):.2f} '
f'grad norm: {float(grad_norm.numpy()):.2f} '
f'time: {train_step_duration:.3f}')
f'time: {train_step_duration:.3f}s '
f'| {tpu_status}')
# Validation step
if step % self.args['batches_per_val_step'] == 0 or step == (self.args['num_training_batches'] - 1):
@@ -508,10 +576,12 @@ class BrainToTextDecoderTrainerTF:
val_metrics = self._validate(val_dist_dataset)
val_step_duration = time.time() - val_start_time
tpu_status = self._get_tpu_status()
self.logger.info(f'Val batch {step}: '
f'PER (avg): {val_metrics["avg_per"]:.4f} '
f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} '
f'time: {val_step_duration:.3f}')
f'time: {val_step_duration:.3f}s '
f'| {tpu_status}')
val_pers.append(val_metrics['avg_per'])
val_losses.append(val_metrics['avg_loss'])