452 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			452 lines
		
	
	
		
			20 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | import os | ||
|  | import sys | ||
|  | import torch | ||
|  | import numpy as np | ||
|  | import pandas as pd | ||
|  | import redis | ||
|  | from omegaconf import OmegaConf | ||
|  | import time | ||
|  | from tqdm import tqdm | ||
|  | import editdistance | ||
|  | import argparse | ||
|  | 
 | ||
|  | # Add parent directories to path to import models | ||
|  | sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training')) | ||
|  | sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm')) | ||
|  | 
 | ||
|  | from model_training.rnn_model import GRUDecoder | ||
|  | from model_training_lstm.rnn_model import LSTMDecoder | ||
|  | from model_training.evaluate_model_helpers import * | ||
|  | 
 | ||
|  | # argument parser for command line arguments | ||
|  | parser = argparse.ArgumentParser(description='Evaluate ensemble GRU+LSTM models using TTA-E on the copy task dataset.') | ||
|  | parser.add_argument('--gru_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline', | ||
|  |                     help='Path to the pretrained GRU model directory.') | ||
|  | parser.add_argument('--lstm_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn', | ||
|  |                     help='Path to the pretrained LSTM model directory.') | ||
|  | parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final', | ||
|  |                     help='Path to the dataset directory (relative to the current working directory).') | ||
|  | parser.add_argument('--eval_type', type=str, default='test', choices=['val', 'test'], | ||
|  |                     help='Evaluation type: "val" for validation set, "test" for test set.') | ||
|  | parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv', | ||
|  |                     help='Path to the CSV file with metadata about the dataset.') | ||
|  | parser.add_argument('--gpu_number', type=int, default=0, | ||
|  |                     help='GPU number to use for model inference. Set to -1 to use CPU.') | ||
|  | parser.add_argument('--gru_weight', type=float, default=0.5, | ||
|  |                     help='Weight for GRU model in ensemble (LSTM weight = 1 - gru_weight).') | ||
|  | # TTA parameters | ||
|  | parser.add_argument('--tta_samples', type=int, default=5, | ||
|  |                     help='Number of TTA augmentation samples per trial.') | ||
|  | parser.add_argument('--tta_noise_std', type=float, default=0.01, | ||
|  |                     help='Standard deviation for TTA noise augmentation.') | ||
|  | parser.add_argument('--tta_smooth_range', type=float, default=0.5, | ||
|  |                     help='Range for TTA smoothing kernel variation (±range from default).') | ||
|  | parser.add_argument('--tta_scale_range', type=float, default=0.05, | ||
|  |                     help='Range for TTA amplitude scaling (±range from 1.0).') | ||
|  | parser.add_argument('--tta_cut_max', type=int, default=3, | ||
|  |                     help='Maximum number of timesteps to cut from beginning in TTA.') | ||
|  | args = parser.parse_args() | ||
|  | 
 | ||
|  | # Model paths | ||
|  | gru_model_path = args.gru_model_path | ||
|  | lstm_model_path = args.lstm_model_path | ||
|  | data_dir = args.data_dir | ||
|  | 
 | ||
|  | # Ensemble weights | ||
|  | gru_weight = args.gru_weight | ||
|  | lstm_weight = 1.0 - gru_weight | ||
|  | 
 | ||
|  | # TTA parameters | ||
|  | tta_samples = args.tta_samples | ||
|  | tta_noise_std = args.tta_noise_std | ||
|  | tta_smooth_range = args.tta_smooth_range | ||
|  | tta_scale_range = args.tta_scale_range | ||
|  | tta_cut_max = args.tta_cut_max | ||
|  | 
 | ||
|  | print(f"TTA-E Configuration:") | ||
|  | print(f"GRU weight: {gru_weight:.2f}") | ||
|  | print(f"LSTM weight: {lstm_weight:.2f}") | ||
|  | print(f"TTA samples per trial: {tta_samples}") | ||
|  | print(f"TTA noise std: {tta_noise_std}") | ||
|  | print(f"TTA smooth range: ±{tta_smooth_range}") | ||
|  | print(f"TTA scale range: ±{tta_scale_range}") | ||
|  | print(f"TTA max cut: {tta_cut_max} timesteps") | ||
|  | print(f"GRU model path: {gru_model_path}") | ||
|  | print(f"LSTM model path: {lstm_model_path}") | ||
|  | print() | ||
|  | 
 | ||
