import os import sys import torch import numpy as np import pandas as pd import redis from omegaconf import OmegaConf import time from tqdm import tqdm import editdistance import argparse # Add parent directories to path to import models sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training')) sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm')) from model_training.rnn_model import GRUDecoder from model_training_lstm.rnn_model import LSTMDecoder from model_training.evaluate_model_helpers import * # argument parser for command line arguments parser = argparse.ArgumentParser(description='Evaluate ensemble GRU+LSTM models using TTA-E on the copy task dataset.') parser.add_argument('--gru_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline', help='Path to the pretrained GRU model directory.') parser.add_argument('--lstm_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn', help='Path to the pretrained LSTM model directory.') parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final', help='Path to the dataset directory (relative to the current working directory).') parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'], help='Evaluation type: "val" for validation set, "test" for test set.') parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv', help='Path to the CSV file with metadata about the dataset.') parser.add_argument('--gpu_number', type=int, default=0, help='GPU number to use for model inference. Set to -1 to use CPU.') parser.add_argument('--gru_weight', type=float, default=0.5, help='Weight for GRU model in ensemble (LSTM weight = 1 - gru_weight).') # TTA parameters parser.add_argument('--tta_samples', type=int, default=5, help='Number of TTA augmentation samples per trial.') parser.add_argument('--tta_noise_std', type=float, default=0.01, help='Standard deviation for TTA noise augmentation.') parser.add_argument('--tta_smooth_range', type=float, default=0.5, help='Range for TTA smoothing kernel variation (±range from default).') parser.add_argument('--tta_scale_range', type=float, default=0.05, help='Range for TTA amplitude scaling (±range from 1.0).') parser.add_argument('--tta_cut_max', type=int, default=3, help='Maximum number of timesteps to cut from beginning in TTA.') args = parser.parse_args() # Model paths gru_model_path = args.gru_model_path lstm_model_path = args.lstm_model_path data_dir = args.data_dir # Ensemble weights gru_weight = args.gru_weight lstm_weight = 1.0 - gru_weight # TTA parameters tta_samples = args.tta_samples tta_noise_std = args.tta_noise_std tta_smooth_range = args.tta_smooth_range tta_scale_range = args.tta_scale_range tta_cut_max = args.tta_cut_max print(f"TTA-E Configuration:") print(f"GRU weight: {gru_weight:.2f}") print(f"LSTM weight: {lstm_weight:.2f}") print(f"TTA samples per trial: {tta_samples}") print(f"TTA noise std: {tta_noise_std}") print(f"TTA smooth range: ±{tta_smooth_range}") print(f"TTA scale range: ±{tta_scale_range}") print(f"TTA max cut: {tta_cut_max} timesteps") print(f"GRU model path: {gru_model_path}") print(f"LSTM model path: {lstm_model_path}") print() # Define evaluation type eval_type = args.eval_type # Load CSV file b2txt_csv_df = pd.read_csv(args.csv_path) # Load model arguments for both models gru_model_args = OmegaConf.load(os.path.join(gru_model_path, 'checkpoint/args.yaml')) lstm_model_args = OmegaConf.load(os.path.join(lstm_model_path, 'checkpoint/args.yaml')) # Set up GPU device gpu_number = args.gpu_number if torch.cuda.is_available() and gpu_number >= 0: if gpu_number >= torch.cuda.device_count(): raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}') device = f'cuda:{gpu_number}' device = torch.device(device) print(f'Using {device} for model inference.') else: if gpu_number >= 0: print(f'GPU number {gpu_number} requested but not available.') print('Using CPU for model inference.') device = torch.device('cpu') # Define GRU model gru_model = GRUDecoder( neural_dim=gru_model_args['model']['n_input_features'], n_units=gru_model_args['model']['n_units'], n_days=len(gru_model_args['dataset']['sessions']), n_classes=gru_model_args['dataset']['n_classes'], rnn_dropout=gru_model_args['model']['rnn_dropout'], input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'], n_layers=gru_model_args['model']['n_layers'], patch_size=gru_model_args['model']['patch_size'], patch_stride=gru_model_args['model']['patch_stride'], ) # Load GRU model weights gru_checkpoint = torch.load(os.path.join(gru_model_path, 'checkpoint/best_checkpoint'), weights_only=False, map_location=device) # Rename keys to not start with "module." (happens if model was saved with DataParallel) for key in list(gru_checkpoint['model_state_dict'].keys()): gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key) gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key) gru_model.load_state_dict(gru_checkpoint['model_state_dict']) # Define LSTM model lstm_model = LSTMDecoder( neural_dim=lstm_model_args['model']['n_input_features'], n_units=lstm_model_args['model']['n_units'], n_days=len(lstm_model_args['dataset']['sessions']), n_classes=lstm_model_args['dataset']['n_classes'], rnn_dropout=lstm_model_args['model']['rnn_dropout'], input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'], n_layers=lstm_model_args['model']['n_layers'], patch_size=lstm_model_args['model']['patch_size'], patch_stride=lstm_model_args['model']['patch_stride'], ) # Load LSTM model weights lstm_checkpoint = torch.load(os.path.join(lstm_model_path, 'checkpoint/best_checkpoint'), weights_only=False, map_location=device) # Rename keys to not start with "module." (happens if model was saved with DataParallel) for key in list(lstm_checkpoint['model_state_dict'].keys()): lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key) lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key) lstm_model.load_state_dict(lstm_checkpoint['model_state_dict']) # Add models to device gru_model.to(device) lstm_model.to(device) # Set models to eval mode gru_model.eval() lstm_model.eval() print("Both models loaded successfully!") print() # TTA-E inference function def runTTAEnsembleDecodingStep(x, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args, device, gru_weight, lstm_weight, tta_samples, tta_noise_std, tta_smooth_range, tta_scale_range, tta_cut_max): """ Run TTA-E (Test Time Augmentation + Ensemble) inference: 1. Apply multiple data augmentations to each input 2. Run both GRU and LSTM models on each augmented version 3. Ensemble model outputs with weights 4. Average across all TTA samples """ all_ensemble_logits = [] # Get default smoothing parameters default_smooth_std = gru_model_args['dataset']['data_transforms']['smooth_kernel_std'] default_smooth_size = gru_model_args['dataset']['data_transforms']['smooth_kernel_size'] for tta_iter in range(tta_samples): # Apply different augmentation strategies x_augmented = x.clone() if tta_iter == 0: # Original data (baseline) augmentation_type = "original" elif tta_iter == 1: # Add Gaussian noise noise = torch.randn_like(x_augmented) * tta_noise_std x_augmented = x_augmented + noise augmentation_type = f"noise_std_{tta_noise_std}" elif tta_iter == 2: # Amplitude scaling scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * tta_scale_range x_augmented = x_augmented * scale_factor augmentation_type = f"scale_{scale_factor:.3f}" elif tta_iter == 3 and tta_cut_max > 0: # Time shift (circular shift instead of cutting to maintain length) shift_amount = np.random.randint(1, min(tta_cut_max + 1, x_augmented.shape[1] // 8)) # Circular shift: move beginning to end x_augmented = torch.cat([x_augmented[:, shift_amount:, :], x_augmented[:, :shift_amount, :]], dim=1) augmentation_type = f"shift_{shift_amount}" else: # Smoothing variation smooth_variation = (torch.rand(1).item() - 0.5) * 2 * tta_smooth_range varied_smooth_std = max(0.5, default_smooth_std + smooth_variation) augmentation_type = f"smooth_std_{varied_smooth_std:.2f}" # Use autocast for efficiency with torch.autocast(device_type="cuda", enabled=gru_model_args['use_amp'], dtype=torch.bfloat16): # Apply Gaussian smoothing with potentially varied parameters if tta_iter < 4 or tta_iter == 0: # Use default smoothing for most augmentations x_smoothed = gauss_smooth( inputs=x_augmented, device=device, smooth_kernel_std=default_smooth_std, smooth_kernel_size=default_smooth_size, padding='valid', ) else: # Use varied smoothing x_smoothed = gauss_smooth( inputs=x_augmented, device=device, smooth_kernel_std=varied_smooth_std, smooth_kernel_size=default_smooth_size, padding='valid', ) with torch.no_grad(): # Get GRU logits gru_logits, _ = gru_model( x=x_smoothed, day_idx=torch.tensor([input_layer], device=device), states=None, return_state=True, ) # Get LSTM logits lstm_logits, _ = lstm_model( x=x_smoothed, day_idx=torch.tensor([input_layer], device=device), states=None, return_state=True, ) # Ensemble using weighted averaging ensemble_logits = gru_weight * gru_logits + lstm_weight * lstm_logits all_ensemble_logits.append(ensemble_logits) # TTA fusion: Handle potentially different tensor shapes by finding minimum length if len(all_ensemble_logits) > 1: # Find the minimum sequence length among all TTA samples min_length = min([logits.shape[1] for logits in all_ensemble_logits]) # Truncate all tensors to the minimum length truncated_logits = [] for logits in all_ensemble_logits: if logits.shape[1] > min_length: truncated_logits.append(logits[:, :min_length, :]) else: truncated_logits.append(logits) # Now stack and average final_logits = torch.mean(torch.stack(truncated_logits), dim=0) else: final_logits = all_ensemble_logits[0] # Convert logits from bfloat16 to float32 return final_logits.float().cpu().numpy() # Load data for each session (using GRU model args as reference since they should be compatible) test_data = {} total_test_trials = 0 for session in gru_model_args['dataset']['sessions']: files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')] if f'data_{eval_type}.hdf5' in files: eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5') data = load_h5py_file(eval_file, b2txt_csv_df) test_data[session] = data total_test_trials += len(test_data[session]["neural_features"]) print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.') print(f'Total number of {eval_type} trials: {total_test_trials}') print() # Put neural data through the TTA-E ensemble model to get phoneme predictions (logits) with tqdm(total=total_test_trials, desc=f'TTA-E inference ({tta_samples} samples/trial)', unit='trial') as pbar: for session, data in test_data.items(): data['logits'] = [] data['pred_seq'] = [] input_layer = gru_model_args['dataset']['sessions'].index(session) for trial in range(len(data['neural_features'])): # Get neural input for the trial neural_input = data['neural_features'][trial] # Add batch dimension neural_input = np.expand_dims(neural_input, axis=0) # Convert to torch tensor neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16) # Run TTA-E decoding step ensemble_logits = runTTAEnsembleDecodingStep( neural_input, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args, device, gru_weight, lstm_weight, tta_samples, tta_noise_std, tta_smooth_range, tta_scale_range, tta_cut_max ) data['logits'].append(ensemble_logits) pbar.update(1) pbar.close() # Convert logits to phoneme sequences and print them out for session, data in test_data.items(): data['pred_seq'] = [] for trial in range(len(data['logits'])): logits = data['logits'][trial][0] pred_seq = np.argmax(logits, axis=-1) # Remove blanks (0) pred_seq = [int(p) for p in pred_seq if p != 0] # Remove consecutive duplicates pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]] # Convert to phonemes pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq] # Add to data data['pred_seq'].append(pred_seq) # Print out the predicted sequences block_num = data['block_num'][trial] trial_num = data['trial_num'][trial] print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}') if eval_type == 'val': sentence_label = data['sentence_label'][trial] true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]] true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq] print(f'Sentence label: {sentence_label}') print(f'True sequence: {" ".join(true_seq)}') print(f'Predicted Sequence: {" ".join(pred_seq)}') print() # Language model inference via redis # Make sure that the standalone language model is running on the localhost redis ip # See README.md for instructions on how to run the language model r = redis.Redis(host='localhost', port=6379, db=0) r.flushall() # Clear all streams in redis # Define redis streams for the remote language model remote_lm_input_stream = 'remote_lm_input' remote_lm_output_partial_stream = 'remote_lm_output_partial' remote_lm_output_final_stream = 'remote_lm_output_final' # Set timestamps for last entries seen in the redis streams remote_lm_output_partial_lastEntrySeen = get_current_redis_time_ms(r) remote_lm_output_final_lastEntrySeen = get_current_redis_time_ms(r) remote_lm_done_resetting_lastEntrySeen = get_current_redis_time_ms(r) remote_lm_done_finalizing_lastEntrySeen = get_current_redis_time_ms(r) remote_lm_done_updating_lastEntrySeen = get_current_redis_time_ms(r) lm_results = { 'session': [], 'block': [], 'trial': [], 'true_sentence': [], 'pred_sentence': [], } # Loop through all trials and put logits into the remote language model to get text predictions with tqdm(total=total_test_trials, desc='Running remote language model', unit='trial') as pbar: for session in test_data.keys(): for trial in range(len(test_data[session]['logits'])): # Get trial logits and rearrange them for the LM logits = rearrange_speech_logits_pt(test_data[session]['logits'][trial])[0] # Reset language model remote_lm_done_resetting_lastEntrySeen = reset_remote_language_model(r, remote_lm_done_resetting_lastEntrySeen) # Put logits into LM remote_lm_output_partial_lastEntrySeen, decoded = send_logits_to_remote_lm( r, remote_lm_input_stream, remote_lm_output_partial_stream, remote_lm_output_partial_lastEntrySeen, logits, ) # Finalize remote LM remote_lm_output_final_lastEntrySeen, lm_out = finalize_remote_lm( r, remote_lm_output_final_stream, remote_lm_output_final_lastEntrySeen, ) # Get the best candidate sentence best_candidate_sentence = lm_out['candidate_sentences'][0] # Store results lm_results['session'].append(session) lm_results['block'].append(test_data[session]['block_num'][trial]) lm_results['trial'].append(test_data[session]['trial_num'][trial]) if eval_type == 'val': lm_results['true_sentence'].append(test_data[session]['sentence_label'][trial]) else: lm_results['true_sentence'].append(None) lm_results['pred_sentence'].append(best_candidate_sentence) # Update progress bar pbar.update(1) pbar.close() # If using the validation set, calculate the aggregate word error rate (WER) if eval_type == 'val': total_true_length = 0 total_edit_distance = 0 lm_results['edit_distance'] = [] lm_results['num_words'] = [] for i in range(len(lm_results['pred_sentence'])): true_sentence = remove_punctuation(lm_results['true_sentence'][i]).strip() pred_sentence = remove_punctuation(lm_results['pred_sentence'][i]).strip() ed = editdistance.eval(true_sentence.split(), pred_sentence.split()) total_true_length += len(true_sentence.split()) total_edit_distance += ed lm_results['edit_distance'].append(ed) lm_results['num_words'].append(len(true_sentence.split())) print(f'{lm_results["session"][i]} - Block {lm_results["block"][i]}, Trial {lm_results["trial"][i]}') print(f'True sentence: {true_sentence}') print(f'Predicted sentence: {pred_sentence}') print(f'WER: {ed} / {len(true_sentence.split())} = {100 * ed / len(true_sentence.split()):.2f}%') print() print(f'Total true sentence length: {total_true_length}') print(f'Total edit distance: {total_edit_distance}') print(f'Aggregate Word Error Rate (WER): {100 * total_edit_distance / total_true_length:.2f}%') # Write predicted sentences to a CSV file with timestamp and TTA-E info timestamp = time.strftime("%Y%m%d_%H%M%S") output_file = f'TTA-E_gru{gru_weight:.1f}_lstm{lstm_weight:.1f}_samples{tta_samples}_{eval_type}_{timestamp}.csv' output_path = os.path.join(os.path.dirname(__file__), output_file) ids = [i for i in range(len(lm_results['pred_sentence']))] df_out = pd.DataFrame({'id': ids, 'text': lm_results['pred_sentence']}) df_out.to_csv(output_path, index=False) print(f'\nResults saved to: {output_path}') print(f'TTA-E configuration: GRU weight = {gru_weight:.2f}, LSTM weight = {lstm_weight:.2f}, TTA samples = {tta_samples}')