Adjust safety margin in dataset shape analysis based on sample size for improved accuracy
This commit is contained in:
@@ -934,8 +934,15 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
|||||||
'n_features': dataset_tf.feature_dim
|
'n_features': dataset_tf.feature_dim
|
||||||
}
|
}
|
||||||
|
|
||||||
# 5. 添加更大的安全边际(30% buffer)防止填充错误
|
# 5. 添加适当的安全边际 - 基于分析范围调整
|
||||||
safety_margin = 1.3
|
if sample_size == -1:
|
||||||
|
# 全数据分析:只需要很小的边际应对可能的舍入误差
|
||||||
|
safety_margin = 1.02 # 2% buffer for rounding errors
|
||||||
|
margin_reason = "minimal buffer for full dataset analysis"
|
||||||
|
else:
|
||||||
|
# 采样分析:需要更大的边际应对未采样到的极值
|
||||||
|
safety_margin = 1.3 # 30% buffer for sampling uncertainty
|
||||||
|
margin_reason = f"larger buffer due to sampling only {sample_size} trials"
|
||||||
|
|
||||||
final_max_shapes = {
|
final_max_shapes = {
|
||||||
'max_time_steps': int(original_max_shapes['max_time_steps'] * safety_margin),
|
'max_time_steps': int(original_max_shapes['max_time_steps'] * safety_margin),
|
||||||
@@ -949,7 +956,7 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
|||||||
|
|
||||||
print(f"✅ Parallel analysis complete in {analysis_time:.2f} seconds!")
|
print(f"✅ Parallel analysis complete in {analysis_time:.2f} seconds!")
|
||||||
print(f"📊 Successfully analyzed {len(local_max_shapes)}/{total_trials_to_analyze} trials ({successful_rate:.1f}%)")
|
print(f"📊 Successfully analyzed {len(local_max_shapes)}/{total_trials_to_analyze} trials ({successful_rate:.1f}%)")
|
||||||
print(f"📏 Final max shapes (with {int((safety_margin-1)*100)}% safety margin):")
|
print(f"📏 Final max shapes (with {int((safety_margin-1)*100)}% safety margin - {margin_reason}):")
|
||||||
print(f" Time steps: {original_max_shapes['max_time_steps']} → {final_max_shapes['max_time_steps']}")
|
print(f" Time steps: {original_max_shapes['max_time_steps']} → {final_max_shapes['max_time_steps']}")
|
||||||
print(f" Phone sequence length: {original_max_shapes['max_phone_seq_len']} → {final_max_shapes['max_phone_seq_len']}")
|
print(f" Phone sequence length: {original_max_shapes['max_phone_seq_len']} → {final_max_shapes['max_phone_seq_len']}")
|
||||||
print(f" Transcription length: {original_max_shapes['max_transcription_len']} → {final_max_shapes['max_transcription_len']}")
|
print(f" Transcription length: {original_max_shapes['max_transcription_len']} → {final_max_shapes['max_transcription_len']}")
|
||||||
|
Reference in New Issue
Block a user