|  | # Define evaluation type | ||
|  | eval_type = args.eval_type | ||
|  | 
 | ||
|  | # Load CSV file | ||
|  | b2txt_csv_df = pd.read_csv(args.csv_path) | ||
|  | 
 | ||
|  | # Load model arguments for both models | ||
|  | gru_model_args = OmegaConf.load(os.path.join(gru_model_path, 'checkpoint/args.yaml')) | ||
|  | lstm_model_args = OmegaConf.load(os.path.join(lstm_model_path, 'checkpoint/args.yaml')) | ||
|  | 
 | ||
|  | # Set up GPU device | ||
|  | gpu_number = args.gpu_number | ||
|  | if torch.cuda.is_available() and gpu_number >= 0: | ||
|  |     if gpu_number >= torch.cuda.device_count(): | ||
|  |         raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}') | ||
|  |     device = f'cuda:{gpu_number}' | ||
|  |     device = torch.device(device) | ||
|  |     print(f'Using {device} for model inference.') | ||
|  | else: | ||
|  |     if gpu_number >= 0: | ||
|  |         print(f'GPU number {gpu_number} requested but not available.') | ||
|  |     print('Using CPU for model inference.') | ||
|  |     device = torch.device('cpu') | ||
|  | 
 | ||
|  | # Define GRU model | ||
|  | gru_model = GRUDecoder( | ||
|  |     neural_dim=gru_model_args['model']['n_input_features'], | ||
|  |     n_units=gru_model_args['model']['n_units'],  | ||
|  |     n_days=len(gru_model_args['dataset']['sessions']), | ||
|  |     n_classes=gru_model_args['dataset']['n_classes'], | ||
|  |     rnn_dropout=gru_model_args['model']['rnn_dropout'], | ||
|  |     input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'], | ||
|  |     n_layers=gru_model_args['model']['n_layers'], | ||
|  |     patch_size=gru_model_args['model']['patch_size'], | ||
|  |     patch_stride=gru_model_args['model']['patch_stride'], | ||
|  | ) | ||
|  | 
 | ||
|  | # Load GRU model weights | ||
|  | gru_checkpoint = torch.load(os.path.join(gru_model_path, 'checkpoint/best_checkpoint'),  | ||
|  |                            weights_only=False, map_location=device) | ||
|  | # Rename keys to not start with "module." (happens if model was saved with DataParallel) | ||
|  | for key in list(gru_checkpoint['model_state_dict'].keys()): | ||
|  |     gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key) | ||
|  |     gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key) | ||
|  | gru_model.load_state_dict(gru_checkpoint['model_state_dict']) | ||
|  | 
 | ||
|  | # Define LSTM model | ||
|  | lstm_model = LSTMDecoder( | ||
|  |     neural_dim=lstm_model_args['model']['n_input_features'], | ||
|  |     n_units=lstm_model_args['model']['n_units'],  | ||
|  |     n_days=len(lstm_model_args['dataset']['sessions']), | ||
|  |     n_classes=lstm_model_args['dataset']['n_classes'], | ||
|  |     rnn_dropout=lstm_model_args['model']['rnn_dropout'], | ||
|  |     input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'], | ||
|  |     n_layers=lstm_model_args['model']['n_layers'], | ||
|  |     patch_size=lstm_model_args['model']['patch_size'], | ||
|  |     patch_stride=lstm_model_args['model']['patch_stride'], | ||
|  | ) | ||
|  | 
 | ||
|  | # Load LSTM model weights | ||
|  | lstm_checkpoint = torch.load(os.path.join(lstm_model_path, 'checkpoint/best_checkpoint'),  | ||
|  |                             weights_only=False, map_location=device) | ||
|  | # Rename keys to not start with "module." (happens if model was saved with DataParallel) | ||
|  | for key in list(lstm_checkpoint['model_state_dict'].keys()): | ||
|  |     lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key) | ||
|  |     lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key) | ||
|  | lstm_model.load_state_dict(lstm_checkpoint['model_state_dict']) | ||
|  | 
 | ||
|  | # Add models to device | ||
|  | gru_model.to(device) | ||
|  | lstm_model.to(device) | ||
|  | 
 | ||
