118 lines
6.9 KiB
Plaintext
118 lines
6.9 KiB
Plaintext
![]() |
# argument parser for command line arguments
|
||
|
parser = argparse.ArgumentParser(description='Evaluate ensemble GRU+LSTM models using TTA-E on the copy task dataset.')
|
||
|
parser.add_argument('--gru_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline',
|
||
|
help='Path to the pretrained GRU model directory.')
|
||
|
parser.add_argument('--lstm_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn',
|
||
|
help='Path to the pretrained LSTM model directory.')
|
||
|
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
|
||
|
help='Path to the dataset directory (relative to the current working directory).')
|
||
|
parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'],
|
||
|
help='Evaluation type: "val" for validation set, "test" for test set.')
|
||
|
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
|
||
|
help='Path to the CSV file with metadata about the dataset.')
|
||
|
parser.add_argument('--gpu_number', type=int, default=0,
|
||
|
help='GPU number to use for model inference. Set to -1 to use CPU.')
|
||
|
parser.add_argument('--gru_weight', type=float, default=0.6,
|
||
|
help='Weight for GRU model in ensemble (LSTM weight = 1 - gru_weight). Improved default for better balance.')
|
||
|
# TTA parameters
|
||
|
parser.add_argument('--tta_samples', type=int, default=8,
|
||
|
help='Number of TTA augmentation samples per trial. Increased from 5 for better coverage.')
|
||
|
parser.add_argument('--tta_noise_std', type=float, default=0.01,
|
||
|
help='Standard deviation for TTA noise augmentation.')
|
||
|
parser.add_argument('--tta_smooth_range', type=float, default=0.5,
|
||
|
help='Range for TTA smoothing kernel variation (±range from default).')
|
||
|
parser.add_argument('--tta_scale_range', type=float, default=0.05,
|
||
|
help='Range for TTA amplitude scaling (±range from 1.0).')
|
||
|
parser.add_argument('--tta_cut_max', type=int, default=3,
|
||
|
help='Maximum number of timesteps to cut from beginning in TTA.')
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
Total true phoneme length: 41392
|
||
|
Total edit distance: 6473
|
||
|
Aggregate Phoneme Error Rate (PER): 15.64%
|
||
|
|
||
|
Results saved to: /root/autodl-tmp/nejm-brain-to-text/TTA-E/TTA-E_gru0.6_lstm0.4_samples8_val_20250917_210946.csv
|
||
|
TTA-E configuration: GRU weight = 0.60, LSTM weight = 0.40, TTA samples = 8
|
||
|
|
||
|
============================================================
|
||
|
parser = argparse.ArgumentParser(description='Evaluate ensemble GRU+LSTM models using TTA-E on the copy task dataset.')
|
||
|
parser.add_argument('--gru_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline',
|
||
|
help='Path to the pretrained GRU model directory.')
|
||
|
parser.add_argument('--lstm_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn',
|
||
|
help='Path to the pretrained LSTM model directory.')
|
||
|
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
|
||
|
help='Path to the dataset directory (relative to the current working directory).')
|
||
|
parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'],
|
||
|
help='Evaluation type: "val" for validation set, "test" for test set.')
|
||
|
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
|
||
|
help='Path to the CSV file with metadata about the dataset.')
|
||
|
parser.add_argument('--gpu_number', type=int, default=0,
|
||
|
help='GPU number to use for model inference. Set to -1 to use CPU.')
|
||
|
parser.add_argument('--gru_weight', type=float, default=0.8,
|
||
|
help='Weight for GRU model in ensemble (LSTM weight = 1 - gru_weight). Improved default for better balance.')
|
||
|
# TTA parameters
|
||
|
parser.add_argument('--tta_samples', type=int, default=8,
|
||
|
help='Number of TTA augmentation samples per trial. Increased from 5 for better coverage.')
|
||
|
parser.add_argument('--tta_noise_std', type=float, default=0.01,
|
||
|
help='Standard deviation for TTA noise augmentation.')
|
||
|
parser.add_argument('--tta_smooth_range', type=float, default=0.5,
|
||
|
help='Range for TTA smoothing kernel variation (±range from default).')
|
||
|
parser.add_argument('--tta_scale_range', type=float, default=0.05,
|
||
|
help='Range for TTA amplitude scaling (±range from 1.0).')
|
||
|
parser.add_argument('--tta_cut_max', type=int, default=3,
|
||
|
help='Maximum number of timesteps to cut from beginning in TTA.')
|
||
|
提高GRU权重
|
||
|
Total true phoneme length: 41392
|
||
|
Total edit distance: 4705
|
||
|
Aggregate Phoneme Error Rate (PER): 11.37%
|
||
|
|
||
|
============================================================
|
||
|
去TTA
|
||
|
Total true phoneme length: 41392
|
||
|
Total edit distance: 4326
|
||
|
Aggregate Phoneme Error Rate (PER): 10.45%
|
||
|
|
||
|
============================================================
|
||
|
纯GRU
|
||
|
Total true phoneme length: 41392
|
||
|
Total edit distance: 4215
|
||
|
Aggregate Phoneme Error Rate (PER): 10.18%
|
||
|
============================================================
|
||
|
纯LSTM
|
||
|
Total true phoneme length: 41392
|
||
|
Total edit distance: 4498
|
||
|
Aggregate Phoneme Error Rate (PER): 10.87%
|
||
|
============================================================
|
||
|
纯GRU + 3TTA
|
||
|
Total true phoneme length: 41392
|
||
|
Total edit distance: 4213
|
||
|
Aggregate Phoneme Error Rate (PER): 10.18%
|
||
|
============================================================
|
||
|
纯GRU + 5TTA
|
||
|
Total true phoneme length: 41392
|
||
|
Total edit distance: 4218
|
||
|
Aggregate Phoneme Error Rate (PER): 10.19%
|
||
|
============================================================
|
||
|
纯GRU + 4TTA
|
||
|
Aggregate Phoneme Error Rate (PER): 10.13%
|
||
|
============================================================
|
||
|
纯GRU + 4TTA 去高斯噪声
|
||
|
Aggregate Phoneme Error Rate (PER): 10.14%
|
||
|
============================================================
|
||
|
parser.add_argument('--tta_weights', type=str, default='0.6,0.6,0.6,1.0,0.0',
|
||
|
Aggregate Phoneme Error Rate (PER): 10.16%
|
||
|
============================================================
|
||
|
parser.add_argument('--tta_weights', type=str, default='1.0,0.6,0.6,1.0,0.0',
|
||
|
Aggregate Phoneme Error Rate (PER): 10.13%
|
||
|
============================================================
|
||
|
parser.add_argument('--tta_weights', type=str, default='1.0,0.6,0.0,1.0,0.0',
|
||
|
Aggregate Phoneme Error Rate (PER): 10.23%
|
||
|
============================================================
|
||
|
parser.add_argument('--tta_weights', type=str, default='1.0,1.0,0.0,1.0,0.0',
|
||
|
Aggregate Phoneme Error Rate (PER): 10.14%
|
||
|
============================================================
|
||
|
============================================================
|
||
|
============================================================
|
||
|
纯GRU + 5TTA 去高斯噪声
|
||
|
Total edit distance: 4308
|
||
|
Aggregate Phoneme Error Rate (PER): 10.41%
|