1866 lines
		
	
	
		
			83 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			1866 lines
		
	
	
		
			83 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| {
 | ||
|  "cells": [
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "# 环境配置与Utils"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": null,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "%%bash\n",
 | ||
|     "rm -rf /kaggle/working/nejm-brain-to-text/\n",
 | ||
|     "git clone https://github.com/ZH-CEN/nejm-brain-to-text.git\n",
 | ||
|     "cp /kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl /kaggle/working/nejm-brain-to-text/data/t15_copyTask.pkl\n",
 | ||
|     "\n",
 | ||
|     "ln -s /kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline /kaggle/working/nejm-brain-to-text/data\n",
 | ||
|     "ln -s /kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final /kaggle/working/nejm-brain-to-text/data\n",
 | ||
|     "ln -s /kaggle/input/rnn-pretagged-data /kaggle/working/nejm-brain-to-text/data/concatenated_data\n",
 | ||
|     "\n",
 | ||
|     "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n",
 | ||
|     "\n",
 | ||
|     "pip install \\\n",
 | ||
|     "    jupyter==1.1.1 \\\n",
 | ||
|     "    \"numpy>=1.26.0,<2.1.0\" \\\n",
 | ||
|     "    pandas==2.3.0 \\\n",
 | ||
|     "    matplotlib==3.10.1 \\\n",
 | ||
|     "    scipy==1.15.2 \\\n",
 | ||
|     "    scikit-learn==1.6.1 \\\n",
 | ||
|     "    lightgbm==4.3.0 \\\n",
 | ||
|     "    tqdm==4.67.1 \\\n",
 | ||
|     "    g2p_en==2.1.0 \\\n",
 | ||
|     "    h5py==3.13.0 \\\n",
 | ||
|     "    omegaconf==2.3.0 \\\n",
 | ||
|     "    editdistance==0.8.1 \\\n",
 | ||
|     "    huggingface-hub==0.33.1 \\\n",
 | ||
|     "    transformers==4.53.0 \\\n",
 | ||
|     "    tokenizers==0.21.2 \\\n",
 | ||
|     "    accelerate==1.8.1 \\\n",
 | ||
|     "    bitsandbytes==0.46.0 \\\n",
 | ||
|     "    seaborn==0.13.2\n",
 | ||
|     "cd /kaggle/working/nejm-brain-to-text/\n",
 | ||
|     "pip install -e ."
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 2,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "==================================================\n",
 | ||
|       "🔧 LightGBM GPU环境检查\n",
 | ||
|       "==================================================\n",
 | ||
|       "❌ 未检测到NVIDIA GPU或驱动\n",
 | ||
|       "\n",
 | ||
|       "❌ 未安装CUDA工具包\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "# 🚀 LightGBM GPU支持检查与配置\n",
 | ||
|     "\n",
 | ||
|     "print(\"=\"*50)\n",
 | ||
|     "print(\"🔧 LightGBM GPU环境检查\")\n",
 | ||
|     "print(\"=\"*50)\n",
 | ||
|     "\n",
 | ||
|     "# 检查CUDA和GPU驱动\n",
 | ||
|     "import subprocess\n",
 | ||
|     "import sys\n",
 | ||
|     "\n",
 | ||
|     "def run_command(command):\n",
 | ||
|     "    \"\"\"运行命令并返回结果\"\"\"\n",
 | ||
|     "    try:\n",
 | ||
|     "        result = subprocess.run(command, shell=True, capture_output=True, text=True, timeout=10)\n",
 | ||
|     "        return result.stdout.strip(), result.returncode == 0\n",
 | ||
|     "    except Exception as e:\n",
 | ||
|     "        return str(e), False\n",
 | ||
|     "\n",
 | ||
|     "# 检查NVIDIA GPU\n",
 | ||
|     "nvidia_output, nvidia_success = run_command(\"nvidia-smi --query-gpu=name,memory.total,driver_version --format=csv,noheader,nounits\")\n",
 | ||
|     "if nvidia_success:\n",
 | ||
|     "    print(\"✅ NVIDIA GPU检测:\")\n",
 | ||
|     "    for line in nvidia_output.split('\\n'):\n",
 | ||
|     "        if line.strip():\n",
 | ||
|     "            print(f\"   {line}\")\n",
 | ||
|     "else:\n",
 | ||
|     "    print(\"❌ 未检测到NVIDIA GPU或驱动\")\n",
 | ||
|     "\n",
 | ||
|     "# 检查CUDA版本\n",
 | ||
|     "cuda_output, cuda_success = run_command(\"nvcc --version\")\n",
 | ||
|     "if cuda_success:\n",
 | ||
|     "    print(\"\\n✅ CUDA工具包:\")\n",
 | ||
|     "    # 提取CUDA版本\n",
 | ||
|     "    for line in cuda_output.split('\\n'):\n",
 | ||
|     "        if 'release' in line:\n",
 | ||
|     "            print(f\"   {line.strip()}\")\n",
 | ||
|     "else:\n",
 | ||
|     "    print(\"\\n❌ 未安装CUDA工具包\")\n"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": null,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "[WinError 3] 系统找不到指定的路径。: '/kaggle/working/nejm-brain-to-text'\n",
 | ||
|       "f:\\BRAIN-TO-TEXT\\nejm-brain-to-text\\brain-to-text-25\n"
 | ||
|      ]
 | ||
|     },
 | ||
|     {
 | ||
|      "name": "stderr",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "d:\\SoftWare\\Anaconda3\\envs\\b2txt25\\lib\\site-packages\\IPython\\core\\magics\\osm.py:393: UserWarning: This is now an optional IPython functionality, using bookmarks requires you to install the `pickleshare` library.\n",
 | ||
|       "  bkms = self.shell.db.get('bookmarks', {})\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "%cd /kaggle/working/nejm-brain-to-text\n",
 | ||
|     "import numpy as np\n",
 | ||
|     "import os\n",
 | ||
|     "import pickle\n",
 | ||
|     "import matplotlib.pyplot as plt\n",
 | ||
|     "import matplotlib\n",
 | ||
|     "from g2p_en import G2p\n",
 | ||
|     "import pandas as pd\n",
 | ||
|     "import numpy as np\n",
 | ||
|     "from nejm_b2txt_utils.general_utils import *\n",
 | ||
|     "matplotlib.rcParams['pdf.fonttype'] = 42\n",
 | ||
|     "matplotlib.rcParams['ps.fonttype'] = 42\n",
 | ||
|     "matplotlib.rcParams['font.family'] = 'sans-serif'\n",
 | ||
|     "matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial Unicode MS', 'sans-serif']\n",
 | ||
|     "matplotlib.rcParams['axes.unicode_minus'] = False\n"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": null,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stderr",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "d:\\SoftWare\\Anaconda3\\envs\\b2txt25\\lib\\site-packages\\IPython\\core\\magics\\osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n",
 | ||
|       "  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n"
 | ||
|      ]
 | ||
|     },
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "f:\\BRAIN-TO-TEXT\\nejm-brain-to-text\\model_training\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "%cd model_training/\n",
 | ||
|     "from data_augmentations import gauss_smooth"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 7,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "LOGIT_TO_PHONEME = [\n",
 | ||
|     "    'BLANK',\n",
 | ||
|     "    'AA', 'AE', 'AH', 'AO', 'AW',\n",
 | ||
|     "    'AY', 'B',  'CH', 'D', 'DH',\n",
 | ||
|     "    'EH', 'ER', 'EY', 'F', 'G',\n",
 | ||
|     "    'HH', 'IH', 'IY', 'JH', 'K',\n",
 | ||
|     "    'L', 'M', 'N', 'NG', 'OW',\n",
 | ||
|     "    'OY', 'P', 'R', 'S', 'SH',\n",
 | ||
|     "    'T', 'TH', 'UH', 'UW', 'V',\n",
 | ||
|     "    'W', 'Y', 'Z', 'ZH',\n",
 | ||
|     "    ' | ',\n",
 | ||
|     "]\n",
 | ||
|     "# 全局配置\n",
 | ||
|     "BALANCE_CONFIG = {\n",
 | ||
|     "    'enable_balance': True,        # 是否启用数据平衡\n",
 | ||
|     "    'undersample_labels': [0, 40], # 需要下采样的标签 (blank等高频标签)\n",
 | ||
|     "    'oversample_threshold': 0.5,   # 过采样阈值 (相对于均值的比例)\n",
 | ||
|     "    'random_state': 42            # 随机种子\n",
 | ||
|     "}\n",
 | ||
|     "# 全局PCA配置\n",
 | ||
|     "PCA_CONFIG = {\n",
 | ||
|     "    'enable_pca': True,           # 是否启用PCA\n",
 | ||
|     "    'n_components': None,         # None=自动选择, 或指定具体数值\n",
 | ||
|     "    'variance_threshold': 0.95,   # 保留95%的方差\n",
 | ||
|     "    'sample_size': 15000,         # 用于拟合PCA的样本数\n",
 | ||
|     "}\n",
 | ||
|     "\n",
 | ||
|     "# 全局PCA对象 (确保只拟合一次)\n",
 | ||
|     "GLOBAL_PCA = {\n",
 | ||
|     "    'scaler': None,\n",
 | ||
|     "    'pca': None,\n",
 | ||
|     "    'is_fitted': False,\n",
 | ||
|     "    'n_components': None\n",
 | ||
|     "}\n",
 | ||
|     "# 设置数据目录和参数【PCA初始化】\n",
 | ||
|     "data_dir = '/kaggle/working/nejm-brain-to-text/data/concatenated_data'\n",
 | ||
|     "MAX_SAMPLES_PER_FILE = -1  # 每个文件最大样本数,可调整"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "# 数据读取工作流"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "# 2️⃣ 数据加载与PCA降维"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 8,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# 🚀 内存友好的数据读取 - 分批加载策略 + PCA降维 【这里还缺一个采样】\n",
 | ||
|     "\n",
 | ||
|     "import os\n",
 | ||
|     "import numpy as np\n",
 | ||
|     "import gc\n",
 | ||
|     "from sklearn.decomposition import PCA\n",
 | ||
|     "from sklearn.preprocessing import StandardScaler\n",
 | ||
|     "import joblib\n",
 | ||
|     "import matplotlib.pyplot as plt\n",
 | ||
|     "\n",
 | ||
|     "\n",
 | ||
|     "def load_data_batch(data_dir, data_type, max_samples_per_file=5000):\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    分批加载指定类型的数据\n",
 | ||
|     "    \n",
 | ||
|     "    Args:\n",
 | ||
|     "        data_dir: 数据目录\n",
 | ||
|     "        data_type: 'train', 'val', 'test'\n",
 | ||
|     "        max_samples_per_file: 每个文件最大加载样本数\n",
 | ||
|     "    \n",
 | ||
|     "    Returns:\n",
 | ||
|     "        generator: 数据批次生成器\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    files = [f for f in os.listdir(data_dir) if f.endswith('.npz') and data_type in f]\n",
 | ||
|     "    \n",
 | ||
|     "    for file_idx, f in enumerate(files):\n",
 | ||
|     "        print(f\"  正在加载文件 {file_idx+1}/{len(files)}: {f}\")\n",
 | ||
|     "        \n",
 | ||
|     "        data = np.load(os.path.join(data_dir, f), allow_pickle=True)\n",
 | ||
|     "        trials = data['neural_logits_concatenated']\n",
 | ||
|     "        \n",
 | ||
|     "        # 限制每个文件的样本数\n",
 | ||
|     "        if len(trials) > max_samples_per_file and max_samples_per_file != -1:\n",
 | ||
|     "            trials = trials[:max_samples_per_file]\n",
 | ||
|     "            print(f\"    限制样本数至: {max_samples_per_file}\")\n",
 | ||
|     "        \n",
 | ||
|     "        yield trials, f\n",
 | ||
|     "        \n",
 | ||
|     "        # 清理内存\n",
 | ||
|     "        del data, trials\n",
 | ||
|     "        gc.collect()\n",
 | ||
|     "\n",
 | ||
|     "def extract_features_labels_batch(trials_batch):\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    从试验批次中提取特征和标签\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    features = []\n",
 | ||
|     "    labels = []\n",
 | ||
|     "    \n",
 | ||
|     "    for trial in trials_batch:\n",
 | ||
|     "        if trial.shape[0] > 0:\n",
 | ||
|     "            for t in range(trial.shape[0]):\n",
 | ||
|     "                neural_features = trial[t, :7168]  # 前7168维神经特征\n",
 | ||
|     "                rnn_logits = trial[t, 7168:]       # 后41维RNN输出\n",
 | ||
|     "                phoneme_label = np.argmax(rnn_logits)\n",
 | ||