|  | # Set models to eval mode | ||
|  | gru_model.eval() | ||
|  | lstm_model.eval() | ||
|  | 
 | ||
|  | print("Both models loaded successfully!") | ||
|  | print() | ||
|  | 
 | ||
|  | # TTA-E inference function | ||
|  | def runTTAEnsembleDecodingStep(x, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args,  | ||
|  |                               device, gru_weight, lstm_weight, tta_samples, tta_noise_std,  | ||
|  |                               tta_smooth_range, tta_scale_range, tta_cut_max): | ||
|  |     """
 | ||
|  |     Run TTA-E (Test Time Augmentation + Ensemble) inference: | ||
|  |     1. Apply multiple data augmentations to each input | ||
|  |     2. Run both GRU and LSTM models on each augmented version | ||
|  |     3. Ensemble model outputs with weights | ||
|  |     4. Average across all TTA samples | ||
|  |     """
 | ||
|  |     all_ensemble_logits = [] | ||
|  |      | ||
|  |     # Get default smoothing parameters | ||
|  |     default_smooth_std = gru_model_args['dataset']['data_transforms']['smooth_kernel_std'] | ||
|  |     default_smooth_size = gru_model_args['dataset']['data_transforms']['smooth_kernel_size'] | ||
|  |      | ||
|  |     for tta_iter in range(tta_samples): | ||
|  |         # Apply different augmentation strategies | ||
|  |         x_augmented = x.clone() | ||
|  |          | ||
|  |         if tta_iter == 0: | ||
|  |             # Original data (baseline) | ||
|  |             augmentation_type = "original" | ||
|  |         elif tta_iter == 1: | ||
|  |             # Add Gaussian noise | ||
|  |             noise = torch.randn_like(x_augmented) * tta_noise_std | ||
|  |             x_augmented = x_augmented + noise | ||
|  |             augmentation_type = f"noise_std_{tta_noise_std}" | ||
|  |         elif tta_iter == 2: | ||
|  |             # Amplitude scaling | ||
|  |             scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * tta_scale_range | ||
|  |             x_augmented = x_augmented * scale_factor | ||
|  |             augmentation_type = f"scale_{scale_factor:.3f}" | ||
|  |         elif tta_iter == 3 and tta_cut_max > 0: | ||
|  |             # Time shift (circular shift instead of cutting to maintain length) | ||
|  |             shift_amount = np.random.randint(1, min(tta_cut_max + 1, x_augmented.shape[1] // 8)) | ||
|  |             # Circular shift: move beginning to end | ||
|  |             x_augmented = torch.cat([x_augmented[:, shift_amount:, :],  | ||
|  |                                    x_augmented[:, :shift_amount, :]], dim=1) | ||
|  |             augmentation_type = f"shift_{shift_amount}" | ||
|  |         else: | ||
|  |             # Smoothing variation | ||
|  |             smooth_variation = (torch.rand(1).item() - 0.5) * 2 * tta_smooth_range | ||
|  |             varied_smooth_std = max(0.5, default_smooth_std + smooth_variation) | ||
|  |             augmentation_type = f"smooth_std_{varied_smooth_std:.2f}" | ||
|  | 
 | ||
|  |         # Use autocast for efficiency | ||
|  |         with torch.autocast(device_type="cuda", enabled=gru_model_args['use_amp'], dtype=torch.bfloat16): | ||
|  |              | ||
|  |             # Apply Gaussian smoothing with potentially varied parameters | ||
|  |             if tta_iter < 4 or tta_iter == 0: | ||
|  |                 # Use default smoothing for most augmentations | ||
|  |                 x_smoothed = gauss_smooth( | ||
|  |                     inputs=x_augmented,  | ||
|  |                     device=device, | ||
|  |                     smooth_kernel_std=default_smooth_std, | ||
|  |                     smooth_kernel_size=default_smooth_size, | ||
|  |                     padding='valid', | ||
|  |                 ) | ||
|  |             else: | ||
|  |                 # Use varied smoothing | ||
|  |                 x_smoothed = gauss_smooth( | ||
|  |                     inputs=x_augmented,  | ||
|  |                     device=device, | ||
|  |                     smooth_kernel_std=varied_smooth_std, | ||
|  |                     smooth_kernel_size=default_smooth_size, | ||
|  |                     padding='valid', | ||
|  |                 ) | ||
|  | 
 | ||
|  |             with torch.no_grad(): | ||
|  |                 # Get GRU logits | ||
|  |                 gru_logits, _ = gru_model( | ||
|  |                     x=x_smoothed, | ||
|  |                     day_idx=torch.tensor([input_layer], device=device), | ||
|  |                     states=None, | ||
|  |                     return_state=True, | ||
|  |                 ) | ||
|  |                  | ||
|  |                 # Get LSTM logits | ||
|  |                 lstm_logits, _ = lstm_model( | ||
|  |                     x=x_smoothed, | ||
|  |                     day_idx=torch.tensor([input_layer], device=device), | ||
|  |                     states=None, | ||
|  |                     return_state=True, | ||
|  |                 ) | ||
|  |                  | ||
|  |                 # Ensemble using weighted averaging | ||
|  |                 ensemble_logits = gru_weight * gru_logits + lstm_weight * lstm_logits | ||
|  |                 all_ensemble_logits.append(ensemble_logits) | ||
|  | 
 | ||
|  |     # TTA fusion: Handle potentially different tensor shapes by finding minimum length | ||
|  |     if len(all_ensemble_logits) > 1: | ||
|  |         # Find the minimum sequence length among all TTA samples | ||
|  |         min_length = min([logits.shape[1] for logits in all_ensemble_logits]) | ||
|  |          | ||
|  |         # Truncate all tensors to the minimum length | ||
|  |         truncated_logits = [] | ||
|  |         for logits in all_ensemble_logits: | ||
|  |             if logits.shape[1] > min_length: | ||
|  |                 truncated_logits.append(logits[:, :min_length, :]) | ||
|  |             else: | ||
|  |                 truncated_logits.append(logits) | ||
|  |          | ||
|  |         # Now stack and average | ||
|  |         final_logits = torch.mean(torch.stack(truncated_logits), dim=0) | ||
|  |     else: | ||
|  |         final_logits = all_ensemble_logits[0] | ||
|  |      | ||
|  |     # Convert logits from bfloat16 to float32 | ||
|  |     return final_logits.float().cpu().numpy() | ||
|  | 
 | ||
|  | # Load data for each session (using GRU model args as reference since they should be compatible) | ||
|  | test_data = {} | ||
|  | total_test_trials = 0 | ||
|  | for session in gru_model_args['dataset']['sessions']: | ||
|  |     files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')] | ||
|  |     if f'data_{eval_type}.hdf5' in files: | ||
|  |         eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5') | ||
|  | 
 | ||
|  |         data = load_h5py_file(eval_file, b2txt_csv_df) | ||
|  |         test_data[session] = data | ||
|  | 
 | ||
|  |         total_test_trials += len(test_data[session]["neural_features"]) | ||
|  |         print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.') | ||
|  | print(f'Total number of {eval_type} trials: {total_test_trials}') | ||
|  | print() | ||
|  | 
 | ||
|  | # Put neural data through the TTA-E ensemble model to get phoneme predictions (logits) | ||
|  | with tqdm(total=total_test_trials, desc=f'TTA-E inference ({tta_samples} samples/trial)', unit='trial') as pbar: | ||
|  |     for session, data in test_data.items(): | ||
|  | 
 | ||
|  |         data['logits'] = [] | ||
|  |         data['pred_seq'] = [] | ||
|  |         input_layer = gru_model_args['dataset']['sessions'].index(session) | ||
|  |          | ||
|  |         for trial in range(len(data['neural_features'])): | ||
|  |             # Get neural input for the trial | ||
|  |             neural_input = data['neural_features'][trial] | ||
|  | 
 | ||
|  |             # Add batch dimension | ||
|  |             neural_input = np.expand_dims(neural_input, axis=0) | ||
|  | 
 | ||
|  |             # Convert to torch tensor | ||
|  |             neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16) | ||
|  | 
 | ||
|  |             # Run TTA-E decoding step | ||
|  |             ensemble_logits = runTTAEnsembleDecodingStep( | ||
|  |                 neural_input, input_layer, gru_model, lstm_model,  | ||
|  |                 gru_model_args, lstm_model_args, device, gru_weight, lstm_weight, | ||
|  |                 tta_samples, tta_noise_std, tta_smooth_range, tta_scale_range, tta_cut_max | ||
|  |             ) | ||
|  |             data['logits'].append(ensemble_logits) | ||
|  | 
 | ||
|  |             pbar.update(1) | ||
|  | pbar.close() | ||
|  | 
 | ||
|  | # Convert logits to phoneme sequences and print them out | ||
|  | for session, data in test_data.items(): | ||
|  |     data['pred_seq'] = [] | ||
|  |     for trial in range(len(data['logits'])): | ||
|  |         logits = data['logits'][trial][0] | ||
|  |         pred_seq = np.argmax(logits, axis=-1) | ||
|  |         # Remove blanks (0) | ||
|  |         pred_seq = [int(p) for p in pred_seq if p != 0] | ||
|  |         # Remove consecutive duplicates | ||
|  |         pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]] | ||
|  |         # Convert to phonemes | ||
|  |         pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq] | ||
|  |         # Add to data | ||
|  |         data['pred_seq'].append(pred_seq) | ||
|  | 
 | ||
