629 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			629 lines
		
	
	
		
			28 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| #!/usr/bin/env python3
 | ||
| """
 | ||
| 差分进化算法优化TTA-E集成参数
 | ||
| 使用SciPy的差分进化算法优化gru_weight和tta_weights参数,目标是最小化PER
 | ||
| """
 | ||
| 
 | ||
| import os
 | ||
| import sys
 | ||
| import torch
 | ||
| import numpy as np
 | ||
| import pandas as pd
 | ||
| from omegaconf import OmegaConf
 | ||
| import time
 | ||
| from tqdm import tqdm
 | ||
| import editdistance
 | ||
| import pickle
 | ||
| import multiprocessing as mp
 | ||
| from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
 | ||
| from scipy.optimize import differential_evolution
 | ||
| import warnings
 | ||
| warnings.filterwarnings("ignore")
 | ||
| 
 | ||
| # Add parent directories to path to import models
 | ||
| sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training'))
 | ||
| sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'model_training_lstm'))
 | ||
| 
 | ||
| from model_training.rnn_model import GRUDecoder
 | ||
| from model_training_lstm.rnn_model import LSTMDecoder
 | ||
| from model_training.evaluate_model_helpers import *
 | ||
| 
 | ||
| class TTAEnsembleCache:
 | ||
|     """高效缓存系统,存储GRU和LSTM在5种增强方式下的预测结果"""
 | ||
|     
 | ||
|     def __init__(self, cache_dir='./tta_cache'):
 | ||
|         self.cache_dir = cache_dir
 | ||
|         os.makedirs(cache_dir, exist_ok=True)
 | ||
|         self.gru_cache = {}
 | ||
|         self.lstm_cache = {}
 | ||
|         self.augmentation_types = ['original', 'noise', 'scale', 'shift', 'smooth']
 | ||
|         
 | ||
|     def _get_cache_key(self, session, trial, aug_type):
 | ||
|         """生成缓存键"""
 | ||
|         return f"{session}_{trial}_{aug_type}"
 | ||
|     
 | ||
|     def _get_cache_file(self, model_type):
 | ||
|         """获取缓存文件路径"""
 | ||
|         return os.path.join(self.cache_dir, f'{model_type}_predictions.pkl')
 | ||
|     
 | ||
|     def save_cache(self):
 | ||
|         """保存缓存到磁盘"""
 | ||
|         with open(self._get_cache_file('gru'), 'wb') as f:
 | ||
|             pickle.dump(self.gru_cache, f)
 | ||
|         with open(self._get_cache_file('lstm'), 'wb') as f:
 | ||
|             pickle.dump(self.lstm_cache, f)
 | ||
|     
 | ||
|     def load_cache(self):
 | ||
|         """从磁盘加载缓存"""
 | ||
|         try:
 | ||
|             with open(self._get_cache_file('gru'), 'rb') as f:
 | ||
|                 self.gru_cache = pickle.load(f)
 | ||
|             with open(self._get_cache_file('lstm'), 'rb') as f:
 | ||
|                 self.lstm_cache = pickle.load(f)
 | ||
|             return True
 | ||
|         except FileNotFoundError:
 | ||
|             return False
 | ||
|     
 | ||
|     def add_prediction(self, model_type, session, trial, aug_type, prediction):
 | ||
|         """添加预测结果到缓存"""
 | ||
|         cache = self.gru_cache if model_type == 'gru' else self.lstm_cache
 | ||
|         key = self._get_cache_key(session, trial, aug_type)
 | ||
|         cache[key] = prediction
 | ||
|     
 | ||
|     def get_prediction(self, model_type, session, trial, aug_type):
 | ||
|         """从缓存获取预测结果"""
 | ||
|         cache = self.gru_cache if model_type == 'gru' else self.lstm_cache
 | ||
|         key = self._get_cache_key(session, trial, aug_type)
 | ||
|         return cache.get(key, None)
 | ||
|     
 | ||
|     def is_complete(self, sessions, trials_per_session):
 | ||
|         """检查缓存是否完整"""
 | ||
|         total_expected = 0
 | ||
|         total_cached_gru = 0
 | ||
|         total_cached_lstm = 0
 | ||
|         
 | ||
|         for session in sessions:
 | ||
|             for trial in range(trials_per_session[session]):
 | ||
|                 for aug_type in self.augmentation_types:
 | ||
|                     total_expected += 1
 | ||
