2875 lines
149 KiB
Plaintext
2875 lines
149 KiB
Plaintext
![]() |
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# 环境配置与Utils"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Looking in indexes: https://download.pytorch.org/whl/cu126\n",
|
|||
|
"Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n",
|
|||
|
"Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (0.21.0+cu124)\n",
|
|||
|
"Requirement already satisfied: torchaudio in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n",
|
|||
|
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n",
|
|||
|
"Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.14.0)\n",
|
|||
|
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.5)\n",
|
|||
|
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n",
|
|||
|
"Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.5.1)\n",
|
|||
|
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
|
|||
|
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
|
|||
|
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
|
|||
|
"Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch) (9.1.0.70)\n",
|
|||
|
"Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.5.8)\n",
|
|||
|
"Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch) (11.2.1.3)\n",
|
|||
|
"Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch) (10.3.5.147)\n",
|
|||
|
"Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch) (11.6.1.9)\n",
|
|||
|
"Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch) (12.3.1.170)\n",
|
|||
|
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n",
|
|||
|
"Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n",
|
|||
|
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
|
|||
|
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
|
|||
|
"Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n",
|
|||
|
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n",
|
|||
|
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n",
|
|||
|
"Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision) (1.26.4)\n",
|
|||
|
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (11.2.1)\n",
|
|||
|
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n",
|
|||
|
"Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.3.8)\n",
|
|||
|
"Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.2.4)\n",
|
|||
|
"Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (0.1.1)\n",
|
|||
|
"Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2025.2.0)\n",
|
|||
|
"Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2022.2.0)\n",
|
|||
|
"Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2.4.1)\n",
|
|||
|
"Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2024.2.0)\n",
|
|||
|
"Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2022.2.0)\n",
|
|||
|
"Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy->torchvision) (1.4.0)\n",
|
|||
|
"Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy->torchvision) (2024.2.0)\n",
|
|||
|
"Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy->torchvision) (2024.2.0)\n",
|
|||
|
"Requirement already satisfied: jupyter==1.1.1 in /usr/local/lib/python3.11/dist-packages (1.1.1)\n",
|
|||
|
"Requirement already satisfied: numpy<2.1.0,>=1.26.0 in /usr/local/lib/python3.11/dist-packages (1.26.4)\n",
|
|||
|
"Requirement already satisfied: pandas==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n",
|
|||
|
"Requirement already satisfied: matplotlib==3.10.1 in /usr/local/lib/python3.11/dist-packages (3.10.1)\n",
|
|||
|
"Requirement already satisfied: scipy==1.15.2 in /usr/local/lib/python3.11/dist-packages (1.15.2)\n",
|
|||
|
"Requirement already satisfied: scikit-learn==1.6.1 in /usr/local/lib/python3.11/dist-packages (1.6.1)\n",
|
|||
|
"Requirement already satisfied: lightgbm==4.3.0 in /usr/local/lib/python3.11/dist-packages (4.3.0)\n",
|
|||
|
"Requirement already satisfied: tqdm==4.67.1 in /usr/local/lib/python3.11/dist-packages (4.67.1)\n",
|
|||
|
"Requirement already satisfied: g2p_en==2.1.0 in /usr/local/lib/python3.11/dist-packages (2.1.0)\n",
|
|||
|
"Requirement already satisfied: h5py==3.13.0 in /usr/local/lib/python3.11/dist-packages (3.13.0)\n",
|
|||
|
"Requirement already satisfied: omegaconf==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n",
|
|||
|
"Requirement already satisfied: editdistance==0.8.1 in /usr/local/lib/python3.11/dist-packages (0.8.1)\n",
|
|||
|
"Requirement already satisfied: huggingface-hub==0.33.1 in /usr/local/lib/python3.11/dist-packages (0.33.1)\n",
|
|||
|
"Requirement already satisfied: transformers==4.53.0 in /usr/local/lib/python3.11/dist-packages (4.53.0)\n",
|
|||
|
"Requirement already satisfied: tokenizers==0.21.2 in /usr/local/lib/python3.11/dist-packages (0.21.2)\n",
|
|||
|
"Requirement already satisfied: accelerate==1.8.1 in /usr/local/lib/python3.11/dist-packages (1.8.1)\n",
|
|||
|
"Requirement already satisfied: bitsandbytes==0.46.0 in /usr/local/lib/python3.11/dist-packages (0.46.0)\n",
|
|||
|
"Requirement already satisfied: seaborn==0.13.2 in /usr/local/lib/python3.11/dist-packages (0.13.2)\n",
|
|||
|
"Requirement already satisfied: notebook in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.5.4)\n",
|
|||
|
"Requirement already satisfied: jupyter-console in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.1.0)\n",
|
|||
|
"Requirement already satisfied: nbconvert in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.4.5)\n",
|
|||
|
"Requirement already satisfied: ipykernel in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.17.1)\n",
|
|||
|
"Requirement already satisfied: ipywidgets in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (8.1.5)\n",
|
|||
|
"Requirement already satisfied: jupyterlab in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (3.6.8)\n",
|
|||
|
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2.9.0.post0)\n",
|
|||
|
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n",
|
|||
|
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n",
|
|||
|
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.3.2)\n",
|
|||
|
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (0.12.1)\n",
|
|||
|
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (4.58.4)\n",
|
|||
|
"Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.4.8)\n",
|
|||
|
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (25.0)\n",
|
|||
|
"Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (11.2.1)\n",
|
|||
|
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (3.0.9)\n",
|
|||
|
"Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (1.5.1)\n",
|
|||
|
"Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (3.6.0)\n",
|
|||
|
"Requirement already satisfied: nltk>=3.2.4 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (3.9.1)\n",
|
|||
|
"Requirement already satisfied: inflect>=0.3.1 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (7.5.0)\n",
|
|||
|
"Requirement already satisfied: distance>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (0.1.3)\n",
|
|||
|
"Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (4.9.3)\n",
|
|||
|
"Requirement already satisfied: PyYAML>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (6.0.2)\n",
|
|||
|
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (3.18.0)\n",
|
|||
|
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2025.5.1)\n",
|
|||
|
"Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2.32.4)\n",
|
|||
|
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (4.14.0)\n",
|
|||
|
"Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (1.1.5)\n",
|
|||
|
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (2024.11.6)\n",
|
|||
|
"Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (0.5.3)\n",
|
|||
|
"Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (7.0.0)\n",
|
|||
|
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (2.6.0+cu124)\n",
|
|||
|
"Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.3.8)\n",
|
|||
|
"Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.2.4)\n",
|
|||
|
"Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (0.1.1)\n",
|
|||
|
"Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2025.2.0)\n",
|
|||
|
"Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2022.2.0)\n",
|
|||
|
"Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2.4.1)\n",
|
|||
|
"Requirement already satisfied: more_itertools>=8.5.0 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (10.7.0)\n",
|
|||
|
"Requirement already satisfied: typeguard>=4.0.1 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (4.4.4)\n",
|
|||
|
"Requirement already satisfied: click in /usr/local/lib/python3.11/dist-packages (from nltk>=3.2.4->g2p_en==2.1.0) (8.2.1)\n",
|
|||
|
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas==2.3.0) (1.17.0)\n",
|
|||
|
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.5)\n",
|
|||
|
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.1.6)\n",
|
|||
|
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
|||
|
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
|||
|
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
|||
|
"Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (9.1.0.70)\n",
|
|||
|
"Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.5.8)\n",
|
|||
|
"Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.2.1.3)\n",
|
|||
|
"Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (10.3.5.147)\n",
|
|||
|
"Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.6.1.9)\n",
|
|||
|
"Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.3.1.170)\n",
|
|||
|
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (0.6.2)\n",
|
|||
|
"Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (2.21.5)\n",
|
|||
|
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
|||
|
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
|||
|
"Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.2.0)\n",
|
|||
|
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (1.13.1)\n",
|
|||
|
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=2.0.0->accelerate==1.8.1) (1.3.0)\n",
|
|||
|
"Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.8.0)\n",
|
|||
|
"Requirement already satisfied: ipython>=7.23.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (7.34.0)\n",
|
|||
|
"Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (8.6.3)\n",
|
|||
|
"Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (0.1.7)\n",
|
|||
|
"Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.6.0)\n",
|
|||
|
"Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (24.0.1)\n",
|
|||
|
"Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (6.5.1)\n",
|
|||
|
"Requirement already satisfied: traitlets>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (5.7.1)\n",
|
|||
|
"Requirement already satisfied: comm>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (0.2.2)\n",
|
|||
|
"Requirement already satisfied: widgetsnbextension~=4.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (4.0.14)\n",
|
|||
|
"Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (3.0.15)\n",
|
|||
|
"Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (3.0.51)\n",
|
|||
|
"Requirement already satisfied: pygments in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (2.19.2)\n",
|
|||
|
"Requirement already satisfied: jupyter-core in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (5.8.1)\n",
|
|||
|
"Requirement already satisfied: jupyterlab-server~=2.19 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.27.3)\n",
|
|||
|
"Requirement already satisfied: jupyter-server<3,>=1.16.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.12.5)\n",
|
|||
|
"Requirement already satisfied: jupyter-ydoc~=0.2.4 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.2.5)\n",
|
|||
|
"Requirement already satisfied: jupyter-server-ydoc~=0.8.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.8.0)\n",
|
|||
|
"Requirement already satisfied: nbclassic in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (1.3.1)\n",
|
|||
|
"Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (25.1.0)\n",
|
|||
|
"Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.2.0)\n",
|
|||
|
"Requirement already satisfied: nbformat in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (5.10.4)\n",
|
|||
|
"Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (1.8.3)\n",
|
|||
|
"Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.18.1)\n",
|
|||
|
"Requirement already satisfied: prometheus-client in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.22.1)\n",
|
|||
|
"Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.8.4)\n",
|
|||
|
"Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.3.0)\n",
|
|||
|
"Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.4)\n",
|
|||
|
"Requirement already satisfied: bleach in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (6.2.0)\n",
|
|||
|
"Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (1.5.1)\n",
|
|||
|
"Requirement already satisfied: testpath in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.6.0)\n",
|
|||
|
"Requirement already satisfied: defusedxml in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.7.1)\n",
|
|||
|
"Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (4.13.4)\n",
|
|||
|
"Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.5.13)\n",
|
|||
|
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (3.0.2)\n",
|
|||
|
"Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
|
|||
|
"Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2022.2.0)\n",
|
|||
|
"Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy<2.1.0,>=1.26.0) (1.4.0)\n",
|
|||
|
"Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
|
|||
|
"Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.4.2)\n",
|
|||
|
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.10)\n",
|
|||
|
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2.5.0)\n",
|
|||
|
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2025.6.15)\n",
|
|||
|
"Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
|
|||
|
"Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (75.2.0)\n",
|
|||
|
"Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.19.2)\n",
|
|||
|
"Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.4.2)\n",
|
|||
|
"Requirement already satisfied: pickleshare in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.7.5)\n",
|
|||
|
"Requirement already satisfied: backcall in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.2.0)\n",
|
|||
|
"Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.9.0)\n",
|
|||
|
"Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.11/dist-packages (from jupyter-core->jupyterlab->jupyter==1.1.1) (4.3.8)\n",
|
|||
|
"Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (4.9.0)\n",
|
|||
|
"Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.12.0)\n",
|
|||
|
"Requirement already satisfied: jupyter-server-terminals in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.5.3)\n",
|
|||
|
"Requirement already satisfied: overrides in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (7.7.0)\n",
|
|||
|
"Requirement already satisfied: websocket-client in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.8.0)\n",
|
|||
|
"Requirement already satisfied: jupyter-server-fileid<1,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.9.3)\n",
|
|||
|
"Requirement already satisfied: ypy-websocket<0.9.0,>=0.8.2 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.8.4)\n",
|
|||
|
"Requirement already satisfied: y-py<0.7.0,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-ydoc~=0.2.4->jupyterlab->jupyter==1.1.1) (0.6.2)\n",
|
|||
|
"Requirement already satisfied: babel>=2.10 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2.17.0)\n",
|
|||
|
"Requirement already satisfied: json5>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.12.0)\n",
|
|||
|
"Requirement already satisfied: jsonschema>=4.18.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (4.24.0)\n",
|
|||
|
"Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.11/dist-packages (from nbclassic->jupyterlab->jupyter==1.1.1) (0.2.4)\n",
|
|||
|
"Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.11/dist-packages (from nbformat->notebook->jupyter==1.1.1) (2.21.1)\n",
|
|||
|
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->jupyter-console->jupyter==1.1.1) (0.2.13)\n",
|
|||
|
"Requirement already satisfied: ptyprocess in /usr/local/lib/python3.11/dist-packages (from terminado>=0.8.3->notebook->jupyter==1.1.1) (0.7.0)\n",
|
|||
|
"Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.11/dist-packages (from argon2-cffi->notebook->jupyter==1.1.1) (21.2.0)\n",
|
|||
|
"Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->nbconvert->jupyter==1.1.1) (2.7)\n",
|
|||
|
"Requirement already satisfied: webencodings in /usr/local/lib/python3.11/dist-packages (from bleach->nbconvert->jupyter==1.1.1) (0.5.1)\n",
|
|||
|
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio>=3.1.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.1)\n",
|
|||
|
"Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.11/dist-packages (from jedi>=0.16->ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.8.4)\n",
|
|||
|
"Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (25.3.0)\n",
|
|||
|
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2025.4.1)\n",
|
|||
|
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.36.2)\n",
|
|||
|
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.25.1)\n",
|
|||
|
"Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.3.0)\n",
|
|||
|
"Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.4)\n",
|
|||
|
"Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.1)\n",
|
|||
|
"Requirement already satisfied: aiofiles<23,>=22.1.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (22.1.0)\n",
|
|||
|
"Requirement already satisfied: aiosqlite<1,>=0.17.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.21.0)\n",
|
|||
|
"Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (1.17.1)\n",
|
|||
|
"Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (2.22)\n",
|
|||
|
"Requirement already satisfied: fqdn in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.5.1)\n",
|
|||
|
"Requirement already satisfied: isoduration in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (20.11.0)\n",
|
|||
|
"Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.0.0)\n",
|
|||
|
"Requirement already satisfied: uri-template in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n",
|
|||
|
"Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (24.11.1)\n",
|
|||
|
"Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.11/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n",
|
|||
|
"Requirement already satisfied: types-python-dateutil>=2.8.10 in /usr/local/lib/python3.11/dist-packages (from arrow>=0.15.0->isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (2.9.0.20250516)\n",
|
|||
|
"Obtaining file:///kaggle/working/nejm-brain-to-text\n",
|
|||
|
" Preparing metadata (setup.py): started\n",
|
|||
|
" Preparing metadata (setup.py): finished with status 'done'\n",
|
|||
|
"Installing collected packages: nejm_b2txt_utils\n",
|
|||
|
" Running setup.py develop for nejm_b2txt_utils\n",
|
|||
|
"Successfully installed nejm_b2txt_utils-0.0.0\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Cloning into 'nejm-brain-to-text'...\n",
|
|||
|
"Updating files: 100% (2633/2633), done.\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"%%bash\n",
|
|||
|
"rm -rf /kaggle/working/nejm-brain-to-text/\n",
|
|||
|
"git clone https://github.com/ZH-CEN/nejm-brain-to-text.git\n",
|
|||
|
"cp /kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl /kaggle/working/nejm-brain-to-text/data/t15_copyTask.pkl\n",
|
|||
|
"\n",
|
|||
|
"ln -s /kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline /kaggle/working/nejm-brain-to-text/data\n",
|
|||
|
"ln -s /kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final /kaggle/working/nejm-brain-to-text/data\n",
|
|||
|
"ln -s /kaggle/input/rnn-pretagged-data /kaggle/working/nejm-brain-to-text/data/concatenated_data\n",
|
|||
|
"\n",
|
|||
|
"pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n",
|
|||
|
"\n",
|
|||
|
"pip install \\\n",
|
|||
|
" jupyter==1.1.1 \\\n",
|
|||
|
" \"numpy>=1.26.0,<2.1.0\" \\\n",
|
|||
|
" pandas==2.3.0 \\\n",
|
|||
|
" matplotlib==3.10.1 \\\n",
|
|||
|
" scipy==1.15.2 \\\n",
|
|||
|
" scikit-learn==1.6.1 \\\n",
|
|||
|
" lightgbm==4.3.0 \\\n",
|
|||
|
" tqdm==4.67.1 \\\n",
|
|||
|
" g2p_en==2.1.0 \\\n",
|
|||
|
" h5py==3.13.0 \\\n",
|
|||
|
" omegaconf==2.3.0 \\\n",
|
|||
|
" editdistance==0.8.1 \\\n",
|
|||
|
" huggingface-hub==0.33.1 \\\n",
|
|||
|
" transformers==4.53.0 \\\n",
|
|||
|
" tokenizers==0.21.2 \\\n",
|
|||
|
" accelerate==1.8.1 \\\n",
|
|||
|
" bitsandbytes==0.46.0 \\\n",
|
|||
|
" seaborn==0.13.2\n",
|
|||
|
"cd /kaggle/working/nejm-brain-to-text/\n",
|
|||
|
"pip install -e ."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"==================================================\n",
|
|||
|
"🔧 LightGBM GPU环境检查\n",
|
|||
|
"==================================================\n",
|
|||
|
"❌ 未检测到NVIDIA GPU或驱动\n",
|
|||
|
"\n",
|
|||
|
"✅ CUDA工具包:\n",
|
|||
|
" Cuda compilation tools, release 12.5, V12.5.82\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 🚀 LightGBM GPU支持检查与配置\n",
|
|||
|
"\n",
|
|||
|
"print(\"=\"*50)\n",
|
|||
|
"print(\"🔧 LightGBM GPU环境检查\")\n",
|
|||
|
"print(\"=\"*50)\n",
|
|||
|
"\n",
|
|||
|
"# 检查CUDA和GPU驱动\n",
|
|||
|
"import subprocess\n",
|
|||
|
"import sys\n",
|
|||
|
"\n",
|
|||
|
"def run_command(command):\n",
|
|||
|
" \"\"\"运行命令并返回结果\"\"\"\n",
|
|||
|
" try:\n",
|
|||
|
" result = subprocess.run(command, shell=True, capture_output=True, text=True, timeout=10)\n",
|
|||
|
" return result.stdout.strip(), result.returncode == 0\n",
|
|||
|
" except Exception as e:\n",
|
|||
|
" return str(e), False\n",
|
|||
|
"\n",
|
|||
|
"# 检查NVIDIA GPU\n",
|
|||
|
"nvidia_output, nvidia_success = run_command(\"nvidia-smi --query-gpu=name,memory.total,driver_version --format=csv,noheader,nounits\")\n",
|
|||
|
"if nvidia_success:\n",
|
|||
|
" print(\"✅ NVIDIA GPU检测:\")\n",
|
|||
|
" for line in nvidia_output.split('\\n'):\n",
|
|||
|
" if line.strip():\n",
|
|||
|
" print(f\" {line}\")\n",
|
|||
|
"else:\n",
|
|||
|
" print(\"❌ 未检测到NVIDIA GPU或驱动\")\n",
|
|||
|
"\n",
|
|||
|
"# 检查CUDA版本\n",
|
|||
|
"cuda_output, cuda_success = run_command(\"nvcc --version\")\n",
|
|||
|
"if cuda_success:\n",
|
|||
|
" print(\"\\n✅ CUDA工具包:\")\n",
|
|||
|
" # 提取CUDA版本\n",
|
|||
|
" for line in cuda_output.split('\\n'):\n",
|
|||
|
" if 'release' in line:\n",
|
|||
|
" print(f\" {line.strip()}\")\n",
|
|||
|
"else:\n",
|
|||
|
" print(\"\\n❌ 未安装CUDA工具包\")\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/kaggle/working/nejm-brain-to-text\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"%cd /kaggle/working/nejm-brain-to-text\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import os\n",
|
|||
|
"import pickle\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"import matplotlib\n",
|
|||
|
"from g2p_en import G2p\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from nejm_b2txt_utils.general_utils import *\n",
|
|||
|
"matplotlib.rcParams['pdf.fonttype'] = 42\n",
|
|||
|
"matplotlib.rcParams['ps.fonttype'] = 42\n",
|
|||
|
"matplotlib.rcParams['font.family'] = 'sans-serif'\n",
|
|||
|
"matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial Unicode MS', 'sans-serif']\n",
|
|||
|
"matplotlib.rcParams['axes.unicode_minus'] = False\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/kaggle/working/nejm-brain-to-text/model_training\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"%cd model_training/\n",
|
|||
|
"from data_augmentations import gauss_smooth"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"LOGIT_TO_PHONEME = [\n",
|
|||
|
" 'BLANK',\n",
|
|||
|
" 'AA', 'AE', 'AH', 'AO', 'AW',\n",
|
|||
|
" 'AY', 'B', 'CH', 'D', 'DH',\n",
|
|||
|
" 'EH', 'ER', 'EY', 'F', 'G',\n",
|
|||
|
" 'HH', 'IH', 'IY', 'JH', 'K',\n",
|
|||
|
" 'L', 'M', 'N', 'NG', 'OW',\n",
|
|||
|
" 'OY', 'P', 'R', 'S', 'SH',\n",
|
|||
|
" 'T', 'TH', 'UH', 'UW', 'V',\n",
|
|||
|
" 'W', 'Y', 'Z', 'ZH',\n",
|
|||
|
" ' | ',\n",
|
|||
|
"]\n",
|
|||
|
"# 全局配置\n",
|
|||
|
"BALANCE_CONFIG = {\n",
|
|||
|
" 'enable_balance': True, # 是否启用数据平衡\n",
|
|||
|
" 'undersample_labels': [0, 40], # 需要下采样的标签 (blank等高频标签)\n",
|
|||
|
" 'oversample_threshold': 0.5, # 过采样阈值 (相对于均值的比例)\n",
|
|||
|
" 'random_state': 42 # 随机种子\n",
|
|||
|
"}\n",
|
|||
|
"# 全局PCA配置\n",
|
|||
|
"PCA_CONFIG = {\n",
|
|||
|
" 'enable_pca': True, # 是否启用PCA\n",
|
|||
|
" 'n_components': None, # None=自动选择, 或指定具体数值\n",
|
|||
|
" 'variance_threshold': 0.95, # 保留95%的方差\n",
|
|||
|
" 'sample_size': 15000, # 用于拟合PCA的样本数\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# 全局PCA对象 (确保只拟合一次)\n",
|
|||
|
"GLOBAL_PCA = {\n",
|
|||
|
" 'scaler': None,\n",
|
|||
|
" 'pca': None,\n",
|
|||
|
" 'is_fitted': False,\n",
|
|||
|
" 'n_components': None\n",
|
|||
|
"}\n",
|
|||
|
"# 设置数据目录和参数【PCA初始化】\n",
|
|||
|
"data_dir = '/kaggle/working/nejm-brain-to-text/data/concatenated_data'\n",
|
|||
|
"MAX_SAMPLES_PER_FILE = -1 # 每个文件最大样本数,可调整"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# 数据读取工作流"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# 2️⃣ 数据加载与PCA降维"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 🚀 内存友好的数据读取 - 分批加载策略 + PCA降维 【这里还缺一个采样】\n",
|
|||
|
"\n",
|
|||
|
"import os\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import gc\n",
|
|||
|
"from sklearn.decomposition import PCA\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"import joblib\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"def load_data_batch(data_dir, data_type, max_samples_per_file=5000):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 分批加载指定类型的数据\n",
|
|||
|
" \n",
|
|||
|
" Args:\n",
|
|||
|
" data_dir: 数据目录\n",
|
|||
|
" data_type: 'train', 'val', 'test'\n",
|
|||
|
" max_samples_per_file: 每个文件最大加载样本数\n",
|
|||
|
" \n",
|
|||
|
" Returns:\n",
|
|||
|
" generator: 数据批次生成器\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" files = [f for f in os.listdir(data_dir) if f.endswith('.npz') and data_type in f]\n",
|
|||
|
" \n",
|
|||
|
" for file_idx, f in enumerate(files):\n",
|
|||
|
" print(f\" 正在加载文件 {file_idx+1}/{len(files)}: {f}\")\n",
|
|||
|
" \n",
|
|||
|
" data = np.load(os.path.join(data_dir, f), allow_pickle=True)\n",
|
|||
|
" trials = data['neural_logits_concatenated']\n",
|
|||
|
" \n",
|
|||
|
" # 限制每个文件的样本数\n",
|
|||
|
" if len(trials) > max_samples_per_file and max_samples_per_file != -1:\n",
|
|||
|
" trials = trials[:max_samples_per_file]\n",
|
|||
|
" print(f\" 限制样本数至: {max_samples_per_file}\")\n",
|
|||
|
" \n",
|
|||
|
" yield trials, f\n",
|
|||
|
" \n",
|
|||
|
" # 清理内存\n",
|
|||
|
" del data, trials\n",
|
|||
|
" gc.collect()\n",
|
|||
|
"\n",
|
|||
|
"def extract_features_labels_batch(trials_batch):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 从试验批次中提取特征和标签\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" features = []\n",
|
|||
|
" labels = []\n",
|
|||
|
" \n",
|
|||
|
" for trial in trials_batch:\n",
|
|||
|
" if trial.shape[0] > 0:\n",
|
|||
|
" for t in range(trial.shape[0]):\n",
|
|||
|
" neural_features = trial[t, :7168] # 前7168维神经特征\n",
|
|||
|
" rnn_logits = trial[t, 7168:] # 后41维RNN输出\n",
|
|||
|
" phoneme_label = np.argmax(rnn_logits)\n",
|
|||
|
" \n",
|
|||
|
" features.append(neural_features)\n",
|
|||
|
" labels.append(phoneme_label)\n",
|
|||
|
" \n",
|
|||
|
" return np.array(features), np.array(labels)\n",
|
|||
|
"\n",
|
|||
|
"def fit_global_pca(data_dir, config):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 在训练数据上拟合全局PCA (只执行一次)\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if GLOBAL_PCA['is_fitted'] or not config['enable_pca']:\n",
|
|||
|
" print(\"🔧 PCA已拟合或未启用,跳过拟合步骤\")\n",
|
|||
|
" return\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n🔧 拟合全局PCA降维器...\")\n",
|
|||
|
" print(f\" 配置: {config}\")\n",
|
|||
|
" \n",
|
|||
|
" # 收集训练样本\n",
|
|||
|
" sample_features = []\n",
|
|||
|
" collected_samples = 0\n",
|
|||
|
" \n",
|
|||
|
" for trials_batch, filename in load_data_batch(data_dir, 'train', 5000):\n",
|
|||
|
" features, labels = extract_features_labels_batch(trials_batch)\n",
|
|||
|
" sample_features.append(features)\n",
|
|||
|
" collected_samples += features.shape[0]\n",
|
|||
|
" \n",
|
|||
|
" if collected_samples >= config['sample_size']:\n",
|
|||
|
" break\n",
|
|||
|
" \n",
|
|||
|
" if sample_features:\n",
|
|||
|
" # 合并样本数据\n",
|
|||
|
" X_sample = np.vstack(sample_features)[:config['sample_size']]\n",
|
|||
|
" print(f\" 实际样本数: {X_sample.shape[0]}\")\n",
|
|||
|
" print(f\" 原始特征数: {X_sample.shape[1]}\")\n",
|
|||
|
" \n",
|
|||
|
" # 标准化\n",
|
|||
|
" GLOBAL_PCA['scaler'] = StandardScaler()\n",
|
|||
|
" X_sample_scaled = GLOBAL_PCA['scaler'].fit_transform(X_sample)\n",
|
|||
|
" \n",
|
|||
|
" # 确定PCA成分数\n",
|
|||
|
" if config['n_components'] is None:\n",
|
|||
|
" print(f\" 🔍 自动选择PCA成分数...\")\n",
|
|||
|
" pca_full = PCA()\n",
|
|||
|
" pca_full.fit(X_sample_scaled)\n",
|
|||
|
" \n",
|
|||
|
" cumsum_var = np.cumsum(pca_full.explained_variance_ratio_)\n",
|
|||
|
" optimal_components = np.argmax(cumsum_var >= config['variance_threshold']) + 1\n",
|
|||
|
" GLOBAL_PCA['n_components'] = min(optimal_components, X_sample.shape[1])\n",
|
|||
|
" \n",
|
|||
|
" print(f\" 保留{config['variance_threshold']*100}%方差需要: {optimal_components} 个成分\")\n",
|
|||
|
" print(f\" 选择成分数: {GLOBAL_PCA['n_components']}\")\n",
|
|||
|
" else:\n",
|
|||
|
" GLOBAL_PCA['n_components'] = config['n_components']\n",
|
|||
|
" print(f\" 使用指定成分数: {GLOBAL_PCA['n_components']}\")\n",
|
|||
|
" \n",
|
|||
|
" # 拟合最终PCA\n",
|
|||
|
" GLOBAL_PCA['pca'] = PCA(n_components=GLOBAL_PCA['n_components'], random_state=42)\n",
|
|||
|
" GLOBAL_PCA['pca'].fit(X_sample_scaled)\n",
|
|||
|
" GLOBAL_PCA['is_fitted'] = True\n",
|
|||
|
" \n",
|
|||
|
" # 保存模型\n",
|
|||
|
" pca_path = \"global_pca_model.joblib\"\n",
|
|||
|
" joblib.dump({\n",
|
|||
|
" 'scaler': GLOBAL_PCA['scaler'], \n",
|
|||
|
" 'pca': GLOBAL_PCA['pca'],\n",
|
|||
|
" 'n_components': GLOBAL_PCA['n_components']\n",
|
|||
|
" }, pca_path)\n",
|
|||
|
" \n",
|
|||
|
" print(f\" ✅ 全局PCA拟合完成!\")\n",
|
|||
|
" print(f\" 降维: {X_sample.shape[1]} → {GLOBAL_PCA['n_components']}\")\n",
|
|||
|
" print(f\" 降维比例: {GLOBAL_PCA['n_components']/X_sample.shape[1]:.2%}\")\n",
|
|||
|
" print(f\" 保留方差: {GLOBAL_PCA['pca'].explained_variance_ratio_.sum():.4f}\")\n",
|
|||
|
" print(f\" 模型已保存: {pca_path}\")\n",
|
|||
|
" \n",
|
|||
|
" # 清理样本数据\n",
|
|||
|
" del sample_features, X_sample, X_sample_scaled\n",
|
|||
|
" gc.collect()\n",
|
|||
|
" else:\n",
|
|||
|
" print(\"❌ 无法收集样本数据用于PCA拟合\")\n",
|
|||
|
"\n",
|
|||
|
"def apply_pca_transform(features):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 应用全局PCA变换\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if not PCA_CONFIG['enable_pca'] or not GLOBAL_PCA['is_fitted']:\n",
|
|||
|
" return features\n",
|
|||
|
" \n",
|
|||
|
" # 标准化 + PCA变换\n",
|
|||
|
" features_scaled = GLOBAL_PCA['scaler'].transform(features)\n",
|
|||
|
" features_pca = GLOBAL_PCA['pca'].transform(features_scaled)\n",
|
|||
|
" return features_pca"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 📊 数据平衡策略 - 标签分布分析与采样优化"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 【采样核心实现】\n",
|
|||
|
"def balance_dataset(X, y, config=BALANCE_CONFIG):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 对数据集进行平衡处理:下采样 + 过采样\n",
|
|||
|
" \n",
|
|||
|
" Args:\n",
|
|||
|
" X: 特征数据\n",
|
|||
|
" y: 标签数据\n",
|
|||
|
" config: 平衡配置\n",
|
|||
|
" \n",
|
|||
|
" Returns:\n",
|
|||
|
" X_balanced, y_balanced: 平衡后的数据\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if not config['enable_balance']:\n",
|
|||
|
" print(\"🔕 数据平衡已禁用,返回原始数据\")\n",
|
|||
|
" return X, y\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n⚖️ 开始数据平衡处理...\")\n",
|
|||
|
" print(f\" 原始数据: {X.shape[0]:,} 样本\")\n",
|
|||
|
" \n",
|
|||
|
" # 分析当前分布 (只考虑1-39号标签的均值)\n",
|
|||
|
" label_counts = Counter(y)\n",
|
|||
|
" counts_exclude_0_40 = [label_counts.get(i, 0) for i in range(1, 40)] # 1-39号标签\n",
|
|||
|
" mean_count = np.mean(counts_exclude_0_40) # 只计算1-39号标签的均值\n",
|
|||
|
" \n",
|
|||
|
" print(f\" 均值样本数 (标签1-39): {mean_count:.0f}\")\n",
|
|||
|
" print(f\" 下采样标签: {config['undersample_labels']}\")\n",
|
|||
|
" print(f\" 过采样阈值: {config['oversample_threshold']} * 均值\")\n",
|
|||
|
" \n",
|
|||
|
" # 准备平衡后的数据\n",
|
|||
|
" X_balanced = []\n",
|
|||
|
" y_balanced = []\n",
|
|||
|
" \n",
|
|||
|
" random.seed(config['random_state'])\n",
|
|||
|
" np.random.seed(config['random_state'])\n",
|
|||
|
" \n",
|
|||
|
" for label in range(41):\n",
|
|||
|
" # 获取当前标签的所有样本\n",
|
|||
|
" label_mask = (y == label)\n",
|
|||
|
" X_label = X[label_mask]\n",
|
|||
|
" y_label = y[label_mask]\n",
|
|||
|
" current_count = len(y_label)\n",
|
|||
|
" \n",
|
|||
|
" if current_count == 0:\n",
|
|||
|
" continue\n",
|
|||
|
" \n",
|
|||
|
" # 决定采样策略\n",
|
|||
|
" if label in config['undersample_labels']:\n",
|
|||
|
" # 下采样到均值水平\n",
|
|||
|
" target_count = int(mean_count)\n",
|
|||
|
" if current_count > target_count:\n",
|
|||
|
" # 下采样\n",
|
|||
|
" indices = np.random.choice(current_count, target_count, replace=False)\n",
|
|||
|
" X_resampled = X_label[indices]\n",
|
|||
|
" y_resampled = y_label[indices]\n",
|
|||
|
" print(f\" 📉 标签 {label}: {current_count} → {target_count} (下采样)\")\n",
|
|||
|
" else:\n",
|
|||
|
" X_resampled = X_label\n",
|
|||
|
" y_resampled = y_label\n",
|
|||
|
" print(f\" ➡️ 标签 {label}: {current_count} (无需下采样)\")\n",
|
|||
|
" \n",
|
|||
|
" elif current_count < mean_count * config['oversample_threshold']:\n",
|
|||
|
" # 过采样到阈值水平\n",
|
|||
|
" target_count = int(mean_count * config['oversample_threshold'])\n",
|
|||
|
" if current_count < target_count:\n",
|
|||
|
" # 过采样\n",
|
|||
|
" X_resampled, y_resampled = resample(\n",
|
|||
|
" X_label, y_label, \n",
|
|||
|
" n_samples=target_count, \n",
|
|||
|
" random_state=config['random_state']\n",
|
|||
|
" )\n",
|
|||
|
" print(f\" 📈 标签 {label}: {current_count} → {target_count} (过采样)\")\n",
|
|||
|
" else:\n",
|
|||
|
" X_resampled = X_label\n",
|
|||
|
" y_resampled = y_label\n",
|
|||
|
" print(f\" ➡️ 标签 {label}: {current_count} (无需过采样)\")\n",
|
|||
|
" else:\n",
|
|||
|
" # 保持不变\n",
|
|||
|
" X_resampled = X_label\n",
|
|||
|
" y_resampled = y_label\n",
|
|||
|
" print(f\" ✅ 标签 {label}: {current_count} (已平衡)\")\n",
|
|||
|
" \n",
|
|||
|
" X_balanced.append(X_resampled)\n",
|
|||
|
" y_balanced.append(y_resampled)\n",
|
|||
|
" \n",
|
|||
|
" # 合并所有平衡后的数据\n",
|
|||
|
" X_balanced = np.vstack(X_balanced)\n",
|
|||
|
" y_balanced = np.hstack(y_balanced)\n",
|
|||
|
" \n",
|
|||
|
" # 随机打乱\n",
|
|||
|
" shuffle_indices = np.random.permutation(len(y_balanced))\n",
|
|||
|
" X_balanced = X_balanced[shuffle_indices]\n",
|
|||
|
" y_balanced = y_balanced[shuffle_indices]\n",
|
|||
|
" \n",
|
|||
|
" print(f\" ✅ 平衡完成: {X_balanced.shape[0]:,} 样本\")\n",
|
|||
|
" print(f\" 数据变化: {X.shape[0]:,} → {X_balanced.shape[0]:,} ({X_balanced.shape[0]/X.shape[0]:.2f}x)\")\n",
|
|||
|
" \n",
|
|||
|
" return X_balanced, y_balanced\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 🔄 集成数据平衡的内存友好数据加载器"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 🧪 数据平衡效果测试"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 🚀 改进版智能数据处理管道"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"🚀 创建智能数据处理管道...\n",
|
|||
|
"✅ 管道创建完成,准备执行步骤1...\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 🚀 改进版智能数据处理管道【没有解决分批训练的问题】\n",
|
|||
|
"# 流程:分析分布 → 确定采样比率 → 拟合PCA(只下采样) → 数据处理(下采样+上采样+PCA)\n",
|
|||
|
"\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"from collections import Counter\n",
|
|||
|
"from sklearn.utils import resample\n",
|
|||
|
"from sklearn.decomposition import PCA\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"import joblib\n",
|
|||
|
"import random\n",
|
|||
|
"import gc\n",
|
|||
|
"\n",
|
|||
|
"class SmartDataPipeline:\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 智能数据处理管道\n",
|
|||
|
" 步骤1: 分析数据分布,确定采样策略\n",
|
|||
|
" 步骤2: 仅下采样拟合PCA参数\n",
|
|||
|
" 步骤3: 数据处理时应用完整采样+PCA降维\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" \n",
|
|||
|
" def __init__(self, data_dir, random_state=42):\n",
|
|||
|
" self.data_dir = data_dir\n",
|
|||
|
" self.random_state = random_state\n",
|
|||
|
" \n",
|
|||
|
" # 步骤1: 分布分析结果\n",
|
|||
|
" self.distribution_analysis = None\n",
|
|||
|
" self.sampling_strategy = None\n",
|
|||
|
" \n",
|
|||
|
" # 步骤2: PCA参数(基于下采样数据拟合)\n",
|
|||
|
" self.pca_scaler = None\n",
|
|||
|
" self.pca_model = None\n",
|
|||
|
" self.pca_components = None\n",
|
|||
|
" self.pca_fitted = False\n",
|
|||
|
" \n",
|
|||
|
" # 配置参数\n",
|
|||
|
" self.undersample_labels = [0, 40] # 需要下采样的标签\n",
|
|||
|
" self.oversample_threshold = 0.5 # 过采样阈值(相对于均值)\n",
|
|||
|
" self.pca_variance_threshold = 0.95 # PCA保留方差比例\n",
|
|||
|
" self.pca_sample_size = 15000 # PCA拟合样本数\n",
|
|||
|
" \n",
|
|||
|
" def step1_analyze_distribution(self, max_samples=100000):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 步骤1: 分析数据分布,确定采样策略\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" print(\"🔍 步骤1: 分析数据分布...\")\n",
|
|||
|
" \n",
|
|||
|
" # 分析验证集分布(代表整体分布特征)\n",
|
|||
|
" all_labels = []\n",
|
|||
|
" for trials_batch, filename in load_data_batch(self.data_dir, 'val', 5000):\n",
|
|||
|
" _, labels = extract_features_labels_batch(trials_batch)\n",
|
|||
|
" all_labels.extend(labels.tolist())\n",
|
|||
|
" if len(all_labels) >= max_samples:\n",
|
|||
|
" break\n",
|
|||
|
" \n",
|
|||
|
" # 统计分析\n",
|
|||
|
" label_counts = Counter(all_labels)\n",
|
|||
|
" \n",
|
|||
|
" # 计算1-39标签的均值(排除0和40)\n",
|
|||
|
" counts_1_39 = [label_counts.get(i, 0) for i in range(1, 40)]\n",
|
|||
|
" target_mean = np.mean(counts_1_39)\n",
|
|||
|
" \n",
|
|||
|
" # 生成采样策略\n",
|
|||
|
" sampling_strategy = {}\n",
|
|||
|
" for label in range(41):\n",
|
|||
|
" current_count = label_counts.get(label, 0)\n",
|
|||
|
" \n",
|
|||
|
" if label in self.undersample_labels:\n",
|
|||
|
" # 下采样到均值水平\n",
|
|||
|
" target_count = int(target_mean)\n",
|
|||
|
" action = 'undersample' if current_count > target_count else 'keep'\n",
|
|||
|
" elif current_count < target_mean * self.oversample_threshold:\n",
|
|||
|
" # 过采样到阈值水平\n",
|
|||
|
" target_count = int(target_mean * self.oversample_threshold)\n",
|
|||
|
" action = 'oversample' if current_count < target_count else 'keep'\n",
|
|||
|
" else:\n",
|
|||
|
" # 保持不变\n",
|
|||
|
" target_count = current_count\n",
|
|||
|
" action = 'keep'\n",
|
|||
|
" \n",
|
|||
|
" sampling_strategy[label] = {\n",
|
|||
|
" 'current_count': current_count,\n",
|
|||
|
" 'target_count': target_count,\n",
|
|||
|
" 'action': action\n",
|
|||
|
" }\n",
|
|||
|
" \n",
|
|||
|
" self.distribution_analysis = {\n",
|
|||
|
" 'label_counts': label_counts,\n",
|
|||
|
" 'target_mean': target_mean,\n",
|
|||
|
" 'total_samples': len(all_labels)\n",
|
|||
|
" }\n",
|
|||
|
" self.sampling_strategy = sampling_strategy\n",
|
|||
|
" \n",
|
|||
|
" print(f\" ✅ 分析完成: {len(all_labels):,} 样本\")\n",
|
|||
|
" print(f\" 📊 标签1-39均值: {target_mean:.0f}\")\n",
|
|||
|
" print(f\" 📉 下采样标签: {self.undersample_labels} → {target_mean:.0f}\")\n",
|
|||
|
" print(f\" 📈 过采样阈值: {self.oversample_threshold} × 均值 = {target_mean * self.oversample_threshold:.0f}\")\n",
|
|||
|
" \n",
|
|||
|
" return self.distribution_analysis, self.sampling_strategy\n",
|
|||
|
"\n",
|
|||
|
"# 创建智能数据处理管道\n",
|
|||
|
"print(\"🚀 创建智能数据处理管道...\")\n",
|
|||
|
"pipeline = SmartDataPipeline(data_dir, random_state=42)\n",
|
|||
|
"print(\"✅ 管道创建完成,准备执行步骤1...\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"✅ 步骤2方法已添加到管道\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 继续添加智能管道的其他方法【管道完善】\n",
|
|||
|
"\n",
|
|||
|
"def step2_fit_pca_with_undersampling(self):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 步骤2: 仅对下采样数据拟合PCA参数(不进行过采样,避免PCA被过采样影响)\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if self.sampling_strategy is None:\n",
|
|||
|
" raise ValueError(\"请先执行步骤1: step1_analyze_distribution()\")\n",
|
|||
|
" \n",
|
|||
|
" print(\"\\n🔧 步骤2: 拟合PCA参数(仅下采样,不过采样)...\")\n",
|
|||
|
" \n",
|
|||
|
" # 收集用于PCA拟合的样本(只下采样,不过采样)\n",
|
|||
|
" pca_features = []\n",
|
|||
|
" collected_samples = 0\n",
|
|||
|
" \n",
|
|||
|
" for trials_batch, filename in load_data_batch(self.data_dir, 'train', 3000):\n",
|
|||
|
" features, labels = extract_features_labels_batch(trials_batch)\n",
|
|||
|
" \n",
|
|||
|
" # 对当前批次应用仅下采样策略\n",
|
|||
|
" downsampled_features, downsampled_labels = self._apply_undersampling_only(features, labels)\n",
|
|||
|
" \n",
|
|||
|
" if downsampled_features.shape[0] > 0:\n",
|
|||
|
" pca_features.append(downsampled_features)\n",
|
|||
|
" collected_samples += downsampled_features.shape[0]\n",
|
|||
|
" \n",
|
|||
|
" if collected_samples >= self.pca_sample_size:\n",
|
|||
|
" break\n",
|
|||
|
" \n",
|
|||
|
" if pca_features:\n",
|
|||
|
" # 合并样本\n",
|
|||
|
" X_pca_sample = np.vstack(pca_features)[:self.pca_sample_size]\n",
|
|||
|
" print(f\" 📦 PCA拟合样本: {X_pca_sample.shape[0]:,} 个下采样样本\")\n",
|
|||
|
" print(f\" 🔢 原始特征维度: {X_pca_sample.shape[1]}\")\n",
|
|||
|
" \n",
|
|||
|
" # 标准化\n",
|
|||
|
" self.pca_scaler = StandardScaler()\n",
|
|||
|
" X_scaled = self.pca_scaler.fit_transform(X_pca_sample)\n",
|
|||
|
" \n",
|
|||
|
" # 确定PCA成分数\n",
|
|||
|
" pca_full = PCA()\n",
|
|||
|
" pca_full.fit(X_scaled)\n",
|
|||
|
" cumsum_var = np.cumsum(pca_full.explained_variance_ratio_)\n",
|
|||
|
" optimal_components = np.argmax(cumsum_var >= self.pca_variance_threshold) + 1\n",
|
|||
|
" self.pca_components = min(optimal_components, X_pca_sample.shape[1])\n",
|
|||
|
" \n",
|
|||
|
" # 拟合最终PCA\n",
|
|||
|
" self.pca_model = PCA(n_components=self.pca_components, random_state=self.random_state)\n",
|
|||
|
" self.pca_model.fit(X_scaled)\n",
|
|||
|
" self.pca_fitted = True\n",
|
|||
|
" \n",
|
|||
|
" # 保存PCA模型\n",
|
|||
|
" pca_path = \"smart_pipeline_pca.joblib\"\n",
|
|||
|
" joblib.dump({\n",
|
|||
|
" 'scaler': self.pca_scaler,\n",
|
|||
|
" 'pca': self.pca_model,\n",
|
|||
|
" 'components': self.pca_components\n",
|
|||
|
" }, pca_path)\n",
|
|||
|
" \n",
|
|||
|
" print(f\" ✅ PCA拟合完成!\")\n",
|
|||
|
" print(f\" 降维: {X_pca_sample.shape[1]} → {self.pca_components}\")\n",
|
|||
|
" print(f\" 降维比例: {self.pca_components/X_pca_sample.shape[1]:.2%}\")\n",
|
|||
|
" print(f\" 保留方差: {self.pca_model.explained_variance_ratio_.sum():.4f}\")\n",
|
|||
|
" print(f\" 模型保存: {pca_path}\")\n",
|
|||
|
" \n",
|
|||
|
" # 清理内存\n",
|
|||
|
" del pca_features, X_pca_sample, X_scaled\n",
|
|||
|
" gc.collect()\n",
|
|||
|
" else:\n",
|
|||
|
" raise ValueError(\"无法收集PCA拟合样本\")\n",
|
|||
|
"\n",
|
|||
|
"def _apply_undersampling_only(self, X, y):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 仅应用下采样策略(用于PCA拟合)\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" X_result = []\n",
|
|||
|
" y_result = []\n",
|
|||
|
" \n",
|
|||
|
" np.random.seed(self.random_state)\n",
|
|||
|
" \n",
|
|||
|
" for label in range(41):\n",
|
|||
|
" label_mask = (y == label)\n",
|
|||
|
" X_label = X[label_mask]\n",
|
|||
|
" y_label = y[label_mask]\n",
|
|||
|
" current_count = len(y_label)\n",
|
|||
|
" \n",
|
|||
|
" if current_count == 0:\n",
|
|||
|
" continue\n",
|
|||
|
" \n",
|
|||
|
" strategy = self.sampling_strategy[label]\n",
|
|||
|
" \n",
|
|||
|
" if strategy['action'] == 'undersample' and current_count > strategy['target_count']:\n",
|
|||
|
" # 下采样\n",
|
|||
|
" indices = np.random.choice(current_count, strategy['target_count'], replace=False)\n",
|
|||
|
" X_resampled = X_label[indices]\n",
|
|||
|
" y_resampled = y_label[indices]\n",
|
|||
|
" else:\n",
|
|||
|
" # 保持原样\n",
|
|||
|
" X_resampled = X_label\n",
|
|||
|
" y_resampled = y_label\n",
|
|||
|
" \n",
|
|||
|
" X_result.append(X_resampled)\n",
|
|||
|
" y_result.append(y_resampled)\n",
|
|||
|
" \n",
|
|||
|
" if X_result:\n",
|
|||
|
" return np.vstack(X_result), np.hstack(y_result)\n",
|
|||
|
" else:\n",
|
|||
|
" return np.array([]).reshape(0, X.shape[1]), np.array([])\n",
|
|||
|
"\n",
|
|||
|
"# 动态添加方法到类\n",
|
|||
|
"SmartDataPipeline.step2_fit_pca_with_undersampling = step2_fit_pca_with_undersampling\n",
|
|||
|
"SmartDataPipeline._apply_undersampling_only = _apply_undersampling_only\n",
|
|||
|
"\n",
|
|||
|
"print(\"✅ 步骤2方法已添加到管道\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"✅ 所有方法已添加到智能管道\n",
|
|||
|
"\n",
|
|||
|
"📋 智能数据处理管道状态:\n",
|
|||
|
" 🔍 步骤1 - 分布分析: ❌ 未完成\n",
|
|||
|
" 🔧 步骤2 - PCA拟合: ❌ 未完成\n",
|
|||
|
"\n",
|
|||
|
"🎯 使用流程:\n",
|
|||
|
" 1. pipeline.step1_analyze_distribution()\n",
|
|||
|
" 2. pipeline.step2_fit_pca_with_undersampling()\n",
|
|||
|
" 3. pipeline.step3_process_data('train') # 训练集\n",
|
|||
|
" pipeline.step3_process_data('val') # 验证集\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 添加智能管道的剩余方法\n",
|
|||
|
"\n",
|
|||
|
"def _apply_full_sampling(self, X, y):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 应用完整的采样策略(下采样+过采样)\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" X_result = []\n",
|
|||
|
" y_result = []\n",
|
|||
|
" \n",
|
|||
|
" np.random.seed(self.random_state)\n",
|
|||
|
" \n",
|
|||
|
" for label in range(41):\n",
|
|||
|
" label_mask = (y == label)\n",
|
|||
|
" X_label = X[label_mask]\n",
|
|||
|
" y_label = y[label_mask]\n",
|
|||
|
" current_count = len(y_label)\n",
|
|||
|
" \n",
|
|||
|
" if current_count == 0:\n",
|
|||
|
" continue\n",
|
|||
|
" \n",
|
|||
|
" strategy = self.sampling_strategy[label]\n",
|
|||
|
" target_count = strategy['target_count']\n",
|
|||
|
" \n",
|
|||
|
" if strategy['action'] == 'undersample' and current_count > target_count:\n",
|
|||
|
" # 下采样\n",
|
|||
|
" indices = np.random.choice(current_count, target_count, replace=False)\n",
|
|||
|
" X_resampled = X_label[indices]\n",
|
|||
|
" y_resampled = y_label[indices]\n",
|
|||
|
" elif strategy['action'] == 'oversample' and current_count < target_count:\n",
|
|||
|
" # 过采样\n",
|
|||
|
" X_resampled, y_resampled = resample(\n",
|
|||
|
" X_label, y_label, \n",
|
|||
|
" n_samples=target_count, \n",
|
|||
|
" random_state=self.random_state\n",
|
|||
|
" )\n",
|
|||
|
" else:\n",
|
|||
|
" # 保持原样\n",
|
|||
|
" X_resampled = X_label\n",
|
|||
|
" y_resampled = y_label\n",
|
|||
|
" \n",
|
|||
|
" X_result.append(X_resampled)\n",
|
|||
|
" y_result.append(y_resampled)\n",
|
|||
|
" \n",
|
|||
|
" if X_result:\n",
|
|||
|
" return np.vstack(X_result), np.hstack(y_result)\n",
|
|||
|
" else:\n",
|
|||
|
" return np.array([]).reshape(0, X.shape[1]), np.array([])\n",
|
|||
|
"\n",
|
|||
|
"def _apply_pca_transform(self, X):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 应用PCA变换\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if not self.pca_fitted:\n",
|
|||
|
" return X\n",
|
|||
|
" \n",
|
|||
|
" X_scaled = self.pca_scaler.transform(X)\n",
|
|||
|
" X_pca = self.pca_model.transform(X_scaled)\n",
|
|||
|
" return X_pca\n",
|
|||
|
"\n",
|
|||
|
"def step3_process_data(self, data_type, apply_sampling=None):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 步骤3: 处理数据(采样+PCA降维)\n",
|
|||
|
" \n",
|
|||
|
" Args:\n",
|
|||
|
" data_type: 'train', 'val', 'test'\n",
|
|||
|
" apply_sampling: 是否应用采样策略,None=训练集应用,验证/测试集不应用\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if not self.pca_fitted:\n",
|
|||
|
" raise ValueError(\"请先执行步骤2: step2_fit_pca_with_undersampling()\")\n",
|
|||
|
" \n",
|
|||
|
" if apply_sampling is None:\n",
|
|||
|
" apply_sampling = (data_type == 'train')\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n🔄 步骤3: 处理{data_type}数据...\")\n",
|
|||
|
" print(f\" 采样策略: {'启用' if apply_sampling else '禁用'}\")\n",
|
|||
|
" \n",
|
|||
|
" all_features = []\n",
|
|||
|
" all_labels = []\n",
|
|||
|
" \n",
|
|||
|
" for trials_batch, filename in load_data_batch(self.data_dir, data_type, 3000):\n",
|
|||
|
" features, labels = extract_features_labels_batch(trials_batch)\n",
|
|||
|
" \n",
|
|||
|
" # 应用采样策略\n",
|
|||
|
" if apply_sampling:\n",
|
|||
|
" features_sampled, labels_sampled = self._apply_full_sampling(features, labels)\n",
|
|||
|
" else:\n",
|
|||
|
" features_sampled, labels_sampled = features, labels\n",
|
|||
|
" \n",
|
|||
|
" # 应用PCA降维\n",
|
|||
|
" if features_sampled.shape[0] > 0:\n",
|
|||
|
" features_pca = self._apply_pca_transform(features_sampled)\n",
|
|||
|
" all_features.append(features_pca)\n",
|
|||
|
" all_labels.append(labels_sampled)\n",
|
|||
|
" \n",
|
|||
|
" if all_features:\n",
|
|||
|
" X = np.vstack(all_features)\n",
|
|||
|
" y = np.hstack(all_labels)\n",
|
|||
|
" \n",
|
|||
|
" # 随机打乱\n",
|
|||
|
" shuffle_indices = np.random.permutation(len(y))\n",
|
|||
|
" X = X[shuffle_indices]\n",
|
|||
|
" y = y[shuffle_indices]\n",
|
|||
|
" \n",
|
|||
|
" print(f\" ✅ 处理完成: {X.shape[0]:,} 样本, {X.shape[1]} 特征\")\n",
|
|||
|
" \n",
|
|||
|
" # 清理内存\n",
|
|||
|
" del all_features, all_labels\n",
|
|||
|
" gc.collect()\n",
|
|||
|
" \n",
|
|||
|
" return X, y\n",
|
|||
|
" else:\n",
|
|||
|
" return None, None\n",
|
|||
|
"\n",
|
|||
|
"def print_summary(self):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 打印管道状态总结\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" print(\"\\n📋 智能数据处理管道状态:\")\n",
|
|||
|
" print(f\" 🔍 步骤1 - 分布分析: {'✅ 完成' if self.distribution_analysis else '❌ 未完成'}\")\n",
|
|||
|
" print(f\" 🔧 步骤2 - PCA拟合: {'✅ 完成' if self.pca_fitted else '❌ 未完成'}\")\n",
|
|||
|
" \n",
|
|||
|
" if self.distribution_analysis:\n",
|
|||
|
" target_mean = self.distribution_analysis['target_mean']\n",
|
|||
|
" print(f\" 📊 标签1-39均值: {target_mean:.0f}\")\n",
|
|||
|
" \n",
|
|||
|
" if self.pca_fitted:\n",
|
|||
|
" print(f\" 🔬 PCA降维: 7168 → {self.pca_components} ({self.pca_components/7168:.1%})\")\n",
|
|||
|
" print(f\" 📈 保留方差: {self.pca_model.explained_variance_ratio_.sum():.4f}\")\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n🎯 使用流程:\")\n",
|
|||
|
" print(f\" 1. pipeline.step1_analyze_distribution()\")\n",
|
|||
|
" print(f\" 2. pipeline.step2_fit_pca_with_undersampling()\")\n",
|
|||
|
" print(f\" 3. pipeline.step3_process_data('train') # 训练集\")\n",
|
|||
|
" print(f\" pipeline.step3_process_data('val') # 验证集\")\n",
|
|||
|
"\n",
|
|||
|
"# 动态添加剩余方法到类\n",
|
|||
|
"SmartDataPipeline._apply_full_sampling = _apply_full_sampling\n",
|
|||
|
"SmartDataPipeline._apply_pca_transform = _apply_pca_transform\n",
|
|||
|
"SmartDataPipeline.step3_process_data = step3_process_data\n",
|
|||
|
"SmartDataPipeline.print_summary = print_summary\n",
|
|||
|
"\n",
|
|||
|
"print(\"✅ 所有方法已添加到智能管道\")\n",
|
|||
|
"pipeline.print_summary()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 🔥 执行智能数据处理管道"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"🚀 开始执行智能数据处理管道...\n",
|
|||
|
"============================================================\n",
|
|||
|
"\n",
|
|||
|
"======================🔍 STEP 1: 分析数据分布======================\n",
|
|||
|
"🔍 步骤1: 分析数据分布...\n",
|
|||
|
" 正在加载文件 1/41: t15.2023.11.17_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 2/41: t15.2023.12.17_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 3/41: t15.2023.10.15_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 4/41: t15.2023.10.08_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 5/41: t15.2025.01.10_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 6/41: t15.2023.12.08_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 7/41: t15.2024.03.08_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 8/41: t15.2024.03.15_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 9/41: t15.2025.03.14_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 10/41: t15.2024.02.25_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 11/41: t15.2025.03.30_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 12/41: t15.2023.09.29_val_concatenated.npz\n",
|
|||
|
" 正在加载文件 13/41: t15.2023.09.01_val_concatenated.npz\n",
|
|||
|
" ✅ 分析完成: 101,906 样本\n",
|
|||
|
" 📊 标签1-39均值: 389\n",
|
|||
|
" 📉 下采样标签: [0, 40] → 389\n",
|
|||
|
" 📈 过采样阈值: 0.5 × 均值 = 194\n",
|
|||
|
"\n",
|
|||
|
"📊 采样策略总结:\n",
|
|||
|
" 📉 下采样标签: 2 个\n",
|
|||
|
" 📈 过采样标签: 11 个\n",
|
|||
|
" ✅ 保持不变: 28 个\n",
|
|||
|
"\n",
|
|||
|
"✅ 步骤1完成!\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 🔥 执行智能数据处理管道【确定采样策略】\n",
|
|||
|
"\n",
|
|||
|
"print(\"🚀 开始执行智能数据处理管道...\")\n",
|
|||
|
"print(\"=\" * 60)\n",
|
|||
|
"\n",
|
|||
|
"# 步骤1: 分析数据分布\n",
|
|||
|
"print(\"\\n\" + \"🔍 STEP 1: 分析数据分布\".center(60, \"=\"))\n",
|
|||
|
"distribution, strategy = pipeline.step1_analyze_distribution()\n",
|
|||
|
"\n",
|
|||
|
"# 显示采样策略总结\n",
|
|||
|
"print(f\"\\n📊 采样策略总结:\")\n",
|
|||
|
"undersample_count = sum(1 for s in strategy.values() if s['action'] == 'undersample')\n",
|
|||
|
"oversample_count = sum(1 for s in strategy.values() if s['action'] == 'oversample')\n",
|
|||
|
"keep_count = sum(1 for s in strategy.values() if s['action'] == 'keep')\n",
|
|||
|
"\n",
|
|||
|
"print(f\" 📉 下采样标签: {undersample_count} 个\")\n",
|
|||
|
"print(f\" 📈 过采样标签: {oversample_count} 个\") \n",
|
|||
|
"print(f\" ✅ 保持不变: {keep_count} 个\")\n",
|
|||
|
"\n",
|
|||
|
"print(\"\\n✅ 步骤1完成!\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 12,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"\n",
|
|||
|
"=====================🔧 STEP 2: 拟合PCA参数======================\n",
|
|||
|
"\n",
|
|||
|
"🔧 步骤2: 拟合PCA参数(仅下采样,不过采样)...\n",
|
|||
|
" 正在加载文件 1/45: t15.2025.04.13_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 2/45: t15.2024.07.21_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 3/45: t15.2024.03.17_train_concatenated.npz\n",
|
|||
|
" 📦 PCA拟合样本: 15,000 个下采样样本\n",
|
|||
|
" 🔢 原始特征维度: 7168\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"ename": "KeyboardInterrupt",
|
|||
|
"evalue": "",
|
|||
|
"output_type": "error",
|
|||
|
"traceback": [
|
|||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|||
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
|||
|
"\u001b[0;32m/tmp/ipykernel_36/3241517313.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# 步骤2: 拟合PCA参数【确定PCA策略】\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"\\n\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"🔧 STEP 2: 拟合PCA参数\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcenter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m60\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"=\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mpipeline\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep2_fit_pca_with_undersampling\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"\\n✅ 步骤2完成!\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;32m/tmp/ipykernel_36/3022750261.py\u001b[0m in \u001b[0;36mstep2_fit_pca_with_undersampling\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0;31m# 确定PCA成分数\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mpca_full\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mPCA\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0mpca_full\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_scaled\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 42\u001b[0m \u001b[0mcumsum_var\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcumsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpca_full\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexplained_variance_ratio_\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0moptimal_components\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcumsum_var\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpca_variance_threshold\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/base.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1387\u001b[0m )\n\u001b[1;32m 1388\u001b[0m ):\n\u001b[0;32m-> 1389\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfit_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1390\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1391\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/decomposition/_pca.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 440\u001b[0m \u001b[0mReturns\u001b[0m \u001b[0mthe\u001b[0m \u001b[0minstance\u001b[0m \u001b[0mitself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 441\u001b[0m \"\"\"\n\u001b[0;32m--> 442\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 443\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 444\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/decomposition/_pca.py\u001b[0m in \u001b[0;36m_fit\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 540\u001b[0m \u001b[0;31m# Call different fits for either full or truncated SVD\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 541\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit_svd_solver\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m\"full\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"covariance_eigh\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 542\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit_full\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_components\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_array_api_compliant\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 543\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit_svd_solver\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m\"arpack\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"randomized\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 544\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit_truncated\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_components\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/decomposition/_pca.py\u001b[0m in \u001b[0;36m_fit_full\u001b[0;34m(self, X, n_components, xp, is_array_api_compliant)\u001b[0m\n\u001b[1;32m 581\u001b[0m \u001b[0;31m# solver by default though (assuming both are built against the\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 582\u001b[0m \u001b[0;31m# same BLAS).\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 583\u001b[0;31m \u001b[0mU\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mS\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mVt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlinalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msvd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_centered\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfull_matrices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 584\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 585\u001b[0m \u001b[0mU\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mS\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mVt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mxp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msvd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_centered\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfull_matrices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/scipy/linalg/_decomp_svd.py\u001b[0m in \u001b[0;36msvd\u001b[0;34m(a, full_matrices, compute_uv, overwrite_a, check_finite, lapack_driver)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;31m# perform decomposition\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m u, s, v, info = gesXd(a1, compute_uv=compute_uv, lwork=lwork,\n\u001b[0m\u001b[1;32m 163\u001b[0m full_matrices=full_matrices, overwrite_a=overwrite_a)\n\u001b[1;32m 164\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 步骤2: 拟合PCA参数【确定PCA策略】\n",
|
|||
|
"print(\"\\n\" + \"🔧 STEP 2: 拟合PCA参数\".center(60, \"=\"))\n",
|
|||
|
"pipeline.step2_fit_pca_with_undersampling()\n",
|
|||
|
"\n",
|
|||
|
"print(\"\\n✅ 步骤2完成!\")\n",
|
|||
|
"pipeline.print_summary()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 🚀 使用智能管道进行分批训练"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 🚀 使用智能管道进行分批训练\n",
|
|||
|
"\n",
|
|||
|
"import lightgbm as lgb\n",
|
|||
|
"import time\n",
|
|||
|
"from collections import Counter\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"\n",
|
|||
|
"class SmartBatchTrainer:\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 智能分批训练器,集成智能数据管道\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" \n",
|
|||
|
" def __init__(self, pipeline, params=None, min_learning_rate=1e-4):\n",
|
|||
|
" self.pipeline = pipeline\n",
|
|||
|
" self.model = None\n",
|
|||
|
" self.training_history = {} # 改为字典,因为只有一次训练\n",
|
|||
|
" self.batch_count = 0\n",
|
|||
|
" self.min_learning_rate = min_learning_rate\n",
|
|||
|
" self.lr_history = [] # 用于可视化\n",
|
|||
|
" \n",
|
|||
|
" # 默认LightGBM参数(GPU优化)\n",
|
|||
|
" self.params = params or {\n",
|
|||
|
" 'objective': 'multiclass',\n",
|
|||
|
" 'num_class': 41,\n",
|
|||
|
" 'metric': 'multi_logloss',\n",
|
|||
|
" 'boosting_type': 'gbdt',\n",
|
|||
|
" 'device_type': 'cpu',\n",
|
|||
|
" # 'gpu_platform_id': 0,\n",
|
|||
|
" # 'gpu_device_id': 0,\n",
|
|||
|
" 'max_bin': 255,\n",
|
|||
|
" 'num_leaves': 127,\n",
|
|||
|
" 'learning_rate': 0.08, #默认0.08\n",
|
|||
|
" 'feature_fraction': 0.8,\n",
|
|||
|
" 'bagging_fraction': 0.8,\n",
|
|||
|
" 'bagging_freq': 5,\n",
|
|||
|
" 'min_data_in_leaf': 20,\n",
|
|||
|
" 'lambda_l1': 0.1,\n",
|
|||
|
" 'lambda_l2': 0.1,\n",
|
|||
|
" 'verbose': -1,\n",
|
|||
|
" 'num_threads': -1\n",
|
|||
|
" }\n",
|
|||
|
" \n",
|
|||
|
" self.initial_learning_rate = self.params.get('learning_rate', 0.08)\n",
|
|||
|
" \n",
|
|||
|
" print(f\"🎯 智能分批训练器创建完成\")\n",
|
|||
|
" print(f\" 🔧 LightGBM参数已配置:{self.params['device_type'].upper()}模式\")\n",
|
|||
|
" print(f\" 💡 学习率调度: 余弦退火 (从 {self.initial_learning_rate} 到 {self.min_learning_rate})\")\n",
|
|||
|
" \n",
|
|||
|
" def prepare_validation_data(self):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 准备验证数据(仅PCA,保持原始分布)\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" print(\"🔄 准备验证数据...\")\n",
|
|||
|
" X_val, y_val = self.pipeline.step3_process_data('val', apply_sampling=False)\n",
|
|||
|
" if X_val is None:\n",
|
|||
|
" raise ValueError(\"无法加载验证数据\")\n",
|
|||
|
" val_counts = Counter(y_val)\n",
|
|||
|
" print(f\" ✅ 验证数据准备完成: {X_val.shape[0]:,} 样本\")\n",
|
|||
|
" print(f\" 📊 验证集分布 (标签0: {val_counts.get(0, 0):,}, 标签40: {val_counts.get(40, 0):,})\")\n",
|
|||
|
" \n",
|
|||
|
" return lgb.Dataset(X_val, label=y_val, free_raw_data=False)\n",
|
|||
|
" \n",
|
|||
|
" def get_training_batch_generator(self):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 获取训练批次生成器(平衡采样+PCA)\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" print(\"🔄 准备训练批次生成器...\")\n",
|
|||
|
" \n",
|
|||
|
" # 使用管道的批次生成器\n",
|
|||
|
" for trials_batch, filename in load_data_batch(self.pipeline.data_dir, 'train', 2000):\n",
|
|||
|
" features, labels = extract_features_labels_batch(trials_batch)\n",
|
|||
|
" \n",
|
|||
|
" # 应用完整采样策略\n",
|
|||
|
" features_sampled, labels_sampled = self.pipeline._apply_full_sampling(features, labels)\n",
|
|||
|
" \n",
|
|||
|
" # 应用PCA降维\n",
|
|||
|
" if features_sampled.shape[0] > 0:\n",
|
|||
|
" features_pca = self.pipeline._apply_pca_transform(features_sampled)\n",
|
|||
|
" \n",
|
|||
|
" # 分析当前批次分布\n",
|
|||
|
" batch_counts = Counter(labels_sampled)\n",
|
|||
|
" \n",
|
|||
|
" print(f\" 📦 批次: {filename}\")\n",
|
|||
|
" print(f\" 样本数: {features_pca.shape[0]:,}\")\n",
|
|||
|
" print(f\" 平衡后分布: 标签0={batch_counts.get(0,0)}, 标签40={batch_counts.get(40,0)}\")\n",
|
|||
|
" \n",
|
|||
|
" yield lgb.Dataset(features_pca, label=labels_sampled), filename\n",
|
|||
|
" \n",
|
|||
|
" def prepare_full_data(self):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 一次性准备所有训练和验证数据\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" print(\"🔄 准备全量训练和验证数据...\")\n",
|
|||
|
" \n",
|
|||
|
" # 1. 准备验证数据 (保持原始分布)\n",
|
|||
|
" X_val, y_val = self.pipeline.step3_process_data('val', apply_sampling=False)\n",
|
|||
|
" if X_val is None:\n",
|
|||
|
" raise ValueError(\"无法加载验证数据\")\n",
|
|||
|
" val_counts = Counter(y_val)\n",
|
|||
|
" print(f\" ✅ 验证数据准备完成: {X_val.shape[0]:,} 样本\")\n",
|
|||
|
" print(f\" 📊 验证集分布 (标签0: {val_counts.get(0, 0):,}, 标签40: {val_counts.get(40, 0):,})\")\n",
|
|||
|
" val_data = lgb.Dataset(X_val, label=y_val, free_raw_data=False)\n",
|
|||
|
" \n",
|
|||
|
" # 2. 准备训练数据 (应用完整采样和PCA策略)\n",
|
|||
|
" X_train, y_train = self.pipeline.step3_process_data('train', apply_sampling=True)\n",
|
|||
|
" if X_train is None:\n",
|
|||
|
" raise ValueError(\"无法加载训练数据\")\n",
|
|||
|
" train_counts = Counter(y_train)\n",
|
|||
|
" print(f\" ✅ 训练数据准备完成: {X_train.shape[0]:,} 样本, {X_train.shape[1]} 特征\")\n",
|
|||
|
" print(f\" 📊 训练集(采样后)分布 (标签0: {train_counts.get(0, 0):,}, 标签40: {train_counts.get(40, 0):,})\")\n",
|
|||
|
" train_data = lgb.Dataset(X_train, label=y_train)\n",
|
|||
|
" \n",
|
|||
|
" return train_data, val_data, X_val, y_val\n",
|
|||
|
" \n",
|
|||
|
" def prepare_training_data(self):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 准备训练数据(仅PCA,保持原始分布)\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" print(\"🔄 准备训练数据...\")\n",
|
|||
|
" # 2. 准备训练数据 (应用完整采样和PCA策略)\n",
|
|||
|
" X_train, y_train = self.pipeline.step3_process_data('train', apply_sampling=True)\n",
|
|||
|
" if X_train is None:\n",
|
|||
|
" raise ValueError(\"无法加载训练数据\")\n",
|
|||
|
" train_counts = Counter(y_train)\n",
|
|||
|
" print(f\" ✅ 训练数据准备完成: {X_train.shape[0]:,} 样本, {X_train.shape[1]} 特征\")\n",
|
|||
|
" print(f\" 📊 训练集(采样后)分布 (标签0: {train_counts.get(0, 0):,}, 标签40: {train_counts.get(40, 0):,})\")\n",
|
|||
|
" \n",
|
|||
|
" return lgb.Dataset(X_train, label=y_train, free_raw_data=False)\n",
|
|||
|
" \n",
|
|||
|
" # 余弦退火调度器函数\n",
|
|||
|
" def _cosine_annealing_scheduler(self, current_round, t_max):\n",
|
|||
|
" eta_max = self.initial_learning_rate\n",
|
|||
|
" eta_min = self.min_learning_rate\n",
|
|||
|
" lr = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * current_round / t_max))\n",
|
|||
|
" return lr\n",
|
|||
|
" \n",
|
|||
|
" def train_incremental(self, num_boost_round=100, early_stopping_rounds=10):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 增量分批训练\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" print(f\"\\n🚀 开始智能分批训练...\")\n",
|
|||
|
" print(f\" 📝 训练轮数 (每批次): {num_boost_round}\")\n",
|
|||
|
" print(f\" ⏹️ 早停轮数: {early_stopping_rounds}\")\n",
|
|||
|
" print(\"=\" * 60)\n",
|
|||
|
" \n",
|
|||
|
" # 准备验证数据\n",
|
|||
|
" val_data = self.prepare_validation_data()\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n🔄 开始分批增量训练...\")\n",
|
|||
|
" total_start_time = time.time()\n",
|
|||
|
" \n",
|
|||
|
" # ⭐️ 新增: 为学习率调度器定义T_max\n",
|
|||
|
" # 我们将每个批次的训练视为一个完整的退火周期\n",
|
|||
|
" t_max_per_batch = num_boost_round\n",
|
|||
|
" \n",
|
|||
|
" for train_data, filename in self.get_training_batch_generator():\n",
|
|||
|
" self.batch_count += 1\n",
|
|||
|
" batch_start_time = time.time()\n",
|
|||
|
" self.last_batch_lr_history = [] # 重置每个批次的LR历史\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n📈 批次 {self.batch_count}: {filename}\")\n",
|
|||
|
" \n",
|
|||
|
" # ⭐️ 新增: 创建学习率调度回调 和 记录回调\n",
|
|||
|
" lr_scheduler_callback = lgb.reset_parameter(\n",
|
|||
|
" learning_rate=lambda current_round: self._cosine_annealing_scheduler(current_round, t_max_per_batch)\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" # 这个简单的回调用于记录每个周期的学习率,以便后续可视化\n",
|
|||
|
" def record_lr_callback(env):\n",
|
|||
|
" self.last_batch_lr_history.append(env.model.params['learning_rate'])\n",
|
|||
|
"\n",
|
|||
|
" # 组合所有回调\n",
|
|||
|
" training_callbacks = [\n",
|
|||
|
" lgb.early_stopping(stopping_rounds=early_stopping_rounds, verbose=True),\n",
|
|||
|
" lgb.log_evaluation(period=10), # 每10轮打印一次\n",
|
|||
|
" lr_scheduler_callback,\n",
|
|||
|
" record_lr_callback\n",
|
|||
|
" ]\n",
|
|||
|
"\n",
|
|||
|
" # 训练当前批次\n",
|
|||
|
" current_model_args = {\n",
|
|||
|
" 'params': self.params,\n",
|
|||
|
" 'train_set': train_data,\n",
|
|||
|
" 'num_boost_round': num_boost_round,\n",
|
|||
|
" 'valid_sets': [val_data],\n",
|
|||
|
" 'valid_names': ['validation'],\n",
|
|||
|
" 'callbacks': training_callbacks\n",
|
|||
|
" }\n",
|
|||
|
" \n",
|
|||
|
" if self.model is None:\n",
|
|||
|
" print(\" 🎯 初始模型训练...\")\n",
|
|||
|
" self.model = lgb.train(**current_model_args)\n",
|
|||
|
" else:\n",
|
|||
|
" print(\" ⚡ 增量训练...\")\n",
|
|||
|
" current_model_args['init_model'] = self.model\n",
|
|||
|
" self.model = lgb.train(**current_model_args)\n",
|
|||
|
"\n",
|
|||
|
" # 记录训练历史\n",
|
|||
|
" batch_time = time.time() - batch_start_time\n",
|
|||
|
" \n",
|
|||
|
" # 评估当前模型\n",
|
|||
|
" val_pred = self.model.predict(self.X_val)\n",
|
|||
|
" val_accuracy = (val_pred.argmax(axis=1) == self.y_val).mean()\n",
|
|||
|
" \n",
|
|||
|
" batch_info = {\n",
|
|||
|
" 'batch': self.batch_count,\n",
|
|||
|
" 'filename': filename,\n",
|
|||
|
" 'time': batch_time,\n",
|
|||
|
" 'val_accuracy': val_accuracy,\n",
|
|||
|
" 'num_trees': self.model.num_trees(),\n",
|
|||
|
" 'lr_history': self.last_batch_lr_history.copy() # 保存当前批次的LR历史\n",
|
|||
|
" }\n",
|
|||
|
" \n",
|
|||
|
" self.training_history.append(batch_info)\n",
|
|||
|
" \n",
|
|||
|
" print(f\" ✅ 批次完成: {batch_time:.1f}秒\")\n",
|
|||
|
" print(f\" 📊 验证准确率: {val_accuracy:.4f}\")\n",
|
|||
|
" print(f\" 🌳 模型树数: {self.model.num_trees()}\")\n",
|
|||
|
" \n",
|
|||
|
" model_path = f\"smart_batch_model_batch_{self.batch_count}.txt\"\n",
|
|||
|
" self.model.save_model(model_path)\n",
|
|||
|
" print(f\" 💾 模型已保存: {model_path}\")\n",
|
|||
|
" \n",
|
|||
|
" total_time = time.time() - total_start_time\n",
|
|||
|
" print(f\"\\n🎉 智能分批训练完成!\")\n",
|
|||
|
" print(f\" ⏱️ 总训练时间: {total_time:.1f}秒\")\n",
|
|||
|
" print(f\" 📊 处理批次数: {self.batch_count}\")\n",
|
|||
|
" print(f\" 🌳 最终模型树数: {self.model.num_trees()}\")\n",
|
|||
|
" \n",
|
|||
|
" return self.model\n",
|
|||
|
" \n",
|
|||
|
" def train(self, num_boost_round=1000, early_stopping_rounds=50):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 执行一次性全量训练\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" print(f\"\\n🚀 开始全量数据训练...\")\n",
|
|||
|
" print(f\" 📝 训练轮数: {num_boost_round}\")\n",
|
|||
|
" print(f\" ⏹️ 早停轮数: {early_stopping_rounds}\")\n",
|
|||
|
" print(\"=\" * 60)\n",
|
|||
|
" \n",
|
|||
|
" # 准备数据\n",
|
|||
|
" train_data, val_data, X_val, y_val = self.prepare_full_data()\n",
|
|||
|
" \n",
|
|||
|
" start_time = time.time()\n",
|
|||
|
" \n",
|
|||
|
" # 定义学习率调度和记录回调\n",
|
|||
|
" lr_scheduler_callback = lgb.reset_parameter(\n",
|
|||
|
" learning_rate=lambda current_round: self._cosine_annealing_scheduler(current_round, num_boost_round)\n",
|
|||
|
" )\n",
|
|||
|
" def record_lr_callback(env):\n",
|
|||
|
" self.lr_history.append(env.model.params['learning_rate'])\n",
|
|||
|
" \n",
|
|||
|
" training_callbacks = [\n",
|
|||
|
" lgb.early_stopping(stopping_rounds=early_stopping_rounds, verbose=True),\n",
|
|||
|
" lgb.log_evaluation(period=1), # 每100轮打印日志\n",
|
|||
|
" lr_scheduler_callback,\n",
|
|||
|
" record_lr_callback\n",
|
|||
|
" ]\n",
|
|||
|
" \n",
|
|||
|
" # 训练模型\n",
|
|||
|
" print(\"\\n📈 开始模型训练...\")\n",
|
|||
|
" self.model = lgb.train(\n",
|
|||
|
" params=self.params,\n",
|
|||
|
" train_set=train_data,\n",
|
|||
|
" num_boost_round=num_boost_round,\n",
|
|||
|
" valid_sets=[val_data],\n",
|
|||
|
" valid_names=['validation'],\n",
|
|||
|
" callbacks=training_callbacks\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" training_time = time.time() - start_time\n",
|
|||
|
" \n",
|
|||
|
" # 评估模型\n",
|
|||
|
" val_pred = self.model.predict(X_val)\n",
|
|||
|
" val_accuracy = (val_pred.argmax(axis=1) == y_val).mean()\n",
|
|||
|
" \n",
|
|||
|
" # 记录训练历史\n",
|
|||
|
" self.training_history = {\n",
|
|||
|
" 'time': training_time,\n",
|
|||
|
" 'val_accuracy': val_accuracy,\n",
|
|||
|
" 'num_trees': self.model.num_trees(),\n",
|
|||
|
" 'lr_history': self.lr_history,\n",
|
|||
|
" 'best_iteration': self.model.best_iteration\n",
|
|||
|
" }\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n🎉 全量数据训练完成!\")\n",
|
|||
|
" print(f\" ⏱️ 总训练时间: {training_time:.1f}秒\")\n",
|
|||
|
" print(f\" 🌳 最终模型树数: {self.model.num_trees()} (最佳轮次: {self.model.best_iteration})\")\n",
|
|||
|
" print(f\" 🎯 最终验证准确率: {val_accuracy:.4f}\")\n",
|
|||
|
" \n",
|
|||
|
" # 保存模型\n",
|
|||
|
" model_path = \"full_train_model.txt\"\n",
|
|||
|
" self.model.save_model(model_path)\n",
|
|||
|
" print(f\" 💾 模型已保存: {model_path}\")\n",
|
|||
|
" \n",
|
|||
|
" return self.model\n",
|
|||
|
" \n",
|
|||
|
" def plot_training_progress(self):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 绘制训练进度\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if not self.training_history:\n",
|
|||
|
" print(\"❌ 没有训练历史记录\")\n",
|
|||
|
" return\n",
|
|||
|
" \n",
|
|||
|
" # ⭐️ 修改: 增加学习率的可视化图表\n",
|
|||
|
" fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = plt.subplots(3, 2, figsize=(15, 15))\n",
|
|||
|
" \n",
|
|||
|
" batches = [h['batch'] for h in self.training_history]\n",
|
|||
|
" accuracies = [h['val_accuracy'] for h in self.training_history]\n",
|
|||
|
" times = [h['time'] for h in self.training_history]\n",
|
|||
|
" trees = [h['num_trees'] for h in self.training_history]\n",
|
|||
|
" \n",
|
|||
|
" # 1. 验证准确率\n",
|
|||
|
" ax1.plot(batches, accuracies, 'b-o', linewidth=2, markersize=6)\n",
|
|||
|
" ax1.set_xlabel('Training Batch')\n",
|
|||
|
" ax1.set_ylabel('Validation Accuracy')\n",
|
|||
|
" ax1.set_title('Validation Accuracy Progress')\n",
|
|||
|
" ax1.grid(True, alpha=0.3)\n",
|
|||
|
" ax1.set_ylim(0, 1)\n",
|
|||
|
" \n",
|
|||
|
" # 2. 批次训练时间\n",
|
|||
|
" ax2.bar(batches, times, color='green', alpha=0.7)\n",
|
|||
|
" ax2.set_xlabel('Training Batch')\n",
|
|||
|
" ax2.set_ylabel('Training Time (seconds)')\n",
|
|||
|
" ax2.set_title('Training Time per Batch')\n",
|
|||
|
" ax2.grid(True, alpha=0.3)\n",
|
|||
|
" \n",
|
|||
|
" # 3. 模型树数增长\n",
|
|||
|
" ax3.plot(batches, trees, 'r-s', linewidth=2, markersize=6)\n",
|
|||
|
" ax3.set_xlabel('Training Batch')\n",
|
|||
|
" ax3.set_ylabel('Number of Trees')\n",
|
|||
|
" ax3.set_title('Model Complexity Growth')\n",
|
|||
|
" ax3.grid(True, alpha=0.3)\n",
|
|||
|
" \n",
|
|||
|
" # 4. 累计准确率提升\n",
|
|||
|
" ax4.plot(batches, [acc - accuracies[0] for acc in accuracies], 'purple', linewidth=2, marker='D')\n",
|
|||
|
" ax4.set_xlabel('Training Batch')\n",
|
|||
|
" ax4.set_ylabel('Accuracy Improvement')\n",
|
|||
|
" ax4.set_title('Cumulative Accuracy Improvement')\n",
|
|||
|
" ax4.grid(True, alpha=0.3)\n",
|
|||
|
" ax4.axhline(y=0, color='black', linestyle='--', alpha=0.5)\n",
|
|||
|
"\n",
|
|||
|
" # ⭐️ 新增: 5. 最后一个批次的学习率曲线\n",
|
|||
|
" last_lr_history = self.training_history[-1]['lr_history']\n",
|
|||
|
" ax5.plot(range(len(last_lr_history)), last_lr_history, color='orange', marker='.')\n",
|
|||
|
" ax5.set_xlabel('Boosting Round in Last Batch')\n",
|
|||
|
" ax5.set_ylabel('Learning Rate')\n",
|
|||
|
" ax5.set_title(f'Cosine Annealing LR in Last Batch (Batch {batches[-1]})')\n",
|
|||
|
" ax5.grid(True, alpha=0.3)\n",
|
|||
|
" \n",
|
|||
|
" # 隐藏第六个子图\n",
|
|||
|
" ax6.axis('off')\n",
|
|||
|
"\n",
|
|||
|
" plt.tight_layout()\n",
|
|||
|
" plt.show()\n",
|
|||
|
" \n",
|
|||
|
" # 打印统计信息\n",
|
|||
|
" print(f\"\\n📈 训练进度统计:\")\n",
|
|||
|
" print(f\" 🎯 初始准确率: {accuracies[0]:.4f}\")\n",
|
|||
|
" print(f\" 🎯 最终准确率: {accuracies[-1]:.4f}\")\n",
|
|||
|
" print(f\" 📈 准确率提升: {accuracies[-1] - accuracies[0]:.4f}\")\n",
|
|||
|
" print(f\" ⏱️ 平均批次时间: {np.mean(times):.1f}秒\")\n",
|
|||
|
" print(f\" 🌳 最终模型树数: {trees[-1]}\")\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"print(\"🚀 创建智能分批训练器...\")\n",
|
|||
|
"# 实例化时可以传入最小学习率\n",
|
|||
|
"trainer = SmartBatchTrainer(pipeline, min_learning_rate=0.001) \n",
|
|||
|
"print(\"✅ 训练器创建完成,准备开始训练!\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# # 全量训练\n",
|
|||
|
"\n",
|
|||
|
"# print(\"🔥 开始智能分批训练!\")\n",
|
|||
|
"# print(\"=\" * 80)\n",
|
|||
|
"\n",
|
|||
|
"# # 训练参数\n",
|
|||
|
"# TRAINING_PARAMS = {\n",
|
|||
|
"# 'num_boost_round': 300, # 每批次的提升轮数\n",
|
|||
|
"# 'early_stopping_rounds': 15 # 早停轮数\n",
|
|||
|
"# }\n",
|
|||
|
"\n",
|
|||
|
"# print(f\"📝 训练配置:\")\n",
|
|||
|
"# print(f\" 训练轮数: {TRAINING_PARAMS['num_boost_round']}\")\n",
|
|||
|
"# print(f\" 早停轮数: {TRAINING_PARAMS['early_stopping_rounds']}\")\n",
|
|||
|
"# print(f\" 数据平衡: 启用(下采样标签0,40 + 过采样少数类)\")\n",
|
|||
|
"# print(f\" PCA降维: 7168 → {pipeline.pca_components} 特征\")\n",
|
|||
|
"\n",
|
|||
|
"# print(f\"\\n🚀 启动训练...\")\n",
|
|||
|
"\n",
|
|||
|
"# # 开始训练\n",
|
|||
|
"# model = trainer.train(\n",
|
|||
|
"# num_boost_round=TRAINING_PARAMS['num_boost_round'],\n",
|
|||
|
"# early_stopping_rounds=TRAINING_PARAMS['early_stopping_rounds']\n",
|
|||
|
"# )"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 📊 训练结果分析"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 📊 训练结果分析和可视化\n",
|
|||
|
"\n",
|
|||
|
"print(\"📊 分析智能分批训练结果...\")\n",
|
|||
|
"print(\"=\" * 60)\n",
|
|||
|
"\n",
|
|||
|
"# 显示训练进度图表\n",
|
|||
|
"trainer.plot_training_progress()\n",
|
|||
|
"\n",
|
|||
|
"# 保存最终模型\n",
|
|||
|
"final_model_path = \"smart_pipeline_final_model.txt\"\n",
|
|||
|
"if trainer.model:\n",
|
|||
|
" trainer.model.save_model(final_model_path)\n",
|
|||
|
" print(f\"\\n💾 最终模型已保存: {final_model_path}\")\n",
|
|||
|
"\n",
|
|||
|
"# 详细分析\n",
|
|||
|
"if trainer.training_history:\n",
|
|||
|
" print(f\"\\n📈 详细训练分析:\")\n",
|
|||
|
" print(f\" 🎯 训练批次总数: {len(trainer.training_history)}\")\n",
|
|||
|
" \n",
|
|||
|
" # 最佳批次\n",
|
|||
|
" best_batch = max(trainer.training_history, key=lambda x: x['val_accuracy'])\n",
|
|||
|
" print(f\" 🏆 最佳验证准确率: {best_batch['val_accuracy']:.4f} (批次 {best_batch['batch']})\")\n",
|
|||
|
" \n",
|
|||
|
" # 训练效率\n",
|
|||
|
" total_training_time = sum(h['time'] for h in trainer.training_history)\n",
|
|||
|
" avg_batch_time = total_training_time / len(trainer.training_history)\n",
|
|||
|
" print(f\" ⏱️ 总训练时间: {total_training_time:.1f}秒\")\n",
|
|||
|
" print(f\" ⏱️ 平均批次时间: {avg_batch_time:.1f}秒\")\n",
|
|||
|
" \n",
|
|||
|
" # 模型复杂度\n",
|
|||
|
" final_trees = trainer.training_history[-1]['num_trees']\n",
|
|||
|
" print(f\" 🌳 最终模型树数: {final_trees}\")\n",
|
|||
|
" \n",
|
|||
|
" # 收敛性分析\n",
|
|||
|
" recent_accs = [h['val_accuracy'] for h in trainer.training_history[-3:]]\n",
|
|||
|
" if len(recent_accs) >= 2:\n",
|
|||
|
" acc_stability = max(recent_accs) - min(recent_accs)\n",
|
|||
|
" print(f\" 📈 准确率稳定性: {acc_stability:.4f} (最近3批次方差)\")\n",
|
|||
|
" \n",
|
|||
|
" if acc_stability < 0.01:\n",
|
|||
|
" print(\" ✅ 模型已收敛 (准确率变化 < 1%)\")\n",
|
|||
|
" else:\n",
|
|||
|
" print(\" ⚠️ 模型可能需要更多训练\")\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n🎉 智能分批训练分析完成!\")\n",
|
|||
|
"print(f\" 💡 使用了改进的数据平衡策略和PCA降维\")\n",
|
|||
|
"print(f\" 💡 训练集应用了下采样+过采样,验证集保持原始分布\")\n",
|
|||
|
"print(f\" 💡 实现了内存友好的分批处理\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 🧪 模型性能评估"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 🧪 模型性能评估\n",
|
|||
|
"\n",
|
|||
|
"from sklearn.metrics import classification_report, confusion_matrix\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"\n",
|
|||
|
"def evaluate_model_performance(model, pipeline, data_type='val'):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 评估模型在指定数据集上的性能\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" print(f\"🧪 评估模型在{data_type}数据集上的性能...\")\n",
|
|||
|
" \n",
|
|||
|
" # 加载数据\n",
|
|||
|
" X, y = pipeline.step3_process_data(data_type, apply_sampling=False)\n",
|
|||
|
" \n",
|
|||
|
" if X is None or y is None:\n",
|
|||
|
" print(f\"❌ 无法加载{data_type}数据\")\n",
|
|||
|
" return None\n",
|
|||
|
" \n",
|
|||
|
" print(f\" 📊 数据集大小: {X.shape[0]:,} 样本, {X.shape[1]} 特征\")\n",
|
|||
|
" \n",
|
|||
|
" # 预测\n",
|
|||
|
" start_time = time.time()\n",
|
|||
|
" y_pred_proba = model.predict(X)\n",
|
|||
|
" y_pred = y_pred_proba.argmax(axis=1)\n",
|
|||
|
" pred_time = time.time() - start_time\n",
|
|||
|
" \n",
|
|||
|
" # 计算性能指标\n",
|
|||
|
" accuracy = (y_pred == y).mean()\n",
|
|||
|
" \n",
|
|||
|
" print(f\" ⏱️ 预测时间: {pred_time:.2f}秒\")\n",
|
|||
|
" print(f\" 🎯 整体准确率: {accuracy:.4f}\")\n",
|
|||
|
" \n",
|
|||
|
" # 分析各类别性能\n",
|
|||
|
" from collections import Counter\n",
|
|||
|
" true_counts = Counter(y)\n",
|
|||
|
" pred_counts = Counter(y_pred)\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n📊 标签分布对比:\")\n",
|
|||
|
" print(\"标签 | 真实数量 | 预测数量 | 准确率\")\n",
|
|||
|
" print(\"-\" * 40)\n",
|
|||
|
" \n",
|
|||
|
" label_accuracies = {}\n",
|
|||
|
" for label in range(41):\n",
|
|||
|
" if label in true_counts:\n",
|
|||
|
" label_mask = (y == label)\n",
|
|||
|
" if label_mask.sum() > 0:\n",
|
|||
|
" label_acc = (y_pred[label_mask] == label).mean()\n",
|
|||
|
" label_accuracies[label] = label_acc\n",
|
|||
|
" true_count = true_counts.get(label, 0)\n",
|
|||
|
" pred_count = pred_counts.get(label, 0)\n",
|
|||
|
" print(f\"{label:4d} | {true_count:8,} | {pred_count:8,} | {label_acc:7.3f}\")\n",
|
|||
|
" \n",
|
|||
|
" # 重点分析关键标签\n",
|
|||
|
" print(f\"\\n🔍 关键标签性能分析:\")\n",
|
|||
|
" key_labels = [0, 40] # 下采样的标签\n",
|
|||
|
" for label in key_labels:\n",
|
|||
|
" if label in label_accuracies:\n",
|
|||
|
" acc = label_accuracies[label]\n",
|
|||
|
" count = true_counts.get(label, 0)\n",
|
|||
|
" print(f\" 标签 {label} (下采样目标): 准确率 {acc:.4f}, 样本数 {count:,}\")\n",
|
|||
|
" \n",
|
|||
|
" # 少数类性能\n",
|
|||
|
" minority_labels = [label for label, count in true_counts.items() \n",
|
|||
|
" if count < 200 and label not in [0, 40]]\n",
|
|||
|
" if minority_labels:\n",
|
|||
|
" minority_accs = [label_accuracies.get(label, 0) for label in minority_labels[:5]]\n",
|
|||
|
" avg_minority_acc = np.mean(minority_accs) if minority_accs else 0\n",
|
|||
|
" print(f\" 少数类平均准确率 (前5个): {avg_minority_acc:.4f}\")\n",
|
|||
|
" \n",
|
|||
|
" # 置信度分析\n",
|
|||
|
" max_proba = y_pred_proba.max(axis=1)\n",
|
|||
|
" print(f\"\\n📈 预测置信度分析:\")\n",
|
|||
|
" print(f\" 平均置信度: {max_proba.mean():.4f}\")\n",
|
|||
|
" print(f\" 置信度中位数: {np.median(max_proba):.4f}\")\n",
|
|||
|
" print(f\" 高置信度预测 (>0.9): {(max_proba > 0.9).sum():,} / {len(max_proba):,} ({(max_proba > 0.9).mean():.2%})\")\n",
|
|||
|
" \n",
|
|||
|
" return {\n",
|
|||
|
" 'accuracy': accuracy,\n",
|
|||
|
" 'prediction_time': pred_time,\n",
|
|||
|
" 'label_accuracies': label_accuracies,\n",
|
|||
|
" 'confidence_stats': {\n",
|
|||
|
" 'mean': max_proba.mean(),\n",
|
|||
|
" 'median': np.median(max_proba),\n",
|
|||
|
" 'high_confidence_ratio': (max_proba > 0.9).mean()\n",
|
|||
|
" }\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
"# 评估模型性能\n",
|
|||
|
"if trainer.model:\n",
|
|||
|
" print(\"🧪 开始模型性能评估...\")\n",
|
|||
|
" \n",
|
|||
|
" # 验证集评估\n",
|
|||
|
" val_results = evaluate_model_performance(trainer.model, pipeline, 'val')\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n\" + \"=\"*60)\n",
|
|||
|
" print(\"🎉 智能分批训练+数据平衡 评估完成!\")\n",
|
|||
|
" print(f\"✅ 实现了数据平衡和PCA降维的完整流程\")\n",
|
|||
|
" print(f\"✅ 使用了内存友好的分批训练策略\")\n",
|
|||
|
" print(f\"✅ 保持了验证集的原始分布以确保评估客观性\")\n",
|
|||
|
"else:\n",
|
|||
|
" print(\"❌ 模型尚未训练完成,请等待训练结束后运行此评估\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# 测试集总评-连接语言模型"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"smart_pipeline_final_model.txt"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"ename": "",
|
|||
|
"evalue": "",
|
|||
|
"output_type": "error",
|
|||
|
"traceback": [
|
|||
|
"\u001b[1;31mFailed to connect to the remote Jupyter Server 'https://kkb-production.jupyter-proxy.kaggle.net/'. Verify the server is running and reachable. (Failed to connect to the remote Jupyter Server 'https://kkb-production.jupyter-proxy.kaggle.net/'. Verify the server is running and reachable. (request to https://kkb-production.jupyter-proxy.kaggle.net/k/261889069/eyJhbGciOiJkaXIiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2IiwidHlwIjoiSldUIn0..3JPUo77-E5Unxk40FcVuDw.UQ5HO58Y63DL5av9cYv59hBnb4Rw6GbfyzcRyPp9ID-u0ODR4KJuqXcaUS7TKXpddj60a_dRVtxSjqjhxD7xtc5fM80xoPpibRRjVKonb_HwqUKs_96UIdvPfI_MeKXYJ3Tb0AXf-5TxLoOaYyps8zaC5bp8r7jzr1uNTM56M7RH09kDMCNnIhvD7zWEZJlQULZ3sY6N8v36OVsY05q5c6ZnVePk92Qw-buRKiNK5bIo4qmSjUssmdP5SqMShwc3.iAgJSIm0bnGknjcE5jhAvQ/proxy/api/kernels?1757936253840 failed, reason: Client network socket disconnected before secure TLS connection was established).)."
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 🚀 加载预训练的 LightGBM 分类模型\n",
|
|||
|
"\n",
|
|||
|
"import lightgbm as lgb\n",
|
|||
|
"import os\n",
|
|||
|
"\n",
|
|||
|
"def load_lgbm_model(model_path):\n",
|
|||
|
" \"\"\" \n",
|
|||
|
" 加载预训练的 LightGBM 模型\n",
|
|||
|
" \n",
|
|||
|
" Args:\n",
|
|||
|
" model_path: 模型文件路径\n",
|
|||
|
" \n",
|
|||
|
" Returns:\n",
|
|||
|
" lgb.Booster: 加载的模型\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if not os.path.exists(model_path):\n",
|
|||
|
" raise FileNotFoundError(f\"模型文件不存在: {model_path}\")\n",
|
|||
|
" \n",
|
|||
|
" print(f\"📂 正在加载模型: {model_path}\")\n",
|
|||
|
" \n",
|
|||
|
" # 加载模型\n",
|
|||
|
" model = lgb.Booster(model_file=model_path)\n",
|
|||
|
" \n",
|
|||
|
" print(f\"✅ 模型加载成功!\")\n",
|
|||
|
" print(f\" 🌳 模型树数: {model.num_trees()}\")\n",
|
|||
|
" print(f\" 📊 特征数: {model.num_feature()}\")\n",
|
|||
|
" print(f\" 🏷️ 类别数: {model.num_model_per_iteration()}\")\n",
|
|||
|
" \n",
|
|||
|
" # 显示模型基本信息\n",
|
|||
|
" model_info = {\n",
|
|||
|
" 'num_trees': model.num_trees(),\n",
|
|||
|
" 'num_features': model.num_feature(),\n",
|
|||
|
" 'num_classes': model.num_model_per_iteration(),\n",
|
|||
|
" 'objective': model.params.get('objective', 'unknown'),\n",
|
|||
|
" 'boosting_type': model.params.get('boosting_type', 'unknown')\n",
|
|||
|
" }\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n📋 模型详细信息:\")\n",
|
|||
|
" for key, value in model_info.items():\n",
|
|||
|
" print(f\" {key}: {value}\")\n",
|
|||
|
" \n",
|
|||
|
" return model\n",
|
|||
|
"\n",
|
|||
|
"# 加载我们训练好的模型\n",
|
|||
|
"MODEL_PATH = \"full_train_model.txt\"\n",
|
|||
|
"\n",
|
|||
|
"try:\n",
|
|||
|
" lgbm_model = load_lgbm_model(MODEL_PATH)\n",
|
|||
|
" print(f\"\\n🎉 LightGBM 模型加载完成,准备用于推理!\")\n",
|
|||
|
" \n",
|
|||
|
"except FileNotFoundError as e:\n",
|
|||
|
" print(f\"❌ 错误: {e}\")\n",
|
|||
|
" print(f\"💡 请确保模型文件 '{MODEL_PATH}' 存在于当前目录\")\n",
|
|||
|
" lgbm_model = None\n",
|
|||
|
" \n",
|
|||
|
"except Exception as e:\n",
|
|||
|
" print(f\"❌ 加载模型时发生错误: {e}\")\n",
|
|||
|
" lgbm_model = None"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 🧪 测试模型预测功能\n",
|
|||
|
"\n",
|
|||
|
"def test_model_prediction(model, test_features=None):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 测试模型的预测功能\n",
|
|||
|
" \n",
|
|||
|
" Args:\n",
|
|||
|
" model: LightGBM 模型\n",
|
|||
|
" test_features: 测试特征数据,如果为None则创建虚拟数据\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if model is None:\n",
|
|||
|
" print(\"❌ 模型未加载,无法进行测试\")\n",
|
|||
|
" return\n",
|
|||
|
" \n",
|
|||
|
" print(\"🧪 测试模型预测功能...\")\n",
|
|||
|
" \n",
|
|||
|
" # 如果没有提供测试数据,创建虚拟数据\n",
|
|||
|
" if test_features is None:\n",
|
|||
|
" print(\" 📝 创建虚拟测试数据...\")\n",
|
|||
|
" # 根据模型期望的特征数创建随机数据\n",
|
|||
|
" num_features = model.num_feature()\n",
|
|||
|
" num_samples = 5\n",
|
|||
|
" test_features = np.random.randn(num_samples, num_features).astype(np.float32)\n",
|
|||
|
" print(f\" 🔢 虚拟数据形状: {test_features.shape}\")\n",
|
|||
|
" \n",
|
|||
|
" try:\n",
|
|||
|
" # 进行预测\n",
|
|||
|
" predictions = model.predict(test_features)\n",
|
|||
|
" \n",
|
|||
|
" print(f\" ✅ 预测成功!\")\n",
|
|||
|
" print(f\" 📊 预测形状: {predictions.shape}\")\n",
|
|||
|
" print(f\" 🎯 预测范例 (前3个样本的前5个类别概率):\")\n",
|
|||
|
" \n",
|
|||
|
" for i in range(min(3, predictions.shape[0])):\n",
|
|||
|
" pred_probs = predictions[i][:5] # 只显示前5个类别\n",
|
|||
|
" predicted_class = np.argmax(predictions[i])\n",
|
|||
|
" max_prob = np.max(predictions[i])\n",
|
|||
|
" \n",
|
|||
|
" print(f\" 样本 {i+1}: 预测类别={predicted_class}, 置信度={max_prob:.4f}\")\n",
|
|||
|
" print(f\" 前5类概率: {pred_probs}\")\n",
|
|||
|
" \n",
|
|||
|
" print(f\" 📈 预测置信度统计:\")\n",
|
|||
|
" max_probs = np.max(predictions, axis=1)\n",
|
|||
|
" print(f\" 平均置信度: {np.mean(max_probs):.4f}\")\n",
|
|||
|
" print(f\" 最高置信度: {np.max(max_probs):.4f}\")\n",
|
|||
|
" print(f\" 最低置信度: {np.min(max_probs):.4f}\")\n",
|
|||
|
" \n",
|
|||
|
" return True\n",
|
|||
|
" \n",
|
|||
|
" except Exception as e:\n",
|
|||
|
" print(f\" ❌ 预测失败: {e}\")\n",
|
|||
|
" return False\n",
|
|||
|
"\n",
|
|||
|
"# 测试加载的模型\n",
|
|||
|
"if lgbm_model is not None:\n",
|
|||
|
" test_success = test_model_prediction(lgbm_model)\n",
|
|||
|
" \n",
|
|||
|
" if test_success:\n",
|
|||
|
" print(f\"\\n🎉 模型测试成功! 可以用于实际推理任务\")\n",
|
|||
|
" print(f\"💡 模型期望输入: {lgbm_model.num_feature()} 维特征向量\")\n",
|
|||
|
" print(f\"💡 模型输出: {lgbm_model.num_model_per_iteration()} 个类别的概率分布\")\n",
|
|||
|
" else:\n",
|
|||
|
" print(f\"\\n❌ 模型测试失败,请检查模型文件\")\n",
|
|||
|
"else:\n",
|
|||
|
" print(\"❌ 模型未加载,跳过测试\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 测试集的实义测试"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"✅ 神经数据处理函数定义完成\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 🚀 仿照RNN评估流程处理测试集数据 + LightGBM预测\n",
|
|||
|
"\n",
|
|||
|
"import h5py\n",
|
|||
|
"import torch\n",
|
|||
|
"import torch.nn.functional as F\n",
|
|||
|
"from scipy.ndimage import gaussian_filter1d\n",
|
|||
|
"import os\n",
|
|||
|
"from tqdm import tqdm\n",
|
|||
|
"\n",
|
|||
|
"def load_h5py_file(file_path, b2txt_csv_df):\n",
|
|||
|
" data = {\n",
|
|||
|
" 'neural_features': [],\n",
|
|||
|
" 'n_time_steps': [],\n",
|
|||
|
" 'seq_class_ids': [],\n",
|
|||
|
" 'seq_len': [],\n",
|
|||
|
" 'transcriptions': [],\n",
|
|||
|
" 'sentence_label': [],\n",
|
|||
|
" 'session': [],\n",
|
|||
|
" 'block_num': [],\n",
|
|||
|
" 'trial_num': [],\n",
|
|||
|
" 'corpus': [],\n",
|
|||
|
" }\n",
|
|||
|
" # Open the hdf5 file for that day\n",
|
|||
|
" with h5py.File(file_path, 'r') as f:\n",
|
|||
|
"\n",
|
|||
|
" keys = list(f.keys())\n",
|
|||
|
"\n",
|
|||
|
" # For each trial in the selected trials in that day\n",
|
|||
|
" for key in keys:\n",
|
|||
|
" g = f[key]\n",
|
|||
|
"\n",
|
|||
|
" neural_features = g['input_features'][:] # pyright: ignore[reportIndexIssue]\n",
|
|||
|
" n_time_steps = g.attrs['n_time_steps']\n",
|
|||
|
" seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None # type: ignore\n",
|
|||
|
" seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None\n",
|
|||
|
" transcription = g['transcription'][:] if 'transcription' in g else None # type: ignore\n",
|
|||
|
" sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None # pyright: ignore[reportIndexIssue]\n",
|
|||
|
" session = g.attrs['session']\n",
|
|||
|
" block_num = g.attrs['block_num']\n",
|
|||
|
" trial_num = g.attrs['trial_num']\n",
|
|||
|
"\n",
|
|||
|
" # match this trial up with the csv to get the corpus name\n",
|
|||
|
" year, month, day = session.split('.')[1:] # pyright: ignore[reportAttributeAccessIssue]\n",
|
|||
|
" date = f'{year}-{month}-{day}'\n",
|
|||
|
" row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)]\n",
|
|||
|
" corpus_name = row['Corpus'].values[0]\n",
|
|||
|
"\n",
|
|||
|
" data['neural_features'].append(neural_features)\n",
|
|||
|
" data['n_time_steps'].append(n_time_steps)\n",
|
|||
|
" data['seq_class_ids'].append(seq_class_ids)\n",
|
|||
|
" data['seq_len'].append(seq_len)\n",
|
|||
|
" data['transcriptions'].append(transcription)\n",
|
|||
|
" data['sentence_label'].append(sentence_label)\n",
|
|||
|
" data['session'].append(session)\n",
|
|||
|
" data['block_num'].append(block_num)\n",
|
|||
|
" data['trial_num'].append(trial_num)\n",
|
|||
|
" data['corpus'].append(corpus_name)\n",
|
|||
|
" return data\n",
|
|||
|
"\n",
|
|||
|
"def gauss_smooth_torch(inputs, device, smooth_kernel_std=2, smooth_kernel_size=100, padding='valid'):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" PyTorch版本的高斯平滑 (仿照data_augmentations.py)\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" # 创建高斯核\n",
|
|||
|
" inp = np.zeros(smooth_kernel_size, dtype=np.float32)\n",
|
|||
|
" inp[smooth_kernel_size // 2] = 1\n",
|
|||
|
" gaussKernel = gaussian_filter1d(inp, smooth_kernel_std)\n",
|
|||
|
" \n",
|
|||
|
" # 过滤小值\n",
|
|||
|
" validIdx = np.argwhere(gaussKernel > 0.01)\n",
|
|||
|
" if len(validIdx) > 0:\n",
|
|||
|
" gaussKernel = gaussKernel[validIdx.flatten()]\n",
|
|||
|
" gaussKernel = np.squeeze(gaussKernel / np.sum(gaussKernel))\n",
|
|||
|
" \n",
|
|||
|
" # 转换为PyTorch张量\n",
|
|||
|
" gaussKernel = torch.tensor(gaussKernel, dtype=inputs.dtype, device=device)\n",
|
|||
|
" gaussKernel = gaussKernel.view(1, 1, -1) # [1, 1, kernel_size]\n",
|
|||
|
" \n",
|
|||
|
" # 准备卷积\n",
|
|||
|
" B, T, C = inputs.shape\n",
|
|||
|
" inputs = inputs.permute(0, 2, 1) # [B, C, T]\n",
|
|||
|
" gaussKernel = gaussKernel.repeat(C, 1, 1) # [C, 1, kernel_size]\n",
|
|||
|
" \n",
|
|||
|
" # 执行卷积\n",
|
|||
|
" smoothed = F.conv1d(inputs, gaussKernel, padding=padding, groups=C)\n",
|
|||
|
" return smoothed.permute(0, 2, 1) # [B, T, C]\n",
|
|||
|
"\n",
|
|||
|
"def apply_patch_processing(x, patch_size=14, patch_stride=4):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 应用patch处理 (仿照rnn_model.py的forward方法)\n",
|
|||
|
" \n",
|
|||
|
" Args:\n",
|
|||
|
" x: 输入张量 [batch, timesteps, features]\n",
|
|||
|
" patch_size: patch大小\n",
|
|||
|
" patch_stride: patch步长\n",
|
|||
|
" \n",
|
|||
|
" Returns:\n",
|
|||
|
" 处理后的张量 [batch, num_patches, patch_size * features]\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if patch_size <= 0:\n",
|
|||
|
" return x\n",
|
|||
|
" \n",
|
|||
|
" x = x.unsqueeze(1) # [batches, 1, timesteps, feature_dim]\n",
|
|||
|
" x = x.permute(0, 3, 1, 2) # [batches, feature_dim, 1, timesteps]\n",
|
|||
|
" \n",
|
|||
|
" # 使用unfold提取patches\n",
|
|||
|
" x_unfold = x.unfold(3, patch_size, patch_stride) # [batches, feature_dim, 1, num_patches, patch_size]\n",
|
|||
|
" \n",
|
|||
|
" # 移除虚拟维度并重新排列\n",
|
|||
|
" x_unfold = x_unfold.squeeze(2) # [batches, feature_dim, num_patches, patch_size]\n",
|
|||
|
" x_unfold = x_unfold.permute(0, 2, 3, 1) # [batches, num_patches, patch_size, feature_dim]\n",
|
|||
|
" \n",
|
|||
|
" # 展平最后两个维度\n",
|
|||
|
" x = x_unfold.reshape(x_unfold.size(0), x_unfold.size(1), -1) # [batch, num_patches, patch_size * features]\n",
|
|||
|
" \n",
|
|||
|
" return x\n",
|
|||
|
"\n",
|
|||
|
"def process_neural_data_for_lgbm(neural_input, device, model_args):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 仿照RNN处理流程:高斯平滑 + patch处理\n",
|
|||
|
" \n",
|
|||
|
" Args:\n",
|
|||
|
" neural_input: 神经数据 [batch, time, features]\n",
|
|||
|
" device: PyTorch设备\n",
|
|||
|
" model_args: 模型参数配置\n",
|
|||
|
" \n",
|
|||
|
" Returns:\n",
|
|||
|
" 处理后的特征数据,准备输入LightGBM\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" # 1. 高斯平滑\n",
|
|||
|
" smoothed_input = gauss_smooth_torch(\n",
|
|||
|
" inputs=neural_input,\n",
|
|||
|
" device=device,\n",
|
|||
|
" smooth_kernel_std=model_args.get('smooth_kernel_std', 2),\n",
|
|||
|
" smooth_kernel_size=model_args.get('smooth_kernel_size', 100),\n",
|
|||
|
" padding='valid'\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" # 2. Patch处理\n",
|
|||
|
" patch_size = model_args.get('patch_size', 14)\n",
|
|||
|
" patch_stride = model_args.get('patch_stride', 4)\n",
|
|||
|
" \n",
|
|||
|
" if patch_size > 0:\n",
|
|||
|
" patched_input = apply_patch_processing(\n",
|
|||
|
" smoothed_input, \n",
|
|||
|
" patch_size=patch_size, \n",
|
|||
|
" patch_stride=patch_stride\n",
|
|||
|
" )\n",
|
|||
|
" # 展平为2D: [batch * num_patches, patch_size * features]\n",
|
|||
|
" batch_size, num_patches, patch_features = patched_input.shape\n",
|
|||
|
" features_2d = patched_input.reshape(-1, patch_features)\n",
|
|||
|
" else:\n",
|
|||
|
" # 如果不使用patch,直接展平\n",
|
|||
|
" features_2d = smoothed_input.reshape(-1, smoothed_input.shape[-1])\n",
|
|||
|
" \n",
|
|||
|
" return features_2d.cpu().numpy()\n",
|
|||
|
"\n",
|
|||
|
"print(\"✅ 神经数据处理函数定义完成\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"t15.2023.08.11\tt15.2023.09.29\tt15.2023.11.04\tt15.2024.02.25\tt15.2024.07.19\n",
|
|||
|
"t15.2023.08.13\tt15.2023.10.01\tt15.2023.11.17\tt15.2024.03.03\tt15.2024.07.21\n",
|
|||
|
"t15.2023.08.18\tt15.2023.10.06\tt15.2023.11.19\tt15.2024.03.08\tt15.2024.07.28\n",
|
|||
|
"t15.2023.08.20\tt15.2023.10.08\tt15.2023.11.26\tt15.2024.03.15\tt15.2025.01.10\n",
|
|||
|
"t15.2023.08.25\tt15.2023.10.13\tt15.2023.12.03\tt15.2024.03.17\tt15.2025.01.12\n",
|
|||
|
"t15.2023.08.27\tt15.2023.10.15\tt15.2023.12.08\tt15.2024.04.25\tt15.2025.03.14\n",
|
|||
|
"t15.2023.09.01\tt15.2023.10.20\tt15.2023.12.10\tt15.2024.04.28\tt15.2025.03.16\n",
|
|||
|
"t15.2023.09.03\tt15.2023.10.22\tt15.2023.12.17\tt15.2024.05.10\tt15.2025.03.30\n",
|
|||
|
"t15.2023.09.24\tt15.2023.11.03\tt15.2023.12.29\tt15.2024.06.14\tt15.2025.04.13\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"!ls data/hdf5_data_final\t"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 🔥 测试集数据加载与LightGBM预测\n",
|
|||
|
"\n",
|
|||
|
"class TestSetPredictor:\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 测试集预测器 - 仿照RNN评估流程\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" \n",
|
|||
|
" def __init__(self, lgbm_model, pipeline, data_dir, device='cpu'):\n",
|
|||
|
" self.lgbm_model = lgbm_model\n",
|
|||
|
" self.pipeline = pipeline\n",
|
|||
|
" self.data_dir = data_dir\n",
|
|||
|
" self.device = torch.device(device)\n",
|
|||
|
" \n",
|
|||
|
" # 配置参数 (仿照RNN模型参数)\n",
|
|||
|
" self.model_args = {\n",
|
|||
|
" 'smooth_kernel_std': 2,\n",
|
|||
|
" 'smooth_kernel_size': 100,\n",
|
|||
|
" 'patch_size': 14,\n",
|
|||
|
" 'patch_stride': 4,\n",
|
|||
|
" }\n",
|
|||
|
" \n",
|
|||
|
" print(f\"🎯 测试集预测器初始化完成\")\n",
|
|||
|
" print(f\" 设备: {self.device}\")\n",
|
|||
|
" print(f\" Patch配置: size={self.model_args['patch_size']}, stride={self.model_args['patch_stride']}\")\n",
|
|||
|
" print(f\" 平滑配置: std={self.model_args['smooth_kernel_std']}, size={self.model_args['smooth_kernel_size']}\")\n",
|
|||
|
" \n",
|
|||
|
" def load_test_sessions(self):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 加载所有测试会话数据\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" print(f\"🔍 扫描测试数据目录: {self.data_dir}\")\n",
|
|||
|
" \n",
|
|||
|
" test_data = {}\n",
|
|||
|
" total_trials = 0\n",
|
|||
|
" \n",
|
|||
|
" # 扫描数据目录中的所有会话\n",
|
|||
|
" for session_name in os.listdir(self.data_dir):\n",
|
|||
|
" session_path = os.path.join(self.data_dir, session_name)\n",
|
|||
|
" if not os.path.isdir(session_path):\n",
|
|||
|
" continue\n",
|
|||
|
" \n",
|
|||
|
" # 查找测试数据文件\n",
|
|||
|
" test_file = os.path.join(session_path, 'data_test.hdf5')\n",
|
|||
|
" if os.path.exists(test_file):\n",
|
|||
|
" print(f\" 📂 发现测试会话: {session_name}\")\n",
|
|||
|
" \n",
|
|||
|
" try:\n",
|
|||
|
" # 加载数据 (传入空的CSV,因为测试集不需要)\n",
|
|||
|
" data = load_h5py_file(test_file, None)\n",
|
|||
|
" test_data[session_name] = data\n",
|
|||
|
" \n",
|
|||
|
" num_trials = len(data['neural_features'])\n",
|
|||
|
" total_trials += num_trials\n",
|
|||
|
" \n",
|
|||
|
" print(f\" ✅ 加载成功: {num_trials} 个试验\")\n",
|
|||
|
" print(f\" 📊 神经特征形状: {data['neural_features'][0].shape if num_trials > 0 else 'N/A'}\")\n",
|
|||
|
" \n",
|
|||
|
" except Exception as e:\n",
|
|||
|
" print(f\" ❌ 加载失败: {e}\")\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n📊 测试数据加载总结:\")\n",
|
|||
|
" print(f\" 会话数: {len(test_data)}\")\n",
|
|||
|
" print(f\" 总试验数: {total_trials}\")\n",
|
|||
|
" \n",
|
|||
|
" return test_data\n",
|
|||
|
" \n",
|
|||
|
" def predict_test_set(self, test_data):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 对测试集进行预测\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if self.lgbm_model is None:\n",
|
|||
|
" raise ValueError(\"LightGBM模型未加载\")\n",
|
|||
|
" \n",
|
|||
|
" if not self.pipeline.pca_fitted:\n",
|
|||
|
" raise ValueError(\"PCA模型未拟合\")\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n🚀 开始测试集预测...\")\n",
|
|||
|
" \n",
|
|||
|
" # 统计总试验数\n",
|
|||
|
" total_trials = sum(len(data['neural_features']) for data in test_data.values())\n",
|
|||
|
" \n",
|
|||
|
" results = {\n",
|
|||
|
" 'session': [],\n",
|
|||
|
" 'block': [],\n",
|
|||
|
" 'trial': [],\n",
|
|||
|
" 'predicted_sequence': [], # 完整的音素序列\n",
|
|||
|
" 'true_sequence': [], # 真实的音素序列\n",
|
|||
|
" 'sentence_label': [], # 句子标签(如果有的话)\n",
|
|||
|
" 'logits': [], # 原始预测概率\n",
|
|||
|
" 'sequence_length': [], # 预测序列长度\n",
|
|||
|
" 'true_length': [] # 真实序列长度\n",
|
|||
|
" }\n",
|
|||
|
" \n",
|
|||
|
" with tqdm(total=total_trials, desc='LightGBM预测进度', unit='trial') as pbar:\n",
|
|||
|
" for session_name, data in test_data.items():\n",
|
|||
|
" print(f\"\\n📈 处理会话: {session_name}\")\n",
|
|||
|
" \n",
|
|||
|
" for trial_idx in range(len(data['neural_features'])):\n",
|
|||
|
" # 1. 获取神经数据\n",
|
|||
|
" neural_input = data['neural_features'][trial_idx]\n",
|
|||
|
" \n",
|
|||
|
" # 2. 添加批次维度并转换为张量\n",
|
|||
|
" neural_input = np.expand_dims(neural_input, axis=0)\n",
|
|||
|
" neural_tensor = torch.tensor(neural_input, device=self.device, dtype=torch.float32)\n",
|
|||
|
" \n",
|
|||
|
" # 3. 应用RNN式的处理流程:高斯平滑 + patch处理\n",
|
|||
|
" processed_features = process_neural_data_for_lgbm(\n",
|
|||
|
" neural_tensor, self.device, self.model_args\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" # 4. 应用PCA降维\n",
|
|||
|
" if processed_features.shape[0] > 0:\n",
|
|||
|
" features_pca = self.pipeline._apply_pca_transform(processed_features)\n",
|
|||
|
" \n",
|
|||
|
" # 5. LightGBM预测 - 获取完整序列\n",
|
|||
|
" predictions = self.lgbm_model.predict(features_pca)\n",
|
|||
|
" \n",
|
|||
|
" # 6. 处理预测结果 - 保持序列形式\n",
|
|||
|
" if len(predictions.shape) > 1:\n",
|
|||
|
" # 每一行是一个时间步/patch的预测\n",
|
|||
|
" logits_sequence = predictions # [num_patches, 41]\n",
|
|||
|
" else:\n",
|
|||
|
" # 单个预测,扩展为序列\n",
|
|||
|
" logits_sequence = predictions.reshape(1, -1) # [1, 41]\n",
|
|||
|
" \n",
|
|||
|
" # 7. 转换为音素序列 (仿照RNN后处理) TODO:这里可以做过滤!!!!!\n",
|
|||
|
" predicted_classes = np.argmax(logits_sequence, axis=-1) # [num_patches]\n",
|
|||
|
" \n",
|
|||
|
" # 8. 后处理音素序列 (仿照evaluate_model.py)\n",
|
|||
|
" # 移除blank (0)\n",
|
|||
|
" pred_seq = [int(p) for p in predicted_classes if p != 0]\n",
|
|||
|
" # 移除连续重复\n",
|
|||
|
" pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]\n",
|
|||
|
" # 转换为音素符号\n",
|
|||
|
" predicted_phoneme_sequence = [LOGIT_TO_PHONEME[p] for p in pred_seq]\n",
|
|||
|
" \n",
|
|||
|
" # 8. 读取真实音素序列(如果存在)\n",
|
|||
|
" true_phoneme_sequence = []\n",
|
|||
|
" sentence_label = \"\"\n",
|
|||
|
" true_length = 0\n",
|
|||
|
" \n",
|
|||
|
" if 'seq_class_ids' in data and 'seq_len' in data:\n",
|
|||
|
" # 仿照evaluate_model.py的处理方式\n",
|
|||
|
" true_seq = data['seq_class_ids'][trial_idx][0:data['seq_len'][trial_idx]]\n",
|
|||
|
" true_phoneme_sequence = [LOGIT_TO_PHONEME[p] for p in true_seq]\n",
|
|||
|
" true_length = len(true_phoneme_sequence)\n",
|
|||
|
" \n",
|
|||
|
" if 'sentence_label' in data:\n",
|
|||
|
" sentence_label = data['sentence_label'][trial_idx]\n",
|
|||
|
" \n",
|
|||
|
" # 9. 存储结果\n",
|
|||
|
" results['session'].append(session_name)\n",
|
|||
|
" results['block'].append(data['block_num'][trial_idx])\n",
|
|||
|
" results['trial'].append(data['trial_num'][trial_idx])\n",
|
|||
|
" results['predicted_sequence'].append(predicted_phoneme_sequence)\n",
|
|||
|
" results['true_sequence'].append(true_phoneme_sequence)\n",
|
|||
|
" results['sentence_label'].append(sentence_label)\n",
|
|||
|
" results['logits'].append(logits_sequence.tolist())\n",
|
|||
|
" results['sequence_length'].append(len(predicted_phoneme_sequence))\n",
|
|||
|
" results['true_length'].append(true_length)\n",
|
|||
|
" \n",
|
|||
|
" pbar.update(1)\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n🎉 测试集预测完成!\")\n",
|
|||
|
" print(f\" 总预测数: {len(results['predicted_sequence'])}\")\n",
|
|||
|
" print(f\" 平均预测序列长度: {np.mean(results['sequence_length']):.1f}\")\n",
|
|||
|
" print(f\" 预测序列长度范围: {min(results['sequence_length'])} - {max(results['sequence_length'])}\")\n",
|
|||
|
" \n",
|
|||
|
" # 统计真实序列情况\n",
|
|||
|
" has_true_seq = sum(1 for seq in results['true_sequence'] if seq)\n",
|
|||
|
" if has_true_seq > 0:\n",
|
|||
|
" true_lengths = [length for length in results['true_length'] if length > 0]\n",
|
|||
|
" print(f\" 包含真实序列: {has_true_seq} / {len(results['predicted_sequence'])} 个试验\")\n",
|
|||
|
" if true_lengths:\n",
|
|||
|
" print(f\" 平均真实序列长度: {np.mean(true_lengths):.1f}\")\n",
|
|||
|
" print(f\" 真实序列长度范围: {min(true_lengths)} - {max(true_lengths)}\")\n",
|
|||
|
" else:\n",
|
|||
|
" print(f\" ⚠️ 测试集无真实序列标签,无法计算WER\")\n",
|
|||
|
" \n",
|
|||
|
" return results\n",
|
|||
|
" \n",
|
|||
|
" def save_results(self, results, output_path=\"test_predictions.csv\"):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 保存预测结果\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" import pandas as pd\n",
|
|||
|
" \n",
|
|||
|
" df = pd.DataFrame(results)\n",
|
|||
|
" df.to_csv(output_path, index=False)\n",
|
|||
|
" \n",
|
|||
|
" print(f\"💾 预测结果已保存: {output_path}\")\n",
|
|||
|
" print(f\"📊 结果统计:\")\n",
|
|||
|
" print(f\" 预测样本数: {len(df)}\")\n",
|
|||
|
" print(f\" 平均预测序列长度: {df['sequence_length'].mean():.1f}\")\n",
|
|||
|
" \n",
|
|||
|
" # 检查是否有真实序列\n",
|
|||
|
" has_true_seq = sum(1 for seq in df['true_sequence'] if seq and len(seq) > 0)\n",
|
|||
|
" if has_true_seq > 0:\n",
|
|||
|
" print(f\" 包含真实序列: {has_true_seq} 个试验\")\n",
|
|||
|
" print(f\" 平均真实序列长度: {df['true_length'].mean():.1f}\")\n",
|
|||
|
" \n",
|
|||
|
" # 统计预测音素分布\n",
|
|||
|
" all_predicted_phonemes = []\n",
|
|||
|
" for seq in df['predicted_sequence']:\n",
|
|||
|
" if isinstance(seq, list):\n",
|
|||
|
" all_predicted_phonemes.extend(seq)\n",
|
|||
|
" \n",
|
|||
|
" if all_predicted_phonemes:\n",
|
|||
|
" from collections import Counter\n",
|
|||
|
" phoneme_counts = Counter(all_predicted_phonemes)\n",
|
|||
|
" print(f\" 预测音素分布 (前10):\")\n",
|
|||
|
" for phoneme, count in phoneme_counts.most_common(10):\n",
|
|||
|
" print(f\" {phoneme}: {count}\")\n",
|
|||
|
" \n",
|
|||
|
" return df\n",
|
|||
|
"\n",
|
|||
|
"print(\"✅ 测试集预测器类定义完成\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 🚀 执行测试集预测\n",
|
|||
|
"\n",
|
|||
|
"# 检查必要组件是否准备就绪\n",
|
|||
|
"if lgbm_model is None:\n",
|
|||
|
" print(\"❌ LightGBM模型未加载,请先运行模型加载代码\")\n",
|
|||
|
"elif not hasattr(pipeline, 'pca_fitted') or not pipeline.pca_fitted:\n",
|
|||
|
" print(\"❌ PCA模型未拟合,请先运行智能管道的步骤1和步骤2\")\n",
|
|||
|
"else:\n",
|
|||
|
" print(\"✅ 所有组件准备就绪,开始测试集预测...\")\n",
|
|||
|
" \n",
|
|||
|
" # 配置测试数据路径\n",
|
|||
|
" TEST_DATA_DIR = \"/kaggle/working/nejm-brain-to-text/data/hdf5_data_final\"\n",
|
|||
|
" \n",
|
|||
|
" # 创建预测器\n",
|
|||
|
" predictor = TestSetPredictor(\n",
|
|||
|
" lgbm_model=lgbm_model,\n",
|
|||
|
" pipeline=pipeline,\n",
|
|||
|
" data_dir=TEST_DATA_DIR,\n",
|
|||
|
" device='cpu' # 可以改为'cuda'如果有GPU\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n\" + \"=\"*60)\n",
|
|||
|
" print(\"🔍 第1步: 加载测试集数据\")\n",
|
|||
|
" print(\"=\"*60)\n",
|
|||
|
" \n",
|
|||
|
" # 加载测试数据\n",
|
|||
|
" test_data = predictor.load_test_sessions()\n",
|
|||
|
" \n",
|
|||
|
" if not test_data:\n",
|
|||
|
" print(\"❌ 未找到测试数据,请检查数据路径\")\n",
|
|||
|
" else:\n",
|
|||
|
" print(f\"\\n\" + \"=\"*60)\n",
|
|||
|
" print(\"🔮 第2步: 执行LightGBM预测\")\n",
|
|||
|
" print(\"=\"*60)\n",
|
|||
|
" \n",
|
|||
|
" # 执行预测\n",
|
|||
|
" try:\n",
|
|||
|
" prediction_results = predictor.predict_test_set(test_data)\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n\" + \"=\"*60)\n",
|
|||
|
" print(\"💾 第3步: 保存预测结果\")\n",
|
|||
|
" print(\"=\"*60)\n",
|
|||
|
" \n",
|
|||
|
" # 保存结果\n",
|
|||
|
" results_df = predictor.save_results(\n",
|
|||
|
" prediction_results, \n",
|
|||
|
" \"lgbm_test_predictions.csv\"\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n🎉 测试集预测流程完成!\")\n",
|
|||
|
" print(f\" 📁 数据路径: {TEST_DATA_DIR}\")\n",
|
|||
|
" print(f\" 📊 处理会话: {len(test_data)} 个\")\n",
|
|||
|
" print(f\" 🎯 预测样本: {len(prediction_results['predicted_phonemes'])} 个\")\n",
|
|||
|
" print(f\" 💾 结果文件: lgbm_test_predictions.csv\")\n",
|
|||
|
" \n",
|
|||
|
" except Exception as e:\n",
|
|||
|
" print(f\"❌ 预测过程中发生错误: {e}\")\n",
|
|||
|
" import traceback\n",
|
|||
|
" traceback.print_exc()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 📊 测试集预测结果分析\n",
|
|||
|
"\n",
|
|||
|
"def analyze_test_predictions(csv_path=\"lgbm_test_predictions.csv\"):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 分析测试集预测结果\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if not os.path.exists(csv_path):\n",
|
|||
|
" print(f\"❌ 结果文件不存在: {csv_path}\")\n",
|
|||
|
" return\n",
|
|||
|
" \n",
|
|||
|
" print(f\"📊 分析测试集预测结果: {csv_path}\")\n",
|
|||
|
" \n",
|
|||
|
" import pandas as pd\n",
|
|||
|
" import matplotlib.pyplot as plt\n",
|
|||
|
" \n",
|
|||
|
" # 加载结果\n",
|
|||
|
" df = pd.read_csv(csv_path)\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n📈 基本统计:\")\n",
|
|||
|
" print(f\" 总预测数: {len(df):,}\")\n",
|
|||
|
" print(f\" 会话数: {df['session'].nunique()}\")\n",
|
|||
|
" print(f\" 平均置信度: {df['prediction_confidence'].mean():.4f}\")\n",
|
|||
|
" print(f\" 置信度标准差: {df['prediction_confidence'].std():.4f}\")\n",
|
|||
|
" \n",
|
|||
|
" # 置信度分布\n",
|
|||
|
" plt.figure(figsize=(15, 10))\n",
|
|||
|
" \n",
|
|||
|
" # 1. 置信度直方图\n",
|
|||
|
" plt.subplot(2, 3, 1)\n",
|
|||
|
" plt.hist(df['prediction_confidence'], bins=50, alpha=0.7, color='skyblue', edgecolor='black')\n",
|
|||
|
" plt.xlabel('预测置信度')\n",
|
|||
|
" plt.ylabel('频数')\n",
|
|||
|
" plt.title('预测置信度分布')\n",
|
|||
|
" plt.grid(True, alpha=0.3)\n",
|
|||
|
" \n",
|
|||
|
" # 2. 音素分布 (前15个)\n",
|
|||
|
" plt.subplot(2, 3, 2)\n",
|
|||
|
" phoneme_counts = df['predicted_phonemes'].value_counts().head(15)\n",
|
|||
|
" phoneme_counts.plot(kind='bar', color='lightcoral')\n",
|
|||
|
" plt.xlabel('音素')\n",
|
|||
|
" plt.ylabel('预测次数')\n",
|
|||
|
" plt.title('预测音素分布 (前15)')\n",
|
|||
|
" plt.xticks(rotation=45)\n",
|
|||
|
" plt.grid(True, alpha=0.3)\n",
|
|||
|
" \n",
|
|||
|
" # 3. 各会话的预测数量\n",
|
|||
|
" plt.subplot(2, 3, 3)\n",
|
|||
|
" session_counts = df['session'].value_counts()\n",
|
|||
|
" session_counts.plot(kind='bar', color='lightgreen')\n",
|
|||
|
" plt.xlabel('会话')\n",
|
|||
|
" plt.ylabel('预测数量')\n",
|
|||
|
" plt.title('各会话预测数量')\n",
|
|||
|
" plt.xticks(rotation=45)\n",
|
|||
|
" plt.grid(True, alpha=0.3)\n",
|
|||
|
" \n",
|
|||
|
" # 4. 各会话的平均置信度\n",
|
|||
|
" plt.subplot(2, 3, 4)\n",
|
|||
|
" session_confidence = df.groupby('session')['prediction_confidence'].mean()\n",
|
|||
|
" session_confidence.plot(kind='bar', color='gold')\n",
|
|||
|
" plt.xlabel('会话')\n",
|
|||
|
" plt.ylabel('平均置信度')\n",
|
|||
|
" plt.title('各会话平均置信度')\n",
|
|||
|
" plt.xticks(rotation=45)\n",
|
|||
|
" plt.grid(True, alpha=0.3)\n",
|
|||
|
" \n",
|
|||
|
" # 5. 高置信度预测的音素分布\n",
|
|||
|
" plt.subplot(2, 3, 5)\n",
|
|||
|
" high_conf_df = df[df['prediction_confidence'] > 0.8]\n",
|
|||
|
" if len(high_conf_df) > 0:\n",
|
|||
|
" high_conf_phonemes = high_conf_df['predicted_phonemes'].value_counts().head(10)\n",
|
|||
|
" high_conf_phonemes.plot(kind='bar', color='orange')\n",
|
|||
|
" plt.xlabel('音素')\n",
|
|||
|
" plt.ylabel('高置信度预测次数')\n",
|
|||
|
" plt.title(f'高置信度(>0.8)音素分布\\n总数: {len(high_conf_df)}')\n",
|
|||
|
" plt.xticks(rotation=45)\n",
|
|||
|
" else:\n",
|
|||
|
" plt.text(0.5, 0.5, '无高置信度预测', ha='center', va='center')\n",
|
|||
|
" plt.title('高置信度预测分布')\n",
|
|||
|
" plt.grid(True, alpha=0.3)\n",
|
|||
|
" \n",
|
|||
|
" # 6. 置信度箱线图(按会话)\n",
|
|||
|
" plt.subplot(2, 3, 6)\n",
|
|||
|
" import seaborn as sns\n",
|
|||
|
" sns.boxplot(data=df, x='session', y='prediction_confidence')\n",
|
|||
|
" plt.xlabel('会话')\n",
|
|||
|
" plt.ylabel('置信度')\n",
|
|||
|
" plt.title('各会话置信度分布')\n",
|
|||
|
" plt.xticks(rotation=45)\n",
|
|||
|
" plt.grid(True, alpha=0.3)\n",
|
|||
|
" \n",
|
|||
|
" plt.tight_layout()\n",
|
|||
|
" plt.show()\n",
|
|||
|
" \n",
|
|||
|
" # 详细统计\n",
|
|||
|
" print(f\"\\n📋 详细统计:\")\n",
|
|||
|
" print(f\" 音素类别数: {df['predicted_phonemes'].nunique()}\")\n",
|
|||
|
" print(f\" 最常预测音素: {df['predicted_phonemes'].mode().iloc[0]} ({df['predicted_phonemes'].value_counts().iloc[0]} 次)\")\n",
|
|||
|
" print(f\" 高置信度预测 (>0.9): {(df['prediction_confidence'] > 0.9).sum()} / {len(df)} ({(df['prediction_confidence'] > 0.9).mean():.2%})\")\n",
|
|||
|
" print(f\" 低置信度预测 (<0.5): {(df['prediction_confidence'] < 0.5).sum()} / {len(df)} ({(df['prediction_confidence'] < 0.5).mean():.2%})\")\n",
|
|||
|
" \n",
|
|||
|
" return df\n",
|
|||
|
"\n",
|
|||
|
"# 如果预测结果文件存在,自动分析\n",
|
|||
|
"if os.path.exists(\"lgbm_test_predictions.csv\"):\n",
|
|||
|
" print(\"🔍 发现预测结果文件,开始自动分析...\")\n",
|
|||
|
" results_analysis = analyze_test_predictions()\n",
|
|||
|
"else:\n",
|
|||
|
" print(\"💡 运行上面的预测代码后,将自动分析结果\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 🔍 序列预测逻辑说明\n",
|
|||
|
"\n",
|
|||
|
"你说得对!我们需要的是完整的音素序列,而不是单个预测的平均值。\n",
|
|||
|
"\n",
|
|||
|
"### 🎯 RNN vs LightGBM 序列处理对比\n",
|
|||
|
"\n",
|
|||
|
"#### **RNN的处理方式:**\n",
|
|||
|
"```\n",
|
|||
|
"神经数据 → 高斯平滑 → Patch分割 → RNN → 序列logits → 音素序列\n",
|
|||
|
"[1,100,512] → [1,86,512] → [1,19,7168] → RNN → [1,19,41] → ['AE','T','SH',...]\n",
|
|||
|
"```\n",
|
|||
|
"\n",
|
|||
|
"#### **我们的LightGBM处理方式:**\n",
|
|||
|
"```\n",
|
|||
|
"神经数据 → 高斯平滑 → Patch分割 → PCA → LightGBM → 序列logits → 音素序列 \n",
|
|||
|
"[1,100,512] → [1,86,512] → [19,7168] → [19,PCA_dim] → LightGBM → [19,41] → ['AE','T','SH',...]\n",
|
|||
|
"```\n",
|
|||
|
"\n",
|
|||
|
"### 🔄 关键修改\n",
|
|||
|
"1. **保持序列维度**: 不再对patch预测取平均,而是保持每个patch的独立预测\n",
|
|||
|
"2. **后处理序列**: 像RNN一样进行blank移除和重复音素合并\n",
|
|||
|
"3. **输出格式**: 每个试验输出完整的音素序列列表,而不是单个音素\n",
|
|||
|
"\n",
|
|||
|
"这样我们就能得到与RNN相同格式的序列输出!"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 🎯 序列预测结果展示和分析(含真实对比和WER计算)\n",
|
|||
|
"\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"import ast\n",
|
|||
|
"import editdistance\n",
|
|||
|
"\n",
|
|||
|
"def display_sequence_predictions_with_wer(csv_path=\"lgbm_test_predictions.csv\", num_examples=5):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 展示序列预测结果,真实结果对比,并计算WER\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if not os.path.exists(csv_path):\n",
|
|||
|
" print(f\"❌ 结果文件不存在: {csv_path}\")\n",
|
|||
|
" return\n",
|
|||
|
" \n",
|
|||
|
" df = pd.read_csv(csv_path)\n",
|
|||
|
" df['predicted_sequence'] = df['predicted_sequence'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)\n",
|
|||
|
" if 'true_sequence' in df.columns:\n",
|
|||
|
" df['true_sequence'] = df['true_sequence'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)\n",
|
|||
|
" else:\n",
|
|||
|
" print(\"⚠️ 没有 true_sequence 列,无法对比和计算WER\")\n",
|
|||
|
" return\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n📊 序列预测统计:\")\n",
|
|||
|
" print(f\" 总试验数: {len(df):,}\")\n",
|
|||
|
" print(f\" 平均序列长度: {df['sequence_length'].mean():.1f}\")\n",
|
|||
|
" print(f\" 序列长度范围: {df['sequence_length'].min()} - {df['sequence_length'].max()}\")\n",
|
|||
|
" \n",
|
|||
|
" # 随机展示若干例子\n",
|
|||
|
" sample_df = df.sample(min(num_examples, len(df)), random_state=42)\n",
|
|||
|
" print(f\"\\n🎭 预测序列与真实序列对比 (随机 {num_examples} 个):\")\n",
|
|||
|
" print(\"=\" * 80)\n",
|
|||
|
" for idx, row in sample_df.iterrows():\n",
|
|||
|
" pred_seq = row['predicted_sequence']\n",
|
|||
|
" true_seq = row['true_sequence']\n",
|
|||
|
" print(f\"Trial {idx+1}:\")\n",
|
|||
|
" print(f\" True: {' '.join(true_seq)}\")\n",
|
|||
|
" print(f\" Predicted: {' '.join(pred_seq)}\")\n",
|
|||
|
" ed = editdistance.eval(true_seq, pred_seq)\n",
|
|||
|
" print(f\" Edit Distance: {ed} / {len(true_seq)} = {ed/len(true_seq):.2%}\")\n",
|
|||
|
" print(\"-\" * 40)\n",
|
|||
|
" \n",
|
|||
|
" # 计算整体WER\n",
|
|||
|
" total_ed = 0\n",
|
|||
|
" total_len = 0\n",
|
|||
|
" for idx, row in df.iterrows():\n",
|
|||
|
" pred_seq = row['predicted_sequence']\n",
|
|||
|
" true_seq = row['true_sequence']\n",
|
|||
|
" if true_seq:\n",
|
|||
|
" ed = editdistance.eval(true_seq, pred_seq)\n",
|
|||
|
" total_ed += ed\n",
|
|||
|
" total_len += len(true_seq)\n",
|
|||
|
" if total_len > 0:\n",
|
|||
|
" print(f\"\\nAggregate Phoneme WER: {total_ed} / {total_len} = {total_ed/total_len:.2%}\")\n",
|
|||
|
" else:\n",
|
|||
|
" print(\"No ground truth available for WER calculation.\")\n",
|
|||
|
" return df\n",
|
|||
|
"\n",
|
|||
|
"# 如果预测结果文件存在,展示序列预测结果和WER\n",
|
|||
|
"if os.path.exists(\"lgbm_test_predictions.csv\"):\n",
|
|||
|
" print(\"🎯 展示序列预测结果与WER...\")\n",
|
|||
|
" sequence_results = display_sequence_predictions_with_wer()\n",
|
|||
|
"else:\n",
|
|||
|
" print(\"💡 运行预测代码后,将展示序列预测结果和WER\")\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def edit_distance(seq1, seq2):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 计算两个序列之间的编辑距离 (Levenshtein distance)\n",
|
|||
|
" 用于计算 WER (Word Error Rate)\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" m, n = len(seq1), len(seq2)\n",
|
|||
|
" \n",
|
|||
|
" # 创建动态规划表\n",
|
|||
|
" dp = [[0] * (n + 1) for _ in range(m + 1)]\n",
|
|||
|
" \n",
|
|||
|
" # 初始化边界条件\n",
|
|||
|
" for i in range(m + 1):\n",
|
|||
|
" dp[i][0] = i\n",
|
|||
|
" for j in range(n + 1):\n",
|
|||
|
" dp[0][j] = j\n",
|
|||
|
" \n",
|
|||
|
" # 填充动态规划表\n",
|
|||
|
" for i in range(1, m + 1):\n",
|
|||
|
" for j in range(1, n + 1):\n",
|
|||
|
" if seq1[i-1] == seq2[j-1]:\n",
|
|||
|
" dp[i][j] = dp[i-1][j-1]\n",
|
|||
|
" else:\n",
|
|||
|
" dp[i][j] = 1 + min(\n",
|
|||
|
" dp[i-1][j], # 删除\n",
|
|||
|
" dp[i][j-1], # 插入\n",
|
|||
|
" dp[i-1][j-1] # 替换\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" return dp[m][n]\n",
|
|||
|
"\n",
|
|||
|
"def calculate_wer(predicted_seq, true_seq):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 计算序列的 Word Error Rate (WER)\n",
|
|||
|
" WER = 编辑距离 / 真实序列长度\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if len(true_seq) == 0:\n",
|
|||
|
" return 1.0 if len(predicted_seq) > 0 else 0.0\n",
|
|||
|
" \n",
|
|||
|
" edit_dist = edit_distance(predicted_seq, true_seq)\n",
|
|||
|
" wer = edit_dist / len(true_seq)\n",
|
|||
|
" \n",
|
|||
|
" return wer\n",
|
|||
|
"\n",
|
|||
|
"print(\"✅ WER 计算功能已定义\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 🚀 运行测试集预测(序列级别,包含WER计算)\n",
|
|||
|
"\n",
|
|||
|
"print(\"🔥 实例化 TestSetPredictor...\")\n",
|
|||
|
"predictor = TestSetPredictor()\n",
|
|||
|
"\n",
|
|||
|
"print(\"\\n🎯 开始测试集序列预测...\")\n",
|
|||
|
"results = predictor.predict_test_set()\n",
|
|||
|
"\n",
|
|||
|
"print(\"\\n💾 保存预测结果...\")\n",
|
|||
|
"df_results = predictor.save_results(results, \"lgbm_test_predictions_with_wer.csv\")\n",
|
|||
|
"\n",
|
|||
|
"print(\"\\n📊 计算整体WER...\")\n",
|
|||
|
"total_ed = 0\n",
|
|||
|
"total_len = 0\n",
|
|||
|
"valid_predictions = 0\n",
|
|||
|
"\n",
|
|||
|
"for result in results:\n",
|
|||
|
" pred_seq = result['predicted_sequence']\n",
|
|||
|
" true_seq = result['true_sequence']\n",
|
|||
|
" \n",
|
|||
|
" if true_seq and len(true_seq) > 0:\n",
|
|||
|
" wer = calculate_wer(pred_seq, true_seq)\n",
|
|||
|
" result['wer'] = wer\n",
|
|||
|
" \n",
|
|||
|
" ed = edit_distance(pred_seq, true_seq)\n",
|
|||
|
" total_ed += ed\n",
|
|||
|
" total_len += len(true_seq)\n",
|
|||
|
" valid_predictions += 1\n",
|
|||
|
" \n",
|
|||
|
" print(f\"Session {result['session']}, Block {result['block']}, Trial {result['trial']}: WER = {wer:.2%}\")\n",
|
|||
|
"\n",
|
|||
|
"if total_len > 0:\n",
|
|||
|
" overall_wer = total_ed / total_len\n",
|
|||
|
" print(f\"\\n🎯 整体序列 WER: {total_ed} / {total_len} = {overall_wer:.2%}\")\n",
|
|||
|
" print(f\"✅ 有效预测数量: {valid_predictions}\")\n",
|
|||
|
"else:\n",
|
|||
|
" print(\"❌ 没有可用的真实序列用于WER计算\")\n",
|
|||
|
"\n",
|
|||
|
"print(\"\\n🎉 测试集序列预测完成!\")"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kaggle": {
|
|||
|
"accelerator": "tpu1vmV38",
|
|||
|
"dataSources": [
|
|||
|
{
|
|||
|
"databundleVersionId": 13056355,
|
|||
|
"sourceId": 106809,
|
|||
|
"sourceType": "competition"
|
|||
|
}
|
|||
|
],
|
|||
|
"dockerImageVersionId": 31091,
|
|||
|
"isGpuEnabled": false,
|
|||
|
"isInternetEnabled": true,
|
|||
|
"language": "python",
|
|||
|
"sourceType": "notebook"
|
|||
|
},
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3 (ipykernel)",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 4
|
|||
|
}
|