|     "                \n",
 | ||
|     "                features.append(neural_features)\n",
 | ||
|     "                labels.append(phoneme_label)\n",
 | ||
|     "    \n",
 | ||
|     "    return np.array(features), np.array(labels)\n",
 | ||
|     "\n",
 | ||
|     "def fit_global_pca(data_dir, config):\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    在训练数据上拟合全局PCA (只执行一次)\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    if GLOBAL_PCA['is_fitted'] or not config['enable_pca']:\n",
 | ||
|     "        print(\"🔧 PCA已拟合或未启用,跳过拟合步骤\")\n",
 | ||
|     "        return\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"\\n🔧 拟合全局PCA降维器...\")\n",
 | ||
|     "    print(f\"   配置: {config}\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 收集训练样本\n",
 | ||
|     "    sample_features = []\n",
 | ||
|     "    collected_samples = 0\n",
 | ||
|     "    \n",
 | ||
|     "    for trials_batch, filename in load_data_batch(data_dir, 'train', 5000):\n",
 | ||
|     "        features, labels = extract_features_labels_batch(trials_batch)\n",
 | ||
|     "        sample_features.append(features)\n",
 | ||
|     "        collected_samples += features.shape[0]\n",
 | ||
|     "        \n",
 | ||
|     "        if collected_samples >= config['sample_size']:\n",
 | ||
|     "            break\n",
 | ||
|     "    \n",
 | ||
|     "    if sample_features:\n",
 | ||
|     "        # 合并样本数据\n",
 | ||
|     "        X_sample = np.vstack(sample_features)[:config['sample_size']]\n",
 | ||
|     "        print(f\"   实际样本数: {X_sample.shape[0]}\")\n",
 | ||
|     "        print(f\"   原始特征数: {X_sample.shape[1]}\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 标准化\n",
 | ||
|     "        GLOBAL_PCA['scaler'] = StandardScaler()\n",
 | ||
|     "        X_sample_scaled = GLOBAL_PCA['scaler'].fit_transform(X_sample)\n",
 | ||
|     "        \n",
 | ||
|     "        # 确定PCA成分数\n",
 | ||
|     "        if config['n_components'] is None:\n",
 | ||
|     "            print(f\"   🔍 自动选择PCA成分数...\")\n",
 | ||
|     "            pca_full = PCA()\n",
 | ||
|     "            pca_full.fit(X_sample_scaled)\n",
 | ||
|     "            \n",
 | ||
|     "            cumsum_var = np.cumsum(pca_full.explained_variance_ratio_)\n",
 | ||
|     "            optimal_components = np.argmax(cumsum_var >= config['variance_threshold']) + 1\n",
 | ||
|     "            GLOBAL_PCA['n_components'] = min(optimal_components, X_sample.shape[1])\n",
 | ||
|     "            \n",
 | ||
|     "            print(f\"   保留{config['variance_threshold']*100}%方差需要: {optimal_components} 个成分\")\n",
 | ||
|     "            print(f\"   选择成分数: {GLOBAL_PCA['n_components']}\")\n",
 | ||
|     "        else:\n",
 | ||
|     "            GLOBAL_PCA['n_components'] = config['n_components']\n",
 | ||
|     "            print(f\"   使用指定成分数: {GLOBAL_PCA['n_components']}\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 拟合最终PCA\n",
 | ||
|     "        GLOBAL_PCA['pca'] = PCA(n_components=GLOBAL_PCA['n_components'], random_state=42)\n",
 | ||
|     "        GLOBAL_PCA['pca'].fit(X_sample_scaled)\n",
 | ||
|     "        GLOBAL_PCA['is_fitted'] = True\n",
 | ||
|     "        \n",
 | ||
|     "        # 保存模型\n",
 | ||
|     "        pca_path = \"global_pca_model.joblib\"\n",
 | ||
|     "        joblib.dump({\n",
 | ||
|     "            'scaler': GLOBAL_PCA['scaler'], \n",
 | ||
|     "            'pca': GLOBAL_PCA['pca'],\n",
 | ||
|     "            'n_components': GLOBAL_PCA['n_components']\n",
 | ||
|     "        }, pca_path)\n",
 | ||
|     "        \n",
 | ||
|     "        print(f\"   ✅ 全局PCA拟合完成!\")\n",
 | ||
|     "        print(f\"      降维: {X_sample.shape[1]} → {GLOBAL_PCA['n_components']}\")\n",
 | ||
|     "        print(f\"      降维比例: {GLOBAL_PCA['n_components']/X_sample.shape[1]:.2%}\")\n",
 | ||
|     "        print(f\"      保留方差: {GLOBAL_PCA['pca'].explained_variance_ratio_.sum():.4f}\")\n",
 | ||
|     "        print(f\"      模型已保存: {pca_path}\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 清理样本数据\n",
 | ||
|     "        del sample_features, X_sample, X_sample_scaled\n",
 | ||
|     "        gc.collect()\n",
 | ||
|     "    else:\n",
 | ||
|     "        print(\"❌ 无法收集样本数据用于PCA拟合\")\n",
 | ||
|     "\n",
 | ||
|     "def apply_pca_transform(features):\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    应用全局PCA变换\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    if not PCA_CONFIG['enable_pca'] or not GLOBAL_PCA['is_fitted']:\n",
 | ||
|     "        return features\n",
 | ||
|     "    \n",
 | ||
|     "    # 标准化 + PCA变换\n",
 | ||
|     "    features_scaled = GLOBAL_PCA['scaler'].transform(features)\n",
 | ||
|     "    features_pca = GLOBAL_PCA['pca'].transform(features_scaled)\n",
 | ||
|     "    return features_pca\n",
 | ||
|     "\n"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "## 📊 数据平衡策略 - 标签分布分析与采样优化"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 9,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# 【采样核心实现】\n",
 | ||
|     "def balance_dataset(X, y, config=BALANCE_CONFIG):\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    对数据集进行平衡处理:下采样 + 过采样\n",
 | ||
|     "    \n",
 | ||
|     "    Args:\n",
 | ||
|     "        X: 特征数据\n",
 | ||
|     "        y: 标签数据\n",
 | ||
|     "        config: 平衡配置\n",
 | ||
|     "    \n",
 | ||
|     "    Returns:\n",
 | ||
|     "        X_balanced, y_balanced: 平衡后的数据\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    if not config['enable_balance']:\n",
 | ||
|     "        print(\"🔕 数据平衡已禁用,返回原始数据\")\n",
 | ||
|     "        return X, y\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"\\n⚖️ 开始数据平衡处理...\")\n",
 | ||
|     "    print(f\"   原始数据: {X.shape[0]:,} 样本\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 分析当前分布 (只考虑1-39号标签的均值)\n",
 | ||
|     "    label_counts = Counter(y)\n",
 | ||
|     "    counts_exclude_0_40 = [label_counts.get(i, 0) for i in range(1, 40)]  # 1-39号标签\n",
 | ||
|     "    mean_count = np.mean(counts_exclude_0_40)  # 只计算1-39号标签的均值\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"   均值样本数 (标签1-39): {mean_count:.0f}\")\n",
 | ||
|     "    print(f\"   下采样标签: {config['undersample_labels']}\")\n",
 | ||
|     "    print(f\"   过采样阈值: {config['oversample_threshold']} * 均值\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 准备平衡后的数据\n",
 | ||
|     "    X_balanced = []\n",
 | ||
|     "    y_balanced = []\n",
 | ||
|     "    \n",
 | ||
|     "    random.seed(config['random_state'])\n",
 | ||
|     "    np.random.seed(config['random_state'])\n",
 | ||
|     "    \n",
 | ||
|     "    for label in range(41):\n",
 | ||
|     "        # 获取当前标签的所有样本\n",
 | ||
|     "        label_mask = (y == label)\n",
 | ||
|     "        X_label = X[label_mask]\n",
 | ||
|     "        y_label = y[label_mask]\n",
 | ||
|     "        current_count = len(y_label)\n",
 | ||
|     "        \n",
 | ||
|     "        if current_count == 0:\n",
 | ||
|     "            continue\n",
 | ||
|     "        \n",
 | ||
|     "        # 决定采样策略\n",
 | ||
|     "        if label in config['undersample_labels']:\n",
 | ||
|     "            # 下采样到均值水平\n",
 | ||
|     "            target_count = int(mean_count)\n",
 | ||
|     "            if current_count > target_count:\n",
 | ||
|     "                # 下采样\n",
 | ||
|     "                indices = np.random.choice(current_count, target_count, replace=False)\n",
 | ||
|     "                X_resampled = X_label[indices]\n",
 | ||
|     "                y_resampled = y_label[indices]\n",
 | ||
|     "                print(f\"   📉 标签 {label}: {current_count} → {target_count} (下采样)\")\n",
 | ||
|     "            else:\n",
 | ||
|     "                X_resampled = X_label\n",
 | ||
|     "                y_resampled = y_label\n",
 | ||
|     "                print(f\"   ➡️ 标签 {label}: {current_count} (无需下采样)\")\n",
 | ||
|     "                \n",
 | ||
|     "        elif current_count < mean_count * config['oversample_threshold']:\n",
 | ||
|     "            # 过采样到阈值水平\n",
 | ||
|     "            target_count = int(mean_count * config['oversample_threshold'])\n",
 | ||
|     "            if current_count < target_count:\n",
 | ||
|     "                # 过采样\n",
 | ||
|     "                X_resampled, y_resampled = resample(\n",
 | ||
|     "                    X_label, y_label, \n",
 | ||
|     "                    n_samples=target_count, \n",
 | ||
|     "                    random_state=config['random_state']\n",
 | ||
|     "                )\n",
 | ||
|     "                print(f\"   📈 标签 {label}: {current_count} → {target_count} (过采样)\")\n",
 | ||
|     "            else:\n",
 | ||
|     "                X_resampled = X_label\n",
 | ||
|     "                y_resampled = y_label\n",
 | ||
|     "                print(f\"   ➡️ 标签 {label}: {current_count} (无需过采样)\")\n",
 | ||
|     "        else:\n",
 | ||
|     "            # 保持不变\n",
 | ||
|     "            X_resampled = X_label\n",
 | ||
|     "            y_resampled = y_label\n",
 | ||
|     "            print(f\"   ✅ 标签 {label}: {current_count} (已平衡)\")\n",
 | ||
|     "        \n",
 | ||
|     "        X_balanced.append(X_resampled)\n",
 | ||
|     "        y_balanced.append(y_resampled)\n",
 | ||
|     "    \n",
 | ||
|     "    # 合并所有平衡后的数据\n",
 | ||
|     "    X_balanced = np.vstack(X_balanced)\n",
 | ||
|     "    y_balanced = np.hstack(y_balanced)\n",
 | ||
|     "    \n",
 | ||
|     "    # 随机打乱\n",
 | ||
|     "    shuffle_indices = np.random.permutation(len(y_balanced))\n",
 | ||
|     "    X_balanced = X_balanced[shuffle_indices]\n",
 | ||
|     "    y_balanced = y_balanced[shuffle_indices]\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"   ✅ 平衡完成: {X_balanced.shape[0]:,} 样本\")\n",
 | ||
|     "    print(f\"   数据变化: {X.shape[0]:,} → {X_balanced.shape[0]:,} ({X_balanced.shape[0]/X.shape[0]:.2f}x)\")\n",
 | ||
|     "    \n",
 | ||
|     "    return X_balanced, y_balanced\n"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "## 🔄 集成数据平衡的内存友好数据加载器"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "## 🧪 数据平衡效果测试"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "## 🚀 改进版智能数据处理管道"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 13,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "🚀 创建智能数据处理管道...\n",
 | ||
|       "✅ 管道创建完成,准备执行步骤1...\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "# 🚀 改进版智能数据处理管道【没有解决分批训练的问题】\n",
 | ||
|     "# 流程:分析分布 → 确定采样比率 → 拟合PCA(只下采样) → 数据处理(下采样+上采样+PCA)\n",
 | ||
|     "\n",
 | ||
|     "import numpy as np\n",
 | ||
|     "import matplotlib.pyplot as plt\n",
 | ||
|     "from collections import Counter\n",
 | ||
|     "from sklearn.utils import resample\n",
 | ||
|     "from sklearn.decomposition import PCA\n",
 | ||
|     "from sklearn.preprocessing import StandardScaler\n",
 | ||
|     "import joblib\n",
 | ||
|     "import random\n",
 | ||
|     "import gc\n",
 | ||
|     "\n",
 | ||
|     "class SmartDataPipeline:\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    智能数据处理管道\n",
 | ||
|     "    步骤1: 分析数据分布,确定采样策略\n",
 | ||
