#!/usr/bin/env python3 """ 遗传算法优化TTA-E集成参数 使用PyGAD优化gru_weight和tta_weights参数,目标是最小化PER """ import os import sys import torch import numpy as np import pandas as pd from omegaconf import OmegaConf import time from tqdm import tqdm import editdistance import pygad import pickle import multiprocessing as mp from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor import hashlib from functools import lru_cache import cupy as cp import warnings warnings.filterwarnings("ignore") # Add parent directories to path to import models sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training')) sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm')) from model_training.rnn_model import GRUDecoder from model_training_lstm.rnn_model import LSTMDecoder from model_training.evaluate_model_helpers import * class TTAEnsembleCache: """高效缓存系统,存储GRU和LSTM在5种增强方式下的预测结果""" def __init__(self, cache_dir='./tta_cache'): self.cache_dir = cache_dir os.makedirs(cache_dir, exist_ok=True) self.gru_cache = {} self.lstm_cache = {} self.augmentation_types = ['original', 'noise', 'scale', 'shift', 'smooth'] def _get_cache_key(self, session, trial, aug_type): """生成缓存键""" return f"{session}_{trial}_{aug_type}" def _get_cache_file(self, model_type): """获取缓存文件路径""" return os.path.join(self.cache_dir, f'{model_type}_predictions.pkl') def save_cache(self): """保存缓存到磁盘""" with open(self._get_cache_file('gru'), 'wb') as f: pickle.dump(self.gru_cache, f) with open(self._get_cache_file('lstm'), 'wb') as f: pickle.dump(self.lstm_cache, f) def load_cache(self): """从磁盘加载缓存""" try: with open(self._get_cache_file('gru'), 'rb') as f: self.gru_cache = pickle.load(f) with open(self._get_cache_file('lstm'), 'rb') as f: self.lstm_cache = pickle.load(f) return True except FileNotFoundError: return False def add_prediction(self, model_type, session, trial, aug_type, prediction): """添加预测结果到缓存""" cache = self.gru_cache if model_type == 'gru' else self.lstm_cache key = self._get_cache_key(session, trial, aug_type) cache[key] = prediction def get_prediction(self, model_type, session, trial, aug_type): """从缓存获取预测结果""" cache = self.gru_cache if model_type == 'gru' else self.lstm_cache key = self._get_cache_key(session, trial, aug_type) return cache.get(key, None) def is_complete(self, sessions, trials_per_session): """检查缓存是否完整""" total_expected = 0 total_cached_gru = 0 total_cached_lstm = 0 for session in sessions: for trial in range(trials_per_session[session]): for aug_type in self.augmentation_types: total_expected += 1 if self.get_prediction('gru', session, trial, aug_type) is not None: total_cached_gru += 1 if self.get_prediction('lstm', session, trial, aug_type) is not None: total_cached_lstm += 1 return (total_cached_gru == total_expected and total_cached_lstm == total_expected) class TTAEGeneticOptimizer: """使用遗传算法优化TTA-E参数的主类""" def __init__(self, gru_model_path='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline', lstm_model_path='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn', data_dir='../data/hdf5_data_final', csv_path='../data/t15_copyTaskData_description.csv', gpu_number=0): self.gru_model_path = gru_model_path self.lstm_model_path = lstm_model_path self.data_dir = data_dir self.csv_path = csv_path self.gpu_number = gpu_number # 初始化设备 if torch.cuda.is_available() and gpu_number >= 0: self.device = torch.device(f'cuda:{gpu_number}') print(f'Using {self.device} for model inference.') else: self.device = torch.device('cpu') print('Using CPU for model inference.') # 初始化缓存 self.cache = TTAEnsembleCache() # TTA参数 self.tta_noise_std = 0.01 self.tta_smooth_range = 0.5 self.tta_scale_range = 0.05 self.tta_cut_max = 3 # 遗传算法参数 self.population_size = 20 self.num_generations = 20 self.num_parents_mating = 5 self.mutation_percent_genes = 20 # 加载模型和数据 self._load_models() self._load_data() def _load_models(self): """加载GRU和LSTM模型""" print("Loading models...") # 加载模型参数 self.gru_model_args = OmegaConf.load(os.path.join(self.gru_model_path, 'checkpoint/args.yaml')) self.lstm_model_args = OmegaConf.load(os.path.join(self.lstm_model_path, 'checkpoint/args.yaml')) # 定义GRU模型 self.gru_model = GRUDecoder( neural_dim=self.gru_model_args['model']['n_input_features'], n_units=self.gru_model_args['model']['n_units'], n_days=len(self.gru_model_args['dataset']['sessions']), n_classes=self.gru_model_args['dataset']['n_classes'], rnn_dropout=self.gru_model_args['model']['rnn_dropout'], input_dropout=self.gru_model_args['model']['input_network']['input_layer_dropout'], n_layers=self.gru_model_args['model']['n_layers'], patch_size=self.gru_model_args['model']['patch_size'], patch_stride=self.gru_model_args['model']['patch_stride'], ) # 加载GRU模型权重 gru_checkpoint = torch.load(os.path.join(self.gru_model_path, 'checkpoint/best_checkpoint'), weights_only=False, map_location=self.device) # 处理模型键名 for key in list(gru_checkpoint['model_state_dict'].keys()): gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key) gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key) self.gru_model.load_state_dict(gru_checkpoint['model_state_dict']) # 定义LSTM模型 self.lstm_model = LSTMDecoder( neural_dim=self.lstm_model_args['model']['n_input_features'], n_units=self.lstm_model_args['model']['n_units'], n_days=len(self.lstm_model_args['dataset']['sessions']), n_classes=self.lstm_model_args['dataset']['n_classes'], rnn_dropout=self.lstm_model_args['model']['rnn_dropout'], input_dropout=self.lstm_model_args['model']['input_network']['input_layer_dropout'], n_layers=self.lstm_model_args['model']['n_layers'], patch_size=self.lstm_model_args['model']['patch_size'], patch_stride=self.lstm_model_args['model']['patch_stride'], ) # 加载LSTM模型权重 lstm_checkpoint = torch.load(os.path.join(self.lstm_model_path, 'checkpoint/best_checkpoint'), weights_only=False, map_location=self.device) # 处理模型键名 for key in list(lstm_checkpoint['model_state_dict'].keys()): lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key) lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key) self.lstm_model.load_state_dict(lstm_checkpoint['model_state_dict']) # 移动模型到设备并设置为评估模式 self.gru_model.to(self.device) self.lstm_model.to(self.device) self.gru_model.eval() self.lstm_model.eval() print("Models loaded successfully!") def _load_data(self): """加载验证数据集""" print("Loading validation data...") # 加载CSV文件 b2txt_csv_df = pd.read_csv(self.csv_path) # 加载验证数据 self.test_data = {} self.trials_per_session = {} total_test_trials = 0 for session in self.gru_model_args['dataset']['sessions']: files = [f for f in os.listdir(os.path.join(self.data_dir, session)) if f.endswith('.hdf5')] if 'data_val.hdf5' in files: eval_file = os.path.join(self.data_dir, session, 'data_val.hdf5') data = load_h5py_file(eval_file, b2txt_csv_df) self.test_data[session] = data self.trials_per_session[session] = len(data["neural_features"]) total_test_trials += len(data["neural_features"]) print(f'Loaded {len(data["neural_features"])} validation trials for session {session}.') print(f'Total number of validation trials: {total_test_trials}') def _apply_augmentation(self, x, aug_type): """应用数据增强""" x_augmented = x.clone() if aug_type == 'original': pass elif aug_type == 'noise': noise_scale = self.tta_noise_std * (0.5 + 0.5 * np.random.rand()) noise = torch.randn_like(x_augmented) * noise_scale x_augmented = x_augmented + noise elif aug_type == 'scale': scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * self.tta_scale_range x_augmented = x_augmented * scale_factor elif aug_type == 'shift' and self.tta_cut_max > 0: shift_amount = np.random.randint(1, min(self.tta_cut_max + 1, x_augmented.shape[1] // 8)) x_augmented = torch.cat([x_augmented[:, shift_amount:, :], x_augmented[:, :shift_amount, :]], dim=1) elif aug_type == 'smooth': smooth_variation = (torch.rand(1).item() - 0.5) * 2 * self.tta_smooth_range varied_smooth_std = max(0.3, self.gru_model_args['dataset']['data_transforms']['smooth_kernel_std'] + smooth_variation) return x_augmented def _get_model_prediction(self, model, model_args, x_smoothed, input_layer): """获取单个模型的预测结果""" with torch.no_grad(): logits, _ = model( x=x_smoothed, day_idx=torch.tensor([input_layer], device=self.device), states=None, return_state=True, ) probs = torch.softmax(logits, dim=-1) return probs def generate_all_predictions(self): """生成所有模型在所有增强方式下的预测结果并缓存""" print("Generating all TTA predictions for caching...") # 尝试加载现有缓存 if self.cache.load_cache(): if self.cache.is_complete(self.test_data.keys(), self.trials_per_session): print("Complete cache found, skipping prediction generation.") return else: print("Incomplete cache found, generating missing predictions...") total_trials = sum(self.trials_per_session.values()) total_predictions = total_trials * len(self.cache.augmentation_types) * 2 # 2 models with tqdm(total=total_predictions, desc='Generating cached predictions', unit='pred') as pbar: for session, data in self.test_data.items(): input_layer = self.gru_model_args['dataset']['sessions'].index(session) for trial in range(len(data['neural_features'])): # 获取神经输入 neural_input = data['neural_features'][trial] neural_input = np.expand_dims(neural_input, axis=0) neural_input = torch.tensor(neural_input, device=self.device, dtype=torch.bfloat16) # 为每种增强方式生成预测 for aug_type in self.cache.augmentation_types: # 检查是否已缓存 if (self.cache.get_prediction('gru', session, trial, aug_type) is not None and self.cache.get_prediction('lstm', session, trial, aug_type) is not None): pbar.update(2) continue # 应用增强 x_augmented = self._apply_augmentation(neural_input, aug_type) # 应用高斯平滑 default_smooth_std = self.gru_model_args['dataset']['data_transforms']['smooth_kernel_std'] default_smooth_size = self.gru_model_args['dataset']['data_transforms']['smooth_kernel_size'] if aug_type == 'smooth': smooth_variation = (torch.rand(1).item() - 0.5) * 2 * self.tta_smooth_range varied_smooth_std = max(0.3, default_smooth_std + smooth_variation) else: varied_smooth_std = default_smooth_std with torch.autocast(device_type="cuda", enabled=self.gru_model_args['use_amp'], dtype=torch.bfloat16): x_smoothed = gauss_smooth( inputs=x_augmented, device=self.device, smooth_kernel_std=varied_smooth_std, smooth_kernel_size=default_smooth_size, padding='valid', ) # GRU预测 if self.cache.get_prediction('gru', session, trial, aug_type) is None: gru_probs = self._get_model_prediction( self.gru_model, self.gru_model_args, x_smoothed, input_layer ) self.cache.add_prediction('gru', session, trial, aug_type, gru_probs.cpu().numpy()) pbar.update(1) # LSTM预测 if self.cache.get_prediction('lstm', session, trial, aug_type) is None: lstm_probs = self._get_model_prediction( self.lstm_model, self.lstm_model_args, x_smoothed, input_layer ) self.cache.add_prediction('lstm', session, trial, aug_type, lstm_probs.cpu().numpy()) pbar.update(1) # 保存缓存 print("Saving cache to disk...") self.cache.save_cache() print("Cache generation completed!") def evaluate_parameters(self, gru_weight, tta_weights): """评估给定参数组合的PER性能""" lstm_weight = 1.0 - gru_weight # 将tta_weights转换为字典 tta_weights_dict = { 'original': tta_weights[0], 'noise': tta_weights[1], 'scale': tta_weights[2], 'shift': tta_weights[3], 'smooth': tta_weights[4] } total_true_length = 0 total_edit_distance = 0 for session, data in self.test_data.items(): for trial in range(len(data['neural_features'])): # 收集所有增强方式的预测结果 all_gru_probs = [] all_lstm_probs = [] sample_weights = [] for aug_type in self.cache.augmentation_types: if tta_weights_dict[aug_type] <= 0: continue # 从缓存获取预测结果 gru_probs = self.cache.get_prediction('gru', session, trial, aug_type) lstm_probs = self.cache.get_prediction('lstm', session, trial, aug_type) if gru_probs is not None and lstm_probs is not None: all_gru_probs.append(torch.tensor(gru_probs)) all_lstm_probs.append(torch.tensor(lstm_probs)) sample_weights.append(tta_weights_dict[aug_type]) if len(all_gru_probs) == 0: continue # TTA融合 if len(all_gru_probs) > 1: min_length = min([probs.shape[1] for probs in all_gru_probs + all_lstm_probs]) # 截断到最小长度 truncated_gru_probs = [] truncated_lstm_probs = [] for gru_probs, lstm_probs in zip(all_gru_probs, all_lstm_probs): if gru_probs.shape[1] > min_length: truncated_gru_probs.append(gru_probs[:, :min_length, :]) else: truncated_gru_probs.append(gru_probs) if lstm_probs.shape[1] > min_length: truncated_lstm_probs.append(lstm_probs[:, :min_length, :]) else: truncated_lstm_probs.append(lstm_probs) # 加权平均 sample_weights_tensor = torch.tensor(sample_weights, dtype=torch.float32) sample_weights_tensor = sample_weights_tensor / sample_weights_tensor.sum() weighted_gru_probs = torch.zeros_like(truncated_gru_probs[0]) weighted_lstm_probs = torch.zeros_like(truncated_lstm_probs[0]) for i, (gru_probs, lstm_probs, weight) in enumerate(zip(truncated_gru_probs, truncated_lstm_probs, sample_weights_tensor)): weighted_gru_probs += weight * gru_probs weighted_lstm_probs += weight * lstm_probs avg_gru_probs = weighted_gru_probs avg_lstm_probs = weighted_lstm_probs else: avg_gru_probs = all_gru_probs[0] avg_lstm_probs = all_lstm_probs[0] # 集成融合 (几何平均) epsilon = 1e-8 avg_gru_probs = avg_gru_probs + epsilon avg_lstm_probs = avg_lstm_probs + epsilon log_ensemble_probs = (gru_weight * torch.log(avg_gru_probs) + lstm_weight * torch.log(avg_lstm_probs)) ensemble_probs = torch.exp(log_ensemble_probs) ensemble_probs = ensemble_probs / ensemble_probs.sum(dim=-1, keepdim=True) # 转换为预测序列 logits = torch.log(ensemble_probs + epsilon) pred_seq = np.argmax(logits[0].numpy(), axis=-1) # 移除空白和连续重复 pred_seq = [int(p) for p in pred_seq if p != 0] pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]] pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq] # 获取真实序列 true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]] true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq] # 计算编辑距离 ed = editdistance.eval(true_seq, pred_seq) total_true_length += len(true_seq) total_edit_distance += ed # 计算PER if total_true_length == 0: return 100.0 # 返回最大PER作为惩罚 per = 100 * total_edit_distance / total_true_length return per def fitness_function(self, ga_instance, solution, solution_idx): """遗传算法的适应度函数""" # 解码参数 gru_weight = solution[0] # 范围 [0, 1] tta_weights = solution[1:6] # 5个TTA权重 # 确保权重非负 tta_weights = np.maximum(tta_weights, 0) # 如果所有TTA权重都为0,返回最低适应度 if np.sum(tta_weights) == 0: return -100.0 try: per = self.evaluate_parameters(gru_weight, tta_weights) # 适应度 = -PER (因为我们要最小化PER) fitness = -per return fitness except Exception as e: print(f"Error in fitness evaluation: {e}") return -100.0 def on_generation(self, ga_instance): """每代结束时的回调函数""" solution, solution_fitness, solution_idx = ga_instance.best_solution() print(f"Generation {ga_instance.generations_completed}") print(f"Best solution: GRU weight={solution[0]:.3f}, TTA weights={solution[1:6]}") print(f"Best fitness (negative PER): {solution_fitness:.3f}") print(f"Best PER: {-solution_fitness:.3f}%") print("-" * 50) def optimize(self): """运行遗传算法优化""" print("Starting genetic algorithm optimization...") # 首先生成所有预测并缓存 self.generate_all_predictions() # 定义参数边界 # gru_weight: [0, 1] # tta_weights: [0, 5] for each weight gene_space = [ {'low': 0.0, 'high': 1.0}, # gru_weight {'low': 0.0, 'high': 5.0}, # original weight {'low': 0.0, 'high': 5.0}, # noise weight {'low': 0.0, 'high': 5.0}, # scale weight {'low': 0.0, 'high': 5.0}, # shift weight {'low': 0.0, 'high': 5.0}, # smooth weight ] # 创建遗传算法实例 ga_instance = pygad.GA( num_generations=self.num_generations, num_parents_mating=self.num_parents_mating, fitness_func=self.fitness_function, sol_per_pop=self.population_size, num_genes=6, # 1个gru_weight + 5个tta_weights gene_space=gene_space, mutation_percent_genes=self.mutation_percent_genes, parent_selection_type="sss", # steady-state selection keep_parents=2, # 保持2个最佳父代,必须 <= num_parents_mating (5) crossover_type="single_point", mutation_type="random", on_generation=self.on_generation, parallel_processing=['thread', mp.cpu_count()//2], # 使用线程并行处理 save_solutions=True, ) # 运行优化 print(f"Running optimization with {self.population_size} population size for {self.num_generations} generations...") start_time = time.time() ga_instance.run() end_time = time.time() # 获取最佳解 solution, solution_fitness, solution_idx = ga_instance.best_solution() print("\n" + "="*60) print("OPTIMIZATION COMPLETED!") print("="*60) print(f"Optimization time: {end_time - start_time:.2f} seconds") print(f"Best GRU weight: {solution[0]:.4f}") print(f"Best LSTM weight: {1.0 - solution[0]:.4f}") print(f"Best TTA weights:") aug_types = ['original', 'noise', 'scale', 'shift', 'smooth'] for i, aug_type in enumerate(aug_types): print(f" - {aug_type}: {solution[i+1]:.4f}") print(f"Best PER: {-solution_fitness:.4f}%") # 保存结果 result = { 'gru_weight': solution[0], 'lstm_weight': 1.0 - solution[0], 'tta_weights': { 'original': solution[1], 'noise': solution[2], 'scale': solution[3], 'shift': solution[4], 'smooth': solution[5] }, 'best_per': -solution_fitness, 'optimization_time': end_time - start_time, 'generations': self.num_generations, 'population_size': self.population_size } # 保存到文件 timestamp = time.strftime("%Y%m%d_%H%M%S") result_file = f'ga_optimization_result_{timestamp}.pkl' with open(result_file, 'wb') as f: pickle.dump(result, f) print(f"Results saved to: {result_file}") return result def main(): """主函数""" print("TTA-E Genetic Algorithm Optimization") print("="*50) # 创建优化器 optimizer = TTAEGeneticOptimizer( gru_model_path='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline', lstm_model_path='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn', data_dir='../data/hdf5_data_final', csv_path='../data/t15_copyTaskData_description.csv', gpu_number=0 ) # 运行优化 result = optimizer.optimize() print("\nOptimization completed successfully!") return result if __name__ == "__main__": # 设置环境 os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 运行主函数 result = main()