|  |         # Print out the predicted sequences | ||
|  |         block_num = data['block_num'][trial] | ||
|  |         trial_num = data['trial_num'][trial] | ||
|  |         print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}') | ||
|  |         if eval_type == 'val': | ||
|  |             sentence_label = data['sentence_label'][trial] | ||
|  |             true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]] | ||
|  |             true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq] | ||
|  | 
 | ||
|  |             print(f'Sentence label:      {sentence_label}') | ||
|  |             print(f'True sequence:       {" ".join(true_seq)}') | ||
|  |         print(f'Predicted Sequence:  {" ".join(pred_seq)}') | ||
|  |         print() | ||
|  | 
 | ||
|  | # Language model inference via redis | ||
|  | # Make sure that the standalone language model is running on the localhost redis ip | ||
|  | # See README.md for instructions on how to run the language model | ||
|  | r = redis.Redis(host='localhost', port=6379, db=0) | ||
|  | r.flushall()  # Clear all streams in redis | ||
|  | 
 | ||
|  | # Define redis streams for the remote language model | ||
|  | remote_lm_input_stream = 'remote_lm_input' | ||
|  | remote_lm_output_partial_stream = 'remote_lm_output_partial' | ||
|  | remote_lm_output_final_stream = 'remote_lm_output_final' | ||
|  | 
 | ||