|     "    步骤2: 仅下采样拟合PCA参数\n",
 | ||
|     "    步骤3: 数据处理时应用完整采样+PCA降维\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    \n",
 | ||
|     "    def __init__(self, data_dir, random_state=42):\n",
 | ||
|     "        self.data_dir = data_dir\n",
 | ||
|     "        self.random_state = random_state\n",
 | ||
|     "        \n",
 | ||
|     "        # 步骤1: 分布分析结果\n",
 | ||
|     "        self.distribution_analysis = None\n",
 | ||
|     "        self.sampling_strategy = None\n",
 | ||
|     "        \n",
 | ||
|     "        # 步骤2: PCA参数(基于下采样数据拟合)\n",
 | ||
|     "        self.pca_scaler = None\n",
 | ||
|     "        self.pca_model = None\n",
 | ||
|     "        self.pca_components = None\n",
 | ||
|     "        self.pca_fitted = False\n",
 | ||
|     "        \n",
 | ||
|     "        # 配置参数\n",
 | ||
|     "        self.undersample_labels = [0, 40]  # 需要下采样的标签\n",
 | ||
|     "        self.oversample_threshold = 0.5    # 过采样阈值(相对于均值)\n",
 | ||
|     "        self.pca_variance_threshold = 0.95 # PCA保留方差比例\n",
 | ||
|     "        self.pca_sample_size = 15000       # PCA拟合样本数\n",
 | ||
|     "        \n",
 | ||
|     "    def step1_analyze_distribution(self, max_samples=100000):\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        步骤1: 分析数据分布,确定采样策略\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        print(\"🔍 步骤1: 分析数据分布...\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 分析验证集分布(代表整体分布特征)\n",
 | ||
|     "        all_labels = []\n",
 | ||
|     "        for trials_batch, filename in load_data_batch(self.data_dir, 'val', 5000):\n",
 | ||
|     "            _, labels = extract_features_labels_batch(trials_batch)\n",
 | ||
|     "            all_labels.extend(labels.tolist())\n",
 | ||
|     "            if len(all_labels) >= max_samples:\n",
 | ||
|     "                break\n",
 | ||
|     "        \n",
 | ||
|     "        # 统计分析\n",
 | ||
|     "        label_counts = Counter(all_labels)\n",
 | ||
|     "        \n",
 | ||
|     "        # 计算1-39标签的均值(排除0和40)\n",
 | ||
|     "        counts_1_39 = [label_counts.get(i, 0) for i in range(1, 40)]\n",
 | ||
|     "        target_mean = np.mean(counts_1_39)\n",
 | ||
|     "        \n",
 | ||
|     "        # 生成采样策略\n",
 | ||
|     "        sampling_strategy = {}\n",
 | ||
|     "        for label in range(41):\n",
 | ||
|     "            current_count = label_counts.get(label, 0)\n",
 | ||
|     "            \n",
 | ||
|     "            if label in self.undersample_labels:\n",
 | ||
|     "                # 下采样到均值水平\n",
 | ||
|     "                target_count = int(target_mean)\n",
 | ||
|     "                action = 'undersample' if current_count > target_count else 'keep'\n",
 | ||
|     "            elif current_count < target_mean * self.oversample_threshold:\n",
 | ||
|     "                # 过采样到阈值水平\n",
 | ||
|     "                target_count = int(target_mean * self.oversample_threshold)\n",
 | ||
|     "                action = 'oversample' if current_count < target_count else 'keep'\n",
 | ||
|     "            else:\n",
 | ||
|     "                # 保持不变\n",
 | ||
|     "                target_count = current_count\n",
 | ||
|     "                action = 'keep'\n",
 | ||
|     "            \n",
 | ||
|     "            sampling_strategy[label] = {\n",
 | ||
|     "                'current_count': current_count,\n",
 | ||
|     "                'target_count': target_count,\n",
 | ||
|     "                'action': action\n",
 | ||
|     "            }\n",
 | ||
|     "        \n",
 | ||
|     "        self.distribution_analysis = {\n",
 | ||
|     "            'label_counts': label_counts,\n",
 | ||
|     "            'target_mean': target_mean,\n",
 | ||
|     "            'total_samples': len(all_labels)\n",
 | ||
|     "        }\n",
 | ||
|     "        self.sampling_strategy = sampling_strategy\n",
 | ||
|     "        \n",
 | ||
|     "        print(f\"   ✅ 分析完成: {len(all_labels):,} 样本\")\n",
 | ||
|     "        print(f\"   📊 标签1-39均值: {target_mean:.0f}\")\n",
 | ||
|     "        print(f\"   📉 下采样标签: {self.undersample_labels} → {target_mean:.0f}\")\n",
 | ||
|     "        print(f\"   📈 过采样阈值: {self.oversample_threshold} × 均值 = {target_mean * self.oversample_threshold:.0f}\")\n",
 | ||
|     "        \n",
 | ||
|     "        return self.distribution_analysis, self.sampling_strategy\n",
 | ||
|     "\n",
 | ||
|     "# 创建智能数据处理管道\n",
 | ||
|     "print(\"🚀 创建智能数据处理管道...\")\n",
 | ||
|     "pipeline = SmartDataPipeline(data_dir, random_state=42)\n",
 | ||
|     "print(\"✅ 管道创建完成,准备执行步骤1...\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 14,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "✅ 步骤2方法已添加到管道\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "# 继续添加智能管道的其他方法【管道完善】\n",
 | ||
|     "\n",
 | ||
|     "def step2_fit_pca_with_undersampling(self):\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    步骤2: 仅对下采样数据拟合PCA参数(不进行过采样,避免PCA被过采样影响)\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    if self.sampling_strategy is None:\n",
 | ||
|     "        raise ValueError(\"请先执行步骤1: step1_analyze_distribution()\")\n",
 | ||
|     "    \n",
 | ||
|     "    print(\"\\n🔧 步骤2: 拟合PCA参数(仅下采样,不过采样)...\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 收集用于PCA拟合的样本(只下采样,不过采样)\n",
 | ||
|     "    pca_features = []\n",
 | ||
|     "    collected_samples = 0\n",
 | ||
|     "    \n",
 | ||
|     "    for trials_batch, filename in load_data_batch(self.data_dir, 'train', 3000):\n",
 | ||
|     "        features, labels = extract_features_labels_batch(trials_batch)\n",
 | ||
|     "        \n",
 | ||
|     "        # 对当前批次应用仅下采样策略\n",
 | ||
|     "        downsampled_features, downsampled_labels = self._apply_undersampling_only(features, labels)\n",
 | ||
|     "        \n",
 | ||
|     "        if downsampled_features.shape[0] > 0:\n",
 | ||
|     "            pca_features.append(downsampled_features)\n",
 | ||
|     "            collected_samples += downsampled_features.shape[0]\n",
 | ||
|     "        \n",
 | ||
|     "        if collected_samples >= self.pca_sample_size:\n",
 | ||
|     "            break\n",
 | ||
|     "    \n",
 | ||
|     "    if pca_features:\n",
 | ||
|     "        # 合并样本\n",
 | ||
|     "        X_pca_sample = np.vstack(pca_features)[:self.pca_sample_size]\n",
 | ||
|     "        print(f\"   📦 PCA拟合样本: {X_pca_sample.shape[0]:,} 个下采样样本\")\n",
 | ||
|     "        print(f\"   🔢 原始特征维度: {X_pca_sample.shape[1]}\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 标准化\n",
 | ||
|     "        self.pca_scaler = StandardScaler()\n",
 | ||
|     "        X_scaled = self.pca_scaler.fit_transform(X_pca_sample)\n",
 | ||
|     "        \n",
 | ||
|     "        # 确定PCA成分数\n",
 | ||
|     "        pca_full = PCA()\n",
 | ||
|     "        pca_full.fit(X_scaled)\n",
 | ||
|     "        cumsum_var = np.cumsum(pca_full.explained_variance_ratio_)\n",
 | ||
|     "        optimal_components = np.argmax(cumsum_var >= self.pca_variance_threshold) + 1\n",
 | ||
|     "        self.pca_components = min(optimal_components, X_pca_sample.shape[1])\n",
 | ||
|     "        \n",
 | ||
|     "        # 拟合最终PCA\n",
 | ||
|     "        self.pca_model = PCA(n_components=self.pca_components, random_state=self.random_state)\n",
 | ||
|     "        self.pca_model.fit(X_scaled)\n",
 | ||
|     "        self.pca_fitted = True\n",
 | ||
|     "        \n",
 | ||
|     "        # 保存PCA模型\n",
 | ||
|     "        pca_path = \"smart_pipeline_pca.joblib\"\n",
 | ||
|     "        joblib.dump({\n",
 | ||
|     "            'scaler': self.pca_scaler,\n",
 | ||
|     "            'pca': self.pca_model,\n",
 | ||
|     "            'components': self.pca_components\n",
 | ||
|     "        }, pca_path)\n",
 | ||
|     "        \n",
 | ||
|     "        print(f\"   ✅ PCA拟合完成!\")\n",
 | ||
|     "        print(f\"      降维: {X_pca_sample.shape[1]} → {self.pca_components}\")\n",
 | ||
|     "        print(f\"      降维比例: {self.pca_components/X_pca_sample.shape[1]:.2%}\")\n",
 | ||
|     "        print(f\"      保留方差: {self.pca_model.explained_variance_ratio_.sum():.4f}\")\n",
 | ||
|     "        print(f\"      模型保存: {pca_path}\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 清理内存\n",
 | ||
|     "        del pca_features, X_pca_sample, X_scaled\n",
 | ||
|     "        gc.collect()\n",
 | ||
|     "    else:\n",
 | ||
|     "        raise ValueError(\"无法收集PCA拟合样本\")\n",
 | ||
|     "\n",
 | ||
|     "def _apply_undersampling_only(self, X, y):\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    仅应用下采样策略(用于PCA拟合)\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    X_result = []\n",
 | ||
|     "    y_result = []\n",
 | ||
|     "    \n",
 | ||
|     "    np.random.seed(self.random_state)\n",
 | ||
|     "    \n",
 | ||
|     "    for label in range(41):\n",
 | ||
|     "        label_mask = (y == label)\n",
 | ||
|     "        X_label = X[label_mask]\n",
 | ||
|     "        y_label = y[label_mask]\n",
 | ||
|     "        current_count = len(y_label)\n",
 | ||
|     "        \n",
 | ||
|     "        if current_count == 0:\n",
 | ||
|     "            continue\n",
 | ||
|     "        \n",
 | ||
|     "        strategy = self.sampling_strategy[label]\n",
 | ||
|     "        \n",
 | ||
|     "        if strategy['action'] == 'undersample' and current_count > strategy['target_count']:\n",
 | ||
|     "            # 下采样\n",
 | ||
|     "            indices = np.random.choice(current_count, strategy['target_count'], replace=False)\n",
 | ||
|     "            X_resampled = X_label[indices]\n",
 | ||
|     "            y_resampled = y_label[indices]\n",
 | ||
|     "        else:\n",
 | ||
|     "            # 保持原样\n",
 | ||
|     "            X_resampled = X_label\n",
 | ||
|     "            y_resampled = y_label\n",
 | ||
|     "        \n",
 | ||
|     "        X_result.append(X_resampled)\n",
 | ||
|     "        y_result.append(y_resampled)\n",
 | ||
|     "    \n",
 | ||
|     "    if X_result:\n",
 | ||
|     "        return np.vstack(X_result), np.hstack(y_result)\n",
 | ||
|     "    else:\n",
 | ||
|     "        return np.array([]).reshape(0, X.shape[1]), np.array([])\n",
 | ||
|     "\n",
 | ||
|     "# 动态添加方法到类\n",
 | ||
|     "SmartDataPipeline.step2_fit_pca_with_undersampling = step2_fit_pca_with_undersampling\n",
 | ||
|     "SmartDataPipeline._apply_undersampling_only = _apply_undersampling_only\n",
 | ||
|     "\n",
 | ||
|     "print(\"✅ 步骤2方法已添加到管道\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 15,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "✅ 所有方法已添加到智能管道\n",
 | ||
|       "\n",
 | ||
|       "📋 智能数据处理管道状态:\n",
 | ||
|       "   🔍 步骤1 - 分布分析: ❌ 未完成\n",
 | ||
|       "   🔧 步骤2 - PCA拟合: ❌ 未完成\n",
 | ||
|       "\n",
 | ||
|       "🎯 使用流程:\n",
 | ||
|       "   1. pipeline.step1_analyze_distribution()\n",
 | ||
|       "   2. pipeline.step2_fit_pca_with_undersampling()\n",
 | ||
|       "   3. pipeline.step3_process_data('train')  # 训练集\n",
 | ||
