typo fix, get corpus for each trial

This commit is contained in:
nckcard
2025-07-14 13:58:34 -07:00
parent 82274632af
commit e93cff1e2e
3 changed files with 16 additions and 4 deletions

View File

@@ -21,6 +21,8 @@ parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'],
help='Evaluation type: "val" for validation set, "test" for test set. '
'If "test", ground truth is not available.')
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
help='Path to the CSV file with metadata about the dataset (relative to the current working directory).')
parser.add_argument('--gpu_number', type=int, default=1,
help='GPU number to use for RNN model inference. Set to -1 to use CPU.')
args = parser.parse_args()
@@ -33,6 +35,9 @@ data_dir = args.data_dir
# define evaluation type
eval_type = args.eval_type # can be 'val' or 'test'. if 'test', ground truth is not available
# load csv file
b2txt_csv_df = pd.read_csv(args.csv_path)
# load model args
model_args = OmegaConf.load(os.path.join(model_path, 'checkpoint/args.yaml'))
@@ -85,7 +90,7 @@ for session in model_args['dataset']['sessions']:
if f'data_{eval_type}.hdf5' in files:
eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
data = load_h5py_file(eval_file)
data = load_h5py_file(eval_file, b2txt_csv_df)
test_data[session] = data
total_test_trials += len(test_data[session]["neural_features"])

View File

@@ -26,7 +26,7 @@ def _extract_transcription(input):
trans += chr(input[c])
return trans
def load_h5py_file(file_path):
def load_h5py_file(file_path, b2txt_csv_df):
data = {
'neural_features': [],
'n_time_steps': [],
@@ -36,7 +36,8 @@ def load_h5py_file(file_path):
'sentence_label': [],
'session': [],
'block_num': [],
'trial_num': []
'trial_num': [],
'corpus': [],
}
# Open the hdf5 file for that day
with h5py.File(file_path, 'r') as f:
@@ -57,6 +58,11 @@ def load_h5py_file(file_path):
block_num = g.attrs['block_num']
trial_num = g.attrs['trial_num']
# match this trial up with the csv to get the corpus name
year, month, day = session.split('.')[1:]
date = f'{year}-{month}-{day}'
row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)]
corpus_name = row['Corpus'].values[0]
data['neural_features'].append(neural_features)
data['n_time_steps'].append(n_time_steps)
@@ -67,6 +73,7 @@ def load_h5py_file(file_path):
data['session'].append(session)
data['block_num'].append(block_num)
data['trial_num'].append(trial_num)
data['corpus'].append(corpus_name)
return data
def rearrange_speech_logits_pt(logits):