typo fix, get corpus for each trial
This commit is contained in:
@@ -32,7 +32,7 @@ The data used in this repository (which can be downloaded from [Dryad](https://d
|
||||
- `t15_copyTask.pkl`: This file contains the online Copy Task results required for generating Figure 2.
|
||||
- `t15_personalUse.pkl`: This file contains the Conversation Mode data required for generating Figure 4.
|
||||
- `t15_copyTask_neuralData.zip`: This dataset contains the neural data for the Copy Task.
|
||||
- There are more than 11,300 sentences from 45 sessions spanning 20 months. Each trial of data includes:
|
||||
- There are 10,948 sentences from 45 sessions spanning 20 months. Each trial of data includes:
|
||||
- The session date, block number, and trial number
|
||||
- 512 neural features (2 features [-4.5 RMS threshold crossings and spike band power] per electrode, 256 electrodes), binned at 20 ms resolution. The data were recorded from the speech motor cortex via four high-density microelectrode arrays (64 electrodes each). The 512 features are ordered as follows in all data files:
|
||||
- 0-64: ventral 6v threshold crossings
|
||||
|
@@ -21,6 +21,8 @@ parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
|
||||
parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'],
|
||||
help='Evaluation type: "val" for validation set, "test" for test set. '
|
||||
'If "test", ground truth is not available.')
|
||||
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
|
||||
help='Path to the CSV file with metadata about the dataset (relative to the current working directory).')
|
||||
parser.add_argument('--gpu_number', type=int, default=1,
|
||||
help='GPU number to use for RNN model inference. Set to -1 to use CPU.')
|
||||
args = parser.parse_args()
|
||||
@@ -33,6 +35,9 @@ data_dir = args.data_dir
|
||||
# define evaluation type
|
||||
eval_type = args.eval_type # can be 'val' or 'test'. if 'test', ground truth is not available
|
||||
|
||||
# load csv file
|
||||
b2txt_csv_df = pd.read_csv(args.csv_path)
|
||||
|
||||
# load model args
|
||||
model_args = OmegaConf.load(os.path.join(model_path, 'checkpoint/args.yaml'))
|
||||
|
||||
@@ -85,7 +90,7 @@ for session in model_args['dataset']['sessions']:
|
||||
if f'data_{eval_type}.hdf5' in files:
|
||||
eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
|
||||
|
||||
data = load_h5py_file(eval_file)
|
||||
data = load_h5py_file(eval_file, b2txt_csv_df)
|
||||
test_data[session] = data
|
||||
|
||||
total_test_trials += len(test_data[session]["neural_features"])
|
||||
|
@@ -26,7 +26,7 @@ def _extract_transcription(input):
|
||||
trans += chr(input[c])
|
||||
return trans
|
||||
|
||||
def load_h5py_file(file_path):
|
||||
def load_h5py_file(file_path, b2txt_csv_df):
|
||||
data = {
|
||||
'neural_features': [],
|
||||
'n_time_steps': [],
|
||||
@@ -36,7 +36,8 @@ def load_h5py_file(file_path):
|
||||
'sentence_label': [],
|
||||
'session': [],
|
||||
'block_num': [],
|
||||
'trial_num': []
|
||||
'trial_num': [],
|
||||
'corpus': [],
|
||||
}
|
||||
# Open the hdf5 file for that day
|
||||
with h5py.File(file_path, 'r') as f:
|
||||
@@ -57,6 +58,11 @@ def load_h5py_file(file_path):
|
||||
block_num = g.attrs['block_num']
|
||||
trial_num = g.attrs['trial_num']
|
||||
|
||||
# match this trial up with the csv to get the corpus name
|
||||
year, month, day = session.split('.')[1:]
|
||||
date = f'{year}-{month}-{day}'
|
||||
row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)]
|
||||
corpus_name = row['Corpus'].values[0]
|
||||
|
||||
data['neural_features'].append(neural_features)
|
||||
data['n_time_steps'].append(n_time_steps)
|
||||
@@ -67,6 +73,7 @@ def load_h5py_file(file_path):
|
||||
data['session'].append(session)
|
||||
data['block_num'].append(block_num)
|
||||
data['trial_num'].append(trial_num)
|
||||
data['corpus'].append(corpus_name)
|
||||
return data
|
||||
|
||||
def rearrange_speech_logits_pt(logits):
|
||||
|
Reference in New Issue
Block a user