|       "      pipeline.step3_process_data('val')    # 验证集\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "# 添加智能管道的剩余方法\n",
 | ||
|     "\n",
 | ||
|     "def _apply_full_sampling(self, X, y):\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    应用完整的采样策略(下采样+过采样)\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    X_result = []\n",
 | ||
|     "    y_result = []\n",
 | ||
|     "    \n",
 | ||
|     "    np.random.seed(self.random_state)\n",
 | ||
|     "    \n",
 | ||
|     "    for label in range(41):\n",
 | ||
|     "        label_mask = (y == label)\n",
 | ||
|     "        X_label = X[label_mask]\n",
 | ||
|     "        y_label = y[label_mask]\n",
 | ||
|     "        current_count = len(y_label)\n",
 | ||
|     "        \n",
 | ||
|     "        if current_count == 0:\n",
 | ||
|     "            continue\n",
 | ||
|     "        \n",
 | ||
|     "        strategy = self.sampling_strategy[label]\n",
 | ||
|     "        target_count = strategy['target_count']\n",
 | ||
|     "        \n",
 | ||
|     "        if strategy['action'] == 'undersample' and current_count > target_count:\n",
 | ||
|     "            # 下采样\n",
 | ||
|     "            indices = np.random.choice(current_count, target_count, replace=False)\n",
 | ||
|     "            X_resampled = X_label[indices]\n",
 | ||
|     "            y_resampled = y_label[indices]\n",
 | ||
|     "        elif strategy['action'] == 'oversample' and current_count < target_count:\n",
 | ||
|     "            # 过采样\n",
 | ||
|     "            X_resampled, y_resampled = resample(\n",
 | ||
|     "                X_label, y_label, \n",
 | ||
|     "                n_samples=target_count, \n",
 | ||
|     "                random_state=self.random_state\n",
 | ||
|     "            )\n",
 | ||
|     "        else:\n",
 | ||
|     "            # 保持原样\n",
 | ||
|     "            X_resampled = X_label\n",
 | ||
|     "            y_resampled = y_label\n",
 | ||
|     "        \n",
 | ||
|     "        X_result.append(X_resampled)\n",
 | ||
|     "        y_result.append(y_resampled)\n",
 | ||
|     "    \n",
 | ||
|     "    if X_result:\n",
 | ||
|     "        return np.vstack(X_result), np.hstack(y_result)\n",
 | ||
|     "    else:\n",
 | ||
|     "        return np.array([]).reshape(0, X.shape[1]), np.array([])\n",
 | ||
|     "\n",
 | ||
|     "def _apply_pca_transform(self, X):\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    应用PCA变换\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    if not self.pca_fitted:\n",
 | ||
|     "        return X\n",
 | ||
|     "    \n",
 | ||
|     "    X_scaled = self.pca_scaler.transform(X)\n",
 | ||
|     "    X_pca = self.pca_model.transform(X_scaled)\n",
 | ||
|     "    return X_pca\n",
 | ||
|     "\n",
 | ||
|     "def step3_process_data(self, data_type, apply_sampling=None):\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    步骤3: 处理数据(采样+PCA降维)\n",
 | ||
|     "    \n",
 | ||
|     "    Args:\n",
 | ||
|     "        data_type: 'train', 'val', 'test'\n",
 | ||
|     "        apply_sampling: 是否应用采样策略,None=训练集应用,验证/测试集不应用\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    if not self.pca_fitted:\n",
 | ||
|     "        raise ValueError(\"请先执行步骤2: step2_fit_pca_with_undersampling()\")\n",
 | ||
|     "    \n",
 | ||
|     "    if apply_sampling is None:\n",
 | ||
|     "        apply_sampling = (data_type == 'train')\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"\\n🔄 步骤3: 处理{data_type}数据...\")\n",
 | ||
|     "    print(f\"   采样策略: {'启用' if apply_sampling else '禁用'}\")\n",
 | ||
|     "    \n",
 | ||
|     "    all_features = []\n",
 | ||
|     "    all_labels = []\n",
 | ||
|     "    \n",
 | ||
|     "    for trials_batch, filename in load_data_batch(self.data_dir, data_type, 3000):\n",
 | ||
|     "        features, labels = extract_features_labels_batch(trials_batch)\n",
 | ||
|     "        \n",
 | ||
|     "        # 应用采样策略\n",
 | ||
|     "        if apply_sampling:\n",
 | ||
|     "            features_sampled, labels_sampled = self._apply_full_sampling(features, labels)\n",
 | ||
|     "        else:\n",
 | ||
|     "            features_sampled, labels_sampled = features, labels\n",
 | ||
|     "        \n",
 | ||
|     "        # 应用PCA降维\n",
 | ||
|     "        if features_sampled.shape[0] > 0:\n",
 | ||
|     "            features_pca = self._apply_pca_transform(features_sampled)\n",
 | ||
|     "            all_features.append(features_pca)\n",
 | ||
|     "            all_labels.append(labels_sampled)\n",
 | ||
|     "    \n",
 | ||
|     "    if all_features:\n",
 | ||
|     "        X = np.vstack(all_features)\n",
 | ||
|     "        y = np.hstack(all_labels)\n",
 | ||
|     "        \n",
 | ||
|     "        # 随机打乱\n",
 | ||
|     "        shuffle_indices = np.random.permutation(len(y))\n",
 | ||
|     "        X = X[shuffle_indices]\n",
 | ||
|     "        y = y[shuffle_indices]\n",
 | ||
|     "        \n",
 | ||
|     "        print(f\"   ✅ 处理完成: {X.shape[0]:,} 样本, {X.shape[1]} 特征\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 清理内存\n",
 | ||
|     "        del all_features, all_labels\n",
 | ||
|     "        gc.collect()\n",
 | ||
|     "        \n",
 | ||
|     "        return X, y\n",
 | ||
|     "    else:\n",
 | ||
|     "        return None, None\n",
 | ||
|     "\n",
 | ||
|     "def print_summary(self):\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    打印管道状态总结\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    print(\"\\n📋 智能数据处理管道状态:\")\n",
 | ||
|     "    print(f\"   🔍 步骤1 - 分布分析: {'✅ 完成' if self.distribution_analysis else '❌ 未完成'}\")\n",
 | ||
|     "    print(f\"   🔧 步骤2 - PCA拟合: {'✅ 完成' if self.pca_fitted else '❌ 未完成'}\")\n",
 | ||
|     "    \n",
 | ||
|     "    if self.distribution_analysis:\n",
 | ||
|     "        target_mean = self.distribution_analysis['target_mean']\n",
 | ||
|     "        print(f\"   📊 标签1-39均值: {target_mean:.0f}\")\n",
 | ||
|     "    \n",
 | ||
|     "    if self.pca_fitted:\n",
 | ||
|     "        print(f\"   🔬 PCA降维: 7168 → {self.pca_components} ({self.pca_components/7168:.1%})\")\n",
 | ||
|     "        print(f\"   📈 保留方差: {self.pca_model.explained_variance_ratio_.sum():.4f}\")\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"\\n🎯 使用流程:\")\n",
 | ||
|     "    print(f\"   1. pipeline.step1_analyze_distribution()\")\n",
 | ||
|     "    print(f\"   2. pipeline.step2_fit_pca_with_undersampling()\")\n",
 | ||
|     "    print(f\"   3. pipeline.step3_process_data('train')  # 训练集\")\n",
 | ||
|     "    print(f\"      pipeline.step3_process_data('val')    # 验证集\")\n",
 | ||
|     "\n",
 | ||
|     "# 动态添加剩余方法到类\n",
 | ||
|     "SmartDataPipeline._apply_full_sampling = _apply_full_sampling\n",
 | ||
|     "SmartDataPipeline._apply_pca_transform = _apply_pca_transform\n",
 | ||
|     "SmartDataPipeline.step3_process_data = step3_process_data\n",
 | ||
|     "SmartDataPipeline.print_summary = print_summary\n",
 | ||
|     "\n",
 | ||
|     "print(\"✅ 所有方法已添加到智能管道\")\n",
 | ||
|     "pipeline.print_summary()"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "## 🔥 执行智能数据处理管道"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 16,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "🚀 开始执行智能数据处理管道...\n",
 | ||
|       "============================================================\n",
 | ||
|       "\n",
 | ||
|       "======================🔍 STEP 1: 分析数据分布======================\n",
 | ||
|       "🔍 步骤1: 分析数据分布...\n"
 | ||
|      ]
 | ||
|     },
 | ||
|     {
 | ||
|      "ename": "FileNotFoundError",
 | ||
|      "evalue": "[WinError 3] 系统找不到指定的路径。: '/kaggle/working/nejm-brain-to-text/data/concatenated_data'",
 | ||
|      "output_type": "error",
 | ||
|      "traceback": [
 | ||
|       "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
 | ||
|       "\u001b[1;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
 | ||
|       "Cell \u001b[1;32mIn[16], line 8\u001b[0m\n\u001b[0;32m      6\u001b[0m \u001b[38;5;66;03m# 步骤1: 分析数据分布\u001b[39;00m\n\u001b[0;32m      7\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m🔍 STEP 1: 分析数据分布\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mcenter(\u001b[38;5;241m60\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m=\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m----> 8\u001b[0m distribution, strategy \u001b[38;5;241m=\u001b[39m \u001b[43mpipeline\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep1_analyze_distribution\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m     10\u001b[0m \u001b[38;5;66;03m# 显示采样策略总结\u001b[39;00m\n\u001b[0;32m     11\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m📊 采样策略总结:\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
 | ||
|       "Cell \u001b[1;32mIn[13], line 50\u001b[0m, in \u001b[0;36mSmartDataPipeline.step1_analyze_distribution\u001b[1;34m(self, max_samples)\u001b[0m\n\u001b[0;32m     48\u001b[0m \u001b[38;5;66;03m# 分析验证集分布(代表整体分布特征)\u001b[39;00m\n\u001b[0;32m     49\u001b[0m all_labels \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m---> 50\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m trials_batch, filename \u001b[38;5;129;01min\u001b[39;00m load_data_batch(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdata_dir, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mval\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;241m5000\u001b[39m):\n\u001b[0;32m     51\u001b[0m     _, labels \u001b[38;5;241m=\u001b[39m extract_features_labels_batch(trials_batch)\n\u001b[0;32m     52\u001b[0m     all_labels\u001b[38;5;241m.\u001b[39mextend(labels\u001b[38;5;241m.\u001b[39mtolist())\n",
 | ||
|       "Cell \u001b[1;32mIn[8], line 24\u001b[0m, in \u001b[0;36mload_data_batch\u001b[1;34m(data_dir, data_type, max_samples_per_file)\u001b[0m\n\u001b[0;32m     12\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mload_data_batch\u001b[39m(data_dir, data_type, max_samples_per_file\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m5000\u001b[39m):\n\u001b[0;32m     13\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m     14\u001b[0m \u001b[38;5;124;03m    分批加载指定类型的数据\u001b[39;00m\n\u001b[0;32m     15\u001b[0m \u001b[38;5;124;03m    \u001b[39;00m\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m     22\u001b[0m \u001b[38;5;124;03m        generator: 数据批次生成器\u001b[39;00m\n\u001b[0;32m     23\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[1;32m---> 24\u001b[0m     files \u001b[38;5;241m=\u001b[39m [f \u001b[38;5;28;01mfor\u001b[39;00m f \u001b[38;5;129;01min\u001b[39;00m \u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlistdir\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mif\u001b[39;00m f\u001b[38;5;241m.\u001b[39mendswith(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.npz\u001b[39m\u001b[38;5;124m'\u001b[39m) \u001b[38;5;129;01mand\u001b[39;00m data_type \u001b[38;5;129;01min\u001b[39;00m f]\n\u001b[0;32m     26\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m file_idx, f \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(files):\n\u001b[0;32m     27\u001b[0m         \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m  正在加载文件 \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfile_idx\u001b[38;5;241m+\u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(files)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mf\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
 | ||
|       "\u001b[1;31mFileNotFoundError\u001b[0m: [WinError 3] 系统找不到指定的路径。: '/kaggle/working/nejm-brain-to-text/data/concatenated_data'"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "# 🔥 执行智能数据处理管道【确定采样策略】\n",
 | ||
|     "\n",
 | ||
|     "print(\"🚀 开始执行智能数据处理管道...\")\n",
 | ||
|     "print(\"=\" * 60)\n",
 | ||
|     "\n",
 | ||
|     "# 步骤1: 分析数据分布\n",
 | ||
|     "print(\"\\n\" + \"🔍 STEP 1: 分析数据分布\".center(60, \"=\"))\n",
 | ||
|     "distribution, strategy = pipeline.step1_analyze_distribution()\n",
 | ||
|     "\n",
 | ||
|     "# 显示采样策略总结\n",
 | ||
|     "print(f\"\\n📊 采样策略总结:\")\n",
 | ||
|     "undersample_count = sum(1 for s in strategy.values() if s['action'] == 'undersample')\n",
 | ||
|     "oversample_count = sum(1 for s in strategy.values() if s['action'] == 'oversample')\n",
 | ||
|     "keep_count = sum(1 for s in strategy.values() if s['action'] == 'keep')\n",
 | ||
|     "\n",
 | ||
|     "print(f\"   📉 下采样标签: {undersample_count} 个\")\n",
 | ||
|     "print(f\"   📈 过采样标签: {oversample_count} 个\") \n",
 | ||
|     "print(f\"   ✅ 保持不变: {keep_count} 个\")\n",
 | ||
|     "\n",
 | ||
|     "print(\"\\n✅ 步骤1完成!\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 17,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "\n",
 | ||
|       "=====================🔧 STEP 2: 拟合PCA参数======================\n"
 | ||
|      ]
 | ||
|     },
 | ||
