{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## 🎲 改进的随机批次生成器\n", "\n", "这个版本改进了数据生成策略:\n", "- **随机文件选择**: 每次从所有训练文件中随机选择 n=4 个文件\n", "- **随机样本采样**: 从选中的文件中随机采样指定数量的样本\n", "- **提高数据多样性**: 避免按固定顺序处理文件,减少过拟合风险\n", "- **可控批次大小**: 固定每批次样本数,确保训练稳定性" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 环境配置与Utils" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://download.pytorch.org/whl/cu126\n", "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n", "Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (0.21.0+cu124)\n", "Requirement already satisfied: torchaudio in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n", "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.14.0)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.5)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.5.1)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch) (9.1.0.70)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.5.8)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch) (11.2.1.3)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch) (10.3.5.147)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch) (11.6.1.9)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch) (12.3.1.170)\n", "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision) (1.26.4)\n", "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (11.2.1)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n", "Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.3.8)\n", "Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.2.4)\n", "Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (0.1.1)\n", "Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2025.2.0)\n", "Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2022.2.0)\n", "Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2.4.1)\n", "Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2024.2.0)\n", "Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2022.2.0)\n", "Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy->torchvision) (1.4.0)\n", "Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy->torchvision) (2024.2.0)\n", "Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy->torchvision) (2024.2.0)\n", "Requirement already satisfied: jupyter==1.1.1 in /usr/local/lib/python3.11/dist-packages (1.1.1)\n", "Requirement already satisfied: numpy<2.1.0,>=1.26.0 in /usr/local/lib/python3.11/dist-packages (1.26.4)\n", "Requirement already satisfied: pandas==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n", "Requirement already satisfied: matplotlib==3.10.1 in /usr/local/lib/python3.11/dist-packages (3.10.1)\n", "Requirement already satisfied: scipy==1.15.2 in /usr/local/lib/python3.11/dist-packages (1.15.2)\n", "Requirement already satisfied: scikit-learn==1.6.1 in /usr/local/lib/python3.11/dist-packages (1.6.1)\n", "Requirement already satisfied: lightgbm==4.3.0 in /usr/local/lib/python3.11/dist-packages (4.3.0)\n", "Requirement already satisfied: tqdm==4.67.1 in /usr/local/lib/python3.11/dist-packages (4.67.1)\n", "Requirement already satisfied: g2p_en==2.1.0 in /usr/local/lib/python3.11/dist-packages (2.1.0)\n", "Requirement already satisfied: h5py==3.13.0 in /usr/local/lib/python3.11/dist-packages (3.13.0)\n", "Requirement already satisfied: omegaconf==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n", "Requirement already satisfied: editdistance==0.8.1 in /usr/local/lib/python3.11/dist-packages (0.8.1)\n", "Requirement already satisfied: huggingface-hub==0.33.1 in /usr/local/lib/python3.11/dist-packages (0.33.1)\n", "Requirement already satisfied: transformers==4.53.0 in /usr/local/lib/python3.11/dist-packages (4.53.0)\n", "Requirement already satisfied: tokenizers==0.21.2 in /usr/local/lib/python3.11/dist-packages (0.21.2)\n", "Requirement already satisfied: accelerate==1.8.1 in /usr/local/lib/python3.11/dist-packages (1.8.1)\n", "Requirement already satisfied: bitsandbytes==0.46.0 in /usr/local/lib/python3.11/dist-packages (0.46.0)\n", "Requirement already satisfied: seaborn==0.13.2 in /usr/local/lib/python3.11/dist-packages (0.13.2)\n", "Requirement already satisfied: notebook in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.5.4)\n", "Requirement already satisfied: jupyter-console in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.1.0)\n", "Requirement already satisfied: nbconvert in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.4.5)\n", "Requirement already satisfied: ipykernel in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.17.1)\n", "Requirement already satisfied: ipywidgets in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (8.1.5)\n", "Requirement already satisfied: jupyterlab in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (3.6.8)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2.9.0.post0)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n", "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n", "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.3.2)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (0.12.1)\n", "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (4.58.4)\n", "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.4.8)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (25.0)\n", "Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (11.2.1)\n", "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (3.0.9)\n", "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (1.5.1)\n", "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (3.6.0)\n", "Requirement already satisfied: nltk>=3.2.4 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (3.9.1)\n", "Requirement already satisfied: inflect>=0.3.1 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (7.5.0)\n", "Requirement already satisfied: distance>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (0.1.3)\n", "Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (4.9.3)\n", "Requirement already satisfied: PyYAML>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (6.0.2)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (3.18.0)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2025.5.1)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2.32.4)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (4.14.0)\n", "Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (1.1.5)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (2024.11.6)\n", "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (0.5.3)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (7.0.0)\n", "Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (2.6.0+cu124)\n", "Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.3.8)\n", "Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.2.4)\n", "Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (0.1.1)\n", "Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2025.2.0)\n", "Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2022.2.0)\n", "Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2.4.1)\n", "Requirement already satisfied: more_itertools>=8.5.0 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (10.7.0)\n", "Requirement already satisfied: typeguard>=4.0.1 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (4.4.4)\n", "Requirement already satisfied: click in /usr/local/lib/python3.11/dist-packages (from nltk>=3.2.4->g2p_en==2.1.0) (8.2.1)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas==2.3.0) (1.17.0)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.5)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.1.6)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (9.1.0.70)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.5.8)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.2.1.3)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (10.3.5.147)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.6.1.9)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.3.1.170)\n", "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (0.6.2)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.2.0)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=2.0.0->accelerate==1.8.1) (1.3.0)\n", "Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.8.0)\n", "Requirement already satisfied: ipython>=7.23.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (7.34.0)\n", "Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (8.6.3)\n", "Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (0.1.7)\n", "Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.6.0)\n", "Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (24.0.1)\n", "Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (6.5.1)\n", "Requirement already satisfied: traitlets>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (5.7.1)\n", "Requirement already satisfied: comm>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (0.2.2)\n", "Requirement already satisfied: widgetsnbextension~=4.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (4.0.14)\n", "Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (3.0.15)\n", "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (3.0.51)\n", "Requirement already satisfied: pygments in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (2.19.2)\n", "Requirement already satisfied: jupyter-core in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (5.8.1)\n", "Requirement already satisfied: jupyterlab-server~=2.19 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.27.3)\n", "Requirement already satisfied: jupyter-server<3,>=1.16.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.12.5)\n", "Requirement already satisfied: jupyter-ydoc~=0.2.4 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.2.5)\n", "Requirement already satisfied: jupyter-server-ydoc~=0.8.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.8.0)\n", "Requirement already satisfied: nbclassic in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (1.3.1)\n", "Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (25.1.0)\n", "Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.2.0)\n", "Requirement already satisfied: nbformat in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (5.10.4)\n", "Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (1.8.3)\n", "Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.18.1)\n", "Requirement already satisfied: prometheus-client in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.22.1)\n", "Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.8.4)\n", "Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.3.0)\n", "Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.4)\n", "Requirement already satisfied: bleach in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (6.2.0)\n", "Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (1.5.1)\n", "Requirement already satisfied: testpath in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.6.0)\n", "Requirement already satisfied: defusedxml in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.7.1)\n", "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (4.13.4)\n", "Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.5.13)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (3.0.2)\n", "Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n", "Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2022.2.0)\n", "Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy<2.1.0,>=1.26.0) (1.4.0)\n", "Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy<2.1.0,>=1.26.0) (2024.2.0)\n", "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.4.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2.5.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2025.6.15)\n", "Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n", "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (75.2.0)\n", "Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.19.2)\n", "Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.4.2)\n", "Requirement already satisfied: pickleshare in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.7.5)\n", "Requirement already satisfied: backcall in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.2.0)\n", "Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.9.0)\n", "Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.11/dist-packages (from jupyter-core->jupyterlab->jupyter==1.1.1) (4.3.8)\n", "Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (4.9.0)\n", "Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.12.0)\n", "Requirement already satisfied: jupyter-server-terminals in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.5.3)\n", "Requirement already satisfied: overrides in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (7.7.0)\n", "Requirement already satisfied: websocket-client in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.8.0)\n", "Requirement already satisfied: jupyter-server-fileid<1,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.9.3)\n", "Requirement already satisfied: ypy-websocket<0.9.0,>=0.8.2 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.8.4)\n", "Requirement already satisfied: y-py<0.7.0,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-ydoc~=0.2.4->jupyterlab->jupyter==1.1.1) (0.6.2)\n", "Requirement already satisfied: babel>=2.10 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2.17.0)\n", "Requirement already satisfied: json5>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.12.0)\n", "Requirement already satisfied: jsonschema>=4.18.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (4.24.0)\n", "Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.11/dist-packages (from nbclassic->jupyterlab->jupyter==1.1.1) (0.2.4)\n", "Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.11/dist-packages (from nbformat->notebook->jupyter==1.1.1) (2.21.1)\n", "Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->jupyter-console->jupyter==1.1.1) (0.2.13)\n", "Requirement already satisfied: ptyprocess in /usr/local/lib/python3.11/dist-packages (from terminado>=0.8.3->notebook->jupyter==1.1.1) (0.7.0)\n", "Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.11/dist-packages (from argon2-cffi->notebook->jupyter==1.1.1) (21.2.0)\n", "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->nbconvert->jupyter==1.1.1) (2.7)\n", "Requirement already satisfied: webencodings in /usr/local/lib/python3.11/dist-packages (from bleach->nbconvert->jupyter==1.1.1) (0.5.1)\n", "Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio>=3.1.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.1)\n", "Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.11/dist-packages (from jedi>=0.16->ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.8.4)\n", "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (25.3.0)\n", "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2025.4.1)\n", "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.36.2)\n", "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.25.1)\n", "Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.3.0)\n", "Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.4)\n", "Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.1)\n", "Requirement already satisfied: aiofiles<23,>=22.1.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (22.1.0)\n", "Requirement already satisfied: aiosqlite<1,>=0.17.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.21.0)\n", "Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (1.17.1)\n", "Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (2.22)\n", "Requirement already satisfied: fqdn in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.5.1)\n", "Requirement already satisfied: isoduration in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (20.11.0)\n", "Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.0.0)\n", "Requirement already satisfied: uri-template in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n", "Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (24.11.1)\n", "Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.11/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n", "Requirement already satisfied: types-python-dateutil>=2.8.10 in /usr/local/lib/python3.11/dist-packages (from arrow>=0.15.0->isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (2.9.0.20250516)\n", "Obtaining file:///kaggle/working/nejm-brain-to-text\n", " Preparing metadata (setup.py): started\n", " Preparing metadata (setup.py): finished with status 'done'\n", "Installing collected packages: nejm_b2txt_utils\n", " Running setup.py develop for nejm_b2txt_utils\n", "Successfully installed nejm_b2txt_utils-0.0.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Cloning into 'nejm-brain-to-text'...\n" ] } ], "source": [ "%%bash\n", "rm -rf /kaggle/working/nejm-brain-to-text/\n", "git clone https://github.com/ZH-CEN/nejm-brain-to-text.git\n", "cp /kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl /kaggle/working/nejm-brain-to-text/data/t15_copyTask.pkl\n", "\n", "ln -s /kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline /kaggle/working/nejm-brain-to-text/data\n", "ln -s /kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final /kaggle/working/nejm-brain-to-text/data\n", "ln -s /kaggle/input/rnn-pretagged-data /kaggle/working/nejm-brain-to-text/data/concatenated_data\n", "\n", "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n", "\n", "pip install \\\n", " jupyter==1.1.1 \\\n", " \"numpy>=1.26.0,<2.1.0\" \\\n", " pandas==2.3.0 \\\n", " matplotlib==3.10.1 \\\n", " scipy==1.15.2 \\\n", " scikit-learn==1.6.1 \\\n", " lightgbm==4.3.0 \\\n", " tqdm==4.67.1 \\\n", " g2p_en==2.1.0 \\\n", " h5py==3.13.0 \\\n", " omegaconf==2.3.0 \\\n", " editdistance==0.8.1 \\\n", " huggingface-hub==0.33.1 \\\n", " transformers==4.53.0 \\\n", " tokenizers==0.21.2 \\\n", " accelerate==1.8.1 \\\n", " bitsandbytes==0.46.0 \\\n", " seaborn==0.13.2\n", "cd /kaggle/working/nejm-brain-to-text/\n", "pip install -e ." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==================================================\n", "🔧 LightGBM GPU环境检查\n", "==================================================\n", "❌ 未检测到NVIDIA GPU或驱动\n", "\n", "✅ CUDA工具包:\n", " Cuda compilation tools, release 12.5, V12.5.82\n" ] } ], "source": [ "# 🚀 LightGBM GPU支持检查与配置\n", "\n", "print(\"=\"*50)\n", "print(\"🔧 LightGBM GPU环境检查\")\n", "print(\"=\"*50)\n", "\n", "# 检查CUDA和GPU驱动\n", "import subprocess\n", "import sys\n", "\n", "def run_command(command):\n", " \"\"\"运行命令并返回结果\"\"\"\n", " try:\n", " result = subprocess.run(command, shell=True, capture_output=True, text=True, timeout=10)\n", " return result.stdout.strip(), result.returncode == 0\n", " except Exception as e:\n", " return str(e), False\n", "\n", "# 检查NVIDIA GPU\n", "nvidia_output, nvidia_success = run_command(\"nvidia-smi --query-gpu=name,memory.total,driver_version --format=csv,noheader,nounits\")\n", "if nvidia_success:\n", " print(\"✅ NVIDIA GPU检测:\")\n", " for line in nvidia_output.split('\\n'):\n", " if line.strip():\n", " print(f\" {line}\")\n", "else:\n", " print(\"❌ 未检测到NVIDIA GPU或驱动\")\n", "\n", "# 检查CUDA版本\n", "cuda_output, cuda_success = run_command(\"nvcc --version\")\n", "if cuda_success:\n", " print(\"\\n✅ CUDA工具包:\")\n", " # 提取CUDA版本\n", " for line in cuda_output.split('\\n'):\n", " if 'release' in line:\n", " print(f\" {line.strip()}\")\n", "else:\n", " print(\"\\n❌ 未安装CUDA工具包\")\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working/nejm-brain-to-text\n" ] } ], "source": [ "%cd nejm-brain-to-text\n", "import numpy as np\n", "import os\n", "import pickle\n", "import matplotlib.pyplot as plt\n", "import matplotlib\n", "from g2p_en import G2p\n", "import pandas as pd\n", "import numpy as np\n", "from nejm_b2txt_utils.general_utils import *\n", "matplotlib.rcParams['pdf.fonttype'] = 42\n", "matplotlib.rcParams['ps.fonttype'] = 42\n", "matplotlib.rcParams['font.family'] = 'sans-serif'\n", "matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial Unicode MS', 'sans-serif']\n", "matplotlib.rcParams['axes.unicode_minus'] = False\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working/nejm-brain-to-text/model_training\n" ] } ], "source": [ "%cd model_training/\n", "from data_augmentations import gauss_smooth" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "LOGIT_TO_PHONEME = [\n", " 'BLANK',\n", " 'AA', 'AE', 'AH', 'AO', 'AW',\n", " 'AY', 'B', 'CH', 'D', 'DH',\n", " 'EH', 'ER', 'EY', 'F', 'G',\n", " 'HH', 'IH', 'IY', 'JH', 'K',\n", " 'L', 'M', 'N', 'NG', 'OW',\n", " 'OY', 'P', 'R', 'S', 'SH',\n", " 'T', 'TH', 'UH', 'UW', 'V',\n", " 'W', 'Y', 'Z', 'ZH',\n", " ' | ',\n", "]\n", "# 全局配置\n", "BALANCE_CONFIG = {\n", " 'enable_balance': True, # 是否启用数据平衡\n", " 'undersample_labels': [0, 40], # 需要下采样的标签 (blank等高频标签)\n", " 'oversample_threshold': 0.5, # 过采样阈值 (相对于均值的比例)\n", " 'random_state': 42 # 随机种子\n", "}\n", "# 全局PCA配置\n", "PCA_CONFIG = {\n", " 'enable_pca': True, # 是否启用PCA\n", " 'n_components': None, # None=自动选择, 或指定具体数值\n", " 'variance_threshold': 0.95, # 保留95%的方差\n", " 'sample_size': 15000, # 用于拟合PCA的样本数\n", "}\n", "\n", "# 全局PCA对象 (确保只拟合一次)\n", "GLOBAL_PCA = {\n", " 'scaler': None,\n", " 'pca': None,\n", " 'is_fitted': False,\n", " 'n_components': None\n", "}\n", "# 设置数据目录和参数【PCA初始化】\n", "data_dir = '../data/concatenated_data'\n", "MAX_SAMPLES_PER_FILE = -1 # 每个文件最大样本数,可调整" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 数据读取工作流" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 2️⃣ 数据加载与PCA降维" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# 🚀 内存友好的数据读取 - 分批加载策略 + PCA降维 【这里还缺一个采样】\n", "import os\n", "import numpy as np\n", "import gc\n", "from sklearn.decomposition import PCA\n", "from sklearn.preprocessing import StandardScaler\n", "import joblib\n", "import matplotlib.pyplot as plt\n", "import random\n", "\n", "\n", "def load_data_batch(data_dir, data_type, max_samples_per_file=5000, verbose=True, random_shuffle_files=True):\n", " \"\"\"\n", " 分批加载指定类型的数据,支持随机文件顺序\n", " \n", " Args:\n", " data_dir: 数据目录\n", " data_type: 'train', 'val', 'test'\n", " max_samples_per_file: 每个文件最大加载样本数\n", " verbose: 是否打印每个文件的加载进度\n", " random_shuffle_files: 是否随机打乱文件加载顺序\n", " \n", " Returns:\n", " generator: 数据批次生成器\n", " \"\"\"\n", " files = [f for f in os.listdir(data_dir) if f.endswith('.npz') and data_type in f]\n", " \n", " # 随机打乱文件顺序\n", " if random_shuffle_files:\n", " random.shuffle(files)\n", " if verbose:\n", " print(f\" 已随机打乱 {len(files)} 个文件的加载顺序\")\n", " \n", " for file_idx, f in enumerate(files):\n", " if verbose:\n", " print(f\" 正在加载文件 {file_idx+1}/{len(files)}: {f}\")\n", " \n", " data = np.load(os.path.join(data_dir, f), allow_pickle=True)\n", " trials = data['neural_logits_concatenated']\n", " \n", " # 限制每个文件的样本数\n", " if len(trials) > max_samples_per_file and max_samples_per_file != -1:\n", " # 随机选择样本而不是只取前N个\n", " random_indices = np.random.choice(len(trials), max_samples_per_file, replace=False)\n", " trials = trials[random_indices]\n", " if verbose:\n", " print(f\" 随机采样样本数至: {max_samples_per_file}\")\n", " \n", " yield trials, f\n", " \n", " # 清理内存\n", " del data, trials\n", " gc.collect()\n", "\n", "def extract_features_labels_batch(trials_batch, random_shuffle_trials=True):\n", " \"\"\"\n", " 从试验批次中提取特征和标签,支持随机打乱试验顺序\n", " \n", " Args:\n", " trials_batch: 试验批次数据\n", " random_shuffle_trials: 是否随机打乱批次内的试验顺序\n", " \"\"\"\n", " features = []\n", " labels = []\n", " \n", " # 随机打乱试验顺序\n", " if random_shuffle_trials and len(trials_batch) > 1:\n", " trial_indices = list(range(len(trials_batch)))\n", " random.shuffle(trial_indices)\n", " trials_batch = trials_batch[trial_indices]\n", " \n", " for trial in trials_batch:\n", " if trial.shape[0] > 0:\n", " # 随机打乱时间步顺序\n", " time_indices = list(range(trial.shape[0]))\n", " if random_shuffle_trials:\n", " random.shuffle(time_indices)\n", " \n", " for t in time_indices:\n", " neural_features = trial[t, :7168] # 前7168维神经特征\n", " rnn_logits = trial[t, 7168:] # 后41维RNN输出\n", " phoneme_label = np.argmax(rnn_logits)\n", " \n", " features.append(neural_features)\n", " labels.append(phoneme_label)\n", " \n", " return np.array(features), np.array(labels)\n", "\n", "def fit_global_pca(data_dir, config):\n", " \"\"\"\n", " 在训练数据上拟合全局PCA (只执行一次)\n", " \"\"\"\n", " if GLOBAL_PCA['is_fitted'] or not config['enable_pca']:\n", " print(\"PCA已拟合或未启用,跳过拟合步骤\")\n", " return\n", " \n", " print(f\"拟合全局PCA降维器...\")\n", " print(f\" 配置: {config}\")\n", " \n", " # 收集训练样本(使用随机加载)\n", " sample_features = []\n", " collected_samples = 0\n", " \n", " # 设置随机种子以确保可重现性\n", " random.seed(42)\n", " np.random.seed(42)\n", " \n", " for trials_batch, filename in load_data_batch(\n", " data_dir, 'train', 5000, verbose=False, random_shuffle_files=True\n", " ):\n", " features, labels = extract_features_labels_batch(trials_batch, random_shuffle_trials=True)\n", " sample_features.append(features)\n", " collected_samples += features.shape[0]\n", " \n", " if collected_samples >= config['sample_size']:\n", " break\n", " \n", " if sample_features:\n", " # 合并样本数据\n", " X_sample = np.vstack(sample_features)[:config['sample_size']]\n", " \n", " # 再次随机打乱样本顺序\n", " shuffle_indices = np.random.permutation(len(X_sample))\n", " X_sample = X_sample[shuffle_indices]\n", " \n", " print(f\" 实际样本数: {X_sample.shape[0]}\")\n", " print(f\" 原始特征数: {X_sample.shape[1]}\")\n", " \n", " # 标准化\n", " GLOBAL_PCA['scaler'] = StandardScaler()\n", " X_sample_scaled = GLOBAL_PCA['scaler'].fit_transform(X_sample)\n", " \n", " # 确定PCA成分数\n", " if config['n_components'] is None:\n", " print(f\" 自动选择PCA成分数...\")\n", " pca_full = PCA()\n", " pca_full.fit(X_sample_scaled)\n", " \n", " cumsum_var = np.cumsum(pca_full.explained_variance_ratio_)\n", " optimal_components = np.argmax(cumsum_var >= config['variance_threshold']) + 1\n", " GLOBAL_PCA['n_components'] = min(optimal_components, X_sample.shape[1])\n", " \n", " print(f\" 保留{config['variance_threshold']*100}%方差需要: {optimal_components} 个成分\")\n", " print(f\" 选择成分数: {GLOBAL_PCA['n_components']}\")\n", " else:\n", " GLOBAL_PCA['n_components'] = config['n_components']\n", " print(f\" 使用指定成分数: {GLOBAL_PCA['n_components']}\")\n", " \n", " # 拟合最终PCA\n", " GLOBAL_PCA['pca'] = PCA(n_components=GLOBAL_PCA['n_components'], random_state=42)\n", " GLOBAL_PCA['pca'].fit(X_sample_scaled)\n", " GLOBAL_PCA['is_fitted'] = True\n", " \n", " # 保存模型\n", " pca_path = \"global_pca_model.joblib\"\n", " joblib.dump({\n", " 'scaler': GLOBAL_PCA['scaler'], \n", " 'pca': GLOBAL_PCA['pca'],\n", " 'n_components': GLOBAL_PCA['n_components']\n", " }, pca_path)\n", " \n", " print(f\" 全局PCA拟合完成\")\n", " print(f\" 降维: {X_sample.shape[1]} → {GLOBAL_PCA['n_components']}\")\n", " print(f\" 降维比例: {GLOBAL_PCA['n_components']/X_sample.shape[1]:.2%}\")\n", " print(f\" 保留方差: {GLOBAL_PCA['pca'].explained_variance_ratio_.sum():.4f}\")\n", " print(f\" 模型已保存: {pca_path}\")\n", " \n", " # 清理样本数据\n", " del sample_features, X_sample, X_sample_scaled\n", " gc.collect()\n", " else:\n", " print(\"无法收集样本数据用于PCA拟合\")\n", "\n", "def apply_pca_transform(features):\n", " \"\"\"\n", " 应用全局PCA变换\n", " \"\"\"\n", " if not PCA_CONFIG['enable_pca'] or not GLOBAL_PCA['is_fitted']:\n", " return features\n", " \n", " # 标准化 + PCA变换\n", " features_scaled = GLOBAL_PCA['scaler'].transform(features)\n", " features_pca = GLOBAL_PCA['pca'].transform(features_scaled)\n", " return features_pca" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 📊 数据平衡策略 - 标签分布分析与采样优化" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# 【数据平衡核心实现】\n", "def balance_dataset(X, y):\n", " \"\"\"\n", " 对数据集进行平衡处理:只做下采样到第三小的样本数目\n", " \n", " Args:\n", " X: 特征数据\n", " y: 标签数据\n", " config: 平衡配置\n", " \n", " Returns:\n", " X_balanced, y_balanced: 平衡后的数据\n", " \"\"\"\n", " if not config['enable_balance']:\n", " print(\"🔕 数据平衡已禁用,返回原始数据\")\n", " return X, y\n", " \n", " print(f\"\\n⚖️ 开始数据平衡处理(只下采样到第三小样本数)...\")\n", " print(f\" 原始数据: {X.shape[0]:,} 样本\")\n", " \n", " # 分析当前分布,找到第三小的样本数\n", " label_counts = Counter(y)\n", " all_counts = [label_counts.get(i, 0) for i in range(41)] # 所有标签的样本数\n", " non_zero_counts = [count for count in all_counts if count > 0] # 去除0样本的标签\n", " \n", " # 排序找到第三小的样本数\n", " sorted_counts = sorted(non_zero_counts)\n", " if len(sorted_counts) >= 3:\n", " third_smallest_count = sorted_counts[2] # 第三小(索引2)\n", " elif len(sorted_counts) >= 2:\n", " third_smallest_count = sorted_counts[1] # 如果不足3个,用第二小\n", " else:\n", " third_smallest_count = sorted_counts[0] if sorted_counts else 1 # 如果不足2个,用最小的\n", " \n", " print(f\" 所有标签样本数: {sorted_counts[:10]}{'...' if len(sorted_counts) > 10 else ''}\")\n", " print(f\" 第三小样本数: {third_smallest_count}\")\n", " print(f\" 下采样策略: 所有标签都下采样到 {third_smallest_count}\")\n", " \n", " # 准备平衡后的数据\n", " X_balanced = []\n", " y_balanced = []\n", " \n", " random.seed(config['random_state'])\n", " np.random.seed(config['random_state'])\n", " \n", " for label in range(41):\n", " # 获取当前标签的所有样本\n", " label_mask = (y == label)\n", " X_label = X[label_mask]\n", " y_label = y[label_mask]\n", " current_count = len(y_label)\n", " \n", " if current_count == 0:\n", " continue\n", " \n", " # 下采样到第三小样本数\n", " if current_count > third_smallest_count:\n", " # 随机下采样\n", " indices = np.random.choice(current_count, third_smallest_count, replace=False)\n", " X_resampled = X_label[indices]\n", " y_resampled = y_label[indices]\n", " print(f\" 📉 标签 {label}: {current_count} → {third_smallest_count} (下采样)\")\n", " else:\n", " # 保持所有样本(不进行上采样)\n", " X_resampled = X_label\n", " y_resampled = y_label\n", " print(f\" ✅ 标签 {label}: {current_count} (保持不变)\")\n", " \n", " X_balanced.append(X_resampled)\n", " y_balanced.append(y_resampled)\n", " \n", " # 合并所有平衡后的数据\n", " X_balanced = np.vstack(X_balanced)\n", " y_balanced = np.hstack(y_balanced)\n", " \n", " # 随机打乱\n", " shuffle_indices = np.random.permutation(len(y_balanced))\n", " X_balanced = X_balanced[shuffle_indices]\n", " y_balanced = y_balanced[shuffle_indices]\n", " \n", " # 统计最终结果\n", " final_counts = Counter(y_balanced)\n", " print(f\"\\n ✅ 下采样完成: {X_balanced.shape[0]:,} 样本\")\n", " print(f\" 数据变化: {X.shape[0]:,} → {X_balanced.shape[0]:,} ({X_balanced.shape[0]/X.shape[0]:.2f}x)\")\n", " print(f\" 最终各标签样本数分布:\")\n", " for label in range(41):\n", " count = final_counts.get(label, 0)\n", " if count > 0:\n", " print(f\" 标签 {label}: {count}\")\n", " \n", " return X_balanced, y_balanced" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🔄 集成数据平衡的内存友好数据加载器" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🧪 数据平衡效果测试" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🚀 改进版智能数据处理管道" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "🚀 创建智能数据处理管道...\n", "✅ 管道创建完成,准备执行步骤1...\n" ] } ], "source": [ "# 🚀 改进版智能数据处理管道【没有解决分批训练的问题】\n", "# 流程:分析分布 → 确定采样比率 → 拟合PCA(只下采样) → 数据处理(下采样+上采样+PCA)\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from collections import Counter\n", "from sklearn.utils import resample\n", "from sklearn.decomposition import PCA\n", "from sklearn.preprocessing import StandardScaler\n", "import joblib\n", "import random\n", "import gc\n", "\n", "class SmartDataPipeline:\n", " \"\"\"\n", " 智能数据处理管道\n", " 步骤1: 分析数据分布,确定采样策略\n", " 步骤2: 仅下采样拟合PCA参数\n", " 步骤3: 数据处理时应用完整采样+PCA降维\n", " \"\"\"\n", " \n", " def __init__(self, data_dir, random_state=42):\n", " self.data_dir = data_dir\n", " self.random_state = random_state\n", " \n", " # 步骤1: 分布分析结果\n", " self.distribution_analysis = None\n", " self.sampling_strategy = None\n", " \n", " # 步骤2: PCA参数(基于下采样数据拟合)\n", " self.pca_scaler = None\n", " self.pca_model = None\n", " self.pca_components = None\n", " self.pca_fitted = False\n", " \n", " # 配置参数\n", " self.undersample_labels = [0, 40] # 需要下采样的标签\n", " self.oversample_threshold = 0.5 # 过采样阈值(相对于均值)\n", " self.pca_variance_threshold = 0.95 # PCA保留方差比例\n", " self.pca_sample_size = 15000 # PCA拟合样本数\n", " \n", " # def step1_analyze_distribution(self, max_samples=100000):\n", " # \"\"\"\n", " # 步骤1: 分析数据分布,确定采样策略\n", " # \"\"\"\n", " # print(\"🔍 步骤1: 分析数据分布...\")\n", " \n", " # # 分析验证集分布(代表整体分布特征)\n", " # all_labels = []\n", " # for trials_batch, filename in load_data_batch(self.data_dir, 'val', 5000):\n", " # _, labels = extract_features_labels_batch(trials_batch)\n", " # all_labels.extend(labels.tolist())\n", " # if len(all_labels) >= max_samples:\n", " # break\n", " \n", " # # 统计分析\n", " # label_counts = Counter(all_labels)\n", " \n", " # # 计算1-39标签的均值(排除0和40)\n", " # counts_1_39 = [label_counts.get(i, 0) for i in range(1, 40)]\n", " # target_mean = np.mean(counts_1_39)\n", " \n", " # # 生成采样策略\n", " # sampling_strategy = {}\n", " # for label in range(41):\n", " # current_count = label_counts.get(label, 0)\n", " \n", " # if label in self.undersample_labels:\n", " # # 下采样到均值水平\n", " # target_count = int(target_mean)\n", " # action = 'undersample' if current_count > target_count else 'keep'\n", " # elif current_count < target_mean * self.oversample_threshold:\n", " # # 过采样到阈值水平\n", " # target_count = int(target_mean * self.oversample_threshold)\n", " # action = 'oversample' if current_count < target_count else 'keep'\n", " # else:\n", " # # 保持不变\n", " # target_count = current_count\n", " # action = 'keep'\n", " \n", " # sampling_strategy[label] = {\n", " # 'current_count': current_count,\n", " # 'target_count': target_count,\n", " # 'action': action\n", " # }\n", " \n", " # self.distribution_analysis = {\n", " # 'label_counts': label_counts,\n", " # 'target_mean': target_mean,\n", " # 'total_samples': len(all_labels)\n", " # }\n", " # self.sampling_strategy = sampling_strategy\n", " \n", " # print(f\" ✅ 分析完成: {len(all_labels):,} 样本\")\n", " # print(f\" 📊 标签1-39均值: {target_mean:.0f}\")\n", " # print(f\" 📉 下采样标签: {self.undersample_labels} → {target_mean:.0f}\")\n", " # print(f\" 📈 过采样阈值: {self.oversample_threshold} × 均值 = {target_mean * self.oversample_threshold:.0f}\")\n", " \n", " # return self.distribution_analysis, self.sampling_strategy\n", "\n", "# 创建智能数据处理管道\n", "print(\"🚀 创建智能数据处理管道...\")\n", "pipeline = SmartDataPipeline(data_dir, random_state=42)\n", "print(\"✅ 管道创建完成,准备执行步骤1...\")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "步骤1和步骤2方法已添加到管道(第三小样本数策略)\n" ] } ], "source": [ "# 继续添加智能管道的其他方法【管道完善】- 修改为第三小样本数策略\n", "\n", "def step1_analyze_distribution(self, max_samples=100000):\n", " \"\"\"\n", " 步骤1: 分析数据分布,确定采样策略(改为第三小样本数策略)\n", " \"\"\"\n", " print(\"🔍 步骤1: 分析数据分布(第三小样本数策略)...\")\n", " \n", " # 分析验证集分布(代表整体分布特征)\n", " all_labels = []\n", " for trials_batch, filename in load_data_batch(self.data_dir, 'val', 5000):\n", " _, labels = extract_features_labels_batch(trials_batch)\n", " all_labels.extend(labels.tolist())\n", " if len(all_labels) >= max_samples:\n", " break\n", " \n", " # 统计分析\n", " label_counts = Counter(all_labels)\n", " \n", " # 计算所有标签的样本数,找到第三小的数值\n", " all_counts = [label_counts.get(i, 0) for i in range(41)]\n", " non_zero_counts = [count for count in all_counts if count > 0]\n", " sorted_counts = sorted(non_zero_counts)\n", " \n", " if len(sorted_counts) >= 3:\n", " third_smallest_count = sorted_counts[2] # 第三小\n", " elif len(sorted_counts) >= 2:\n", " third_smallest_count = sorted_counts[1] # 如果不足3个,用第二小\n", " else:\n", " third_smallest_count = sorted_counts[0] if sorted_counts else 1 # 如果不足2个,用最小的\n", " \n", " print(f\" 所有标签样本数: {sorted_counts[:10]}{'...' if len(sorted_counts) > 10 else ''}\")\n", " print(f\" 第三小样本数: {third_smallest_count}\")\n", " \n", " # 生成采样策略:所有标签都下采样到第三小,不进行过采样\n", " sampling_strategy = {}\n", " for label in range(41):\n", " current_count = label_counts.get(label, 0)\n", " \n", " if current_count > third_smallest_count:\n", " # 下采样到第三小样本数\n", " target_count = third_smallest_count\n", " action = 'undersample'\n", " else:\n", " # 保持现有样本数(不进行过采样)\n", " target_count = current_count\n", " action = 'keep'\n", " \n", " sampling_strategy[label] = {\n", " 'current_count': current_count,\n", " 'target_count': target_count,\n", " 'action': action\n", " }\n", " \n", " self.distribution_analysis = {\n", " 'label_counts': label_counts,\n", " 'target_third_smallest': third_smallest_count,\n", " 'total_samples': len(all_labels),\n", " 'sorted_counts': sorted_counts\n", " }\n", " self.sampling_strategy = sampling_strategy\n", " \n", " # 统计采样策略\n", " undersample_count = sum(1 for s in sampling_strategy.values() if s['action'] == 'undersample')\n", " keep_count = sum(1 for s in sampling_strategy.values() if s['action'] == 'keep')\n", " \n", " print(f\" ✅ 分析完成: {len(all_labels):,} 样本\")\n", " print(f\" 📉 下采样标签: {undersample_count} 个 → {third_smallest_count}\")\n", " print(f\" ✅ 保持不变: {keep_count} 个\")\n", " print(f\" 🚫 不进行过采样\")\n", " \n", " return self.distribution_analysis, self.sampling_strategy\n", "\n", "def step2_fit_pca_with_undersampling(self):\n", " \"\"\"\n", " 步骤2: 仅对下采样数据拟合PCA参数(不进行过采样,避免PCA被过采样影响)\n", " \"\"\"\n", " if self.sampling_strategy is None:\n", " raise ValueError(\"请先执行步骤1: step1_analyze_distribution()\")\n", " \n", " print(\"\\n🔧 步骤2: 拟合PCA参数(仅下采样,不过采样)...\")\n", " \n", " # 收集用于PCA拟合的样本(只下采样,不过采样)\n", " pca_features = []\n", " collected_samples = 0\n", " \n", " for trials_batch, filename in load_data_batch(self.data_dir, 'train', 3000, verbose=False):\n", " features, labels = extract_features_labels_batch(trials_batch)\n", " \n", " # 对当前批次应用仅下采样策略\n", " downsampled_features, downsampled_labels = self._apply_undersampling_only(features, labels)\n", " \n", " if downsampled_features.shape[0] > 0:\n", " pca_features.append(downsampled_features)\n", " collected_samples += downsampled_features.shape[0]\n", " \n", " if collected_samples >= self.pca_sample_size:\n", " break\n", " \n", " if not pca_features:\n", " raise ValueError(\"无法收集用于PCA拟合的样本,请检查数据或采样策略\")\n", " \n", " # 合并样本用于PCA拟合\n", " X_pca_fit = np.vstack(pca_features)[:self.pca_sample_size]\n", " print(f\" 用于PCA拟合的样本数: {X_pca_fit.shape[0]:,}\")\n", " \n", " # 标准化 + PCA\n", " self.pca_scaler = StandardScaler()\n", " X_scaled = self.pca_scaler.fit_transform(X_pca_fit)\n", " \n", " # 自动选择PCA成分数以保留指定方差\n", " if self.pca_components is None:\n", " pca_full = PCA()\n", " pca_full.fit(X_scaled)\n", " cumsum_var = np.cumsum(pca_full.explained_variance_ratio_)\n", " optimal_components = np.argmax(cumsum_var >= self.pca_variance_threshold) + 1\n", " self.pca_components = optimal_components\n", " \n", " self.pca_model = PCA(n_components=self.pca_components, random_state=self.random_state)\n", " self.pca_model.fit(X_scaled)\n", " self.pca_fitted = True\n", " \n", " print(f\" PCA拟合完成: 7168 → {self.pca_components}\")\n", " print(f\" 保留方差: {self.pca_model.explained_variance_ratio_.sum():.4f}\")\n", "\n", "def _apply_undersampling_only(self, X, y):\n", " \"\"\"\n", " 仅对指定标签做下采样(不做过采样)- 修改为第三小样本数策略\n", " \"\"\"\n", " if self.sampling_strategy is None:\n", " raise ValueError(\"请先执行步骤1: step1_analyze_distribution()\")\n", " \n", " X_result = []\n", " y_result = []\n", " \n", " np.random.seed(self.random_state)\n", " \n", " for label in range(41):\n", " label_mask = (y == label)\n", " X_label = X[label_mask]\n", " y_label = y[label_mask]\n", " current_count = len(y_label)\n", " \n", " if current_count == 0:\n", " continue\n", " \n", " strategy = self.sampling_strategy[label]\n", " \n", " if strategy['action'] == 'undersample' and current_count > strategy['target_count']:\n", " # 下采样到第三小样本数\n", " indices = np.random.choice(current_count, strategy['target_count'], replace=False)\n", " X_resampled = X_label[indices]\n", " y_resampled = y_label[indices]\n", " else:\n", " # 保持原样\n", " X_resampled = X_label\n", " y_resampled = y_label\n", " \n", " X_result.append(X_resampled)\n", " y_result.append(y_resampled)\n", " \n", " if X_result:\n", " return np.vstack(X_result), np.hstack(y_result)\n", " else:\n", " return np.array([]).reshape(0, X.shape[1]), np.array([])\n", "\n", "# 动态添加方法到类\n", "SmartDataPipeline.step1_analyze_distribution = step1_analyze_distribution\n", "SmartDataPipeline.step2_fit_pca_with_undersampling = step2_fit_pca_with_undersampling\n", "SmartDataPipeline._apply_undersampling_only = _apply_undersampling_only\n", "\n", "print(\"步骤1和步骤2方法已添加到管道(第三小样本数策略)\")" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "所有方法已添加到智能管道(第三小样本数策略)\n", "\n", "智能数据处理管道状态:\n", " 步骤1 - 分布分析: 未完成\n", " 步骤2 - PCA拟合: 未完成\n" ] } ], "source": [ "# 添加智能管道的剩余方法 - 修改为第三小样本数策略\n", "\n", "def _apply_full_sampling(self, X, y):\n", " \"\"\"\n", " 应用完整的采样策略(修改为只下采样到第三小样本数)\n", " \"\"\"\n", " X_result = []\n", " y_result = []\n", " \n", " np.random.seed(self.random_state)\n", " \n", " for label in range(41):\n", " label_mask = (y == label)\n", " X_label = X[label_mask]\n", " y_label = y[label_mask]\n", " current_count = len(y_label)\n", " \n", " if current_count == 0:\n", " continue\n", " \n", " strategy = self.sampling_strategy[label]\n", " target_count = strategy['target_count']\n", " \n", " if strategy['action'] == 'undersample' and current_count > target_count:\n", " # 只进行下采样到第三小样本数\n", " indices = np.random.choice(current_count, target_count, replace=False)\n", " X_resampled = X_label[indices]\n", " y_resampled = y_label[indices]\n", " else:\n", " # 保持原样(不进行过采样)\n", " X_resampled = X_label\n", " y_resampled = y_label\n", " \n", " X_result.append(X_resampled)\n", " y_result.append(y_resampled)\n", " \n", " if X_result:\n", " return np.vstack(X_result), np.hstack(y_result)\n", " else:\n", " return np.array([]).reshape(0, X.shape[1]), np.array([])\n", "\n", "def _apply_pca_transform(self, X):\n", " \"\"\"\n", " 应用PCA变换\n", " \"\"\"\n", " if not self.pca_fitted:\n", " return X\n", " \n", " X_scaled = self.pca_scaler.transform(X)\n", " X_pca = self.pca_model.transform(X_scaled)\n", " return X_pca\n", "\n", "def step3_process_data(self, data_type, apply_sampling=None):\n", " \"\"\"\n", " 步骤3: 处理数据(采样+PCA降维)\n", " \"\"\"\n", " if not self.pca_fitted:\n", " raise ValueError(\"请先执行步骤2: step2_fit_pca_with_undersampling()\")\n", " \n", " if apply_sampling is None:\n", " apply_sampling = (data_type == 'train')\n", " \n", " print(f\"\\n处理{data_type}数据...\")\n", " print(f\" 采样策略: {'启用(只下采样)' if apply_sampling else '禁用'}\")\n", " \n", " all_features = []\n", " all_labels = []\n", " \n", " # 在内部关闭加载时的逐文件打印\n", " for trials_batch, filename in load_data_batch(self.data_dir, data_type, 3000, verbose=False):\n", " features, labels = extract_features_labels_batch(trials_batch)\n", " \n", " if apply_sampling:\n", " features_sampled, labels_sampled = self._apply_full_sampling(features, labels)\n", " else:\n", " features_sampled, labels_sampled = features, labels\n", " \n", " if features_sampled.shape[0] > 0:\n", " features_pca = self._apply_pca_transform(features_sampled)\n", " all_features.append(features_pca)\n", " all_labels.append(labels_sampled)\n", " \n", " if all_features:\n", " X = np.vstack(all_features)\n", " y = np.hstack(all_labels)\n", " \n", " shuffle_indices = np.random.permutation(len(y))\n", " X = X[shuffle_indices]\n", " y = y[shuffle_indices]\n", " \n", " print(f\" 完成: {X.shape[0]:,} 样本, {X.shape[1]} 特征\")\n", " \n", " del all_features, all_labels\n", " gc.collect()\n", " \n", " return X, y\n", " else:\n", " return None, None\n", "\n", "def print_summary(self):\n", " print(\"\\n智能数据处理管道状态:\")\n", " print(f\" 步骤1 - 分布分析: {'完成' if self.distribution_analysis else '未完成'}\")\n", " print(f\" 步骤2 - PCA拟合: {'完成' if self.pca_fitted else '未完成'}\")\n", " \n", " if self.distribution_analysis:\n", " target_third_smallest = self.distribution_analysis['target_third_smallest']\n", " print(f\" 第三小样本数: {target_third_smallest}\")\n", " print(f\" 采样策略: 只下采样,不过采样\")\n", " \n", " if self.pca_fitted:\n", " print(f\" PCA降维: 7168 → {self.pca_components} ({self.pca_components/7168:.1%})\")\n", " print(f\" 保留方差: {self.pca_model.explained_variance_ratio_.sum():.4f}\")\n", "\n", "# 动态添加剩余方法到类\n", "SmartDataPipeline._apply_full_sampling = _apply_full_sampling\n", "SmartDataPipeline._apply_pca_transform = _apply_pca_transform\n", "SmartDataPipeline.step3_process_data = step3_process_data\n", "SmartDataPipeline.print_summary = print_summary\n", "\n", "print(\"所有方法已添加到智能管道(第三小样本数策略)\")\n", "pipeline.print_summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🎯 数据增强模块 - 时序神经数据增强" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "🎯 神经数据增强器初始化完成\n", " 白噪声标准差: 1.0\n", " 常数偏移标准差: 0.2\n", " 随机游走标准差: 0.0\n", " 静态增益标准差: 0.0\n", " 随机切割步数: 3\n", " 平滑数据: True\n", " 平滑核大小: 100, 标准差: 2\n", "✅ 数据增强器创建完成!\n" ] } ], "source": [ "# 🎯 时序神经数据增强类\n", "\n", "import numpy as np\n", "import random\n", "from scipy import ndimage\n", "from scipy.ndimage import gaussian_filter1d\n", "import gc\n", "\n", "class NeuralDataAugmenter:\n", " \"\"\"\n", " 时序神经数据增强器\n", " 专门用于处理脑机接口的神经信号数据增强\n", " \"\"\"\n", " \n", " def __init__(self, \n", " white_noise_std=1.0,\n", " constant_offset_std=0.2,\n", " random_walk_std=0.0,\n", " random_walk_axis=-1,\n", " static_gain_std=0.0,\n", " random_cut=3,\n", " smooth_kernel_size=100,\n", " smooth_data=True,\n", " smooth_kernel_std=2,\n", " random_state=42):\n", " \"\"\"\n", " 初始化数据增强器\n", " \n", " Args:\n", " white_noise_std: 白噪声标准差\n", " constant_offset_std: 常数偏移标准差\n", " random_walk_std: 随机游走标准差\n", " random_walk_axis: 随机游走应用的轴\n", " static_gain_std: 静态增益标准差\n", " random_cut: 随机切割时间步数\n", " smooth_kernel_size: 平滑核大小\n", " smooth_data: 是否平滑数据\n", " smooth_kernel_std: 平滑核标准差\n", " random_state: 随机种子\n", " \"\"\"\n", " self.white_noise_std = white_noise_std\n", " self.constant_offset_std = constant_offset_std\n", " self.random_walk_std = random_walk_std\n", " self.random_walk_axis = random_walk_axis\n", " self.static_gain_std = static_gain_std\n", " self.random_cut = random_cut\n", " self.smooth_kernel_size = smooth_kernel_size\n", " self.smooth_data = smooth_data\n", " self.smooth_kernel_std = smooth_kernel_std\n", " self.random_state = random_state\n", " \n", " # 设置随机种子\n", " np.random.seed(random_state)\n", " random.seed(random_state)\n", " \n", " print(f\"🎯 神经数据增强器初始化完成\")\n", " print(f\" 白噪声标准差: {white_noise_std}\")\n", " print(f\" 常数偏移标准差: {constant_offset_std}\")\n", " print(f\" 随机游走标准差: {random_walk_std}\")\n", " print(f\" 静态增益标准差: {static_gain_std}\")\n", " print(f\" 随机切割步数: {random_cut}\")\n", " print(f\" 平滑数据: {smooth_data}\")\n", " if smooth_data:\n", " print(f\" 平滑核大小: {smooth_kernel_size}, 标准差: {smooth_kernel_std}\")\n", " \n", " def reconstruct_time_series(self, flattened_features, original_shape=(14, 512)):\n", " \"\"\"\n", " 将扁平化的特征重建为时序数据\n", " \n", " Args:\n", " flattened_features: 扁平化的特征 (7168维)\n", " original_shape: 原始时序形状 (时间步, 特征维度)\n", " \n", " Returns:\n", " time_series: 重建的时序数据 (time_steps, features)\n", " \"\"\"\n", " # 假设前7168维是神经特征 (14 * 512 = 7168)\n", " neural_data = flattened_features[:7168]\n", " time_series = neural_data.reshape(original_shape)\n", " return time_series\n", " \n", " def add_white_noise(self, data):\n", " \"\"\"添加白噪声\"\"\"\n", " if self.white_noise_std <= 0:\n", " return data\n", " \n", " noise = np.random.normal(0, self.white_noise_std, data.shape)\n", " return data + noise\n", " \n", " def add_constant_offset(self, data):\n", " \"\"\"添加常数偏移\"\"\"\n", " if self.constant_offset_std <= 0:\n", " return data\n", " \n", " # 为每个通道添加不同的常数偏移\n", " offset = np.random.normal(0, self.constant_offset_std, (1, data.shape[1]))\n", " return data + offset\n", " \n", " def add_random_walk(self, data):\n", " \"\"\"添加随机游走\"\"\"\n", " if self.random_walk_std <= 0:\n", " return data\n", " \n", " if self.random_walk_axis == -1: # 沿时间轴\n", " walk = np.random.normal(0, self.random_walk_std, data.shape[0])\n", " walk = np.cumsum(walk) # 累积求和形成随机游走\n", " walk = walk.reshape(-1, 1) # 广播到所有通道\n", " return data + walk\n", " else:\n", " walk = np.random.normal(0, self.random_walk_std, data.shape)\n", " walk = np.cumsum(walk, axis=self.random_walk_axis)\n", " return data + walk\n", " \n", " def apply_static_gain(self, data):\n", " \"\"\"应用静态增益\"\"\"\n", " if self.static_gain_std <= 0:\n", " return data\n", " \n", " # 为每个通道应用不同的增益\n", " gain = 1 + np.random.normal(0, self.static_gain_std, (1, data.shape[1]))\n", " return data * gain\n", " \n", " def random_time_cut(self, data):\n", " \"\"\"随机切割时间步\"\"\"\n", " if self.random_cut <= 0 or data.shape[0] <= self.random_cut:\n", " return data\n", " \n", " # 从开头随机切掉一些时间步\n", " cut_steps = np.random.randint(0, min(self.random_cut + 1, data.shape[0]))\n", " return data[cut_steps:]\n", " \n", " def smooth_data_func(self, data):\n", " \"\"\"平滑数据\"\"\"\n", " if not self.smooth_data or self.smooth_kernel_std <= 0:\n", " return data\n", " \n", " # 对每个通道分别应用高斯平滑\n", " smoothed_data = np.zeros_like(data)\n", " for i in range(data.shape[1]):\n", " smoothed_data[:, i] = gaussian_filter1d(\n", " data[:, i], \n", " sigma=self.smooth_kernel_std\n", " )\n", " return smoothed_data\n", " \n", " def augment_time_series(self, time_series_data):\n", " \"\"\"\n", " 对时序数据应用所有增强方法\n", " \n", " Args:\n", " time_series_data: 时序数据 (time_steps, features)\n", " \n", " Returns:\n", " augmented_data: 增强后的时序数据\n", " \"\"\"\n", " data = time_series_data.copy()\n", " \n", " # 1. 随机时间切割(在其他增强之前)\n", " data = self.random_time_cut(data)\n", " \n", " # 2. 添加白噪声\n", " data = self.add_white_noise(data)\n", " \n", " # 3. 添加常数偏移\n", " data = self.add_constant_offset(data)\n", " \n", " # 4. 添加随机游走\n", " data = self.add_random_walk(data)\n", " \n", " # 5. 应用静态增益\n", " data = self.apply_static_gain(data)\n", " \n", " # 6. 平滑数据(在最后应用)\n", " data = self.smooth_data_func(data)\n", " \n", " return data\n", " \n", " def flatten_time_series(self, time_series_data, target_length=7168):\n", " \"\"\"\n", " 将时序数据重新扁平化为目标长度\n", " \n", " Args:\n", " time_series_data: 时序数据 (time_steps, features)\n", " target_length: 目标扁平化长度\n", " \n", " Returns:\n", " flattened_data: 扁平化的数据\n", " \"\"\"\n", " flattened = time_series_data.flatten()\n", " \n", " # 如果长度不够,用零填充\n", " if len(flattened) < target_length:\n", " padded = np.zeros(target_length)\n", " padded[:len(flattened)] = flattened\n", " return padded\n", " # 如果长度超过,截断\n", " elif len(flattened) > target_length:\n", " return flattened[:target_length]\n", " else:\n", " return flattened\n", " \n", " def augment_neural_features(self, features_batch, augment_ratio=0.5):\n", " \"\"\"\n", " 对神经特征批次进行数据增强\n", " \n", " Args:\n", " features_batch: 特征批次 (batch_size, 7168)\n", " augment_ratio: 增强比例(0-1之间)\n", " \n", " Returns:\n", " augmented_features: 增强后的特征(包含原始和增强的数据)\n", " augmented_indices: 增强样本的索引\n", " \"\"\"\n", " batch_size = features_batch.shape[0]\n", " n_augment = int(batch_size * augment_ratio)\n", " \n", " if n_augment == 0:\n", " return features_batch, []\n", " \n", " # 随机选择要增强的样本\n", " augment_indices = np.random.choice(batch_size, n_augment, replace=False)\n", " \n", " augmented_features = []\n", " \n", " for i, features in enumerate(features_batch):\n", " if i in augment_indices:\n", " # 重建时序数据\n", " time_series = self.reconstruct_time_series(features)\n", " \n", " # 应用数据增强\n", " augmented_time_series = self.augment_time_series(time_series)\n", " \n", " # 重新扁平化\n", " augmented_features_flat = self.flatten_time_series(augmented_time_series)\n", " \n", " augmented_features.append(augmented_features_flat)\n", " else:\n", " augmented_features.append(features)\n", " \n", " return np.array(augmented_features), augment_indices.tolist()\n", "\n", "# 创建数据增强器实例\n", "augmenter = NeuralDataAugmenter(\n", " white_noise_std=1.0,\n", " constant_offset_std=0.2,\n", " random_walk_std=0.0,\n", " random_walk_axis=-1,\n", " static_gain_std=0.0,\n", " random_cut=3,\n", " smooth_kernel_size=100,\n", " smooth_data=True,\n", " smooth_kernel_std=2,\n", " random_state=42\n", ")\n", "\n", "print(\"✅ 数据增强器创建完成!\")" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "✅ 数据增强功能已集成到智能管道(已修复数组比较问题)\n" ] } ], "source": [ "# 🔗 集成数据增强到智能管道\n", "\n", "def extract_features_labels_batch_with_augmentation(trials_batch, random_shuffle_trials=True, \n", " apply_augmentation=True, augment_ratio=0.3):\n", " \"\"\"\n", " 从试验批次中提取特征和标签,支持数据增强\n", " \n", " Args:\n", " trials_batch: 试验批次数据\n", " random_shuffle_trials: 是否随机打乱试验顺序\n", " apply_augmentation: 是否应用数据增强\n", " augment_ratio: 数据增强比例\n", " \"\"\"\n", " features = []\n", " labels = []\n", " \n", " # 随机打乱试验顺序\n", " if random_shuffle_trials and len(trials_batch) > 1:\n", " trial_indices = list(range(len(trials_batch)))\n", " random.shuffle(trial_indices)\n", " trials_batch = trials_batch[trial_indices]\n", " \n", " for trial in trials_batch:\n", " if trial.shape[0] > 0:\n", " # 随机打乱时间步顺序\n", " time_indices = list(range(trial.shape[0]))\n", " if random_shuffle_trials:\n", " random.shuffle(time_indices)\n", " \n", " for t in time_indices:\n", " neural_features = trial[t, :7168] # 前7168维神经特征\n", " rnn_logits = trial[t, 7168:] # 后41维RNN输出\n", " phoneme_label = np.argmax(rnn_logits)\n", " \n", " features.append(neural_features)\n", " labels.append(phoneme_label)\n", " \n", " if not features:\n", " return np.array([]), np.array([])\n", " \n", " features = np.array(features)\n", " labels = np.array(labels)\n", " \n", " # 应用数据增强\n", " if apply_augmentation and len(features) > 0:\n", " print(f\" 应用数据增强 (比例: {augment_ratio})\")\n", " features, augmented_indices = augmenter.augment_neural_features(features, augment_ratio)\n", " print(f\" 增强样本数: {len(augmented_indices)}\")\n", " \n", " return features, labels\n", "\n", "def _apply_full_sampling(self, X, y):\n", " \"\"\"\n", " 应用完整的采样策略(修改为只下采样到第三小样本数)\n", " \"\"\"\n", " X_result = []\n", " y_result = []\n", " \n", " np.random.seed(self.random_state)\n", " \n", " # 确保y是1维numpy数组,避免数组比较的歧义\n", " y = np.asarray(y, dtype=int).flatten()\n", " \n", " for label in range(41):\n", " label_mask = (y == label)\n", " X_label = X[label_mask]\n", " y_label = y[label_mask]\n", " current_count = len(y_label)\n", " \n", " if current_count == 0:\n", " continue\n", " \n", " strategy = self.sampling_strategy[label]\n", " target_count = strategy['target_count']\n", " \n", " if strategy['action'] == 'undersample' and current_count > target_count:\n", " # 只进行下采样到第三小样本数\n", " indices = np.random.choice(current_count, target_count, replace=False)\n", " X_resampled = X_label[indices]\n", " y_resampled = y_label[indices]\n", " else:\n", " # 保持原样(不进行过采样)\n", " X_resampled = X_label\n", " y_resampled = y_label\n", " \n", " X_result.append(X_resampled)\n", " y_result.append(y_resampled)\n", " \n", " if X_result:\n", " return np.vstack(X_result), np.hstack(y_result)\n", " else:\n", " return np.array([]).reshape(0, X.shape[1]), np.array([])\n", "\n", "def _apply_pca_transform(self, X):\n", " \"\"\"\n", " 应用PCA变换\n", " \"\"\"\n", " if not self.pca_fitted:\n", " return X\n", " \n", " X_scaled = self.pca_scaler.transform(X)\n", " X_pca = self.pca_model.transform(X_scaled)\n", " return X_pca\n", "\n", "def _apply_full_sampling_with_augmentation(self, X, y, apply_augmentation=True, augment_ratio=0.3):\n", " \"\"\"\n", " 应用完整的采样策略(修改为只下采样到第三小样本数)+ 数据增强\n", " \"\"\"\n", " X_result = []\n", " y_result = []\n", " \n", " np.random.seed(self.random_state)\n", " \n", " # 确保y是1维numpy数组,避免数组比较的歧义\n", " y = np.asarray(y, dtype=int).flatten()\n", " \n", " for label in range(41):\n", " label_mask = (y == label)\n", " X_label = X[label_mask]\n", " y_label = y[label_mask]\n", " current_count = len(y_label)\n", " \n", " if current_count == 0:\n", " continue\n", " \n", " strategy = self.sampling_strategy[label]\n", " target_count = strategy['target_count']\n", " \n", " if strategy['action'] == 'undersample' and current_count > target_count:\n", " # 只进行下采样到第三小样本数\n", " indices = np.random.choice(current_count, target_count, replace=False)\n", " X_resampled = X_label[indices]\n", " y_resampled = y_label[indices]\n", " else:\n", " # 保持原样(不进行过采样)\n", " X_resampled = X_label\n", " y_resampled = y_label\n", " \n", " # 对下采样后的数据应用数据增强\n", " if apply_augmentation and len(X_resampled) > 0:\n", " X_resampled, _ = augmenter.augment_neural_features(X_resampled, augment_ratio)\n", " \n", " X_result.append(X_resampled)\n", " y_result.append(y_resampled)\n", " \n", " if X_result:\n", " return np.vstack(X_result), np.hstack(y_result)\n", " else:\n", " return np.array([]).reshape(0, X.shape[1]), np.array([])\n", "\n", "def step3_process_data(self, data_type, apply_sampling=None):\n", " \"\"\"\n", " 步骤3: 处理数据(采样+PCA降维)- 原始版本(无数据增强)\n", " \"\"\"\n", " if not self.pca_fitted:\n", " raise ValueError(\"请先执行步骤2: step2_fit_pca_with_undersampling()\")\n", " \n", " if apply_sampling is None:\n", " apply_sampling = (data_type == 'train')\n", " \n", " print(f\"\\n处理{data_type}数据...\")\n", " print(f\" 采样策略: {'启用(只下采样)' if apply_sampling else '禁用'}\")\n", " \n", " all_features = []\n", " all_labels = []\n", " \n", " # 在内部关闭加载时的逐文件打印\n", " for trials_batch, filename in load_data_batch(self.data_dir, data_type, 3000, verbose=False):\n", " features, labels = extract_features_labels_batch(trials_batch)\n", " \n", " if apply_sampling:\n", " features_sampled, labels_sampled = self._apply_full_sampling(features, labels)\n", " else:\n", " features_sampled, labels_sampled = features, labels\n", " \n", " if features_sampled.shape[0] > 0:\n", " features_pca = self._apply_pca_transform(features_sampled)\n", " all_features.append(features_pca)\n", " all_labels.append(labels_sampled)\n", " \n", " if all_features:\n", " X = np.vstack(all_features)\n", " y = np.hstack(all_labels)\n", " \n", " shuffle_indices = np.random.permutation(len(y))\n", " X = X[shuffle_indices]\n", " y = y[shuffle_indices]\n", " \n", " print(f\" 完成: {X.shape[0]:,} 样本, {X.shape[1]} 特征\")\n", " \n", " del all_features, all_labels\n", " gc.collect()\n", " \n", " return X, y\n", " else:\n", " return None, None\n", "\n", "def step3_process_data_with_augmentation(self, data_type, apply_sampling=None, \n", " apply_augmentation=True, augment_ratio=0.3):\n", " \"\"\"\n", " 步骤3: 处理数据(采样+数据增强+PCA降维)\n", " \"\"\"\n", " if not self.pca_fitted:\n", " raise ValueError(\"请先执行步骤2: step2_fit_pca_with_undersampling()\")\n", " \n", " if apply_sampling is None:\n", " apply_sampling = (data_type == 'train')\n", " \n", " # 只对训练数据应用数据增强\n", " if data_type != 'train':\n", " apply_augmentation = False\n", " \n", " print(f\"\\n处理{data_type}数据...\")\n", " print(f\" 采样策略: {'启用(只下采样)' if apply_sampling else '禁用'}\")\n", " print(f\" 数据增强: {'启用' if apply_augmentation else '禁用'}\")\n", " if apply_augmentation:\n", " print(f\" 增强比例: {augment_ratio}\")\n", " \n", " all_features = []\n", " all_labels = []\n", " \n", " # 在内部关闭加载时的逐文件打印\n", " for trials_batch, filename in load_data_batch(self.data_dir, data_type, 3000, verbose=False):\n", " # 使用带数据增强的特征提取函数\n", " features, labels = extract_features_labels_batch_with_augmentation(\n", " trials_batch, \n", " random_shuffle_trials=True,\n", " apply_augmentation=apply_augmentation,\n", " augment_ratio=augment_ratio\n", " )\n", " \n", " if apply_sampling and len(features) > 0:\n", " features_sampled, labels_sampled = self._apply_full_sampling(features, labels)\n", " else:\n", " features_sampled, labels_sampled = features, labels\n", " \n", " if features_sampled.shape[0] > 0:\n", " features_pca = self._apply_pca_transform(features_sampled)\n", " all_features.append(features_pca)\n", " all_labels.append(labels_sampled)\n", " \n", " if all_features:\n", " X = np.vstack(all_features)\n", " y = np.hstack(all_labels)\n", " \n", " shuffle_indices = np.random.permutation(len(y))\n", " X = X[shuffle_indices]\n", " y = y[shuffle_indices]\n", " \n", " print(f\" 完成: {X.shape[0]:,} 样本, {X.shape[1]} 特征\")\n", " \n", " del all_features, all_labels\n", " gc.collect()\n", " \n", " return X, y\n", " else:\n", " return None, None\n", "\n", "# 动态添加所有方法到类(包括原始版本和增强版本)\n", "# 确保添加所有必需的方法\n", "SmartDataPipeline._apply_full_sampling = _apply_full_sampling\n", "SmartDataPipeline._apply_pca_transform = _apply_pca_transform \n", "SmartDataPipeline._apply_full_sampling_with_augmentation = _apply_full_sampling_with_augmentation\n", "SmartDataPipeline.step3_process_data = step3_process_data\n", "SmartDataPipeline.step3_process_data_with_augmentation = step3_process_data_with_augmentation\n", "\n", "print(\"✅ 数据增强功能已集成到智能管道(已修复数组比较问题)\")" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "🧪 测试神经数据增强功能...\n", "============================================================\n", "模拟时序数据形状: (5, 14, 512)\n", "扁平化特征形状: (5, 7168)\n", "\n", "测试单个样本的数据增强:\n", " 原始样本统计: mean=-0.007, std=1.004\n", " 重建时序形状: (14, 512)\n", " 增强时序形状: (12, 512)\n", " 增强样本统计: mean=0.005, std=0.599\n", "\n", "测试批量数据增强:\n", " 原始批次形状: (5, 7168)\n", " 增强批次形状: (5, 7168)\n", " 增强样本索引: [3, 2, 1]\n", "\n", "统计比较:\n", " 样本 0: ➡️ 原始\n", " 原始: mean=-0.007, std=1.004\n", " 处理: mean=-0.007, std=1.004\n", " 样本 1: 🎯 增强\n", " 原始: mean=0.008, std=1.000\n", " 处理: mean=0.007, std=0.604\n", " 样本 2: 🎯 增强\n", " 原始: mean=0.011, std=1.005\n", " 处理: mean=0.005, std=0.600\n", " 样本 3: 🎯 增强\n", " 原始: mean=-0.008, std=0.985\n", " 处理: mean=-0.006, std=0.610\n", " 样本 4: ➡️ 原始\n", " 原始: mean=-0.021, std=1.015\n", " 处理: mean=-0.021, std=1.015\n", "\n", "✅ 数据增强测试完成!\n", "\n", "📋 当前数据增强配置:\n", " 白噪声标准差: 1.0\n", " 常数偏移标准差: 0.2\n", " 随机游走标准差: 0.0\n", " 静态增益标准差: 0.0\n", " 随机切割步数: 3\n", " 数据平滑: True\n", " 平滑核标准差: 2\n" ] } ], "source": [ "# 🧪 测试数据增强功能\n", "\n", "print(\"🧪 测试神经数据增强功能...\")\n", "print(\"=\" * 60)\n", "\n", "# 创建模拟的时序神经数据\n", "np.random.seed(42)\n", "batch_size = 5\n", "n_timesteps = 14\n", "n_features = 512\n", "\n", "# 模拟神经信号数据\n", "mock_time_series = np.random.randn(batch_size, n_timesteps, n_features)\n", "print(f\"模拟时序数据形状: {mock_time_series.shape}\")\n", "\n", "# 扁平化为7168维特征\n", "mock_features = mock_time_series.reshape(batch_size, -1)\n", "print(f\"扁平化特征形状: {mock_features.shape}\")\n", "\n", "# 测试单个样本的数据增强\n", "print(f\"\\n测试单个样本的数据增强:\")\n", "original_sample = mock_features[0]\n", "print(f\" 原始样本统计: mean={original_sample.mean():.3f}, std={original_sample.std():.3f}\")\n", "\n", "# 重建时序数据\n", "reconstructed_ts = augmenter.reconstruct_time_series(original_sample)\n", "print(f\" 重建时序形状: {reconstructed_ts.shape}\")\n", "\n", "# 应用数据增强\n", "augmented_ts = augmenter.augment_time_series(reconstructed_ts)\n", "print(f\" 增强时序形状: {augmented_ts.shape}\")\n", "\n", "# 重新扁平化\n", "augmented_sample = augmenter.flatten_time_series(augmented_ts)\n", "print(f\" 增强样本统计: mean={augmented_sample.mean():.3f}, std={augmented_sample.std():.3f}\")\n", "\n", "# 测试批量数据增强\n", "print(f\"\\n测试批量数据增强:\")\n", "augmented_batch, aug_indices = augmenter.augment_neural_features(mock_features, augment_ratio=0.6)\n", "print(f\" 原始批次形状: {mock_features.shape}\")\n", "print(f\" 增强批次形状: {augmented_batch.shape}\")\n", "print(f\" 增强样本索引: {aug_indices}\")\n", "\n", "# 比较原始和增强数据的统计特性\n", "print(f\"\\n统计比较:\")\n", "for i in range(batch_size):\n", " original_stats = f\"mean={mock_features[i].mean():.3f}, std={mock_features[i].std():.3f}\"\n", " augmented_stats = f\"mean={augmented_batch[i].mean():.3f}, std={augmented_batch[i].std():.3f}\"\n", " status = \"🎯 增强\" if i in aug_indices else \"➡️ 原始\"\n", " print(f\" 样本 {i}: {status}\")\n", " print(f\" 原始: {original_stats}\")\n", " print(f\" 处理: {augmented_stats}\")\n", "\n", "print(f\"\\n✅ 数据增强测试完成!\")\n", "\n", "# 显示数据增强配置\n", "print(f\"\\n📋 当前数据增强配置:\")\n", "print(f\" 白噪声标准差: {augmenter.white_noise_std}\")\n", "print(f\" 常数偏移标准差: {augmenter.constant_offset_std}\")\n", "print(f\" 随机游走标准差: {augmenter.random_walk_std}\")\n", "print(f\" 静态增益标准差: {augmenter.static_gain_std}\")\n", "print(f\" 随机切割步数: {augmenter.random_cut}\")\n", "print(f\" 数据平滑: {augmenter.smooth_data}\")\n", "if augmenter.smooth_data:\n", " print(f\" 平滑核标准差: {augmenter.smooth_kernel_std}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🔥 执行智能数据处理管道" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "🚀 开始执行智能数据处理管道...\n", "============================================================\n", "\n", "======================🔍 STEP 1: 分析数据分布======================\n", "🔍 步骤1: 分析数据分布(第三小样本数策略)...\n", " 已随机打乱 41 个文件的加载顺序\n", " 正在加载文件 1/41: t15.2024.02.25_val_concatenated.npz\n", " 正在加载文件 2/41: t15.2023.10.08_val_concatenated.npz\n", " 正在加载文件 2/41: t15.2023.10.08_val_concatenated.npz\n", " 正在加载文件 3/41: t15.2025.03.30_val_concatenated.npz\n", " 正在加载文件 3/41: t15.2025.03.30_val_concatenated.npz\n", " 正在加载文件 4/41: t15.2023.11.19_val_concatenated.npz\n", " 正在加载文件 4/41: t15.2023.11.19_val_concatenated.npz\n", " 正在加载文件 5/41: t15.2023.11.26_val_concatenated.npz\n", " 正在加载文件 5/41: t15.2023.11.26_val_concatenated.npz\n", " 正在加载文件 6/41: t15.2024.07.28_val_concatenated.npz\n", " 正在加载文件 6/41: t15.2024.07.28_val_concatenated.npz\n", " 正在加载文件 7/41: t15.2024.07.21_val_concatenated.npz\n", " 正在加载文件 7/41: t15.2024.07.21_val_concatenated.npz\n", " 正在加载文件 8/41: t15.2023.09.29_val_concatenated.npz\n", " 正在加载文件 8/41: t15.2023.09.29_val_concatenated.npz\n", " 正在加载文件 9/41: t15.2025.01.10_val_concatenated.npz\n", " 正在加载文件 9/41: t15.2025.01.10_val_concatenated.npz\n", " 正在加载文件 10/41: t15.2025.04.13_val_concatenated.npz\n", " 正在加载文件 10/41: t15.2025.04.13_val_concatenated.npz\n", " 正在加载文件 11/41: t15.2024.07.19_val_concatenated.npz\n", " 正在加载文件 11/41: t15.2024.07.19_val_concatenated.npz\n", " 正在加载文件 12/41: t15.2023.11.04_val_concatenated.npz\n", " 正在加载文件 12/41: t15.2023.11.04_val_concatenated.npz\n", " 正在加载文件 13/41: t15.2023.11.03_val_concatenated.npz\n", " 正在加载文件 13/41: t15.2023.11.03_val_concatenated.npz\n", " 所有标签样本数: [29, 56, 76, 78, 87, 87, 89, 135, 136, 147]...\n", " 第三小样本数: 76\n", " ✅ 分析完成: 105,617 样本\n", " 📉 下采样标签: 38 个 → 76\n", " ✅ 保持不变: 3 个\n", " 🚫 不进行过采样\n", "\n", "📊 采样策略总结:\n", " 📉 下采样标签: 38 个\n", " 📈 过采样标签: 0 个\n", " ✅ 保持不变: 3 个\n", "\n", "✅ 步骤1完成!\n", " 所有标签样本数: [29, 56, 76, 78, 87, 87, 89, 135, 136, 147]...\n", " 第三小样本数: 76\n", " ✅ 分析完成: 105,617 样本\n", " 📉 下采样标签: 38 个 → 76\n", " ✅ 保持不变: 3 个\n", " 🚫 不进行过采样\n", "\n", "📊 采样策略总结:\n", " 📉 下采样标签: 38 个\n", " 📈 过采样标签: 0 个\n", " ✅ 保持不变: 3 个\n", "\n", "✅ 步骤1完成!\n" ] } ], "source": [ "# 🔥 执行智能数据处理管道【确定采样策略】\n", "\n", "print(\"🚀 开始执行智能数据处理管道...\")\n", "print(\"=\" * 60)\n", "\n", "# 步骤1: 分析数据分布\n", "print(\"\\n\" + \"🔍 STEP 1: 分析数据分布\".center(60, \"=\"))\n", "distribution, strategy = pipeline.step1_analyze_distribution()\n", "\n", "# 显示采样策略总结\n", "print(f\"\\n📊 采样策略总结:\")\n", "undersample_count = sum(1 for s in strategy.values() if s['action'] == 'undersample')\n", "oversample_count = sum(1 for s in strategy.values() if s['action'] == 'oversample')\n", "keep_count = sum(1 for s in strategy.values() if s['action'] == 'keep')\n", "\n", "print(f\" 📉 下采样标签: {undersample_count} 个\")\n", "print(f\" 📈 过采样标签: {oversample_count} 个\") \n", "print(f\" ✅ 保持不变: {keep_count} 个\")\n", "\n", "print(\"\\n✅ 步骤1完成!\")" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "=====================🔧 STEP 2: 拟合PCA参数======================\n", "\n", "🔧 步骤2: 拟合PCA参数(仅下采样,不过采样)...\n", " 用于PCA拟合的样本数: 15,000\n", " 用于PCA拟合的样本数: 15,000\n", " PCA拟合完成: 7168 → 1243\n", " 保留方差: 0.9489\n", " PCA拟合完成: 7168 → 1243\n", " 保留方差: 0.9489\n" ] } ], "source": [ "# 步骤2: 拟合PCA参数【确定PCA策略】\n", "print(\"\\n\" + \"🔧 STEP 2: 拟合PCA参数\".center(60, \"=\"))\n", "pipeline.step2_fit_pca_with_undersampling()\n", "\n", "# print(\"\\n✅ 步骤2完成!\")\n", "# pipeline.print_summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🚀 使用智能管道进行分批训练" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "# 使用智能管道进行分批训练\n", "import lightgbm as lgb\n", "import time\n", "from collections import Counter\n", "import matplotlib.pyplot as plt\n", "import random\n", "import numpy as np\n", "import os\n", "import gc\n", "\n", "class SmartBatchTrainer:\n", " \"\"\"\n", " 智能分批训练器,集成智能数据管道\n", " \"\"\"\n", " \n", " def __init__(self, pipeline, params=None, min_learning_rate=1e-4, t_0=50, t_mult=2):\n", " self.pipeline = pipeline\n", " self.model = None\n", " self.training_history = {} # 改为字典,因为只有一次训练\n", " self.batch_count = 0\n", " self.min_learning_rate = min_learning_rate\n", " self.lr_history = [] # 用于可视化\n", " \n", " # 带重启的余弦退火参数\n", " self.t_0 = t_0 # 第一个重启周期的长度\n", " self.t_mult = t_mult # 重启周期的乘数\n", " \n", " # 默认LightGBM参数(GPU优化)\n", " self.params = params or {\n", " 'objective': 'multiclass',\n", " 'num_class': 41,\n", " 'metric': 'multi_logloss',\n", " 'boosting_type': 'gbdt',\n", " 'device_type': 'cpu',\n", " # 'gpu_platform_id': 0,\n", " # 'gpu_device_id': 0,\n", " 'max_bin': 255,\n", " 'num_leaves': 127,\n", " 'learning_rate': 0.10, #默认0.08\n", " 'feature_fraction': 0.8,\n", " 'bagging_fraction': 0.8,\n", " 'bagging_freq': 5,\n", " 'min_data_in_leaf': 20,\n", " 'lambda_l1': 0.1,\n", " 'lambda_l2': 0.1,\n", " 'verbose': -1,\n", " 'num_threads': -1\n", " }\n", " \n", " self.initial_learning_rate = self.params.get('learning_rate', 0.08)\n", " \n", " print(f\"智能分批训练器创建完成\")\n", " print(f\" LightGBM参数已配置:{self.params['device_type'].upper()}模式\")\n", " print(f\" 学习率调度: 带重启的余弦退火 (从 {self.initial_learning_rate} 到 {self.min_learning_rate})\")\n", " print(f\" 重启参数: T_0={self.t_0}, T_mult={self.t_mult}\")\n", " \n", " def prepare_validation_data(self):\n", " \"\"\"准备验证数据(仅PCA,保持原始分布)\"\"\"\n", " print(\"准备验证数据...\")\n", " X_val, y_val = self.pipeline.step3_process_data('val', apply_sampling=False)\n", " if X_val is None:\n", " raise ValueError(\"无法加载验证数据\")\n", " val_counts = Counter(y_val)\n", " print(f\" 验证数据准备完成: {X_val.shape[0]:,} 样本\")\n", " print(f\" 验证集分布 (标签0: {val_counts.get(0, 0):,}, 标签40: {val_counts.get(40, 0):,})\")\n", " \n", " # 缓存原始数组,便于计算accuracy\n", " self._X_val_np = X_val\n", " self._y_val_np = y_val\n", " \n", " return lgb.Dataset(X_val, label=y_val, free_raw_data=False)\n", " \n", " def get_training_batch_generator(self, n_files_per_batch=4, batch_size=8000):\n", " \"\"\"改进的训练批次生成器:每次从所有文件中随机选择n个文件,然后随机采样\"\"\"\n", " print(f\"准备改进的训练批次生成器...\")\n", " print(f\" 每批次选择文件数: {n_files_per_batch}\")\n", " print(f\" 每批次目标样本数: {batch_size:,}\")\n", " \n", " # 获取所有训练文件列表\n", " all_train_files = [f for f in os.listdir(self.pipeline.data_dir) \n", " if f.endswith('.npz') and 'train' in f]\n", " \n", " if len(all_train_files) < n_files_per_batch:\n", " print(f\" 可用文件数({len(all_train_files)})少于每批次需要的文件数({n_files_per_batch})\")\n", " n_files_per_batch = len(all_train_files)\n", " \n", " print(f\" 总计可用训练文件: {len(all_train_files)}\")\n", " \n", " batch_id = 0\n", " while True: # 无限生成器,可以重复采样\n", " batch_id += 1\n", " \n", " # 随机选择n个文件\n", " selected_files = random.sample(all_train_files, n_files_per_batch)\n", " \n", " print(f\" 批次 {batch_id} - 随机选择的文件:\")\n", " for i, f in enumerate(selected_files, 1):\n", " print(f\" {i}. {f}\")\n", " \n", " # 从选中的文件中加载数据\n", " all_features = []\n", " all_labels = []\n", " total_available_samples = 0\n", " \n", " for filename in selected_files:\n", " # 加载文件数据\n", " data = np.load(os.path.join(self.pipeline.data_dir, filename), allow_pickle=True)\n", " trials = data['neural_logits_concatenated']\n", " \n", " # 提取特征和标签\n", " features, labels = extract_features_labels_batch(trials)\n", " \n", " if features.shape[0] > 0:\n", " all_features.append(features)\n", " all_labels.append(labels)\n", " total_available_samples += features.shape[0]\n", " \n", " # 清理单个文件数据\n", " del data, trials\n", " gc.collect()\n", " \n", " if all_features:\n", " # 合并所有选中文件的数据\n", " combined_features = np.vstack(all_features)\n", " combined_labels = np.hstack(all_labels)\n", " \n", " print(f\" 合并后总样本数: {combined_features.shape[0]:,}\")\n", " \n", " # 随机采样到目标batch_size\n", " if combined_features.shape[0] > batch_size:\n", " # 随机选择batch_size个样本\n", " sample_indices = np.random.choice(\n", " combined_features.shape[0], \n", " size=batch_size, \n", " replace=False\n", " )\n", " sampled_features = combined_features[sample_indices]\n", " sampled_labels = combined_labels[sample_indices]\n", " print(f\" 随机采样到: {batch_size:,} 样本\")\n", " else:\n", " # 如果样本不足,使用所有样本\n", " sampled_features = combined_features\n", " sampled_labels = combined_labels\n", " print(f\" 样本不足,使用全部: {sampled_features.shape[0]:,} 样本\")\n", " \n", " # 应用采样策略(平衡处理)\n", " features_balanced, labels_balanced = self.pipeline._apply_full_sampling(\n", " sampled_features, sampled_labels\n", " )\n", " \n", " # 应用PCA降维\n", " if features_balanced.shape[0] > 0:\n", " features_pca = self.pipeline._apply_pca_transform(features_balanced)\n", " \n", " # 分析当前批次分布\n", " batch_counts = Counter(labels_balanced)\n", " \n", " print(f\" 批次 {batch_id} 最终结果:\")\n", " print(f\" 平衡后样本数: {features_pca.shape[0]:,}\")\n", " print(f\" 特征维度: {features_pca.shape[1]}\")\n", " print(f\" 分布: 标签0={batch_counts.get(0,0)}, 标签40={batch_counts.get(40,0)}\")\n", " print(f\" \" + \"=\"*50)\n", " \n", " # 重要修复:设置 free_raw_data=False 避免增量训练失败\n", " yield lgb.Dataset(features_pca, label=labels_balanced, free_raw_data=False), f\"batch_{batch_id}_files_{len(selected_files)}\"\n", " \n", " # 清理批次数据\n", " del all_features, all_labels, combined_features, combined_labels\n", " del sampled_features, sampled_labels, features_balanced, labels_balanced\n", " gc.collect()\n", " else:\n", " print(f\" 批次 {batch_id} 无有效数据\")\n", " continue\n", " \n", " def prepare_full_data(self):\n", " \"\"\"一次性准备所有训练和验证数据\"\"\"\n", " print(\"准备全量训练和验证数据...\")\n", " \n", " # 1. 准备验证数据 (保持原始分布)\n", " X_val, y_val = self.pipeline.step3_process_data('val', apply_sampling=False)\n", " if X_val is None:\n", " raise ValueError(\"无法加载验证数据\")\n", " val_counts = Counter(y_val)\n", " print(f\" 验证数据准备完成: {X_val.shape[0]:,} 样本\")\n", " print(f\" 验证集分布 (标签0: {val_counts.get(0, 0):,}, 标签40: {val_counts.get(40, 0):,})\")\n", " val_data = lgb.Dataset(X_val, label=y_val, free_raw_data=False)\n", " \n", " # 2. 准备训练数据 (应用完整采样和PCA策略)\n", " X_train, y_train = self.pipeline.step3_process_data('train', apply_sampling=True)\n", " if X_train is None:\n", " raise ValueError(\"无法加载训练数据\")\n", " train_counts = Counter(y_train)\n", " print(f\" 训练数据准备完成: {X_train.shape[0]:,} 样本, {X_train.shape[1]} 特征\")\n", " print(f\" 训练集(采样后)分布 (标签0: {train_counts.get(0, 0):,}, 标签40: {train_counts.get(40, 0):,})\")\n", " train_data = lgb.Dataset(X_train, label=y_train)\n", " \n", " return train_data, val_data, X_val, y_val\n", " \n", " def prepare_training_data(self):\n", " \"\"\"准备训练数据(仅PCA,保持原始分布)\"\"\"\n", " print(\"准备训练数据...\")\n", " # 2. 准备训练数据 (应用完整采样和PCA策略)\n", " X_train, y_train = self.pipeline.step3_process_data('train', apply_sampling=True)\n", " if X_train is None:\n", " raise ValueError(\"无法加载训练数据\")\n", " train_counts = Counter(y_train)\n", " print(f\" 训练数据准备完成: {X_train.shape[0]:,} 样本, {X_train.shape[1]} 特征\")\n", " print(f\" 训练集(采样后)分布 (标签0: {train_counts.get(0, 0):,}, 标签40: {train_counts.get(40, 0):,})\")\n", " \n", " return lgb.Dataset(X_train, label=y_train, free_raw_data=False)\n", " \n", " # 带重启的余弦退火调度器函数\n", " def _cosine_annealing_with_warm_restarts(self, current_round):\n", " \"\"\"带重启的余弦退火调度器 (SGDR)\"\"\"\n", " eta_max = self.initial_learning_rate\n", " eta_min = self.min_learning_rate\n", " \n", " # 计算当前在哪个重启周期中\n", " t_cur = current_round\n", " t_i = self.t_0\n", " \n", " # 找到当前的重启周期\n", " cycle = 0\n", " while t_cur >= t_i:\n", " t_cur -= t_i\n", " cycle += 1\n", " t_i *= self.t_mult\n", " \n", " # 在当前周期内的位置\n", " progress = t_cur / t_i\n", " \n", " # 计算学习率\n", " lr = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * progress))\n", " \n", " return lr\n", " \n", " def train_incremental(self, num_boost_round=100, early_stopping_rounds=10, \n", " n_files_per_batch=4, batch_size=8000, max_batches=None):\n", " \"\"\"增量分批训练 - 支持自定义批次参数\"\"\"\n", " print(f\"开始智能分批训练...\")\n", " print(f\" 训练轮数 (每批次): {num_boost_round}\")\n", " print(f\" 早停轮数: {early_stopping_rounds}\")\n", " print(f\" 每批次文件数: {n_files_per_batch}\")\n", " print(f\" 每批次样本数: {batch_size:,}\")\n", " if max_batches:\n", " print(f\" 最大批次数: {max_batches}\")\n", " print(\"=\" * 60)\n", " \n", " # 准备验证数据\n", " val_data = self.prepare_validation_data()\n", " \n", " print(f\"开始分批增量训练...\")\n", " total_start_time = time.time()\n", " \n", " # 初始化训练历史\n", " self.training_history = []\n", " \n", " # 创建改进的生成器\n", " batch_generator = self.get_training_batch_generator(n_files_per_batch, batch_size)\n", " \n", " for train_data, batch_name in batch_generator:\n", " self.batch_count += 1\n", " batch_start_time = time.time()\n", " \n", " # 检查是否达到最大批次数\n", " if max_batches and self.batch_count > max_batches:\n", " print(f\"达到最大批次数 {max_batches},停止训练\")\n", " break\n", " \n", " # 先构建数据集,使得可以安全访问 num_data()\n", " try:\n", " train_data.construct()\n", " except Exception:\n", " pass\n", "\n", " print(f\"\\n批次 {self.batch_count}: {batch_name}\")\n", " try:\n", " print(f\" 样本数: {train_data.num_data():,}\")\n", " except Exception:\n", " print(\" 样本数: (未构建,跳过显示)\")\n", " \n", " # 计算当前批次的学习率\n", " current_lr = self._cosine_annealing_with_warm_restarts(\n", " (self.batch_count - 1) * num_boost_round\n", " )\n", " \n", " # 更新训练参数中的学习率\n", " current_params = self.params.copy()\n", " current_params['learning_rate'] = current_lr\n", " \n", " try:\n", " # 训练参数\n", " train_params = {\n", " 'params': current_params,\n", " 'train_set': train_data,\n", " 'num_boost_round': num_boost_round,\n", " 'valid_sets': [val_data],\n", " 'valid_names': ['validation'],\n", " 'callbacks': [\n", " lgb.log_evaluation(period=1, show_stdv=False) # 1轮打印一次,减少重复\n", " ]\n", " }\n", " \n", " # 如果有早停设置\n", " if early_stopping_rounds:\n", " train_params['callbacks'].append(\n", " lgb.early_stopping(early_stopping_rounds, verbose=False)\n", " )\n", " \n", " # 增量训练\n", " if self.model is None:\n", " # 第一次训练\n", " print(f\" 首次训练 (学习率: {current_lr:.6f})\")\n", " self.model = lgb.train(**train_params)\n", " else:\n", " # 增量训练\n", " print(f\" 增量训练 (学习率: {current_lr:.6f})\")\n", " train_params['init_model'] = self.model\n", " self.model = lgb.train(**train_params)\n", " \n", " # 验证 - 修复数组比较的歧义性问题\n", " # 优先使用缓存的验证集数组,退回到val_data中的数据\n", " Xv = getattr(self, '_X_val_np', None) \n", " yv = getattr(self, '_y_val_np', None)\n", " \n", " if Xv is None or yv is None:\n", " print(\" 警告: 无法获取验证数据,跳过准确率计算\")\n", " val_accuracy = 0.0\n", " else:\n", " val_pred = self.model.predict(Xv)\n", " \n", " # 确保yv是1维numpy数组,避免数组比较的歧义\n", " yv = np.asarray(yv, dtype=int).flatten()\n", " \n", " # 计算验证准确率\n", " pred_labels = np.argmax(val_pred, axis=1)\n", " pred_labels = np.asarray(pred_labels, dtype=int).flatten()\n", " \n", " # 确保两个数组形状一致\n", " if len(pred_labels) != len(yv):\n", " print(f\" 警告: 预测标签数({len(pred_labels)}) != 真实标签数({len(yv)})\")\n", " min_len = min(len(pred_labels), len(yv))\n", " pred_labels = pred_labels[:min_len]\n", " yv = yv[:min_len]\n", " \n", " # 使用更安全的数组比较方式\n", " try:\n", " comparison = np.equal(pred_labels, yv)\n", " val_accuracy = float(np.mean(comparison))\n", " except Exception as e:\n", " print(f\" 数组比较错误: {e}\")\n", " val_accuracy = 0.0\n", " \n", " # 记录训练历史\n", " batch_time = time.time() - batch_start_time\n", " try:\n", " samples_cnt = train_data.num_data()\n", " except Exception:\n", " samples_cnt = None\n", " self.training_history.append({\n", " 'batch': self.batch_count,\n", " 'batch_name': batch_name,\n", " 'val_accuracy': val_accuracy,\n", " 'time': batch_time,\n", " 'num_trees': self.model.num_trees(),\n", " 'learning_rate': current_lr,\n", " 'samples': samples_cnt\n", " })\n", " \n", " print(f\" 批次完成:\")\n", " print(f\" 验证准确率: {val_accuracy:.4f}\")\n", " print(f\" 训练时间: {batch_time:.1f}秒\")\n", " print(f\" 模型树数: {self.model.num_trees()}\")\n", " print(f\" 当前学习率: {current_lr:.6f}\")\n", " \n", " except Exception as e:\n", " print(f\" 批次训练失败: {e}\")\n", " import traceback\n", " traceback.print_exc()\n", " continue\n", " \n", " # 训练完成\n", " total_time = time.time() - total_start_time\n", " print(f\"\\n增量训练完成!\")\n", " print(f\" 总批次数: {len(self.training_history)}\")\n", " print(f\" 总训练时间: {total_time:.1f}秒\")\n", " \n", " if self.training_history:\n", " best_batch = max(self.training_history, key=lambda x: x['val_accuracy'])\n", " print(f\" 最佳准确率: {best_batch['val_accuracy']:.4f} (批次 {best_batch['batch']})\")\n", " final_accuracy = self.training_history[-1]['val_accuracy']\n", " print(f\" 最终准确率: {final_accuracy:.4f}\")\n", " \n", " return self.model\n", "\n", " @staticmethod\n", " def _ctc_collapse(seq, blank=0, drop_sep40=False):\n", " out = []\n", " prev = None\n", " for s in seq:\n", " if s == prev:\n", " continue\n", " prev = s\n", " if s == blank:\n", " continue\n", " if drop_sep40 and s == 40:\n", " continue\n", " out.append(int(s))\n", " return out\n", "\n", " @staticmethod\n", " def _levenshtein(a, b):\n", " # a, b are lists of ints\n", " n, m = len(a), len(b)\n", " if n == 0:\n", " return m\n", " if m == 0:\n", " return n\n", " dp = list(range(m + 1))\n", " for i in range(1, n + 1):\n", " prev = dp[0]\n", " dp[0] = i\n", " ai = a[i - 1]\n", " for j in range(1, m + 1):\n", " tmp = dp[j]\n", " cost = 0 if ai == b[j - 1] else 1\n", " dp[j] = min(dp[j] + 1, # deletion\n", " dp[j - 1] + 1, # insertion\n", " prev + cost) # substitution\n", " prev = tmp\n", " return dp[m]\n", "\n", " def evaluate_val_per_experiment(self, fraction=0.33, random_state=42, drop_sep40=False, max_trials_per_file=None):\n", " \"\"\"使用所有验证文件,每个文件抽取33%的trial,按trial计算PER并求均值\"\"\"\n", " if self.model is None:\n", " raise RuntimeError(\"模型尚未训练,无法评估PER\")\n", "\n", " rng = np.random.RandomState(random_state)\n", " val_files = [f for f in os.listdir(self.pipeline.data_dir) if f.endswith('.npz') and 'val' in f]\n", " if not val_files:\n", " raise FileNotFoundError(\"未找到验证集npz文件\")\n", "\n", " results_by_file = {}\n", " per_list = []\n", " corpus_edit = 0\n", " corpus_len = 0\n", " total_trials = 0\n", "\n", " for vf in sorted(val_files):\n", " data = np.load(os.path.join(self.pipeline.data_dir, vf), allow_pickle=True)\n", " trials = data['neural_logits_concatenated']\n", " n_trials = len(trials)\n", " if n_trials == 0:\n", " results_by_file[vf] = {'n': 0, 'mean_PER': None}\n", " continue\n", " k = max(1, int(np.ceil(n_trials * fraction)))\n", " idx = np.arange(n_trials)\n", " idx = rng.choice(idx, size=k, replace=False)\n", " if max_trials_per_file is not None:\n", " k = min(k, max_trials_per_file)\n", " idx = idx[:k]\n", "\n", " trial_pers = []\n", " for ti in idx:\n", " tr = trials[ti]\n", " X_trial = tr[:, :7168]\n", " rnn_logits = tr[:, 7168:]\n", " # 变换到PCA空间\n", " X_trial_pca = self.pipeline._apply_pca_transform(X_trial)\n", " # 预测\n", " pred_proba = self.model.predict(X_trial_pca)\n", " y_pred_seq = np.argmax(pred_proba, axis=1)\n", " y_true_seq = np.argmax(rnn_logits, axis=1)\n", " # CTC折叠\n", " pred_collapsed = self._ctc_collapse(y_pred_seq, blank=0, drop_sep40=drop_sep40)\n", " true_collapsed = self._ctc_collapse(y_true_seq, blank=0, drop_sep40=drop_sep40)\n", " if len(true_collapsed) == 0:\n", " continue\n", " ed = self._levenshtein(pred_collapsed, true_collapsed)\n", " per = ed / len(true_collapsed)\n", " trial_pers.append(per)\n", " corpus_edit += ed\n", " corpus_len += len(true_collapsed)\n", " total_trials += 1\n", "\n", " if trial_pers:\n", " results_by_file[vf] = {\n", " 'n': len(trial_pers),\n", " 'mean_PER': float(np.mean(trial_pers))\n", " }\n", " per_list.extend(trial_pers)\n", " else:\n", " results_by_file[vf] = {'n': 0, 'mean_PER': None}\n", "\n", " del data, trials\n", " gc.collect()\n", "\n", " overall_mean = float(np.mean(per_list)) if per_list else None\n", " corpus_per = float(corpus_edit / corpus_len) if corpus_len > 0 else None\n", "\n", " summary = {\n", " 'overall_mean_PER': overall_mean,\n", " 'corpus_PER': corpus_per,\n", " 'total_trials': total_trials,\n", " 'per_file': results_by_file\n", " }\n", " print(\"验证集PER评估完成\")\n", " print(f\" 文件数: {len(val_files)} 评估trial数: {total_trials}\")\n", " print(f\" 平均PER(逐trial取均值): {overall_mean}\")\n", " print(f\" 语料级PER(总编辑距离/总长度): {corpus_per}\")\n", " return summary" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "✅ SmartBatchTrainer已添加数据增强支持\n" ] } ], "source": [ "# 🔧 为SmartBatchTrainer添加数据增强支持\n", "\n", "def prepare_training_data_with_augmentation(self, apply_augmentation=True, augment_ratio=0.3):\n", " \"\"\"准备训练数据(应用采样+数据增强+PCA)\"\"\"\n", " print(\"准备训练数据(含数据增强)...\")\n", " \n", " if apply_augmentation:\n", " print(f\" 数据增强: 启用 (比例: {augment_ratio})\")\n", " X_train, y_train = self.pipeline.step3_process_data_with_augmentation(\n", " 'train', apply_sampling=True, apply_augmentation=True, augment_ratio=augment_ratio\n", " )\n", " else:\n", " print(f\" 数据增强: 禁用\")\n", " X_train, y_train = self.pipeline.step3_process_data('train', apply_sampling=True)\n", " \n", " if X_train is None:\n", " raise ValueError(\"无法加载训练数据\")\n", " \n", " train_counts = Counter(y_train)\n", " print(f\" 训练数据准备完成: {X_train.shape[0]:,} 样本, {X_train.shape[1]} 特征\")\n", " print(f\" 训练集(采样+增强后)分布 (标签0: {train_counts.get(0, 0):,}, 标签40: {train_counts.get(40, 0):,})\")\n", " \n", " return lgb.Dataset(X_train, label=y_train, free_raw_data=False)\n", "\n", "def prepare_full_data_with_augmentation(self, apply_augmentation=True, augment_ratio=0.3):\n", " \"\"\"一次性准备所有训练和验证数据(含数据增强)\"\"\"\n", " print(\"准备全量训练和验证数据(含数据增强)...\")\n", " \n", " # 1. 准备验证数据 (保持原始分布,不增强)\n", " X_val, y_val = self.pipeline.step3_process_data('val', apply_sampling=False)\n", " if X_val is None:\n", " raise ValueError(\"无法加载验证数据\")\n", " val_counts = Counter(y_val)\n", " print(f\" 验证数据准备完成: {X_val.shape[0]:,} 样本\")\n", " print(f\" 验证集分布 (标签0: {val_counts.get(0, 0):,}, 标签40: {val_counts.get(40, 0):,})\")\n", " val_data = lgb.Dataset(X_val, label=y_val, free_raw_data=False)\n", " \n", " # 2. 准备训练数据 (应用完整采样和数据增强)\n", " if apply_augmentation:\n", " print(f\" 训练数据增强: 启用 (比例: {augment_ratio})\")\n", " X_train, y_train = self.pipeline.step3_process_data_with_augmentation(\n", " 'train', apply_sampling=True, apply_augmentation=True, augment_ratio=augment_ratio\n", " )\n", " else:\n", " print(f\" 训练数据增强: 禁用\")\n", " X_train, y_train = self.pipeline.step3_process_data('train', apply_sampling=True)\n", " \n", " if X_train is None:\n", " raise ValueError(\"无法加载训练数据\")\n", " train_counts = Counter(y_train)\n", " print(f\" 训练数据准备完成: {X_train.shape[0]:,} 样本, {X_train.shape[1]} 特征\")\n", " print(f\" 训练集(采样+增强后)分布 (标签0: {train_counts.get(0, 0):,}, 标签40: {train_counts.get(40, 0):,})\")\n", " train_data = lgb.Dataset(X_train, label=y_train)\n", " \n", " return train_data, val_data, X_val, y_val\n", "\n", "def get_training_batch_generator_with_augmentation(self, n_files_per_batch=4, batch_size=8000, \n", " apply_augmentation=True, augment_ratio=0.3):\n", " \"\"\"改进的训练批次生成器:每次从所有文件中随机选择n个文件,然后随机采样+数据增强\"\"\"\n", " print(f\"准备改进的训练批次生成器(含数据增强)...\")\n", " print(f\" 每批次选择文件数: {n_files_per_batch}\")\n", " print(f\" 每批次目标样本数: {batch_size:,}\")\n", " print(f\" 数据增强: {'启用' if apply_augmentation else '禁用'}\")\n", " if apply_augmentation:\n", " print(f\" 增强比例: {augment_ratio}\")\n", " \n", " # 获取所有训练文件列表\n", " all_train_files = [f for f in os.listdir(self.pipeline.data_dir) \n", " if f.endswith('.npz') and 'train' in f]\n", " \n", " if len(all_train_files) < n_files_per_batch:\n", " print(f\" 可用文件数({len(all_train_files)})少于每批次需要的文件数({n_files_per_batch})\")\n", " n_files_per_batch = len(all_train_files)\n", " \n", " print(f\" 总计可用训练文件: {len(all_train_files)}\")\n", " \n", " batch_id = 0\n", " while True: # 无限生成器,可以重复采样\n", " batch_id += 1\n", " \n", " # 随机选择n个文件\n", " selected_files = random.sample(all_train_files, n_files_per_batch)\n", " \n", " print(f\" 批次 {batch_id} - 随机选择的文件:\")\n", " for i, f in enumerate(selected_files, 1):\n", " print(f\" {i}. {f}\")\n", " \n", " # 从选中的文件中加载数据\n", " all_features = []\n", " all_labels = []\n", " total_available_samples = 0\n", " \n", " for filename in selected_files:\n", " # 加载文件数据\n", " data = np.load(os.path.join(self.pipeline.data_dir, filename), allow_pickle=True)\n", " trials = data['neural_logits_concatenated']\n", " \n", " # 提取特征和标签(带数据增强)\n", " features, labels = extract_features_labels_batch_with_augmentation(\n", " trials, \n", " random_shuffle_trials=True,\n", " apply_augmentation=apply_augmentation,\n", " augment_ratio=augment_ratio\n", " )\n", " \n", " if features.shape[0] > 0:\n", " all_features.append(features)\n", " all_labels.append(labels)\n", " total_available_samples += features.shape[0]\n", " \n", " # 清理单个文件数据\n", " del data, trials\n", " gc.collect()\n", " \n", " if all_features:\n", " # 合并所有选中文件的数据\n", " combined_features = np.vstack(all_features)\n", " combined_labels = np.hstack(all_labels)\n", " \n", " print(f\" 合并后总样本数: {combined_features.shape[0]:,}\")\n", " \n", " # 随机采样到目标batch_size\n", " if combined_features.shape[0] > batch_size:\n", " # 随机选择batch_size个样本\n", " sample_indices = np.random.choice(\n", " combined_features.shape[0], \n", " size=batch_size, \n", " replace=False\n", " )\n", " sampled_features = combined_features[sample_indices]\n", " sampled_labels = combined_labels[sample_indices]\n", " print(f\" 随机采样到: {batch_size:,} 样本\")\n", " else:\n", " # 如果样本不足,使用所有样本\n", " sampled_features = combined_features\n", " sampled_labels = combined_labels\n", " print(f\" 样本不足,使用全部: {sampled_features.shape[0]:,} 样本\")\n", " \n", " # 应用采样策略(平衡处理)\n", " features_balanced, labels_balanced = self.pipeline._apply_full_sampling(\n", " sampled_features, sampled_labels\n", " )\n", " \n", " # 应用PCA降维\n", " if features_balanced.shape[0] > 0:\n", " features_pca = self.pipeline._apply_pca_transform(features_balanced)\n", " \n", " # 分析当前批次分布\n", " batch_counts = Counter(labels_balanced)\n", " \n", " print(f\" 批次 {batch_id} 最终结果:\")\n", " print(f\" 平衡后样本数: {features_pca.shape[0]:,}\")\n", " print(f\" 特征维度: {features_pca.shape[1]}\")\n", " print(f\" 分布: 标签0={batch_counts.get(0,0)}, 标签40={batch_counts.get(40,0)}\")\n", " print(f\" \" + \"=\"*50)\n", " \n", " # 重要修复:设置 free_raw_data=False 避免增量训练失败\n", " yield lgb.Dataset(features_pca, label=labels_balanced, free_raw_data=False), f\"batch_{batch_id}_files_{len(selected_files)}_aug_{apply_augmentation}\"\n", " \n", " # 清理批次数据\n", " del all_features, all_labels, combined_features, combined_labels\n", " del sampled_features, sampled_labels, features_balanced, labels_balanced\n", " gc.collect()\n", " else:\n", " print(f\" 批次 {batch_id} 无有效数据\")\n", " continue\n", "\n", "# 动态添加数据增强支持的方法到SmartBatchTrainer类\n", "SmartBatchTrainer.prepare_training_data_with_augmentation = prepare_training_data_with_augmentation\n", "SmartBatchTrainer.prepare_full_data_with_augmentation = prepare_full_data_with_augmentation\n", "SmartBatchTrainer.get_training_batch_generator_with_augmentation = get_training_batch_generator_with_augmentation\n", "\n", "print(\"✅ SmartBatchTrainer已添加数据增强支持\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "智能分批训练器创建完成\n", " LightGBM参数已配置:CPU模式\n", " 学习率调度: 带重启的余弦退火 (从 0.1 到 0.001)\n", " 重启参数: T_0=50, T_mult=2\n" ] } ], "source": [ "trainer = SmartBatchTrainer(pipeline, min_learning_rate=0.001, t_0=30, t_mult=2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "开始智能分批训练...\n", " 训练轮数 (每批次): 5\n", " 早停轮数: 10\n", " 每批次文件数: 8\n", " 每批次样本数: 50,000\n", " 最大批次数: 100\n", "============================================================\n", "准备验证数据...\n", "\n", "处理val数据...\n", " 采样策略: 禁用\n", " 完成: 321,773 样本, 1243 特征\n", " 验证数据准备完成: 321,773 样本\n", " 验证集分布 (标签0: 238,705, 标签40: 35,425)\n", "开始分批增量训练...\n", "准备改进的训练批次生成器...\n", " 每批次选择文件数: 8\n", " 每批次目标样本数: 50,000\n", " 总计可用训练文件: 45\n", " 批次 1 - 随机选择的文件:\n", " 1. t15.2023.11.26_train_concatenated.npz\n", " 2. t15.2024.04.28_train_concatenated.npz\n", " 3. t15.2023.10.01_train_concatenated.npz\n", " 4. t15.2025.04.13_train_concatenated.npz\n", " 5. t15.2024.02.25_train_concatenated.npz\n", " 6. t15.2023.08.20_train_concatenated.npz\n", " 7. t15.2023.12.08_train_concatenated.npz\n", " 8. t15.2023.10.06_train_concatenated.npz\n", " 完成: 321,773 样本, 1243 特征\n", " 验证数据准备完成: 321,773 样本\n", " 验证集分布 (标签0: 238,705, 标签40: 35,425)\n", "开始分批增量训练...\n", "准备改进的训练批次生成器...\n", " 每批次选择文件数: 8\n", " 每批次目标样本数: 50,000\n", " 总计可用训练文件: 45\n", " 批次 1 - 随机选择的文件:\n", " 1. t15.2023.11.26_train_concatenated.npz\n", " 2. t15.2024.04.28_train_concatenated.npz\n", " 3. t15.2023.10.01_train_concatenated.npz\n", " 4. t15.2025.04.13_train_concatenated.npz\n", " 5. t15.2024.02.25_train_concatenated.npz\n", " 6. t15.2023.08.20_train_concatenated.npz\n", " 7. t15.2023.12.08_train_concatenated.npz\n", " 8. t15.2023.10.06_train_concatenated.npz\n", " 合并后总样本数: 307,619\n", " 合并后总样本数: 307,619\n", " 随机采样到: 50,000 样本\n", " 随机采样到: 50,000 样本\n", " 批次 1 最终结果:\n", " 平衡后样本数: 2,836\n", " 特征维度: 1243\n", " 分布: 标签0=76, 标签40=76\n", " ==================================================\n", " 批次 1 最终结果:\n", " 平衡后样本数: 2,836\n", " 特征维度: 1243\n", " 分布: 标签0=76, 标签40=76\n", " ==================================================\n", "\n", "批次 1: batch_1_files_8\n", " 样本数: 2,836\n", " 首次训练 (学习率: 0.100000)\n", "\n", "批次 1: batch_1_files_8\n", " 样本数: 2,836\n", " 首次训练 (学习率: 0.100000)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n", " _log_warning('Overriding the parameters from Reference Dataset.')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[1]\tvalidation's multi_logloss: 3.61118\n", "[2]\tvalidation's multi_logloss: 3.59449\n", "[2]\tvalidation's multi_logloss: 3.59449\n", "[3]\tvalidation's multi_logloss: 3.60069\n", "[3]\tvalidation's multi_logloss: 3.60069\n", "[4]\tvalidation's multi_logloss: 3.58862\n", "[4]\tvalidation's multi_logloss: 3.58862\n", "[5]\tvalidation's multi_logloss: 3.56763\n", "[5]\tvalidation's multi_logloss: 3.56763\n", " 批次完成:\n", " 验证准确率: 0.0740\n", " 训练时间: 56.9秒\n", " 模型树数: 205\n", " 当前学习率: 0.100000\n", " 批次 2 - 随机选择的文件:\n", " 1. t15.2023.11.04_train_concatenated.npz\n", " 2. t15.2023.10.08_train_concatenated.npz\n", " 3. t15.2024.03.15_train_concatenated.npz\n", " 4. t15.2023.12.08_train_concatenated.npz\n", " 5. t15.2024.04.28_train_concatenated.npz\n", " 6. t15.2023.09.29_train_concatenated.npz\n", " 7. t15.2023.08.27_train_concatenated.npz\n", " 8. t15.2025.03.30_train_concatenated.npz\n", " 批次完成:\n", " 验证准确率: 0.0740\n", " 训练时间: 56.9秒\n", " 模型树数: 205\n", " 当前学习率: 0.100000\n", " 批次 2 - 随机选择的文件:\n", " 1. t15.2023.11.04_train_concatenated.npz\n", " 2. t15.2023.10.08_train_concatenated.npz\n", " 3. t15.2024.03.15_train_concatenated.npz\n", " 4. t15.2023.12.08_train_concatenated.npz\n", " 5. t15.2024.04.28_train_concatenated.npz\n", " 6. t15.2023.09.29_train_concatenated.npz\n", " 7. t15.2023.08.27_train_concatenated.npz\n", " 8. t15.2025.03.30_train_concatenated.npz\n", " 合并后总样本数: 312,205\n", " 合并后总样本数: 312,205\n", " 随机采样到: 50,000 样本\n", " 随机采样到: 50,000 样本\n", " 批次 2 最终结果:\n", " 平衡后样本数: 2,802\n", " 特征维度: 1243\n", " 分布: 标签0=76, 标签40=76\n", " ==================================================\n", " 批次 2 最终结果:\n", " 平衡后样本数: 2,802\n", " 特征维度: 1243\n", " 分布: 标签0=76, 标签40=76\n", " ==================================================\n", "\n", "批次 2: batch_2_files_8\n", " 样本数: 2,802\n", " 增量训练 (学习率: 0.097577)\n", "\n", "批次 2: batch_2_files_8\n", " 样本数: 2,802\n", " 增量训练 (学习率: 0.097577)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n", " _log_warning('Overriding the parameters from Reference Dataset.')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[6]\tvalidation's multi_logloss: 3.54169\n", "[7]\tvalidation's multi_logloss: 3.53279\n", "[7]\tvalidation's multi_logloss: 3.53279\n", "[8]\tvalidation's multi_logloss: 3.54027\n", "[8]\tvalidation's multi_logloss: 3.54027\n", "[9]\tvalidation's multi_logloss: 3.522\n", "[9]\tvalidation's multi_logloss: 3.522\n", "[10]\tvalidation's multi_logloss: 3.5052\n", "[10]\tvalidation's multi_logloss: 3.5052\n", " 批次完成:\n", " 验证准确率: 0.1107\n", " 训练时间: 75.2秒\n", " 模型树数: 410\n", " 当前学习率: 0.097577\n", " 批次 3 - 随机选择的文件:\n", " 1. t15.2023.10.01_train_concatenated.npz\n", " 2. t15.2024.07.28_train_concatenated.npz\n", " 3. t15.2025.01.12_train_concatenated.npz\n", " 4. t15.2023.10.22_train_concatenated.npz\n", " 5. t15.2025.03.30_train_concatenated.npz\n", " 6. t15.2023.08.13_train_concatenated.npz\n", " 7. t15.2024.05.10_train_concatenated.npz\n", " 8. t15.2025.04.13_train_concatenated.npz\n", " 批次完成:\n", " 验证准确率: 0.1107\n", " 训练时间: 75.2秒\n", " 模型树数: 410\n", " 当前学习率: 0.097577\n", " 批次 3 - 随机选择的文件:\n", " 1. t15.2023.10.01_train_concatenated.npz\n", " 2. t15.2024.07.28_train_concatenated.npz\n", " 3. t15.2025.01.12_train_concatenated.npz\n", " 4. t15.2023.10.22_train_concatenated.npz\n", " 5. t15.2025.03.30_train_concatenated.npz\n", " 6. t15.2023.08.13_train_concatenated.npz\n", " 7. t15.2024.05.10_train_concatenated.npz\n", " 8. t15.2025.04.13_train_concatenated.npz\n", " 合并后总样本数: 293,792\n", " 合并后总样本数: 293,792\n", " 随机采样到: 50,000 样本\n", " 随机采样到: 50,000 样本\n", " 批次 3 最终结果:\n", " 平衡后样本数: 2,871\n", " 特征维度: 1243\n", " 分布: 标签0=76, 标签40=76\n", " ==================================================\n", " 批次 3 最终结果:\n", " 平衡后样本数: 2,871\n", " 特征维度: 1243\n", " 分布: 标签0=76, 标签40=76\n", " ==================================================\n", "\n", "批次 3: batch_3_files_8\n", " 样本数: 2,871\n", " 增量训练 (学习率: 0.090546)\n", "\n", "批次 3: batch_3_files_8\n", " 样本数: 2,871\n", " 增量训练 (学习率: 0.090546)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n", " _log_warning('Overriding the parameters from Reference Dataset.')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[11]\tvalidation's multi_logloss: 3.53701\n", "[12]\tvalidation's multi_logloss: 3.55692\n", "[12]\tvalidation's multi_logloss: 3.55692\n", "[13]\tvalidation's multi_logloss: 3.57387\n", "[13]\tvalidation's multi_logloss: 3.57387\n", "[14]\tvalidation's multi_logloss: 3.57509\n", "[14]\tvalidation's multi_logloss: 3.57509\n", "[15]\tvalidation's multi_logloss: 3.57052\n", "[15]\tvalidation's multi_logloss: 3.57052\n", " 批次完成:\n", " 验证准确率: 0.1049\n", " 训练时间: 83.5秒\n", " 模型树数: 451\n", " 当前学习率: 0.090546\n", " 批次 4 - 随机选择的文件:\n", " 1. t15.2024.03.15_train_concatenated.npz\n", " 2. t15.2023.10.15_train_concatenated.npz\n", " 3. t15.2024.02.25_train_concatenated.npz\n", " 4. t15.2023.08.20_train_concatenated.npz\n", " 5. t15.2023.10.22_train_concatenated.npz\n", " 6. t15.2023.12.10_train_concatenated.npz\n", " 7. t15.2023.10.20_train_concatenated.npz\n", " 8. t15.2024.07.28_train_concatenated.npz\n", " 批次完成:\n", " 验证准确率: 0.1049\n", " 训练时间: 83.5秒\n", " 模型树数: 451\n", " 当前学习率: 0.090546\n", " 批次 4 - 随机选择的文件:\n", " 1. t15.2024.03.15_train_concatenated.npz\n", " 2. t15.2023.10.15_train_concatenated.npz\n", " 3. t15.2024.02.25_train_concatenated.npz\n", " 4. t15.2023.08.20_train_concatenated.npz\n", " 5. t15.2023.10.22_train_concatenated.npz\n", " 6. t15.2023.12.10_train_concatenated.npz\n", " 7. t15.2023.10.20_train_concatenated.npz\n", " 8. t15.2024.07.28_train_concatenated.npz\n", " 合并后总样本数: 335,983\n", " 合并后总样本数: 335,983\n", " 随机采样到: 50,000 样本\n", " 随机采样到: 50,000 样本\n", " 批次 4 最终结果:\n", " 平衡后样本数: 2,840\n", " 特征维度: 1243\n", " 分布: 标签0=76, 标签40=76\n", " ==================================================\n", " 批次 4 最终结果:\n", " 平衡后样本数: 2,840\n", " 特征维度: 1243\n", " 分布: 标签0=76, 标签40=76\n", " ==================================================\n", "\n", "批次 4: batch_4_files_8\n", " 样本数: 2,840\n", " 增量训练 (学习率: 0.079595)\n", "\n", "批次 4: batch_4_files_8\n", " 样本数: 2,840\n", " 增量训练 (学习率: 0.079595)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n", " _log_warning('Overriding the parameters from Reference Dataset.')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[12]\tvalidation's multi_logloss: 3.51209\n", "[13]\tvalidation's multi_logloss: 3.51032\n", "[13]\tvalidation's multi_logloss: 3.51032\n", "[14]\tvalidation's multi_logloss: 3.51553\n", "[14]\tvalidation's multi_logloss: 3.51553\n", "[15]\tvalidation's multi_logloss: 3.5119\n", "[15]\tvalidation's multi_logloss: 3.5119\n", "[16]\tvalidation's multi_logloss: 3.50659\n", "[16]\tvalidation's multi_logloss: 3.50659\n", " 批次完成:\n", " 验证准确率: 0.1218\n", " 训练时间: 89.2秒\n", " 模型树数: 656\n", " 当前学习率: 0.079595\n", " 批次 5 - 随机选择的文件:\n", " 1. t15.2023.08.27_train_concatenated.npz\n", " 2. t15.2023.11.17_train_concatenated.npz\n", " 3. t15.2024.07.19_train_concatenated.npz\n", " 4. t15.2023.09.03_train_concatenated.npz\n", " 5. t15.2023.08.20_train_concatenated.npz\n", " 6. t15.2023.12.10_train_concatenated.npz\n", " 7. t15.2023.12.08_train_concatenated.npz\n", " 8. t15.2023.10.13_train_concatenated.npz\n", " 批次完成:\n", " 验证准确率: 0.1218\n", " 训练时间: 89.2秒\n", " 模型树数: 656\n", " 当前学习率: 0.079595\n", " 批次 5 - 随机选择的文件:\n", " 1. t15.2023.08.27_train_concatenated.npz\n", " 2. t15.2023.11.17_train_concatenated.npz\n", " 3. t15.2024.07.19_train_concatenated.npz\n", " 4. t15.2023.09.03_train_concatenated.npz\n", " 5. t15.2023.08.20_train_concatenated.npz\n", " 6. t15.2023.12.10_train_concatenated.npz\n", " 7. t15.2023.12.08_train_concatenated.npz\n", " 8. t15.2023.10.13_train_concatenated.npz\n", " 合并后总样本数: 319,334\n", " 合并后总样本数: 319,334\n", " 随机采样到: 50,000 样本\n", " 随机采样到: 50,000 样本\n", " 批次 5 最终结果:\n", " 平衡后样本数: 2,839\n", " 特征维度: 1243\n", " 分布: 标签0=76, 标签40=76\n", " ==================================================\n", " 批次 5 最终结果:\n", " 平衡后样本数: 2,839\n", " 特征维度: 1243\n", " 分布: 标签0=76, 标签40=76\n", " ==================================================\n", "\n", "批次 5: batch_5_files_8\n", " 样本数: 2,839\n", " 增量训练 (学习率: 0.065796)\n", "\n", "批次 5: batch_5_files_8\n", " 样本数: 2,839\n", " 增量训练 (学习率: 0.065796)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n", " _log_warning('Overriding the parameters from Reference Dataset.')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[17]\tvalidation's multi_logloss: 3.48509\n", "[18]\tvalidation's multi_logloss: 3.49194\n", "[18]\tvalidation's multi_logloss: 3.49194\n", "[19]\tvalidation's multi_logloss: 3.50197\n", "[19]\tvalidation's multi_logloss: 3.50197\n", "[20]\tvalidation's multi_logloss: 3.50542\n", "[20]\tvalidation's multi_logloss: 3.50542\n", "[21]\tvalidation's multi_logloss: 3.51177\n", "[21]\tvalidation's multi_logloss: 3.51177\n", " 批次完成:\n", " 验证准确率: 0.1273\n", " 训练时间: 101.3秒\n", " 模型树数: 697\n", " 当前学习率: 0.065796\n", " 批次 6 - 随机选择的文件:\n", " 1. t15.2023.10.13_train_concatenated.npz\n", " 2. t15.2023.10.20_train_concatenated.npz\n", " 3. t15.2024.02.25_train_concatenated.npz\n", " 4. t15.2023.08.27_train_concatenated.npz\n", " 5. t15.2024.04.28_train_concatenated.npz\n", " 6. t15.2025.01.12_train_concatenated.npz\n", " 7. t15.2023.08.25_train_concatenated.npz\n", " 8. t15.2023.12.03_train_concatenated.npz\n", " 批次完成:\n", " 验证准确率: 0.1273\n", " 训练时间: 101.3秒\n", " 模型树数: 697\n", " 当前学习率: 0.065796\n", " 批次 6 - 随机选择的文件:\n", " 1. t15.2023.10.13_train_concatenated.npz\n", " 2. t15.2023.10.20_train_concatenated.npz\n", " 3. t15.2024.02.25_train_concatenated.npz\n", " 4. t15.2023.08.27_train_concatenated.npz\n", " 5. t15.2024.04.28_train_concatenated.npz\n", " 6. t15.2025.01.12_train_concatenated.npz\n", " 7. t15.2023.08.25_train_concatenated.npz\n", " 8. t15.2023.12.03_train_concatenated.npz\n", " 合并后总样本数: 258,917\n", " 合并后总样本数: 258,917\n", " 随机采样到: 50,000 样本\n", " 随机采样到: 50,000 样本\n", " 批次 6 最终结果:\n", " 平衡后样本数: 2,826\n", " 特征维度: 1243\n", " 分布: 标签0=76, 标签40=76\n", " ==================================================\n", " 批次 6 最终结果:\n", " 平衡后样本数: 2,826\n", " 特征维度: 1243\n", " 分布: 标签0=76, 标签40=76\n", " ==================================================\n", "\n", "批次 6: batch_6_files_8\n", " 样本数: 2,826\n", " 增量训练 (学习率: 0.050500)\n", "\n", "批次 6: batch_6_files_8\n", " 样本数: 2,826\n", " 增量训练 (学习率: 0.050500)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n", " _log_warning('Overriding the parameters from Reference Dataset.')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[18]\tvalidation's multi_logloss: 3.48149\n", "[19]\tvalidation's multi_logloss: 3.47673\n", "[19]\tvalidation's multi_logloss: 3.47673\n", "[20]\tvalidation's multi_logloss: 3.47655\n", "[20]\tvalidation's multi_logloss: 3.47655\n", "[21]\tvalidation's multi_logloss: 3.47924\n", "[21]\tvalidation's multi_logloss: 3.47924\n", "[22]\tvalidation's multi_logloss: 3.47754\n", "[22]\tvalidation's multi_logloss: 3.47754\n", " 批次完成:\n", " 验证准确率: 0.1308\n", " 训练时间: 112.3秒\n", " 模型树数: 820\n", " 当前学习率: 0.050500\n", " 批次 7 - 随机选择的文件:\n", " 1. t15.2023.12.29_train_concatenated.npz\n", " 2. t15.2023.09.29_train_concatenated.npz\n", " 3. t15.2023.09.01_train_concatenated.npz\n", " 4. t15.2023.12.08_train_concatenated.npz\n", " 5. t15.2024.05.10_train_concatenated.npz\n", " 6. t15.2023.10.08_train_concatenated.npz\n", " 7. t15.2025.04.13_train_concatenated.npz\n", " 8. t15.2023.09.24_train_concatenated.npz\n", " 批次完成:\n", " 验证准确率: 0.1308\n", " 训练时间: 112.3秒\n", " 模型树数: 820\n", " 当前学习率: 0.050500\n", " 批次 7 - 随机选择的文件:\n", " 1. t15.2023.12.29_train_concatenated.npz\n", " 2. t15.2023.09.29_train_concatenated.npz\n", " 3. t15.2023.09.01_train_concatenated.npz\n", " 4. t15.2023.12.08_train_concatenated.npz\n", " 5. t15.2024.05.10_train_concatenated.npz\n", " 6. t15.2023.10.08_train_concatenated.npz\n", " 7. t15.2025.04.13_train_concatenated.npz\n", " 8. t15.2023.09.24_train_concatenated.npz\n", " 合并后总样本数: 338,426\n", " 合并后总样本数: 338,426\n", " 随机采样到: 50,000 样本\n", " 随机采样到: 50,000 样本\n", " 批次 7 最终结果:\n", " 平衡后样本数: 2,858\n", " 特征维度: 1243\n", " 分布: 标签0=76, 标签40=76\n", " ==================================================\n", " 批次 7 最终结果:\n", " 平衡后样本数: 2,858\n", " 特征维度: 1243\n", " 分布: 标签0=76, 标签40=76\n", " ==================================================\n", "\n", "批次 7: batch_7_files_8\n", " 样本数: 2,858\n", " 增量训练 (学习率: 0.035204)\n", "\n", "批次 7: batch_7_files_8\n", " 样本数: 2,858\n", " 增量训练 (学习率: 0.035204)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n", " _log_warning('Overriding the parameters from Reference Dataset.')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[21]\tvalidation's multi_logloss: 3.47751\n", "[22]\tvalidation's multi_logloss: 3.48293\n", "[22]\tvalidation's multi_logloss: 3.48293\n", "[23]\tvalidation's multi_logloss: 3.48763\n", "[23]\tvalidation's multi_logloss: 3.48763\n", "[24]\tvalidation's multi_logloss: 3.48841\n", "[24]\tvalidation's multi_logloss: 3.48841\n", "[25]\tvalidation's multi_logloss: 3.49353\n", "[25]\tvalidation's multi_logloss: 3.49353\n", " 批次完成:\n", " 验证准确率: 0.1311\n", " 训练时间: 137.4秒\n", " 模型树数: 861\n", " 当前学习率: 0.035204\n", " 批次 8 - 随机选择的文件:\n", " 1. t15.2023.12.08_train_concatenated.npz\n", " 2. t15.2025.03.30_train_concatenated.npz\n", " 3. t15.2023.11.03_train_concatenated.npz\n", " 4. t15.2023.09.29_train_concatenated.npz\n", " 5. t15.2024.03.15_train_concatenated.npz\n", " 6. t15.2025.01.10_train_concatenated.npz\n", " 7. t15.2023.08.27_train_concatenated.npz\n", " 8. t15.2023.10.22_train_concatenated.npz\n", " 批次完成:\n", " 验证准确率: 0.1311\n", " 训练时间: 137.4秒\n", " 模型树数: 861\n", " 当前学习率: 0.035204\n", " 批次 8 - 随机选择的文件:\n", " 1. t15.2023.12.08_train_concatenated.npz\n", " 2. t15.2025.03.30_train_concatenated.npz\n", " 3. t15.2023.11.03_train_concatenated.npz\n", " 4. t15.2023.09.29_train_concatenated.npz\n", " 5. t15.2024.03.15_train_concatenated.npz\n", " 6. t15.2025.01.10_train_concatenated.npz\n", " 7. t15.2023.08.27_train_concatenated.npz\n", " 8. t15.2023.10.22_train_concatenated.npz\n", " 合并后总样本数: 307,185\n", " 合并后总样本数: 307,185\n", " 随机采样到: 50,000 样本\n", " 随机采样到: 50,000 样本\n", " 批次 8 最终结果:\n", " 平衡后样本数: 2,812\n", " 特征维度: 1243\n", " 分布: 标签0=76, 标签40=76\n", " ==================================================\n", " 批次 8 最终结果:\n", " 平衡后样本数: 2,812\n", " 特征维度: 1243\n", " 分布: 标签0=76, 标签40=76\n", " ==================================================\n", "\n", "批次 8: batch_8_files_8\n", " 样本数: 2,812\n", " 增量训练 (学习率: 0.021405)\n", "\n", "批次 8: batch_8_files_8\n", " 样本数: 2,812\n", " 增量训练 (学习率: 0.021405)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n", " _log_warning('Overriding the parameters from Reference Dataset.')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[22]\tvalidation's multi_logloss: 3.47332\n", "[23]\tvalidation's multi_logloss: 3.47467\n", "[23]\tvalidation's multi_logloss: 3.47467\n", "[24]\tvalidation's multi_logloss: 3.4695\n", "[24]\tvalidation's multi_logloss: 3.4695\n", "[25]\tvalidation's multi_logloss: 3.46523\n", "[25]\tvalidation's multi_logloss: 3.46523\n", "[26]\tvalidation's multi_logloss: 3.46263\n", "[26]\tvalidation's multi_logloss: 3.46263\n", " 批次完成:\n", " 验证准确率: 0.1359\n", " 训练时间: 126.0秒\n", " 模型树数: 1066\n", " 当前学习率: 0.021405\n", " 批次 9 - 随机选择的文件:\n", " 1. t15.2023.11.03_train_concatenated.npz\n", " 2. t15.2024.03.08_train_concatenated.npz\n", " 3. t15.2023.09.01_train_concatenated.npz\n", " 4. t15.2023.08.18_train_concatenated.npz\n", " 5. t15.2023.08.27_train_concatenated.npz\n", " 6. t15.2023.11.19_train_concatenated.npz\n", " 7. t15.2023.09.03_train_concatenated.npz\n", " 8. t15.2024.02.25_train_concatenated.npz\n", " 批次完成:\n", " 验证准确率: 0.1359\n", " 训练时间: 126.0秒\n", " 模型树数: 1066\n", " 当前学习率: 0.021405\n", " 批次 9 - 随机选择的文件:\n", " 1. t15.2023.11.03_train_concatenated.npz\n", " 2. t15.2024.03.08_train_concatenated.npz\n", " 3. t15.2023.09.01_train_concatenated.npz\n", " 4. t15.2023.08.18_train_concatenated.npz\n", " 5. t15.2023.08.27_train_concatenated.npz\n", " 6. t15.2023.11.19_train_concatenated.npz\n", " 7. t15.2023.09.03_train_concatenated.npz\n", " 8. t15.2024.02.25_train_concatenated.npz\n", " 合并后总样本数: 318,419\n", " 合并后总样本数: 318,419\n", " 随机采样到: 50,000 样本\n", " 随机采样到: 50,000 样本\n", " 批次 9 最终结果:\n", " 平衡后样本数: 2,864\n", " 特征维度: 1243\n", " 分布: 标签0=76, 标签40=76\n", " ==================================================\n", " 批次 9 最终结果:\n", " 平衡后样本数: 2,864\n", " 特征维度: 1243\n", " 分布: 标签0=76, 标签40=76\n", " ==================================================\n", "\n", "批次 9: batch_9_files_8\n", " 样本数: 2,864\n", " 增量训练 (学习率: 0.010454)\n", "\n", "批次 9: batch_9_files_8\n", " 样本数: 2,864\n", " 增量训练 (学习率: 0.010454)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n", " _log_warning('Overriding the parameters from Reference Dataset.')\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[27]\tvalidation's multi_logloss: 3.45802\n", "[28]\tvalidation's multi_logloss: 3.45485\n", "[28]\tvalidation's multi_logloss: 3.45485\n", "[29]\tvalidation's multi_logloss: 3.45208\n", "[29]\tvalidation's multi_logloss: 3.45208\n", "[30]\tvalidation's multi_logloss: 3.44677\n", "[30]\tvalidation's multi_logloss: 3.44677\n", "[31]\tvalidation's multi_logloss: 3.44218\n", "[31]\tvalidation's multi_logloss: 3.44218\n", " 批次完成:\n", " 验证准确率: 0.1416\n", " 训练时间: 143.2秒\n", " 模型树数: 1271\n", " 当前学习率: 0.010454\n", " 批次 10 - 随机选择的文件:\n", " 1. t15.2024.07.21_train_concatenated.npz\n", " 2. t15.2024.06.14_train_concatenated.npz\n", " 3. t15.2023.12.10_train_concatenated.npz\n", " 4. t15.2024.07.19_train_concatenated.npz\n", " 5. t15.2023.12.03_train_concatenated.npz\n", " 6. t15.2023.11.04_train_concatenated.npz\n", " 7. t15.2023.10.01_train_concatenated.npz\n", " 8. t15.2023.08.25_train_concatenated.npz\n", " 批次完成:\n", " 验证准确率: 0.1416\n", " 训练时间: 143.2秒\n", " 模型树数: 1271\n", " 当前学习率: 0.010454\n", " 批次 10 - 随机选择的文件:\n", " 1. t15.2024.07.21_train_concatenated.npz\n", " 2. t15.2024.06.14_train_concatenated.npz\n", " 3. t15.2023.12.10_train_concatenated.npz\n", " 4. t15.2024.07.19_train_concatenated.npz\n", " 5. t15.2023.12.03_train_concatenated.npz\n", " 6. t15.2023.11.04_train_concatenated.npz\n", " 7. t15.2023.10.01_train_concatenated.npz\n", " 8. t15.2023.08.25_train_concatenated.npz\n", " 合并后总样本数: 262,351\n", " 合并后总样本数: 262,351\n", " 随机采样到: 50,000 样本\n", " 随机采样到: 50,000 样本\n", " 批次 10 最终结果:\n", " 平衡后样本数: 2,823\n", " 特征维度: 1243\n", " 分布: 标签0=76, 标签40=76\n", " ==================================================\n", " 批次 10 最终结果:\n", " 平衡后样本数: 2,823\n", " 特征维度: 1243\n", " 分布: 标签0=76, 标签40=76\n", " ==================================================\n", "\n", "批次 10: batch_10_files_8\n", " 样本数: 2,823\n", " 增量训练 (学习率: 0.003423)\n", "\n", "批次 10: batch_10_files_8\n", " 样本数: 2,823\n", " 增量训练 (学习率: 0.003423)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n", " _log_warning('Overriding the parameters from Reference Dataset.')\n" ] } ], "source": [ "# 改进的训练参数\n", "IMPROVED_TRAINING_PARAMS = {\n", " 'num_boost_round': 5, # 每批次的提升轮数\n", " 'early_stopping_rounds': 10, # 早停轮数\n", " 'n_files_per_batch': 8, # 快速验证用,减少到4\n", " 'batch_size': 50000, # 快速验证用,减半\n", " 'max_batches': 100 # 仅跑100个批次做冒烟测试\n", "}\n", "\n", "# 开始使用改进的训练器\n", "model = trainer.train_incremental(\n", " num_boost_round=IMPROVED_TRAINING_PARAMS['num_boost_round'],\n", " early_stopping_rounds=IMPROVED_TRAINING_PARAMS['early_stopping_rounds'],\n", " n_files_per_batch=IMPROVED_TRAINING_PARAMS['n_files_per_batch'],\n", " batch_size=IMPROVED_TRAINING_PARAMS['batch_size'],\n", " max_batches=IMPROVED_TRAINING_PARAMS['max_batches']\n", ")\n", "\n", "# 训练完成后计算一次验证集PER(每个文件取33%试验)\n", "per_summary = trainer.evaluate_val_per_experiment(fraction=0.33, random_state=42, drop_sep40=False, max_trials_per_file=5)\n", "print(per_summary)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 📊 训练结果分析" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 📊 训练结果分析和可视化\n", "\n", "print(\"📊 分析智能分批训练结果...\")\n", "print(\"=\" * 60)\n", "\n", "# 显示训练进度图表\n", "trainer.plot_training_progress()\n", "\n", "# 保存最终模型\n", "final_model_path = \"smart_pipeline_final_model.txt\"\n", "if trainer.model:\n", " trainer.model.save_model(final_model_path)\n", " print(f\"\\n💾 最终模型已保存: {final_model_path}\")\n", "\n", "# 详细分析\n", "if trainer.training_history:\n", " print(f\"\\n📈 详细训练分析:\")\n", " print(f\" 🎯 训练批次总数: {len(trainer.training_history)}\")\n", " \n", " # 最佳批次\n", " best_batch = max(trainer.training_history, key=lambda x: x['val_accuracy'])\n", " print(f\" 🏆 最佳验证准确率: {best_batch['val_accuracy']:.4f} (批次 {best_batch['batch']})\")\n", " \n", " # 训练效率\n", " total_training_time = sum(h['time'] for h in trainer.training_history)\n", " avg_batch_time = total_training_time / len(trainer.training_history)\n", " print(f\" ⏱️ 总训练时间: {total_training_time:.1f}秒\")\n", " print(f\" ⏱️ 平均批次时间: {avg_batch_time:.1f}秒\")\n", " \n", " # 模型复杂度\n", " final_trees = trainer.training_history[-1]['num_trees']\n", " print(f\" 🌳 最终模型树数: {final_trees}\")\n", " \n", " # 收敛性分析\n", " recent_accs = [h['val_accuracy'] for h in trainer.training_history[-3:]]\n", " if len(recent_accs) >= 2:\n", " acc_stability = max(recent_accs) - min(recent_accs)\n", " print(f\" 📈 准确率稳定性: {acc_stability:.4f} (最近3批次方差)\")\n", " \n", " if acc_stability < 0.01:\n", " print(\" ✅ 模型已收敛 (准确率变化 < 1%)\")\n", " else:\n", " print(\" ⚠️ 模型可能需要更多训练\")\n", "\n", "print(f\"\\n🎉 智能分批训练分析完成!\")\n", "print(f\" 💡 使用了改进的数据平衡策略和PCA降维\")\n", "print(f\" 💡 训练集应用了下采样+过采样,验证集保持原始分布\")\n", "print(f\" 💡 实现了内存友好的分批处理\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🧪 模型性能评估" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 🧪 模型性能评估\n", "\n", "from sklearn.metrics import classification_report, confusion_matrix\n", "import numpy as np\n", "\n", "def evaluate_model_performance(model, pipeline, data_type='val'):\n", " \"\"\"\n", " 评估模型在指定数据集上的性能\n", " \"\"\"\n", " print(f\"🧪 评估模型在{data_type}数据集上的性能...\")\n", " \n", " # 加载数据\n", " X, y = pipeline.step3_process_data(data_type, apply_sampling=False)\n", " \n", " if X is None or y is None:\n", " print(f\"❌ 无法加载{data_type}数据\")\n", " return None\n", " \n", " print(f\" 📊 数据集大小: {X.shape[0]:,} 样本, {X.shape[1]} 特征\")\n", " \n", " # 预测\n", " start_time = time.time()\n", " y_pred_proba = model.predict(X)\n", " y_pred = y_pred_proba.argmax(axis=1)\n", " pred_time = time.time() - start_time\n", " \n", " # 计算性能指标\n", " accuracy = (y_pred == y).mean()\n", " \n", " print(f\" ⏱️ 预测时间: {pred_time:.2f}秒\")\n", " print(f\" 🎯 整体准确率: {accuracy:.4f}\")\n", " \n", " # 分析各类别性能\n", " from collections import Counter\n", " true_counts = Counter(y)\n", " pred_counts = Counter(y_pred)\n", " \n", " print(f\"\\n📊 标签分布对比:\")\n", " print(\"标签 | 真实数量 | 预测数量 | 准确率\")\n", " print(\"-\" * 40)\n", " \n", " label_accuracies = {}\n", " for label in range(41):\n", " if label in true_counts:\n", " label_mask = (y == label)\n", " if label_mask.sum() > 0:\n", " label_acc = (y_pred[label_mask] == label).mean()\n", " label_accuracies[label] = label_acc\n", " true_count = true_counts.get(label, 0)\n", " pred_count = pred_counts.get(label, 0)\n", " print(f\"{label:4d} | {true_count:8,} | {pred_count:8,} | {label_acc:7.3f}\")\n", " \n", " # 重点分析关键标签\n", " print(f\"\\n🔍 关键标签性能分析:\")\n", " key_labels = [0, 40] # 下采样的标签\n", " for label in key_labels:\n", " if label in label_accuracies:\n", " acc = label_accuracies[label]\n", " count = true_counts.get(label, 0)\n", " print(f\" 标签 {label} (下采样目标): 准确率 {acc:.4f}, 样本数 {count:,}\")\n", " \n", " # 少数类性能\n", " minority_labels = [label for label, count in true_counts.items() \n", " if count < 200 and label not in [0, 40]]\n", " if minority_labels:\n", " minority_accs = [label_accuracies.get(label, 0) for label in minority_labels[:5]]\n", " avg_minority_acc = np.mean(minority_accs) if minority_accs else 0\n", " print(f\" 少数类平均准确率 (前5个): {avg_minority_acc:.4f}\")\n", " \n", " # 置信度分析\n", " max_proba = y_pred_proba.max(axis=1)\n", " print(f\"\\n📈 预测置信度分析:\")\n", " print(f\" 平均置信度: {max_proba.mean():.4f}\")\n", " print(f\" 置信度中位数: {np.median(max_proba):.4f}\")\n", " print(f\" 高置信度预测 (>0.9): {(max_proba > 0.9).sum():,} / {len(max_proba):,} ({(max_proba > 0.9).mean():.2%})\")\n", " \n", " return {\n", " 'accuracy': accuracy,\n", " 'prediction_time': pred_time,\n", " 'label_accuracies': label_accuracies,\n", " 'confidence_stats': {\n", " 'mean': max_proba.mean(),\n", " 'median': np.median(max_proba),\n", " 'high_confidence_ratio': (max_proba > 0.9).mean()\n", " }\n", " }\n", "\n", "# 评估模型性能\n", "if trainer.model:\n", " print(\"🧪 开始模型性能评估...\")\n", " \n", " # 验证集评估\n", " val_results = evaluate_model_performance(trainer.model, pipeline, 'val')\n", " \n", " print(f\"\\n\" + \"=\"*60)\n", " print(\"🎉 智能分批训练+数据平衡 评估完成!\")\n", " print(f\"✅ 实现了数据平衡和PCA降维的完整流程\")\n", " print(f\"✅ 使用了内存友好的分批训练策略\")\n", " print(f\"✅ 保持了验证集的原始分布以确保评估客观性\")\n", "else:\n", " print(\"❌ 模型尚未训练完成,请等待训练结束后运行此评估\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# ✅ 余弦退火已更新为带重启版本\n", "\n", "print(\"🎉 余弦退火调度器更新完成!\")\n", "\n", "# 检查trainer是否已创建,如果未创建则先创建\n", "if 'trainer' not in globals():\n", " print(\"⚠️ 训练器尚未创建,请先运行前面的代码创建训练器\")\n", "else:\n", " print(f\"✅ 当前使用:带重启的余弦退火 (SGDR)\")\n", " print(f\" 🔄 重启参数: T_0={trainer.t_0}, T_mult={trainer.t_mult}\")\n", " print(f\" 📈 学习率范围: {trainer.initial_learning_rate} → {trainer.min_learning_rate}\")\n", "\n", " # 可视化新的学习率调度\n", " import matplotlib.pyplot as plt\n", " import numpy as np\n", "\n", " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))\n", "\n", " # 模拟300轮的学习率变化\n", " rounds = list(range(300))\n", " old_lrs = [] # 原始余弦退火\n", " new_lrs = [] # 带重启的余弦退火\n", "\n", " for r in rounds:\n", " # 原始余弦退火 (单调递减)\n", " old_lr = trainer.min_learning_rate + 0.5 * (trainer.initial_learning_rate - trainer.min_learning_rate) * (1 + np.cos(np.pi * r / 300))\n", " old_lrs.append(old_lr)\n", " \n", " # 带重启的余弦退火\n", " new_lr = trainer._cosine_annealing_with_warm_restarts(r)\n", " new_lrs.append(new_lr)\n", "\n", " # 绘制对比图\n", " ax1.plot(rounds, old_lrs, 'b-', label='原始余弦退火', linewidth=2)\n", " ax1.set_xlabel('Training Round')\n", " ax1.set_ylabel('Learning Rate')\n", " ax1.set_title('原始余弦退火 (单调递减)')\n", " ax1.grid(True, alpha=0.3)\n", " ax1.legend()\n", "\n", " ax2.plot(rounds, new_lrs, 'r-', label='带重启的余弦退火', linewidth=2)\n", " ax2.set_xlabel('Training Round')\n", " ax2.set_ylabel('Learning Rate')\n", " ax2.set_title('带重启的余弦退火 (SGDR)')\n", " ax2.grid(True, alpha=0.3)\n", " ax2.legend()\n", "\n", " plt.tight_layout()\n", " plt.show()\n", "\n", " print(\"📊 学习率调度对比可视化完成\")\n", " print(\" 🔵 原始版本:单调递减的余弦曲线\")\n", " print(\" 🔴 新版本:周期性重启,每次重启后学习率回到最大值\")\n", " print(\" 💡 SGDR的优势:多次重启可以帮助模型跳出局部最优解\")\n", "\n", " # 显示重启点\n", " restart_points = []\n", " t_cur = 0\n", " t_i = trainer.t_0\n", " while t_cur < 300:\n", " restart_points.append(t_cur)\n", " t_cur += t_i\n", " t_i *= trainer.t_mult\n", "\n", " print(f\" 🔄 在300轮训练中的重启点: {restart_points[:5]}...\") # 显示前5个重启点" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 🧪 测试新的第三小样本数采样策略\n", "\n", "print(\"🧪 测试新的第三小样本数采样策略...\")\n", "print(\"=\" * 60)\n", "\n", "# 模拟数据测试\n", "np.random.seed(42)\n", "random.seed(42)\n", "\n", "# 创建模拟数据\n", "n_samples = 10000\n", "n_features = 100\n", "X_test = np.random.randn(n_samples, n_features)\n", "\n", "# 创建不平衡的标签分布\n", "label_counts_target = {\n", " 0: 5000, 1: 100, 2: 150, 3: 80, 4: 200, 5: 120, 6: 90, 7: 180, 8: 110, 9: 160,\n", " 10: 140, 11: 170, 12: 130, 13: 190, 14: 105, 15: 95, 16: 175, 17: 125, 18: 155, 19: 135,\n", " 20: 145, 21: 165, 22: 115, 23: 185, 24: 85, 25: 195, 26: 75, 27: 205, 28: 70, 29: 210,\n", " 30: 65, 31: 215, 32: 60, 33: 220, 34: 55, 35: 225, 36: 50, 37: 230, 38: 45, 39: 235, 40: 3000\n", "}\n", "\n", "y_test = []\n", "for label, count in label_counts_target.items():\n", " y_test.extend([label] * min(count, n_samples - len(y_test)))\n", "y_test = np.array(y_test)\n", "X_test = X_test[:len(y_test)]\n", "\n", "# 随机打乱\n", "shuffle_idx = np.random.permutation(len(y_test))\n", "X_test = X_test[shuffle_idx]\n", "y_test = y_test[shuffle_idx]\n", "\n", "print(f\"模拟数据创建完成: {X_test.shape[0]:,} 样本, {X_test.shape[1]} 特征\")\n", "\n", "# 显示原始分布\n", "original_counts = Counter(y_test)\n", "all_counts = [original_counts.get(i, 0) for i in range(41)]\n", "non_zero_counts = [c for c in all_counts if c > 0]\n", "sorted_counts = sorted(non_zero_counts)\n", "\n", "print(f\"原始分布前10个最小: {sorted_counts[:10]}\")\n", "print(f\"第三小样本数: {sorted_counts[2] if len(sorted_counts) >= 3 else 'N/A'}\")\n", "\n", "# 测试balance_dataset函数\n", "X_balanced, y_balanced = balance_dataset(X_test, y_test)\n", "\n", "# 显示平衡后的分布\n", "balanced_counts = Counter(y_balanced)\n", "print(f\"\\n平衡后各标签样本数:\")\n", "for label in range(41):\n", " original = original_counts.get(label, 0)\n", " balanced = balanced_counts.get(label, 0)\n", " if original > 0 or balanced > 0:\n", " status = \"📉\" if balanced < original else \"✅\" if balanced == original else \"📈\"\n", " print(f\" {status} 标签 {label:2d}: {original:4d} → {balanced:4d}\")\n", "\n", "print(f\"\\n✅ 测试完成!\")\n", "print(f\" 原始样本数: {len(y_test):,}\")\n", "print(f\" 平衡后样本数: {len(y_balanced):,}\")\n", "print(f\" 数据变化比例: {len(y_balanced)/len(y_test):.2f}x\")" ] } ], "metadata": { "kaggle": { "accelerator": "tpu1vmV38", "dataSources": [ { "databundleVersionId": 13056355, "sourceId": 106809, "sourceType": "competition" } ], "dockerImageVersionId": 31091, "isGpuEnabled": false, "isInternetEnabled": true, "language": "python", "sourceType": "notebook" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.13" } }, "nbformat": 4, "nbformat_minor": 4 }