Increase safety margin to 30% in dataset shape analysis for improved padding accuracy
This commit is contained in:
@@ -934,8 +934,8 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
|||||||
'n_features': dataset_tf.feature_dim
|
'n_features': dataset_tf.feature_dim
|
||||||
}
|
}
|
||||||
|
|
||||||
# 5. 添加安全边际(10% buffer)
|
# 5. 添加更大的安全边际(30% buffer)防止填充错误
|
||||||
safety_margin = 1.1
|
safety_margin = 1.3
|
||||||
|
|
||||||
final_max_shapes = {
|
final_max_shapes = {
|
||||||
'max_time_steps': int(original_max_shapes['max_time_steps'] * safety_margin),
|
'max_time_steps': int(original_max_shapes['max_time_steps'] * safety_margin),
|
||||||
@@ -998,7 +998,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
|||||||
|
|
||||||
# Analyze dataset to get maximum shapes
|
# Analyze dataset to get maximum shapes
|
||||||
print("📊 Analyzing dataset for maximum shapes...")
|
print("📊 Analyzing dataset for maximum shapes...")
|
||||||
max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=100)
|
max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1) # Analyze ALL data for maximum accuracy
|
||||||
|
|
||||||
# Use static shapes based on analysis
|
# Use static shapes based on analysis
|
||||||
padded_shapes = {
|
padded_shapes = {
|
||||||
|
Reference in New Issue
Block a user