|     {
 | ||
|      "ename": "ValueError",
 | ||
|      "evalue": "请先执行步骤1: step1_analyze_distribution()",
 | ||
|      "output_type": "error",
 | ||
|      "traceback": [
 | ||
|       "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
 | ||
|       "\u001b[1;31mValueError\u001b[0m                                Traceback (most recent call last)",
 | ||
|       "Cell \u001b[1;32mIn[17], line 3\u001b[0m\n\u001b[0;32m      1\u001b[0m \u001b[38;5;66;03m# 步骤2: 拟合PCA参数【确定PCA策略】\u001b[39;00m\n\u001b[0;32m      2\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m🔧 STEP 2: 拟合PCA参数\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mcenter(\u001b[38;5;241m60\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m=\u001b[39m\u001b[38;5;124m\"\u001b[39m))\n\u001b[1;32m----> 3\u001b[0m \u001b[43mpipeline\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep2_fit_pca_with_undersampling\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m      5\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m✅ 步骤2完成!\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m      6\u001b[0m pipeline\u001b[38;5;241m.\u001b[39mprint_summary()\n",
 | ||
|       "Cell \u001b[1;32mIn[14], line 8\u001b[0m, in \u001b[0;36mstep2_fit_pca_with_undersampling\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m      4\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m      5\u001b[0m \u001b[38;5;124;03m步骤2: 仅对下采样数据拟合PCA参数(不进行过采样,避免PCA被过采样影响)\u001b[39;00m\n\u001b[0;32m      6\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m      7\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msampling_strategy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m----> 8\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m请先执行步骤1: step1_analyze_distribution()\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m     10\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m🔧 步骤2: 拟合PCA参数(仅下采样,不过采样)...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m     12\u001b[0m \u001b[38;5;66;03m# 收集用于PCA拟合的样本(只下采样,不过采样)\u001b[39;00m\n",
 | ||
|       "\u001b[1;31mValueError\u001b[0m: 请先执行步骤1: step1_analyze_distribution()"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "# 步骤2: 拟合PCA参数【确定PCA策略】\n",
 | ||
|     "print(\"\\n\" + \"🔧 STEP 2: 拟合PCA参数\".center(60, \"=\"))\n",
 | ||
|     "pipeline.step2_fit_pca_with_undersampling()\n",
 | ||
|     "\n",
 | ||
|     "print(\"\\n✅ 步骤2完成!\")\n",
 | ||
|     "pipeline.print_summary()"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "## 🚀 使用智能管道进行分批训练"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": null,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# 🚀 使用智能管道进行分批训练\n",
 | ||
|     "\n",
 | ||
|     "import lightgbm as lgb\n",
 | ||
|     "import time\n",
 | ||
|     "from collections import Counter\n",
 | ||
|     "import matplotlib.pyplot as plt\n",
 | ||
|     "\n",
 | ||
|     "class SmartBatchTrainer:\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    智能分批训练器,集成智能数据管道\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    \n",
 | ||
|     "    def __init__(self, pipeline, params=None, min_learning_rate=1e-4, t_0=50, t_mult=2):\n",
 | ||
|     "        self.pipeline = pipeline\n",
 | ||
|     "        self.model = None\n",
 | ||
|     "        self.training_history = {} # 改为字典,因为只有一次训练\n",
 | ||
|     "        self.batch_count = 0\n",
 | ||
|     "        self.min_learning_rate = min_learning_rate\n",
 | ||
|     "        self.lr_history = [] # 用于可视化\n",
 | ||
|     "        \n",
 | ||
|     "        # 带重启的余弦退火参数\n",
 | ||
|     "        self.t_0 = t_0  # 第一个重启周期的长度\n",
 | ||
|     "        self.t_mult = t_mult  # 重启周期的乘数\n",
 | ||
|     "        \n",
 | ||
|     "        # 默认LightGBM参数(GPU优化)\n",
 | ||
|     "        self.params = params or {\n",
 | ||
|     "            'objective': 'multiclass',\n",
 | ||
|     "            'num_class': 41,\n",
 | ||
|     "            'metric': 'multi_logloss',\n",
 | ||
|     "            'boosting_type': 'gbdt',\n",
 | ||
|     "            'device_type': 'cpu',\n",
 | ||
|     "            # 'gpu_platform_id': 0,\n",
 | ||
|     "            # 'gpu_device_id': 0,\n",
 | ||
|     "            'max_bin': 255,\n",
 | ||
|     "            'num_leaves': 127,\n",
 | ||
|     "            'learning_rate': 0.08, #默认0.08\n",
 | ||
|     "            'feature_fraction': 0.8,\n",
 | ||
|     "            'bagging_fraction': 0.8,\n",
 | ||
|     "            'bagging_freq': 5,\n",
 | ||
|     "            'min_data_in_leaf': 20,\n",
 | ||
|     "            'lambda_l1': 0.1,\n",
 | ||
|     "            'lambda_l2': 0.1,\n",
 | ||
|     "            'verbose': -1,\n",
 | ||
|     "            'num_threads': -1\n",
 | ||
|     "        }\n",
 | ||
|     "        \n",
 | ||
|     "        self.initial_learning_rate = self.params.get('learning_rate', 0.08)\n",
 | ||
|     "        \n",
 | ||
|     "        print(f\"🎯 智能分批训练器创建完成\")\n",
 | ||
|     "        print(f\"   🔧 LightGBM参数已配置:{self.params['device_type'].upper()}模式\")\n",
 | ||
|     "        print(f\"   💡 学习率调度: 带重启的余弦退火 (从 {self.initial_learning_rate} 到 {self.min_learning_rate})\")\n",
 | ||
|     "        print(f\"   🔄 重启参数: T_0={self.t_0}, T_mult={self.t_mult}\")\n",
 | ||
|     "    \n",
 | ||
|     "    def prepare_validation_data(self):\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        准备验证数据(仅PCA,保持原始分布)\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        print(\"🔄 准备验证数据...\")\n",
 | ||
|     "        self.X_val, self.y_val = self.pipeline.step3_process_data('val', apply_sampling=False)\n",
 | ||
|     "        if self.X_val is None:\n",
 | ||
|     "            raise ValueError(\"无法加载验证数据\")\n",
 | ||
|     "        val_counts = Counter(self.y_val)\n",
 | ||
|     "        print(f\"   ✅ 验证数据准备完成: {self.X_val.shape[0]:,} 样本\")\n",
 | ||
|     "        print(f\"   📊 验证集分布 (标签0: {val_counts.get(0, 0):,}, 标签40: {val_counts.get(40, 0):,})\")\n",
 | ||
|     "\n",
 | ||
|     "        return lgb.Dataset(self.X_val, label=self.y_val, free_raw_data=False)\n",
 | ||
|     "\n",
 | ||
|     "    def get_training_batch_generator(self):\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        获取训练批次生成器(平衡采样+PCA)\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        print(\"🔄 准备训练批次生成器...\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 使用管道的批次生成器\n",
 | ||
|     "        for trials_batch, filename in load_data_batch(self.pipeline.data_dir, 'train', 2000):\n",
 | ||
|     "            features, labels = extract_features_labels_batch(trials_batch)\n",
 | ||
|     "            \n",
 | ||
|     "            # 应用完整采样策略\n",
 | ||
|     "            features_sampled, labels_sampled = self.pipeline._apply_full_sampling(features, labels)\n",
 | ||
|     "            \n",
 | ||
|     "            # 应用PCA降维\n",
 | ||
|     "            if features_sampled.shape[0] > 0:\n",
 | ||
|     "                features_pca = self.pipeline._apply_pca_transform(features_sampled)\n",
 | ||
|     "                \n",
 | ||
|     "                # 分析当前批次分布\n",
 | ||
|     "                batch_counts = Counter(labels_sampled)\n",
 | ||
|     "                \n",
 | ||
|     "                print(f\"   📦 批次: {filename}\")\n",
 | ||
|     "                print(f\"      样本数: {features_pca.shape[0]:,}\")\n",
 | ||
|     "                print(f\"      平衡后分布: 标签0={batch_counts.get(0,0)}, 标签40={batch_counts.get(40,0)}\")\n",
 | ||
|     "                \n",
 | ||
|     "                yield lgb.Dataset(features_pca, label=labels_sampled), filename\n",
 | ||
|     "    \n",
 | ||
|     "    def prepare_full_data(self):\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        一次性准备所有训练和验证数据\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        print(\"🔄 准备全量训练和验证数据...\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 1. 准备验证数据 (保持原始分布)\n",
 | ||
|     "        X_val, y_val = self.pipeline.step3_process_data('val', apply_sampling=False)\n",
 | ||
|     "        if X_val is None:\n",
 | ||
|     "            raise ValueError(\"无法加载验证数据\")\n",
 | ||
|     "        val_counts = Counter(y_val)\n",
 | ||
|     "        print(f\"   ✅ 验证数据准备完成: {X_val.shape[0]:,} 样本\")\n",
 | ||
|     "        print(f\"   📊 验证集分布 (标签0: {val_counts.get(0, 0):,}, 标签40: {val_counts.get(40, 0):,})\")\n",
 | ||
|     "        val_data = lgb.Dataset(X_val, label=y_val, free_raw_data=False)\n",
 | ||
|     "        \n",
 | ||
|     "        # 2. 准备训练数据 (应用完整采样和PCA策略)\n",
 | ||
|     "        X_train, y_train = self.pipeline.step3_process_data('train', apply_sampling=True)\n",
 | ||
|     "        if X_train is None:\n",
 | ||
|     "            raise ValueError(\"无法加载训练数据\")\n",
 | ||
|     "        train_counts = Counter(y_train)\n",
 | ||
|     "        print(f\"   ✅ 训练数据准备完成: {X_train.shape[0]:,} 样本, {X_train.shape[1]} 特征\")\n",
 | ||
|     "        print(f\"   📊 训练集(采样后)分布 (标签0: {train_counts.get(0, 0):,}, 标签40: {train_counts.get(40, 0):,})\")\n",
 | ||
|     "        train_data = lgb.Dataset(X_train, label=y_train)\n",
 | ||
|     "        \n",
 | ||
|     "        return train_data, val_data, X_val, y_val\n",
 | ||
|     "    \n",
 | ||
|     "    def prepare_training_data(self):\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        准备训练数据(仅PCA,保持原始分布)\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        print(\"🔄 准备训练数据...\")\n",
 | ||
|     "        # 2. 准备训练数据 (应用完整采样和PCA策略)\n",
 | ||
|     "        X_train, y_train = self.pipeline.step3_process_data('train', apply_sampling=True)\n",
 | ||
|     "        if X_train is None:\n",
 | ||
|     "            raise ValueError(\"无法加载训练数据\")\n",
 | ||
|     "        train_counts = Counter(y_train)\n",
 | ||
|     "        print(f\"   ✅ 训练数据准备完成: {X_train.shape[0]:,} 样本, {X_train.shape[1]} 特征\")\n",
 | ||
|     "        print(f\"   📊 训练集(采样后)分布 (标签0: {train_counts.get(0, 0):,}, 标签40: {train_counts.get(40, 0):,})\")\n",
 | ||
|     "        \n",
 | ||
|     "        return lgb.Dataset(X_train, label=y_train, free_raw_data=False)\n",
 | ||
|     "                \n",
 | ||
|     "    # 带重启的余弦退火调度器函数\n",
 | ||
|     "    def _cosine_annealing_with_warm_restarts(self, current_round):\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        带重启的余弦退火调度器 (SGDR)\n",
 | ||
|     "        \n",
 | ||
|     "        Args:\n",
 | ||
|     "            current_round: 当前训练轮数\n",
 | ||
|     "            \n",
 | ||
|     "        Returns:\n",
 | ||
|     "            学习率\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        eta_max = self.initial_learning_rate\n",
 | ||
|     "        eta_min = self.min_learning_rate\n",
 | ||
|     "        \n",
 | ||
|     "        # 计算当前在哪个重启周期中\n",
 | ||
|     "        t_cur = current_round\n",
 | ||
|     "        t_i = self.t_0\n",
 | ||
|     "        \n",
 | ||
|     "        # 找到当前的重启周期\n",
 | ||
|     "        cycle = 0\n",
 | ||
|     "        while t_cur >= t_i:\n",
 | ||
|     "            t_cur -= t_i\n",
 | ||
|     "            cycle += 1\n",
 | ||
|     "            t_i *= self.t_mult\n",
 | ||
|     "        \n",
 | ||
|     "        # 在当前周期内的位置\n",
 | ||
|     "        progress = t_cur / t_i\n",
 | ||
|     "        \n",
 | ||
|     "        # 计算学习率\n",
 | ||
|     "        lr = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * progress))\n",
 | ||
|     "        \n",
 | ||
|     "        return lr\n",
 | ||
|     "    \n",
 | ||
|     "    def train_incremental(self, num_boost_round=100, early_stopping_rounds=10):\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        增量分批训练\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        print(f\"\\n🚀 开始智能分批训练...\")\n",
 | ||
|     "        print(f\"   📝 训练轮数 (每批次): {num_boost_round}\")\n",
 | ||
|     "        print(f\"   ⏹️ 早停轮数: {early_stopping_rounds}\")\n",
 | ||
|     "        print(\"=\" * 60)\n",
 | ||
|     "        \n",
 | ||
|     "        # 准备验证数据\n",
 | ||
|     "        val_data = self.prepare_validation_data()\n",
 | ||
|     "        \n",
 | ||
|     "        print(f\"\\n🔄 开始分批增量训练...\")\n",
 | ||
|     "        total_start_time = time.time()\n",
 | ||
|     "        \n",
 | ||
|     "        # ⭐️ 新增: 为学习率调度器定义T_max\n",
 | ||
|     "        # 我们将每个批次的训练视为一个完整的退火周期\n",
 | ||
|     "        t_max_per_batch = num_boost_round\n",
 | ||
|     "        \n",
 | ||
|     "        for train_data, filename in self.get_training_batch_generator():\n",
 | ||
|     "            self.batch_count += 1\n",
 | ||
|     "            batch_start_time = time.time()\n",
 | ||
|     "            self.last_batch_lr_history = [] # 重置每个批次的LR历史\n",
 | ||
|     "            \n",
 | ||
|     "            print(f\"\\n📈 批次 {self.batch_count}: {filename}\")\n",
 | ||
|     "            \n",
 | ||
|     "            # ⭐️ 新增: 创建学习率调度回调 和 记录回调\n",
 | ||
|     "            lr_scheduler_callback = lgb.reset_parameter(\n",
 | ||
|     "                learning_rate=lambda current_round: self._cosine_annealing_with_warm_restarts(current_round)\n",
 | ||
|     "            )\n",
 | ||
|     "\n",
 | ||
|     "            # 这个简单的回调用于记录每个周期的学习率,以便后续可视化\n",
 | ||
|     "            def record_lr_callback(env):\n",
 | ||
|     "                self.last_batch_lr_history.append(env.model.params['learning_rate'])\n",
 | ||
|     "\n",
 | ||
|     "            # 组合所有回调\n",
 | ||
|     "            training_callbacks = [\n",
 | ||
|     "                lgb.early_stopping(stopping_rounds=early_stopping_rounds, verbose=True),\n",
 | ||
|     "                lgb.log_evaluation(period=10), # 每10轮打印一次\n",
 | ||
|     "                lr_scheduler_callback,\n",
 | ||
|     "                record_lr_callback\n",
 | ||
|     "            ]\n",
 | ||
|     "\n",
 | ||
|     "            # 训练当前批次\n",
 | ||
|     "            current_model_args = {\n",
 | ||
|     "                'params': self.params,\n",
 | ||
|     "                'train_set': train_data,\n",
 | ||
|     "                'num_boost_round': num_boost_round,\n",
 | ||
|     "                'valid_sets': [val_data],\n",
 | ||
|     "                'valid_names': ['validation'],\n",
 | ||
|     "                'callbacks': training_callbacks\n",
 | ||
|     "            }\n",
 | ||
|     "            \n",
 | ||
|     "            if self.model is None:\n",
 | ||
|     "                print(\"   🎯 初始模型训练...\")\n",
 | ||
|     "                self.model = lgb.train(**current_model_args)\n",
 | ||
|     "            else:\n",
 | ||
|     "                print(\"   ⚡ 增量训练...\")\n",
 | ||
|     "                current_model_args['init_model'] = self.model\n",
 | ||
|     "                self.model = lgb.train(**current_model_args)\n",
 | ||
|     "\n",
 | ||
|     "            # 记录训练历史\n",
 | ||
|     "            batch_time = time.time() - batch_start_time\n",
 | ||
|     "            \n",
 | ||
|     "            # 评估当前模型\n",
 | ||
|     "            val_pred = self.model.predict(self.X_val)\n",
 | ||
|     "            val_accuracy = (val_pred.argmax(axis=1) == self.y_val).mean()\n",
 | ||
|     "            \n",
 | ||
|     "            batch_info = {\n",
 | ||
|     "                'batch': self.batch_count,\n",
 | ||
|     "                'filename': filename,\n",
 | ||
|     "                'time': batch_time,\n",
 | ||
|     "                'val_accuracy': val_accuracy,\n",
 | ||
|     "                'num_trees': self.model.num_trees(),\n",
 | ||
|     "                'lr_history': self.last_batch_lr_history.copy() # 保存当前批次的LR历史\n",
 | ||
|     "            }\n",
 | ||
|     "            \n",
 | ||
|     "            self.training_history.append(batch_info)\n",
 | ||
|     "            \n",
 | ||
|     "            print(f\"   ✅ 批次完成: {batch_time:.1f}秒\")\n",
 | ||
|     "            print(f\"   📊 验证准确率: {val_accuracy:.4f}\")\n",
 | ||
|     "            print(f\"   🌳 模型树数: {self.model.num_trees()}\")\n",
 | ||
|     "            \n",
 | ||
|     "            model_path = f\"smart_batch_model_batch_{self.batch_count}.txt\"\n",
 | ||
|     "            self.model.save_model(model_path)\n",
 | ||
|     "            print(f\"   💾 模型已保存: {model_path}\")\n",
 | ||
|     "        \n",
 | ||
|     "        total_time = time.time() - total_start_time\n",
 | ||
|     "        print(f\"\\n🎉 智能分批训练完成!\")\n",
 | ||
|     "        print(f\"   ⏱️ 总训练时间: {total_time:.1f}秒\")\n",
 | ||
|     "        print(f\"   📊 处理批次数: {self.batch_count}\")\n",
 | ||
|     "        print(f\"   🌳 最终模型树数: {self.model.num_trees()}\")\n",
 | ||
|     "        \n",
 | ||
|     "        return self.model\n",
 | ||
|     "    \n",
 | ||
|     "    def train(self, num_boost_round=1000, early_stopping_rounds=50):\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        执行一次性全量训练\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        print(f\"\\n🚀 开始全量数据训练...\")\n",
 | ||
|     "        print(f\"   📝 训练轮数: {num_boost_round}\")\n",
 | ||
|     "        print(f\"   ⏹️ 早停轮数: {early_stopping_rounds}\")\n",
 | ||
|     "        print(\"=\" * 60)\n",
 | ||
|     "        \n",
 | ||
|     "        # 准备数据\n",
 | ||
|     "        train_data, val_data, X_val, y_val = self.prepare_full_data()\n",
 | ||
|     "        \n",
 | ||
|     "        start_time = time.time()\n",
 | ||
|     "        \n",
 | ||
|     "        # 定义学习率调度和记录回调\n",
 | ||
|     "        lr_scheduler_callback = lgb.reset_parameter(\n",
 | ||
|     "            learning_rate=lambda current_round: self._cosine_annealing_with_warm_restarts(current_round)\n",
 | ||
|     "        )\n",
 | ||
|     "        def record_lr_callback(env):\n",
 | ||
|     "            self.lr_history.append(env.model.params['learning_rate'])\n",
 | ||
|     "        \n",
 | ||
|     "        training_callbacks = [\n",
 | ||
|     "            lgb.early_stopping(stopping_rounds=early_stopping_rounds, verbose=True),\n",
 | ||
|     "            lgb.log_evaluation(period=1), # 每100轮打印日志\n",
 | ||
|     "            lr_scheduler_callback,\n",
 | ||
|     "            record_lr_callback\n",
 | ||
|     "        ]\n",
 | ||
|     "        \n",
 | ||
|     "        # 训练模型\n",
 | ||
|     "        print(\"\\n📈 开始模型训练...\")\n",
 | ||
|     "        self.model = lgb.train(\n",
 | ||
|     "            params=self.params,\n",
 | ||
|     "            train_set=train_data,\n",
 | ||
|     "            num_boost_round=num_boost_round,\n",
 | ||
|     "            valid_sets=[val_data],\n",
 | ||
|     "            valid_names=['validation'],\n",
 | ||
|     "            callbacks=training_callbacks\n",
 | ||
|     "        )\n",
 | ||
|     "        \n",
 | ||
|     "        training_time = time.time() - start_time\n",
 | ||
|     "        \n",
 | ||
|     "        # 评估模型\n",
 | ||
|     "        val_pred = self.model.predict(X_val)\n",
 | ||
|     "        val_accuracy = (val_pred.argmax(axis=1) == y_val).mean()\n",
 | ||
|     "        \n",
 | ||
|     "        # 记录训练历史\n",
 | ||
|     "        self.training_history = {\n",
 | ||
|     "            'time': training_time,\n",
 | ||
|     "            'val_accuracy': val_accuracy,\n",
 | ||
|     "            'num_trees': self.model.num_trees(),\n",
 | ||
|     "            'lr_history': self.lr_history,\n",
 | ||
|     "            'best_iteration': self.model.best_iteration\n",
 | ||
|     "        }\n",
 | ||
|     "        \n",
 | ||
|     "        print(f\"\\n🎉 全量数据训练完成!\")\n",
 | ||
|     "        print(f\"   ⏱️ 总训练时间: {training_time:.1f}秒\")\n",
 | ||
|     "        print(f\"   🌳 最终模型树数: {self.model.num_trees()} (最佳轮次: {self.model.best_iteration})\")\n",
 | ||
|     "        print(f\"   🎯 最终验证准确率: {val_accuracy:.4f}\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 保存模型\n",
 | ||
|     "        model_path = \"full_train_model.txt\"\n",
 | ||
|     "        self.model.save_model(model_path)\n",
 | ||
|     "        print(f\"   💾 模型已保存: {model_path}\")\n",
 | ||
|     "        \n",
 | ||
|     "        return self.model\n",
 | ||
|     "    \n",
 | ||
|     "    def plot_training_progress(self):\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        绘制训练进度\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        if not self.training_history:\n",
 | ||
|     "            print(\"❌ 没有训练历史记录\")\n",
 | ||
|     "            return\n",
 | ||
|     "        \n",
 | ||
|     "        # ⭐️ 修改: 增加学习率的可视化图表\n",
 | ||
|     "        fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = plt.subplots(3, 2, figsize=(15, 15))\n",
 | ||
|     "        \n",
 | ||
|     "        batches = [h['batch'] for h in self.training_history]\n",
 | ||
|     "        accuracies = [h['val_accuracy'] for h in self.training_history]\n",
 | ||
|     "        times = [h['time'] for h in self.training_history]\n",
 | ||
|     "        trees = [h['num_trees'] for h in self.training_history]\n",
 | ||
|     "        \n",
 | ||
|     "        # 1. 验证准确率\n",
 | ||
|     "        ax1.plot(batches, accuracies, 'b-o', linewidth=2, markersize=6)\n",
 | ||
|     "        ax1.set_xlabel('Training Batch')\n",
 | ||
|     "        ax1.set_ylabel('Validation Accuracy')\n",
 | ||
|     "        ax1.set_title('Validation Accuracy Progress')\n",
 | ||
|     "        ax1.grid(True, alpha=0.3)\n",
 | ||
|     "        ax1.set_ylim(0, 1)\n",
 | ||
|     "        \n",
 | ||
|     "        # 2. 批次训练时间\n",
 | ||
|     "        ax2.bar(batches, times, color='green', alpha=0.7)\n",
 | ||
|     "        ax2.set_xlabel('Training Batch')\n",
 | ||
|     "        ax2.set_ylabel('Training Time (seconds)')\n",
 | ||
|     "        ax2.set_title('Training Time per Batch')\n",
 | ||
|     "        ax2.grid(True, alpha=0.3)\n",
 | ||
|     "        \n",
 | ||
|     "        # 3. 模型树数增长\n",
 | ||
|     "        ax3.plot(batches, trees, 'r-s', linewidth=2, markersize=6)\n",
 | ||
|     "        ax3.set_xlabel('Training Batch')\n",
 | ||
|     "        ax3.set_ylabel('Number of Trees')\n",
 | ||
|     "        ax3.set_title('Model Complexity Growth')\n",
 | ||
|     "        ax3.grid(True, alpha=0.3)\n",
 | ||
|     "        \n",
 | ||
|     "        # 4. 累计准确率提升\n",
 | ||
|     "        ax4.plot(batches, [acc - accuracies[0] for acc in accuracies], 'purple', linewidth=2, marker='D')\n",
 | ||
|     "        ax4.set_xlabel('Training Batch')\n",
 | ||
|     "        ax4.set_ylabel('Accuracy Improvement')\n",
 | ||
|     "        ax4.set_title('Cumulative Accuracy Improvement')\n",
 | ||
|     "        ax4.grid(True, alpha=0.3)\n",
 | ||
|     "        ax4.axhline(y=0, color='black', linestyle='--', alpha=0.5)\n",
 | ||
|     "\n",
 | ||
|     "        # ⭐️ 新增: 5. 最后一个批次的学习率曲线\n",
 | ||
|     "        last_lr_history = self.training_history[-1]['lr_history']\n",
 | ||
|     "        ax5.plot(range(len(last_lr_history)), last_lr_history, color='orange', marker='.')\n",
 | ||
|     "        ax5.set_xlabel('Boosting Round in Last Batch')\n",
 | ||
|     "        ax5.set_ylabel('Learning Rate')\n",
 | ||
|     "        ax5.set_title(f'Cosine Annealing LR in Last Batch (Batch {batches[-1]})')\n",
 | ||
|     "        ax5.grid(True, alpha=0.3)\n",
 | ||
|     "        \n",
 | ||
|     "        # 隐藏第六个子图\n",
 | ||
|     "        ax6.axis('off')\n",
 | ||
|     "\n",
 | ||
|     "        plt.tight_layout()\n",
 | ||
|     "        plt.show()\n",
 | ||
|     "        \n",
 | ||
|     "        # 打印统计信息\n",
 | ||
|     "        print(f\"\\n📈 训练进度统计:\")\n",
 | ||
|     "        print(f\"   🎯 初始准确率: {accuracies[0]:.4f}\")\n",
 | ||
|     "        print(f\"   🎯 最终准确率: {accuracies[-1]:.4f}\")\n",
 | ||
|     "        print(f\"   📈 准确率提升: {accuracies[-1] - accuracies[0]:.4f}\")\n",
 | ||
|     "        print(f\"   ⏱️ 平均批次时间: {np.mean(times):.1f}秒\")\n",
 | ||
|     "        print(f\"   🌳 最终模型树数: {trees[-1]}\")\n",
 | ||
|     "\n",
 | ||
|     "\n",
 | ||
|     "print(\"🚀 创建智能分批训练器...\")\n",
 | ||
|     "# 实例化时可以传入最小学习率\n",
 | ||
|     "trainer = SmartBatchTrainer(pipeline, min_learning_rate=0.001) \n",
 | ||
|     "print(\"✅ 训练器创建完成,准备开始训练!\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 11,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "🔥 开始智能分批训练!\n",
 | ||
|       "================================================================================\n",
 | ||
|       "📝 训练配置:\n",
 | ||
|       "   训练轮数: 500\n",
 | ||
|       "   早停轮数: 15\n",
 | ||
|       "   数据平衡: 启用(下采样标签0,40 + 过采样少数类)\n",
 | ||
|       "   PCA降维: 7168 → None 特征\n",
 | ||
|       "\n",
 | ||
|       "🚀 启动训练...\n"
 | ||
|      ]
 | ||
|     },
 | ||
