Files
b2txt25/TTA-E/temp.py
2025-10-06 15:17:44 +08:00

335 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import sys
import torch
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
import time
from tqdm import tqdm
import editdistance
import argparse
# Add parent directories to path to import models
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training'))
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm'))
from model_training.rnn_model import GRUDecoder
from model_training_lstm.rnn_model import LSTMDecoder
from model_training.evaluate_model_helpers import *
# argument parser for command line arguments
parser = argparse.ArgumentParser(description='Evaluate ensemble GRU+LSTM models (without TTA) on the copy task dataset.')
parser.add_argument('--gru_model_path', type=str, default='../data/t15_pretrained_rnn_baseline',
help='Path to the pretrained GRU model directory.')
parser.add_argument('--lstm_model_path', type=str, default='../model_training_lstm/trained_models/baseline_rnn',
help='Path to the pretrained LSTM model directory.')
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
help='Path to the dataset directory (relative to the current working directory).')
parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'],
help='Evaluation type: "val" for validation set, "test" for test set.')
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
help='Path to the CSV file with metadata about the dataset.')
parser.add_argument('--gpu_number', type=int, default=0,
help='GPU number to use for model inference. Set to -1 to use CPU.')
parser.add_argument('--gru_weight', type=float, default=0.5,
help='Weight for GRU model in ensemble (LSTM weight = 1 - gru_weight).')
args = parser.parse_args()
# Model paths
gru_model_path = args.gru_model_path
lstm_model_path = args.lstm_model_path
data_dir = args.data_dir
eval_type = args.eval_type
gru_weight = args.gru_weight
lstm_weight = 1.0 - gru_weight
# Load CSV file
b2txt_csv_df = pd.read_csv(args.csv_path)
# Load model args
gru_model_args = OmegaConf.load(os.path.join(gru_model_path, 'checkpoint/args.yaml'))
lstm_model_args = OmegaConf.load(os.path.join(lstm_model_path, 'checkpoint/args.yaml'))
print(f'GRU model path: {gru_model_path}')
print(f'LSTM model path: {lstm_model_path}')
print(f'Data directory: {data_dir}')
print(f'Evaluation type: {eval_type}')
print(f'GRU weight: {gru_weight:.2f}, LSTM weight: {lstm_weight:.2f}')
print()
# Set up GPU device
gpu_number = args.gpu_number
if torch.cuda.is_available() and gpu_number >= 0:
if gpu_number >= torch.cuda.device_count():
raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}')
device = f'cuda:{gpu_number}'
device = torch.device(device)
print(f'Using GPU device: {device}')
else:
device = torch.device('cpu')
print('Using CPU device')
# Define GRU model
gru_model = GRUDecoder(
neural_dim=gru_model_args['model']['n_input_features'],
n_units=gru_model_args['model']['n_units'],
n_days=len(gru_model_args['dataset']['sessions']),
n_classes=gru_model_args['dataset']['n_classes'],
rnn_dropout=gru_model_args['model']['rnn_dropout'],
input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'],
n_layers=gru_model_args['model']['n_layers'],
patch_size=gru_model_args['model']['patch_size'],
patch_stride=gru_model_args['model']['patch_stride'],
)
# Load GRU model weights
gru_checkpoint = torch.load(os.path.join(gru_model_path, 'checkpoint/best_checkpoint'),
weights_only=False, map_location=device)
# Rename keys to not start with "module." (happens if model was saved with DataParallel)
for key in list(gru_checkpoint['model_state_dict'].keys()):
gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key)
gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key)
gru_model.load_state_dict(gru_checkpoint['model_state_dict'])
# Define LSTM model
lstm_model = LSTMDecoder(
neural_dim=lstm_model_args['model']['n_input_features'],
n_units=lstm_model_args['model']['n_units'],
n_days=len(lstm_model_args['dataset']['sessions']),
n_classes=lstm_model_args['dataset']['n_classes'],
rnn_dropout=lstm_model_args['model']['rnn_dropout'],
input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'],
n_layers=lstm_model_args['model']['n_layers'],
patch_size=lstm_model_args['model']['patch_size'],
patch_stride=lstm_model_args['model']['patch_stride'],
)
# Load LSTM model weights
lstm_checkpoint = torch.load(os.path.join(lstm_model_path, 'checkpoint/best_checkpoint'),
weights_only=False, map_location=device)
# Rename keys to not start with "module." (happens if model was saved with DataParallel)
for key in list(lstm_checkpoint['model_state_dict'].keys()):
lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
lstm_model.load_state_dict(lstm_checkpoint['model_state_dict'])
# Add models to device
gru_model.to(device)
lstm_model.to(device)
# Set models to eval mode
gru_model.eval()
lstm_model.eval()
print(f'Loaded GRU model from: {gru_model_path}')
print(f'Loaded LSTM model from: {lstm_model_path}')
print()
def runEnsembleDecodingStep(x, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args,
device, gru_weight, lstm_weight):
"""
Run ensemble inference without TTA:
1. Apply Gaussian smoothing to input
2. Run both GRU and LSTM models
3. Ensemble model outputs with weights
"""
# Get smoothing parameters
smooth_std = gru_model_args['dataset']['data_transforms']['smooth_kernel_std']
smooth_size = gru_model_args['dataset']['data_transforms']['smooth_kernel_size']
# Use autocast for efficiency (disabled for now to avoid dtype issues)
# with torch.autocast(device_type="cuda", enabled=gru_model_args['use_amp'], dtype=torch.bfloat16):
# Convert to float32 for smoothing operations to avoid dtype mismatch
x_float = x.float()
# Apply Gaussian smoothing
x_smoothed = gauss_smooth(
inputs=x_float,
device=device,
smooth_kernel_std=smooth_std,
smooth_kernel_size=smooth_size,
padding='valid',
)
# Keep as float32 for model inference
# x_smoothed = x_smoothed.to(torch.bfloat16)
with torch.no_grad():
# Convert to float32 for model inference to avoid einsum dtype mismatch
x_smoothed_float = x_smoothed.float()
# Get GRU logits
gru_logits, _ = gru_model(
x=x_smoothed_float,
day_idx=torch.tensor([input_layer], device=device),
states=None,
return_state=True,
)
# Get LSTM logits
lstm_logits, _ = lstm_model(
x=x_smoothed_float,
day_idx=torch.tensor([input_layer], device=device),
states=None,
return_state=True,
)
# 🔧 CORRECTED ENSEMBLE METHOD: Scale Normalized Averaging
# 原始问题GRU方差~7.97, LSTM方差~5.73直接平均会偏向GRU
# 解决方案:方差归一化后再平均
# Convert to numpy for easier manipulation
gru_logits_np = gru_logits.float().cpu().numpy()[0]
lstm_logits_np = lstm_logits.float().cpu().numpy()[0]
# Calculate per-timestep variance for normalization
gru_var = np.var(gru_logits_np, axis=-1, keepdims=True)
lstm_var = np.var(lstm_logits_np, axis=-1, keepdims=True)
# Normalize by standard deviation to equalize scales
gru_normalized = gru_logits_np / np.sqrt(gru_var + 1e-8)
lstm_normalized = lstm_logits_np / np.sqrt(lstm_var + 1e-8)
# Now apply weighted averaging on normalized logits
ensemble_logits_np = gru_weight * gru_normalized + lstm_weight * lstm_normalized
# Convert back to tensor
ensemble_logits = torch.tensor(ensemble_logits_np, device=device, dtype=torch.float32).unsqueeze(0)
# Convert logits from bfloat16 to float32
return ensemble_logits.float().cpu().numpy()
# Load data for each session
test_data = {}
total_test_trials = 0
for session in gru_model_args['dataset']['sessions']:
files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]
if f'data_{eval_type}.hdf5' in files:
eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')
data = load_h5py_file(eval_file, b2txt_csv_df)
test_data[session] = data
total_test_trials += len(test_data[session]["neural_features"])
print(f'Loaded {len(test_data[session]["neural_features"])} {eval_type} trials for session {session}.')
print(f'Total number of {eval_type} trials: {total_test_trials}')
print()
# Put neural data through the ensemble model to get phoneme predictions (logits)
with tqdm(total=total_test_trials, desc=f'Ensemble inference (GRU+LSTM)', unit='trial') as pbar:
for session, data in test_data.items():
data['logits'] = []
data['pred_seq'] = []
input_layer = gru_model_args['dataset']['sessions'].index(session)
for trial in range(len(data['neural_features'])):
# Get neural input for the trial
neural_input = data['neural_features'][trial]
# Add batch dimension
neural_input = np.expand_dims(neural_input, axis=0)
# Convert to torch tensor (use float32 to avoid dtype issues)
neural_input = torch.tensor(neural_input, device=device, dtype=torch.float32)
# Run ensemble decoding step
ensemble_logits = runEnsembleDecodingStep(
neural_input, input_layer, gru_model, lstm_model,
gru_model_args, lstm_model_args, device, gru_weight, lstm_weight
)
data['logits'].append(ensemble_logits)
pbar.update(1)
pbar.close()
# Convert logits to phoneme sequences and calculate PER
total_phonemes = 0
total_phoneme_errors = 0
per_results = []
for session, data in test_data.items():
data['pred_seq'] = []
for trial in range(len(data['logits'])):
logits = data['logits'][trial][0]
pred_seq = np.argmax(logits, axis=-1)
# Remove blanks (0)
pred_seq = [int(p) for p in pred_seq if p != 0]
# Remove consecutive duplicates
pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
# Convert to phonemes
pred_phonemes = [LOGIT_TO_PHONEME[p] for p in pred_seq]
# Add to data
data['pred_seq'].append(pred_phonemes)
# Print out the predicted sequences and calculate PER
block_num = data['block_num'][trial]
trial_num = data['trial_num'][trial]
print(f'Session: {session}, Block: {block_num}, Trial: {trial_num}')
if eval_type == 'val':
sentence_label = data['sentence_label'][trial]
true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
true_phonemes = [LOGIT_TO_PHONEME[p] for p in true_seq]
# Calculate phoneme error rate (PER)
phoneme_errors = editdistance.eval(true_phonemes, pred_phonemes)
num_phonemes = len(true_phonemes)
per = phoneme_errors / num_phonemes if num_phonemes > 0 else 0
total_phonemes += num_phonemes
total_phoneme_errors += phoneme_errors
per_results.append({
'session': session,
'block': block_num,
'trial': trial_num,
'sentence_label': sentence_label,
'true_phonemes': true_phonemes,
'pred_phonemes': pred_phonemes,
'phoneme_errors': phoneme_errors,
'num_phonemes': num_phonemes,
'per': per
})
print(f'Sentence label: {sentence_label}')
print(f'True phonemes: {" ".join(true_phonemes)}')
print(f'Pred phonemes: {" ".join(pred_phonemes)}')
print(f'PER: {phoneme_errors} / {num_phonemes} = {100 * per:.2f}%')
else:
print(f'Pred phonemes: {" ".join(pred_phonemes)}')
print()
# Calculate and print aggregate PER if using validation set
if eval_type == 'val' and total_phonemes > 0:
aggregate_per = total_phoneme_errors / total_phonemes
print(f'Total phonemes: {total_phonemes}')
print(f'Total phoneme errors: {total_phoneme_errors}')
print(f'Aggregate Phoneme Error Rate (PER): {100 * aggregate_per:.2f}%')
print()
# Save results to CSV
timestamp = time.strftime("%Y%m%d_%H%M%S")
output_file = f'ensemble_gru{gru_weight:.1f}_lstm{lstm_weight:.1f}_{eval_type}_{timestamp}.csv'
output_path = os.path.join(os.path.dirname(__file__), output_file)
if eval_type == 'val':
# Save detailed results for validation
df_out = pd.DataFrame(per_results)
df_out.to_csv(output_path, index=False)
print(f'Detailed results saved to: {output_path}')
else:
# Save only predictions for test set
ids = []
pred_phonemes_str = []
for session, data in test_data.items():
for trial in range(len(data['pred_seq'])):
ids.append(len(ids))
pred_phonemes_str.append(' '.join(data['pred_seq'][trial]))
df_out = pd.DataFrame({'id': ids, 'phonemes': pred_phonemes_str})
df_out.to_csv(output_path, index=False)
print(f'Predictions saved to: {output_path}')
print(f'Ensemble configuration: GRU weight = {gru_weight:.2f}, LSTM weight = {lstm_weight:.2f}')