1910 lines
87 KiB
Plaintext
1910 lines
87 KiB
Plaintext
![]() |
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 🎲 改进的随机批次生成器\n",
|
|||
|
"\n",
|
|||
|
"这个版本改进了数据生成策略:\n",
|
|||
|
"- **随机文件选择**: 每次从所有训练文件中随机选择 n=4 个文件\n",
|
|||
|
"- **随机样本采样**: 从选中的文件中随机采样指定数量的样本\n",
|
|||
|
"- **提高数据多样性**: 避免按固定顺序处理文件,减少过拟合风险\n",
|
|||
|
"- **可控批次大小**: 固定每批次样本数,确保训练稳定性"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 🧪 测试改进的随机批次生成器\n",
|
|||
|
"\n",
|
|||
|
"def test_improved_generator(trainer, n_batches_to_test=3):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 测试改进的生成器功能\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" print(\"🧪 测试改进的随机批次生成器...\")\n",
|
|||
|
" print(\"=\" * 50)\n",
|
|||
|
" \n",
|
|||
|
" # 创建生成器\n",
|
|||
|
" generator = trainer.get_training_batch_generator(\n",
|
|||
|
" n_files_per_batch=4, # 每批次选择4个文件\n",
|
|||
|
" batch_size=5000 # 每批次5000个样本\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" # 测试前几个批次\n",
|
|||
|
" for i in range(n_batches_to_test):\n",
|
|||
|
" print(f\"\\n📦 测试批次 {i+1}:\")\n",
|
|||
|
" try:\n",
|
|||
|
" batch_data, batch_name = next(generator)\n",
|
|||
|
" print(f\" ✅ 批次名称: {batch_name}\")\n",
|
|||
|
" print(f\" 📊 数据形状: {batch_data.num_data()} 样本\")\n",
|
|||
|
" print(f\" 🏷️ 标签范围: {min(batch_data.get_label())} - {max(batch_data.get_label())}\")\n",
|
|||
|
" \n",
|
|||
|
" # 分析标签分布\n",
|
|||
|
" from collections import Counter\n",
|
|||
|
" label_dist = Counter(batch_data.get_label())\n",
|
|||
|
" print(f\" 📈 标签分布样例: 0={label_dist.get(0,0)}, 40={label_dist.get(40,0)}\")\n",
|
|||
|
" \n",
|
|||
|
" except Exception as e:\n",
|
|||
|
" print(f\" ❌ 错误: {e}\")\n",
|
|||
|
" break\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n✅ 生成器测试完成!\")\n",
|
|||
|
" print(\"💡 每次运行都会生成不同的文件组合和样本\")\n",
|
|||
|
"\n",
|
|||
|
"# 如果trainer已经创建,运行测试\n",
|
|||
|
"if 'trainer' in locals():\n",
|
|||
|
" test_improved_generator(trainer)\n",
|
|||
|
"else:\n",
|
|||
|
" print(\"⚠️ 请先创建trainer对象再运行此测试\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# 环境配置与Utils"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"==================================================\n",
|
|||
|
"🔧 LightGBM GPU环境检查\n",
|
|||
|
"==================================================\n",
|
|||
|
"❌ 未检测到NVIDIA GPU或驱动\n",
|
|||
|
"\n",
|
|||
|
"❌ 未安装CUDA工具包\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 🚀 LightGBM GPU支持检查与配置\n",
|
|||
|
"\n",
|
|||
|
"print(\"=\"*50)\n",
|
|||
|
"print(\"🔧 LightGBM GPU环境检查\")\n",
|
|||
|
"print(\"=\"*50)\n",
|
|||
|
"\n",
|
|||
|
"# 检查CUDA和GPU驱动\n",
|
|||
|
"import subprocess\n",
|
|||
|
"import sys\n",
|
|||
|
"\n",
|
|||
|
"def run_command(command):\n",
|
|||
|
" \"\"\"运行命令并返回结果\"\"\"\n",
|
|||
|
" try:\n",
|
|||
|
" result = subprocess.run(command, shell=True, capture_output=True, text=True, timeout=10)\n",
|
|||
|
" return result.stdout.strip(), result.returncode == 0\n",
|
|||
|
" except Exception as e:\n",
|
|||
|
" return str(e), False\n",
|
|||
|
"\n",
|
|||
|
"# 检查NVIDIA GPU\n",
|
|||
|
"nvidia_output, nvidia_success = run_command(\"nvidia-smi --query-gpu=name,memory.total,driver_version --format=csv,noheader,nounits\")\n",
|
|||
|
"if nvidia_success:\n",
|
|||
|
" print(\"✅ NVIDIA GPU检测:\")\n",
|
|||
|
" for line in nvidia_output.split('\\n'):\n",
|
|||
|
" if line.strip():\n",
|
|||
|
" print(f\" {line}\")\n",
|
|||
|
"else:\n",
|
|||
|
" print(\"❌ 未检测到NVIDIA GPU或驱动\")\n",
|
|||
|
"\n",
|
|||
|
"# 检查CUDA版本\n",
|
|||
|
"cuda_output, cuda_success = run_command(\"nvcc --version\")\n",
|
|||
|
"if cuda_success:\n",
|
|||
|
" print(\"\\n✅ CUDA工具包:\")\n",
|
|||
|
" # 提取CUDA版本\n",
|
|||
|
" for line in cuda_output.split('\\n'):\n",
|
|||
|
" if 'release' in line:\n",
|
|||
|
" print(f\" {line.strip()}\")\n",
|
|||
|
"else:\n",
|
|||
|
" print(\"\\n❌ 未安装CUDA工具包\")\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"[WinError 2] 系统找不到指定的文件。: 'nejm-brain-to-text'\n",
|
|||
|
"f:\\BRAIN-TO-TEXT\\nejm-brain-to-text\\brain-to-text-25\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"d:\\SoftWare\\Anaconda3\\envs\\b2txt25\\lib\\site-packages\\IPython\\core\\magics\\osm.py:393: UserWarning: This is now an optional IPython functionality, using bookmarks requires you to install the `pickleshare` library.\n",
|
|||
|
" bkms = self.shell.db.get('bookmarks', {})\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"%cd nejm-brain-to-text\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import os\n",
|
|||
|
"import pickle\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"import matplotlib\n",
|
|||
|
"from g2p_en import G2p\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from nejm_b2txt_utils.general_utils import *\n",
|
|||
|
"matplotlib.rcParams['pdf.fonttype'] = 42\n",
|
|||
|
"matplotlib.rcParams['ps.fonttype'] = 42\n",
|
|||
|
"matplotlib.rcParams['font.family'] = 'sans-serif'\n",
|
|||
|
"matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial Unicode MS', 'sans-serif']\n",
|
|||
|
"matplotlib.rcParams['axes.unicode_minus'] = False\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"d:\\SoftWare\\Anaconda3\\envs\\b2txt25\\lib\\site-packages\\IPython\\core\\magics\\osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n",
|
|||
|
" self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"f:\\BRAIN-TO-TEXT\\nejm-brain-to-text\\model_training\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"%cd model_training/\n",
|
|||
|
"from data_augmentations import gauss_smooth"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"LOGIT_TO_PHONEME = [\n",
|
|||
|
" 'BLANK',\n",
|
|||
|
" 'AA', 'AE', 'AH', 'AO', 'AW',\n",
|
|||
|
" 'AY', 'B', 'CH', 'D', 'DH',\n",
|
|||
|
" 'EH', 'ER', 'EY', 'F', 'G',\n",
|
|||
|
" 'HH', 'IH', 'IY', 'JH', 'K',\n",
|
|||
|
" 'L', 'M', 'N', 'NG', 'OW',\n",
|
|||
|
" 'OY', 'P', 'R', 'S', 'SH',\n",
|
|||
|
" 'T', 'TH', 'UH', 'UW', 'V',\n",
|
|||
|
" 'W', 'Y', 'Z', 'ZH',\n",
|
|||
|
" ' | ',\n",
|
|||
|
"]\n",
|
|||
|
"# 全局配置\n",
|
|||
|
"BALANCE_CONFIG = {\n",
|
|||
|
" 'enable_balance': True, # 是否启用数据平衡\n",
|
|||
|
" 'undersample_labels': [0, 40], # 需要下采样的标签 (blank等高频标签)\n",
|
|||
|
" 'oversample_threshold': 0.5, # 过采样阈值 (相对于均值的比例)\n",
|
|||
|
" 'random_state': 42 # 随机种子\n",
|
|||
|
"}\n",
|
|||
|
"# 全局PCA配置\n",
|
|||
|
"PCA_CONFIG = {\n",
|
|||
|
" 'enable_pca': True, # 是否启用PCA\n",
|
|||
|
" 'n_components': None, # None=自动选择, 或指定具体数值\n",
|
|||
|
" 'variance_threshold': 0.95, # 保留95%的方差\n",
|
|||
|
" 'sample_size': 15000, # 用于拟合PCA的样本数\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# 全局PCA对象 (确保只拟合一次)\n",
|
|||
|
"GLOBAL_PCA = {\n",
|
|||
|
" 'scaler': None,\n",
|
|||
|
" 'pca': None,\n",
|
|||
|
" 'is_fitted': False,\n",
|
|||
|
" 'n_components': None\n",
|
|||
|
"}\n",
|
|||
|
"# 设置数据目录和参数【PCA初始化】\n",
|
|||
|
"data_dir = '../data/concatenated_data'\n",
|
|||
|
"MAX_SAMPLES_PER_FILE = -1 # 每个文件最大样本数,可调整"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# 数据读取工作流"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# 2️⃣ 数据加载与PCA降维"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 🚀 内存友好的数据读取 - 分批加载策略 + PCA降维 【这里还缺一个采样】\n",
|
|||
|
"\n",
|
|||
|
"import os\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import gc\n",
|
|||
|
"from sklearn.decomposition import PCA\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"import joblib\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"def load_data_batch(data_dir, data_type, max_samples_per_file=5000):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 分批加载指定类型的数据\n",
|
|||
|
" \n",
|
|||
|
" Args:\n",
|
|||
|
" data_dir: 数据目录\n",
|
|||
|
" data_type: 'train', 'val', 'test'\n",
|
|||
|
" max_samples_per_file: 每个文件最大加载样本数\n",
|
|||
|
" \n",
|
|||
|
" Returns:\n",
|
|||
|
" generator: 数据批次生成器\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" files = [f for f in os.listdir(data_dir) if f.endswith('.npz') and data_type in f]\n",
|
|||
|
" \n",
|
|||
|
" for file_idx, f in enumerate(files):\n",
|
|||
|
" print(f\" 正在加载文件 {file_idx+1}/{len(files)}: {f}\")\n",
|
|||
|
" \n",
|
|||
|
" data = np.load(os.path.join(data_dir, f), allow_pickle=True)\n",
|
|||
|
" trials = data['neural_logits_concatenated']\n",
|
|||
|
" \n",
|
|||
|
" # 限制每个文件的样本数\n",
|
|||
|
" if len(trials) > max_samples_per_file and max_samples_per_file != -1:\n",
|
|||
|
" trials = trials[:max_samples_per_file]\n",
|
|||
|
" print(f\" 限制样本数至: {max_samples_per_file}\")\n",
|
|||
|
" \n",
|
|||
|
" yield trials, f\n",
|
|||
|
" \n",
|
|||
|
" # 清理内存\n",
|
|||
|
" del data, trials\n",
|
|||
|
" gc.collect()\n",
|
|||
|
"\n",
|
|||
|
"def extract_features_labels_batch(trials_batch):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 从试验批次中提取特征和标签\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" features = []\n",
|
|||
|
" labels = []\n",
|
|||
|
" \n",
|
|||
|
" for trial in trials_batch:\n",
|
|||
|
" if trial.shape[0] > 0:\n",
|
|||
|
" for t in range(trial.shape[0]):\n",
|
|||
|
" neural_features = trial[t, :7168] # 前7168维神经特征\n",
|
|||
|
" rnn_logits = trial[t, 7168:] # 后41维RNN输出\n",
|
|||
|
" phoneme_label = np.argmax(rnn_logits)\n",
|
|||
|
" \n",
|
|||
|
" features.append(neural_features)\n",
|
|||
|
" labels.append(phoneme_label)\n",
|
|||
|
" \n",
|
|||
|
" return np.array(features), np.array(labels)\n",
|
|||
|
"\n",
|
|||
|
"def fit_global_pca(data_dir, config):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 在训练数据上拟合全局PCA (只执行一次)\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if GLOBAL_PCA['is_fitted'] or not config['enable_pca']:\n",
|
|||
|
" print(\"🔧 PCA已拟合或未启用,跳过拟合步骤\")\n",
|
|||
|
" return\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n🔧 拟合全局PCA降维器...\")\n",
|
|||
|
" print(f\" 配置: {config}\")\n",
|
|||
|
" \n",
|
|||
|
" # 收集训练样本\n",
|
|||
|
" sample_features = []\n",
|
|||
|
" collected_samples = 0\n",
|
|||
|
" \n",
|
|||
|
" for trials_batch, filename in load_data_batch(data_dir, 'train', 5000):\n",
|
|||
|
" features, labels = extract_features_labels_batch(trials_batch)\n",
|
|||
|
" sample_features.append(features)\n",
|
|||
|
" collected_samples += features.shape[0]\n",
|
|||
|
" \n",
|
|||
|
" if collected_samples >= config['sample_size']:\n",
|
|||
|
" break\n",
|
|||
|
" \n",
|
|||
|
" if sample_features:\n",
|
|||
|
" # 合并样本数据\n",
|
|||
|
" X_sample = np.vstack(sample_features)[:config['sample_size']]\n",
|
|||
|
" print(f\" 实际样本数: {X_sample.shape[0]}\")\n",
|
|||
|
" print(f\" 原始特征数: {X_sample.shape[1]}\")\n",
|
|||
|
" \n",
|
|||
|
" # 标准化\n",
|
|||
|
" GLOBAL_PCA['scaler'] = StandardScaler()\n",
|
|||
|
" X_sample_scaled = GLOBAL_PCA['scaler'].fit_transform(X_sample)\n",
|
|||
|
" \n",
|
|||
|
" # 确定PCA成分数\n",
|
|||
|
" if config['n_components'] is None:\n",
|
|||
|
" print(f\" 🔍 自动选择PCA成分数...\")\n",
|
|||
|
" pca_full = PCA()\n",
|
|||
|
" pca_full.fit(X_sample_scaled)\n",
|
|||
|
" \n",
|
|||
|
" cumsum_var = np.cumsum(pca_full.explained_variance_ratio_)\n",
|
|||
|
" optimal_components = np.argmax(cumsum_var >= config['variance_threshold']) + 1\n",
|
|||
|
" GLOBAL_PCA['n_components'] = min(optimal_components, X_sample.shape[1])\n",
|
|||
|
" \n",
|
|||
|
" print(f\" 保留{config['variance_threshold']*100}%方差需要: {optimal_components} 个成分\")\n",
|
|||
|
" print(f\" 选择成分数: {GLOBAL_PCA['n_components']}\")\n",
|
|||
|
" else:\n",
|
|||
|
" GLOBAL_PCA['n_components'] = config['n_components']\n",
|
|||
|
" print(f\" 使用指定成分数: {GLOBAL_PCA['n_components']}\")\n",
|
|||
|
" \n",
|
|||
|
" # 拟合最终PCA\n",
|
|||
|
" GLOBAL_PCA['pca'] = PCA(n_components=GLOBAL_PCA['n_components'], random_state=42)\n",
|
|||
|
" GLOBAL_PCA['pca'].fit(X_sample_scaled)\n",
|
|||
|
" GLOBAL_PCA['is_fitted'] = True\n",
|
|||
|
" \n",
|
|||
|
" # 保存模型\n",
|
|||
|
" pca_path = \"global_pca_model.joblib\"\n",
|
|||
|
" joblib.dump({\n",
|
|||
|
" 'scaler': GLOBAL_PCA['scaler'], \n",
|
|||
|
" 'pca': GLOBAL_PCA['pca'],\n",
|
|||
|
" 'n_components': GLOBAL_PCA['n_components']\n",
|
|||
|
" }, pca_path)\n",
|
|||
|
" \n",
|
|||
|
" print(f\" ✅ 全局PCA拟合完成!\")\n",
|
|||
|
" print(f\" 降维: {X_sample.shape[1]} → {GLOBAL_PCA['n_components']}\")\n",
|
|||
|
" print(f\" 降维比例: {GLOBAL_PCA['n_components']/X_sample.shape[1]:.2%}\")\n",
|
|||
|
" print(f\" 保留方差: {GLOBAL_PCA['pca'].explained_variance_ratio_.sum():.4f}\")\n",
|
|||
|
" print(f\" 模型已保存: {pca_path}\")\n",
|
|||
|
" \n",
|
|||
|
" # 清理样本数据\n",
|
|||
|
" del sample_features, X_sample, X_sample_scaled\n",
|
|||
|
" gc.collect()\n",
|
|||
|
" else:\n",
|
|||
|
" print(\"❌ 无法收集样本数据用于PCA拟合\")\n",
|
|||
|
"\n",
|
|||
|
"def apply_pca_transform(features):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 应用全局PCA变换\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if not PCA_CONFIG['enable_pca'] or not GLOBAL_PCA['is_fitted']:\n",
|
|||
|
" return features\n",
|
|||
|
" \n",
|
|||
|
" # 标准化 + PCA变换\n",
|
|||
|
" features_scaled = GLOBAL_PCA['scaler'].transform(features)\n",
|
|||
|
" features_pca = GLOBAL_PCA['pca'].transform(features_scaled)\n",
|
|||
|
" return features_pca\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 📊 数据平衡策略 - 标签分布分析与采样优化"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 【采样核心实现】\n",
|
|||
|
"def balance_dataset(X, y, config=BALANCE_CONFIG):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 对数据集进行平衡处理:下采样 + 过采样\n",
|
|||
|
" \n",
|
|||
|
" Args:\n",
|
|||
|
" X: 特征数据\n",
|
|||
|
" y: 标签数据\n",
|
|||
|
" config: 平衡配置\n",
|
|||
|
" \n",
|
|||
|
" Returns:\n",
|
|||
|
" X_balanced, y_balanced: 平衡后的数据\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if not config['enable_balance']:\n",
|
|||
|
" print(\"🔕 数据平衡已禁用,返回原始数据\")\n",
|
|||
|
" return X, y\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n⚖️ 开始数据平衡处理...\")\n",
|
|||
|
" print(f\" 原始数据: {X.shape[0]:,} 样本\")\n",
|
|||
|
" \n",
|
|||
|
" # 分析当前分布 (只考虑1-39号标签的均值)\n",
|
|||
|
" label_counts = Counter(y)\n",
|
|||
|
" counts_exclude_0_40 = [label_counts.get(i, 0) for i in range(1, 40)] # 1-39号标签\n",
|
|||
|
" mean_count = np.mean(counts_exclude_0_40) # 只计算1-39号标签的均值\n",
|
|||
|
" \n",
|
|||
|
" print(f\" 均值样本数 (标签1-39): {mean_count:.0f}\")\n",
|
|||
|
" print(f\" 下采样标签: {config['undersample_labels']}\")\n",
|
|||
|
" print(f\" 过采样阈值: {config['oversample_threshold']} * 均值\")\n",
|
|||
|
" \n",
|
|||
|
" # 准备平衡后的数据\n",
|
|||
|
" X_balanced = []\n",
|
|||
|
" y_balanced = []\n",
|
|||
|
" \n",
|
|||
|
" random.seed(config['random_state'])\n",
|
|||
|
" np.random.seed(config['random_state'])\n",
|
|||
|
" \n",
|
|||
|
" for label in range(41):\n",
|
|||
|
" # 获取当前标签的所有样本\n",
|
|||
|
" label_mask = (y == label)\n",
|
|||
|
" X_label = X[label_mask]\n",
|
|||
|
" y_label = y[label_mask]\n",
|
|||
|
" current_count = len(y_label)\n",
|
|||
|
" \n",
|
|||
|
" if current_count == 0:\n",
|
|||
|
" continue\n",
|
|||
|
" \n",
|
|||
|
" # 决定采样策略\n",
|
|||
|
" if label in config['undersample_labels']:\n",
|
|||
|
" # 下采样到均值水平\n",
|
|||
|
" target_count = int(mean_count)\n",
|
|||
|
" if current_count > target_count:\n",
|
|||
|
" # 下采样\n",
|
|||
|
" indices = np.random.choice(current_count, target_count, replace=False)\n",
|
|||
|
" X_resampled = X_label[indices]\n",
|
|||
|
" y_resampled = y_label[indices]\n",
|
|||
|
" print(f\" 📉 标签 {label}: {current_count} → {target_count} (下采样)\")\n",
|
|||
|
" else:\n",
|
|||
|
" X_resampled = X_label\n",
|
|||
|
" y_resampled = y_label\n",
|
|||
|
" print(f\" ➡️ 标签 {label}: {current_count} (无需下采样)\")\n",
|
|||
|
" \n",
|
|||
|
" elif current_count < mean_count * config['oversample_threshold']:\n",
|
|||
|
" # 过采样到阈值水平\n",
|
|||
|
" target_count = int(mean_count * config['oversample_threshold'])\n",
|
|||
|
" if current_count < target_count:\n",
|
|||
|
" # 过采样\n",
|
|||
|
" X_resampled, y_resampled = resample(\n",
|
|||
|
" X_label, y_label, \n",
|
|||
|
" n_samples=target_count, \n",
|
|||
|
" random_state=config['random_state']\n",
|
|||
|
" )\n",
|
|||
|
" print(f\" 📈 标签 {label}: {current_count} → {target_count} (过采样)\")\n",
|
|||
|
" else:\n",
|
|||
|
" X_resampled = X_label\n",
|
|||
|
" y_resampled = y_label\n",
|
|||
|
" print(f\" ➡️ 标签 {label}: {current_count} (无需过采样)\")\n",
|
|||
|
" else:\n",
|
|||
|
" # 保持不变\n",
|
|||
|
" X_resampled = X_label\n",
|
|||
|
" y_resampled = y_label\n",
|
|||
|
" print(f\" ✅ 标签 {label}: {current_count} (已平衡)\")\n",
|
|||
|
" \n",
|
|||
|
" X_balanced.append(X_resampled)\n",
|
|||
|
" y_balanced.append(y_resampled)\n",
|
|||
|
" \n",
|
|||
|
" # 合并所有平衡后的数据\n",
|
|||
|
" X_balanced = np.vstack(X_balanced)\n",
|
|||
|
" y_balanced = np.hstack(y_balanced)\n",
|
|||
|
" \n",
|
|||
|
" # 随机打乱\n",
|
|||
|
" shuffle_indices = np.random.permutation(len(y_balanced))\n",
|
|||
|
" X_balanced = X_balanced[shuffle_indices]\n",
|
|||
|
" y_balanced = y_balanced[shuffle_indices]\n",
|
|||
|
" \n",
|
|||
|
" print(f\" ✅ 平衡完成: {X_balanced.shape[0]:,} 样本\")\n",
|
|||
|
" print(f\" 数据变化: {X.shape[0]:,} → {X_balanced.shape[0]:,} ({X_balanced.shape[0]/X.shape[0]:.2f}x)\")\n",
|
|||
|
" \n",
|
|||
|
" return X_balanced, y_balanced\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 🔄 集成数据平衡的内存友好数据加载器"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 🧪 数据平衡效果测试"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 🚀 改进版智能数据处理管道"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"🚀 创建智能数据处理管道...\n",
|
|||
|
"✅ 管道创建完成,准备执行步骤1...\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 🚀 改进版智能数据处理管道【没有解决分批训练的问题】\n",
|
|||
|
"# 流程:分析分布 → 确定采样比率 → 拟合PCA(只下采样) → 数据处理(下采样+上采样+PCA)\n",
|
|||
|
"\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"from collections import Counter\n",
|
|||
|
"from sklearn.utils import resample\n",
|
|||
|
"from sklearn.decomposition import PCA\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"import joblib\n",
|
|||
|
"import random\n",
|
|||
|
"import gc\n",
|
|||
|
"\n",
|
|||
|
"class SmartDataPipeline:\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 智能数据处理管道\n",
|
|||
|
" 步骤1: 分析数据分布,确定采样策略\n",
|
|||
|
" 步骤2: 仅下采样拟合PCA参数\n",
|
|||
|
" 步骤3: 数据处理时应用完整采样+PCA降维\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" \n",
|
|||
|
" def __init__(self, data_dir, random_state=42):\n",
|
|||
|
" self.data_dir = data_dir\n",
|
|||
|
" self.random_state = random_state\n",
|
|||
|
" \n",
|
|||
|
" # 步骤1: 分布分析结果\n",
|
|||
|
" self.distribution_analysis = None\n",
|
|||
|
" self.sampling_strategy = None\n",
|
|||
|
" \n",
|
|||
|
" # 步骤2: PCA参数(基于下采样数据拟合)\n",
|
|||
|
" self.pca_scaler = None\n",
|
|||
|
" self.pca_model = None\n",
|
|||
|
" self.pca_components = None\n",
|
|||
|
" self.pca_fitted = False\n",
|
|||
|
" \n",
|
|||
|
" # 配置参数\n",
|
|||
|
" self.undersample_labels = [0, 40] # 需要下采样的标签\n",
|
|||
|
" self.oversample_threshold = 0.5 # 过采样阈值(相对于均值)\n",
|
|||
|
" self.pca_variance_threshold = 0.95 # PCA保留方差比例\n",
|
|||
|
" self.pca_sample_size = 15000 # PCA拟合样本数\n",
|
|||
|
" \n",
|
|||
|
" def step1_analyze_distribution(self, max_samples=100000):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 步骤1: 分析数据分布,确定采样策略\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" print(\"🔍 步骤1: 分析数据分布...\")\n",
|
|||
|
" \n",
|
|||
|
" # 分析验证集分布(代表整体分布特征)\n",
|
|||
|
" all_labels = []\n",
|
|||
|
" for trials_batch, filename in load_data_batch(self.data_dir, 'val', 5000):\n",
|
|||
|
" _, labels = extract_features_labels_batch(trials_batch)\n",
|
|||
|
" all_labels.extend(labels.tolist())\n",
|
|||
|
" if len(all_labels) >= max_samples:\n",
|
|||
|
" break\n",
|
|||
|
" \n",
|
|||
|
" # 统计分析\n",
|
|||
|
" label_counts = Counter(all_labels)\n",
|
|||
|
" \n",
|
|||
|
" # 计算1-39标签的均值(排除0和40)\n",
|
|||
|
" counts_1_39 = [label_counts.get(i, 0) for i in range(1, 40)]\n",
|
|||
|
" target_mean = np.mean(counts_1_39)\n",
|
|||
|
" \n",
|
|||
|
" # 生成采样策略\n",
|
|||
|
" sampling_strategy = {}\n",
|
|||
|
" for label in range(41):\n",
|
|||
|
" current_count = label_counts.get(label, 0)\n",
|
|||
|
" \n",
|
|||
|
" if label in self.undersample_labels:\n",
|
|||
|
" # 下采样到均值水平\n",
|
|||
|
" target_count = int(target_mean)\n",
|
|||
|
" action = 'undersample' if current_count > target_count else 'keep'\n",
|
|||
|
" elif current_count < target_mean * self.oversample_threshold:\n",
|
|||
|
" # 过采样到阈值水平\n",
|
|||
|
" target_count = int(target_mean * self.oversample_threshold)\n",
|
|||
|
" action = 'oversample' if current_count < target_count else 'keep'\n",
|
|||
|
" else:\n",
|
|||
|
" # 保持不变\n",
|
|||
|
" target_count = current_count\n",
|
|||
|
" action = 'keep'\n",
|
|||
|
" \n",
|
|||
|
" sampling_strategy[label] = {\n",
|
|||
|
" 'current_count': current_count,\n",
|
|||
|
" 'target_count': target_count,\n",
|
|||
|
" 'action': action\n",
|
|||
|
" }\n",
|
|||
|
" \n",
|
|||
|
" self.distribution_analysis = {\n",
|
|||
|
" 'label_counts': label_counts,\n",
|
|||
|
" 'target_mean': target_mean,\n",
|
|||
|
" 'total_samples': len(all_labels)\n",
|
|||
|
" }\n",
|
|||
|
" self.sampling_strategy = sampling_strategy\n",
|
|||
|
" \n",
|
|||
|
" print(f\" ✅ 分析完成: {len(all_labels):,} 样本\")\n",
|
|||
|
" print(f\" 📊 标签1-39均值: {target_mean:.0f}\")\n",
|
|||
|
" print(f\" 📉 下采样标签: {self.undersample_labels} → {target_mean:.0f}\")\n",
|
|||
|
" print(f\" 📈 过采样阈值: {self.oversample_threshold} × 均值 = {target_mean * self.oversample_threshold:.0f}\")\n",
|
|||
|
" \n",
|
|||
|
" return self.distribution_analysis, self.sampling_strategy\n",
|
|||
|
"\n",
|
|||
|
"# 创建智能数据处理管道\n",
|
|||
|
"print(\"🚀 创建智能数据处理管道...\")\n",
|
|||
|
"pipeline = SmartDataPipeline(data_dir, random_state=42)\n",
|
|||
|
"print(\"✅ 管道创建完成,准备执行步骤1...\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"✅ 步骤2方法已添加到管道\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 继续添加智能管道的其他方法【管道完善】\n",
|
|||
|
"\n",
|
|||
|
"def step2_fit_pca_with_undersampling(self):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 步骤2: 仅对下采样数据拟合PCA参数(不进行过采样,避免PCA被过采样影响)\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if self.sampling_strategy is None:\n",
|
|||
|
" raise ValueError(\"请先执行步骤1: step1_analyze_distribution()\")\n",
|
|||
|
" \n",
|
|||
|
" print(\"\\n🔧 步骤2: 拟合PCA参数(仅下采样,不过采样)...\")\n",
|
|||
|
" \n",
|
|||
|
" # 收集用于PCA拟合的样本(只下采样,不过采样)\n",
|
|||
|
" pca_features = []\n",
|
|||
|
" collected_samples = 0\n",
|
|||
|
" \n",
|
|||
|
" for trials_batch, filename in load_data_batch(self.data_dir, 'train', 3000):\n",
|
|||
|
" features, labels = extract_features_labels_batch(trials_batch)\n",
|
|||
|
" \n",
|
|||
|
" # 对当前批次应用仅下采样策略\n",
|
|||
|
" downsampled_features, downsampled_labels = self._apply_undersampling_only(features, labels)\n",
|
|||
|
" \n",
|
|||
|
" if downsampled_features.shape[0] > 0:\n",
|
|||
|
" pca_features.append(downsampled_features)\n",
|
|||
|
" collected_samples += downsampled_features.shape[0]\n",
|
|||
|
" \n",
|
|||
|
" if collected_samples >= self.pca_sample_size:\n",
|
|||
|
" break\n",
|
|||
|
" \n",
|
|||
|
" if pca_features:\n",
|
|||
|
" # 合并样本\n",
|
|||
|
" X_pca_sample = np.vstack(pca_features)[:self.pca_sample_size]\n",
|
|||
|
" print(f\" 📦 PCA拟合样本: {X_pca_sample.shape[0]:,} 个下采样样本\")\n",
|
|||
|
" print(f\" 🔢 原始特征维度: {X_pca_sample.shape[1]}\")\n",
|
|||
|
" \n",
|
|||
|
" # 标准化\n",
|
|||
|
" self.pca_scaler = StandardScaler()\n",
|
|||
|
" X_scaled = self.pca_scaler.fit_transform(X_pca_sample)\n",
|
|||
|
" \n",
|
|||
|
" # 确定PCA成分数\n",
|
|||
|
" pca_full = PCA()\n",
|
|||
|
" pca_full.fit(X_scaled)\n",
|
|||
|
" cumsum_var = np.cumsum(pca_full.explained_variance_ratio_)\n",
|
|||
|
" optimal_components = np.argmax(cumsum_var >= self.pca_variance_threshold) + 1\n",
|
|||
|
" self.pca_components = min(optimal_components, X_pca_sample.shape[1])\n",
|
|||
|
" \n",
|
|||
|
" # 拟合最终PCA\n",
|
|||
|
" self.pca_model = PCA(n_components=self.pca_components, random_state=self.random_state)\n",
|
|||
|
" self.pca_model.fit(X_scaled)\n",
|
|||
|
" self.pca_fitted = True\n",
|
|||
|
" \n",
|
|||
|
" # 保存PCA模型\n",
|
|||
|
" pca_path = \"smart_pipeline_pca.joblib\"\n",
|
|||
|
" joblib.dump({\n",
|
|||
|
" 'scaler': self.pca_scaler,\n",
|
|||
|
" 'pca': self.pca_model,\n",
|
|||
|
" 'components': self.pca_components\n",
|
|||
|
" }, pca_path)\n",
|
|||
|
" \n",
|
|||
|
" print(f\" ✅ PCA拟合完成!\")\n",
|
|||
|
" print(f\" 降维: {X_pca_sample.shape[1]} → {self.pca_components}\")\n",
|
|||
|
" print(f\" 降维比例: {self.pca_components/X_pca_sample.shape[1]:.2%}\")\n",
|
|||
|
" print(f\" 保留方差: {self.pca_model.explained_variance_ratio_.sum():.4f}\")\n",
|
|||
|
" print(f\" 模型保存: {pca_path}\")\n",
|
|||
|
" \n",
|
|||
|
" # 清理内存\n",
|
|||
|
" del pca_features, X_pca_sample, X_scaled\n",
|
|||
|
" gc.collect()\n",
|
|||
|
" else:\n",
|
|||
|
" raise ValueError(\"无法收集PCA拟合样本\")\n",
|
|||
|
"\n",
|
|||
|
"def _apply_undersampling_only(self, X, y):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 仅应用下采样策略(用于PCA拟合)\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" X_result = []\n",
|
|||
|
" y_result = []\n",
|
|||
|
" \n",
|
|||
|
" np.random.seed(self.random_state)\n",
|
|||
|
" \n",
|
|||
|
" for label in range(41):\n",
|
|||
|
" label_mask = (y == label)\n",
|
|||
|
" X_label = X[label_mask]\n",
|
|||
|
" y_label = y[label_mask]\n",
|
|||
|
" current_count = len(y_label)\n",
|
|||
|
" \n",
|
|||
|
" if current_count == 0:\n",
|
|||
|
" continue\n",
|
|||
|
" \n",
|
|||
|
" strategy = self.sampling_strategy[label]\n",
|
|||
|
" \n",
|
|||
|
" if strategy['action'] == 'undersample' and current_count > strategy['target_count']:\n",
|
|||
|
" # 下采样\n",
|
|||
|
" indices = np.random.choice(current_count, strategy['target_count'], replace=False)\n",
|
|||
|
" X_resampled = X_label[indices]\n",
|
|||
|
" y_resampled = y_label[indices]\n",
|
|||
|
" else:\n",
|
|||
|
" # 保持原样\n",
|
|||
|
" X_resampled = X_label\n",
|
|||
|
" y_resampled = y_label\n",
|
|||
|
" \n",
|
|||
|
" X_result.append(X_resampled)\n",
|
|||
|
" y_result.append(y_resampled)\n",
|
|||
|
" \n",
|
|||
|
" if X_result:\n",
|
|||
|
" return np.vstack(X_result), np.hstack(y_result)\n",
|
|||
|
" else:\n",
|
|||
|
" return np.array([]).reshape(0, X.shape[1]), np.array([])\n",
|
|||
|
"\n",
|
|||
|
"# 动态添加方法到类\n",
|
|||
|
"SmartDataPipeline.step2_fit_pca_with_undersampling = step2_fit_pca_with_undersampling\n",
|
|||
|
"SmartDataPipeline._apply_undersampling_only = _apply_undersampling_only\n",
|
|||
|
"\n",
|
|||
|
"print(\"✅ 步骤2方法已添加到管道\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"✅ 所有方法已添加到智能管道\n",
|
|||
|
"\n",
|
|||
|
"📋 智能数据处理管道状态:\n",
|
|||
|
" 🔍 步骤1 - 分布分析: ❌ 未完成\n",
|
|||
|
" 🔧 步骤2 - PCA拟合: ❌ 未完成\n",
|
|||
|
"\n",
|
|||
|
"🎯 使用流程:\n",
|
|||
|
" 1. pipeline.step1_analyze_distribution()\n",
|
|||
|
" 2. pipeline.step2_fit_pca_with_undersampling()\n",
|
|||
|
" 3. pipeline.step3_process_data('train') # 训练集\n",
|
|||
|
" pipeline.step3_process_data('val') # 验证集\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 添加智能管道的剩余方法\n",
|
|||
|
"\n",
|
|||
|
"def _apply_full_sampling(self, X, y):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 应用完整的采样策略(下采样+过采样)\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" X_result = []\n",
|
|||
|
" y_result = []\n",
|
|||
|
" \n",
|
|||
|
" np.random.seed(self.random_state)\n",
|
|||
|
" \n",
|
|||
|
" for label in range(41):\n",
|
|||
|
" label_mask = (y == label)\n",
|
|||
|
" X_label = X[label_mask]\n",
|
|||
|
" y_label = y[label_mask]\n",
|
|||
|
" current_count = len(y_label)\n",
|
|||
|
" \n",
|
|||
|
" if current_count == 0:\n",
|
|||
|
" continue\n",
|
|||
|
" \n",
|
|||
|
" strategy = self.sampling_strategy[label]\n",
|
|||
|
" target_count = strategy['target_count']\n",
|
|||
|
" \n",
|
|||
|
" if strategy['action'] == 'undersample' and current_count > target_count:\n",
|
|||
|
" # 下采样\n",
|
|||
|
" indices = np.random.choice(current_count, target_count, replace=False)\n",
|
|||
|
" X_resampled = X_label[indices]\n",
|
|||
|
" y_resampled = y_label[indices]\n",
|
|||
|
" elif strategy['action'] == 'oversample' and current_count < target_count:\n",
|
|||
|
" # 过采样\n",
|
|||
|
" X_resampled, y_resampled = resample(\n",
|
|||
|
" X_label, y_label, \n",
|
|||
|
" n_samples=target_count, \n",
|
|||
|
" random_state=self.random_state\n",
|
|||
|
" )\n",
|
|||
|
" else:\n",
|
|||
|
" # 保持原样\n",
|
|||
|
" X_resampled = X_label\n",
|
|||
|
" y_resampled = y_label\n",
|
|||
|
" \n",
|
|||
|
" X_result.append(X_resampled)\n",
|
|||
|
" y_result.append(y_resampled)\n",
|
|||
|
" \n",
|
|||
|
" if X_result:\n",
|
|||
|
" return np.vstack(X_result), np.hstack(y_result)\n",
|
|||
|
" else:\n",
|
|||
|
" return np.array([]).reshape(0, X.shape[1]), np.array([])\n",
|
|||
|
"\n",
|
|||
|
"def _apply_pca_transform(self, X):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 应用PCA变换\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if not self.pca_fitted:\n",
|
|||
|
" return X\n",
|
|||
|
" \n",
|
|||
|
" X_scaled = self.pca_scaler.transform(X)\n",
|
|||
|
" X_pca = self.pca_model.transform(X_scaled)\n",
|
|||
|
" return X_pca\n",
|
|||
|
"\n",
|
|||
|
"def step3_process_data(self, data_type, apply_sampling=None):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 步骤3: 处理数据(采样+PCA降维)\n",
|
|||
|
" \n",
|
|||
|
" Args:\n",
|
|||
|
" data_type: 'train', 'val', 'test'\n",
|
|||
|
" apply_sampling: 是否应用采样策略,None=训练集应用,验证/测试集不应用\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if not self.pca_fitted:\n",
|
|||
|
" raise ValueError(\"请先执行步骤2: step2_fit_pca_with_undersampling()\")\n",
|
|||
|
" \n",
|
|||
|
" if apply_sampling is None:\n",
|
|||
|
" apply_sampling = (data_type == 'train')\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n🔄 步骤3: 处理{data_type}数据...\")\n",
|
|||
|
" print(f\" 采样策略: {'启用' if apply_sampling else '禁用'}\")\n",
|
|||
|
" \n",
|
|||
|
" all_features = []\n",
|
|||
|
" all_labels = []\n",
|
|||
|
" \n",
|
|||
|
" for trials_batch, filename in load_data_batch(self.data_dir, data_type, 3000):\n",
|
|||
|
" features, labels = extract_features_labels_batch(trials_batch)\n",
|
|||
|
" \n",
|
|||
|
" # 应用采样策略\n",
|
|||
|
" if apply_sampling:\n",
|
|||
|
" features_sampled, labels_sampled = self._apply_full_sampling(features, labels)\n",
|
|||
|
" else:\n",
|
|||
|
" features_sampled, labels_sampled = features, labels\n",
|
|||
|
" \n",
|
|||
|
" # 应用PCA降维\n",
|
|||
|
" if features_sampled.shape[0] > 0:\n",
|
|||
|
" features_pca = self._apply_pca_transform(features_sampled)\n",
|
|||
|
" all_features.append(features_pca)\n",
|
|||
|
" all_labels.append(labels_sampled)\n",
|
|||
|
" \n",
|
|||
|
" if all_features:\n",
|
|||
|
" X = np.vstack(all_features)\n",
|
|||
|
" y = np.hstack(all_labels)\n",
|
|||
|
" \n",
|
|||
|
" # 随机打乱\n",
|
|||
|
" shuffle_indices = np.random.permutation(len(y))\n",
|
|||
|
" X = X[shuffle_indices]\n",
|
|||
|
" y = y[shuffle_indices]\n",
|
|||
|
" \n",
|
|||
|
" print(f\" ✅ 处理完成: {X.shape[0]:,} 样本, {X.shape[1]} 特征\")\n",
|
|||
|
" \n",
|
|||
|
" # 清理内存\n",
|
|||
|
" del all_features, all_labels\n",
|
|||
|
" gc.collect()\n",
|
|||
|
" \n",
|
|||
|
" return X, y\n",
|
|||
|
" else:\n",
|
|||
|
" return None, None\n",
|
|||
|
"\n",
|
|||
|
"def print_summary(self):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 打印管道状态总结\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" print(\"\\n📋 智能数据处理管道状态:\")\n",
|
|||
|
" print(f\" 🔍 步骤1 - 分布分析: {'✅ 完成' if self.distribution_analysis else '❌ 未完成'}\")\n",
|
|||
|
" print(f\" 🔧 步骤2 - PCA拟合: {'✅ 完成' if self.pca_fitted else '❌ 未完成'}\")\n",
|
|||
|
" \n",
|
|||
|
" if self.distribution_analysis:\n",
|
|||
|
" target_mean = self.distribution_analysis['target_mean']\n",
|
|||
|
" print(f\" 📊 标签1-39均值: {target_mean:.0f}\")\n",
|
|||
|
" \n",
|
|||
|
" if self.pca_fitted:\n",
|
|||
|
" print(f\" 🔬 PCA降维: 7168 → {self.pca_components} ({self.pca_components/7168:.1%})\")\n",
|
|||
|
" print(f\" 📈 保留方差: {self.pca_model.explained_variance_ratio_.sum():.4f}\")\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n🎯 使用流程:\")\n",
|
|||
|
" print(f\" 1. pipeline.step1_analyze_distribution()\")\n",
|
|||
|
" print(f\" 2. pipeline.step2_fit_pca_with_undersampling()\")\n",
|
|||
|
" print(f\" 3. pipeline.step3_process_data('train') # 训练集\")\n",
|
|||
|
" print(f\" pipeline.step3_process_data('val') # 验证集\")\n",
|
|||
|
"\n",
|
|||
|
"# 动态添加剩余方法到类\n",
|
|||
|
"SmartDataPipeline._apply_full_sampling = _apply_full_sampling\n",
|
|||
|
"SmartDataPipeline._apply_pca_transform = _apply_pca_transform\n",
|
|||
|
"SmartDataPipeline.step3_process_data = step3_process_data\n",
|
|||
|
"SmartDataPipeline.print_summary = print_summary\n",
|
|||
|
"\n",
|
|||
|
"print(\"✅ 所有方法已添加到智能管道\")\n",
|
|||
|
"pipeline.print_summary()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 🔥 执行智能数据处理管道"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"🚀 开始执行智能数据处理管道...\n",
|
|||
|
"============================================================\n",
|
|||
|
"\n",
|
|||
|
"======================🔍 STEP 1: 分析数据分布======================\n",
|
|||
|
"🔍 步骤1: 分析数据分布...\n",
|
|||
|
" 正在加载文件 1/41: t15.2023.08.13_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 2/41: t15.2023.08.18_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 3/41: t15.2023.08.20_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 4/41: t15.2023.08.25_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 5/41: t15.2023.08.27_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 6/41: t15.2023.09.01_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 7/41: t15.2023.09.03_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 8/41: t15.2023.09.24_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 9/41: t15.2023.09.29_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 10/41: t15.2023.10.01_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 11/41: t15.2023.10.06_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 12/41: t15.2023.10.08_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 13/41: t15.2023.10.13_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 14/41: t15.2023.10.15_val_concatenated.npz\n",
|
|||
|
" ✅ 分析完成: 108,742 样本\n",
|
|||
|
" 📊 标签1-39均值: 455\n",
|
|||
|
" 📉 下采样标签: [0, 40] → 455\n",
|
|||
|
" 📈 过采样阈值: 0.5 × 均值 = 227\n",
|
|||
|
"\n",
|
|||
|
"📊 采样策略总结:\n",
|
|||
|
" 📉 下采样标签: 2 个\n",
|
|||
|
" 📈 过采样标签: 11 个\n",
|
|||
|
" ✅ 保持不变: 28 个\n",
|
|||
|
"\n",
|
|||
|
"✅ 步骤1完成!\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 🔥 执行智能数据处理管道【确定采样策略】\n",
|
|||
|
"\n",
|
|||
|
"print(\"🚀 开始执行智能数据处理管道...\")\n",
|
|||
|
"print(\"=\" * 60)\n",
|
|||
|
"\n",
|
|||
|
"# 步骤1: 分析数据分布\n",
|
|||
|
"print(\"\\n\" + \"🔍 STEP 1: 分析数据分布\".center(60, \"=\"))\n",
|
|||
|
"distribution, strategy = pipeline.step1_analyze_distribution()\n",
|
|||
|
"\n",
|
|||
|
"# 显示采样策略总结\n",
|
|||
|
"print(f\"\\n📊 采样策略总结:\")\n",
|
|||
|
"undersample_count = sum(1 for s in strategy.values() if s['action'] == 'undersample')\n",
|
|||
|
"oversample_count = sum(1 for s in strategy.values() if s['action'] == 'oversample')\n",
|
|||
|
"keep_count = sum(1 for s in strategy.values() if s['action'] == 'keep')\n",
|
|||
|
"\n",
|
|||
|
"print(f\" 📉 下采样标签: {undersample_count} 个\")\n",
|
|||
|
"print(f\" 📈 过采样标签: {oversample_count} 个\") \n",
|
|||
|
"print(f\" ✅ 保持不变: {keep_count} 个\")\n",
|
|||
|
"\n",
|
|||
|
"print(\"\\n✅ 步骤1完成!\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 12,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"\n",
|
|||
|
"=====================🔧 STEP 2: 拟合PCA参数======================\n",
|
|||
|
"\n",
|
|||
|
"🔧 步骤2: 拟合PCA参数(仅下采样,不过采样)...\n",
|
|||
|
" 正在加载文件 1/45: t15.2023.08.11_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 2/45: t15.2023.08.13_train_concatenated.npz\n",
|
|||
|
" 📦 PCA拟合样本: 15,000 个下采样样本\n",
|
|||
|
" 🔢 原始特征维度: 7168\n",
|
|||
|
" ✅ PCA拟合完成!\n",
|
|||
|
" 降维: 7168 → 1153\n",
|
|||
|
" 降维比例: 16.09%\n",
|
|||
|
" 保留方差: 0.9491\n",
|
|||
|
" 模型保存: smart_pipeline_pca.joblib\n",
|
|||
|
"\n",
|
|||
|
"✅ 步骤2完成!\n",
|
|||
|
"\n",
|
|||
|
"📋 智能数据处理管道状态:\n",
|
|||
|
" 🔍 步骤1 - 分布分析: ✅ 完成\n",
|
|||
|
" 🔧 步骤2 - PCA拟合: ✅ 完成\n",
|
|||
|
" 📊 标签1-39均值: 455\n",
|
|||
|
" 🔬 PCA降维: 7168 → 1153 (16.1%)\n",
|
|||
|
" 📈 保留方差: 0.9491\n",
|
|||
|
"\n",
|
|||
|
"🎯 使用流程:\n",
|
|||
|
" 1. pipeline.step1_analyze_distribution()\n",
|
|||
|
" 2. pipeline.step2_fit_pca_with_undersampling()\n",
|
|||
|
" 3. pipeline.step3_process_data('train') # 训练集\n",
|
|||
|
" pipeline.step3_process_data('val') # 验证集\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 步骤2: 拟合PCA参数【确定PCA策略】\n",
|
|||
|
"print(\"\\n\" + \"🔧 STEP 2: 拟合PCA参数\".center(60, \"=\"))\n",
|
|||
|
"pipeline.step2_fit_pca_with_undersampling()\n",
|
|||
|
"\n",
|
|||
|
"print(\"\\n✅ 步骤2完成!\")\n",
|
|||
|
"pipeline.print_summary()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 🚀 使用智能管道进行分批训练"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"🚀 创建智能分批训练器...\n",
|
|||
|
"🎯 智能分批训练器创建完成\n",
|
|||
|
" 🔧 LightGBM参数已配置:CPU模式\n",
|
|||
|
" 💡 学习率调度: 带重启的余弦退火 (从 0.1 到 0.003)\n",
|
|||
|
" 🔄 重启参数: T_0=50, T_mult=2\n",
|
|||
|
"✅ 训练器创建完成,准备开始训练!\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 🚀 使用智能管道进行分批训练\n",
|
|||
|
"\n",
|
|||
|
"import lightgbm as lgb\n",
|
|||
|
"import time\n",
|
|||
|
"from collections import Counter\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"import random\n",
|
|||
|
"\n",
|
|||
|
"class SmartBatchTrainer:\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 智能分批训练器,集成智能数据管道\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" \n",
|
|||
|
" def __init__(self, pipeline, params=None, min_learning_rate=1e-4, t_0=50, t_mult=2):\n",
|
|||
|
" self.pipeline = pipeline\n",
|
|||
|
" self.model = None\n",
|
|||
|
" self.training_history = {} # 改为字典,因为只有一次训练\n",
|
|||
|
" self.batch_count = 0\n",
|
|||
|
" self.min_learning_rate = min_learning_rate\n",
|
|||
|
" self.lr_history = [] # 用于可视化\n",
|
|||
|
" \n",
|
|||
|
" # 带重启的余弦退火参数\n",
|
|||
|
" self.t_0 = t_0 # 第一个重启周期的长度\n",
|
|||
|
" self.t_mult = t_mult # 重启周期的乘数\n",
|
|||
|
" \n",
|
|||
|
" # 默认LightGBM参数(GPU优化)\n",
|
|||
|
" self.params = params or {\n",
|
|||
|
" 'objective': 'multiclass',\n",
|
|||
|
" 'num_class': 41,\n",
|
|||
|
" 'metric': 'multi_logloss',\n",
|
|||
|
" 'boosting_type': 'gbdt',\n",
|
|||
|
" 'device_type': 'cpu',\n",
|
|||
|
" # 'gpu_platform_id': 0,\n",
|
|||
|
" # 'gpu_device_id': 0,\n",
|
|||
|
" 'max_bin': 255,\n",
|
|||
|
" 'num_leaves': 127,\n",
|
|||
|
" 'learning_rate': 0.10, #默认0.08\n",
|
|||
|
" 'feature_fraction': 0.8,\n",
|
|||
|
" 'bagging_fraction': 0.8,\n",
|
|||
|
" 'bagging_freq': 5,\n",
|
|||
|
" 'min_data_in_leaf': 20,\n",
|
|||
|
" 'lambda_l1': 0.1,\n",
|
|||
|
" 'lambda_l2': 0.1,\n",
|
|||
|
" 'verbose': -1,\n",
|
|||
|
" 'num_threads': -1\n",
|
|||
|
" }\n",
|
|||
|
" \n",
|
|||
|
" self.initial_learning_rate = self.params.get('learning_rate', 0.08)\n",
|
|||
|
" \n",
|
|||
|
" print(f\"🎯 智能分批训练器创建完成\")\n",
|
|||
|
" print(f\" 🔧 LightGBM参数已配置:{self.params['device_type'].upper()}模式\")\n",
|
|||
|
" print(f\" 💡 学习率调度: 带重启的余弦退火 (从 {self.initial_learning_rate} 到 {self.min_learning_rate})\")\n",
|
|||
|
" print(f\" 🔄 重启参数: T_0={self.t_0}, T_mult={self.t_mult}\")\n",
|
|||
|
" \n",
|
|||
|
" def prepare_validation_data(self):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 准备验证数据(仅PCA,保持原始分布)\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" print(\"🔄 准备验证数据...\")\n",
|
|||
|
" X_val, y_val = self.pipeline.step3_process_data('val', apply_sampling=False)\n",
|
|||
|
" if X_val is None:\n",
|
|||
|
" raise ValueError(\"无法加载验证数据\")\n",
|
|||
|
" val_counts = Counter(y_val)\n",
|
|||
|
" print(f\" ✅ 验证数据准备完成: {X_val.shape[0]:,} 样本\")\n",
|
|||
|
" print(f\" 📊 验证集分布 (标签0: {val_counts.get(0, 0):,}, 标签40: {val_counts.get(40, 0):,})\")\n",
|
|||
|
" \n",
|
|||
|
" return lgb.Dataset(X_val, label=y_val, free_raw_data=False)\n",
|
|||
|
" \n",
|
|||
|
" def get_training_batch_generator(self, n_files_per_batch=4, batch_size=8000):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 改进的训练批次生成器:每次从所有文件中随机选择n个文件,然后随机采样\n",
|
|||
|
" \n",
|
|||
|
" Args:\n",
|
|||
|
" n_files_per_batch: 每个批次随机选择的文件数量 (默认4)\n",
|
|||
|
" batch_size: 每个批次的目标样本数 (默认8000)\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" print(f\"🔄 准备改进的训练批次生成器...\")\n",
|
|||
|
" print(f\" 📁 每批次选择文件数: {n_files_per_batch}\")\n",
|
|||
|
" print(f\" 📊 每批次目标样本数: {batch_size:,}\")\n",
|
|||
|
" \n",
|
|||
|
" # 获取所有训练文件列表\n",
|
|||
|
" all_train_files = [f for f in os.listdir(self.pipeline.data_dir) \n",
|
|||
|
" if f.endswith('.npz') and 'train' in f]\n",
|
|||
|
" \n",
|
|||
|
" if len(all_train_files) < n_files_per_batch:\n",
|
|||
|
" print(f\" ⚠️ 可用文件数({len(all_train_files)})少于每批次需要的文件数({n_files_per_batch})\")\n",
|
|||
|
" n_files_per_batch = len(all_train_files)\n",
|
|||
|
" \n",
|
|||
|
" print(f\" 📂 总计可用训练文件: {len(all_train_files)}\")\n",
|
|||
|
" \n",
|
|||
|
" batch_id = 0\n",
|
|||
|
" while True: # 无限生成器,可以重复采样\n",
|
|||
|
" batch_id += 1\n",
|
|||
|
" \n",
|
|||
|
" # 随机选择n个文件\n",
|
|||
|
" selected_files = random.sample(all_train_files, n_files_per_batch)\n",
|
|||
|
" \n",
|
|||
|
" print(f\" 🎲 批次 {batch_id} - 随机选择的文件:\")\n",
|
|||
|
" for i, f in enumerate(selected_files, 1):\n",
|
|||
|
" print(f\" {i}. {f}\")\n",
|
|||
|
" \n",
|
|||
|
" # 从选中的文件中加载数据\n",
|
|||
|
" all_features = []\n",
|
|||
|
" all_labels = []\n",
|
|||
|
" total_available_samples = 0\n",
|
|||
|
" \n",
|
|||
|
" for filename in selected_files:\n",
|
|||
|
" # 加载文件数据\n",
|
|||
|
" data = np.load(os.path.join(self.pipeline.data_dir, filename), allow_pickle=True)\n",
|
|||
|
" trials = data['neural_logits_concatenated']\n",
|
|||
|
" \n",
|
|||
|
" # 提取特征和标签\n",
|
|||
|
" features, labels = extract_features_labels_batch(trials)\n",
|
|||
|
" \n",
|
|||
|
" if features.shape[0] > 0:\n",
|
|||
|
" all_features.append(features)\n",
|
|||
|
" all_labels.append(labels)\n",
|
|||
|
" total_available_samples += features.shape[0]\n",
|
|||
|
" \n",
|
|||
|
" # 清理单个文件数据\n",
|
|||
|
" del data, trials\n",
|
|||
|
" gc.collect()\n",
|
|||
|
" \n",
|
|||
|
" if all_features:\n",
|
|||
|
" # 合并所有选中文件的数据\n",
|
|||
|
" combined_features = np.vstack(all_features)\n",
|
|||
|
" combined_labels = np.hstack(all_labels)\n",
|
|||
|
" \n",
|
|||
|
" print(f\" 📊 合并后总样本数: {combined_features.shape[0]:,}\")\n",
|
|||
|
" \n",
|
|||
|
" # 随机采样到目标batch_size\n",
|
|||
|
" if combined_features.shape[0] > batch_size:\n",
|
|||
|
" # 随机选择batch_size个样本\n",
|
|||
|
" sample_indices = np.random.choice(\n",
|
|||
|
" combined_features.shape[0], \n",
|
|||
|
" size=batch_size, \n",
|
|||
|
" replace=False\n",
|
|||
|
" )\n",
|
|||
|
" sampled_features = combined_features[sample_indices]\n",
|
|||
|
" sampled_labels = combined_labels[sample_indices]\n",
|
|||
|
" print(f\" 🎯 随机采样到: {batch_size:,} 样本\")\n",
|
|||
|
" else:\n",
|
|||
|
" # 如果样本不足,使用所有样本\n",
|
|||
|
" sampled_features = combined_features\n",
|
|||
|
" sampled_labels = combined_labels\n",
|
|||
|
" print(f\" ⚠️ 样本不足,使用全部: {sampled_features.shape[0]:,} 样本\")\n",
|
|||
|
" \n",
|
|||
|
" # 应用采样策略(平衡处理)\n",
|
|||
|
" features_balanced, labels_balanced = self.pipeline._apply_full_sampling(\n",
|
|||
|
" sampled_features, sampled_labels\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" # 应用PCA降维\n",
|
|||
|
" if features_balanced.shape[0] > 0:\n",
|
|||
|
" features_pca = self.pipeline._apply_pca_transform(features_balanced)\n",
|
|||
|
" \n",
|
|||
|
" # 分析当前批次分布\n",
|
|||
|
" batch_counts = Counter(labels_balanced)\n",
|
|||
|
" \n",
|
|||
|
" print(f\" 📦 批次 {batch_id} 最终结果:\")\n",
|
|||
|
" print(f\" 平衡后样本数: {features_pca.shape[0]:,}\")\n",
|
|||
|
" print(f\" 特征维度: {features_pca.shape[1]}\")\n",
|
|||
|
" print(f\" 分布: 标签0={batch_counts.get(0,0)}, 标签40={batch_counts.get(40,0)}\")\n",
|
|||
|
" print(f\" \" + \"=\"*50)\n",
|
|||
|
" \n",
|
|||
|
" yield lgb.Dataset(features_pca, label=labels_balanced), f\"batch_{batch_id}_files_{len(selected_files)}\"\n",
|
|||
|
" \n",
|
|||
|
" # 清理批次数据\n",
|
|||
|
" del all_features, all_labels, combined_features, combined_labels\n",
|
|||
|
" del sampled_features, sampled_labels, features_balanced, labels_balanced\n",
|
|||
|
" gc.collect()\n",
|
|||
|
" else:\n",
|
|||
|
" print(f\" ❌ 批次 {batch_id} 无有效数据\")\n",
|
|||
|
" continue\n",
|
|||
|
" \n",
|
|||
|
" def prepare_full_data(self):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 一次性准备所有训练和验证数据\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" print(\"🔄 准备全量训练和验证数据...\")\n",
|
|||
|
" \n",
|
|||
|
" # 1. 准备验证数据 (保持原始分布)\n",
|
|||
|
" X_val, y_val = self.pipeline.step3_process_data('val', apply_sampling=False)\n",
|
|||
|
" if X_val is None:\n",
|
|||
|
" raise ValueError(\"无法加载验证数据\")\n",
|
|||
|
" val_counts = Counter(y_val)\n",
|
|||
|
" print(f\" ✅ 验证数据准备完成: {X_val.shape[0]:,} 样本\")\n",
|
|||
|
" print(f\" 📊 验证集分布 (标签0: {val_counts.get(0, 0):,}, 标签40: {val_counts.get(40, 0):,})\")\n",
|
|||
|
" val_data = lgb.Dataset(X_val, label=y_val, free_raw_data=False)\n",
|
|||
|
" \n",
|
|||
|
" # 2. 准备训练数据 (应用完整采样和PCA策略)\n",
|
|||
|
" X_train, y_train = self.pipeline.step3_process_data('train', apply_sampling=True)\n",
|
|||
|
" if X_train is None:\n",
|
|||
|
" raise ValueError(\"无法加载训练数据\")\n",
|
|||
|
" train_counts = Counter(y_train)\n",
|
|||
|
" print(f\" ✅ 训练数据准备完成: {X_train.shape[0]:,} 样本, {X_train.shape[1]} 特征\")\n",
|
|||
|
" print(f\" 📊 训练集(采样后)分布 (标签0: {train_counts.get(0, 0):,}, 标签40: {train_counts.get(40, 0):,})\")\n",
|
|||
|
" train_data = lgb.Dataset(X_train, label=y_train)\n",
|
|||
|
" \n",
|
|||
|
" return train_data, val_data, X_val, y_val\n",
|
|||
|
" \n",
|
|||
|
" def prepare_training_data(self):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 准备训练数据(仅PCA,保持原始分布)\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" print(\"🔄 准备训练数据...\")\n",
|
|||
|
" # 2. 准备训练数据 (应用完整采样和PCA策略)\n",
|
|||
|
" X_train, y_train = self.pipeline.step3_process_data('train', apply_sampling=True)\n",
|
|||
|
" if X_train is None:\n",
|
|||
|
" raise ValueError(\"无法加载训练数据\")\n",
|
|||
|
" train_counts = Counter(y_train)\n",
|
|||
|
" print(f\" ✅ 训练数据准备完成: {X_train.shape[0]:,} 样本, {X_train.shape[1]} 特征\")\n",
|
|||
|
" print(f\" 📊 训练集(采样后)分布 (标签0: {train_counts.get(0, 0):,}, 标签40: {train_counts.get(40, 0):,})\")\n",
|
|||
|
" \n",
|
|||
|
" return lgb.Dataset(X_train, label=y_train, free_raw_data=False)\n",
|
|||
|
" \n",
|
|||
|
" # 带重启的余弦退火调度器函数\n",
|
|||
|
" def _cosine_annealing_with_warm_restarts(self, current_round):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 带重启的余弦退火调度器 (SGDR)\n",
|
|||
|
" \n",
|
|||
|
" Args:\n",
|
|||
|
" current_round: 当前训练轮数\n",
|
|||
|
" \n",
|
|||
|
" Returns:\n",
|
|||
|
" 学习率\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" eta_max = self.initial_learning_rate\n",
|
|||
|
" eta_min = self.min_learning_rate\n",
|
|||
|
" \n",
|
|||
|
" # 计算当前在哪个重启周期中\n",
|
|||
|
" t_cur = current_round\n",
|
|||
|
" t_i = self.t_0\n",
|
|||
|
" \n",
|
|||
|
" # 找到当前的重启周期\n",
|
|||
|
" cycle = 0\n",
|
|||
|
" while t_cur >= t_i:\n",
|
|||
|
" t_cur -= t_i\n",
|
|||
|
" cycle += 1\n",
|
|||
|
" t_i *= self.t_mult\n",
|
|||
|
" \n",
|
|||
|
" # 在当前周期内的位置\n",
|
|||
|
" progress = t_cur / t_i\n",
|
|||
|
" \n",
|
|||
|
" # 计算学习率\n",
|
|||
|
" lr = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * progress))\n",
|
|||
|
" \n",
|
|||
|
" return lr\n",
|
|||
|
" \n",
|
|||
|
" def train_incremental(self, num_boost_round=100, early_stopping_rounds=10, \n",
|
|||
|
" n_files_per_batch=4, batch_size=8000, max_batches=None):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 增量分批训练 - 支持自定义批次参数\n",
|
|||
|
" \n",
|
|||
|
" Args:\n",
|
|||
|
" num_boost_round: 每批次的训练轮数\n",
|
|||
|
" early_stopping_rounds: 早停轮数\n",
|
|||
|
" n_files_per_batch: 每批次随机选择的文件数\n",
|
|||
|
" batch_size: 每批次目标样本数\n",
|
|||
|
" max_batches: 最大批次数,None表示无限制\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" print(f\"\\n🚀 开始智能分批训练...\")\n",
|
|||
|
" print(f\" 📝 训练轮数 (每批次): {num_boost_round}\")\n",
|
|||
|
" print(f\" ⏹️ 早停轮数: {early_stopping_rounds}\")\n",
|
|||
|
" print(f\" 📁 每批次文件数: {n_files_per_batch}\")\n",
|
|||
|
" print(f\" 每批次样本数: {batch_size:,}\")\n",
|
|||
|
" if max_batches:\n",
|
|||
|
" print(f\" 最大批次数: {max_batches}\")\n",
|
|||
|
" print(\"=\" * 60)\n",
|
|||
|
" \n",
|
|||
|
" # 准备验证数据\n",
|
|||
|
" val_data = self.prepare_validation_data()\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n 开始分批增量训练...\")\n",
|
|||
|
" total_start_time = time.time()\n",
|
|||
|
" \n",
|
|||
|
" # ⭐️ 新增: 为学习率调度器定义T_max\n",
|
|||
|
" # 我们将每个批次的训练视为一个完整的退火周期\n",
|
|||
|
" t_max_per_batch = num_boost_round\n",
|
|||
|
" \n",
|
|||
|
" # 创建改进的生成器\n",
|
|||
|
" batch_generator = self.get_training_batch_generator(n_files_per_batch, batch_size)\n",
|
|||
|
" \n",
|
|||
|
" for train_data, batch_name in batch_generator:\n",
|
|||
|
" self.batch_count += 1\n",
|
|||
|
" batch_start_time = time.time()\n",
|
|||
|
" \n",
|
|||
|
" # 检查是否达到最大批次数\n",
|
|||
|
" if max_batches and self.batch_count > max_batches:\n",
|
|||
|
" print(f\"达到最大批次数 {max_batches},停止训练\")\n",
|
|||
|
" break"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"trainer = SmartBatchTrainer(pipeline,min_learning_rate=0.001)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"🔥 开始智能分批训练!\n",
|
|||
|
"================================================================================\n",
|
|||
|
"📝 训练配置:\n",
|
|||
|
" 训练轮数: 300\n",
|
|||
|
" 早停轮数: 15\n",
|
|||
|
" 数据平衡: 启用(下采样标签0,40 + 过采样少数类)\n",
|
|||
|
" PCA降维: 7168 → 1153 特征\n",
|
|||
|
"\n",
|
|||
|
"🚀 启动训练...\n",
|
|||
|
"\n",
|
|||
|
"🚀 开始全量数据训练...\n",
|
|||
|
" 📝 训练轮数: 300\n",
|
|||
|
" ⏹️ 早停轮数: 15\n",
|
|||
|
"============================================================\n",
|
|||
|
"🔄 准备全量训练和验证数据...\n",
|
|||
|
"\n",
|
|||
|
"🔄 步骤3: 处理val数据...\n",
|
|||
|
" 采样策略: 禁用\n",
|
|||
|
" 正在加载文件 1/41: t15.2023.08.13_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 2/41: t15.2023.08.18_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 3/41: t15.2023.08.20_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 4/41: t15.2023.08.25_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 5/41: t15.2023.08.27_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 6/41: t15.2023.09.01_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 7/41: t15.2023.09.03_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 8/41: t15.2023.09.24_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 9/41: t15.2023.09.29_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 10/41: t15.2023.10.01_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 11/41: t15.2023.10.06_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 12/41: t15.2023.10.08_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 13/41: t15.2023.10.13_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 14/41: t15.2023.10.15_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 15/41: t15.2023.10.20_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 16/41: t15.2023.10.22_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 17/41: t15.2023.11.03_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 18/41: t15.2023.11.04_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 19/41: t15.2023.11.17_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 20/41: t15.2023.11.19_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 21/41: t15.2023.11.26_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 22/41: t15.2023.12.03_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 23/41: t15.2023.12.08_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 24/41: t15.2023.12.10_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 25/41: t15.2023.12.17_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 26/41: t15.2023.12.29_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 27/41: t15.2024.02.25_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 28/41: t15.2024.03.08_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 29/41: t15.2024.03.15_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 30/41: t15.2024.03.17_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 31/41: t15.2024.05.10_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 32/41: t15.2024.06.14_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 33/41: t15.2024.07.19_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 34/41: t15.2024.07.21_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 35/41: t15.2024.07.28_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 36/41: t15.2025.01.10_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 37/41: t15.2025.01.12_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 38/41: t15.2025.03.14_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 39/41: t15.2025.03.16_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 40/41: t15.2025.03.30_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 41/41: t15.2025.04.13_val_concatenated.npz\n",
|
|||
|
" ✅ 处理完成: 321,773 样本, 1153 特征\n",
|
|||
|
" ✅ 验证数据准备完成: 321,773 样本\n",
|
|||
|
" 📊 验证集分布 (标签0: 238,705, 标签40: 35,425)\n",
|
|||
|
"\n",
|
|||
|
"🔄 步骤3: 处理train数据...\n",
|
|||
|
" 采样策略: 启用\n",
|
|||
|
" 正在加载文件 1/45: t15.2023.08.11_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 2/45: t15.2023.08.13_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 3/45: t15.2023.08.18_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 4/45: t15.2023.08.20_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 5/45: t15.2023.08.25_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 6/45: t15.2023.08.27_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 7/45: t15.2023.09.01_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 8/45: t15.2023.09.03_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 9/45: t15.2023.09.24_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 10/45: t15.2023.09.29_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 11/45: t15.2023.10.01_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 12/45: t15.2023.10.06_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 13/45: t15.2023.10.08_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 14/45: t15.2023.10.13_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 15/45: t15.2023.10.15_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 16/45: t15.2023.10.20_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 17/45: t15.2023.10.22_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 18/45: t15.2023.11.03_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 19/45: t15.2023.11.04_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 20/45: t15.2023.11.17_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 21/45: t15.2023.11.19_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 22/45: t15.2023.11.26_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 23/45: t15.2023.12.03_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 24/45: t15.2023.12.08_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 25/45: t15.2023.12.10_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 26/45: t15.2023.12.17_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 27/45: t15.2023.12.29_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 28/45: t15.2024.02.25_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 29/45: t15.2024.03.03_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 30/45: t15.2024.03.08_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 31/45: t15.2024.03.15_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 32/45: t15.2024.03.17_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 33/45: t15.2024.04.25_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 34/45: t15.2024.04.28_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 35/45: t15.2024.05.10_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 36/45: t15.2024.06.14_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 37/45: t15.2024.07.19_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 38/45: t15.2024.07.21_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 39/45: t15.2024.07.28_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 40/45: t15.2025.01.10_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 41/45: t15.2025.01.12_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 42/45: t15.2025.03.14_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 43/45: t15.2025.03.16_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 44/45: t15.2025.03.30_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 45/45: t15.2025.04.13_train_concatenated.npz\n",
|
|||
|
" ✅ 处理完成: 398,908 样本, 1153 特征\n",
|
|||
|
" ✅ 训练数据准备完成: 398,908 样本, 1153 特征\n",
|
|||
|
" 📊 训练集(采样后)分布 (标签0: 20,430, 标签40: 20,430)\n",
|
|||
|
"\n",
|
|||
|
"📈 开始模型训练...\n",
|
|||
|
"[1]\tvalidation's multi_logloss: 2.80791\n",
|
|||
|
"Training until validation scores don't improve for 15 rounds\n",
|
|||
|
"[2]\tvalidation's multi_logloss: 2.70345\n",
|
|||
|
"[3]\tvalidation's multi_logloss: 2.63795\n",
|
|||
|
"[4]\tvalidation's multi_logloss: 2.57732\n",
|
|||
|
"[5]\tvalidation's multi_logloss: 2.52478\n",
|
|||
|
"[6]\tvalidation's multi_logloss: 2.48274\n",
|
|||
|
"[7]\tvalidation's multi_logloss: 2.44558\n",
|
|||
|
"[8]\tvalidation's multi_logloss: 2.41266\n",
|
|||
|
"[9]\tvalidation's multi_logloss: 2.3834\n",
|
|||
|
"[10]\tvalidation's multi_logloss: 2.36057\n",
|
|||
|
"[11]\tvalidation's multi_logloss: 2.33851\n",
|
|||
|
"[12]\tvalidation's multi_logloss: 2.31705\n",
|
|||
|
"[13]\tvalidation's multi_logloss: 2.29887\n",
|
|||
|
"[14]\tvalidation's multi_logloss: 2.28414\n",
|
|||
|
"[15]\tvalidation's multi_logloss: 2.26917\n",
|
|||
|
"[16]\tvalidation's multi_logloss: 2.25536\n",
|
|||
|
"[17]\tvalidation's multi_logloss: 2.24213\n",
|
|||
|
"[18]\tvalidation's multi_logloss: 2.22994\n",
|
|||
|
"[19]\tvalidation's multi_logloss: 2.21842\n",
|
|||
|
"[20]\tvalidation's multi_logloss: 2.20886\n",
|
|||
|
"[21]\tvalidation's multi_logloss: 2.19978\n",
|
|||
|
"[22]\tvalidation's multi_logloss: 2.19116\n",
|
|||
|
"[23]\tvalidation's multi_logloss: 2.18282\n",
|
|||
|
"[24]\tvalidation's multi_logloss: 2.17482\n",
|
|||
|
"[25]\tvalidation's multi_logloss: 2.16816\n",
|
|||
|
"[26]\tvalidation's multi_logloss: 2.16188\n",
|
|||
|
"[27]\tvalidation's multi_logloss: 2.15594\n",
|
|||
|
"[28]\tvalidation's multi_logloss: 2.15015\n",
|
|||
|
"[29]\tvalidation's multi_logloss: 2.1444\n",
|
|||
|
"[30]\tvalidation's multi_logloss: 2.13993\n",
|
|||
|
"[31]\tvalidation's multi_logloss: 2.13535\n",
|
|||
|
"[32]\tvalidation's multi_logloss: 2.13122\n",
|
|||
|
"[33]\tvalidation's multi_logloss: 2.12758\n",
|
|||
|
"[34]\tvalidation's multi_logloss: 2.12433\n",
|
|||
|
"[35]\tvalidation's multi_logloss: 2.1213\n",
|
|||
|
"[36]\tvalidation's multi_logloss: 2.1188\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"ename": "KeyboardInterrupt",
|
|||
|
"evalue": "",
|
|||
|
"output_type": "error",
|
|||
|
"traceback": [
|
|||
|
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
|||
|
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
|||
|
"Cell \u001b[1;32mIn[14], line 21\u001b[0m\n\u001b[0;32m 18\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m🚀 启动训练...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 20\u001b[0m \u001b[38;5;66;03m# 开始训练\u001b[39;00m\n\u001b[1;32m---> 21\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 22\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_boost_round\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mTRAINING_PARAMS\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mnum_boost_round\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 23\u001b[0m \u001b[43m \u001b[49m\u001b[43mearly_stopping_rounds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mTRAINING_PARAMS\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mearly_stopping_rounds\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[0;32m 24\u001b[0m \u001b[43m)\u001b[49m\n",
|
|||
|
"Cell \u001b[1;32mIn[13], line 293\u001b[0m, in \u001b[0;36mSmartBatchTrainer.train\u001b[1;34m(self, num_boost_round, early_stopping_rounds)\u001b[0m\n\u001b[0;32m 291\u001b[0m \u001b[38;5;66;03m# 训练模型\u001b[39;00m\n\u001b[0;32m 292\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m📈 开始模型训练...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m--> 293\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel \u001b[38;5;241m=\u001b[39m \u001b[43mlgb\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 294\u001b[0m \u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 295\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_set\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrain_data\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 296\u001b[0m \u001b[43m \u001b[49m\u001b[43mnum_boost_round\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnum_boost_round\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 297\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalid_sets\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43mval_data\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 298\u001b[0m \u001b[43m \u001b[49m\u001b[43mvalid_names\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mvalidation\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 299\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtraining_callbacks\u001b[49m\n\u001b[0;32m 300\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 302\u001b[0m training_time \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime() \u001b[38;5;241m-\u001b[39m start_time\n\u001b[0;32m 304\u001b[0m \u001b[38;5;66;03m# 评估模型\u001b[39;00m\n",
|
|||
|
"File \u001b[1;32md:\\SoftWare\\Anaconda3\\envs\\b2txt25\\lib\\site-packages\\lightgbm\\engine.py:322\u001b[0m, in \u001b[0;36mtrain\u001b[1;34m(params, train_set, num_boost_round, valid_sets, valid_names, feval, init_model, keep_training_booster, callbacks)\u001b[0m\n\u001b[0;32m 310\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m cb \u001b[38;5;129;01min\u001b[39;00m callbacks_before_iter:\n\u001b[0;32m 311\u001b[0m cb(\n\u001b[0;32m 312\u001b[0m callback\u001b[38;5;241m.\u001b[39mCallbackEnv(\n\u001b[0;32m 313\u001b[0m model\u001b[38;5;241m=\u001b[39mbooster,\n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 319\u001b[0m )\n\u001b[0;32m 320\u001b[0m )\n\u001b[1;32m--> 322\u001b[0m \u001b[43mbooster\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupdate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfobj\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfobj\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 324\u001b[0m evaluation_result_list: List[_LGBM_BoosterEvalMethodResultType] \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m 325\u001b[0m \u001b[38;5;66;03m# check evaluation result.\u001b[39;00m\n",
|
|||
|
"File \u001b[1;32md:\\SoftWare\\Anaconda3\\envs\\b2txt25\\lib\\site-packages\\lightgbm\\basic.py:4155\u001b[0m, in \u001b[0;36mBooster.update\u001b[1;34m(self, train_set, fobj)\u001b[0m\n\u001b[0;32m 4152\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__set_objective_to_none:\n\u001b[0;32m 4153\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m LightGBMError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCannot update due to null objective function.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 4154\u001b[0m _safe_call(\n\u001b[1;32m-> 4155\u001b[0m \u001b[43m_LIB\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mLGBM_BoosterUpdateOneIter\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m 4156\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_handle\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 4157\u001b[0m \u001b[43m \u001b[49m\u001b[43mctypes\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbyref\u001b[49m\u001b[43m(\u001b[49m\u001b[43mis_finished\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 4158\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 4159\u001b[0m )\n\u001b[0;32m 4160\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__is_predicted_cur_iter \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;01mFalse\u001b[39;00m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m__num_dataset)]\n\u001b[0;32m 4161\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m is_finished\u001b[38;5;241m.\u001b[39mvalue \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m\n",
|
|||
|
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"\n",
|
|||
|
"# 改进的训练参数\n",
|
|||
|
"IMPROVED_TRAINING_PARAMS = {\n",
|
|||
|
" 'num_boost_round': 1, # 每批次的提升轮数\n",
|
|||
|
" 'early_stopping_rounds': 30, # 早停轮数\n",
|
|||
|
" 'n_files_per_batch': 4, # 每批次随机选择的文件数 ✨\n",
|
|||
|
" 'batch_size': 8000, # 每批次目标样本数 ✨\n",
|
|||
|
" 'max_batches': 200 # 最大批次数(可选,None为无限制)\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"# 开始使用改进的训练器\n",
|
|||
|
"model = trainer.train_incremental(\n",
|
|||
|
" num_boost_round=IMPROVED_TRAINING_PARAMS['num_boost_round'],\n",
|
|||
|
" early_stopping_rounds=IMPROVED_TRAINING_PARAMS['early_stopping_rounds'],\n",
|
|||
|
" n_files_per_batch=IMPROVED_TRAINING_PARAMS['n_files_per_batch'],\n",
|
|||
|
" batch_size=IMPROVED_TRAINING_PARAMS['batch_size'],\n",
|
|||
|
" max_batches=IMPROVED_TRAINING_PARAMS['max_batches']\n",
|
|||
|
")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 📊 训练结果分析"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 📊 训练结果分析和可视化\n",
|
|||
|
"\n",
|
|||
|
"print(\"📊 分析智能分批训练结果...\")\n",
|
|||
|
"print(\"=\" * 60)\n",
|
|||
|
"\n",
|
|||
|
"# 显示训练进度图表\n",
|
|||
|
"trainer.plot_training_progress()\n",
|
|||
|
"\n",
|
|||
|
"# 保存最终模型\n",
|
|||
|
"final_model_path = \"smart_pipeline_final_model.txt\"\n",
|
|||
|
"if trainer.model:\n",
|
|||
|
" trainer.model.save_model(final_model_path)\n",
|
|||
|
" print(f\"\\n💾 最终模型已保存: {final_model_path}\")\n",
|
|||
|
"\n",
|
|||
|
"# 详细分析\n",
|
|||
|
"if trainer.training_history:\n",
|
|||
|
" print(f\"\\n📈 详细训练分析:\")\n",
|
|||
|
" print(f\" 🎯 训练批次总数: {len(trainer.training_history)}\")\n",
|
|||
|
" \n",
|
|||
|
" # 最佳批次\n",
|
|||
|
" best_batch = max(trainer.training_history, key=lambda x: x['val_accuracy'])\n",
|
|||
|
" print(f\" 🏆 最佳验证准确率: {best_batch['val_accuracy']:.4f} (批次 {best_batch['batch']})\")\n",
|
|||
|
" \n",
|
|||
|
" # 训练效率\n",
|
|||
|
" total_training_time = sum(h['time'] for h in trainer.training_history)\n",
|
|||
|
" avg_batch_time = total_training_time / len(trainer.training_history)\n",
|
|||
|
" print(f\" ⏱️ 总训练时间: {total_training_time:.1f}秒\")\n",
|
|||
|
" print(f\" ⏱️ 平均批次时间: {avg_batch_time:.1f}秒\")\n",
|
|||
|
" \n",
|
|||
|
" # 模型复杂度\n",
|
|||
|
" final_trees = trainer.training_history[-1]['num_trees']\n",
|
|||
|
" print(f\" 🌳 最终模型树数: {final_trees}\")\n",
|
|||
|
" \n",
|
|||
|
" # 收敛性分析\n",
|
|||
|
" recent_accs = [h['val_accuracy'] for h in trainer.training_history[-3:]]\n",
|
|||
|
" if len(recent_accs) >= 2:\n",
|
|||
|
" acc_stability = max(recent_accs) - min(recent_accs)\n",
|
|||
|
" print(f\" 📈 准确率稳定性: {acc_stability:.4f} (最近3批次方差)\")\n",
|
|||
|
" \n",
|
|||
|
" if acc_stability < 0.01:\n",
|
|||
|
" print(\" ✅ 模型已收敛 (准确率变化 < 1%)\")\n",
|
|||
|
" else:\n",
|
|||
|
" print(\" ⚠️ 模型可能需要更多训练\")\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n🎉 智能分批训练分析完成!\")\n",
|
|||
|
"print(f\" 💡 使用了改进的数据平衡策略和PCA降维\")\n",
|
|||
|
"print(f\" 💡 训练集应用了下采样+过采样,验证集保持原始分布\")\n",
|
|||
|
"print(f\" 💡 实现了内存友好的分批处理\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 🧪 模型性能评估"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 🧪 模型性能评估\n",
|
|||
|
"\n",
|
|||
|
"from sklearn.metrics import classification_report, confusion_matrix\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"\n",
|
|||
|
"def evaluate_model_performance(model, pipeline, data_type='val'):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 评估模型在指定数据集上的性能\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" print(f\"🧪 评估模型在{data_type}数据集上的性能...\")\n",
|
|||
|
" \n",
|
|||
|
" # 加载数据\n",
|
|||
|
" X, y = pipeline.step3_process_data(data_type, apply_sampling=False)\n",
|
|||
|
" \n",
|
|||
|
" if X is None or y is None:\n",
|
|||
|
" print(f\"❌ 无法加载{data_type}数据\")\n",
|
|||
|
" return None\n",
|
|||
|
" \n",
|
|||
|
" print(f\" 📊 数据集大小: {X.shape[0]:,} 样本, {X.shape[1]} 特征\")\n",
|
|||
|
" \n",
|
|||
|
" # 预测\n",
|
|||
|
" start_time = time.time()\n",
|
|||
|
" y_pred_proba = model.predict(X)\n",
|
|||
|
" y_pred = y_pred_proba.argmax(axis=1)\n",
|
|||
|
" pred_time = time.time() - start_time\n",
|
|||
|
" \n",
|
|||
|
" # 计算性能指标\n",
|
|||
|
" accuracy = (y_pred == y).mean()\n",
|
|||
|
" \n",
|
|||
|
" print(f\" ⏱️ 预测时间: {pred_time:.2f}秒\")\n",
|
|||
|
" print(f\" 🎯 整体准确率: {accuracy:.4f}\")\n",
|
|||
|
" \n",
|
|||
|
" # 分析各类别性能\n",
|
|||
|
" from collections import Counter\n",
|
|||
|
" true_counts = Counter(y)\n",
|
|||
|
" pred_counts = Counter(y_pred)\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n📊 标签分布对比:\")\n",
|
|||
|
" print(\"标签 | 真实数量 | 预测数量 | 准确率\")\n",
|
|||
|
" print(\"-\" * 40)\n",
|
|||
|
" \n",
|
|||
|
" label_accuracies = {}\n",
|
|||
|
" for label in range(41):\n",
|
|||
|
" if label in true_counts:\n",
|
|||
|
" label_mask = (y == label)\n",
|
|||
|
" if label_mask.sum() > 0:\n",
|
|||
|
" label_acc = (y_pred[label_mask] == label).mean()\n",
|
|||
|
" label_accuracies[label] = label_acc\n",
|
|||
|
" true_count = true_counts.get(label, 0)\n",
|
|||
|
" pred_count = pred_counts.get(label, 0)\n",
|
|||
|
" print(f\"{label:4d} | {true_count:8,} | {pred_count:8,} | {label_acc:7.3f}\")\n",
|
|||
|
" \n",
|
|||
|
" # 重点分析关键标签\n",
|
|||
|
" print(f\"\\n🔍 关键标签性能分析:\")\n",
|
|||
|
" key_labels = [0, 40] # 下采样的标签\n",
|
|||
|
" for label in key_labels:\n",
|
|||
|
" if label in label_accuracies:\n",
|
|||
|
" acc = label_accuracies[label]\n",
|
|||
|
" count = true_counts.get(label, 0)\n",
|
|||
|
" print(f\" 标签 {label} (下采样目标): 准确率 {acc:.4f}, 样本数 {count:,}\")\n",
|
|||
|
" \n",
|
|||
|
" # 少数类性能\n",
|
|||
|
" minority_labels = [label for label, count in true_counts.items() \n",
|
|||
|
" if count < 200 and label not in [0, 40]]\n",
|
|||
|
" if minority_labels:\n",
|
|||
|
" minority_accs = [label_accuracies.get(label, 0) for label in minority_labels[:5]]\n",
|
|||
|
" avg_minority_acc = np.mean(minority_accs) if minority_accs else 0\n",
|
|||
|
" print(f\" 少数类平均准确率 (前5个): {avg_minority_acc:.4f}\")\n",
|
|||
|
" \n",
|
|||
|
" # 置信度分析\n",
|
|||
|
" max_proba = y_pred_proba.max(axis=1)\n",
|
|||
|
" print(f\"\\n📈 预测置信度分析:\")\n",
|
|||
|
" print(f\" 平均置信度: {max_proba.mean():.4f}\")\n",
|
|||
|
" print(f\" 置信度中位数: {np.median(max_proba):.4f}\")\n",
|
|||
|
" print(f\" 高置信度预测 (>0.9): {(max_proba > 0.9).sum():,} / {len(max_proba):,} ({(max_proba > 0.9).mean():.2%})\")\n",
|
|||
|
" \n",
|
|||
|
" return {\n",
|
|||
|
" 'accuracy': accuracy,\n",
|
|||
|
" 'prediction_time': pred_time,\n",
|
|||
|
" 'label_accuracies': label_accuracies,\n",
|
|||
|
" 'confidence_stats': {\n",
|
|||
|
" 'mean': max_proba.mean(),\n",
|
|||
|
" 'median': np.median(max_proba),\n",
|
|||
|
" 'high_confidence_ratio': (max_proba > 0.9).mean()\n",
|
|||
|
" }\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
"# 评估模型性能\n",
|
|||
|
"if trainer.model:\n",
|
|||
|
" print(\"🧪 开始模型性能评估...\")\n",
|
|||
|
" \n",
|
|||
|
" # 验证集评估\n",
|
|||
|
" val_results = evaluate_model_performance(trainer.model, pipeline, 'val')\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n\" + \"=\"*60)\n",
|
|||
|
" print(\"🎉 智能分批训练+数据平衡 评估完成!\")\n",
|
|||
|
" print(f\"✅ 实现了数据平衡和PCA降维的完整流程\")\n",
|
|||
|
" print(f\"✅ 使用了内存友好的分批训练策略\")\n",
|
|||
|
" print(f\"✅ 保持了验证集的原始分布以确保评估客观性\")\n",
|
|||
|
"else:\n",
|
|||
|
" print(\"❌ 模型尚未训练完成,请等待训练结束后运行此评估\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# ✅ 余弦退火已更新为带重启版本\n",
|
|||
|
"\n",
|
|||
|
"print(\"🎉 余弦退火调度器更新完成!\")\n",
|
|||
|
"\n",
|
|||
|
"# 检查trainer是否已创建,如果未创建则先创建\n",
|
|||
|
"if 'trainer' not in globals():\n",
|
|||
|
" print(\"⚠️ 训练器尚未创建,请先运行前面的代码创建训练器\")\n",
|
|||
|
"else:\n",
|
|||
|
" print(f\"✅ 当前使用:带重启的余弦退火 (SGDR)\")\n",
|
|||
|
" print(f\" 🔄 重启参数: T_0={trainer.t_0}, T_mult={trainer.t_mult}\")\n",
|
|||
|
" print(f\" 📈 学习率范围: {trainer.initial_learning_rate} → {trainer.min_learning_rate}\")\n",
|
|||
|
"\n",
|
|||
|
" # 可视化新的学习率调度\n",
|
|||
|
" import matplotlib.pyplot as plt\n",
|
|||
|
" import numpy as np\n",
|
|||
|
"\n",
|
|||
|
" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))\n",
|
|||
|
"\n",
|
|||
|
" # 模拟300轮的学习率变化\n",
|
|||
|
" rounds = list(range(300))\n",
|
|||
|
" old_lrs = [] # 原始余弦退火\n",
|
|||
|
" new_lrs = [] # 带重启的余弦退火\n",
|
|||
|
"\n",
|
|||
|
" for r in rounds:\n",
|
|||
|
" # 原始余弦退火 (单调递减)\n",
|
|||
|
" old_lr = trainer.min_learning_rate + 0.5 * (trainer.initial_learning_rate - trainer.min_learning_rate) * (1 + np.cos(np.pi * r / 300))\n",
|
|||
|
" old_lrs.append(old_lr)\n",
|
|||
|
" \n",
|
|||
|
" # 带重启的余弦退火\n",
|
|||
|
" new_lr = trainer._cosine_annealing_with_warm_restarts(r)\n",
|
|||
|
" new_lrs.append(new_lr)\n",
|
|||
|
"\n",
|
|||
|
" # 绘制对比图\n",
|
|||
|
" ax1.plot(rounds, old_lrs, 'b-', label='原始余弦退火', linewidth=2)\n",
|
|||
|
" ax1.set_xlabel('Training Round')\n",
|
|||
|
" ax1.set_ylabel('Learning Rate')\n",
|
|||
|
" ax1.set_title('原始余弦退火 (单调递减)')\n",
|
|||
|
" ax1.grid(True, alpha=0.3)\n",
|
|||
|
" ax1.legend()\n",
|
|||
|
"\n",
|
|||
|
" ax2.plot(rounds, new_lrs, 'r-', label='带重启的余弦退火', linewidth=2)\n",
|
|||
|
" ax2.set_xlabel('Training Round')\n",
|
|||
|
" ax2.set_ylabel('Learning Rate')\n",
|
|||
|
" ax2.set_title('带重启的余弦退火 (SGDR)')\n",
|
|||
|
" ax2.grid(True, alpha=0.3)\n",
|
|||
|
" ax2.legend()\n",
|
|||
|
"\n",
|
|||
|
" plt.tight_layout()\n",
|
|||
|
" plt.show()\n",
|
|||
|
"\n",
|
|||
|
" print(\"📊 学习率调度对比可视化完成\")\n",
|
|||
|
" print(\" 🔵 原始版本:单调递减的余弦曲线\")\n",
|
|||
|
" print(\" 🔴 新版本:周期性重启,每次重启后学习率回到最大值\")\n",
|
|||
|
" print(\" 💡 SGDR的优势:多次重启可以帮助模型跳出局部最优解\")\n",
|
|||
|
"\n",
|
|||
|
" # 显示重启点\n",
|
|||
|
" restart_points = []\n",
|
|||
|
" t_cur = 0\n",
|
|||
|
" t_i = trainer.t_0\n",
|
|||
|
" while t_cur < 300:\n",
|
|||
|
" restart_points.append(t_cur)\n",
|
|||
|
" t_cur += t_i\n",
|
|||
|
" t_i *= trainer.t_mult\n",
|
|||
|
"\n",
|
|||
|
" print(f\" 🔄 在300轮训练中的重启点: {restart_points[:5]}...\") # 显示前5个重启点"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kaggle": {
|
|||
|
"accelerator": "tpu1vmV38",
|
|||
|
"dataSources": [
|
|||
|
{
|
|||
|
"databundleVersionId": 13056355,
|
|||
|
"sourceId": 106809,
|
|||
|
"sourceType": "competition"
|
|||
|
}
|
|||
|
],
|
|||
|
"dockerImageVersionId": 31091,
|
|||
|
"isGpuEnabled": false,
|
|||
|
"isInternetEnabled": true,
|
|||
|
"language": "python",
|
|||
|
"sourceType": "notebook"
|
|||
|
},
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3 (ipykernel)",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.11.13"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 4
|
|||
|
}
|