|     {
 | ||
|      "ename": "NameError",
 | ||
|      "evalue": "name 'trainer' is not defined",
 | ||
|      "output_type": "error",
 | ||
|      "traceback": [
 | ||
|       "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
 | ||
|       "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
 | ||
|       "Cell \u001b[1;32mIn[11], line 21\u001b[0m\n\u001b[0;32m     18\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m🚀 启动训练...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m     20\u001b[0m \u001b[38;5;66;03m# 开始训练\u001b[39;00m\n\u001b[1;32m---> 21\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[38;5;241m.\u001b[39mtrain(\n\u001b[0;32m     22\u001b[0m     num_boost_round\u001b[38;5;241m=\u001b[39mTRAINING_PARAMS[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mnum_boost_round\u001b[39m\u001b[38;5;124m'\u001b[39m],\n\u001b[0;32m     23\u001b[0m     early_stopping_rounds\u001b[38;5;241m=\u001b[39mTRAINING_PARAMS[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mearly_stopping_rounds\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[0;32m     24\u001b[0m )\n",
 | ||
|       "\u001b[1;31mNameError\u001b[0m: name 'trainer' is not defined"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "# 全量训练\n",
 | ||
|     "\n",
 | ||
|     "print(\"🔥 开始智能分批训练!\")\n",
 | ||
|     "print(\"=\" * 80)\n",
 | ||
|     "\n",
 | ||
|     "# 训练参数\n",
 | ||
|     "TRAINING_PARAMS = {\n",
 | ||
|     "    'num_boost_round': 500,      # 每批次的提升轮数\n",
 | ||
|     "    'early_stopping_rounds': 15  # 早停轮数\n",
 | ||
|     "}\n",
 | ||
|     "\n",
 | ||
|     "print(f\"📝 训练配置:\")\n",
 | ||
|     "print(f\"   训练轮数: {TRAINING_PARAMS['num_boost_round']}\")\n",
 | ||
|     "print(f\"   早停轮数: {TRAINING_PARAMS['early_stopping_rounds']}\")\n",
 | ||
|     "print(f\"   数据平衡: 启用(下采样标签0,40 + 过采样少数类)\")\n",
 | ||
|     "print(f\"   PCA降维: 7168 → {pipeline.pca_components} 特征\")\n",
 | ||
|     "\n",
 | ||
|     "print(f\"\\n🚀 启动训练...\")\n",
 | ||
|     "\n",
 | ||
|     "# 开始训练\n",
 | ||
|     "model = trainer.train(\n",
 | ||
|     "    num_boost_round=TRAINING_PARAMS['num_boost_round'],\n",
 | ||
|     "    early_stopping_rounds=TRAINING_PARAMS['early_stopping_rounds']\n",
 | ||
|     ")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 12,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "🔥 开始智能分批训练!\n",
 | ||
|       "================================================================================\n",
 | ||
|       "📝 训练配置:\n",
 | ||
|       "   训练轮数: 300\n",
 | ||
|       "   早停轮数: 15\n",
 | ||
|       "   数据平衡: 启用(下采样标签0,40 + 过采样少数类)\n",
 | ||
|       "   PCA降维: 7168 → None 特征\n",
 | ||
|       "\n",
 | ||
|       "🚀 启动训练...\n"
 | ||
|      ]
 | ||
|     },
 | ||
