简单修复
This commit is contained in:
@@ -94,7 +94,11 @@ def gauss_smooth(inputs: tf.Tensor, smooth_kernel_std: float = 2.0, smooth_kerne
|
||||
return tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
|
||||
|
||||
indices = tf.range(tf.shape(inputs)[-1])
|
||||
smoothed_features_tensor = tf.map_fn(smooth_single_feature, indices, dtype=tf.float32)
|
||||
smoothed_features_tensor = tf.map_fn(
|
||||
smooth_single_feature,
|
||||
indices,
|
||||
fn_output_signature=tf.TensorSpec(shape=[None, None, 1], dtype=tf.float32)
|
||||
)
|
||||
smoothed = tf.transpose(smoothed_features_tensor, [1, 2, 0, 3])
|
||||
smoothed = tf.squeeze(smoothed, axis=-1)
|
||||
else:
|
||||
@@ -109,7 +113,25 @@ def gauss_smooth(inputs: tf.Tensor, smooth_kernel_std: float = 2.0, smooth_kerne
|
||||
return smoothed
|
||||
```
|
||||
|
||||
## 4. ✅ Test Script Fix (`test_tensorflow_implementation.py`)
|
||||
## 4. ✅ TensorFlow Deprecation Warning Fix (`dataset_tf.py`)
|
||||
|
||||
**Problem**: `calling map_fn_v2 with dtype is deprecated and will be removed in a future version`
|
||||
|
||||
**Solution**: Replaced deprecated `dtype` parameter with `fn_output_signature` in `tf.map_fn`:
|
||||
|
||||
```python
|
||||
# Before (deprecated):
|
||||
smoothed_features_tensor = tf.map_fn(smooth_single_feature, indices, dtype=tf.float32)
|
||||
|
||||
# After (modern API):
|
||||
smoothed_features_tensor = tf.map_fn(
|
||||
smooth_single_feature,
|
||||
indices,
|
||||
fn_output_signature=tf.TensorSpec(shape=[None, None, 1], dtype=tf.float32)
|
||||
)
|
||||
```
|
||||
|
||||
## 5. ✅ Test Script Fix (`test_tensorflow_implementation.py`)
|
||||
|
||||
**Problem**: `cannot access local variable 'expected_features' where it is not associated with a value`
|
||||
|
||||
|
||||
Reference in New Issue
Block a user