|  | # Set timestamps for last entries seen in the redis streams | ||
|  | remote_lm_output_partial_lastEntrySeen = get_current_redis_time_ms(r) | ||
|  | remote_lm_output_final_lastEntrySeen = get_current_redis_time_ms(r) | ||
|  | remote_lm_done_resetting_lastEntrySeen = get_current_redis_time_ms(r) | ||
|  | remote_lm_done_finalizing_lastEntrySeen = get_current_redis_time_ms(r) | ||
|  | remote_lm_done_updating_lastEntrySeen = get_current_redis_time_ms(r) | ||
|  | 
 | ||
|  | lm_results = { | ||
|  |     'session': [], | ||
|  |     'block': [], | ||
|  |     'trial': [], | ||
|  |     'true_sentence': [], | ||
|  |     'pred_sentence': [], | ||
|  | } | ||
|  | 
 | ||
|  | # Loop through all trials and put logits into the remote language model to get text predictions | ||
|  | with tqdm(total=total_test_trials, desc='Running remote language model', unit='trial') as pbar: | ||
|  |     for session in test_data.keys(): | ||
|  |         for trial in range(len(test_data[session]['logits'])): | ||
|  |             # Get trial logits and rearrange them for the LM | ||
|  |             logits = rearrange_speech_logits_pt(test_data[session]['logits'][trial])[0] | ||
|  | 
 | ||
|  |             # Reset language model | ||
|  |             remote_lm_done_resetting_lastEntrySeen = reset_remote_language_model(r, remote_lm_done_resetting_lastEntrySeen) | ||
|  |              | ||
|  |             # Put logits into LM | ||
|  |             remote_lm_output_partial_lastEntrySeen, decoded = send_logits_to_remote_lm( | ||
|  |                 r, | ||
|  |                 remote_lm_input_stream, | ||
|  |                 remote_lm_output_partial_stream, | ||
|  |                 remote_lm_output_partial_lastEntrySeen, | ||
|  |                 logits, | ||
|  |             ) | ||
|  | 
 | ||