|     {
 | ||
|      "ename": "NameError",
 | ||
|      "evalue": "name 'trainer' is not defined",
 | ||
|      "output_type": "error",
 | ||
|      "traceback": [
 | ||
|       "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
 | ||
|       "\u001b[1;31mNameError\u001b[0m                                 Traceback (most recent call last)",
 | ||
|       "Cell \u001b[1;32mIn[12], line 19\u001b[0m\n\u001b[0;32m     16\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m🚀 启动训练...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m     18\u001b[0m \u001b[38;5;66;03m# 开始训练\u001b[39;00m\n\u001b[1;32m---> 19\u001b[0m model \u001b[38;5;241m=\u001b[39m \u001b[43mtrainer\u001b[49m\u001b[38;5;241m.\u001b[39mtrain(\n\u001b[0;32m     20\u001b[0m     num_boost_round\u001b[38;5;241m=\u001b[39mTRAINING_PARAMS[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mnum_boost_round\u001b[39m\u001b[38;5;124m'\u001b[39m],\n\u001b[0;32m     21\u001b[0m     early_stopping_rounds\u001b[38;5;241m=\u001b[39mTRAINING_PARAMS[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mearly_stopping_rounds\u001b[39m\u001b[38;5;124m'\u001b[39m]\n\u001b[0;32m     22\u001b[0m )\n",
 | ||
|       "\u001b[1;31mNameError\u001b[0m: name 'trainer' is not defined"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "print(\"🔥 开始智能分批训练!\")\n",
 | ||
|     "print(\"=\" * 80)\n",
 | ||
|     "\n",
 | ||
|     "# 训练参数\n",
 | ||
|     "TRAINING_PARAMS = {\n",
 | ||
|     "    'num_boost_round': 300,      # 每批次的提升轮数\n",
 | ||
|     "    'early_stopping_rounds': 15  # 早停轮数\n",
 | ||
|     "}\n",
 | ||
|     "\n",
 | ||
|     "print(f\"📝 训练配置:\")\n",
 | ||
|     "print(f\"   训练轮数: {TRAINING_PARAMS['num_boost_round']}\")\n",
 | ||
|     "print(f\"   早停轮数: {TRAINING_PARAMS['early_stopping_rounds']}\")\n",
 | ||
|     "print(f\"   数据平衡: 启用(下采样标签0,40 + 过采样少数类)\")\n",
 | ||
|     "print(f\"   PCA降维: 7168 → {pipeline.pca_components} 特征\")\n",
 | ||
|     "\n",
 | ||
|     "print(f\"\\n🚀 启动训练...\")\n",
 | ||
|     "\n",
 | ||
|     "# 开始训练\n",
 | ||
|     "model = trainer.train(\n",
 | ||
|     "    num_boost_round=TRAINING_PARAMS['num_boost_round'],\n",
 | ||
|     "    early_stopping_rounds=TRAINING_PARAMS['early_stopping_rounds']\n",
 | ||
|     ")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "## 📊 训练结果分析"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": null,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# 📊 训练结果分析和可视化\n",
 | ||
|     "\n",
 | ||
|     "print(\"📊 分析智能分批训练结果...\")\n",
 | ||
|     "print(\"=\" * 60)\n",
 | ||
|     "\n",
 | ||
|     "# 显示训练进度图表\n",
 | ||
|     "trainer.plot_training_progress()\n",
 | ||
|     "\n",
 | ||
|     "# 保存最终模型\n",
 | ||
|     "final_model_path = \"smart_pipeline_final_model.txt\"\n",
 | ||
|     "if trainer.model:\n",
 | ||
|     "    trainer.model.save_model(final_model_path)\n",
 | ||
|     "    print(f\"\\n💾 最终模型已保存: {final_model_path}\")\n",
 | ||
|     "\n",
 | ||
|     "# 详细分析\n",
 | ||
|     "if trainer.training_history:\n",
 | ||
|     "    print(f\"\\n📈 详细训练分析:\")\n",
 | ||
|     "    print(f\"   🎯 训练批次总数: {len(trainer.training_history)}\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 最佳批次\n",
 | ||
|     "    best_batch = max(trainer.training_history, key=lambda x: x['val_accuracy'])\n",
 | ||
|     "    print(f\"   🏆 最佳验证准确率: {best_batch['val_accuracy']:.4f} (批次 {best_batch['batch']})\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 训练效率\n",
 | ||
|     "    total_training_time = sum(h['time'] for h in trainer.training_history)\n",
 | ||
|     "    avg_batch_time = total_training_time / len(trainer.training_history)\n",
 | ||
|     "    print(f\"   ⏱️ 总训练时间: {total_training_time:.1f}秒\")\n",
 | ||
|     "    print(f\"   ⏱️ 平均批次时间: {avg_batch_time:.1f}秒\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 模型复杂度\n",
 | ||
|     "    final_trees = trainer.training_history[-1]['num_trees']\n",
 | ||
|     "    print(f\"   🌳 最终模型树数: {final_trees}\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 收敛性分析\n",
 | ||
|     "    recent_accs = [h['val_accuracy'] for h in trainer.training_history[-3:]]\n",
 | ||
|     "    if len(recent_accs) >= 2:\n",
 | ||
|     "        acc_stability = max(recent_accs) - min(recent_accs)\n",
 | ||
|     "        print(f\"   📈 准确率稳定性: {acc_stability:.4f} (最近3批次方差)\")\n",
 | ||
|     "        \n",
 | ||
|     "        if acc_stability < 0.01:\n",
 | ||
|     "            print(\"   ✅ 模型已收敛 (准确率变化 < 1%)\")\n",
 | ||
|     "        else:\n",
 | ||
|     "            print(\"   ⚠️ 模型可能需要更多训练\")\n",
 | ||
|     "\n",
 | ||
|     "print(f\"\\n🎉 智能分批训练分析完成!\")\n",
 | ||
|     "print(f\"   💡 使用了改进的数据平衡策略和PCA降维\")\n",
 | ||
|     "print(f\"   💡 训练集应用了下采样+过采样,验证集保持原始分布\")\n",
 | ||
|     "print(f\"   💡 实现了内存友好的分批处理\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "## 🧪 模型性能评估"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": null,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# 🧪 模型性能评估\n",
 | ||
|     "\n",
 | ||
|     "from sklearn.metrics import classification_report, confusion_matrix\n",
 | ||
|     "import numpy as np\n",
 | ||
|     "\n",
 | ||
|     "def evaluate_model_performance(model, pipeline, data_type='val'):\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    评估模型在指定数据集上的性能\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    print(f\"🧪 评估模型在{data_type}数据集上的性能...\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 加载数据\n",
 | ||
|     "    X, y = pipeline.step3_process_data(data_type, apply_sampling=False)\n",
 | ||
|     "    \n",
 | ||
|     "    if X is None or y is None:\n",
 | ||
|     "        print(f\"❌ 无法加载{data_type}数据\")\n",
 | ||
|     "        return None\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"   📊 数据集大小: {X.shape[0]:,} 样本, {X.shape[1]} 特征\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 预测\n",
 | ||
|     "    start_time = time.time()\n",
 | ||
|     "    y_pred_proba = model.predict(X)\n",
 | ||
|     "    y_pred = y_pred_proba.argmax(axis=1)\n",
 | ||
|     "    pred_time = time.time() - start_time\n",
 | ||
|     "    \n",
 | ||
|     "    # 计算性能指标\n",
 | ||
|     "    accuracy = (y_pred == y).mean()\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"   ⏱️ 预测时间: {pred_time:.2f}秒\")\n",
 | ||
|     "    print(f\"   🎯 整体准确率: {accuracy:.4f}\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 分析各类别性能\n",
 | ||
|     "    from collections import Counter\n",
 | ||
|     "    true_counts = Counter(y)\n",
 | ||
|     "    pred_counts = Counter(y_pred)\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"\\n📊 标签分布对比:\")\n",
 | ||
|     "    print(\"标签 | 真实数量 | 预测数量 | 准确率\")\n",
 | ||
|     "    print(\"-\" * 40)\n",
 | ||
|     "    \n",
 | ||
|     "    label_accuracies = {}\n",
 | ||
|     "    for label in range(41):\n",
 | ||
|     "        if label in true_counts:\n",
 | ||
|     "            label_mask = (y == label)\n",
 | ||
|     "            if label_mask.sum() > 0:\n",
 | ||
|     "                label_acc = (y_pred[label_mask] == label).mean()\n",
 | ||
|     "                label_accuracies[label] = label_acc\n",
 | ||
|     "                true_count = true_counts.get(label, 0)\n",
 | ||
|     "                pred_count = pred_counts.get(label, 0)\n",
 | ||
|     "                print(f\"{label:4d} | {true_count:8,} | {pred_count:8,} | {label_acc:7.3f}\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 重点分析关键标签\n",
 | ||
|     "    print(f\"\\n🔍 关键标签性能分析:\")\n",
 | ||
|     "    key_labels = [0, 40]  # 下采样的标签\n",
 | ||
|     "    for label in key_labels:\n",
 | ||
|     "        if label in label_accuracies:\n",
 | ||
|     "            acc = label_accuracies[label]\n",
 | ||
|     "            count = true_counts.get(label, 0)\n",
 | ||
|     "            print(f\"   标签 {label} (下采样目标): 准确率 {acc:.4f}, 样本数 {count:,}\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 少数类性能\n",
 | ||
|     "    minority_labels = [label for label, count in true_counts.items() \n",
 | ||
|     "                      if count < 200 and label not in [0, 40]]\n",
 | ||
|     "    if minority_labels:\n",
 | ||
|     "        minority_accs = [label_accuracies.get(label, 0) for label in minority_labels[:5]]\n",
 | ||
|     "        avg_minority_acc = np.mean(minority_accs) if minority_accs else 0\n",
 | ||
|     "        print(f\"   少数类平均准确率 (前5个): {avg_minority_acc:.4f}\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 置信度分析\n",
 | ||
|     "    max_proba = y_pred_proba.max(axis=1)\n",
 | ||
|     "    print(f\"\\n📈 预测置信度分析:\")\n",
 | ||
|     "    print(f\"   平均置信度: {max_proba.mean():.4f}\")\n",
 | ||
|     "    print(f\"   置信度中位数: {np.median(max_proba):.4f}\")\n",
 | ||
|     "    print(f\"   高置信度预测 (>0.9): {(max_proba > 0.9).sum():,} / {len(max_proba):,} ({(max_proba > 0.9).mean():.2%})\")\n",
 | ||
|     "    \n",
 | ||
|     "    return {\n",
 | ||
|     "        'accuracy': accuracy,\n",
 | ||
|     "        'prediction_time': pred_time,\n",
 | ||
|     "        'label_accuracies': label_accuracies,\n",
 | ||
|     "        'confidence_stats': {\n",
 | ||
|     "            'mean': max_proba.mean(),\n",
 | ||
|     "            'median': np.median(max_proba),\n",
 | ||
|     "            'high_confidence_ratio': (max_proba > 0.9).mean()\n",
 | ||
|     "        }\n",
 | ||
|     "    }\n",
 | ||
|     "\n",
 | ||
|     "# 评估模型性能\n",
 | ||
|     "if trainer.model:\n",
 | ||
|     "    print(\"🧪 开始模型性能评估...\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 验证集评估\n",
 | ||
|     "    val_results = evaluate_model_performance(trainer.model, pipeline, 'val')\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"\\n\" + \"=\"*60)\n",
 | ||
|     "    print(\"🎉 智能分批训练+数据平衡 评估完成!\")\n",
 | ||
|     "    print(f\"✅ 实现了数据平衡和PCA降维的完整流程\")\n",
 | ||
|     "    print(f\"✅ 使用了内存友好的分批训练策略\")\n",
 | ||
|     "    print(f\"✅ 保持了验证集的原始分布以确保评估客观性\")\n",
 | ||
|     "else:\n",
 | ||
|     "    print(\"❌ 模型尚未训练完成,请等待训练结束后运行此评估\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": null,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# ✅ 余弦退火已更新为带重启版本\n",
 | ||
|     "\n",
 | ||
|     "print(\"🎉 余弦退火调度器更新完成!\")\n",
 | ||
|     "\n",
 | ||
|     "# 检查trainer是否已创建,如果未创建则先创建\n",
 | ||
|     "if 'trainer' not in globals():\n",
 | ||
|     "    print(\"⚠️ 训练器尚未创建,请先运行前面的代码创建训练器\")\n",
 | ||
|     "else:\n",
 | ||
|     "    print(f\"✅ 当前使用:带重启的余弦退火 (SGDR)\")\n",
 | ||
|     "    print(f\"   🔄 重启参数: T_0={trainer.t_0}, T_mult={trainer.t_mult}\")\n",
 | ||
|     "    print(f\"   📈 学习率范围: {trainer.initial_learning_rate} → {trainer.min_learning_rate}\")\n",
 | ||
|     "\n",
 | ||
|     "    # 可视化新的学习率调度\n",
 | ||
|     "    import matplotlib.pyplot as plt\n",
 | ||
|     "    import numpy as np\n",
 | ||
|     "\n",
 | ||
|     "    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))\n",
 | ||
|     "\n",
 | ||
|     "    # 模拟300轮的学习率变化\n",
 | ||
|     "    rounds = list(range(300))\n",
 | ||
|     "    old_lrs = []  # 原始余弦退火\n",
 | ||
|     "    new_lrs = []  # 带重启的余弦退火\n",
 | ||
|     "\n",
 | ||
|     "    for r in rounds:\n",
 | ||
|     "        # 原始余弦退火 (单调递减)\n",
 | ||
|     "        old_lr = trainer.min_learning_rate + 0.5 * (trainer.initial_learning_rate - trainer.min_learning_rate) * (1 + np.cos(np.pi * r / 300))\n",
 | ||
|     "        old_lrs.append(old_lr)\n",
 | ||
|     "        \n",
 | ||
|     "        # 带重启的余弦退火\n",
 | ||
|     "        new_lr = trainer._cosine_annealing_with_warm_restarts(r)\n",
 | ||
|     "        new_lrs.append(new_lr)\n",
 | ||
|     "\n",
 | ||
|     "    # 绘制对比图\n",
 | ||
|     "    ax1.plot(rounds, old_lrs, 'b-', label='原始余弦退火', linewidth=2)\n",
 | ||
|     "    ax1.set_xlabel('Training Round')\n",
 | ||
|     "    ax1.set_ylabel('Learning Rate')\n",
 | ||
|     "    ax1.set_title('原始余弦退火 (单调递减)')\n",
 | ||
|     "    ax1.grid(True, alpha=0.3)\n",
 | ||
|     "    ax1.legend()\n",
 | ||
|     "\n",
 | ||
|     "    ax2.plot(rounds, new_lrs, 'r-', label='带重启的余弦退火', linewidth=2)\n",
 | ||
|     "    ax2.set_xlabel('Training Round')\n",
 | ||
|     "    ax2.set_ylabel('Learning Rate')\n",
 | ||
|     "    ax2.set_title('带重启的余弦退火 (SGDR)')\n",
 | ||
|     "    ax2.grid(True, alpha=0.3)\n",
 | ||
|     "    ax2.legend()\n",
 | ||
|     "\n",
 | ||
|     "    plt.tight_layout()\n",
 | ||
|     "    plt.show()\n",
 | ||
|     "\n",
 | ||
|     "    print(\"📊 学习率调度对比可视化完成\")\n",
 | ||
|     "    print(\"   🔵 原始版本:单调递减的余弦曲线\")\n",
 | ||
|     "    print(\"   🔴 新版本:周期性重启,每次重启后学习率回到最大值\")\n",
 | ||
|     "    print(\"   💡 SGDR的优势:多次重启可以帮助模型跳出局部最优解\")\n",
 | ||
|     "\n",
 | ||
|     "    # 显示重启点\n",
 | ||
|     "    restart_points = []\n",
 | ||
|     "    t_cur = 0\n",
 | ||
|     "    t_i = trainer.t_0\n",
 | ||
|     "    while t_cur < 300:\n",
 | ||
|     "        restart_points.append(t_cur)\n",
 | ||
|     "        t_cur += t_i\n",
 | ||
|     "        t_i *= trainer.t_mult\n",
 | ||
|     "\n",
 | ||
|     "    print(f\"   🔄 在300轮训练中的重启点: {restart_points[:5]}...\")  # 显示前5个重启点"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": null,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": []
 | ||
|   }
 | ||
|  ],
 | ||