|                     if self.get_prediction('gru', session, trial, aug_type) is not None:
 | ||
|                         total_cached_gru += 1
 | ||
|                     if self.get_prediction('lstm', session, trial, aug_type) is not None:
 | ||
|                         total_cached_lstm += 1
 | ||
|         
 | ||
|         return (total_cached_gru == total_expected and 
 | ||
|                 total_cached_lstm == total_expected)
 | ||
| 
 | ||
| class TTAEDifferentialEvolutionOptimizer:
 | ||
|     """使用差分进化算法优化TTA-E参数的主类"""
 | ||
|     
 | ||
|     def __init__(self, 
 | ||
|                  gru_model_path='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline',
 | ||
|                  lstm_model_path='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn',
 | ||
|                  data_dir='../data/hdf5_data_final',
 | ||
|                  csv_path='../data/t15_copyTaskData_description.csv',
 | ||
|                  gpu_number=0):
 | ||
|         
 | ||
|         self.gru_model_path = gru_model_path
 | ||
|         self.lstm_model_path = lstm_model_path
 | ||
|         self.data_dir = data_dir
 | ||
|         self.csv_path = csv_path
 | ||
|         self.gpu_number = gpu_number
 | ||
|         
 | ||
|         # 初始化设备
 | ||
|         if torch.cuda.is_available() and gpu_number >= 0:
 | ||
|             self.device = torch.device(f'cuda:{gpu_number}')
 | ||
|             print(f'Using {self.device} for model inference.')
 | ||
|         else:
 | ||
|             self.device = torch.device('cpu')
 | ||
|             print('Using CPU for model inference.')
 | ||
|         
 | ||
|         # 初始化缓存
 | ||
|         self.cache = TTAEnsembleCache()
 | ||
|         
 | ||
|         # TTA参数
 | ||
|         self.tta_noise_std = 0.01
 | ||
|         self.tta_smooth_range = 0.5
 | ||
|         self.tta_scale_range = 0.05
 | ||
|         self.tta_cut_max = 3
 | ||
|         
 | ||
|         # 差分进化算法参数
 | ||
|         self.population_size = 15  # 种群大小倍数 (实际种群 = 15 * 6维 = 90)
 | ||
|         self.max_iterations = 50   # 最大迭代次数
 | ||
|         self.tolerance = 1e-6      # 收敛容忍度
 | ||
|         self.mutation_factor = 0.7 # 变异因子 [0.5, 2.0]
 | ||
|         self.crossover_prob = 0.9  # 交叉概率 [0, 1]
 | ||
|         
 | ||
|         # 评估计数器
 | ||
|         self.evaluation_count = 0
 | ||
|         self.best_per_history = []
 | ||
|         
 | ||
|         # 加载模型和数据
 | ||
|         self._load_models()
 | ||
|         self._load_data()
 | ||
|         
 | ||
|     def _load_models(self):
 | ||
|         """加载GRU和LSTM模型"""
 | ||
|         print("Loading models...")
 | ||
|         
 | ||
|         # 加载模型参数
 | ||
|         self.gru_model_args = OmegaConf.load(os.path.join(self.gru_model_path, 'checkpoint/args.yaml'))
 | ||
|         self.lstm_model_args = OmegaConf.load(os.path.join(self.lstm_model_path, 'checkpoint/args.yaml'))
 | ||
|         
 | ||
|         # 定义GRU模型
 | ||
|         self.gru_model = GRUDecoder(
 | ||
|             neural_dim=self.gru_model_args['model']['n_input_features'],
 | ||
|             n_units=self.gru_model_args['model']['n_units'], 
 | ||
|             n_days=len(self.gru_model_args['dataset']['sessions']),
 | ||
|             n_classes=self.gru_model_args['dataset']['n_classes'],
 | ||
|             rnn_dropout=self.gru_model_args['model']['rnn_dropout'],
 | ||
|             input_dropout=self.gru_model_args['model']['input_network']['input_layer_dropout'],
 | ||
|             n_layers=self.gru_model_args['model']['n_layers'],
 | ||
|             patch_size=self.gru_model_args['model']['patch_size'],
 | ||
|             patch_stride=self.gru_model_args['model']['patch_stride'],
 | ||
|         )
 | ||
|         
 | ||
|         # 加载GRU模型权重
 | ||
|         gru_checkpoint = torch.load(os.path.join(self.gru_model_path, 'checkpoint/best_checkpoint'), 
 | ||
|                                    weights_only=False, map_location=self.device)
 | ||
