From 3c993a6268759a5224e7bd3679375fa863414311 Mon Sep 17 00:00:00 2001 From: Zchen <161216199+ZH-CEN@users.noreply.github.com> Date: Wed, 22 Oct 2025 01:47:08 +0800 Subject: [PATCH] Increase safety margin to 30% in dataset shape analysis for improved padding accuracy --- model_training_nnn_tpu/dataset_tf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/model_training_nnn_tpu/dataset_tf.py b/model_training_nnn_tpu/dataset_tf.py index ba9afb8..5c07fab 100644 --- a/model_training_nnn_tpu/dataset_tf.py +++ b/model_training_nnn_tpu/dataset_tf.py @@ -934,8 +934,8 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int = 'n_features': dataset_tf.feature_dim } - # 5. 添加安全边际(10% buffer) - safety_margin = 1.1 + # 5. 添加更大的安全边际(30% buffer)防止填充错误 + safety_margin = 1.3 final_max_shapes = { 'max_time_steps': int(original_max_shapes['max_time_steps'] * safety_margin), @@ -998,7 +998,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF, # Analyze dataset to get maximum shapes print("📊 Analyzing dataset for maximum shapes...") - max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=100) + max_shapes = analyze_dataset_shapes(dataset_tf, sample_size=-1) # Analyze ALL data for maximum accuracy # Use static shapes based on analysis padded_shapes = {