|  "metadata": {
 | ||
|   "kaggle": {
 | ||
|    "accelerator": "tpu1vmV38",
 | ||
|    "dataSources": [
 | ||
|     {
 | ||
|      "databundleVersionId": 13056355,
 | ||
|      "sourceId": 106809,
 | ||
|      "sourceType": "competition"
 | ||
|     }
 | ||
|    ],
 | ||
|    "dockerImageVersionId": 31091,
 | ||
|    "isGpuEnabled": false,
 | ||
|    "isInternetEnabled": true,
 | ||
|    "language": "python",
 | ||
|    "sourceType": "notebook"
 | ||
|   },
 | ||
|   "kernelspec": {
 | ||
|    "display_name": "b2txt25",
 | ||
|    "language": "python",
 | ||
|    "name": "python3"
 | ||
|   },
 | ||
|   "language_info": {
 | ||
|    "codemirror_mode": {
 | ||
|     "name": "ipython",
 | ||
|     "version": 3
 | ||
|    },
 | ||
|    "file_extension": ".py",
 | ||
|    "mimetype": "text/x-python",
 | ||
|    "name": "python",
 | ||
|    "nbconvert_exporter": "python",
 | ||
|    "pygments_lexer": "ipython3",
 | ||
|    "version": "3.10.18"
 | ||
|   }
 | ||
|  },
 | ||
|  "nbformat": 4,
 | ||
|  "nbformat_minor": 4
 | ||
| }
 | 
