2424 lines
		
	
	
		
			115 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
		
		
			
		
	
	
			2424 lines
		
	
	
		
			115 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
|   | { | |||
|  |  "cells": [ | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "## 🎲 改进的随机批次生成器\n", | |||
|  |     "\n", | |||
|  |     "这个版本改进了数据生成策略:\n", | |||
|  |     "- **随机文件选择**: 每次从所有训练文件中随机选择 n=4 个文件\n", | |||
|  |     "- **随机样本采样**: 从选中的文件中随机采样指定数量的样本\n", | |||
|  |     "- **提高数据多样性**: 避免按固定顺序处理文件,减少过拟合风险\n", | |||
|  |     "- **可控批次大小**: 固定每批次样本数,确保训练稳定性" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": null, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [], | |||
|  |    "source": [ | |||
|  |     "# 🧪 测试改进的随机批次生成器\n", | |||
|  |     "\n", | |||
|  |     "def test_improved_generator(trainer, n_batches_to_test=3):\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    测试改进的生成器功能\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    print(\"🧪 测试改进的随机批次生成器...\")\n", | |||
|  |     "    print(\"=\" * 50)\n", | |||
|  |     "    \n", | |||
|  |     "    # 创建生成器\n", | |||
|  |     "    generator = trainer.get_training_batch_generator(\n", | |||
|  |     "        n_files_per_batch=4,  # 每批次选择4个文件\n", | |||
|  |     "        batch_size=5000       # 每批次5000个样本\n", | |||
|  |     "    )\n", | |||
|  |     "    \n", | |||
|  |     "    # 测试前几个批次\n", | |||
|  |     "    for i in range(n_batches_to_test):\n", | |||
|  |     "        print(f\"\\n📦 测试批次 {i+1}:\")\n", | |||
|  |     "        try:\n", | |||
|  |     "            batch_data, batch_name = next(generator)\n", | |||
|  |     "            print(f\"   ✅ 批次名称: {batch_name}\")\n", | |||
|  |     "            print(f\"   📊 数据形状: {batch_data.num_data()} 样本\")\n", | |||
|  |     "            print(f\"   🏷️ 标签范围: {min(batch_data.get_label())} - {max(batch_data.get_label())}\")\n", | |||
|  |     "            \n", | |||
|  |     "            # 分析标签分布\n", | |||
|  |     "            from collections import Counter\n", | |||
|  |     "            label_dist = Counter(batch_data.get_label())\n", | |||
|  |     "            print(f\"   📈 标签分布样例: 0={label_dist.get(0,0)}, 40={label_dist.get(40,0)}\")\n", | |||
|  |     "            \n", | |||
|  |     "        except Exception as e:\n", | |||
|  |     "            print(f\"   ❌ 错误: {e}\")\n", | |||
|  |     "            break\n", | |||
|  |     "    \n", | |||
|  |     "    print(f\"\\n✅ 生成器测试完成!\")\n", | |||
|  |     "    print(\"💡 每次运行都会生成不同的文件组合和样本\")\n", | |||
|  |     "\n", | |||
|  |     "# 如果trainer已经创建,运行测试\n", | |||
|  |     "if 'trainer' in locals():\n", | |||
|  |     "    test_improved_generator(trainer)\n", | |||
|  |     "else:\n", | |||
|  |     "    print(\"⚠️ 请先创建trainer对象再运行此测试\")" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "# 环境配置与Utils" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": null, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [], | |||
|  |    "source": [ | |||
|  |     "%%bash\n", | |||
|  |     "rm -rf /kaggle/working/nejm-brain-to-text/\n", | |||
|  |     "git clone https://github.com/ZH-CEN/nejm-brain-to-text.git\n", | |||
|  |     "cp /kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl /kaggle/working/nejm-brain-to-text/data/t15_copyTask.pkl\n", | |||
|  |     "\n", | |||
|  |     "ln -s /kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline /kaggle/working/nejm-brain-to-text/data\n", | |||
|  |     "ln -s /kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final /kaggle/working/nejm-brain-to-text/data\n", | |||
|  |     "ln -s /kaggle/input/rnn-pretagged-data /kaggle/working/nejm-brain-to-text/data/concatenated_data\n", | |||
|  |     "\n", | |||
|  |     "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n", | |||
|  |     "\n", | |||
|  |     "pip install \\\n", | |||
|  |     "    jupyter==1.1.1 \\\n", | |||
|  |     "    \"numpy>=1.26.0,<2.1.0\" \\\n", | |||
|  |     "    pandas==2.3.0 \\\n", | |||
|  |     "    matplotlib==3.10.1 \\\n", | |||
|  |     "    scipy==1.15.2 \\\n", | |||
|  |     "    scikit-learn==1.6.1 \\\n", | |||
|  |     "    lightgbm==4.3.0 \\\n", | |||
|  |     "    tqdm==4.67.1 \\\n", | |||
|  |     "    g2p_en==2.1.0 \\\n", | |||
|  |     "    h5py==3.13.0 \\\n", | |||
|  |     "    omegaconf==2.3.0 \\\n", | |||
|  |     "    editdistance==0.8.1 \\\n", | |||
|  |     "    huggingface-hub==0.33.1 \\\n", | |||
|  |     "    transformers==4.53.0 \\\n", | |||
|  |     "    tokenizers==0.21.2 \\\n", | |||
|  |     "    accelerate==1.8.1 \\\n", | |||
|  |     "    bitsandbytes==0.46.0 \\\n", | |||
|  |     "    seaborn==0.13.2\n", | |||
|  |     "cd /kaggle/working/nejm-brain-to-text/\n", | |||
|  |     "pip install -e ." | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 2, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "==================================================\n", | |||
|  |       "🔧 LightGBM GPU环境检查\n", | |||
|  |       "==================================================\n", | |||
|  |       "❌ 未检测到NVIDIA GPU或驱动\n", | |||
|  |       "\n", | |||
|  |       "❌ 未安装CUDA工具包\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "# 🚀 LightGBM GPU支持检查与配置\n", | |||
|  |     "\n", | |||
|  |     "print(\"=\"*50)\n", | |||
|  |     "print(\"🔧 LightGBM GPU环境检查\")\n", | |||
|  |     "print(\"=\"*50)\n", | |||
|  |     "\n", | |||
|  |     "# 检查CUDA和GPU驱动\n", | |||
|  |     "import subprocess\n", | |||
|  |     "import sys\n", | |||
|  |     "\n", | |||
|  |     "def run_command(command):\n", | |||
|  |     "    \"\"\"运行命令并返回结果\"\"\"\n", | |||
|  |     "    try:\n", | |||
|  |     "        result = subprocess.run(command, shell=True, capture_output=True, text=True, timeout=10)\n", | |||
|  |     "        return result.stdout.strip(), result.returncode == 0\n", | |||
|  |     "    except Exception as e:\n", | |||
|  |     "        return str(e), False\n", | |||
|  |     "\n", | |||
|  |     "# 检查NVIDIA GPU\n", | |||
|  |     "nvidia_output, nvidia_success = run_command(\"nvidia-smi --query-gpu=name,memory.total,driver_version --format=csv,noheader,nounits\")\n", | |||
|  |     "if nvidia_success:\n", | |||
|  |     "    print(\"✅ NVIDIA GPU检测:\")\n", | |||
|  |     "    for line in nvidia_output.split('\\n'):\n", | |||
|  |     "        if line.strip():\n", | |||
|  |     "            print(f\"   {line}\")\n", | |||
|  |     "else:\n", | |||
|  |     "    print(\"❌ 未检测到NVIDIA GPU或驱动\")\n", | |||
|  |     "\n", | |||
|  |     "# 检查CUDA版本\n", | |||
|  |     "cuda_output, cuda_success = run_command(\"nvcc --version\")\n", | |||
|  |     "if cuda_success:\n", | |||
|  |     "    print(\"\\n✅ CUDA工具包:\")\n", | |||
|  |     "    # 提取CUDA版本\n", | |||
|  |     "    for line in cuda_output.split('\\n'):\n", | |||
|  |     "        if 'release' in line:\n", | |||
|  |     "            print(f\"   {line.strip()}\")\n", | |||
|  |     "else:\n", | |||
|  |     "    print(\"\\n❌ 未安装CUDA工具包\")\n" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 3, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "[WinError 2] 系统找不到指定的文件。: 'nejm-brain-to-text'\n", | |||
|  |       "f:\\BRAIN-TO-TEXT\\nejm-brain-to-text\\brain-to-text-25\n" | |||
|  |      ] | |||
|  |     }, | |||
|  |     { | |||
|  |      "name": "stderr", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "d:\\SoftWare\\Anaconda3\\envs\\b2txt25\\lib\\site-packages\\IPython\\core\\magics\\osm.py:393: UserWarning: This is now an optional IPython functionality, using bookmarks requires you to install the `pickleshare` library.\n", | |||
|  |       "  bkms = self.shell.db.get('bookmarks', {})\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "%cd nejm-brain-to-text\n", | |||
|  |     "import numpy as np\n", | |||
|  |     "import os\n", | |||
|  |     "import pickle\n", | |||
|  |     "import matplotlib.pyplot as plt\n", | |||
|  |     "import matplotlib\n", | |||
|  |     "from g2p_en import G2p\n", | |||
|  |     "import pandas as pd\n", | |||
|  |     "import numpy as np\n", | |||
|  |     "from nejm_b2txt_utils.general_utils import *\n", | |||
|  |     "matplotlib.rcParams['pdf.fonttype'] = 42\n", | |||
|  |     "matplotlib.rcParams['ps.fonttype'] = 42\n", | |||
|  |     "matplotlib.rcParams['font.family'] = 'sans-serif'\n", | |||
|  |     "matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial Unicode MS', 'sans-serif']\n", | |||
|  |     "matplotlib.rcParams['axes.unicode_minus'] = False\n" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": null, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stderr", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "d:\\SoftWare\\Anaconda3\\envs\\b2txt25\\lib\\site-packages\\IPython\\core\\magics\\osm.py:417: UserWarning: This is now an optional IPython functionality, setting dhist requires you to install the `pickleshare` library.\n", | |||
|  |       "  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]\n" | |||
|  |      ] | |||
|  |     }, | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "f:\\BRAIN-TO-TEXT\\nejm-brain-to-text\\model_training\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "%cd model_training/\n", | |||
|  |     "from data_augmentations import gauss_smooth" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 5, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [], | |||
|  |    "source": [ | |||
|  |     "LOGIT_TO_PHONEME = [\n", | |||
|  |     "    'BLANK',\n", | |||
|  |     "    'AA', 'AE', 'AH', 'AO', 'AW',\n", | |||
|  |     "    'AY', 'B',  'CH', 'D', 'DH',\n", | |||
|  |     "    'EH', 'ER', 'EY', 'F', 'G',\n", | |||
|  |     "    'HH', 'IH', 'IY', 'JH', 'K',\n", | |||
|  |     "    'L', 'M', 'N', 'NG', 'OW',\n", | |||
|  |     "    'OY', 'P', 'R', 'S', 'SH',\n", | |||
|  |     "    'T', 'TH', 'UH', 'UW', 'V',\n", | |||
|  |     "    'W', 'Y', 'Z', 'ZH',\n", | |||
|  |     "    ' | ',\n", | |||
|  |     "]\n", | |||
|  |     "# 全局配置\n", | |||
|  |     "BALANCE_CONFIG = {\n", | |||
|  |     "    'enable_balance': True,        # 是否启用数据平衡\n", | |||
|  |     "    'undersample_labels': [0, 40], # 需要下采样的标签 (blank等高频标签)\n", | |||
|  |     "    'oversample_threshold': 0.5,   # 过采样阈值 (相对于均值的比例)\n", | |||
|  |     "    'random_state': 42            # 随机种子\n", | |||
|  |     "}\n", | |||
|  |     "# 全局PCA配置\n", | |||
|  |     "PCA_CONFIG = {\n", | |||
|  |     "    'enable_pca': True,           # 是否启用PCA\n", | |||
|  |     "    'n_components': None,         # None=自动选择, 或指定具体数值\n", | |||
|  |     "    'variance_threshold': 0.95,   # 保留95%的方差\n", | |||
|  |     "    'sample_size': 15000,         # 用于拟合PCA的样本数\n", | |||
|  |     "}\n", | |||
|  |     "\n", | |||
|  |     "# 全局PCA对象 (确保只拟合一次)\n", | |||
|  |     "GLOBAL_PCA = {\n", | |||
|  |     "    'scaler': None,\n", | |||
|  |     "    'pca': None,\n", | |||
|  |     "    'is_fitted': False,\n", | |||
|  |     "    'n_components': None\n", | |||
|  |     "}\n", | |||
|  |     "# 设置数据目录和参数【PCA初始化】\n", | |||
|  |     "data_dir = '../data/concatenated_data'\n", | |||
|  |     "MAX_SAMPLES_PER_FILE = -1  # 每个文件最大样本数,可调整" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "# 数据读取工作流" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "# 2️⃣ 数据加载与PCA降维" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": null, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [], | |||
|  |    "source": [ | |||
|  |     "# 🚀 内存友好的数据读取 - 分批加载策略 + PCA降维 【这里还缺一个采样】\n", | |||
|  |     "import os\n", | |||
|  |     "import numpy as np\n", | |||
|  |     "import gc\n", | |||
|  |     "from sklearn.decomposition import PCA\n", | |||
|  |     "from sklearn.preprocessing import StandardScaler\n", | |||
|  |     "import joblib\n", | |||
|  |     "import matplotlib.pyplot as plt\n", | |||
|  |     "\n", | |||
|  |     "\n", | |||
|  |     "def load_data_batch(data_dir, data_type, max_samples_per_file=5000, verbose=True):\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    分批加载指定类型的数据\n", | |||
|  |     "    \n", | |||
|  |     "    Args:\n", | |||
|  |     "        data_dir: 数据目录\n", | |||
|  |     "        data_type: 'train', 'val', 'test'\n", | |||
|  |     "        max_samples_per_file: 每个文件最大加载样本数\n", | |||
|  |     "        verbose: 是否打印每个文件的加载进度\n", | |||
|  |     "    \n", | |||
|  |     "    Returns:\n", | |||
|  |     "        generator: 数据批次生成器\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    files = [f for f in os.listdir(data_dir) if f.endswith('.npz') and data_type in f]\n", | |||
|  |     "    \n", | |||
|  |     "    for file_idx, f in enumerate(files):\n", | |||
|  |     "        if verbose:\n", | |||
|  |     "            print(f\"  正在加载文件 {file_idx+1}/{len(files)}: {f}\")\n", | |||
|  |     "        \n", | |||
|  |     "        data = np.load(os.path.join(data_dir, f), allow_pickle=True)\n", | |||
|  |     "        trials = data['neural_logits_concatenated']\n", | |||
|  |     "        \n", | |||
|  |     "        # 限制每个文件的样本数\n", | |||
|  |     "        if len(trials) > max_samples_per_file and max_samples_per_file != -1:\n", | |||
|  |     "            trials = trials[:max_samples_per_file]\n", | |||
|  |     "            if verbose:\n", | |||
|  |     "                print(f\"    限制样本数至: {max_samples_per_file}\")\n", | |||
|  |     "        \n", | |||
|  |     "        yield trials, f\n", | |||
|  |     "        \n", | |||
|  |     "        # 清理内存\n", | |||
|  |     "        del data, trials\n", | |||
|  |     "        gc.collect()\n", | |||
|  |     "\n", | |||
|  |     "def extract_features_labels_batch(trials_batch):\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    从试验批次中提取特征和标签\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    features = []\n", | |||
|  |     "    labels = []\n", | |||
|  |     "    \n", | |||
|  |     "    for trial in trials_batch:\n", | |||
|  |     "        if trial.shape[0] > 0:\n", | |||
|  |     "            for t in range(trial.shape[0]):\n", | |||
|  |     "                neural_features = trial[t, :7168]  # 前7168维神经特征\n", | |||
|  |     "                rnn_logits = trial[t, 7168:]       # 后41维RNN输出\n", | |||
|  |     "                phoneme_label = np.argmax(rnn_logits)\n", | |||
|  |     "                \n", | |||
|  |     "                features.append(neural_features)\n", | |||
|  |     "                labels.append(phoneme_label)\n", | |||
|  |     "    \n", | |||
|  |     "    return np.array(features), np.array(labels)\n", | |||
|  |     "\n", | |||
|  |     "def fit_global_pca(data_dir, config):\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    在训练数据上拟合全局PCA (只执行一次)\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    if GLOBAL_PCA['is_fitted'] or not config['enable_pca']:\n", | |||
|  |     "        print(\"PCA已拟合或未启用,跳过拟合步骤\")\n", | |||
|  |     "        return\n", | |||
|  |     "    \n", | |||
|  |     "    print(f\"拟合全局PCA降维器...\")\n", | |||
|  |     "    print(f\"   配置: {config}\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 收集训练样本\n", | |||
|  |     "    sample_features = []\n", | |||
|  |     "    collected_samples = 0\n", | |||
|  |     "    \n", | |||
|  |     "    for trials_batch, filename in load_data_batch(data_dir, 'train', 5000, verbose=False):\n", | |||
|  |     "        features, labels = extract_features_labels_batch(trials_batch)\n", | |||
|  |     "        sample_features.append(features)\n", | |||
|  |     "        collected_samples += features.shape[0]\n", | |||
|  |     "        \n", | |||
|  |     "        if collected_samples >= config['sample_size']:\n", | |||
|  |     "            break\n", | |||
|  |     "    \n", | |||
|  |     "    if sample_features:\n", | |||
|  |     "        # 合并样本数据\n", | |||
|  |     "        X_sample = np.vstack(sample_features)[:config['sample_size']]\n", | |||
|  |     "        print(f\"   实际样本数: {X_sample.shape[0]}\")\n", | |||
|  |     "        print(f\"   原始特征数: {X_sample.shape[1]}\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 标准化\n", | |||
|  |     "        GLOBAL_PCA['scaler'] = StandardScaler()\n", | |||
|  |     "        X_sample_scaled = GLOBAL_PCA['scaler'].fit_transform(X_sample)\n", | |||
|  |     "        \n", | |||
|  |     "        # 确定PCA成分数\n", | |||
|  |     "        if config['n_components'] is None:\n", | |||
|  |     "            print(f\"   自动选择PCA成分数...\")\n", | |||
|  |     "            pca_full = PCA()\n", | |||
|  |     "            pca_full.fit(X_sample_scaled)\n", | |||
|  |     "            \n", | |||
|  |     "            cumsum_var = np.cumsum(pca_full.explained_variance_ratio_)\n", | |||
|  |     "            optimal_components = np.argmax(cumsum_var >= config['variance_threshold']) + 1\n", | |||
|  |     "            GLOBAL_PCA['n_components'] = min(optimal_components, X_sample.shape[1])\n", | |||
|  |     "            \n", | |||
|  |     "            print(f\"   保留{config['variance_threshold']*100}%方差需要: {optimal_components} 个成分\")\n", | |||
|  |     "            print(f\"   选择成分数: {GLOBAL_PCA['n_components']}\")\n", | |||
|  |     "        else:\n", | |||
|  |     "            GLOBAL_PCA['n_components'] = config['n_components']\n", | |||
|  |     "            print(f\"   使用指定成分数: {GLOBAL_PCA['n_components']}\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 拟合最终PCA\n", | |||
|  |     "        GLOBAL_PCA['pca'] = PCA(n_components=GLOBAL_PCA['n_components'], random_state=42)\n", | |||
|  |     "        GLOBAL_PCA['pca'].fit(X_sample_scaled)\n", | |||
|  |     "        GLOBAL_PCA['is_fitted'] = True\n", | |||
|  |     "        \n", | |||
|  |     "        # 保存模型\n", | |||
|  |     "        pca_path = \"global_pca_model.joblib\"\n", | |||
|  |     "        joblib.dump({\n", | |||
|  |     "            'scaler': GLOBAL_PCA['scaler'], \n", | |||
|  |     "            'pca': GLOBAL_PCA['pca'],\n", | |||
|  |     "            'n_components': GLOBAL_PCA['n_components']\n", | |||
|  |     "        }, pca_path)\n", | |||
|  |     "        \n", | |||
|  |     "        print(f\"   全局PCA拟合完成\")\n", | |||
|  |     "        print(f\"      降维: {X_sample.shape[1]} → {GLOBAL_PCA['n_components']}\")\n", | |||
|  |     "        print(f\"      降维比例: {GLOBAL_PCA['n_components']/X_sample.shape[1]:.2%}\")\n", | |||
|  |     "        print(f\"      保留方差: {GLOBAL_PCA['pca'].explained_variance_ratio_.sum():.4f}\")\n", | |||
|  |     "        print(f\"      模型已保存: {pca_path}\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 清理样本数据\n", | |||
|  |     "        del sample_features, X_sample, X_sample_scaled\n", | |||
|  |     "        gc.collect()\n", | |||
|  |     "    else:\n", | |||
|  |     "        print(\"无法收集样本数据用于PCA拟合\")\n", | |||
|  |     "\n", | |||
|  |     "def apply_pca_transform(features):\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    应用全局PCA变换\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    if not PCA_CONFIG['enable_pca'] or not GLOBAL_PCA['is_fitted']:\n", | |||
|  |     "        return features\n", | |||
|  |     "    \n", | |||
|  |     "    # 标准化 + PCA变换\n", | |||
|  |     "    features_scaled = GLOBAL_PCA['scaler'].transform(features)\n", | |||
|  |     "    features_pca = GLOBAL_PCA['pca'].transform(features_scaled)\n", | |||
|  |     "    return features_pca\n", | |||
|  |     "\n" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "## 📊 数据平衡策略 - 标签分布分析与采样优化" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 7, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [], | |||
|  |    "source": [ | |||
|  |     "# 【采样核心实现】\n", | |||
|  |     "def balance_dataset(X, y, config=BALANCE_CONFIG):\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    对数据集进行平衡处理:下采样 + 过采样\n", | |||
|  |     "    \n", | |||
|  |     "    Args:\n", | |||
|  |     "        X: 特征数据\n", | |||
|  |     "        y: 标签数据\n", | |||
|  |     "        config: 平衡配置\n", | |||
|  |     "    \n", | |||
|  |     "    Returns:\n", | |||
|  |     "        X_balanced, y_balanced: 平衡后的数据\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    if not config['enable_balance']:\n", | |||
|  |     "        print(\"🔕 数据平衡已禁用,返回原始数据\")\n", | |||
|  |     "        return X, y\n", | |||
|  |     "    \n", | |||
|  |     "    print(f\"\\n⚖️ 开始数据平衡处理...\")\n", | |||
|  |     "    print(f\"   原始数据: {X.shape[0]:,} 样本\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 分析当前分布 (只考虑1-39号标签的均值)\n", | |||
|  |     "    label_counts = Counter(y)\n", | |||
|  |     "    counts_exclude_0_40 = [label_counts.get(i, 0) for i in range(1, 40)]  # 1-39号标签\n", | |||
|  |     "    mean_count = np.mean(counts_exclude_0_40)  # 只计算1-39号标签的均值\n", | |||
|  |     "    \n", | |||
|  |     "    print(f\"   均值样本数 (标签1-39): {mean_count:.0f}\")\n", | |||
|  |     "    print(f\"   下采样标签: {config['undersample_labels']}\")\n", | |||
|  |     "    print(f\"   过采样阈值: {config['oversample_threshold']} * 均值\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 准备平衡后的数据\n", | |||
|  |     "    X_balanced = []\n", | |||
|  |     "    y_balanced = []\n", | |||
|  |     "    \n", | |||
|  |     "    random.seed(config['random_state'])\n", | |||
|  |     "    np.random.seed(config['random_state'])\n", | |||
|  |     "    \n", | |||
|  |     "    for label in range(41):\n", | |||
|  |     "        # 获取当前标签的所有样本\n", | |||
|  |     "        label_mask = (y == label)\n", | |||
|  |     "        X_label = X[label_mask]\n", | |||
|  |     "        y_label = y[label_mask]\n", | |||
|  |     "        current_count = len(y_label)\n", | |||
|  |     "        \n", | |||
|  |     "        if current_count == 0:\n", | |||
|  |     "            continue\n", | |||
|  |     "        \n", | |||
|  |     "        # 决定采样策略\n", | |||
|  |     "        if label in config['undersample_labels']:\n", | |||
|  |     "            # 下采样到均值水平\n", | |||
|  |     "            target_count = int(mean_count)\n", | |||
|  |     "            if current_count > target_count:\n", | |||
|  |     "                # 下采样\n", | |||
|  |     "                indices = np.random.choice(current_count, target_count, replace=False)\n", | |||
|  |     "                X_resampled = X_label[indices]\n", | |||
|  |     "                y_resampled = y_label[indices]\n", | |||
|  |     "                print(f\"   📉 标签 {label}: {current_count} → {target_count} (下采样)\")\n", | |||
|  |     "            else:\n", | |||
|  |     "                X_resampled = X_label\n", | |||
|  |     "                y_resampled = y_label\n", | |||
|  |     "                print(f\"   ➡️ 标签 {label}: {current_count} (无需下采样)\")\n", | |||
|  |     "                \n", | |||
|  |     "        elif current_count < mean_count * config['oversample_threshold']:\n", | |||
|  |     "            # 过采样到阈值水平\n", | |||
|  |     "            target_count = int(mean_count * config['oversample_threshold'])\n", | |||
|  |     "            if current_count < target_count:\n", | |||
|  |     "                # 过采样\n", | |||
|  |     "                X_resampled, y_resampled = resample(\n", | |||
|  |     "                    X_label, y_label, \n", | |||
|  |     "                    n_samples=target_count, \n", | |||
|  |     "                    random_state=config['random_state']\n", | |||
|  |     "                )\n", | |||
|  |     "                print(f\"   📈 标签 {label}: {current_count} → {target_count} (过采样)\")\n", | |||
|  |     "            else:\n", | |||
|  |     "                X_resampled = X_label\n", | |||
|  |     "                y_resampled = y_label\n", | |||
|  |     "                print(f\"   ➡️ 标签 {label}: {current_count} (无需过采样)\")\n", | |||
|  |     "        else:\n", | |||
|  |     "            # 保持不变\n", | |||
|  |     "            X_resampled = X_label\n", | |||
|  |     "            y_resampled = y_label\n", | |||
|  |     "            print(f\"   ✅ 标签 {label}: {current_count} (已平衡)\")\n", | |||
|  |     "        \n", | |||
|  |     "        X_balanced.append(X_resampled)\n", | |||
|  |     "        y_balanced.append(y_resampled)\n", | |||
|  |     "    \n", | |||
|  |     "    # 合并所有平衡后的数据\n", | |||
|  |     "    X_balanced = np.vstack(X_balanced)\n", | |||
|  |     "    y_balanced = np.hstack(y_balanced)\n", | |||
|  |     "    \n", | |||
|  |     "    # 随机打乱\n", | |||
|  |     "    shuffle_indices = np.random.permutation(len(y_balanced))\n", | |||
|  |     "    X_balanced = X_balanced[shuffle_indices]\n", | |||
|  |     "    y_balanced = y_balanced[shuffle_indices]\n", | |||
|  |     "    \n", | |||
|  |     "    print(f\"   ✅ 平衡完成: {X_balanced.shape[0]:,} 样本\")\n", | |||
|  |     "    print(f\"   数据变化: {X.shape[0]:,} → {X_balanced.shape[0]:,} ({X_balanced.shape[0]/X.shape[0]:.2f}x)\")\n", | |||
|  |     "    \n", | |||
|  |     "    return X_balanced, y_balanced\n" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "## 🔄 集成数据平衡的内存友好数据加载器" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "## 🧪 数据平衡效果测试" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "## 🚀 改进版智能数据处理管道" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 8, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "🚀 创建智能数据处理管道...\n", | |||
|  |       "✅ 管道创建完成,准备执行步骤1...\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "# 🚀 改进版智能数据处理管道【没有解决分批训练的问题】\n", | |||
|  |     "# 流程:分析分布 → 确定采样比率 → 拟合PCA(只下采样) → 数据处理(下采样+上采样+PCA)\n", | |||
|  |     "\n", | |||
|  |     "import numpy as np\n", | |||
|  |     "import matplotlib.pyplot as plt\n", | |||
|  |     "from collections import Counter\n", | |||
|  |     "from sklearn.utils import resample\n", | |||
|  |     "from sklearn.decomposition import PCA\n", | |||
|  |     "from sklearn.preprocessing import StandardScaler\n", | |||
|  |     "import joblib\n", | |||
|  |     "import random\n", | |||
|  |     "import gc\n", | |||
|  |     "\n", | |||
|  |     "class SmartDataPipeline:\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    智能数据处理管道\n", | |||
|  |     "    步骤1: 分析数据分布,确定采样策略\n", | |||
|  |     "    步骤2: 仅下采样拟合PCA参数\n", | |||
|  |     "    步骤3: 数据处理时应用完整采样+PCA降维\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    \n", | |||
|  |     "    def __init__(self, data_dir, random_state=42):\n", | |||
|  |     "        self.data_dir = data_dir\n", | |||
|  |     "        self.random_state = random_state\n", | |||
|  |     "        \n", | |||
|  |     "        # 步骤1: 分布分析结果\n", | |||
|  |     "        self.distribution_analysis = None\n", | |||
|  |     "        self.sampling_strategy = None\n", | |||
|  |     "        \n", | |||
|  |     "        # 步骤2: PCA参数(基于下采样数据拟合)\n", | |||
|  |     "        self.pca_scaler = None\n", | |||
|  |     "        self.pca_model = None\n", | |||
|  |     "        self.pca_components = None\n", | |||
|  |     "        self.pca_fitted = False\n", | |||
|  |     "        \n", | |||
|  |     "        # 配置参数\n", | |||
|  |     "        self.undersample_labels = [0, 40]  # 需要下采样的标签\n", | |||
|  |     "        self.oversample_threshold = 0.5    # 过采样阈值(相对于均值)\n", | |||
|  |     "        self.pca_variance_threshold = 0.95 # PCA保留方差比例\n", | |||
|  |     "        self.pca_sample_size = 15000       # PCA拟合样本数\n", | |||
|  |     "        \n", | |||
|  |     "    def step1_analyze_distribution(self, max_samples=100000):\n", | |||
|  |     "        \"\"\"\n", | |||
|  |     "        步骤1: 分析数据分布,确定采样策略\n", | |||
|  |     "        \"\"\"\n", | |||
|  |     "        print(\"🔍 步骤1: 分析数据分布...\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 分析验证集分布(代表整体分布特征)\n", | |||
|  |     "        all_labels = []\n", | |||
|  |     "        for trials_batch, filename in load_data_batch(self.data_dir, 'val', 5000):\n", | |||
|  |     "            _, labels = extract_features_labels_batch(trials_batch)\n", | |||
|  |     "            all_labels.extend(labels.tolist())\n", | |||
|  |     "            if len(all_labels) >= max_samples:\n", | |||
|  |     "                break\n", | |||
|  |     "        \n", | |||
|  |     "        # 统计分析\n", | |||
|  |     "        label_counts = Counter(all_labels)\n", | |||
|  |     "        \n", | |||
|  |     "        # 计算1-39标签的均值(排除0和40)\n", | |||
|  |     "        counts_1_39 = [label_counts.get(i, 0) for i in range(1, 40)]\n", | |||
|  |     "        target_mean = np.mean(counts_1_39)\n", | |||
|  |     "        \n", | |||
|  |     "        # 生成采样策略\n", | |||
|  |     "        sampling_strategy = {}\n", | |||
|  |     "        for label in range(41):\n", | |||
|  |     "            current_count = label_counts.get(label, 0)\n", | |||
|  |     "            \n", | |||
|  |     "            if label in self.undersample_labels:\n", | |||
|  |     "                # 下采样到均值水平\n", | |||
|  |     "                target_count = int(target_mean)\n", | |||
|  |     "                action = 'undersample' if current_count > target_count else 'keep'\n", | |||
|  |     "            elif current_count < target_mean * self.oversample_threshold:\n", | |||
|  |     "                # 过采样到阈值水平\n", | |||
|  |     "                target_count = int(target_mean * self.oversample_threshold)\n", | |||
|  |     "                action = 'oversample' if current_count < target_count else 'keep'\n", | |||
|  |     "            else:\n", | |||
|  |     "                # 保持不变\n", | |||
|  |     "                target_count = current_count\n", | |||
|  |     "                action = 'keep'\n", | |||
|  |     "            \n", | |||
|  |     "            sampling_strategy[label] = {\n", | |||
|  |     "                'current_count': current_count,\n", | |||
|  |     "                'target_count': target_count,\n", | |||
|  |     "                'action': action\n", | |||
|  |     "            }\n", | |||
|  |     "        \n", | |||
|  |     "        self.distribution_analysis = {\n", | |||
|  |     "            'label_counts': label_counts,\n", | |||
|  |     "            'target_mean': target_mean,\n", | |||
|  |     "            'total_samples': len(all_labels)\n", | |||
|  |     "        }\n", | |||
|  |     "        self.sampling_strategy = sampling_strategy\n", | |||
|  |     "        \n", | |||
|  |     "        print(f\"   ✅ 分析完成: {len(all_labels):,} 样本\")\n", | |||
|  |     "        print(f\"   📊 标签1-39均值: {target_mean:.0f}\")\n", | |||
|  |     "        print(f\"   📉 下采样标签: {self.undersample_labels} → {target_mean:.0f}\")\n", | |||
|  |     "        print(f\"   📈 过采样阈值: {self.oversample_threshold} × 均值 = {target_mean * self.oversample_threshold:.0f}\")\n", | |||
|  |     "        \n", | |||
|  |     "        return self.distribution_analysis, self.sampling_strategy\n", | |||
|  |     "\n", | |||
|  |     "# 创建智能数据处理管道\n", | |||
|  |     "print(\"🚀 创建智能数据处理管道...\")\n", | |||
|  |     "pipeline = SmartDataPipeline(data_dir, random_state=42)\n", | |||
|  |     "print(\"✅ 管道创建完成,准备执行步骤1...\")" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": null, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "✅ 步骤2方法已添加到管道\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "# 继续添加智能管道的其他方法【管道完善】\n", | |||
|  |     "\n", | |||
|  |     "def step2_fit_pca_with_undersampling(self):\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    步骤2: 仅对下采样数据拟合PCA参数(不进行过采样,避免PCA被过采样影响)\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    if self.sampling_strategy is None:\n", | |||
|  |     "        raise ValueError(\"请先执行步骤1: step1_analyze_distribution()\")\n", | |||
|  |     "    \n", | |||
|  |     "    print(\"\\n🔧 步骤2: 拟合PCA参数(仅下采样,不过采样)...\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 收集用于PCA拟合的样本(只下采样,不过采样)\n", | |||
|  |     "    pca_features = []\n", | |||
|  |     "    collected_samples = 0\n", | |||
|  |     "    \n", | |||
|  |     "    for trials_batch, filename in load_data_batch(self.data_dir, 'train', 3000, verbose=False):\n", | |||
|  |     "        features, labels = extract_features_labels_batch(trials_batch)\n", | |||
|  |     "        \n", | |||
|  |     "        # 对当前批次应用仅下采样策略\n", | |||
|  |     "        downsampled_features, downsampled_labels = self._apply_undersampling_only(features, labels)\n", | |||
|  |     "        \n", | |||
|  |     "        if downsampled_features.shape[0] > 0:\n", | |||
|  |     "            pca_features.append(downsampled_features)\n", | |||
|  |     "            collected_samples += downsampled_features.shape[0]\n", | |||
|  |     "        \n", | |||
|  |     "        if collected_samples >= self.pca_sample_size:\n", | |||
|  |     "            break\n", | |||
|  |     "    \n", | |||
|  |     "    if not pca_features:\n", | |||
|  |     "        raise ValueError(\"无法收集用于PCA拟合的样本,请检查数据或采样策略\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 合并样本用于PCA拟合\n", | |||
|  |     "    X_pca_fit = np.vstack(pca_features)[:self.pca_sample_size]\n", | |||
|  |     "    print(f\"   用于PCA拟合的样本数: {X_pca_fit.shape[0]:,}\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 标准化 + PCA\n", | |||
|  |     "    self.pca_scaler = StandardScaler()\n", | |||
|  |     "    X_scaled = self.pca_scaler.fit_transform(X_pca_fit)\n", | |||
|  |     "    \n", | |||
|  |     "    # 自动选择PCA成分数以保留指定方差\n", | |||
|  |     "    if self.pca_components is None:\n", | |||
|  |     "        pca_full = PCA()\n", | |||
|  |     "        pca_full.fit(X_scaled)\n", | |||
|  |     "        cumsum_var = np.cumsum(pca_full.explained_variance_ratio_)\n", | |||
|  |     "        optimal_components = np.argmax(cumsum_var >= self.pca_variance_threshold) + 1\n", | |||
|  |     "        self.pca_components = optimal_components\n", | |||
|  |     "    \n", | |||
|  |     "    self.pca_model = PCA(n_components=self.pca_components, random_state=self.random_state)\n", | |||
|  |     "    self.pca_model.fit(X_scaled)\n", | |||
|  |     "    self.pca_fitted = True\n", | |||
|  |     "    \n", | |||
|  |     "    print(f\"   PCA拟合完成: 7168 → {self.pca_components}\")\n", | |||
|  |     "    print(f\"   保留方差: {self.pca_model.explained_variance_ratio_.sum():.4f}\")\n", | |||
|  |     "\n", | |||
|  |     "def _apply_undersampling_only(self, X, y):\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    仅对指定标签做下采样(不做过采样)\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    if self.sampling_strategy is None:\n", | |||
|  |     "        raise ValueError(\"请先执行步骤1: step1_analyze_distribution()\")\n", | |||
|  |     "    \n", | |||
|  |     "    X_result = []\n", | |||
|  |     "    y_result = []\n", | |||
|  |     "    \n", | |||
|  |     "    for label in range(41):\n", | |||
|  |     "        label_mask = (y == label)\n", | |||
|  |     "        X_label = X[label_mask]\n", | |||
|  |     "        y_label = y[label_mask]\n", | |||
|  |     "        current_count = len(y_label)\n", | |||
|  |     "        \n", | |||
|  |     "        if current_count == 0:\n", | |||
|  |     "            continue\n", | |||
|  |     "        \n", | |||
|  |     "        strategy = self.sampling_strategy[label]\n", | |||
|  |     "        \n", | |||
|  |     "        if strategy['action'] == 'undersample' and current_count > strategy['target_count']:\n", | |||
|  |     "            # 下采样\n", | |||
|  |     "            indices = np.random.choice(current_count, strategy['target_count'], replace=False)\n", | |||
|  |     "            X_resampled = X_label[indices]\n", | |||
|  |     "            y_resampled = y_label[indices]\n", | |||
|  |     "        else:\n", | |||
|  |     "            # 保持原样\n", | |||
|  |     "            X_resampled = X_label\n", | |||
|  |     "            y_resampled = y_label\n", | |||
|  |     "        \n", | |||
|  |     "        X_result.append(X_resampled)\n", | |||
|  |     "        y_result.append(y_resampled)\n", | |||
|  |     "    \n", | |||
|  |     "    if X_result:\n", | |||
|  |     "        return np.vstack(X_result), np.hstack(y_result)\n", | |||
|  |     "    else:\n", | |||
|  |     "        return np.array([]).reshape(0, X.shape[1]), np.array([])\n", | |||
|  |     "\n", | |||
|  |     "# 动态添加方法到类\n", | |||
|  |     "SmartDataPipeline.step2_fit_pca_with_undersampling = step2_fit_pca_with_undersampling\n", | |||
|  |     "SmartDataPipeline._apply_undersampling_only = _apply_undersampling_only\n", | |||
|  |     "\n", | |||
|  |     "print(\"步骤2方法已添加到管道\")" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": null, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "✅ 所有方法已添加到智能管道\n", | |||
|  |       "\n", | |||
|  |       "📋 智能数据处理管道状态:\n", | |||
|  |       "   🔍 步骤1 - 分布分析: ❌ 未完成\n", | |||
|  |       "   🔧 步骤2 - PCA拟合: ❌ 未完成\n", | |||
|  |       "\n", | |||
|  |       "🎯 使用流程:\n", | |||
|  |       "   1. pipeline.step1_analyze_distribution()\n", | |||
|  |       "   2. pipeline.step2_fit_pca_with_undersampling()\n", | |||
|  |       "   3. pipeline.step3_process_data('train')  # 训练集\n", | |||
|  |       "      pipeline.step3_process_data('val')    # 验证集\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "# 添加智能管道的剩余方法\n", | |||
|  |     "\n", | |||
|  |     "def _apply_full_sampling(self, X, y):\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    应用完整的采样策略(下采样+过采样)\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    X_result = []\n", | |||
|  |     "    y_result = []\n", | |||
|  |     "    \n", | |||
|  |     "    np.random.seed(self.random_state)\n", | |||
|  |     "    \n", | |||
|  |     "    for label in range(41):\n", | |||
|  |     "        label_mask = (y == label)\n", | |||
|  |     "        X_label = X[label_mask]\n", | |||
|  |     "        y_label = y[label_mask]\n", | |||
|  |     "        current_count = len(y_label)\n", | |||
|  |     "        \n", | |||
|  |     "        if current_count == 0:\n", | |||
|  |     "            continue\n", | |||
|  |     "        \n", | |||
|  |     "        strategy = self.sampling_strategy[label]\n", | |||
|  |     "        target_count = strategy['target_count']\n", | |||
|  |     "        \n", | |||
|  |     "        if strategy['action'] == 'undersample' and current_count > target_count:\n", | |||
|  |     "            indices = np.random.choice(current_count, target_count, replace=False)\n", | |||
|  |     "            X_resampled = X_label[indices]\n", | |||
|  |     "            y_resampled = y_label[indices]\n", | |||
|  |     "        elif strategy['action'] == 'oversample' and current_count < target_count:\n", | |||
|  |     "            X_resampled, y_resampled = resample(\n", | |||
|  |     "                X_label, y_label, \n", | |||
|  |     "                n_samples=target_count, \n", | |||
|  |     "                random_state=self.random_state\n", | |||
|  |     ")\n", | |||
|  |     "        else:\n", | |||
|  |     "            X_resampled = X_label\n", | |||
|  |     "            y_resampled = y_label\n", | |||
|  |     "        \n", | |||
|  |     "        X_result.append(X_resampled)\n", | |||
|  |     "        y_result.append(y_resampled)\n", | |||
|  |     "    \n", | |||
|  |     "    if X_result:\n", | |||
|  |     "        return np.vstack(X_result), np.hstack(y_result)\n", | |||
|  |     "    else:\n", | |||
|  |     "        return np.array([]).reshape(0, X.shape[1]), np.array([])\n", | |||
|  |     "\n", | |||
|  |     "def _apply_pca_transform(self, X):\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    应用PCA变换\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    if not self.pca_fitted:\n", | |||
|  |     "        return X\n", | |||
|  |     "    \n", | |||
|  |     "    X_scaled = self.pca_scaler.transform(X)\n", | |||
|  |     "    X_pca = self.pca_model.transform(X_scaled)\n", | |||
|  |     "    return X_pca\n", | |||
|  |     "\n", | |||
|  |     "def step3_process_data(self, data_type, apply_sampling=None):\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    步骤3: 处理数据(采样+PCA降维)\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    if not self.pca_fitted:\n", | |||
|  |     "        raise ValueError(\"请先执行步骤2: step2_fit_pca_with_undersampling()\")\n", | |||
|  |     "    \n", | |||
|  |     "    if apply_sampling is None:\n", | |||
|  |     "        apply_sampling = (data_type == 'train')\n", | |||
|  |     "    \n", | |||
|  |     "    print(f\"\\n处理{data_type}数据...\")\n", | |||
|  |     "    print(f\"   采样策略: {'启用' if apply_sampling else '禁用'}\")\n", | |||
|  |     "    \n", | |||
|  |     "    all_features = []\n", | |||
|  |     "    all_labels = []\n", | |||
|  |     "    \n", | |||
|  |     "    # 在内部关闭加载时的逐文件打印\n", | |||
|  |     "    for trials_batch, filename in load_data_batch(self.data_dir, data_type, 3000, verbose=False):\n", | |||
|  |     "        features, labels = extract_features_labels_batch(trials_batch)\n", | |||
|  |     "        \n", | |||
|  |     "        if apply_sampling:\n", | |||
|  |     "            features_sampled, labels_sampled = self._apply_full_sampling(features, labels)\n", | |||
|  |     "        else:\n", | |||
|  |     "            features_sampled, labels_sampled = features, labels\n", | |||
|  |     "        \n", | |||
|  |     "        if features_sampled.shape[0] > 0:\n", | |||
|  |     "            features_pca = self._apply_pca_transform(features_sampled)\n", | |||
|  |     "            all_features.append(features_pca)\n", | |||
|  |     "            all_labels.append(labels_sampled)\n", | |||
|  |     "    \n", | |||
|  |     "    if all_features:\n", | |||
|  |     "        X = np.vstack(all_features)\n", | |||
|  |     "        y = np.hstack(all_labels)\n", | |||
|  |     "        \n", | |||
|  |     "        shuffle_indices = np.random.permutation(len(y))\n", | |||
|  |     "        X = X[shuffle_indices]\n", | |||
|  |     "        y = y[shuffle_indices]\n", | |||
|  |     "        \n", | |||
|  |     "        print(f\"   完成: {X.shape[0]:,} 样本, {X.shape[1]} 特征\")\n", | |||
|  |     "        \n", | |||
|  |     "        del all_features, all_labels\n", | |||
|  |     "        gc.collect()\n", | |||
|  |     "        \n", | |||
|  |     "        return X, y\n", | |||
|  |     "    else:\n", | |||
|  |     "        return None, None\n", | |||
|  |     "\n", | |||
|  |     "def print_summary(self):\n", | |||
|  |     "    print(\"\\n智能数据处理管道状态:\")\n", | |||
|  |     "    print(f\"   步骤1 - 分布分析: {'完成' if self.distribution_analysis else '未完成'}\")\n", | |||
|  |     "    print(f\"   步骤2 - PCA拟合: {'完成' if self.pca_fitted else '未完成'}\")\n", | |||
|  |     "    \n", | |||
|  |     "    if self.distribution_analysis:\n", | |||
|  |     "        target_mean = self.distribution_analysis['target_mean']\n", | |||
|  |     "        print(f\"   标签1-39均值: {target_mean:.0f}\")\n", | |||
|  |     "    \n", | |||
|  |     "    if self.pca_fitted:\n", | |||
|  |     "        print(f\"   PCA降维: 7168 → {self.pca_components} ({self.pca_components/7168:.1%})\")\n", | |||
|  |     "        print(f\"   保留方差: {self.pca_model.explained_variance_ratio_.sum():.4f}\")\n", | |||
|  |     "\n", | |||
|  |     "# 动态添加剩余方法到类\n", | |||
|  |     "SmartDataPipeline._apply_full_sampling = _apply_full_sampling\n", | |||
|  |     "SmartDataPipeline._apply_pca_transform = _apply_pca_transform\n", | |||
|  |     "SmartDataPipeline.step3_process_data = step3_process_data\n", | |||
|  |     "SmartDataPipeline.print_summary = print_summary\n", | |||
|  |     "\n", | |||
|  |     "print(\"所有方法已添加到智能管道\")\n", | |||
|  |     "pipeline.print_summary()" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "## 🎯 数据增强模块 - 时序神经数据增强策略" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "## 🔥 执行智能数据处理管道" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 32, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "🚀 开始执行智能数据处理管道...\n", | |||
|  |       "============================================================\n", | |||
|  |       "\n", | |||
|  |       "======================🔍 STEP 1: 分析数据分布======================\n", | |||
|  |       "🔍 步骤1: 分析数据分布...\n", | |||
|  |       "  正在加载文件 1/41: t15.2023.11.17_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 2/41: t15.2023.12.17_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 2/41: t15.2023.12.17_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 3/41: t15.2023.10.15_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 3/41: t15.2023.10.15_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 4/41: t15.2023.10.08_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 4/41: t15.2023.10.08_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 5/41: t15.2025.01.10_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 5/41: t15.2025.01.10_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 6/41: t15.2023.12.08_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 6/41: t15.2023.12.08_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 7/41: t15.2024.03.08_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 7/41: t15.2024.03.08_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 8/41: t15.2024.03.15_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 8/41: t15.2024.03.15_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 9/41: t15.2025.03.14_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 9/41: t15.2025.03.14_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 10/41: t15.2024.02.25_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 10/41: t15.2024.02.25_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 11/41: t15.2025.03.30_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 11/41: t15.2025.03.30_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 12/41: t15.2023.09.29_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 12/41: t15.2023.09.29_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 13/41: t15.2023.09.01_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 13/41: t15.2023.09.01_val_concatenated.npz\n", | |||
|  |       "   ✅ 分析完成: 101,906 样本\n", | |||
|  |       "   📊 标签1-39均值: 389\n", | |||
|  |       "   📉 下采样标签: [0, 40] → 389\n", | |||
|  |       "   📈 过采样阈值: 0.5 × 均值 = 194\n", | |||
|  |       "\n", | |||
|  |       "📊 采样策略总结:\n", | |||
|  |       "   📉 下采样标签: 2 个\n", | |||
|  |       "   📈 过采样标签: 11 个\n", | |||
|  |       "   ✅ 保持不变: 28 个\n", | |||
|  |       "\n", | |||
|  |       "✅ 步骤1完成!\n", | |||
|  |       "   ✅ 分析完成: 101,906 样本\n", | |||
|  |       "   📊 标签1-39均值: 389\n", | |||
|  |       "   📉 下采样标签: [0, 40] → 389\n", | |||
|  |       "   📈 过采样阈值: 0.5 × 均值 = 194\n", | |||
|  |       "\n", | |||
|  |       "📊 采样策略总结:\n", | |||
|  |       "   📉 下采样标签: 2 个\n", | |||
|  |       "   📈 过采样标签: 11 个\n", | |||
|  |       "   ✅ 保持不变: 28 个\n", | |||
|  |       "\n", | |||
|  |       "✅ 步骤1完成!\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "# 🔥 执行智能数据处理管道【确定采样策略】\n", | |||
|  |     "\n", | |||
|  |     "print(\"🚀 开始执行智能数据处理管道...\")\n", | |||
|  |     "print(\"=\" * 60)\n", | |||
|  |     "\n", | |||
|  |     "# 步骤1: 分析数据分布\n", | |||
|  |     "print(\"\\n\" + \"🔍 STEP 1: 分析数据分布\".center(60, \"=\"))\n", | |||
|  |     "distribution, strategy = pipeline.step1_analyze_distribution()\n", | |||
|  |     "\n", | |||
|  |     "# 显示采样策略总结\n", | |||
|  |     "print(f\"\\n📊 采样策略总结:\")\n", | |||
|  |     "undersample_count = sum(1 for s in strategy.values() if s['action'] == 'undersample')\n", | |||
|  |     "oversample_count = sum(1 for s in strategy.values() if s['action'] == 'oversample')\n", | |||
|  |     "keep_count = sum(1 for s in strategy.values() if s['action'] == 'keep')\n", | |||
|  |     "\n", | |||
|  |     "print(f\"   📉 下采样标签: {undersample_count} 个\")\n", | |||
|  |     "print(f\"   📈 过采样标签: {oversample_count} 个\") \n", | |||
|  |     "print(f\"   ✅ 保持不变: {keep_count} 个\")\n", | |||
|  |     "\n", | |||
|  |     "print(\"\\n✅ 步骤1完成!\")" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 33, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "\n", | |||
|  |       "=====================🔧 STEP 2: 拟合PCA参数======================\n", | |||
|  |       "\n", | |||
|  |       "🔧 步骤2: 拟合PCA参数(仅下采样,不过采样)...\n", | |||
|  |       "  正在加载文件 1/45: t15.2025.04.13_train_concatenated.npz\n", | |||
|  |       "  正在加载文件 2/45: t15.2024.07.21_train_concatenated.npz\n", | |||
|  |       "  正在加载文件 2/45: t15.2024.07.21_train_concatenated.npz\n", | |||
|  |       "  正在加载文件 3/45: t15.2024.03.17_train_concatenated.npz\n", | |||
|  |       "  正在加载文件 3/45: t15.2024.03.17_train_concatenated.npz\n", | |||
|  |       "   📦 PCA拟合样本: 15,000 个下采样样本\n", | |||
|  |       "   🔢 原始特征维度: 7168\n", | |||
|  |       "   📦 PCA拟合样本: 15,000 个下采样样本\n", | |||
|  |       "   🔢 原始特征维度: 7168\n", | |||
|  |       "   ✅ PCA拟合完成!\n", | |||
|  |       "      降维: 7168 → 1219\n", | |||
|  |       "      降维比例: 17.01%\n", | |||
|  |       "      保留方差: 0.9491\n", | |||
|  |       "      模型保存: smart_pipeline_pca.joblib\n", | |||
|  |       "\n", | |||
|  |       "✅ 步骤2完成!\n", | |||
|  |       "\n", | |||
|  |       "📋 智能数据处理管道状态:\n", | |||
|  |       "   🔍 步骤1 - 分布分析: ✅ 完成\n", | |||
|  |       "   🔧 步骤2 - PCA拟合: ✅ 完成\n", | |||
|  |       "   📊 标签1-39均值: 389\n", | |||
|  |       "   🔬 PCA降维: 7168 → 1219 (17.0%)\n", | |||
|  |       "   📈 保留方差: 0.9491\n", | |||
|  |       "\n", | |||
|  |       "🎯 使用流程:\n", | |||
|  |       "   1. pipeline.step1_analyze_distribution()\n", | |||
|  |       "   2. pipeline.step2_fit_pca_with_undersampling()\n", | |||
|  |       "   3. pipeline.step3_process_data('train')  # 训练集\n", | |||
|  |       "      pipeline.step3_process_data('val')    # 验证集\n", | |||
|  |       "   ✅ PCA拟合完成!\n", | |||
|  |       "      降维: 7168 → 1219\n", | |||
|  |       "      降维比例: 17.01%\n", | |||
|  |       "      保留方差: 0.9491\n", | |||
|  |       "      模型保存: smart_pipeline_pca.joblib\n", | |||
|  |       "\n", | |||
|  |       "✅ 步骤2完成!\n", | |||
|  |       "\n", | |||
|  |       "📋 智能数据处理管道状态:\n", | |||
|  |       "   🔍 步骤1 - 分布分析: ✅ 完成\n", | |||
|  |       "   🔧 步骤2 - PCA拟合: ✅ 完成\n", | |||
|  |       "   📊 标签1-39均值: 389\n", | |||
|  |       "   🔬 PCA降维: 7168 → 1219 (17.0%)\n", | |||
|  |       "   📈 保留方差: 0.9491\n", | |||
|  |       "\n", | |||
|  |       "🎯 使用流程:\n", | |||
|  |       "   1. pipeline.step1_analyze_distribution()\n", | |||
|  |       "   2. pipeline.step2_fit_pca_with_undersampling()\n", | |||
|  |       "   3. pipeline.step3_process_data('train')  # 训练集\n", | |||
|  |       "      pipeline.step3_process_data('val')    # 验证集\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "# 步骤2: 拟合PCA参数【确定PCA策略】\n", | |||
|  |     "print(\"\\n\" + \"🔧 STEP 2: 拟合PCA参数\".center(60, \"=\"))\n", | |||
|  |     "pipeline.step2_fit_pca_with_undersampling()\n", | |||
|  |     "\n", | |||
|  |     "print(\"\\n✅ 步骤2完成!\")\n", | |||
|  |     "pipeline.print_summary()" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "## 🚀 使用智能管道进行分批训练" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": null, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "ename": "", | |||
|  |      "evalue": "", | |||
|  |      "output_type": "error", | |||
|  |      "traceback": [ | |||
|  |       "\u001b[1;31mnotebook controller is DISPOSED. \n", | |||
|  |       "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details." | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "# 使用智能管道进行分批训练\n", | |||
|  |     "import lightgbm as lgb\n", | |||
|  |     "import time\n", | |||
|  |     "from collections import Counter\n", | |||
|  |     "import matplotlib.pyplot as plt\n", | |||
|  |     "import random\n", | |||
|  |     "import numpy as np\n", | |||
|  |     "import os\n", | |||
|  |     "import gc\n", | |||
|  |     "\n", | |||
|  |     "class SmartBatchTrainer:\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    智能分批训练器,集成智能数据管道\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    \n", | |||
|  |     "    def __init__(self, pipeline, params=None, min_learning_rate=1e-4, t_0=50, t_mult=2):\n", | |||
|  |     "        self.pipeline = pipeline\n", | |||
|  |     "        self.model = None\n", | |||
|  |     "        self.training_history = {} # 改为字典,因为只有一次训练\n", | |||
|  |     "        self.batch_count = 0\n", | |||
|  |     "        self.min_learning_rate = min_learning_rate\n", | |||
|  |     "        self.lr_history = [] # 用于可视化\n", | |||
|  |     "        \n", | |||
|  |     "        # 带重启的余弦退火参数\n", | |||
|  |     "        self.t_0 = t_0  # 第一个重启周期的长度\n", | |||
|  |     "        self.t_mult = t_mult  # 重启周期的乘数\n", | |||
|  |     "        \n", | |||
|  |     "        # 默认LightGBM参数(GPU优化)\n", | |||
|  |     "        self.params = params or {\n", | |||
|  |     "            'objective': 'multiclass',\n", | |||
|  |     "            'num_class': 41,\n", | |||
|  |     "            'metric': 'multi_logloss',\n", | |||
|  |     "            'boosting_type': 'gbdt',\n", | |||
|  |     "            'device_type': 'cpu',\n", | |||
|  |     "            # 'gpu_platform_id': 0,\n", | |||
|  |     "            # 'gpu_device_id': 0,\n", | |||
|  |     "            'max_bin': 255,\n", | |||
|  |     "            'num_leaves': 127,\n", | |||
|  |     "            'learning_rate': 0.10, #默认0.08\n", | |||
|  |     "            'feature_fraction': 0.8,\n", | |||
|  |     "            'bagging_fraction': 0.8,\n", | |||
|  |     "            'bagging_freq': 5,\n", | |||
|  |     "            'min_data_in_leaf': 20,\n", | |||
|  |     "            'lambda_l1': 0.1,\n", | |||
|  |     "            'lambda_l2': 0.1,\n", | |||
|  |     "            'verbose': -1,\n", | |||
|  |     "            'num_threads': -1\n", | |||
|  |     "        }\n", | |||
|  |     "        \n", | |||
|  |     "        self.initial_learning_rate = self.params.get('learning_rate', 0.08)\n", | |||
|  |     "        \n", | |||
|  |     "        print(f\"智能分批训练器创建完成\")\n", | |||
|  |     "        print(f\"   LightGBM参数已配置:{self.params['device_type'].upper()}模式\")\n", | |||
|  |     "        print(f\"   学习率调度: 带重启的余弦退火 (从 {self.initial_learning_rate} 到 {self.min_learning_rate})\")\n", | |||
|  |     "        print(f\"   重启参数: T_0={self.t_0}, T_mult={self.t_mult}\")\n", | |||
|  |     "    \n", | |||
|  |     "    def prepare_validation_data(self):\n", | |||
|  |     "        \"\"\"准备验证数据(仅PCA,保持原始分布)\"\"\"\n", | |||
|  |     "        print(\"准备验证数据...\")\n", | |||
|  |     "        X_val, y_val = self.pipeline.step3_process_data('val', apply_sampling=False)\n", | |||
|  |     "        if X_val is None:\n", | |||
|  |     "            raise ValueError(\"无法加载验证数据\")\n", | |||
|  |     "        val_counts = Counter(y_val)\n", | |||
|  |     "        print(f\"   验证数据准备完成: {X_val.shape[0]:,} 样本\")\n", | |||
|  |     "        print(f\"   验证集分布 (标签0: {val_counts.get(0, 0):,}, 标签40: {val_counts.get(40, 0):,})\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 缓存原始数组,便于计算accuracy\n", | |||
|  |     "        self._X_val_np = X_val\n", | |||
|  |     "        self._y_val_np = y_val\n", | |||
|  |     "        \n", | |||
|  |     "        return lgb.Dataset(X_val, label=y_val, free_raw_data=False)\n", | |||
|  |     "    \n", | |||
|  |     "    def get_training_batch_generator(self, n_files_per_batch=4, batch_size=8000):\n", | |||
|  |     "        \"\"\"改进的训练批次生成器:每次从所有文件中随机选择n个文件,然后随机采样\"\"\"\n", | |||
|  |     "        print(f\"准备改进的训练批次生成器...\")\n", | |||
|  |     "        print(f\"   每批次选择文件数: {n_files_per_batch}\")\n", | |||
|  |     "        print(f\"   每批次目标样本数: {batch_size:,}\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 获取所有训练文件列表\n", | |||
|  |     "        all_train_files = [f for f in os.listdir(self.pipeline.data_dir) \n", | |||
|  |     "                          if f.endswith('.npz') and 'train' in f]\n", | |||
|  |     "        \n", | |||
|  |     "        if len(all_train_files) < n_files_per_batch:\n", | |||
|  |     "            print(f\"   可用文件数({len(all_train_files)})少于每批次需要的文件数({n_files_per_batch})\")\n", | |||
|  |     "            n_files_per_batch = len(all_train_files)\n", | |||
|  |     "        \n", | |||
|  |     "        print(f\"   总计可用训练文件: {len(all_train_files)}\")\n", | |||
|  |     "        \n", | |||
|  |     "        batch_id = 0\n", | |||
|  |     "        while True:  # 无限生成器,可以重复采样\n", | |||
|  |     "            batch_id += 1\n", | |||
|  |     "            \n", | |||
|  |     "            # 随机选择n个文件\n", | |||
|  |     "            selected_files = random.sample(all_train_files, n_files_per_batch)\n", | |||
|  |     "            \n", | |||
|  |     "            print(f\"   批次 {batch_id} - 随机选择的文件:\")\n", | |||
|  |     "            for i, f in enumerate(selected_files, 1):\n", | |||
|  |     "                print(f\"      {i}. {f}\")\n", | |||
|  |     "            \n", | |||
|  |     "            # 从选中的文件中加载数据\n", | |||
|  |     "            all_features = []\n", | |||
|  |     "            all_labels = []\n", | |||
|  |     "            total_available_samples = 0\n", | |||
|  |     "            \n", | |||
|  |     "            for filename in selected_files:\n", | |||
|  |     "                # 加载文件数据\n", | |||
|  |     "                data = np.load(os.path.join(self.pipeline.data_dir, filename), allow_pickle=True)\n", | |||
|  |     "                trials = data['neural_logits_concatenated']\n", | |||
|  |     "                \n", | |||
|  |     "                # 提取特征和标签\n", | |||
|  |     "                features, labels = extract_features_labels_batch(trials)\n", | |||
|  |     "                \n", | |||
|  |     "                if features.shape[0] > 0:\n", | |||
|  |     "                    all_features.append(features)\n", | |||
|  |     "                    all_labels.append(labels)\n", | |||
|  |     "                    total_available_samples += features.shape[0]\n", | |||
|  |     "                \n", | |||
|  |     "                # 清理单个文件数据\n", | |||
|  |     "                del data, trials\n", | |||
|  |     "                gc.collect()\n", | |||
|  |     "            \n", | |||
|  |     "            if all_features:\n", | |||
|  |     "                # 合并所有选中文件的数据\n", | |||
|  |     "                combined_features = np.vstack(all_features)\n", | |||
|  |     "                combined_labels = np.hstack(all_labels)\n", | |||
|  |     "                \n", | |||
|  |     "                print(f\"   合并后总样本数: {combined_features.shape[0]:,}\")\n", | |||
|  |     "                \n", | |||
|  |     "                # 随机采样到目标batch_size\n", | |||
|  |     "                if combined_features.shape[0] > batch_size:\n", | |||
|  |     "                    # 随机选择batch_size个样本\n", | |||
|  |     "                    sample_indices = np.random.choice(\n", | |||
|  |     "                        combined_features.shape[0], \n", | |||
|  |     "                        size=batch_size, \n", | |||
|  |     "                        replace=False\n", | |||
|  |     "                    )\n", | |||
|  |     "                    sampled_features = combined_features[sample_indices]\n", | |||
|  |     "                    sampled_labels = combined_labels[sample_indices]\n", | |||
|  |     "                    print(f\"   随机采样到: {batch_size:,} 样本\")\n", | |||
|  |     "                else:\n", | |||
|  |     "                    # 如果样本不足,使用所有样本\n", | |||
|  |     "                    sampled_features = combined_features\n", | |||
|  |     "                    sampled_labels = combined_labels\n", | |||
|  |     "                    print(f\"   样本不足,使用全部: {sampled_features.shape[0]:,} 样本\")\n", | |||
|  |     "                \n", | |||
|  |     "                # 应用采样策略(平衡处理)\n", | |||
|  |     "                features_balanced, labels_balanced = self.pipeline._apply_full_sampling(\n", | |||
|  |     "                    sampled_features, sampled_labels\n", | |||
|  |     "                )\n", | |||
|  |     "                \n", | |||
|  |     "                # 应用PCA降维\n", | |||
|  |     "                if features_balanced.shape[0] > 0:\n", | |||
|  |     "                    features_pca = self.pipeline._apply_pca_transform(features_balanced)\n", | |||
|  |     "                    \n", | |||
|  |     "                    # 分析当前批次分布\n", | |||
|  |     "                    batch_counts = Counter(labels_balanced)\n", | |||
|  |     "                    \n", | |||
|  |     "                    print(f\"   批次 {batch_id} 最终结果:\")\n", | |||
|  |     "                    print(f\"      平衡后样本数: {features_pca.shape[0]:,}\")\n", | |||
|  |     "                    print(f\"      特征维度: {features_pca.shape[1]}\")\n", | |||
|  |     "                    print(f\"      分布: 标签0={batch_counts.get(0,0)}, 标签40={batch_counts.get(40,0)}\")\n", | |||
|  |     "                    print(f\"   \" + \"=\"*50)\n", | |||
|  |     "                    \n", | |||
|  |     "                    # 重要修复:设置 free_raw_data=False 避免增量训练失败\n", | |||
|  |     "                    yield lgb.Dataset(features_pca, label=labels_balanced, free_raw_data=False), f\"batch_{batch_id}_files_{len(selected_files)}\"\n", | |||
|  |     "                \n", | |||
|  |     "                # 清理批次数据\n", | |||
|  |     "                del all_features, all_labels, combined_features, combined_labels\n", | |||
|  |     "                del sampled_features, sampled_labels, features_balanced, labels_balanced\n", | |||
|  |     "                gc.collect()\n", | |||
|  |     "            else:\n", | |||
|  |     "                print(f\"   批次 {batch_id} 无有效数据\")\n", | |||
|  |     "                continue\n", | |||
|  |     "    \n", | |||
|  |     "    def prepare_full_data(self):\n", | |||
|  |     "        \"\"\"一次性准备所有训练和验证数据\"\"\"\n", | |||
|  |     "        print(\"准备全量训练和验证数据...\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 1. 准备验证数据 (保持原始分布)\n", | |||
|  |     "        X_val, y_val = self.pipeline.step3_process_data('val', apply_sampling=False)\n", | |||
|  |     "        if X_val is None:\n", | |||
|  |     "            raise ValueError(\"无法加载验证数据\")\n", | |||
|  |     "        val_counts = Counter(y_val)\n", | |||
|  |     "        print(f\"   验证数据准备完成: {X_val.shape[0]:,} 样本\")\n", | |||
|  |     "        print(f\"   验证集分布 (标签0: {val_counts.get(0, 0):,}, 标签40: {val_counts.get(40, 0):,})\")\n", | |||
|  |     "        val_data = lgb.Dataset(X_val, label=y_val, free_raw_data=False)\n", | |||
|  |     "        \n", | |||
|  |     "        # 2. 准备训练数据 (应用完整采样和PCA策略)\n", | |||
|  |     "        X_train, y_train = self.pipeline.step3_process_data('train', apply_sampling=True)\n", | |||
|  |     "        if X_train is None:\n", | |||
|  |     "            raise ValueError(\"无法加载训练数据\")\n", | |||
|  |     "        train_counts = Counter(y_train)\n", | |||
|  |     "        print(f\"   训练数据准备完成: {X_train.shape[0]:,} 样本, {X_train.shape[1]} 特征\")\n", | |||
|  |     "        print(f\"   训练集(采样后)分布 (标签0: {train_counts.get(0, 0):,}, 标签40: {train_counts.get(40, 0):,})\")\n", | |||
|  |     "        train_data = lgb.Dataset(X_train, label=y_train)\n", | |||
|  |     "        \n", | |||
|  |     "        return train_data, val_data, X_val, y_val\n", | |||
|  |     "    \n", | |||
|  |     "    def prepare_training_data(self):\n", | |||
|  |     "        \"\"\"准备训练数据(仅PCA,保持原始分布)\"\"\"\n", | |||
|  |     "        print(\"准备训练数据...\")\n", | |||
|  |     "        # 2. 准备训练数据 (应用完整采样和PCA策略)\n", | |||
|  |     "        X_train, y_train = self.pipeline.step3_process_data('train', apply_sampling=True)\n", | |||
|  |     "        if X_train is None:\n", | |||
|  |     "            raise ValueError(\"无法加载训练数据\")\n", | |||
|  |     "        train_counts = Counter(y_train)\n", | |||
|  |     "        print(f\"   训练数据准备完成: {X_train.shape[0]:,} 样本, {X_train.shape[1]} 特征\")\n", | |||
|  |     "        print(f\"   训练集(采样后)分布 (标签0: {train_counts.get(0, 0):,}, 标签40: {train_counts.get(40, 0):,})\")\n", | |||
|  |     "        \n", | |||
|  |     "        return lgb.Dataset(X_train, label=y_train, free_raw_data=False)\n", | |||
|  |     "                \n", | |||
|  |     "    # 带重启的余弦退火调度器函数\n", | |||
|  |     "    def _cosine_annealing_with_warm_restarts(self, current_round):\n", | |||
|  |     "        \"\"\"带重启的余弦退火调度器 (SGDR)\"\"\"\n", | |||
|  |     "        eta_max = self.initial_learning_rate\n", | |||
|  |     "        eta_min = self.min_learning_rate\n", | |||
|  |     "        \n", | |||
|  |     "        # 计算当前在哪个重启周期中\n", | |||
|  |     "        t_cur = current_round\n", | |||
|  |     "        t_i = self.t_0\n", | |||
|  |     "        \n", | |||
|  |     "        # 找到当前的重启周期\n", | |||
|  |     "        cycle = 0\n", | |||
|  |     "        while t_cur >= t_i:\n", | |||
|  |     "            t_cur -= t_i\n", | |||
|  |     "            cycle += 1\n", | |||
|  |     "            t_i *= self.t_mult\n", | |||
|  |     "        \n", | |||
|  |     "        # 在当前周期内的位置\n", | |||
|  |     "        progress = t_cur / t_i\n", | |||
|  |     "        \n", | |||
|  |     "        # 计算学习率\n", | |||
|  |     "        lr = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * progress))\n", | |||
|  |     "        \n", | |||
|  |     "        return lr\n", | |||
|  |     "    \n", | |||
|  |     "    def train_incremental(self, num_boost_round=100, early_stopping_rounds=10, \n", | |||
|  |     "                         n_files_per_batch=4, batch_size=8000, max_batches=None):\n", | |||
|  |     "        \"\"\"增量分批训练 - 支持自定义批次参数\"\"\"\n", | |||
|  |     "        print(f\"开始智能分批训练...\")\n", | |||
|  |     "        print(f\"   训练轮数 (每批次): {num_boost_round}\")\n", | |||
|  |     "        print(f\"   早停轮数: {early_stopping_rounds}\")\n", | |||
|  |     "        print(f\"   每批次文件数: {n_files_per_batch}\")\n", | |||
|  |     "        print(f\"   每批次样本数: {batch_size:,}\")\n", | |||
|  |     "        if max_batches:\n", | |||
|  |     "            print(f\"   最大批次数: {max_batches}\")\n", | |||
|  |     "        print(\"=\" * 60)\n", | |||
|  |     "        \n", | |||
|  |     "        # 准备验证数据\n", | |||
|  |     "        val_data = self.prepare_validation_data()\n", | |||
|  |     "        \n", | |||
|  |     "        print(f\"开始分批增量训练...\")\n", | |||
|  |     "        total_start_time = time.time()\n", | |||
|  |     "        \n", | |||
|  |     "        # 初始化训练历史\n", | |||
|  |     "        self.training_history = []\n", | |||
|  |     "        \n", | |||
|  |     "        # 创建改进的生成器\n", | |||
|  |     "        batch_generator = self.get_training_batch_generator(n_files_per_batch, batch_size)\n", | |||
|  |     "        \n", | |||
|  |     "        for train_data, batch_name in batch_generator:\n", | |||
|  |     "            self.batch_count += 1\n", | |||
|  |     "            batch_start_time = time.time()\n", | |||
|  |     "            \n", | |||
|  |     "            # 检查是否达到最大批次数\n", | |||
|  |     "            if max_batches and self.batch_count > max_batches:\n", | |||
|  |     "                print(f\"达到最大批次数 {max_batches},停止训练\")\n", | |||
|  |     "                break\n", | |||
|  |     "            \n", | |||
|  |     "            # 先构建数据集,使得可以安全访问 num_data()\n", | |||
|  |     "            try:\n", | |||
|  |     "                train_data.construct()\n", | |||
|  |     "            except Exception:\n", | |||
|  |     "                pass\n", | |||
|  |     "\n", | |||
|  |     "            print(f\"\\n批次 {self.batch_count}: {batch_name}\")\n", | |||
|  |     "            try:\n", | |||
|  |     "                print(f\"   样本数: {train_data.num_data():,}\")\n", | |||
|  |     "            except Exception:\n", | |||
|  |     "                print(\"   样本数: (未构建,跳过显示)\")\n", | |||
|  |     "            \n", | |||
|  |     "            # 计算当前批次的学习率\n", | |||
|  |     "            current_lr = self._cosine_annealing_with_warm_restarts(\n", | |||
|  |     "                (self.batch_count - 1) * num_boost_round\n", | |||
|  |     "            )\n", | |||
|  |     "            \n", | |||
|  |     "            # 更新训练参数中的学习率\n", | |||
|  |     "            current_params = self.params.copy()\n", | |||
|  |     "            current_params['learning_rate'] = current_lr\n", | |||
|  |     "            \n", | |||
|  |     "            try:\n", | |||
|  |     "                # 训练参数\n", | |||
|  |     "                train_params = {\n", | |||
|  |     "                    'params': current_params,\n", | |||
|  |     "                    'train_set': train_data,\n", | |||
|  |     "                    'num_boost_round': num_boost_round,\n", | |||
|  |     "                    'valid_sets': [val_data],\n", | |||
|  |     "                    'valid_names': ['validation'],\n", | |||
|  |     "                    'callbacks': [\n", | |||
|  |     "                        lgb.log_evaluation(period=1)\n", | |||
|  |     "                    ]\n", | |||
|  |     "                }\n", | |||
|  |     "                \n", | |||
|  |     "                # 如果有早停设置\n", | |||
|  |     "                if early_stopping_rounds:\n", | |||
|  |     "                    train_params['callbacks'].append(\n", | |||
|  |     "                        lgb.early_stopping(early_stopping_rounds, verbose=False)\n", | |||
|  |     "                    )\n", | |||
|  |     "                \n", | |||
|  |     "                # 增量训练\n", | |||
|  |     "                if self.model is None:\n", | |||
|  |     "                    # 第一次训练\n", | |||
|  |     "                    print(f\"   首次训练 (学习率: {current_lr:.6f})\")\n", | |||
|  |     "                    self.model = lgb.train(**train_params)\n", | |||
|  |     "                else:\n", | |||
|  |     "                    # 增量训练\n", | |||
|  |     "                    print(f\"   增量训练 (学习率: {current_lr:.6f})\")\n", | |||
|  |     "                    train_params['init_model'] = self.model\n", | |||
|  |     "                    self.model = lgb.train(**train_params)\n", | |||
|  |     "                \n", | |||
|  |     "                # 验证 - 修复数组比较的歧义性问题\n", | |||
|  |     "                # 优先使用缓存的验证集数组,退回到val_data中的数据\n", | |||
|  |     "                Xv = getattr(self, '_X_val_np', None) or val_data.get_data()\n", | |||
|  |     "                yv = getattr(self, '_y_val_np', None) or val_data.get_label()\n", | |||
|  |     "                val_pred = self.model.predict(Xv)\n", | |||
|  |     "                \n", | |||
|  |     "                # 确保yv是1维数组,避免数组比较的歧义\n", | |||
|  |     "                if hasattr(yv, 'shape') and len(yv.shape) > 1:\n", | |||
|  |     "                    yv = yv.flatten()\n", | |||
|  |     "                if hasattr(yv, 'shape') and yv.shape[0] == 0:\n", | |||
|  |     "                    yv = np.array(yv, dtype=int)\n", | |||
|  |     "                \n", | |||
|  |     "                # 计算验证准确率\n", | |||
|  |     "                pred_labels = np.argmax(val_pred, axis=1)\n", | |||
|  |     "                val_accuracy = float(np.mean(pred_labels == yv))\n", | |||
|  |     "                \n", | |||
|  |     "                # 记录训练历史\n", | |||
|  |     "                batch_time = time.time() - batch_start_time\n", | |||
|  |     "                try:\n", | |||
|  |     "                    samples_cnt = train_data.num_data()\n", | |||
|  |     "                except Exception:\n", | |||
|  |     "                    samples_cnt = None\n", | |||
|  |     "                self.training_history.append({\n", | |||
|  |     "                    'batch': self.batch_count,\n", | |||
|  |     "                    'batch_name': batch_name,\n", | |||
|  |     "                    'val_accuracy': val_accuracy,\n", | |||
|  |     "                    'time': batch_time,\n", | |||
|  |     "                    'num_trees': self.model.num_trees(),\n", | |||
|  |     "                    'learning_rate': current_lr,\n", | |||
|  |     "                    'samples': samples_cnt\n", | |||
|  |     "                })\n", | |||
|  |     "                \n", | |||
|  |     "                print(f\"   批次完成:\")\n", | |||
|  |     "                print(f\"      验证准确率: {val_accuracy:.4f}\")\n", | |||
|  |     "                print(f\"      训练时间: {batch_time:.1f}秒\")\n", | |||
|  |     "                print(f\"      模型树数: {self.model.num_trees()}\")\n", | |||
|  |     "                print(f\"      当前学习率: {current_lr:.6f}\")\n", | |||
|  |     "                \n", | |||
|  |     "            except Exception as e:\n", | |||
|  |     "                print(f\"   批次训练失败: {e}\")\n", | |||
|  |     "                continue\n", | |||
|  |     "        \n", | |||
|  |     "        # 训练完成\n", | |||
|  |     "        total_time = time.time() - total_start_time\n", | |||
|  |     "        print(f\"\\n增量训练完成!\")\n", | |||
|  |     "        print(f\"   总批次数: {len(self.training_history)}\")\n", | |||
|  |     "        print(f\"   总训练时间: {total_time:.1f}秒\")\n", | |||
|  |     "        \n", | |||
|  |     "        if self.training_history:\n", | |||
|  |     "            best_batch = max(self.training_history, key=lambda x: x['val_accuracy'])\n", | |||
|  |     "            print(f\"   最佳准确率: {best_batch['val_accuracy']:.4f} (批次 {best_batch['batch']})\")\n", | |||
|  |     "            final_accuracy = self.training_history[-1]['val_accuracy']\n", | |||
|  |     "            print(f\"   最终准确率: {final_accuracy:.4f}\")\n", | |||
|  |     "        \n", | |||
|  |     "        return self.model\n", | |||
|  |     "\n", | |||
|  |     "    @staticmethod\n", | |||
|  |     "    def _ctc_collapse(seq, blank=0, drop_sep40=False):\n", | |||
|  |     "        out = []\n", | |||
|  |     "        prev = None\n", | |||
|  |     "        for s in seq:\n", | |||
|  |     "            if s == prev:\n", | |||
|  |     "                continue\n", | |||
|  |     "            prev = s\n", | |||
|  |     "            if s == blank:\n", | |||
|  |     "                continue\n", | |||
|  |     "            if drop_sep40 and s == 40:\n", | |||
|  |     "                continue\n", | |||
|  |     "            out.append(int(s))\n", | |||
|  |     "        return out\n", | |||
|  |     "\n", | |||
|  |     "    @staticmethod\n", | |||
|  |     "    def _levenshtein(a, b):\n", | |||
|  |     "        # a, b are lists of ints\n", | |||
|  |     "        n, m = len(a), len(b)\n", | |||
|  |     "        if n == 0:\n", | |||
|  |     "            return m\n", | |||
|  |     "        if m == 0:\n", | |||
|  |     "            return n\n", | |||
|  |     "        dp = list(range(m + 1))\n", | |||
|  |     "        for i in range(1, n + 1):\n", | |||
|  |     "            prev = dp[0]\n", | |||
|  |     "            dp[0] = i\n", | |||
|  |     "            ai = a[i - 1]\n", | |||
|  |     "            for j in range(1, m + 1):\n", | |||
|  |     "                tmp = dp[j]\n", | |||
|  |     "                cost = 0 if ai == b[j - 1] else 1\n", | |||
|  |     "                dp[j] = min(dp[j] + 1,      # deletion\n", | |||
|  |     "                            dp[j - 1] + 1,  # insertion\n", | |||
|  |     "                            prev + cost)    # substitution\n", | |||
|  |     "                prev = tmp\n", | |||
|  |     "        return dp[m]\n", | |||
|  |     "\n", | |||
|  |     "    def evaluate_val_per_experiment(self, fraction=0.33, random_state=42, drop_sep40=False, max_trials_per_file=None):\n", | |||
|  |     "        \"\"\"使用所有验证文件,每个文件抽取33%的trial,按trial计算PER并求均值\"\"\"\n", | |||
|  |     "        if self.model is None:\n", | |||
|  |     "            raise RuntimeError(\"模型尚未训练,无法评估PER\")\n", | |||
|  |     "\n", | |||
|  |     "        rng = np.random.RandomState(random_state)\n", | |||
|  |     "        val_files = [f for f in os.listdir(self.pipeline.data_dir) if f.endswith('.npz') and 'val' in f]\n", | |||
|  |     "        if not val_files:\n", | |||
|  |     "            raise FileNotFoundError(\"未找到验证集npz文件\")\n", | |||
|  |     "\n", | |||
|  |     "        results_by_file = {}\n", | |||
|  |     "        per_list = []\n", | |||
|  |     "        corpus_edit = 0\n", | |||
|  |     "        corpus_len = 0\n", | |||
|  |     "        total_trials = 0\n", | |||
|  |     "\n", | |||
|  |     "        for vf in sorted(val_files):\n", | |||
|  |     "            data = np.load(os.path.join(self.pipeline.data_dir, vf), allow_pickle=True)\n", | |||
|  |     "            trials = data['neural_logits_concatenated']\n", | |||
|  |     "            n_trials = len(trials)\n", | |||
|  |     "            if n_trials == 0:\n", | |||
|  |     "                results_by_file[vf] = {'n': 0, 'mean_PER': None}\n", | |||
|  |     "                continue\n", | |||
|  |     "            k = max(1, int(np.ceil(n_trials * fraction)))\n", | |||
|  |     "            idx = np.arange(n_trials)\n", | |||
|  |     "            idx = rng.choice(idx, size=k, replace=False)\n", | |||
|  |     "            if max_trials_per_file is not None:\n", | |||
|  |     "                k = min(k, max_trials_per_file)\n", | |||
|  |     "                idx = idx[:k]\n", | |||
|  |     "\n", | |||
|  |     "            trial_pers = []\n", | |||
|  |     "            for ti in idx:\n", | |||
|  |     "                tr = trials[ti]\n", | |||
|  |     "                X_trial = tr[:, :7168]\n", | |||
|  |     "                rnn_logits = tr[:, 7168:]\n", | |||
|  |     "                # 变换到PCA空间\n", | |||
|  |     "                X_trial_pca = self.pipeline._apply_pca_transform(X_trial)\n", | |||
|  |     "                # 预测\n", | |||
|  |     "                pred_proba = self.model.predict(X_trial_pca)\n", | |||
|  |     "                y_pred_seq = np.argmax(pred_proba, axis=1)\n", | |||
|  |     "                y_true_seq = np.argmax(rnn_logits, axis=1)\n", | |||
|  |     "                # CTC折叠\n", | |||
|  |     "                pred_collapsed = self._ctc_collapse(y_pred_seq, blank=0, drop_sep40=drop_sep40)\n", | |||
|  |     "                true_collapsed = self._ctc_collapse(y_true_seq, blank=0, drop_sep40=drop_sep40)\n", | |||
|  |     "                if len(true_collapsed) == 0:\n", | |||
|  |     "                    continue\n", | |||
|  |     "                ed = self._levenshtein(pred_collapsed, true_collapsed)\n", | |||
|  |     "                per = ed / len(true_collapsed)\n", | |||
|  |     "                trial_pers.append(per)\n", | |||
|  |     "                corpus_edit += ed\n", | |||
|  |     "                corpus_len += len(true_collapsed)\n", | |||
|  |     "                total_trials += 1\n", | |||
|  |     "\n", | |||
|  |     "            if trial_pers:\n", | |||
|  |     "                results_by_file[vf] = {\n", | |||
|  |     "                    'n': len(trial_pers),\n", | |||
|  |     "                    'mean_PER': float(np.mean(trial_pers))\n", | |||
|  |     "                }\n", | |||
|  |     "                per_list.extend(trial_pers)\n", | |||
|  |     "            else:\n", | |||
|  |     "                results_by_file[vf] = {'n': 0, 'mean_PER': None}\n", | |||
|  |     "\n", | |||
|  |     "            del data, trials\n", | |||
|  |     "            gc.collect()\n", | |||
|  |     "\n", | |||
|  |     "        overall_mean = float(np.mean(per_list)) if per_list else None\n", | |||
|  |     "        corpus_per = float(corpus_edit / corpus_len) if corpus_len > 0 else None\n", | |||
|  |     "\n", | |||
|  |     "        summary = {\n", | |||
|  |     "            'overall_mean_PER': overall_mean,\n", | |||
|  |     "            'corpus_PER': corpus_per,\n", | |||
|  |     "            'total_trials': total_trials,\n", | |||
|  |     "            'per_file': results_by_file\n", | |||
|  |     "        }\n", | |||
|  |     "        print(\"验证集PER评估完成\")\n", | |||
|  |     "        print(f\"   文件数: {len(val_files)}  评估trial数: {total_trials}\")\n", | |||
|  |     "        print(f\"   平均PER(逐trial取均值): {overall_mean}\")\n", | |||
|  |     "        print(f\"   语料级PER(总编辑距离/总长度): {corpus_per}\")\n", | |||
|  |     "        return summary" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 41, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "智能分批训练器创建完成\n", | |||
|  |       "   LightGBM参数已配置:CPU模式\n", | |||
|  |       "   学习率调度: 带重启的余弦退火 (从 0.1 到 0.001)\n", | |||
|  |       "   重启参数: T_0=50, T_mult=2\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "trainer = SmartBatchTrainer(pipeline, min_learning_rate=0.001, t_0=50, t_mult=2)" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": null, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "开始智能分批训练...\n", | |||
|  |       "   训练轮数 (每批次): 10\n", | |||
|  |       "   早停轮数: 10\n", | |||
|  |       "   每批次文件数: 2\n", | |||
|  |       "   每批次样本数: 4,000\n", | |||
|  |       "   最大批次数: 100\n", | |||
|  |       "============================================================\n", | |||
|  |       "准备验证数据...\n", | |||
|  |       "\n", | |||
|  |       "🔄 步骤3: 处理val数据...\n", | |||
|  |       "   采样策略: 禁用\n", | |||
|  |       "  正在加载文件 1/41: t15.2023.11.17_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 2/41: t15.2023.12.17_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 2/41: t15.2023.12.17_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 3/41: t15.2023.10.15_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 3/41: t15.2023.10.15_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 4/41: t15.2023.10.08_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 4/41: t15.2023.10.08_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 5/41: t15.2025.01.10_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 5/41: t15.2025.01.10_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 6/41: t15.2023.12.08_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 6/41: t15.2023.12.08_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 7/41: t15.2024.03.08_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 7/41: t15.2024.03.08_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 8/41: t15.2024.03.15_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 8/41: t15.2024.03.15_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 9/41: t15.2025.03.14_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 9/41: t15.2025.03.14_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 10/41: t15.2024.02.25_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 10/41: t15.2024.02.25_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 11/41: t15.2025.03.30_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 11/41: t15.2025.03.30_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 12/41: t15.2023.09.29_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 12/41: t15.2023.09.29_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 13/41: t15.2023.09.01_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 13/41: t15.2023.09.01_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 14/41: t15.2024.03.17_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 14/41: t15.2024.03.17_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 15/41: t15.2025.01.12_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 15/41: t15.2025.01.12_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 16/41: t15.2023.09.03_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 16/41: t15.2023.09.03_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 17/41: t15.2023.12.10_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 17/41: t15.2023.12.10_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 18/41: t15.2023.08.20_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 18/41: t15.2023.08.20_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 19/41: t15.2024.06.14_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 19/41: t15.2024.06.14_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 20/41: t15.2023.08.18_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 20/41: t15.2023.08.18_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 21/41: t15.2023.10.01_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 21/41: t15.2023.10.01_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 22/41: t15.2024.07.19_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 22/41: t15.2024.07.19_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 23/41: t15.2023.09.24_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 23/41: t15.2023.09.24_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 24/41: t15.2023.11.03_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 24/41: t15.2023.11.03_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 25/41: t15.2024.07.28_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 25/41: t15.2024.07.28_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 26/41: t15.2023.11.26_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 26/41: t15.2023.11.26_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 27/41: t15.2023.12.03_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 27/41: t15.2023.12.03_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 28/41: t15.2025.03.16_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 28/41: t15.2025.03.16_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 29/41: t15.2023.11.04_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 29/41: t15.2023.11.04_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 30/41: t15.2023.08.13_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 30/41: t15.2023.08.13_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 31/41: t15.2024.07.21_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 31/41: t15.2024.07.21_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 32/41: t15.2023.08.25_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 32/41: t15.2023.08.25_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 33/41: t15.2025.04.13_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 33/41: t15.2025.04.13_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 34/41: t15.2024.05.10_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 34/41: t15.2024.05.10_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 35/41: t15.2023.10.06_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 35/41: t15.2023.10.06_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 36/41: t15.2023.08.27_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 36/41: t15.2023.08.27_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 37/41: t15.2023.11.19_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 37/41: t15.2023.11.19_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 38/41: t15.2023.10.20_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 38/41: t15.2023.10.20_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 39/41: t15.2023.10.13_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 39/41: t15.2023.10.13_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 40/41: t15.2023.12.29_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 40/41: t15.2023.12.29_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 41/41: t15.2023.10.22_val_concatenated.npz\n", | |||
|  |       "  正在加载文件 41/41: t15.2023.10.22_val_concatenated.npz\n", | |||
|  |       "   ✅ 处理完成: 321,773 样本, 1219 特征\n", | |||
|  |       "   验证数据准备完成: 321,773 样本\n", | |||
|  |       "   验证集分布 (标签0: 238,705, 标签40: 35,425)\n", | |||
|  |       "开始分批增量训练...\n", | |||
|  |       "准备改进的训练批次生成器...\n", | |||
|  |       "   每批次选择文件数: 2\n", | |||
|  |       "   每批次目标样本数: 4,000\n", | |||
|  |       "   总计可用训练文件: 45\n", | |||
|  |       "   批次 1 - 随机选择的文件:\n", | |||
|  |       "      1. t15.2025.03.16_train_concatenated.npz\n", | |||
|  |       "      2. t15.2023.08.20_train_concatenated.npz\n", | |||
|  |       "   ✅ 处理完成: 321,773 样本, 1219 特征\n", | |||
|  |       "   验证数据准备完成: 321,773 样本\n", | |||
|  |       "   验证集分布 (标签0: 238,705, 标签40: 35,425)\n", | |||
|  |       "开始分批增量训练...\n", | |||
|  |       "准备改进的训练批次生成器...\n", | |||
|  |       "   每批次选择文件数: 2\n", | |||
|  |       "   每批次目标样本数: 4,000\n", | |||
|  |       "   总计可用训练文件: 45\n", | |||
|  |       "   批次 1 - 随机选择的文件:\n", | |||
|  |       "      1. t15.2025.03.16_train_concatenated.npz\n", | |||
|  |       "      2. t15.2023.08.20_train_concatenated.npz\n", | |||
|  |       "   合并后总样本数: 74,732\n", | |||
|  |       "   随机采样到: 4,000 样本\n", | |||
|  |       "   合并后总样本数: 74,732\n", | |||
|  |       "   随机采样到: 4,000 样本\n", | |||
|  |       "   批次 1 最终结果:\n", | |||
|  |       "      平衡后样本数: 3,096\n", | |||
|  |       "      特征维度: 1219\n", | |||
|  |       "      分布: 标签0=388, 标签40=388\n", | |||
|  |       "   ==================================================\n", | |||
|  |       "   批次 1 最终结果:\n", | |||
|  |       "      平衡后样本数: 3,096\n", | |||
|  |       "      特征维度: 1219\n", | |||
|  |       "      分布: 标签0=388, 标签40=388\n", | |||
|  |       "   ==================================================\n", | |||
|  |       "\n", | |||
|  |       "批次 1: batch_1_files_2\n", | |||
|  |       "   样本数: 3,096\n", | |||
|  |       "   首次训练 (学习率: 0.100000)\n", | |||
|  |       "\n", | |||
|  |       "批次 1: batch_1_files_2\n", | |||
|  |       "   样本数: 3,096\n", | |||
|  |       "   首次训练 (学习率: 0.100000)\n" | |||
|  |      ] | |||
|  |     }, | |||
|  |     { | |||
|  |      "name": "stderr", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n", | |||
|  |       "  _log_warning('Overriding the parameters from Reference Dataset.')\n" | |||
|  |      ] | |||
|  |     }, | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "[1]\tvalidation's multi_logloss: 2.38768\n", | |||
|  |       "[2]\tvalidation's multi_logloss: 2.27059\n", | |||
|  |       "[2]\tvalidation's multi_logloss: 2.27059\n", | |||
|  |       "[3]\tvalidation's multi_logloss: 2.18685\n", | |||
|  |       "[3]\tvalidation's multi_logloss: 2.18685\n", | |||
|  |       "[4]\tvalidation's multi_logloss: 2.11433\n", | |||
|  |       "[4]\tvalidation's multi_logloss: 2.11433\n", | |||
|  |       "[5]\tvalidation's multi_logloss: 2.05362\n", | |||
|  |       "[5]\tvalidation's multi_logloss: 2.05362\n", | |||
|  |       "[6]\tvalidation's multi_logloss: 1.99934\n", | |||
|  |       "[6]\tvalidation's multi_logloss: 1.99934\n", | |||
|  |       "[7]\tvalidation's multi_logloss: 1.95025\n", | |||
|  |       "[7]\tvalidation's multi_logloss: 1.95025\n", | |||
|  |       "[8]\tvalidation's multi_logloss: 1.90743\n", | |||
|  |       "[8]\tvalidation's multi_logloss: 1.90743\n", | |||
|  |       "[9]\tvalidation's multi_logloss: 1.87524\n", | |||
|  |       "[9]\tvalidation's multi_logloss: 1.87524\n", | |||
|  |       "[10]\tvalidation's multi_logloss: 1.84247\n", | |||
|  |       "   批次训练失败: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()\n", | |||
|  |       "[10]\tvalidation's multi_logloss: 1.84247\n", | |||
|  |       "   批次训练失败: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()\n", | |||
|  |       "   批次 2 - 随机选择的文件:\n", | |||
|  |       "      1. t15.2024.07.28_train_concatenated.npz\n", | |||
|  |       "      2. t15.2024.03.03_train_concatenated.npz\n", | |||
|  |       "   批次 2 - 随机选择的文件:\n", | |||
|  |       "      1. t15.2024.07.28_train_concatenated.npz\n", | |||
|  |       "      2. t15.2024.03.03_train_concatenated.npz\n", | |||
|  |       "   合并后总样本数: 68,118\n", | |||
|  |       "   随机采样到: 4,000 样本\n", | |||
|  |       "   合并后总样本数: 68,118\n", | |||
|  |       "   随机采样到: 4,000 样本\n", | |||
|  |       "   批次 2 最终结果:\n", | |||
|  |       "      平衡后样本数: 3,345\n", | |||
|  |       "      特征维度: 1219\n", | |||
|  |       "      分布: 标签0=388, 标签40=388\n", | |||
|  |       "   ==================================================\n", | |||
|  |       "   批次 2 最终结果:\n", | |||
|  |       "      平衡后样本数: 3,345\n", | |||
|  |       "      特征维度: 1219\n", | |||
|  |       "      分布: 标签0=388, 标签40=388\n", | |||
|  |       "   ==================================================\n", | |||
|  |       "\n", | |||
|  |       "批次 2: batch_2_files_2\n", | |||
|  |       "   样本数: 3,345\n", | |||
|  |       "   增量训练 (学习率: 0.050500)\n", | |||
|  |       "\n", | |||
|  |       "批次 2: batch_2_files_2\n", | |||
|  |       "   样本数: 3,345\n", | |||
|  |       "   增量训练 (学习率: 0.050500)\n" | |||
|  |      ] | |||
|  |     }, | |||
|  |     { | |||
|  |      "name": "stderr", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n", | |||
|  |       "  _log_warning('Overriding the parameters from Reference Dataset.')\n" | |||
|  |      ] | |||
|  |     }, | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "[11]\tvalidation's multi_logloss: 1.85013\n", | |||
|  |       "[12]\tvalidation's multi_logloss: 1.84353\n", | |||
|  |       "[12]\tvalidation's multi_logloss: 1.84353\n", | |||
|  |       "[13]\tvalidation's multi_logloss: 1.84134\n", | |||
|  |       "[13]\tvalidation's multi_logloss: 1.84134\n", | |||
|  |       "[14]\tvalidation's multi_logloss: 1.83556\n", | |||
|  |       "[14]\tvalidation's multi_logloss: 1.83556\n", | |||
|  |       "[15]\tvalidation's multi_logloss: 1.82583\n", | |||
|  |       "[15]\tvalidation's multi_logloss: 1.82583\n", | |||
|  |       "[16]\tvalidation's multi_logloss: 1.81337\n", | |||
|  |       "[16]\tvalidation's multi_logloss: 1.81337\n", | |||
|  |       "[17]\tvalidation's multi_logloss: 1.80265\n", | |||
|  |       "[17]\tvalidation's multi_logloss: 1.80265\n", | |||
|  |       "[18]\tvalidation's multi_logloss: 1.79249\n", | |||
|  |       "[18]\tvalidation's multi_logloss: 1.79249\n", | |||
|  |       "[19]\tvalidation's multi_logloss: 1.78209\n", | |||
|  |       "[19]\tvalidation's multi_logloss: 1.78209\n", | |||
|  |       "[20]\tvalidation's multi_logloss: 1.7724\n", | |||
|  |       "   批次训练失败: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()\n", | |||
|  |       "[20]\tvalidation's multi_logloss: 1.7724\n", | |||
|  |       "   批次训练失败: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()\n", | |||
|  |       "   批次 3 - 随机选择的文件:\n", | |||
|  |       "      1. t15.2024.07.28_train_concatenated.npz\n", | |||
|  |       "      2. t15.2023.10.20_train_concatenated.npz\n", | |||
|  |       "   批次 3 - 随机选择的文件:\n", | |||
|  |       "      1. t15.2024.07.28_train_concatenated.npz\n", | |||
|  |       "      2. t15.2023.10.20_train_concatenated.npz\n", | |||
|  |       "   合并后总样本数: 57,627\n", | |||
|  |       "   随机采样到: 4,000 样本\n", | |||
|  |       "   合并后总样本数: 57,627\n", | |||
|  |       "   随机采样到: 4,000 样本\n", | |||
|  |       "   批次 3 最终结果:\n", | |||
|  |       "      平衡后样本数: 3,493\n", | |||
|  |       "      特征维度: 1219\n", | |||
|  |       "      分布: 标签0=388, 标签40=388\n", | |||
|  |       "   ==================================================\n", | |||
|  |       "   批次 3 最终结果:\n", | |||
|  |       "      平衡后样本数: 3,493\n", | |||
|  |       "      特征维度: 1219\n", | |||
|  |       "      分布: 标签0=388, 标签40=388\n", | |||
|  |       "   ==================================================\n", | |||
|  |       "\n", | |||
|  |       "批次 3: batch_3_files_2\n", | |||
|  |       "   样本数: 3,493\n", | |||
|  |       "   增量训练 (学习率: 0.100000)\n", | |||
|  |       "\n", | |||
|  |       "批次 3: batch_3_files_2\n", | |||
|  |       "   样本数: 3,493\n", | |||
|  |       "   增量训练 (学习率: 0.100000)\n" | |||
|  |      ] | |||
|  |     }, | |||
|  |     { | |||
|  |      "name": "stderr", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n", | |||
|  |       "  _log_warning('Overriding the parameters from Reference Dataset.')\n" | |||
|  |      ] | |||
|  |     }, | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "[21]\tvalidation's multi_logloss: 1.84376\n", | |||
|  |       "[22]\tvalidation's multi_logloss: 1.83698\n", | |||
|  |       "[22]\tvalidation's multi_logloss: 1.83698\n", | |||
|  |       "[23]\tvalidation's multi_logloss: 1.81596\n", | |||
|  |       "[23]\tvalidation's multi_logloss: 1.81596\n", | |||
|  |       "[24]\tvalidation's multi_logloss: 1.79844\n", | |||
|  |       "[24]\tvalidation's multi_logloss: 1.79844\n", | |||
|  |       "[25]\tvalidation's multi_logloss: 1.77603\n", | |||
|  |       "[25]\tvalidation's multi_logloss: 1.77603\n", | |||
|  |       "[26]\tvalidation's multi_logloss: 1.76072\n", | |||
|  |       "[26]\tvalidation's multi_logloss: 1.76072\n", | |||
|  |       "[27]\tvalidation's multi_logloss: 1.73975\n", | |||
|  |       "[27]\tvalidation's multi_logloss: 1.73975\n", | |||
|  |       "[28]\tvalidation's multi_logloss: 1.72273\n", | |||
|  |       "[28]\tvalidation's multi_logloss: 1.72273\n", | |||
|  |       "[29]\tvalidation's multi_logloss: 1.70835\n", | |||
|  |       "[29]\tvalidation's multi_logloss: 1.70835\n", | |||
|  |       "[30]\tvalidation's multi_logloss: 1.69313\n", | |||
|  |       "   批次训练失败: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()\n", | |||
|  |       "[30]\tvalidation's multi_logloss: 1.69313\n", | |||
|  |       "   批次训练失败: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()\n", | |||
|  |       "   批次 4 - 随机选择的文件:\n", | |||
|  |       "      1. t15.2024.07.28_train_concatenated.npz\n", | |||
|  |       "      2. t15.2023.11.03_train_concatenated.npz\n", | |||
|  |       "   批次 4 - 随机选择的文件:\n", | |||
|  |       "      1. t15.2024.07.28_train_concatenated.npz\n", | |||
|  |       "      2. t15.2023.11.03_train_concatenated.npz\n", | |||
|  |       "   合并后总样本数: 76,842\n", | |||
|  |       "   随机采样到: 4,000 样本\n", | |||
|  |       "   合并后总样本数: 76,842\n", | |||
|  |       "   随机采样到: 4,000 样本\n", | |||
|  |       "   批次 4 最终结果:\n", | |||
|  |       "      平衡后样本数: 3,368\n", | |||
|  |       "      特征维度: 1219\n", | |||
|  |       "      分布: 标签0=388, 标签40=388\n", | |||
|  |       "   ==================================================\n", | |||
|  |       "   批次 4 最终结果:\n", | |||
|  |       "      平衡后样本数: 3,368\n", | |||
|  |       "      特征维度: 1219\n", | |||
|  |       "      分布: 标签0=388, 标签40=388\n", | |||
|  |       "   ==================================================\n", | |||
|  |       "\n", | |||
|  |       "批次 4: batch_4_files_2\n", | |||
|  |       "   样本数: 3,368\n", | |||
|  |       "   增量训练 (学习率: 0.085502)\n", | |||
|  |       "\n", | |||
|  |       "批次 4: batch_4_files_2\n", | |||
|  |       "   样本数: 3,368\n", | |||
|  |       "   增量训练 (学习率: 0.085502)\n" | |||
|  |      ] | |||
|  |     }, | |||
|  |     { | |||
|  |      "name": "stderr", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n", | |||
|  |       "  _log_warning('Overriding the parameters from Reference Dataset.')\n" | |||
|  |      ] | |||
|  |     }, | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "[31]\tvalidation's multi_logloss: 1.97656\n", | |||
|  |       "[32]\tvalidation's multi_logloss: 1.96596\n", | |||
|  |       "[32]\tvalidation's multi_logloss: 1.96596\n", | |||
|  |       "[33]\tvalidation's multi_logloss: 1.94311\n", | |||
|  |       "[33]\tvalidation's multi_logloss: 1.94311\n", | |||
|  |       "[34]\tvalidation's multi_logloss: 1.92177\n", | |||
|  |       "[34]\tvalidation's multi_logloss: 1.92177\n", | |||
|  |       "[35]\tvalidation's multi_logloss: 1.90267\n", | |||
|  |       "[35]\tvalidation's multi_logloss: 1.90267\n", | |||
|  |       "[36]\tvalidation's multi_logloss: 1.88543\n", | |||
|  |       "[36]\tvalidation's multi_logloss: 1.88543\n", | |||
|  |       "[37]\tvalidation's multi_logloss: 1.86867\n", | |||
|  |       "[37]\tvalidation's multi_logloss: 1.86867\n", | |||
|  |       "[38]\tvalidation's multi_logloss: 1.85562\n", | |||
|  |       "[38]\tvalidation's multi_logloss: 1.85562\n", | |||
|  |       "[39]\tvalidation's multi_logloss: 1.8434\n", | |||
|  |       "[39]\tvalidation's multi_logloss: 1.8434\n", | |||
|  |       "[40]\tvalidation's multi_logloss: 1.83282\n", | |||
|  |       "   批次训练失败: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()\n", | |||
|  |       "[40]\tvalidation's multi_logloss: 1.83282\n", | |||
|  |       "   批次训练失败: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()\n", | |||
|  |       "   批次 5 - 随机选择的文件:\n", | |||
|  |       "      1. t15.2023.09.03_train_concatenated.npz\n", | |||
|  |       "      2. t15.2023.12.08_train_concatenated.npz\n", | |||
|  |       "   批次 5 - 随机选择的文件:\n", | |||
|  |       "      1. t15.2023.09.03_train_concatenated.npz\n", | |||
|  |       "      2. t15.2023.12.08_train_concatenated.npz\n", | |||
|  |       "   合并后总样本数: 112,126\n", | |||
|  |       "   随机采样到: 4,000 样本\n", | |||
|  |       "   合并后总样本数: 112,126\n", | |||
|  |       "   随机采样到: 4,000 样本\n", | |||
|  |       "   批次 5 最终结果:\n", | |||
|  |       "      平衡后样本数: 3,549\n", | |||
|  |       "      特征维度: 1219\n", | |||
|  |       "      分布: 标签0=388, 标签40=388\n", | |||
|  |       "   ==================================================\n", | |||
|  |       "   批次 5 最终结果:\n", | |||
|  |       "      平衡后样本数: 3,549\n", | |||
|  |       "      特征维度: 1219\n", | |||
|  |       "      分布: 标签0=388, 标签40=388\n", | |||
|  |       "   ==================================================\n", | |||
|  |       "\n", | |||
|  |       "批次 5: batch_5_files_2\n", | |||
|  |       "   样本数: 3,549\n", | |||
|  |       "   增量训练 (学习率: 0.050500)\n", | |||
|  |       "\n", | |||
|  |       "批次 5: batch_5_files_2\n", | |||
|  |       "   样本数: 3,549\n", | |||
|  |       "   增量训练 (学习率: 0.050500)\n" | |||
|  |      ] | |||
|  |     }, | |||
|  |     { | |||
|  |      "ename": "KeyboardInterrupt", | |||
|  |      "evalue": "", | |||
|  |      "output_type": "error", | |||
|  |      "traceback": [ | |||
|  |       "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |||
|  |       "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)", | |||
|  |       "\u001b[0;32m/tmp/ipykernel_36/1413771336.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0;31m# 开始使用改进的训练器\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m model = trainer.train_incremental(\n\u001b[0m\u001b[1;32m     12\u001b[0m     \u001b[0mnum_boost_round\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mIMPROVED_TRAINING_PARAMS\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'num_boost_round'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m     \u001b[0mearly_stopping_rounds\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mIMPROVED_TRAINING_PARAMS\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'early_stopping_rounds'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |||
|  |       "\u001b[0;32m/tmp/ipykernel_36/3359360572.py\u001b[0m in \u001b[0;36mtrain_incremental\u001b[0;34m(self, num_boost_round, early_stopping_rounds, n_files_per_batch, batch_size, max_batches)\u001b[0m\n\u001b[1;32m    317\u001b[0m                     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"   增量训练 (学习率: {current_lr:.6f})\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    318\u001b[0m                     \u001b[0mtrain_params\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'init_model'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 319\u001b[0;31m                     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlgb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mtrain_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    320\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    321\u001b[0m                 \u001b[0;31m# 验证\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |||
|  |       "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/lightgbm/engine.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(params, train_set, num_boost_round, valid_sets, valid_names, feval, init_model, feature_name, categorical_feature, keep_training_booster, callbacks)\u001b[0m\n\u001b[1;32m    220\u001b[0m                     \u001b[0mtrain_data_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalid_names\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    221\u001b[0m                 \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 222\u001b[0;31m             \u001b[0mreduced_valid_sets\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalid_data\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_reference\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_set\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    223\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mvalid_names\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalid_names\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    224\u001b[0m                 \u001b[0mname_valid_sets\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalid_names\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |||
|  |       "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py\u001b[0m in \u001b[0;36mset_reference\u001b[0;34m(self, reference)\u001b[0m\n\u001b[1;32m   2831\u001b[0m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_categorical_feature\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreference\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcategorical_feature\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2832\u001b[0m             \u001b[0;34m.\u001b[0m\u001b[0mset_feature_name\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreference\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeature_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2833\u001b[0;31m             \u001b[0;34m.\u001b[0m\u001b[0m_set_predictor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreference\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_predictor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2834\u001b[0m         \u001b[0;31m# we're done if self and reference share a common upstream reference\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2835\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_ref_chain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mintersection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreference\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_ref_chain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |||
|  |       "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py\u001b[0m in \u001b[0;36m_set_predictor\u001b[0;34m(self, predictor)\u001b[0m\n\u001b[1;32m   2799\u001b[0m         \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2800\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_predictor\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpredictor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2801\u001b[0;31m             self._set_init_score_by_predictor(\n\u001b[0m\u001b[1;32m   2802\u001b[0m                 \u001b[0mpredictor\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_predictor\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2803\u001b[0m                 \u001b[0mdata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |||
|  |       "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py\u001b[0m in \u001b[0;36m_set_init_score_by_predictor\u001b[0;34m(self, predictor, data, used_indices)\u001b[0m\n\u001b[1;32m   1970\u001b[0m         \u001b[0mnum_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1971\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mpredictor\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1972\u001b[0;31m             init_score: Union[np.ndarray, scipy.sparse.spmatrix] = predictor.predict(\n\u001b[0m\u001b[1;32m   1973\u001b[0m                 \u001b[0mdata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1974\u001b[0m                 \u001b[0mraw_score\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |||
|  |       "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(self, data, start_iteration, num_iteration, raw_score, pred_leaf, pred_contrib, data_has_header, validate_features)\u001b[0m\n\u001b[1;32m   1157\u001b[0m             )\n\u001b[1;32m   1158\u001b[0m         \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndarray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1159\u001b[0;31m             preds, nrow = self.__pred_for_np2d(\n\u001b[0m\u001b[1;32m   1160\u001b[0m                 \u001b[0mmat\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1161\u001b[0m                 \u001b[0mstart_iteration\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstart_iteration\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |||
|  |       "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py\u001b[0m in \u001b[0;36m__pred_for_np2d\u001b[0;34m(self, mat, start_iteration, num_iteration, predict_type)\u001b[0m\n\u001b[1;32m   1304\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mpreds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnrow\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1305\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1306\u001b[0;31m             return self.__inner_predict_np2d(\n\u001b[0m\u001b[1;32m   1307\u001b[0m                 \u001b[0mmat\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmat\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1308\u001b[0m                 \u001b[0mstart_iteration\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mstart_iteration\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |||
|  |       "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py\u001b[0m in \u001b[0;36m__inner_predict_np2d\u001b[0;34m(self, mat, start_iteration, num_iteration, predict_type, preds)\u001b[0m\n\u001b[1;32m   1257\u001b[0m             \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Wrong length of pre-allocated predict array\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1258\u001b[0m         \u001b[0mout_num_preds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mctypes\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_int64\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1259\u001b[0;31m         _safe_call(_LIB.LGBM_BoosterPredictForMat(\n\u001b[0m\u001b[1;32m   1260\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_handle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1261\u001b[0m             \u001b[0mptr_data\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |||
|  |       "\u001b[0;31mKeyboardInterrupt\u001b[0m: " | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "# 改进的训练参数\n", | |||
|  |     "IMPROVED_TRAINING_PARAMS = {\n", | |||
|  |     "    'num_boost_round': 10,          # 每批次的提升轮数\n", | |||
|  |     "    'early_stopping_rounds': 10,    # 早停轮数\n", | |||
|  |     "    'n_files_per_batch': 2,         # 快速验证用,减少到2\n", | |||
|  |     "    'batch_size': 4000,             # 快速验证用,减半\n", | |||
|  |     "    'max_batches':  100               # 仅跑100个批次做冒烟测试\n", | |||
|  |     "}\n", | |||
|  |     "\n", | |||
|  |     "# 开始使用改进的训练器\n", | |||
|  |     "model = trainer.train_incremental(\n", | |||
|  |     "    num_boost_round=IMPROVED_TRAINING_PARAMS['num_boost_round'],\n", | |||
|  |     "    early_stopping_rounds=IMPROVED_TRAINING_PARAMS['early_stopping_rounds'],\n", | |||
|  |     "    n_files_per_batch=IMPROVED_TRAINING_PARAMS['n_files_per_batch'],\n", | |||
|  |     "    batch_size=IMPROVED_TRAINING_PARAMS['batch_size'],\n", | |||
|  |     "    max_batches=IMPROVED_TRAINING_PARAMS['max_batches']\n", | |||
|  |     ")\n", | |||
|  |     "\n", | |||
|  |     "# 训练完成后计算一次验证集PER(每个文件取33%试验)\n", | |||
|  |     "per_summary = trainer.evaluate_val_per_experiment(fraction=0.33, random_state=42, drop_sep40=False, max_trials_per_file=5)\n", | |||
|  |     "print(per_summary)" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "## 📊 训练结果分析" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": null, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [], | |||
|  |    "source": [ | |||
|  |     "# 📊 训练结果分析和可视化\n", | |||
|  |     "\n", | |||
|  |     "print(\"📊 分析智能分批训练结果...\")\n", | |||
|  |     "print(\"=\" * 60)\n", | |||
|  |     "\n", | |||
|  |     "# 显示训练进度图表\n", | |||
|  |     "trainer.plot_training_progress()\n", | |||
|  |     "\n", | |||
|  |     "# 保存最终模型\n", | |||
|  |     "final_model_path = \"smart_pipeline_final_model.txt\"\n", | |||
|  |     "if trainer.model:\n", | |||
|  |     "    trainer.model.save_model(final_model_path)\n", | |||
|  |     "    print(f\"\\n💾 最终模型已保存: {final_model_path}\")\n", | |||
|  |     "\n", | |||
|  |     "# 详细分析\n", | |||
|  |     "if trainer.training_history:\n", | |||
|  |     "    print(f\"\\n📈 详细训练分析:\")\n", | |||
|  |     "    print(f\"   🎯 训练批次总数: {len(trainer.training_history)}\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 最佳批次\n", | |||
|  |     "    best_batch = max(trainer.training_history, key=lambda x: x['val_accuracy'])\n", | |||
|  |     "    print(f\"   🏆 最佳验证准确率: {best_batch['val_accuracy']:.4f} (批次 {best_batch['batch']})\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 训练效率\n", | |||
|  |     "    total_training_time = sum(h['time'] for h in trainer.training_history)\n", | |||
|  |     "    avg_batch_time = total_training_time / len(trainer.training_history)\n", | |||
|  |     "    print(f\"   ⏱️ 总训练时间: {total_training_time:.1f}秒\")\n", | |||
|  |     "    print(f\"   ⏱️ 平均批次时间: {avg_batch_time:.1f}秒\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 模型复杂度\n", | |||
|  |     "    final_trees = trainer.training_history[-1]['num_trees']\n", | |||
|  |     "    print(f\"   🌳 最终模型树数: {final_trees}\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 收敛性分析\n", | |||
|  |     "    recent_accs = [h['val_accuracy'] for h in trainer.training_history[-3:]]\n", | |||
|  |     "    if len(recent_accs) >= 2:\n", | |||
|  |     "        acc_stability = max(recent_accs) - min(recent_accs)\n", | |||
|  |     "        print(f\"   📈 准确率稳定性: {acc_stability:.4f} (最近3批次方差)\")\n", | |||
|  |     "        \n", | |||
|  |     "        if acc_stability < 0.01:\n", | |||
|  |     "            print(\"   ✅ 模型已收敛 (准确率变化 < 1%)\")\n", | |||
|  |     "        else:\n", | |||
|  |     "            print(\"   ⚠️ 模型可能需要更多训练\")\n", | |||
|  |     "\n", | |||
|  |     "print(f\"\\n🎉 智能分批训练分析完成!\")\n", | |||
|  |     "print(f\"   💡 使用了改进的数据平衡策略和PCA降维\")\n", | |||
|  |     "print(f\"   💡 训练集应用了下采样+过采样,验证集保持原始分布\")\n", | |||
|  |     "print(f\"   💡 实现了内存友好的分批处理\")" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "## 🧪 模型性能评估" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": null, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [], | |||
|  |    "source": [ | |||
|  |     "# 🧪 模型性能评估\n", | |||
|  |     "\n", | |||
|  |     "from sklearn.metrics import classification_report, confusion_matrix\n", | |||
|  |     "import numpy as np\n", | |||
|  |     "\n", | |||
|  |     "def evaluate_model_performance(model, pipeline, data_type='val'):\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    评估模型在指定数据集上的性能\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    print(f\"🧪 评估模型在{data_type}数据集上的性能...\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 加载数据\n", | |||
|  |     "    X, y = pipeline.step3_process_data(data_type, apply_sampling=False)\n", | |||
|  |     "    \n", | |||
|  |     "    if X is None or y is None:\n", | |||
|  |     "        print(f\"❌ 无法加载{data_type}数据\")\n", | |||
|  |     "        return None\n", | |||
|  |     "    \n", | |||
|  |     "    print(f\"   📊 数据集大小: {X.shape[0]:,} 样本, {X.shape[1]} 特征\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 预测\n", | |||
|  |     "    start_time = time.time()\n", | |||
|  |     "    y_pred_proba = model.predict(X)\n", | |||
|  |     "    y_pred = y_pred_proba.argmax(axis=1)\n", | |||
|  |     "    pred_time = time.time() - start_time\n", | |||
|  |     "    \n", | |||
|  |     "    # 计算性能指标\n", | |||
|  |     "    accuracy = (y_pred == y).mean()\n", | |||
|  |     "    \n", | |||
|  |     "    print(f\"   ⏱️ 预测时间: {pred_time:.2f}秒\")\n", | |||
|  |     "    print(f\"   🎯 整体准确率: {accuracy:.4f}\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 分析各类别性能\n", | |||
|  |     "    from collections import Counter\n", | |||
|  |     "    true_counts = Counter(y)\n", | |||
|  |     "    pred_counts = Counter(y_pred)\n", | |||
|  |     "    \n", | |||
|  |     "    print(f\"\\n📊 标签分布对比:\")\n", | |||
|  |     "    print(\"标签 | 真实数量 | 预测数量 | 准确率\")\n", | |||
|  |     "    print(\"-\" * 40)\n", | |||
|  |     "    \n", | |||
|  |     "    label_accuracies = {}\n", | |||
|  |     "    for label in range(41):\n", | |||
|  |     "        if label in true_counts:\n", | |||
|  |     "            label_mask = (y == label)\n", | |||
|  |     "            if label_mask.sum() > 0:\n", | |||
|  |     "                label_acc = (y_pred[label_mask] == label).mean()\n", | |||
|  |     "                label_accuracies[label] = label_acc\n", | |||
|  |     "                true_count = true_counts.get(label, 0)\n", | |||
|  |     "                pred_count = pred_counts.get(label, 0)\n", | |||
|  |     "                print(f\"{label:4d} | {true_count:8,} | {pred_count:8,} | {label_acc:7.3f}\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 重点分析关键标签\n", | |||
|  |     "    print(f\"\\n🔍 关键标签性能分析:\")\n", | |||
|  |     "    key_labels = [0, 40]  # 下采样的标签\n", | |||
|  |     "    for label in key_labels:\n", | |||
|  |     "        if label in label_accuracies:\n", | |||
|  |     "            acc = label_accuracies[label]\n", | |||
|  |     "            count = true_counts.get(label, 0)\n", | |||
|  |     "            print(f\"   标签 {label} (下采样目标): 准确率 {acc:.4f}, 样本数 {count:,}\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 少数类性能\n", | |||
|  |     "    minority_labels = [label for label, count in true_counts.items() \n", | |||
|  |     "                      if count < 200 and label not in [0, 40]]\n", | |||
|  |     "    if minority_labels:\n", | |||
|  |     "        minority_accs = [label_accuracies.get(label, 0) for label in minority_labels[:5]]\n", | |||
|  |     "        avg_minority_acc = np.mean(minority_accs) if minority_accs else 0\n", | |||
|  |     "        print(f\"   少数类平均准确率 (前5个): {avg_minority_acc:.4f}\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 置信度分析\n", | |||
|  |     "    max_proba = y_pred_proba.max(axis=1)\n", | |||
|  |     "    print(f\"\\n📈 预测置信度分析:\")\n", | |||
|  |     "    print(f\"   平均置信度: {max_proba.mean():.4f}\")\n", | |||
|  |     "    print(f\"   置信度中位数: {np.median(max_proba):.4f}\")\n", | |||
|  |     "    print(f\"   高置信度预测 (>0.9): {(max_proba > 0.9).sum():,} / {len(max_proba):,} ({(max_proba > 0.9).mean():.2%})\")\n", | |||
|  |     "    \n", | |||
|  |     "    return {\n", | |||
|  |     "        'accuracy': accuracy,\n", | |||
|  |     "        'prediction_time': pred_time,\n", | |||
|  |     "        'label_accuracies': label_accuracies,\n", | |||
|  |     "        'confidence_stats': {\n", | |||
|  |     "            'mean': max_proba.mean(),\n", | |||
|  |     "            'median': np.median(max_proba),\n", | |||
|  |     "            'high_confidence_ratio': (max_proba > 0.9).mean()\n", | |||
|  |     "        }\n", | |||
|  |     "    }\n", | |||
|  |     "\n", | |||
|  |     "# 评估模型性能\n", | |||
|  |     "if trainer.model:\n", | |||
|  |     "    print(\"🧪 开始模型性能评估...\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 验证集评估\n", | |||
|  |     "    val_results = evaluate_model_performance(trainer.model, pipeline, 'val')\n", | |||
|  |     "    \n", | |||
|  |     "    print(f\"\\n\" + \"=\"*60)\n", | |||
|  |     "    print(\"🎉 智能分批训练+数据平衡 评估完成!\")\n", | |||
|  |     "    print(f\"✅ 实现了数据平衡和PCA降维的完整流程\")\n", | |||
|  |     "    print(f\"✅ 使用了内存友好的分批训练策略\")\n", | |||
|  |     "    print(f\"✅ 保持了验证集的原始分布以确保评估客观性\")\n", | |||
|  |     "else:\n", | |||
|  |     "    print(\"❌ 模型尚未训练完成,请等待训练结束后运行此评估\")" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": null, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [], | |||
|  |    "source": [ | |||
|  |     "# ✅ 余弦退火已更新为带重启版本\n", | |||
|  |     "\n", | |||
|  |     "print(\"🎉 余弦退火调度器更新完成!\")\n", | |||
|  |     "\n", | |||
|  |     "# 检查trainer是否已创建,如果未创建则先创建\n", | |||
|  |     "if 'trainer' not in globals():\n", | |||
|  |     "    print(\"⚠️ 训练器尚未创建,请先运行前面的代码创建训练器\")\n", | |||
|  |     "else:\n", | |||
|  |     "    print(f\"✅ 当前使用:带重启的余弦退火 (SGDR)\")\n", | |||
|  |     "    print(f\"   🔄 重启参数: T_0={trainer.t_0}, T_mult={trainer.t_mult}\")\n", | |||
|  |     "    print(f\"   📈 学习率范围: {trainer.initial_learning_rate} → {trainer.min_learning_rate}\")\n", | |||
|  |     "\n", | |||
|  |     "    # 可视化新的学习率调度\n", | |||
|  |     "    import matplotlib.pyplot as plt\n", | |||
|  |     "    import numpy as np\n", | |||
|  |     "\n", | |||
|  |     "    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))\n", | |||
|  |     "\n", | |||
|  |     "    # 模拟300轮的学习率变化\n", | |||
|  |     "    rounds = list(range(300))\n", | |||
|  |     "    old_lrs = []  # 原始余弦退火\n", | |||
|  |     "    new_lrs = []  # 带重启的余弦退火\n", | |||
|  |     "\n", | |||
|  |     "    for r in rounds:\n", | |||
|  |     "        # 原始余弦退火 (单调递减)\n", | |||
|  |     "        old_lr = trainer.min_learning_rate + 0.5 * (trainer.initial_learning_rate - trainer.min_learning_rate) * (1 + np.cos(np.pi * r / 300))\n", | |||
|  |     "        old_lrs.append(old_lr)\n", | |||
|  |     "        \n", | |||
|  |     "        # 带重启的余弦退火\n", | |||
|  |     "        new_lr = trainer._cosine_annealing_with_warm_restarts(r)\n", | |||
|  |     "        new_lrs.append(new_lr)\n", | |||
|  |     "\n", | |||
|  |     "    # 绘制对比图\n", | |||
|  |     "    ax1.plot(rounds, old_lrs, 'b-', label='原始余弦退火', linewidth=2)\n", | |||
|  |     "    ax1.set_xlabel('Training Round')\n", | |||
|  |     "    ax1.set_ylabel('Learning Rate')\n", | |||
|  |     "    ax1.set_title('原始余弦退火 (单调递减)')\n", | |||
|  |     "    ax1.grid(True, alpha=0.3)\n", | |||
|  |     "    ax1.legend()\n", | |||
|  |     "\n", | |||
|  |     "    ax2.plot(rounds, new_lrs, 'r-', label='带重启的余弦退火', linewidth=2)\n", | |||
|  |     "    ax2.set_xlabel('Training Round')\n", | |||
|  |     "    ax2.set_ylabel('Learning Rate')\n", | |||
|  |     "    ax2.set_title('带重启的余弦退火 (SGDR)')\n", | |||
|  |     "    ax2.grid(True, alpha=0.3)\n", | |||
|  |     "    ax2.legend()\n", | |||
|  |     "\n", | |||
|  |     "    plt.tight_layout()\n", | |||
|  |     "    plt.show()\n", | |||
|  |     "\n", | |||
|  |     "    print(\"📊 学习率调度对比可视化完成\")\n", | |||
|  |     "    print(\"   🔵 原始版本:单调递减的余弦曲线\")\n", | |||
|  |     "    print(\"   🔴 新版本:周期性重启,每次重启后学习率回到最大值\")\n", | |||
|  |     "    print(\"   💡 SGDR的优势:多次重启可以帮助模型跳出局部最优解\")\n", | |||
|  |     "\n", | |||
|  |     "    # 显示重启点\n", | |||
|  |     "    restart_points = []\n", | |||
|  |     "    t_cur = 0\n", | |||
|  |     "    t_i = trainer.t_0\n", | |||
|  |     "    while t_cur < 300:\n", | |||
|  |     "        restart_points.append(t_cur)\n", | |||
|  |     "        t_cur += t_i\n", | |||
|  |     "        t_i *= trainer.t_mult\n", | |||
|  |     "\n", | |||
|  |     "    print(f\"   🔄 在300轮训练中的重启点: {restart_points[:5]}...\")  # 显示前5个重启点" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": null, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [], | |||
|  |    "source": [] | |||
|  |   } | |||
|  |  ], | |||
|  |  "metadata": { | |||
|  |   "kaggle": { | |||
|  |    "accelerator": "tpu1vmV38", | |||
|  |    "dataSources": [ | |||
|  |     { | |||
|  |      "databundleVersionId": 13056355, | |||
|  |      "sourceId": 106809, | |||
|  |      "sourceType": "competition" | |||
|  |     } | |||
|  |    ], | |||
|  |    "dockerImageVersionId": 31091, | |||
|  |    "isGpuEnabled": false, | |||
|  |    "isInternetEnabled": true, | |||
|  |    "language": "python", | |||
|  |    "sourceType": "notebook" | |||
|  |   }, | |||
|  |   "kernelspec": { | |||
|  |    "display_name": "Python 3 (ipykernel)", | |||
|  |    "language": "python", | |||
|  |    "name": "python3" | |||
|  |   }, | |||
|  |   "language_info": { | |||
|  |    "codemirror_mode": { | |||
|  |     "name": "ipython", | |||
|  |     "version": 3 | |||
|  |    }, | |||
|  |    "file_extension": ".py", | |||
|  |    "mimetype": "text/x-python", | |||
|  |    "name": "python", | |||
|  |    "nbconvert_exporter": "python", | |||
|  |    "pygments_lexer": "ipython3", | |||
|  |    "version": "3.11.13" | |||
|  |   } | |||
|  |  }, | |||
|  |  "nbformat": 4, | |||
|  |  "nbformat_minor": 4 | |||
|  | } |