Files
b2txt25/TTA-E/parameter_search.py
2025-10-06 15:17:44 +08:00

473 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
TTA-E参数搜索脚本
先运行所有基础配置获取预测结果,然后搜索最优的参数组合
避免重复模型推理,提高搜索效率
"""
import os
import sys
import torch
import numpy as np
import pandas as pd
from omegaconf import OmegaConf
import time
from tqdm import tqdm
import editdistance
import argparse
import itertools
import json
from concurrent.futures import ProcessPoolExecutor
import pickle
# Add parent directories to path to import models
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training'))
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm'))
from model_training.rnn_model import GRUDecoder
from model_training_lstm.rnn_model import LSTMDecoder
from model_training.evaluate_model_helpers import *
def parse_arguments():
parser = argparse.ArgumentParser(description='TTA-E Parameter Search for optimal configuration.')
parser.add_argument('--gru_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline',
help='Path to the pretrained GRU model directory.')
parser.add_argument('--lstm_model_path', type=str, default='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn',
help='Path to the pretrained LSTM model directory.')
parser.add_argument('--data_dir', type=str, default='../data/hdf5_data_final',
help='Path to the dataset directory.')
parser.add_argument('--csv_path', type=str, default='../data/t15_copyTaskData_description.csv',
help='Path to the CSV file with metadata.')
parser.add_argument('--gpu_number', type=int, default=0,
help='GPU number to use for model inference.')
parser.add_argument('--eval_type', type=str, default='val', choices=['val', 'test'],
help='Evaluation type.')
# 搜索空间参数
parser.add_argument('--gru_weights', type=str, default='0.4,0.5,0.6,0.7,0.8,1.0',
help='Comma-separated GRU weights to search.')
parser.add_argument('--tta_noise_weights', type=str, default='0.0,0.5,1.0',
help='Comma-separated noise weights to search.')
parser.add_argument('--tta_scale_weights', type=str, default='0.0,0.5,1.0',
help='Comma-separated scale weights to search.')
parser.add_argument('--tta_shift_weights', type=str, default='0.0,0.5,1.0',
help='Comma-separated shift weights to search.')
parser.add_argument('--tta_smooth_weights', type=str, default='0.0,0.5,1.0',
help='Comma-separated smooth weights to search.')
# TTA固定参数
parser.add_argument('--tta_noise_std', type=float, default=0.01,
help='Standard deviation for TTA noise augmentation.')
parser.add_argument('--tta_smooth_range', type=float, default=0.5,
help='Range for TTA smoothing kernel variation.')
parser.add_argument('--tta_scale_range', type=float, default=0.05,
help='Range for TTA amplitude scaling.')
parser.add_argument('--tta_cut_max', type=int, default=3,
help='Maximum timesteps for TTA shift.')
# 输出控制
parser.add_argument('--cache_file', type=str, default='tta_predictions_cache.pkl',
help='File to cache model predictions.')
parser.add_argument('--results_file', type=str, default='parameter_search_results.json',
help='File to save search results.')
parser.add_argument('--force_recache', action='store_true',
help='Force re-computation of predictions cache.')
return parser.parse_args()
def generate_all_base_configs():
"""生成所有需要运行的基础配置(每种增强单独运行)"""
base_configs = [
{'name': 'original', 'tta_weights': {'original': 1.0, 'noise': 0.0, 'scale': 0.0, 'shift': 0.0, 'smooth': 0.0}},
{'name': 'noise', 'tta_weights': {'original': 0.0, 'noise': 1.0, 'scale': 0.0, 'shift': 0.0, 'smooth': 0.0}},
{'name': 'scale', 'tta_weights': {'original': 0.0, 'noise': 0.0, 'scale': 1.0, 'shift': 0.0, 'smooth': 0.0}},
{'name': 'shift', 'tta_weights': {'original': 0.0, 'noise': 0.0, 'scale': 0.0, 'shift': 1.0, 'smooth': 0.0}},
{'name': 'smooth', 'tta_weights': {'original': 0.0, 'noise': 0.0, 'scale': 0.0, 'shift': 0.0, 'smooth': 1.0}},
]
return base_configs
def run_single_tta_prediction(x, input_layer, gru_model, lstm_model, gru_model_args, lstm_model_args,
device, aug_type, tta_noise_std, tta_smooth_range, tta_scale_range, tta_cut_max):
"""运行单个TTA增强的预测"""
x_augmented = x.clone()
# Get default smoothing parameters
default_smooth_std = gru_model_args['dataset']['data_transforms']['smooth_kernel_std']
default_smooth_size = gru_model_args['dataset']['data_transforms']['smooth_kernel_size']
if aug_type == 'original':
pass
elif aug_type == 'noise':
noise_scale = tta_noise_std * (0.5 + 0.5 * np.random.rand())
noise = torch.randn_like(x_augmented) * noise_scale
x_augmented = x_augmented + noise
elif aug_type == 'scale':
scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * tta_scale_range
x_augmented = x_augmented * scale_factor
elif aug_type == 'shift' and tta_cut_max > 0:
shift_amount = np.random.randint(1, min(tta_cut_max + 1, x_augmented.shape[1] // 8))
x_augmented = torch.cat([x_augmented[:, shift_amount:, :],
x_augmented[:, :shift_amount, :]], dim=1)
elif aug_type == 'smooth':
smooth_variation = (torch.rand(1).item() - 0.5) * 2 * tta_smooth_range
varied_smooth_std = max(0.3, default_smooth_std + smooth_variation)
# Use autocast for efficiency
with torch.autocast(device_type="cuda", enabled=gru_model_args['use_amp'], dtype=torch.bfloat16):
# Apply Gaussian smoothing
if aug_type == 'smooth':
x_smoothed = gauss_smooth(
inputs=x_augmented,
device=device,
smooth_kernel_std=varied_smooth_std,
smooth_kernel_size=default_smooth_size,
padding='valid',
)
else:
x_smoothed = gauss_smooth(
inputs=x_augmented,
device=device,
smooth_kernel_std=default_smooth_std,
smooth_kernel_size=default_smooth_size,
padding='valid',
)
with torch.no_grad():
# Get GRU logits and convert to probabilities
gru_logits, _ = gru_model(
x=x_smoothed,
day_idx=torch.tensor([input_layer], device=device),
states=None,
return_state=True,
)
gru_probs = torch.softmax(gru_logits, dim=-1)
# Get LSTM logits and convert to probabilities
lstm_logits, _ = lstm_model(
x=x_smoothed,
day_idx=torch.tensor([input_layer], device=device),
states=None,
return_state=True,
)
lstm_probs = torch.softmax(lstm_logits, dim=-1)
return gru_probs.float().cpu().numpy(), lstm_probs.float().cpu().numpy()
def cache_model_predictions(args):
"""缓存所有基础模型预测结果"""
print("=== 第一阶段:缓存模型预测结果 ===")
# 设置设备
if torch.cuda.is_available() and args.gpu_number >= 0:
device = torch.device(f'cuda:{args.gpu_number}')
else:
device = torch.device('cpu')
print(f'Using device: {device}')
# 加载模型
print("Loading models...")
gru_model_args = OmegaConf.load(os.path.join(args.gru_model_path, 'checkpoint/args.yaml'))
lstm_model_args = OmegaConf.load(os.path.join(args.lstm_model_path, 'checkpoint/args.yaml'))
# Define GRU model
gru_model = GRUDecoder(
neural_dim=gru_model_args['model']['n_input_features'],
n_units=gru_model_args['model']['n_units'],
n_days=len(gru_model_args['dataset']['sessions']),
n_classes=gru_model_args['dataset']['n_classes'],
rnn_dropout=gru_model_args['model']['rnn_dropout'],
input_dropout=gru_model_args['model']['input_network']['input_layer_dropout'],
n_layers=gru_model_args['model']['n_layers'],
patch_size=gru_model_args['model']['patch_size'],
patch_stride=gru_model_args['model']['patch_stride'],
)
# Load GRU model weights
gru_checkpoint = torch.load(os.path.join(args.gru_model_path, 'checkpoint/best_checkpoint'),
weights_only=False, map_location=device)
for key in list(gru_checkpoint['model_state_dict'].keys()):
gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key)
gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key)
gru_model.load_state_dict(gru_checkpoint['model_state_dict'])
# Define LSTM model
lstm_model = LSTMDecoder(
neural_dim=lstm_model_args['model']['n_input_features'],
n_units=lstm_model_args['model']['n_units'],
n_days=len(lstm_model_args['dataset']['sessions']),
n_classes=lstm_model_args['dataset']['n_classes'],
rnn_dropout=lstm_model_args['model']['rnn_dropout'],
input_dropout=lstm_model_args['model']['input_network']['input_layer_dropout'],
n_layers=lstm_model_args['model']['n_layers'],
patch_size=lstm_model_args['model']['patch_size'],
patch_stride=lstm_model_args['model']['patch_stride'],
)
# Load LSTM model weights
lstm_checkpoint = torch.load(os.path.join(args.lstm_model_path, 'checkpoint/best_checkpoint'),
weights_only=False, map_location=device)
for key in list(lstm_checkpoint['model_state_dict'].keys()):
lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
lstm_model.load_state_dict(lstm_checkpoint['model_state_dict'])
gru_model.to(device)
lstm_model.to(device)
gru_model.eval()
lstm_model.eval()
# 加载数据
print("Loading data...")
b2txt_csv_df = pd.read_csv(args.csv_path)
test_data = {}
total_trials = 0
for session in gru_model_args['dataset']['sessions']:
files = [f for f in os.listdir(os.path.join(args.data_dir, session)) if f.endswith('.hdf5')]
if f'data_{args.eval_type}.hdf5' in files:
eval_file = os.path.join(args.data_dir, session, f'data_{args.eval_type}.hdf5')
data = load_h5py_file(eval_file, b2txt_csv_df)
test_data[session] = data
total_trials += len(test_data[session]["neural_features"])
print(f'Loaded {len(test_data[session]["neural_features"])} {args.eval_type} trials for session {session}.')
print(f'Total trials: {total_trials}')
# 生成所有基础预测
base_configs = generate_all_base_configs()
predictions_cache = {}
for config in base_configs:
config_name = config['name']
print(f"\nGenerating predictions for: {config_name}")
predictions_cache[config_name] = {}
with tqdm(total=total_trials, desc=f'Processing {config_name}', unit='trial') as pbar:
for session, data in test_data.items():
input_layer = gru_model_args['dataset']['sessions'].index(session)
session_predictions = []
for trial in range(len(data['neural_features'])):
neural_input = data['neural_features'][trial]
neural_input = np.expand_dims(neural_input, axis=0)
neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16)
gru_probs, lstm_probs = run_single_tta_prediction(
neural_input, input_layer, gru_model, lstm_model,
gru_model_args, lstm_model_args, device, config_name,
args.tta_noise_std, args.tta_smooth_range,
args.tta_scale_range, args.tta_cut_max
)
session_predictions.append({
'gru_probs': gru_probs,
'lstm_probs': lstm_probs,
'trial_info': {
'session': session,
'block_num': data['block_num'][trial],
'trial_num': data['trial_num'][trial],
'seq_class_ids': data['seq_class_ids'][trial] if args.eval_type == 'val' else None,
'seq_len': data['seq_len'][trial] if args.eval_type == 'val' else None,
'sentence_label': data['sentence_label'][trial] if args.eval_type == 'val' else None,
}
})
pbar.update(1)
predictions_cache[config_name][session] = session_predictions
# 保存缓存
print(f"\nSaving predictions cache to {args.cache_file}...")
with open(args.cache_file, 'wb') as f:
pickle.dump(predictions_cache, f)
print("✓ Predictions cache saved successfully!")
return predictions_cache
def ensemble_and_evaluate(predictions_cache, gru_weight, tta_weights, eval_type='val'):
"""基于缓存的预测结果进行集成和评估"""
lstm_weight = 1.0 - gru_weight
epsilon = 1e-8
total_trials = 0
total_edit_distance = 0
total_true_length = 0
# 检查哪些增强被启用
enabled_augmentations = [aug_type for aug_type, weight in tta_weights.items() if weight > 0]
if len(enabled_augmentations) == 0:
return float('inf') # 无效配置
# 归一化TTA权重
total_tta_weight = sum(weight for weight in tta_weights.values() if weight > 0)
normalized_tta_weights = {k: v/total_tta_weight for k, v in tta_weights.items() if v > 0}
for session in predictions_cache['original'].keys():
session_predictions = predictions_cache['original'][session]
for trial_idx in range(len(session_predictions)):
trial_info = session_predictions[trial_idx]['trial_info']
if eval_type == 'val' and trial_info['seq_class_ids'] is None:
continue
# 收集所有启用增强的概率
weighted_gru_probs = None
weighted_lstm_probs = None
for aug_type in enabled_augmentations:
weight = normalized_tta_weights[aug_type]
gru_probs = predictions_cache[aug_type][session][trial_idx]['gru_probs']
lstm_probs = predictions_cache[aug_type][session][trial_idx]['lstm_probs']
gru_probs = torch.tensor(gru_probs)
lstm_probs = torch.tensor(lstm_probs)
if weighted_gru_probs is None:
weighted_gru_probs = weight * gru_probs
weighted_lstm_probs = weight * lstm_probs
else:
# 处理序列长度不同的情况
min_len = min(weighted_gru_probs.shape[1], gru_probs.shape[1])
weighted_gru_probs = weighted_gru_probs[:, :min_len, :] + weight * gru_probs[:, :min_len, :]
weighted_lstm_probs = weighted_lstm_probs[:, :min_len, :] + weight * lstm_probs[:, :min_len, :]
# 集成GRU和LSTM几何平均
weighted_gru_probs = weighted_gru_probs + epsilon
weighted_lstm_probs = weighted_lstm_probs + epsilon
log_ensemble_probs = (gru_weight * torch.log(weighted_gru_probs) +
lstm_weight * torch.log(weighted_lstm_probs))
ensemble_probs = torch.exp(log_ensemble_probs)
ensemble_probs = ensemble_probs / ensemble_probs.sum(dim=-1, keepdim=True)
# 解码预测序列
pred_seq = torch.argmax(ensemble_probs[0], dim=-1).numpy()
pred_seq = [int(p) for p in pred_seq if p != 0]
pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
pred_phonemes = [LOGIT_TO_PHONEME[p] for p in pred_seq]
if eval_type == 'val':
# 计算PER
true_seq = trial_info['seq_class_ids'][0:trial_info['seq_len']]
true_phonemes = [LOGIT_TO_PHONEME[p] for p in true_seq]
ed = editdistance.eval(true_phonemes, pred_phonemes)
total_edit_distance += ed
total_true_length += len(true_phonemes)
total_trials += 1
if eval_type == 'val' and total_true_length > 0:
per = 100 * total_edit_distance / total_true_length
return per
else:
return 0.0 # test模式返回0
def search_optimal_parameters(predictions_cache, args):
"""搜索最优参数组合"""
print("\n=== 第二阶段:参数搜索 ===")
# 解析搜索空间
gru_weights = [float(x) for x in args.gru_weights.split(',')]
noise_weights = [float(x) for x in args.tta_noise_weights.split(',')]
scale_weights = [float(x) for x in args.tta_scale_weights.split(',')]
shift_weights = [float(x) for x in args.tta_shift_weights.split(',')]
smooth_weights = [float(x) for x in args.tta_smooth_weights.split(',')]
print(f"Search space:")
print(f" GRU weights: {gru_weights}")
print(f" Noise weights: {noise_weights}")
print(f" Scale weights: {scale_weights}")
print(f" Shift weights: {shift_weights}")
print(f" Smooth weights: {smooth_weights}")
# 生成所有参数组合
all_combinations = list(itertools.product(gru_weights, noise_weights, scale_weights, shift_weights, smooth_weights))
total_combinations = len(all_combinations)
print(f"Total combinations to evaluate: {total_combinations}")
best_per = float('inf')
best_config = None
results = []
with tqdm(total=total_combinations, desc='Parameter search', unit='config') as pbar:
for gru_w, noise_w, scale_w, shift_w, smooth_w in all_combinations:
tta_weights = {
'original': 1.0, # 总是包含原始数据
'noise': noise_w,
'scale': scale_w,
'shift': shift_w,
'smooth': smooth_w
}
per = ensemble_and_evaluate(predictions_cache, gru_w, tta_weights, args.eval_type)
config = {
'gru_weight': gru_w,
'lstm_weight': 1.0 - gru_w,
'tta_weights': tta_weights,
'per': per
}
results.append(config)
if per < best_per:
best_per = per
best_config = config
print(f"\n🎯 New best PER: {per:.3f}%")
print(f" GRU weight: {gru_w:.1f}")
print(f" TTA weights: {tta_weights}")
pbar.update(1)
return results, best_config
def main():
args = parse_arguments()
print("TTA-E Parameter Search")
print("=" * 50)
# 第一阶段:缓存预测结果
if args.force_recache or not os.path.exists(args.cache_file):
predictions_cache = cache_model_predictions(args)
else:
print(f"Loading existing predictions cache from {args.cache_file}...")
with open(args.cache_file, 'rb') as f:
predictions_cache = pickle.load(f)
print("✓ Cache loaded successfully!")
# 第二阶段:参数搜索
results, best_config = search_optimal_parameters(predictions_cache, args)
# 保存结果
print(f"\n=== 搜索完成 ===")
print(f"Best configuration:")
print(f" PER: {best_config['per']:.3f}%")
print(f" GRU weight: {best_config['gru_weight']:.1f}")
print(f" LSTM weight: {best_config['lstm_weight']:.1f}")
print(f" TTA weights: {best_config['tta_weights']}")
# 保存所有结果
search_results = {
'best_config': best_config,
'all_results': results,
'search_args': vars(args),
'timestamp': time.strftime("%Y-%m-%d %H:%M:%S")
}
with open(args.results_file, 'w') as f:
json.dump(search_results, f, indent=2)
print(f"\n✓ Results saved to {args.results_file}")
# 显示前10个最佳配置
sorted_results = sorted(results, key=lambda x: x['per'])
print(f"\nTop 10 configurations:")
for i, config in enumerate(sorted_results[:10]):
print(f"{i+1:2d}. PER={config['per']:6.3f}% | GRU={config['gru_weight']:.1f} | TTA={config['tta_weights']}")
if __name__ == "__main__":
main()