|  |             # Finalize remote LM | ||
|  |             remote_lm_output_final_lastEntrySeen, lm_out = finalize_remote_lm( | ||
|  |                 r, | ||
|  |                 remote_lm_output_final_stream, | ||
|  |                 remote_lm_output_final_lastEntrySeen, | ||
|  |             ) | ||
|  | 
 | ||
|  |             # Get the best candidate sentence | ||
|  |             best_candidate_sentence = lm_out['candidate_sentences'][0] | ||
|  | 
 | ||
|  |             # Store results | ||
|  |             lm_results['session'].append(session) | ||
|  |             lm_results['block'].append(test_data[session]['block_num'][trial]) | ||
|  |             lm_results['trial'].append(test_data[session]['trial_num'][trial]) | ||
|  |             if eval_type == 'val': | ||
|  |                 lm_results['true_sentence'].append(test_data[session]['sentence_label'][trial]) | ||
|  |             else: | ||
|  |                 lm_results['true_sentence'].append(None) | ||
|  |             lm_results['pred_sentence'].append(best_candidate_sentence) | ||
|  | 
 | ||
|  |             # Update progress bar | ||
|  |             pbar.update(1) | ||
|  | pbar.close() | ||
|  | 
 | ||
|  | # If using the validation set, calculate the aggregate word error rate (WER) | ||
|  | if eval_type == 'val': | ||
|  |     total_true_length = 0 | ||
|  |     total_edit_distance = 0 | ||
|  | 
 | ||
|  |     lm_results['edit_distance'] = [] | ||
|  |     lm_results['num_words'] = [] | ||
|  | 
 | ||
|  |     for i in range(len(lm_results['pred_sentence'])): | ||
|  |         true_sentence = remove_punctuation(lm_results['true_sentence'][i]).strip() | ||
|  |         pred_sentence = remove_punctuation(lm_results['pred_sentence'][i]).strip() | ||
|  |         ed = editdistance.eval(true_sentence.split(), pred_sentence.split()) | ||
|  | 
 | ||
|  |         total_true_length += len(true_sentence.split()) | ||
|  |         total_edit_distance += ed | ||
|  | 
 | ||
|  |         lm_results['edit_distance'].append(ed) | ||
|  |         lm_results['num_words'].append(len(true_sentence.split())) | ||
|  | 
 | ||
|  |         print(f'{lm_results["session"][i]} - Block {lm_results["block"][i]}, Trial {lm_results["trial"][i]}') | ||
|  |         print(f'True sentence:       {true_sentence}') | ||
|  |         print(f'Predicted sentence:  {pred_sentence}') | ||
|  |         print(f'WER: {ed} / {len(true_sentence.split())} = {100 * ed / len(true_sentence.split()):.2f}%') | ||
|  |         print() | ||
|  | 
 | ||
|  |     print(f'Total true sentence length: {total_true_length}') | ||
|  |     print(f'Total edit distance: {total_edit_distance}') | ||
|  |     print(f'Aggregate Word Error Rate (WER): {100 * total_edit_distance / total_true_length:.2f}%') | ||
|  | 
 | ||
|  | # Write predicted sentences to a CSV file with timestamp and TTA-E info | ||
|  | timestamp = time.strftime("%Y%m%d_%H%M%S") | ||
|  | output_file = f'TTA-E_gru{gru_weight:.1f}_lstm{lstm_weight:.1f}_samples{tta_samples}_{eval_type}_{timestamp}.csv' | ||
|  | output_path = os.path.join(os.path.dirname(__file__), output_file) | ||
|  | 
 | ||
|  | ids = [i for i in range(len(lm_results['pred_sentence']))] | ||
|  | df_out = pd.DataFrame({'id': ids, 'text': lm_results['pred_sentence']}) | ||
|  | df_out.to_csv(output_path, index=False) | ||
|  | 
 | ||
|  | print(f'\nResults saved to: {output_path}') | ||
|  | print(f'TTA-E configuration: GRU weight = {gru_weight:.2f}, LSTM weight = {lstm_weight:.2f}, TTA samples = {tta_samples}') |