简单修复

This commit is contained in:
Zchen
2025-10-16 10:53:42 +08:00
parent df4a914bbd
commit 0ff6634192
3 changed files with 106 additions and 10 deletions

View File

@@ -94,7 +94,11 @@ def gauss_smooth(inputs: tf.Tensor, smooth_kernel_std: float = 2.0, smooth_kerne
return tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
indices = tf.range(tf.shape(inputs)[-1])
smoothed_features_tensor = tf.map_fn(smooth_single_feature, indices, dtype=tf.float32)
smoothed_features_tensor = tf.map_fn(
smooth_single_feature,
indices,
fn_output_signature=tf.TensorSpec(shape=[None, None, 1], dtype=tf.float32)
)
smoothed = tf.transpose(smoothed_features_tensor, [1, 2, 0, 3])
smoothed = tf.squeeze(smoothed, axis=-1)
else:
@@ -109,7 +113,25 @@ def gauss_smooth(inputs: tf.Tensor, smooth_kernel_std: float = 2.0, smooth_kerne
return smoothed
```
## 4. ✅ Test Script Fix (`test_tensorflow_implementation.py`)
## 4. ✅ TensorFlow Deprecation Warning Fix (`dataset_tf.py`)
**Problem**: `calling map_fn_v2 with dtype is deprecated and will be removed in a future version`
**Solution**: Replaced deprecated `dtype` parameter with `fn_output_signature` in `tf.map_fn`:
```python
# Before (deprecated):
smoothed_features_tensor = tf.map_fn(smooth_single_feature, indices, dtype=tf.float32)
# After (modern API):
smoothed_features_tensor = tf.map_fn(
smooth_single_feature,
indices,
fn_output_signature=tf.TensorSpec(shape=[None, None, 1], dtype=tf.float32)
)
```
## 5. ✅ Test Script Fix (`test_tensorflow_implementation.py`)
**Problem**: `cannot access local variable 'expected_features' where it is not associated with a value`