|         # 处理模型键名
 | ||
|         for key in list(gru_checkpoint['model_state_dict'].keys()):
 | ||
|             gru_checkpoint['model_state_dict'][key.replace("module.", "")] = gru_checkpoint['model_state_dict'].pop(key)
 | ||
|             gru_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = gru_checkpoint['model_state_dict'].pop(key)
 | ||
|         self.gru_model.load_state_dict(gru_checkpoint['model_state_dict'])
 | ||
|         
 | ||
|         # 定义LSTM模型
 | ||
|         self.lstm_model = LSTMDecoder(
 | ||
|             neural_dim=self.lstm_model_args['model']['n_input_features'],
 | ||
|             n_units=self.lstm_model_args['model']['n_units'], 
 | ||
|             n_days=len(self.lstm_model_args['dataset']['sessions']),
 | ||
|             n_classes=self.lstm_model_args['dataset']['n_classes'],
 | ||
|             rnn_dropout=self.lstm_model_args['model']['rnn_dropout'],
 | ||
|             input_dropout=self.lstm_model_args['model']['input_network']['input_layer_dropout'],
 | ||
|             n_layers=self.lstm_model_args['model']['n_layers'],
 | ||
|             patch_size=self.lstm_model_args['model']['patch_size'],
 | ||
|             patch_stride=self.lstm_model_args['model']['patch_stride'],
 | ||
|         )
 | ||
|         
 | ||
|         # 加载LSTM模型权重
 | ||
|         lstm_checkpoint = torch.load(os.path.join(self.lstm_model_path, 'checkpoint/best_checkpoint'), 
 | ||
|                                     weights_only=False, map_location=self.device)
 | ||
|         # 处理模型键名
 | ||
|         for key in list(lstm_checkpoint['model_state_dict'].keys()):
 | ||
|             lstm_checkpoint['model_state_dict'][key.replace("module.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
 | ||
|             lstm_checkpoint['model_state_dict'][key.replace("_orig_mod.", "")] = lstm_checkpoint['model_state_dict'].pop(key)
 | ||
|         self.lstm_model.load_state_dict(lstm_checkpoint['model_state_dict'])
 | ||
|         
 | ||
|         # 移动模型到设备并设置为评估模式
 | ||
|         self.gru_model.to(self.device)
 | ||
|         self.lstm_model.to(self.device)
 | ||
|         self.gru_model.eval()
 | ||
|         self.lstm_model.eval()
 | ||
|         
 | ||
|         print("Models loaded successfully!")
 | ||
|     
 | ||
|     def _load_data(self):
 | ||
|         """加载验证数据集"""
 | ||
|         print("Loading validation data...")
 | ||
|         
 | ||
|         # 加载CSV文件
 | ||
|         b2txt_csv_df = pd.read_csv(self.csv_path)
 | ||
|         
 | ||
|         # 加载验证数据
 | ||
|         self.test_data = {}
 | ||
|         self.trials_per_session = {}
 | ||
|         total_test_trials = 0
 | ||
|         
 | ||
|         for session in self.gru_model_args['dataset']['sessions']:
 | ||
|             files = [f for f in os.listdir(os.path.join(self.data_dir, session)) if f.endswith('.hdf5')]
 | ||
|             if 'data_val.hdf5' in files:
 | ||
|                 eval_file = os.path.join(self.data_dir, session, 'data_val.hdf5')
 | ||
|                 data = load_h5py_file(eval_file, b2txt_csv_df)
 | ||
|                 self.test_data[session] = data
 | ||
|                 self.trials_per_session[session] = len(data["neural_features"])
 | ||
|                 total_test_trials += len(data["neural_features"])
 | ||
|                 print(f'Loaded {len(data["neural_features"])} validation trials for session {session}.')
 | ||
|         
 | ||
|         print(f'Total number of validation trials: {total_test_trials}')
 | ||
|     
 | ||
|     def _apply_augmentation(self, x, aug_type):
 | ||
|         """应用数据增强"""
 | ||
|         x_augmented = x.clone()
 | ||
|         
 | ||
|         if aug_type == 'original':
 | ||
|             pass
 | ||
|         elif aug_type == 'noise':
 | ||
|             noise_scale = self.tta_noise_std * (0.5 + 0.5 * np.random.rand())
 | ||
|             noise = torch.randn_like(x_augmented) * noise_scale
 | ||
|             x_augmented = x_augmented + noise
 | ||
|         elif aug_type == 'scale':
 | ||
|             scale_factor = 1.0 + (torch.rand(1).item() - 0.5) * 2 * self.tta_scale_range
 | ||
|             x_augmented = x_augmented * scale_factor
 | ||
|         elif aug_type == 'shift' and self.tta_cut_max > 0:
 | ||
|             shift_amount = np.random.randint(1, min(self.tta_cut_max + 1, x_augmented.shape[1] // 8))
 | ||
|             x_augmented = torch.cat([x_augmented[:, shift_amount:, :], 
 | ||
|                                    x_augmented[:, :shift_amount, :]], dim=1)
 | ||
|         elif aug_type == 'smooth':
 | ||
|             smooth_variation = (torch.rand(1).item() - 0.5) * 2 * self.tta_smooth_range
 | ||
|             varied_smooth_std = max(0.3, self.gru_model_args['dataset']['data_transforms']['smooth_kernel_std'] + smooth_variation)
 | ||
|         
 | ||
|         return x_augmented
 | ||
|     
 | ||
|     def _get_model_prediction(self, model, model_args, x_smoothed, input_layer):
 | ||
|         """获取单个模型的预测结果"""
 | ||
|         with torch.no_grad():
 | ||
|             logits, _ = model(
 | ||
|                 x=x_smoothed,
 | ||
|                 day_idx=torch.tensor([input_layer], device=self.device),
 | ||
|                 states=None,
 | ||
|                 return_state=True,
 | ||
|             )
 | ||
|             probs = torch.softmax(logits, dim=-1)
 | ||
|             return probs
 | ||
|     
 | ||
|     def generate_all_predictions(self):
 | ||
|         """生成所有模型在所有增强方式下的预测结果并缓存"""
 | ||
|         print("Generating all TTA predictions for caching...")
 | ||
|         
 | ||
|         # 尝试加载现有缓存
 | ||
|         if self.cache.load_cache():
 | ||
|             if self.cache.is_complete(self.test_data.keys(), self.trials_per_session):
 | ||
|                 print("Complete cache found, skipping prediction generation.")
 | ||
|                 return
 | ||
|             else:
 | ||
|                 print("Incomplete cache found, generating missing predictions...")
 | ||
|         
 | ||
|         total_trials = sum(self.trials_per_session.values())
 | ||
|         total_predictions = total_trials * len(self.cache.augmentation_types) * 2  # 2 models
 | ||
|         
 | ||
|         with tqdm(total=total_predictions, desc='Generating cached predictions', unit='pred') as pbar:
 | ||
|             for session, data in self.test_data.items():
 | ||
|                 input_layer = self.gru_model_args['dataset']['sessions'].index(session)
 | ||
|                 
 | ||
|                 for trial in range(len(data['neural_features'])):
 | ||
|                     # 获取神经输入
 | ||
|                     neural_input = data['neural_features'][trial]
 | ||
|                     neural_input = np.expand_dims(neural_input, axis=0)
 | ||
|                     neural_input = torch.tensor(neural_input, device=self.device, dtype=torch.bfloat16)
 | ||
|                     
 | ||
|                     # 为每种增强方式生成预测
 | ||
|                     for aug_type in self.cache.augmentation_types:
 | ||
|                         # 检查是否已缓存
 | ||
|                         if (self.cache.get_prediction('gru', session, trial, aug_type) is not None and
 | ||
|                             self.cache.get_prediction('lstm', session, trial, aug_type) is not None):
 | ||
|                             pbar.update(2)
 | ||
|                             continue
 | ||
|                         
 | ||
|                         # 应用增强
 | ||
|                         x_augmented = self._apply_augmentation(neural_input, aug_type)
 | ||
|                         
 | ||
|                         # 应用高斯平滑
 | ||
|                         default_smooth_std = self.gru_model_args['dataset']['data_transforms']['smooth_kernel_std']
 | ||
|                         default_smooth_size = self.gru_model_args['dataset']['data_transforms']['smooth_kernel_size']
 | ||
|                         
 | ||
|                         if aug_type == 'smooth':
 | ||
|                             smooth_variation = (torch.rand(1).item() - 0.5) * 2 * self.tta_smooth_range
 | ||
|                             varied_smooth_std = max(0.3, default_smooth_std + smooth_variation)
 | ||
|                         else:
 | ||
|                             varied_smooth_std = default_smooth_std
 | ||
|                         
 | ||
|                         with torch.autocast(device_type="cuda", enabled=self.gru_model_args['use_amp'], dtype=torch.bfloat16):
 | ||
|                             x_smoothed = gauss_smooth(
 | ||
|                                 inputs=x_augmented, 
 | ||
|                                 device=self.device,
 | ||
|                                 smooth_kernel_std=varied_smooth_std,
 | ||
|                                 smooth_kernel_size=default_smooth_size,
 | ||
|                                 padding='valid',
 | ||
|                             )
 | ||
|                             
 | ||
|                             # GRU预测
 | ||
|                             if self.cache.get_prediction('gru', session, trial, aug_type) is None:
 | ||
|                                 gru_probs = self._get_model_prediction(
 | ||
|                                     self.gru_model, self.gru_model_args, x_smoothed, input_layer
 | ||
|                                 )
 | ||
|                                 self.cache.add_prediction('gru', session, trial, aug_type, 
 | ||
|                                                         gru_probs.cpu().numpy())
 | ||
|                                 pbar.update(1)
 | ||
|                             
 | ||
|                             # LSTM预测
 | ||
|                             if self.cache.get_prediction('lstm', session, trial, aug_type) is None:
 | ||
|                                 lstm_probs = self._get_model_prediction(
 | ||
|                                     self.lstm_model, self.lstm_model_args, x_smoothed, input_layer
 | ||
|                                 )
 | ||
|                                 self.cache.add_prediction('lstm', session, trial, aug_type, 
 | ||
|                                                          lstm_probs.cpu().numpy())
 | ||
|                                 pbar.update(1)
 | ||
|         
 | ||
|         # 保存缓存
 | ||
|         print("Saving cache to disk...")
 | ||
|         self.cache.save_cache()
 | ||
|         print("Cache generation completed!")
 | ||
|     
 | ||
|     def evaluate_parameters(self, gru_weight, tta_weights):
 | ||
|         """评估给定参数组合的PER性能"""
 | ||
|         lstm_weight = 1.0 - gru_weight
 | ||
|         
 | ||
|         # 将tta_weights转换为字典
 | ||
|         tta_weights_dict = {
 | ||
|             'original': tta_weights[0],
 | ||
|             'noise': tta_weights[1],
 | ||
|             'scale': tta_weights[2],
 | ||
|             'shift': tta_weights[3],
 | ||
|             'smooth': tta_weights[4]
 | ||
|         }
 | ||
|         
 | ||
|         total_true_length = 0
 | ||
|         total_edit_distance = 0
 | ||
|         
 | ||
|         for session, data in self.test_data.items():
 | ||
|             for trial in range(len(data['neural_features'])):
 | ||
|                 # 收集所有增强方式的预测结果
 | ||
|                 all_gru_probs = []
 | ||
|                 all_lstm_probs = []
 | ||
|                 sample_weights = []
 | ||
|                 
 | ||
|                 for aug_type in self.cache.augmentation_types:
 | ||
|                     if tta_weights_dict[aug_type] <= 0:
 | ||
|                         continue
 | ||
|                     
 | ||
|                     # 从缓存获取预测结果
 | ||
|                     gru_probs = self.cache.get_prediction('gru', session, trial, aug_type)
 | ||
|                     lstm_probs = self.cache.get_prediction('lstm', session, trial, aug_type)
 | ||
|                     
 | ||
|                     if gru_probs is not None and lstm_probs is not None:
 | ||
|                         all_gru_probs.append(torch.tensor(gru_probs))
 | ||
|                         all_lstm_probs.append(torch.tensor(lstm_probs))
 | ||
|                         sample_weights.append(tta_weights_dict[aug_type])
 | ||
|                 
 | ||
|                 if len(all_gru_probs) == 0:
 | ||
|                     continue
 | ||
|                 
 | ||
|                 # TTA融合
 | ||
|                 if len(all_gru_probs) > 1:
 | ||
|                     min_length = min([probs.shape[1] for probs in all_gru_probs + all_lstm_probs])
 | ||
|                     
 | ||
|                     # 截断到最小长度
 | ||
|                     truncated_gru_probs = []
 | ||
|                     truncated_lstm_probs = []
 | ||
|                     for gru_probs, lstm_probs in zip(all_gru_probs, all_lstm_probs):
 | ||
|                         if gru_probs.shape[1] > min_length:
 | ||
|                             truncated_gru_probs.append(gru_probs[:, :min_length, :])
 | ||
|                         else:
 | ||
|                             truncated_gru_probs.append(gru_probs)
 | ||
|                         
 | ||
|                         if lstm_probs.shape[1] > min_length:
 | ||
|                             truncated_lstm_probs.append(lstm_probs[:, :min_length, :])
 | ||
|                         else:
 | ||
|                             truncated_lstm_probs.append(lstm_probs)
 | ||
|                     
 | ||
|                     # 加权平均
 | ||
|                     sample_weights_tensor = torch.tensor(sample_weights, dtype=torch.float32)
 | ||
|                     sample_weights_tensor = sample_weights_tensor / sample_weights_tensor.sum()
 | ||
|                     
 | ||
|                     weighted_gru_probs = torch.zeros_like(truncated_gru_probs[0])
 | ||
|                     weighted_lstm_probs = torch.zeros_like(truncated_lstm_probs[0])
 | ||
|                     
 | ||
|                     for i, (gru_probs, lstm_probs, weight) in enumerate(zip(truncated_gru_probs, truncated_lstm_probs, sample_weights_tensor)):
 | ||
|                         weighted_gru_probs += weight * gru_probs
 | ||
|                         weighted_lstm_probs += weight * lstm_probs
 | ||
|                     
 | ||
|                     avg_gru_probs = weighted_gru_probs
 | ||
|                     avg_lstm_probs = weighted_lstm_probs
 | ||
|                 else:
 | ||
|                     avg_gru_probs = all_gru_probs[0]
 | ||
|                     avg_lstm_probs = all_lstm_probs[0]
 | ||
|                 
 | ||
|                 # 集成融合 (几何平均)
 | ||
|                 epsilon = 1e-8
 | ||
|                 avg_gru_probs = avg_gru_probs + epsilon
 | ||
|                 avg_lstm_probs = avg_lstm_probs + epsilon
 | ||
|                 
 | ||
|                 log_ensemble_probs = (gru_weight * torch.log(avg_gru_probs) + 
 | ||
|                                      lstm_weight * torch.log(avg_lstm_probs))
 | ||
|                 ensemble_probs = torch.exp(log_ensemble_probs)
 | ||
|                 ensemble_probs = ensemble_probs / ensemble_probs.sum(dim=-1, keepdim=True)
 | ||
|                 
 | ||
|                 # 转换为预测序列
 | ||
|                 logits = torch.log(ensemble_probs + epsilon)
 | ||
|                 pred_seq = np.argmax(logits[0].numpy(), axis=-1)
 | ||
|                 
 | ||
|                 # 移除空白和连续重复
 | ||
|                 pred_seq = [int(p) for p in pred_seq if p != 0]
 | ||
|                 pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]
 | ||
|                 pred_seq = [LOGIT_TO_PHONEME[p] for p in pred_seq]
 | ||
|                 
 | ||
|                 # 获取真实序列
 | ||
|                 true_seq = data['seq_class_ids'][trial][0:data['seq_len'][trial]]
 | ||
|                 true_seq = [LOGIT_TO_PHONEME[p] for p in true_seq]
 | ||
|                 
 | ||
|                 # 计算编辑距离
 | ||
|                 ed = editdistance.eval(true_seq, pred_seq)
 | ||
|                 total_true_length += len(true_seq)
 | ||
|                 total_edit_distance += ed
 | ||
|         
 | ||
|         # 计算PER
 | ||
|         if total_true_length == 0:
 | ||
|             return 100.0  # 返回最大PER作为惩罚
 | ||
|         
 | ||
|         per = 100 * total_edit_distance / total_true_length
 | ||
|         return per
 | ||
|     
 | ||
|     def objective_function(self, params):
 | ||
|         """差分进化的目标函数"""
 | ||
|         self.evaluation_count += 1
 | ||
|         
 | ||
|         # 解码参数
 | ||
|         gru_weight = params[0]  # 范围 [0, 1]
 | ||
|         tta_weights = params[1:6]  # 5个TTA权重
 | ||
|         
 | ||
|         # 确保权重非负
 | ||
|         tta_weights = np.maximum(tta_weights, 0)
 | ||
|         
 | ||
|         # 如果所有TTA权重都为0,返回最大PER作为惩罚
 | ||
|         if np.sum(tta_weights) == 0:
 | ||
|             return 100.0
 | ||
|         
 | ||
|         try:
 | ||
|             per = self.evaluate_parameters(gru_weight, tta_weights)
 | ||
|             
 | ||
|             # 记录历史最佳
 | ||
|             if len(self.best_per_history) == 0 or per < min(self.best_per_history):
 | ||
|                 print(f"🎯 Eval {self.evaluation_count}: New best PER = {per:.4f}%")
 | ||
|                 print(f"   GRU weight = {gru_weight:.4f}, TTA weights = {tta_weights}")
 | ||
|             
 | ||
|             self.best_per_history.append(per)
 | ||
|             
 | ||
|             # 每10次评估输出进度
 | ||
|             if self.evaluation_count % 10 == 0:
 | ||
|                 current_best = min(self.best_per_history)
 | ||
|                 print(f"📊 Progress: {self.evaluation_count} evaluations, Best PER = {current_best:.4f}%")
 | ||
|             
 | ||
|             return per  # 差分进化默认最小化目标函数
 | ||
|             
 | ||
|         except Exception as e:
 | ||
|             print(f"Error in objective function evaluation: {e}")
 | ||
|             return 100.0
 | ||
|     
 | ||
|     def optimize(self):
 | ||
|         """运行差分进化算法优化"""
 | ||
|         print("Starting Differential Evolution optimization...")
 | ||
|         print(f"Algorithm parameters:")
 | ||
|         print(f"  - Population size multiplier: {self.population_size}")
 | ||
|         print(f"  - Max iterations: {self.max_iterations}")
 | ||
|         print(f"  - Mutation factor: {self.mutation_factor}")
 | ||
|         print(f"  - Crossover probability: {self.crossover_prob}")
 | ||
|         print(f"  - Tolerance: {self.tolerance}")
 | ||
|         
 | ||
|         # 首先生成所有预测并缓存
 | ||
|         self.generate_all_predictions()
 | ||
|         
 | ||
|         # 定义参数边界
 | ||
|         # params = [gru_weight, original_weight, noise_weight, scale_weight, shift_weight, smooth_weight]
 | ||
|         bounds = [
 | ||
|             (0.0, 1.0),  # gru_weight: [0, 1]
 | ||
|             (0.0, 5.0),  # original weight: [0, 5]
 | ||
|             (0.0, 5.0),  # noise weight: [0, 5]
 | ||
|             (0.0, 5.0),  # scale weight: [0, 5]
 | ||
|             (0.0, 5.0),  # shift weight: [0, 5]
 | ||
|             (0.0, 5.0),  # smooth weight: [0, 5]
 | ||
|         ]
 | ||
|         
 | ||
|         print(f"\nParameter bounds:")
 | ||
|         param_names = ['GRU weight', 'Original', 'Noise', 'Scale', 'Shift', 'Smooth']
 | ||
|         for i, (name, (low, high)) in enumerate(zip(param_names, bounds)):
 | ||
|             print(f"  - {name}: [{low}, {high}]")
 | ||
|         
 | ||
|         # 运行差分进化优化
 | ||
|         print(f"\nRunning differential evolution optimization...")
 | ||
|         start_time = time.time()
 | ||
|         
 | ||
|         result = differential_evolution(
 | ||
|             func=self.objective_function,
 | ||
|             bounds=bounds,
 | ||
|             popsize=self.population_size,    # 种群大小倍数
 | ||
|             maxiter=self.max_iterations,     # 最大迭代次数
 | ||
|             tol=self.tolerance,              # 收敛容忍度
 | ||
|             mutation=self.mutation_factor,   # 变异因子
 | ||
|             recombination=self.crossover_prob, # 交叉概率
 | ||
|             seed=42,                         # 随机种子确保可复现
 | ||
|             disp=True,                       # 显示优化过程
 | ||
|             polish=True,                     # 最后用局部优化算法精炼
 | ||
|             workers=1,                       # 单线程避免缓存冲突
 | ||
|             updating='deferred',             # 延迟更新策略
 | ||
|         )
 | ||
|         
 | ||
|         end_time = time.time()
 | ||
|         
 | ||
|         # 获取最佳解
 | ||
|         best_params = result.x
 | ||
|         best_per = result.fun
 | ||
|         
 | ||
|         print("\n" + "="*60)
 | ||
|         print("DIFFERENTIAL EVOLUTION OPTIMIZATION COMPLETED!")
 | ||
|         print("="*60)
 | ||
|         print(f"Optimization time: {end_time - start_time:.2f} seconds")
 | ||
|         print(f"Total function evaluations: {result.nfev}")
 | ||
|         print(f"Optimization success: {result.success}")
 | ||
|         print(f"Termination message: {result.message}")
 | ||
|         
 | ||
|         print(f"\nBest solution:")
 | ||
|         print(f"Best GRU weight: {best_params[0]:.4f}")
 | ||
|         print(f"Best LSTM weight: {1.0 - best_params[0]:.4f}")
 | ||
|         print(f"Best TTA weights:")
 | ||
|         aug_types = ['original', 'noise', 'scale', 'shift', 'smooth']
 | ||
|         for i, aug_type in enumerate(aug_types):
 | ||
|             print(f"  - {aug_type}: {best_params[i+1]:.4f}")
 | ||
|         print(f"Best PER: {best_per:.4f}%")
 | ||
|         
 | ||
|         # 保存详细结果
 | ||
|         optimization_result = {
 | ||
|             'best_params': best_params,
 | ||
|             'gru_weight': best_params[0],
 | ||
|             'lstm_weight': 1.0 - best_params[0],
 | ||
|             'tta_weights': {
 | ||
|                 'original': best_params[1],
 | ||
|                 'noise': best_params[2], 
 | ||
|                 'scale': best_params[3],
 | ||
|                 'shift': best_params[4],
 | ||
|                 'smooth': best_params[5]
 | ||
|             },
 | ||
|             'best_per': best_per,
 | ||
|             'optimization_time': end_time - start_time,
 | ||
|             'function_evaluations': result.nfev,
 | ||
|             'success': result.success,
 | ||
|             'message': result.message,
 | ||
|             'per_history': self.best_per_history,
 | ||
|             'algorithm': 'differential_evolution',
 | ||
|             'algorithm_params': {
 | ||
|                 'popsize': self.population_size,
 | ||
|                 'maxiter': self.max_iterations,
 | ||
|                 'mutation': self.mutation_factor,
 | ||
|                 'recombination': self.crossover_prob,
 | ||
|                 'tolerance': self.tolerance
 | ||
|             }
 | ||
|         }
 | ||
|         
 | ||
|         # 保存到文件
 | ||
|         timestamp = time.strftime("%Y%m%d_%H%M%S")
 | ||
|         result_file = f'de_optimization_result_{timestamp}.pkl'
 | ||
|         with open(result_file, 'wb') as f:
 | ||
|             pickle.dump(optimization_result, f)
 | ||
|         
 | ||
|         print(f"\nResults saved to: {result_file}")
 | ||
|         
 | ||
|         # 性能分析
 | ||
|         print(f"\nPerformance Analysis:")
 | ||
|         print(f"  - Average evaluation time: {(end_time - start_time) / result.nfev:.3f} seconds")
 | ||
|         print(f"  - Evaluations per minute: {result.nfev / ((end_time - start_time) / 60):.1f}")
 | ||
|         if len(self.best_per_history) > 10:
 | ||
|             improvement = self.best_per_history[0] - min(self.best_per_history)
 | ||
|             print(f"  - Total PER improvement: {improvement:.4f}%")
 | ||
|         
 | ||
|         return optimization_result
 | ||
| 
 | ||
| def main():
 | ||
|     """主函数"""
 | ||
|     print("TTA-E Differential Evolution Optimization")
 | ||
|     print("="*50)
 | ||
|     
 | ||
|     # 创建优化器
 | ||
|     optimizer = TTAEDifferentialEvolutionOptimizer(
 | ||
|         gru_model_path='/root/autodl-tmp/nejm-brain-to-text/data/t15_pretrained_rnn_baseline',
 | ||
|         lstm_model_path='/root/autodl-tmp/nejm-brain-to-text/model_training_lstm/trained_models/baseline_rnn',
 | ||
|         data_dir='../data/hdf5_data_final',
 | ||
|         csv_path='../data/t15_copyTaskData_description.csv',
 | ||
|         gpu_number=0
 | ||
|     )
 | ||
|     
 | ||
|     # 运行优化
 | ||
|     result = optimizer.optimize()
 | ||
|     
 | ||
|     print("\nDifferential Evolution optimization completed successfully!")
 | ||
|     return result
 | ||
| 
 | ||
| if __name__ == "__main__":
 | ||
|     # 设置环境
 | ||
|     os.environ['CUDA_VISIBLE_DEVICES'] = '0'
 | ||
|     
 | ||
|     # 运行主函数
 | ||
|     result = main() | 
