4212 lines
309 KiB
Plaintext
4212 lines
309 KiB
Plaintext
![]() |
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# 环境配置 与 utils"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"%cd /kaggle/working/"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 1,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Looking in indexes: https://download.pytorch.org/whl/cu126\n",
|
|||
|
"Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n",
|
|||
|
"Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (0.21.0+cu124)\n",
|
|||
|
"Requirement already satisfied: torchaudio in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n",
|
|||
|
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n",
|
|||
|
"Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.14.0)\n",
|
|||
|
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.5)\n",
|
|||
|
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n",
|
|||
|
"Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.5.1)\n",
|
|||
|
"Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)\n",
|
|||
|
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
|
|||
|
"Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)\n",
|
|||
|
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
|
|||
|
"Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)\n",
|
|||
|
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
|
|||
|
"Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)\n",
|
|||
|
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
|
|||
|
"Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)\n",
|
|||
|
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
|
|||
|
"Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)\n",
|
|||
|
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
|
|||
|
"Collecting nvidia-curand-cu12==10.3.5.147 (from torch)\n",
|
|||
|
" Downloading https://download.pytorch.org/whl/cu126/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
|
|||
|
"Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)\n",
|
|||
|
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
|
|||
|
"Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch)\n",
|
|||
|
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
|
|||
|
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n",
|
|||
|
"Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n",
|
|||
|
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
|
|||
|
"Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch)\n",
|
|||
|
" Downloading https://download.pytorch.org/whl/cu126/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
|
|||
|
"Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n",
|
|||
|
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n",
|
|||
|
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n",
|
|||
|
"Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision) (1.26.4)\n",
|
|||
|
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (11.2.1)\n",
|
|||
|
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n",
|
|||
|
"Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.3.8)\n",
|
|||
|
"Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.2.4)\n",
|
|||
|
"Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (0.1.1)\n",
|
|||
|
"Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2025.2.0)\n",
|
|||
|
"Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2022.2.0)\n",
|
|||
|
"Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2.4.1)\n",
|
|||
|
"Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2024.2.0)\n",
|
|||
|
"Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2022.2.0)\n",
|
|||
|
"Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy->torchvision) (1.4.0)\n",
|
|||
|
"Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy->torchvision) (2024.2.0)\n",
|
|||
|
"Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy->torchvision) (2024.2.0)\n",
|
|||
|
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 363.4/363.4 MB 3.6 MB/s eta 0:00:00\n",
|
|||
|
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.8/13.8 MB 4.9 MB/s eta 0:00:00\n",
|
|||
|
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.6/24.6 MB 24.7 MB/s eta 0:00:00\n",
|
|||
|
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 883.7/883.7 kB 9.6 MB/s eta 0:00:00\n",
|
|||
|
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 2.0 MB/s eta 0:00:00\n",
|
|||
|
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 211.5/211.5 MB 2.6 MB/s eta 0:00:00\n",
|
|||
|
"Downloading https://download.pytorch.org/whl/cu126/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.3/56.3 MB 22.7 MB/s eta 0:00:00\n",
|
|||
|
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 127.9/127.9 MB 6.3 MB/s eta 0:00:00\n",
|
|||
|
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.5/207.5 MB 4.5 MB/s eta 0:00:00\n",
|
|||
|
"Downloading https://download.pytorch.org/whl/cu126/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.1/21.1 MB 22.8 MB/s eta 0:00:00\n",
|
|||
|
"Installing collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12\n",
|
|||
|
" Attempting uninstall: nvidia-nvjitlink-cu12\n",
|
|||
|
" Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n",
|
|||
|
" Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n",
|
|||
|
" Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n",
|
|||
|
" Attempting uninstall: nvidia-curand-cu12\n",
|
|||
|
" Found existing installation: nvidia-curand-cu12 10.3.6.82\n",
|
|||
|
" Uninstalling nvidia-curand-cu12-10.3.6.82:\n",
|
|||
|
" Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n",
|
|||
|
" Attempting uninstall: nvidia-cufft-cu12\n",
|
|||
|
" Found existing installation: nvidia-cufft-cu12 11.2.3.61\n",
|
|||
|
" Uninstalling nvidia-cufft-cu12-11.2.3.61:\n",
|
|||
|
" Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n",
|
|||
|
" Attempting uninstall: nvidia-cuda-runtime-cu12\n",
|
|||
|
" Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n",
|
|||
|
" Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n",
|
|||
|
" Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n",
|
|||
|
" Attempting uninstall: nvidia-cuda-nvrtc-cu12\n",
|
|||
|
" Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n",
|
|||
|
" Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n",
|
|||
|
" Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n",
|
|||
|
" Attempting uninstall: nvidia-cuda-cupti-cu12\n",
|
|||
|
" Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n",
|
|||
|
" Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n",
|
|||
|
" Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n",
|
|||
|
" Attempting uninstall: nvidia-cublas-cu12\n",
|
|||
|
" Found existing installation: nvidia-cublas-cu12 12.5.3.2\n",
|
|||
|
" Uninstalling nvidia-cublas-cu12-12.5.3.2:\n",
|
|||
|
" Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n",
|
|||
|
" Attempting uninstall: nvidia-cusparse-cu12\n",
|
|||
|
" Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n",
|
|||
|
" Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n",
|
|||
|
" Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n",
|
|||
|
" Attempting uninstall: nvidia-cudnn-cu12\n",
|
|||
|
" Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n",
|
|||
|
" Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n",
|
|||
|
" Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n",
|
|||
|
" Attempting uninstall: nvidia-cusolver-cu12\n",
|
|||
|
" Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n",
|
|||
|
" Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n",
|
|||
|
" Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n",
|
|||
|
"Successfully installed nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127\n",
|
|||
|
"Collecting jupyter==1.1.1\n",
|
|||
|
" Downloading jupyter-1.1.1-py2.py3-none-any.whl.metadata (2.0 kB)\n",
|
|||
|
"Requirement already satisfied: numpy<2.1.0,>=1.26.0 in /usr/local/lib/python3.11/dist-packages (1.26.4)\n",
|
|||
|
"Collecting pandas==2.3.0\n",
|
|||
|
" Downloading pandas-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (91 kB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 91.2/91.2 kB 1.9 MB/s eta 0:00:00\n",
|
|||
|
"Collecting matplotlib==3.10.1\n",
|
|||
|
" Downloading matplotlib-3.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)\n",
|
|||
|
"Collecting scipy==1.15.2\n",
|
|||
|
" Downloading scipy-1.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 62.0/62.0 kB 2.3 MB/s eta 0:00:00\n",
|
|||
|
"Collecting scikit-learn==1.6.1\n",
|
|||
|
" Downloading scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (18 kB)\n",
|
|||
|
"Collecting lightgbm==4.3.0\n",
|
|||
|
" Downloading lightgbm-4.3.0-py3-none-manylinux_2_28_x86_64.whl.metadata (19 kB)\n",
|
|||
|
"Requirement already satisfied: tqdm==4.67.1 in /usr/local/lib/python3.11/dist-packages (4.67.1)\n",
|
|||
|
"Collecting g2p_en==2.1.0\n",
|
|||
|
" Downloading g2p_en-2.1.0-py3-none-any.whl.metadata (4.5 kB)\n",
|
|||
|
"Collecting h5py==3.13.0\n",
|
|||
|
" Downloading h5py-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.5 kB)\n",
|
|||
|
"Requirement already satisfied: omegaconf==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n",
|
|||
|
"Requirement already satisfied: editdistance==0.8.1 in /usr/local/lib/python3.11/dist-packages (0.8.1)\n",
|
|||
|
"Requirement already satisfied: huggingface-hub==0.33.1 in /usr/local/lib/python3.11/dist-packages (0.33.1)\n",
|
|||
|
"Collecting transformers==4.53.0\n",
|
|||
|
" Downloading transformers-4.53.0-py3-none-any.whl.metadata (39 kB)\n",
|
|||
|
"Requirement already satisfied: tokenizers==0.21.2 in /usr/local/lib/python3.11/dist-packages (0.21.2)\n",
|
|||
|
"Requirement already satisfied: accelerate==1.8.1 in /usr/local/lib/python3.11/dist-packages (1.8.1)\n",
|
|||
|
"Collecting bitsandbytes==0.46.0\n",
|
|||
|
" Downloading bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)\n",
|
|||
|
"Collecting seaborn==0.13.2\n",
|
|||
|
" Downloading seaborn-0.13.2-py3-none-any.whl.metadata (5.4 kB)\n",
|
|||
|
"Requirement already satisfied: notebook in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.5.4)\n",
|
|||
|
"Requirement already satisfied: jupyter-console in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.1.0)\n",
|
|||
|
"Requirement already satisfied: nbconvert in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.4.5)\n",
|
|||
|
"Requirement already satisfied: ipykernel in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.17.1)\n",
|
|||
|
"Requirement already satisfied: ipywidgets in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (8.1.5)\n",
|
|||
|
"Requirement already satisfied: jupyterlab in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (3.6.8)\n",
|
|||
|
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2.9.0.post0)\n",
|
|||
|
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n",
|
|||
|
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n",
|
|||
|
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.3.2)\n",
|
|||
|
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (0.12.1)\n",
|
|||
|
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (4.58.4)\n",
|
|||
|
"Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.4.8)\n",
|
|||
|
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (25.0)\n",
|
|||
|
"Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (11.2.1)\n",
|
|||
|
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (3.0.9)\n",
|
|||
|
"Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (1.5.1)\n",
|
|||
|
"Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (3.6.0)\n",
|
|||
|
"Requirement already satisfied: nltk>=3.2.4 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (3.9.1)\n",
|
|||
|
"Requirement already satisfied: inflect>=0.3.1 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (7.5.0)\n",
|
|||
|
"Collecting distance>=0.1.3 (from g2p_en==2.1.0)\n",
|
|||
|
" Downloading Distance-0.1.3.tar.gz (180 kB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 180.3/180.3 kB 5.2 MB/s eta 0:00:00\n",
|
|||
|
" Preparing metadata (setup.py): started\n",
|
|||
|
" Preparing metadata (setup.py): finished with status 'done'\n",
|
|||
|
"Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (4.9.3)\n",
|
|||
|
"Requirement already satisfied: PyYAML>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (6.0.2)\n",
|
|||
|
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (3.18.0)\n",
|
|||
|
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2025.5.1)\n",
|
|||
|
"Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2.32.4)\n",
|
|||
|
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (4.14.0)\n",
|
|||
|
"Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (1.1.5)\n",
|
|||
|
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (2024.11.6)\n",
|
|||
|
"Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (0.5.3)\n",
|
|||
|
"Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (7.0.0)\n",
|
|||
|
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (2.6.0+cu124)\n",
|
|||
|
"Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.3.8)\n",
|
|||
|
"Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.2.4)\n",
|
|||
|
"Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (0.1.1)\n",
|
|||
|
"Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2025.2.0)\n",
|
|||
|
"Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2022.2.0)\n",
|
|||
|
"Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2.4.1)\n",
|
|||
|
"Requirement already satisfied: more_itertools>=8.5.0 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (10.7.0)\n",
|
|||
|
"Requirement already satisfied: typeguard>=4.0.1 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (4.4.4)\n",
|
|||
|
"Requirement already satisfied: click in /usr/local/lib/python3.11/dist-packages (from nltk>=3.2.4->g2p_en==2.1.0) (8.2.1)\n",
|
|||
|
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas==2.3.0) (1.17.0)\n",
|
|||
|
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.5)\n",
|
|||
|
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.1.6)\n",
|
|||
|
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
|||
|
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
|||
|
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
|||
|
"Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (9.1.0.70)\n",
|
|||
|
"Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.5.8)\n",
|
|||
|
"Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.2.1.3)\n",
|
|||
|
"Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (10.3.5.147)\n",
|
|||
|
"Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.6.1.9)\n",
|
|||
|
"Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.3.1.170)\n",
|
|||
|
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (0.6.2)\n",
|
|||
|
"Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (2.21.5)\n",
|
|||
|
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
|||
|
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
|||
|
"Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.2.0)\n",
|
|||
|
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (1.13.1)\n",
|
|||
|
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=2.0.0->accelerate==1.8.1) (1.3.0)\n",
|
|||
|
"Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.8.0)\n",
|
|||
|
"Requirement already satisfied: ipython>=7.23.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (7.34.0)\n",
|
|||
|
"Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (8.6.3)\n",
|
|||
|
"Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (0.1.7)\n",
|
|||
|
"Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.6.0)\n",
|
|||
|
"Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (24.0.1)\n",
|
|||
|
"Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (6.5.1)\n",
|
|||
|
"Requirement already satisfied: traitlets>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (5.7.1)\n",
|
|||
|
"Requirement already satisfied: comm>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (0.2.2)\n",
|
|||
|
"Requirement already satisfied: widgetsnbextension~=4.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (4.0.14)\n",
|
|||
|
"Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (3.0.15)\n",
|
|||
|
"Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (3.0.51)\n",
|
|||
|
"Requirement already satisfied: pygments in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (2.19.2)\n",
|
|||
|
"Requirement already satisfied: jupyter-core in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (5.8.1)\n",
|
|||
|
"Requirement already satisfied: jupyterlab-server~=2.19 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.27.3)\n",
|
|||
|
"Requirement already satisfied: jupyter-server<3,>=1.16.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.12.5)\n",
|
|||
|
"Requirement already satisfied: jupyter-ydoc~=0.2.4 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.2.5)\n",
|
|||
|
"Requirement already satisfied: jupyter-server-ydoc~=0.8.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.8.0)\n",
|
|||
|
"Requirement already satisfied: nbclassic in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (1.3.1)\n",
|
|||
|
"Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (25.1.0)\n",
|
|||
|
"Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.2.0)\n",
|
|||
|
"Requirement already satisfied: nbformat in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (5.10.4)\n",
|
|||
|
"Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (1.8.3)\n",
|
|||
|
"Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.18.1)\n",
|
|||
|
"Requirement already satisfied: prometheus-client in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.22.1)\n",
|
|||
|
"Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.8.4)\n",
|
|||
|
"Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.3.0)\n",
|
|||
|
"Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.4)\n",
|
|||
|
"Requirement already satisfied: bleach in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (6.2.0)\n",
|
|||
|
"Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (1.5.1)\n",
|
|||
|
"Requirement already satisfied: testpath in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.6.0)\n",
|
|||
|
"Requirement already satisfied: defusedxml in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.7.1)\n",
|
|||
|
"Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (4.13.4)\n",
|
|||
|
"Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.5.13)\n",
|
|||
|
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (3.0.2)\n",
|
|||
|
"Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
|
|||
|
"Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2022.2.0)\n",
|
|||
|
"Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy<2.1.0,>=1.26.0) (1.4.0)\n",
|
|||
|
"Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
|
|||
|
"Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.4.2)\n",
|
|||
|
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.10)\n",
|
|||
|
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2.5.0)\n",
|
|||
|
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2025.6.15)\n",
|
|||
|
"Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
|
|||
|
"Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (75.2.0)\n",
|
|||
|
"Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.19.2)\n",
|
|||
|
"Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.4.2)\n",
|
|||
|
"Requirement already satisfied: pickleshare in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.7.5)\n",
|
|||
|
"Requirement already satisfied: backcall in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.2.0)\n",
|
|||
|
"Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.9.0)\n",
|
|||
|
"Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.11/dist-packages (from jupyter-core->jupyterlab->jupyter==1.1.1) (4.3.8)\n",
|
|||
|
"Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (4.9.0)\n",
|
|||
|
"Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.12.0)\n",
|
|||
|
"Requirement already satisfied: jupyter-server-terminals in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.5.3)\n",
|
|||
|
"Requirement already satisfied: overrides in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (7.7.0)\n",
|
|||
|
"Requirement already satisfied: websocket-client in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.8.0)\n",
|
|||
|
"Requirement already satisfied: jupyter-server-fileid<1,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.9.3)\n",
|
|||
|
"Requirement already satisfied: ypy-websocket<0.9.0,>=0.8.2 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.8.4)\n",
|
|||
|
"Requirement already satisfied: y-py<0.7.0,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-ydoc~=0.2.4->jupyterlab->jupyter==1.1.1) (0.6.2)\n",
|
|||
|
"Requirement already satisfied: babel>=2.10 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2.17.0)\n",
|
|||
|
"Requirement already satisfied: json5>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.12.0)\n",
|
|||
|
"Requirement already satisfied: jsonschema>=4.18.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (4.24.0)\n",
|
|||
|
"Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.11/dist-packages (from nbclassic->jupyterlab->jupyter==1.1.1) (0.2.4)\n",
|
|||
|
"Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.11/dist-packages (from nbformat->notebook->jupyter==1.1.1) (2.21.1)\n",
|
|||
|
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->jupyter-console->jupyter==1.1.1) (0.2.13)\n",
|
|||
|
"Requirement already satisfied: ptyprocess in /usr/local/lib/python3.11/dist-packages (from terminado>=0.8.3->notebook->jupyter==1.1.1) (0.7.0)\n",
|
|||
|
"Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.11/dist-packages (from argon2-cffi->notebook->jupyter==1.1.1) (21.2.0)\n",
|
|||
|
"Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->nbconvert->jupyter==1.1.1) (2.7)\n",
|
|||
|
"Requirement already satisfied: webencodings in /usr/local/lib/python3.11/dist-packages (from bleach->nbconvert->jupyter==1.1.1) (0.5.1)\n",
|
|||
|
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio>=3.1.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.1)\n",
|
|||
|
"Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.11/dist-packages (from jedi>=0.16->ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.8.4)\n",
|
|||
|
"Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (25.3.0)\n",
|
|||
|
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2025.4.1)\n",
|
|||
|
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.36.2)\n",
|
|||
|
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.25.1)\n",
|
|||
|
"Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.3.0)\n",
|
|||
|
"Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.4)\n",
|
|||
|
"Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.1)\n",
|
|||
|
"Requirement already satisfied: aiofiles<23,>=22.1.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (22.1.0)\n",
|
|||
|
"Requirement already satisfied: aiosqlite<1,>=0.17.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.21.0)\n",
|
|||
|
"Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (1.17.1)\n",
|
|||
|
"Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (2.22)\n",
|
|||
|
"Requirement already satisfied: fqdn in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.5.1)\n",
|
|||
|
"Requirement already satisfied: isoduration in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (20.11.0)\n",
|
|||
|
"Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.0.0)\n",
|
|||
|
"Requirement already satisfied: uri-template in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n",
|
|||
|
"Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (24.11.1)\n",
|
|||
|
"Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.11/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n",
|
|||
|
"Requirement already satisfied: types-python-dateutil>=2.8.10 in /usr/local/lib/python3.11/dist-packages (from arrow>=0.15.0->isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (2.9.0.20250516)\n",
|
|||
|
"Downloading jupyter-1.1.1-py2.py3-none-any.whl (2.7 kB)\n",
|
|||
|
"Downloading pandas-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.4 MB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 12.4/12.4 MB 72.3 MB/s eta 0:00:00\n",
|
|||
|
"Downloading matplotlib-3.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.6 MB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.6/8.6 MB 83.6 MB/s eta 0:00:00\n",
|
|||
|
"Downloading scipy-1.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (37.6 MB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 37.6/37.6 MB 41.4 MB/s eta 0:00:00\n",
|
|||
|
"Downloading scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.5 MB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.5/13.5 MB 71.9 MB/s eta 0:00:00\n",
|
|||
|
"Downloading lightgbm-4.3.0-py3-none-manylinux_2_28_x86_64.whl (3.1 MB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.1/3.1 MB 56.5 MB/s eta 0:00:00\n",
|
|||
|
"Downloading g2p_en-2.1.0-py3-none-any.whl (3.1 MB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.1/3.1 MB 55.2 MB/s eta 0:00:00\n",
|
|||
|
"Downloading h5py-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.5/4.5 MB 60.9 MB/s eta 0:00:00\n",
|
|||
|
"Downloading transformers-4.53.0-py3-none-any.whl (10.8 MB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 10.8/10.8 MB 70.9 MB/s eta 0:00:00\n",
|
|||
|
"Downloading bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl (67.0 MB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 67.0/67.0 MB 23.0 MB/s eta 0:00:00\n",
|
|||
|
"Downloading seaborn-0.13.2-py3-none-any.whl (294 kB)\n",
|
|||
|
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 294.9/294.9 kB 10.7 MB/s eta 0:00:00\n",
|
|||
|
"Building wheels for collected packages: distance\n",
|
|||
|
" Building wheel for distance (setup.py): started\n",
|
|||
|
" Building wheel for distance (setup.py): finished with status 'done'\n",
|
|||
|
" Created wheel for distance: filename=Distance-0.1.3-py3-none-any.whl size=16256 sha256=47c469031387c13df7e41975e1d8b073c7ad21ba646da84d8b4807b4713c12d2\n",
|
|||
|
" Stored in directory: /root/.cache/pip/wheels/fb/cd/9c/3ab5d666e3bcacc58900b10959edd3816cc9557c7337986322\n",
|
|||
|
"Successfully built distance\n",
|
|||
|
"Installing collected packages: distance, jupyter, scipy, pandas, matplotlib, transformers, seaborn, scikit-learn, lightgbm, h5py, g2p_en, bitsandbytes\n",
|
|||
|
" Attempting uninstall: scipy\n",
|
|||
|
" Found existing installation: scipy 1.15.3\n",
|
|||
|
" Uninstalling scipy-1.15.3:\n",
|
|||
|
" Successfully uninstalled scipy-1.15.3\n",
|
|||
|
" Attempting uninstall: pandas\n",
|
|||
|
" Found existing installation: pandas 2.2.3\n",
|
|||
|
" Uninstalling pandas-2.2.3:\n",
|
|||
|
" Successfully uninstalled pandas-2.2.3\n",
|
|||
|
" Attempting uninstall: matplotlib\n",
|
|||
|
" Found existing installation: matplotlib 3.7.2\n",
|
|||
|
" Uninstalling matplotlib-3.7.2:\n",
|
|||
|
" Successfully uninstalled matplotlib-3.7.2\n",
|
|||
|
" Attempting uninstall: transformers\n",
|
|||
|
" Found existing installation: transformers 4.52.4\n",
|
|||
|
" Uninstalling transformers-4.52.4:\n",
|
|||
|
" Successfully uninstalled transformers-4.52.4\n",
|
|||
|
" Attempting uninstall: seaborn\n",
|
|||
|
" Found existing installation: seaborn 0.12.2\n",
|
|||
|
" Uninstalling seaborn-0.12.2:\n",
|
|||
|
" Successfully uninstalled seaborn-0.12.2\n",
|
|||
|
" Attempting uninstall: scikit-learn\n",
|
|||
|
" Found existing installation: scikit-learn 1.2.2\n",
|
|||
|
" Uninstalling scikit-learn-1.2.2:\n",
|
|||
|
" Successfully uninstalled scikit-learn-1.2.2\n",
|
|||
|
" Attempting uninstall: lightgbm\n",
|
|||
|
" Found existing installation: lightgbm 4.5.0\n",
|
|||
|
" Uninstalling lightgbm-4.5.0:\n",
|
|||
|
" Successfully uninstalled lightgbm-4.5.0\n",
|
|||
|
" Attempting uninstall: h5py\n",
|
|||
|
" Found existing installation: h5py 3.14.0\n",
|
|||
|
" Uninstalling h5py-3.14.0:\n",
|
|||
|
" Successfully uninstalled h5py-3.14.0\n",
|
|||
|
"Successfully installed bitsandbytes-0.46.0 distance-0.1.3 g2p_en-2.1.0 h5py-3.13.0 jupyter-1.1.1 lightgbm-4.3.0 matplotlib-3.10.1 pandas-2.3.0 scikit-learn-1.6.1 scipy-1.15.2 seaborn-0.13.2 transformers-4.53.0\n",
|
|||
|
"Obtaining file:///kaggle/working/nejm-brain-to-text\n",
|
|||
|
" Preparing metadata (setup.py): started\n",
|
|||
|
" Preparing metadata (setup.py): finished with status 'done'\n",
|
|||
|
"Installing collected packages: nejm_b2txt_utils\n",
|
|||
|
" Running setup.py develop for nejm_b2txt_utils\n",
|
|||
|
"Successfully installed nejm_b2txt_utils-0.0.0\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Cloning into 'nejm-brain-to-text'...\n",
|
|||
|
"ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
|
|||
|
"bigframes 2.8.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.\n",
|
|||
|
"gensim 4.3.3 requires scipy<1.14.0,>=1.7.0, but you have scipy 1.15.2 which is incompatible.\n",
|
|||
|
"dask-cudf-cu12 25.2.2 requires pandas<2.2.4dev0,>=2.0, but you have pandas 2.3.0 which is incompatible.\n",
|
|||
|
"cudf-cu12 25.2.2 requires pandas<2.2.4dev0,>=2.0, but you have pandas 2.3.0 which is incompatible.\n",
|
|||
|
"datasets 3.6.0 requires fsspec[http]<=2025.3.0,>=2023.1.0, but you have fsspec 2025.5.1 which is incompatible.\n",
|
|||
|
"ydata-profiling 4.16.1 requires matplotlib<=3.10,>=3.5, but you have matplotlib 3.10.1 which is incompatible.\n",
|
|||
|
"category-encoders 2.7.0 requires scikit-learn<1.6.0,>=1.0.0, but you have scikit-learn 1.6.1 which is incompatible.\n",
|
|||
|
"cesium 0.12.4 requires numpy<3.0,>=2.0, but you have numpy 1.26.4 which is incompatible.\n",
|
|||
|
"google-colab 1.0.0 requires google-auth==2.38.0, but you have google-auth 2.40.3 which is incompatible.\n",
|
|||
|
"google-colab 1.0.0 requires notebook==6.5.7, but you have notebook 6.5.4 which is incompatible.\n",
|
|||
|
"google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.3.0 which is incompatible.\n",
|
|||
|
"google-colab 1.0.0 requires requests==2.32.3, but you have requests 2.32.4 which is incompatible.\n",
|
|||
|
"google-colab 1.0.0 requires tornado==6.4.2, but you have tornado 6.5.1 which is incompatible.\n",
|
|||
|
"dopamine-rl 4.1.2 requires gymnasium>=1.0.0, but you have gymnasium 0.29.0 which is incompatible.\n",
|
|||
|
"pandas-gbq 0.29.1 requires google-api-core<3.0.0,>=2.10.2, but you have google-api-core 1.34.1 which is incompatible.\n",
|
|||
|
"bigframes 2.8.0 requires google-cloud-bigquery[bqstorage,pandas]>=3.31.0, but you have google-cloud-bigquery 3.25.0 which is incompatible.\n",
|
|||
|
"bigframes 2.8.0 requires rich<14,>=12.4.4, but you have rich 14.0.0 which is incompatible.\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"%%bash\n",
|
|||
|
"rm -rf /kaggle/working/nejm-brain-to-text/\n",
|
|||
|
"git clone https://github.com/ZH-CEN/nejm-brain-to-text.git\n",
|
|||
|
"cp /kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl /kaggle/working/nejm-brain-to-text/data/t15_copyTask.pkl\n",
|
|||
|
"\n",
|
|||
|
"ln -s /kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline /kaggle/working/nejm-brain-to-text/data\n",
|
|||
|
"ln -s /kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final /kaggle/working/nejm-brain-to-text/data\n",
|
|||
|
"ln -s /kaggle/input/rnn-pretagged-data /kaggle/working/nejm-brain-to-text/data/concatenated_data\n",
|
|||
|
"\n",
|
|||
|
"pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n",
|
|||
|
"\n",
|
|||
|
"pip install \\\n",
|
|||
|
" jupyter==1.1.1 \\\n",
|
|||
|
" \"numpy>=1.26.0,<2.1.0\" \\\n",
|
|||
|
" pandas==2.3.0 \\\n",
|
|||
|
" matplotlib==3.10.1 \\\n",
|
|||
|
" scipy==1.15.2 \\\n",
|
|||
|
" scikit-learn==1.6.1 \\\n",
|
|||
|
" lightgbm==4.3.0 \\\n",
|
|||
|
" tqdm==4.67.1 \\\n",
|
|||
|
" g2p_en==2.1.0 \\\n",
|
|||
|
" h5py==3.13.0 \\\n",
|
|||
|
" omegaconf==2.3.0 \\\n",
|
|||
|
" editdistance==0.8.1 \\\n",
|
|||
|
" huggingface-hub==0.33.1 \\\n",
|
|||
|
" transformers==4.53.0 \\\n",
|
|||
|
" tokenizers==0.21.2 \\\n",
|
|||
|
" accelerate==1.8.1 \\\n",
|
|||
|
" bitsandbytes==0.46.0 \\\n",
|
|||
|
" seaborn==0.13.2\n",
|
|||
|
"cd /kaggle/working/nejm-brain-to-text/\n",
|
|||
|
"pip install -e ."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"==================================================\n",
|
|||
|
"🔧 LightGBM GPU环境检查\n",
|
|||
|
"==================================================\n",
|
|||
|
"✅ NVIDIA GPU检测:\n",
|
|||
|
" Tesla P100-PCIE-16GB, 16384, 560.35.03\n",
|
|||
|
"\n",
|
|||
|
"✅ CUDA工具包:\n",
|
|||
|
" Cuda compilation tools, release 12.5, V12.5.82\n",
|
|||
|
"\n",
|
|||
|
"🔍 LightGBM GPU支持选项:\n",
|
|||
|
" 1. CUDA: NVIDIA GPU的主要支持方式\n",
|
|||
|
" 2. OpenCL: 跨平台GPU支持(NVIDIA/AMD/Intel)\n",
|
|||
|
" 3. 自动回退: GPU不可用时自动使用CPU\n",
|
|||
|
"\n",
|
|||
|
"📦 LightGBM GPU版本安装:\n",
|
|||
|
" 方法1: pip install lightgbm --config-settings=cmake.define.USE_CUDA=ON\n",
|
|||
|
" 方法2: conda install -c conda-forge lightgbm\n",
|
|||
|
" 方法3: 使用预编译的GPU版本\n",
|
|||
|
"\n",
|
|||
|
"⚙️ GPU训练优化建议:\n",
|
|||
|
" - 确保CUDA版本与GPU驱动兼容\n",
|
|||
|
" - 监控GPU内存使用情况\n",
|
|||
|
" - 调整max_bin参数优化GPU性能\n",
|
|||
|
" - 使用合适的num_leaves数量\n",
|
|||
|
"\n",
|
|||
|
"💡 故障排除:\n",
|
|||
|
" 如果GPU训练失败:\n",
|
|||
|
" 1. 检查CUDA安装和版本\n",
|
|||
|
" 2. 确认LightGBM是GPU版本\n",
|
|||
|
" 3. 查看具体错误信息\n",
|
|||
|
" 4. 代码会自动回退到CPU模式\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 🚀 LightGBM GPU支持检查与配置\n",
|
|||
|
"\n",
|
|||
|
"print(\"=\"*50)\n",
|
|||
|
"print(\"🔧 LightGBM GPU环境检查\")\n",
|
|||
|
"print(\"=\"*50)\n",
|
|||
|
"\n",
|
|||
|
"# 检查CUDA和GPU驱动\n",
|
|||
|
"import subprocess\n",
|
|||
|
"import sys\n",
|
|||
|
"\n",
|
|||
|
"def run_command(command):\n",
|
|||
|
" \"\"\"运行命令并返回结果\"\"\"\n",
|
|||
|
" try:\n",
|
|||
|
" result = subprocess.run(command, shell=True, capture_output=True, text=True, timeout=10)\n",
|
|||
|
" return result.stdout.strip(), result.returncode == 0\n",
|
|||
|
" except Exception as e:\n",
|
|||
|
" return str(e), False\n",
|
|||
|
"\n",
|
|||
|
"# 检查NVIDIA GPU\n",
|
|||
|
"nvidia_output, nvidia_success = run_command(\"nvidia-smi --query-gpu=name,memory.total,driver_version --format=csv,noheader,nounits\")\n",
|
|||
|
"if nvidia_success:\n",
|
|||
|
" print(\"✅ NVIDIA GPU检测:\")\n",
|
|||
|
" for line in nvidia_output.split('\\n'):\n",
|
|||
|
" if line.strip():\n",
|
|||
|
" print(f\" {line}\")\n",
|
|||
|
"else:\n",
|
|||
|
" print(\"❌ 未检测到NVIDIA GPU或驱动\")\n",
|
|||
|
"\n",
|
|||
|
"# 检查CUDA版本\n",
|
|||
|
"cuda_output, cuda_success = run_command(\"nvcc --version\")\n",
|
|||
|
"if cuda_success:\n",
|
|||
|
" print(\"\\n✅ CUDA工具包:\")\n",
|
|||
|
" # 提取CUDA版本\n",
|
|||
|
" for line in cuda_output.split('\\n'):\n",
|
|||
|
" if 'release' in line:\n",
|
|||
|
" print(f\" {line.strip()}\")\n",
|
|||
|
"else:\n",
|
|||
|
" print(\"\\n❌ 未安装CUDA工具包\")\n",
|
|||
|
"\n",
|
|||
|
"# 检查OpenCL (LightGBM也支持OpenCL)\n",
|
|||
|
"print(f\"\\n🔍 LightGBM GPU支持选项:\")\n",
|
|||
|
"print(f\" 1. CUDA: NVIDIA GPU的主要支持方式\")\n",
|
|||
|
"print(f\" 2. OpenCL: 跨平台GPU支持(NVIDIA/AMD/Intel)\")\n",
|
|||
|
"print(f\" 3. 自动回退: GPU不可用时自动使用CPU\")\n",
|
|||
|
"\n",
|
|||
|
"# 安装说明\n",
|
|||
|
"print(f\"\\n📦 LightGBM GPU版本安装:\")\n",
|
|||
|
"print(f\" 方法1: pip install lightgbm --config-settings=cmake.define.USE_CUDA=ON\")\n",
|
|||
|
"print(f\" 方法2: conda install -c conda-forge lightgbm\")\n",
|
|||
|
"print(f\" 方法3: 使用预编译的GPU版本\")\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n⚙️ GPU训练优化建议:\")\n",
|
|||
|
"print(f\" - 确保CUDA版本与GPU驱动兼容\")\n",
|
|||
|
"print(f\" - 监控GPU内存使用情况\")\n",
|
|||
|
"print(f\" - 调整max_bin参数优化GPU性能\")\n",
|
|||
|
"print(f\" - 使用合适的num_leaves数量\")\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n💡 故障排除:\")\n",
|
|||
|
"print(f\" 如果GPU训练失败:\")\n",
|
|||
|
"print(f\" 1. 检查CUDA安装和版本\")\n",
|
|||
|
"print(f\" 2. 确认LightGBM是GPU版本\")\n",
|
|||
|
"print(f\" 3. 查看具体错误信息\")\n",
|
|||
|
"print(f\" 4. 代码会自动回退到CPU模式\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/kaggle/working/nejm-brain-to-text\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"%cd /kaggle/working/nejm-brain-to-text\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import os\n",
|
|||
|
"import pickle\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"import matplotlib\n",
|
|||
|
"from g2p_en import G2p\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from nejm_b2txt_utils.general_utils import *\n",
|
|||
|
"\n",
|
|||
|
"matplotlib.rcParams['pdf.fonttype'] = 42\n",
|
|||
|
"matplotlib.rcParams['ps.fonttype'] = 42\n",
|
|||
|
"matplotlib.rcParams['font.family'] = 'sans-serif'\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"/kaggle/working/nejm-brain-to-text/model_training\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"%cd model_training/\n",
|
|||
|
"from data_augmentations import gauss_smooth\n",
|
|||
|
"# single decoding step function that also returns smoothed input\n",
|
|||
|
"# smooths data and puts it through the model, returning both logits and smoothed input.\n",
|
|||
|
"def runSingleDecodingStepWithSmoothedInput(x, input_layer, model, model_args, device):\n",
|
|||
|
"\n",
|
|||
|
" # Use autocast for efficiency\n",
|
|||
|
" with torch.autocast(device_type = \"cuda\", enabled = model_args['use_amp'], dtype = torch.bfloat16):\n",
|
|||
|
"\n",
|
|||
|
" smoothed_x = gauss_smooth(\n",
|
|||
|
" inputs = x, \n",
|
|||
|
" device = device,\n",
|
|||
|
" smooth_kernel_std = model_args['dataset']['data_transforms']['smooth_kernel_std'],\n",
|
|||
|
" smooth_kernel_size = model_args['dataset']['data_transforms']['smooth_kernel_size'],\n",
|
|||
|
" padding = 'valid',\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" with torch.no_grad():\n",
|
|||
|
" logits, _ = model(\n",
|
|||
|
" x = smoothed_x,\n",
|
|||
|
" day_idx = torch.tensor([input_layer], device=device),\n",
|
|||
|
" states = None, # no initial states\n",
|
|||
|
" return_state = True,\n",
|
|||
|
" )\n",
|
|||
|
"\n",
|
|||
|
" # convert both logits and smoothed input from bfloat16 to float32\n",
|
|||
|
" logits = logits.float().cpu().numpy()\n",
|
|||
|
" smoothed_input = smoothed_x.float().cpu().numpy()\n",
|
|||
|
"\n",
|
|||
|
" # # original order is [BLANK, phonemes..., SIL]\n",
|
|||
|
" # # rearrange so the order is [BLANK, SIL, phonemes...]\n",
|
|||
|
" # logits = rearrange_speech_logits_pt(logits)\n",
|
|||
|
"\n",
|
|||
|
" return logits, smoothed_input\n",
|
|||
|
"\n",
|
|||
|
"\n",
|
|||
|
"import h5py\n",
|
|||
|
"def load_h5py_file(file_path, b2txt_csv_df):\n",
|
|||
|
" data = {\n",
|
|||
|
" 'neural_features': [],\n",
|
|||
|
" 'n_time_steps': [],\n",
|
|||
|
" 'seq_class_ids': [],\n",
|
|||
|
" 'seq_len': [],\n",
|
|||
|
" 'transcriptions': [],\n",
|
|||
|
" 'sentence_label': [],\n",
|
|||
|
" 'session': [],\n",
|
|||
|
" 'block_num': [],\n",
|
|||
|
" 'trial_num': [],\n",
|
|||
|
" 'corpus': [],\n",
|
|||
|
" }\n",
|
|||
|
" # Open the hdf5 file for that day\n",
|
|||
|
" with h5py.File(file_path, 'r') as f:\n",
|
|||
|
"\n",
|
|||
|
" keys = list(f.keys())\n",
|
|||
|
"\n",
|
|||
|
" # For each trial in the selected trials in that day\n",
|
|||
|
" for key in keys:\n",
|
|||
|
" g = f[key]\n",
|
|||
|
"\n",
|
|||
|
" neural_features = g['input_features'][:] # pyright: ignore[reportIndexIssue]\n",
|
|||
|
" n_time_steps = g.attrs['n_time_steps']\n",
|
|||
|
" seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None # type: ignore\n",
|
|||
|
" seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None\n",
|
|||
|
" transcription = g['transcription'][:] if 'transcription' in g else None # type: ignore\n",
|
|||
|
" sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None # pyright: ignore[reportIndexIssue]\n",
|
|||
|
" session = g.attrs['session']\n",
|
|||
|
" block_num = g.attrs['block_num']\n",
|
|||
|
" trial_num = g.attrs['trial_num']\n",
|
|||
|
"\n",
|
|||
|
" # match this trial up with the csv to get the corpus name\n",
|
|||
|
" year, month, day = session.split('.')[1:] # pyright: ignore[reportAttributeAccessIssue]\n",
|
|||
|
" date = f'{year}-{month}-{day}'\n",
|
|||
|
" row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)]\n",
|
|||
|
" corpus_name = row['Corpus'].values[0]\n",
|
|||
|
"\n",
|
|||
|
" data['neural_features'].append(neural_features)\n",
|
|||
|
" data['n_time_steps'].append(n_time_steps)\n",
|
|||
|
" data['seq_class_ids'].append(seq_class_ids)\n",
|
|||
|
" data['seq_len'].append(seq_len)\n",
|
|||
|
" data['transcriptions'].append(transcription)\n",
|
|||
|
" data['sentence_label'].append(sentence_label)\n",
|
|||
|
" data['session'].append(session)\n",
|
|||
|
" data['block_num'].append(block_num)\n",
|
|||
|
" data['trial_num'].append(trial_num)\n",
|
|||
|
" data['corpus'].append(corpus_name)\n",
|
|||
|
" return data"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"LOGIT_TO_PHONEME = [\n",
|
|||
|
" 'BLANK',\n",
|
|||
|
" 'AA', 'AE', 'AH', 'AO', 'AW',\n",
|
|||
|
" 'AY', 'B', 'CH', 'D', 'DH',\n",
|
|||
|
" 'EH', 'ER', 'EY', 'F', 'G',\n",
|
|||
|
" 'HH', 'IH', 'IY', 'JH', 'K',\n",
|
|||
|
" 'L', 'M', 'N', 'NG', 'OW',\n",
|
|||
|
" 'OY', 'P', 'R', 'S', 'SH',\n",
|
|||
|
" 'T', 'TH', 'UH', 'UW', 'V',\n",
|
|||
|
" 'W', 'Y', 'Z', 'ZH',\n",
|
|||
|
" ' | ',\n",
|
|||
|
"]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# 数据分析与预处理"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## 数据准备"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"%cd /kaggle/working/nejm-brain-to-text/"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"data = load_h5py_file(file_path='/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.08.11/data_train.hdf5',\n",
|
|||
|
" b2txt_csv_df=pd.read_csv('/kaggle/working/nejm-brain-to-text/data/t15_copyTaskData_description.csv'))"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"data2 = load_h5py_file(file_path='/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.08.13/data_train.hdf5',\n",
|
|||
|
" b2txt_csv_df=pd.read_csv('/kaggle/working/nejm-brain-to-text/data/t15_copyTaskData_description.csv'))"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"- **任务介绍** :机器学习解决高维信号的模式识别问题"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"我们的数据集标签缺少时间戳,现在要进行的是半监督学习"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"- 音素时间均等分割或者按照调研数据设定初始长度。然后筛掉异常值。提取出可用的训练集,再控制时间长短,查看样本类的长度"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"def data_patch(data, index):\n",
|
|||
|
" data_patch = {}\n",
|
|||
|
" data_patch['neural_features'] = data['neural_features'][index]\n",
|
|||
|
" data_patch['n_time_steps'] = data['n_time_steps'][index]\n",
|
|||
|
" data_patch['seq_class_ids'] = data['seq_class_ids'][index]\n",
|
|||
|
" data_patch['seq_len'] = data['seq_len'][index]\n",
|
|||
|
" data_patch['transcriptions'] = data['transcriptions'][index]\n",
|
|||
|
" data_patch['sentence_label'] = data['sentence_label'][index]\n",
|
|||
|
" data_patch['session'] = data['session'][index]\n",
|
|||
|
" data_patch['block_num'] = data['block_num'][index]\n",
|
|||
|
" data_patch['trial_num'] = data['trial_num'][index]\n",
|
|||
|
" data_patch['corpus'] = data['corpus'][index]\n",
|
|||
|
" return data_patch"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"ename": "NameError",
|
|||
|
"evalue": "name 'd1' is not defined",
|
|||
|
"output_type": "error",
|
|||
|
"traceback": [
|
|||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|||
|
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
|||
|
"\u001b[0;32m/tmp/ipykernel_36/3818271146.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtrans_len\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mx\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0md1\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'transcriptions'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mseq_len_nonzero\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mx\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0md1\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'seq_class_ids'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mseq_len\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0md1\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'seq_len'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Transcriptions non-zero length: {trans_len}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Seq class ids non-zero length: {seq_len_nonzero}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;31mNameError\u001b[0m: name 'd1' is not defined"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"trans_len = len([x for x in d1['transcriptions'] if x != 0])\n",
|
|||
|
"seq_len_nonzero = len([x for x in d1['seq_class_ids'] if x != 0])\n",
|
|||
|
"seq_len = d1['seq_len']\n",
|
|||
|
"print(f\"Transcriptions non-zero length: {trans_len}\")\n",
|
|||
|
"print(f\"Seq class ids non-zero length: {seq_len_nonzero}\")\n",
|
|||
|
"print(f\"Seq len: {seq_len}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"ename": "NameError",
|
|||
|
"evalue": "name 'd1' is not defined",
|
|||
|
"output_type": "error",
|
|||
|
"traceback": [
|
|||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|||
|
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
|||
|
"\u001b[0;32m/tmp/ipykernel_36/1034715934.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;31m# Example usage\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0mfeature_sequences\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcreate_time_windows\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 19\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Number of feature sequences:\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeature_sequences\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Shape of first sequence:\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeature_sequences\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;31mNameError\u001b[0m: name 'd1' is not defined"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"def create_time_windows(d1):\n",
|
|||
|
" import numpy as np\n",
|
|||
|
" n_time_steps = d1['n_time_steps']\n",
|
|||
|
" seq_len = d1['seq_len']\n",
|
|||
|
" # Create equal windows\n",
|
|||
|
" edges = np.linspace(0, n_time_steps, seq_len + 1, dtype=int)\n",
|
|||
|
" windows = [(edges[i], edges[i+1]) for i in range(seq_len)]\n",
|
|||
|
" \n",
|
|||
|
" # Extract feature sequences for each window\n",
|
|||
|
" feature_sequences = []\n",
|
|||
|
" for start, end in windows:\n",
|
|||
|
" seq = d1['neural_features'][start:end, :]\n",
|
|||
|
" feature_sequences.append(seq)\n",
|
|||
|
" \n",
|
|||
|
" return feature_sequences\n",
|
|||
|
"\n",
|
|||
|
"# Example usage\n",
|
|||
|
"feature_sequences = create_time_windows(d1)\n",
|
|||
|
"print(\"Number of feature sequences:\", len(feature_sequences))\n",
|
|||
|
"print(\"Shape of first sequence:\", feature_sequences[0].shape)\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 9,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Train: 0, Val: 0, Test: 0\n",
|
|||
|
"Train files (first 3): []\n",
|
|||
|
"Val files (first 3): []\n",
|
|||
|
"Test files (first 3): []\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import os\n",
|
|||
|
"\n",
|
|||
|
"def scan_hdf5_files(base_path):\n",
|
|||
|
" train_files = []\n",
|
|||
|
" val_files = []\n",
|
|||
|
" test_files = []\n",
|
|||
|
" for root, dirs, files in os.walk(base_path):\n",
|
|||
|
" for file in files:\n",
|
|||
|
" if file.endswith('.hdf5'):\n",
|
|||
|
" abs_path = os.path.abspath(os.path.join(root, file))\n",
|
|||
|
" if 'data_train.hdf5' in file:\n",
|
|||
|
" train_files.append(abs_path)\n",
|
|||
|
" elif 'data_val.hdf5' in file:\n",
|
|||
|
" val_files.append(abs_path)\n",
|
|||
|
" elif 'data_test.hdf5' in file:\n",
|
|||
|
" test_files.append(abs_path)\n",
|
|||
|
" return train_files, val_files, test_files\n",
|
|||
|
"\n",
|
|||
|
"# Example usage\n",
|
|||
|
"FILE_PATH = 'data/hdf5_data_final'\n",
|
|||
|
"train_list, val_list, test_list = scan_hdf5_files(FILE_PATH)\n",
|
|||
|
"print(f\"Train: {len(train_list)}, Val: {len(val_list)}, Test: {len(test_list)}\")\n",
|
|||
|
"print(\"Train files (first 3):\", train_list[:3])\n",
|
|||
|
"print(\"Val files (first 3):\", val_list[:3])\n",
|
|||
|
"print(\"Test files (first 3):\", test_list[:3])"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# 数据读取工作流"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"📊 数据文件统计:\n",
|
|||
|
" 训练文件: 45\n",
|
|||
|
" 验证文件: 41\n",
|
|||
|
" 测试文件: 41\n",
|
|||
|
" 每文件最大样本数: 3000\n",
|
|||
|
"\n",
|
|||
|
"🔧 初始化全局PCA...\n",
|
|||
|
"\n",
|
|||
|
"🔧 拟合全局PCA降维器...\n",
|
|||
|
" 配置: {'enable_pca': True, 'n_components': None, 'variance_threshold': 0.95, 'sample_size': 15000}\n",
|
|||
|
" 正在加载文件 1/45: t15.2025.04.13_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 2/45: t15.2024.07.21_train_concatenated.npz\n",
|
|||
|
" 正在加载文件 2/45: t15.2024.07.21_train_concatenated.npz\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"ename": "KeyboardInterrupt",
|
|||
|
"evalue": "",
|
|||
|
"output_type": "error",
|
|||
|
"traceback": [
|
|||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|||
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
|||
|
"\u001b[0;32m/tmp/ipykernel_36/650628148.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 179\u001b[0m \u001b[0;31m# 🔧 初始化全局PCA (只在训练集上拟合一次)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 180\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"\\n🔧 初始化全局PCA...\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 181\u001b[0;31m \u001b[0mfit_global_pca\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_dir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mPCA_CONFIG\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 182\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 183\u001b[0m \u001b[0;31m# 内存友好的数据加载策略 (带PCA集成)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;32m/tmp/ipykernel_36/650628148.py\u001b[0m in \u001b[0;36mfit_global_pca\u001b[0;34m(data_dir, config)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0mcollected_samples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 92\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mtrials_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilename\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mload_data_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_dir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'train'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m5000\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 93\u001b[0m \u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mextract_features_labels_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrials_batch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0msample_features\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;32m/tmp/ipykernel_36/650628148.py\u001b[0m in \u001b[0;36mload_data_batch\u001b[0;34m(data_dir, data_type, max_samples_per_file)\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 44\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_dir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mallow_pickle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 45\u001b[0;31m \u001b[0mtrials\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'neural_logits_concatenated'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 46\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;31m# 限制每个文件的样本数\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/numpy/lib/npyio.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 254\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmagic\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMAGIC_PREFIX\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[0mbytes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzip\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 256\u001b[0;31m return format.read_array(bytes,\n\u001b[0m\u001b[1;32m 257\u001b[0m \u001b[0mallow_pickle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mallow_pickle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 258\u001b[0m \u001b[0mpickle_kwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpickle_kwargs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/numpy/lib/format.py\u001b[0m in \u001b[0;36mread_array\u001b[0;34m(fp, allow_pickle, pickle_kwargs, max_header_size)\u001b[0m\n\u001b[1;32m 798\u001b[0m \u001b[0mpickle_kwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 799\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 800\u001b[0;31m \u001b[0marray\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mpickle_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 801\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mUnicodeError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0merr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 802\u001b[0m \u001b[0;31m# Friendlier error message\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;32m/usr/lib/python3.11/zipfile.py\u001b[0m in \u001b[0;36mread\u001b[0;34m(self, n)\u001b[0m\n\u001b[1;32m 964\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_offset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 965\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_eof\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 966\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_read1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 967\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 968\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_readbuffer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;32m/usr/lib/python3.11/zipfile.py\u001b[0m in \u001b[0;36m_read1\u001b[0;34m(self, n)\u001b[0m\n\u001b[1;32m 1040\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compress_type\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mZIP_DEFLATED\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1041\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMIN_READ_SIZE\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1042\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_decompressor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecompress\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1043\u001b[0m self._eof = (self._decompressor.eof or\n\u001b[1;32m 1044\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compress_left\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mand\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 🚀 内存友好的数据读取 - 分批加载策略 + PCA降维\n",
|
|||
|
"\n",
|
|||
|
"import os\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import gc\n",
|
|||
|
"from sklearn.decomposition import PCA\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"import joblib\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"\n",
|
|||
|
"# 全局PCA配置\n",
|
|||
|
"PCA_CONFIG = {\n",
|
|||
|
" 'enable_pca': True, # 是否启用PCA\n",
|
|||
|
" 'n_components': None, # None=自动选择, 或指定具体数值\n",
|
|||
|
" 'variance_threshold': 0.95, # 保留95%的方差\n",
|
|||
|
" 'sample_size': 15000, # 用于拟合PCA的样本数\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# 全局PCA对象 (确保只拟合一次)\n",
|
|||
|
"GLOBAL_PCA = {\n",
|
|||
|
" 'scaler': None,\n",
|
|||
|
" 'pca': None,\n",
|
|||
|
" 'is_fitted': False,\n",
|
|||
|
" 'n_components': None\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"def load_data_batch(data_dir, data_type, max_samples_per_file=5000):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 分批加载指定类型的数据\n",
|
|||
|
" \n",
|
|||
|
" Args:\n",
|
|||
|
" data_dir: 数据目录\n",
|
|||
|
" data_type: 'train', 'val', 'test'\n",
|
|||
|
" max_samples_per_file: 每个文件最大加载样本数\n",
|
|||
|
" \n",
|
|||
|
" Returns:\n",
|
|||
|
" generator: 数据批次生成器\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" files = [f for f in os.listdir(data_dir) if f.endswith('.npz') and data_type in f]\n",
|
|||
|
" \n",
|
|||
|
" for file_idx, f in enumerate(files):\n",
|
|||
|
" print(f\" 正在加载文件 {file_idx+1}/{len(files)}: {f}\")\n",
|
|||
|
" \n",
|
|||
|
" data = np.load(os.path.join(data_dir, f), allow_pickle=True)\n",
|
|||
|
" trials = data['neural_logits_concatenated']\n",
|
|||
|
" \n",
|
|||
|
" # 限制每个文件的样本数\n",
|
|||
|
" if len(trials) > max_samples_per_file:\n",
|
|||
|
" trials = trials[:max_samples_per_file]\n",
|
|||
|
" print(f\" 限制样本数至: {max_samples_per_file}\")\n",
|
|||
|
" \n",
|
|||
|
" yield trials, f\n",
|
|||
|
" \n",
|
|||
|
" # 清理内存\n",
|
|||
|
" del data, trials\n",
|
|||
|
" gc.collect()\n",
|
|||
|
"\n",
|
|||
|
"def extract_features_labels_batch(trials_batch):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 从试验批次中提取特征和标签\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" features = []\n",
|
|||
|
" labels = []\n",
|
|||
|
" \n",
|
|||
|
" for trial in trials_batch:\n",
|
|||
|
" if trial.shape[0] > 0:\n",
|
|||
|
" for t in range(trial.shape[0]):\n",
|
|||
|
" neural_features = trial[t, :7168] # 前7168维神经特征\n",
|
|||
|
" rnn_logits = trial[t, 7168:] # 后41维RNN输出\n",
|
|||
|
" phoneme_label = np.argmax(rnn_logits)\n",
|
|||
|
" \n",
|
|||
|
" features.append(neural_features)\n",
|
|||
|
" labels.append(phoneme_label)\n",
|
|||
|
" \n",
|
|||
|
" return np.array(features), np.array(labels)\n",
|
|||
|
"\n",
|
|||
|
"def fit_global_pca(data_dir, config):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 在训练数据上拟合全局PCA (只执行一次)\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if GLOBAL_PCA['is_fitted'] or not config['enable_pca']:\n",
|
|||
|
" print(\"🔧 PCA已拟合或未启用,跳过拟合步骤\")\n",
|
|||
|
" return\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n🔧 拟合全局PCA降维器...\")\n",
|
|||
|
" print(f\" 配置: {config}\")\n",
|
|||
|
" \n",
|
|||
|
" # 收集训练样本\n",
|
|||
|
" sample_features = []\n",
|
|||
|
" collected_samples = 0\n",
|
|||
|
" \n",
|
|||
|
" for trials_batch, filename in load_data_batch(data_dir, 'train', 5000):\n",
|
|||
|
" features, labels = extract_features_labels_batch(trials_batch)\n",
|
|||
|
" sample_features.append(features)\n",
|
|||
|
" collected_samples += features.shape[0]\n",
|
|||
|
" \n",
|
|||
|
" if collected_samples >= config['sample_size']:\n",
|
|||
|
" break\n",
|
|||
|
" \n",
|
|||
|
" if sample_features:\n",
|
|||
|
" # 合并样本数据\n",
|
|||
|
" X_sample = np.vstack(sample_features)[:config['sample_size']]\n",
|
|||
|
" print(f\" 实际样本数: {X_sample.shape[0]}\")\n",
|
|||
|
" print(f\" 原始特征数: {X_sample.shape[1]}\")\n",
|
|||
|
" \n",
|
|||
|
" # 标准化\n",
|
|||
|
" GLOBAL_PCA['scaler'] = StandardScaler()\n",
|
|||
|
" X_sample_scaled = GLOBAL_PCA['scaler'].fit_transform(X_sample)\n",
|
|||
|
" \n",
|
|||
|
" # 确定PCA成分数\n",
|
|||
|
" if config['n_components'] is None:\n",
|
|||
|
" print(f\" 🔍 自动选择PCA成分数...\")\n",
|
|||
|
" pca_full = PCA()\n",
|
|||
|
" pca_full.fit(X_sample_scaled)\n",
|
|||
|
" \n",
|
|||
|
" cumsum_var = np.cumsum(pca_full.explained_variance_ratio_)\n",
|
|||
|
" optimal_components = np.argmax(cumsum_var >= config['variance_threshold']) + 1\n",
|
|||
|
" GLOBAL_PCA['n_components'] = min(optimal_components, X_sample.shape[1])\n",
|
|||
|
" \n",
|
|||
|
" print(f\" 保留{config['variance_threshold']*100}%方差需要: {optimal_components} 个成分\")\n",
|
|||
|
" print(f\" 选择成分数: {GLOBAL_PCA['n_components']}\")\n",
|
|||
|
" else:\n",
|
|||
|
" GLOBAL_PCA['n_components'] = config['n_components']\n",
|
|||
|
" print(f\" 使用指定成分数: {GLOBAL_PCA['n_components']}\")\n",
|
|||
|
" \n",
|
|||
|
" # 拟合最终PCA\n",
|
|||
|
" GLOBAL_PCA['pca'] = PCA(n_components=GLOBAL_PCA['n_components'], random_state=42)\n",
|
|||
|
" GLOBAL_PCA['pca'].fit(X_sample_scaled)\n",
|
|||
|
" GLOBAL_PCA['is_fitted'] = True\n",
|
|||
|
" \n",
|
|||
|
" # 保存模型\n",
|
|||
|
" pca_path = \"global_pca_model.joblib\"\n",
|
|||
|
" joblib.dump({\n",
|
|||
|
" 'scaler': GLOBAL_PCA['scaler'], \n",
|
|||
|
" 'pca': GLOBAL_PCA['pca'],\n",
|
|||
|
" 'n_components': GLOBAL_PCA['n_components']\n",
|
|||
|
" }, pca_path)\n",
|
|||
|
" \n",
|
|||
|
" print(f\" ✅ 全局PCA拟合完成!\")\n",
|
|||
|
" print(f\" 降维: {X_sample.shape[1]} → {GLOBAL_PCA['n_components']}\")\n",
|
|||
|
" print(f\" 降维比例: {GLOBAL_PCA['n_components']/X_sample.shape[1]:.2%}\")\n",
|
|||
|
" print(f\" 保留方差: {GLOBAL_PCA['pca'].explained_variance_ratio_.sum():.4f}\")\n",
|
|||
|
" print(f\" 模型已保存: {pca_path}\")\n",
|
|||
|
" \n",
|
|||
|
" # 清理样本数据\n",
|
|||
|
" del sample_features, X_sample, X_sample_scaled\n",
|
|||
|
" gc.collect()\n",
|
|||
|
" else:\n",
|
|||
|
" print(\"❌ 无法收集样本数据用于PCA拟合\")\n",
|
|||
|
"\n",
|
|||
|
"def apply_pca_transform(features):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 应用全局PCA变换\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if not PCA_CONFIG['enable_pca'] or not GLOBAL_PCA['is_fitted']:\n",
|
|||
|
" return features\n",
|
|||
|
" \n",
|
|||
|
" # 标准化 + PCA变换\n",
|
|||
|
" features_scaled = GLOBAL_PCA['scaler'].transform(features)\n",
|
|||
|
" features_pca = GLOBAL_PCA['pca'].transform(features_scaled)\n",
|
|||
|
" return features_pca\n",
|
|||
|
"\n",
|
|||
|
"# 设置数据目录和参数\n",
|
|||
|
"data_dir = '/kaggle/working/nejm-brain-to-text/data/concatenated_data'\n",
|
|||
|
"MAX_SAMPLES_PER_FILE = 3000 # 每个文件最大样本数,可调整\n",
|
|||
|
"\n",
|
|||
|
"# 检查可用文件\n",
|
|||
|
"all_files = [f for f in os.listdir(data_dir) if f.endswith('.npz')]\n",
|
|||
|
"train_files = [f for f in all_files if 'train' in f]\n",
|
|||
|
"val_files = [f for f in all_files if 'val' in f]\n",
|
|||
|
"test_files = [f for f in all_files if 'test' in f]\n",
|
|||
|
"\n",
|
|||
|
"print(f\"📊 数据文件统计:\")\n",
|
|||
|
"print(f\" 训练文件: {len(train_files)}\")\n",
|
|||
|
"print(f\" 验证文件: {len(val_files)}\")\n",
|
|||
|
"print(f\" 测试文件: {len(test_files)}\")\n",
|
|||
|
"print(f\" 每文件最大样本数: {MAX_SAMPLES_PER_FILE}\")\n",
|
|||
|
"\n",
|
|||
|
"# 🔧 初始化全局PCA (只在训练集上拟合一次)\n",
|
|||
|
"print(f\"\\n🔧 初始化全局PCA...\")\n",
|
|||
|
"fit_global_pca(data_dir, PCA_CONFIG)\n",
|
|||
|
"\n",
|
|||
|
"# 内存友好的数据加载策略 (带PCA集成)\n",
|
|||
|
"class MemoryFriendlyDataset:\n",
|
|||
|
" def __init__(self, data_dir, data_type, max_samples_per_file=3000):\n",
|
|||
|
" self.data_dir = data_dir\n",
|
|||
|
" self.data_type = data_type\n",
|
|||
|
" self.max_samples_per_file = max_samples_per_file\n",
|
|||
|
" self.files = [f for f in os.listdir(data_dir) if f.endswith('.npz') and data_type in f]\n",
|
|||
|
" \n",
|
|||
|
" def load_all_data(self):\n",
|
|||
|
" \"\"\"一次性加载所有数据(自动应用PCA)\"\"\"\n",
|
|||
|
" print(f\"\\n🔄 加载{self.data_type}数据...\")\n",
|
|||
|
" all_features = []\n",
|
|||
|
" all_labels = []\n",
|
|||
|
" \n",
|
|||
|
" for trials_batch, filename in load_data_batch(self.data_dir, self.data_type, self.max_samples_per_file):\n",
|
|||
|
" features, labels = extract_features_labels_batch(trials_batch)\n",
|
|||
|
" \n",
|
|||
|
" # 应用PCA降维\n",
|
|||
|
" features_processed = apply_pca_transform(features)\n",
|
|||
|
" \n",
|
|||
|
" all_features.append(features_processed)\n",
|
|||
|
" all_labels.append(labels)\n",
|
|||
|
" \n",
|
|||
|
" if all_features:\n",
|
|||
|
" X = np.vstack(all_features)\n",
|
|||
|
" y = np.hstack(all_labels)\n",
|
|||
|
" \n",
|
|||
|
" # 清理临时数据\n",
|
|||
|
" del all_features, all_labels\n",
|
|||
|
" gc.collect()\n",
|
|||
|
" \n",
|
|||
|
" feature_info = f\"{X.shape[1]} PCA特征\" if PCA_CONFIG['enable_pca'] else f\"{X.shape[1]} 原始特征\"\n",
|
|||
|
" print(f\" ✅ 加载完成: {X.shape[0]} 样本, {feature_info}\")\n",
|
|||
|
" return X, y\n",
|
|||
|
" else:\n",
|
|||
|
" return None, None\n",
|
|||
|
" \n",
|
|||
|
" def get_batch_generator(self):\n",
|
|||
|
" \"\"\"返回批次生成器(自动应用PCA)\"\"\"\n",
|
|||
|
" for trials_batch, filename in load_data_batch(self.data_dir, self.data_type, self.max_samples_per_file):\n",
|
|||
|
" features, labels = extract_features_labels_batch(trials_batch)\n",
|
|||
|
" \n",
|
|||
|
" # 应用PCA降维\n",
|
|||
|
" features_processed = apply_pca_transform(features)\n",
|
|||
|
" \n",
|
|||
|
" yield features_processed, labels\n",
|
|||
|
"\n",
|
|||
|
"# 创建数据集对象\n",
|
|||
|
"train_dataset = MemoryFriendlyDataset(data_dir, 'train', MAX_SAMPLES_PER_FILE)\n",
|
|||
|
"val_dataset = MemoryFriendlyDataset(data_dir, 'val', MAX_SAMPLES_PER_FILE)\n",
|
|||
|
"test_dataset = MemoryFriendlyDataset(data_dir, 'test', MAX_SAMPLES_PER_FILE)\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n✅ 集成PCA的内存友好数据集已创建\")\n",
|
|||
|
"if PCA_CONFIG['enable_pca'] and GLOBAL_PCA['is_fitted']:\n",
|
|||
|
" print(f\" 🔬 PCA降维: 7168 → {GLOBAL_PCA['n_components']} ({GLOBAL_PCA['n_components']/7168:.1%})\")\n",
|
|||
|
" print(f\" 📊 方差保留: {GLOBAL_PCA['pca'].explained_variance_ratio_.sum():.4f}\")\n",
|
|||
|
"print(f\" 使用方式1: dataset.load_all_data() - 一次性加载 (自动PCA)\")\n",
|
|||
|
"print(f\" 使用方式2: dataset.get_batch_generator() - 分批处理 (自动PCA)\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"train_dataset"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 8,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"[Errno 2] No such file or directory: 'model_training/'\n",
|
|||
|
"/kaggle/working/nejm-brain-to-text/model_training\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"%cd model_training/"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# 模型建立"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"## LightGBM 梯度提升决策树\n",
|
|||
|
"\n",
|
|||
|
"使用LightGBM进行音素分类任务。LightGBM是微软开发的高效梯度提升框架,具有以下优势:\n",
|
|||
|
"\n",
|
|||
|
"- **训练速度快**: 相比传统GBDT算法速度提升10倍以上\n",
|
|||
|
"- **内存占用低**: 使用直方图算法减少内存使用\n",
|
|||
|
"- **准确率高**: 在许多机器学习竞赛中表现优异 \n",
|
|||
|
"- **支持并行**: 支持特征并行和数据并行\n",
|
|||
|
"- **可解释性强**: 提供特征重要性分析\n",
|
|||
|
"\n",
|
|||
|
"对于脑电信号到音素的分类任务,LightGBM能够:\n",
|
|||
|
"1. 自动处理高维特征(7168维神经信号)\n",
|
|||
|
"2. 发现特征之间的非线性关系\n",
|
|||
|
"3. 提供特征重要性排序,帮助理解哪些脑区信号最重要\n",
|
|||
|
"4. 快速训练,适合实验和调参"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 🚀 LightGBM内存友好训练 - 分批处理策略\n",
|
|||
|
"\n",
|
|||
|
"import lightgbm as lgb\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import time\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"from sklearn.metrics import accuracy_score, classification_report\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"import seaborn as sns\n",
|
|||
|
"import gc\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 21,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"\n",
|
|||
|
"Usage: \n",
|
|||
|
" pip3 install [options] <requirement specifier> [package-index-options] ...\n",
|
|||
|
" pip3 install [options] -r <requirements file> [package-index-options] ...\n",
|
|||
|
" pip3 install [options] [-e] <vcs project url> ...\n",
|
|||
|
" pip3 install [options] [-e] <local project path> ...\n",
|
|||
|
" pip3 install [options] <archive url/path> ...\n",
|
|||
|
"\n",
|
|||
|
"no such option: --install-option\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"!pip install lightgbm --install-option=--gpu"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 11,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"🔧 设备检查:\n",
|
|||
|
" LightGBM GPU支持: ✅ 可用\n",
|
|||
|
" 训练设备: GPU\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 检查GPU可用性\n",
|
|||
|
"def check_gpu_support():\n",
|
|||
|
" \"\"\"检查LightGBM的GPU支持\"\"\"\n",
|
|||
|
" try:\n",
|
|||
|
" test_data = lgb.Dataset(np.random.rand(100, 10), label=np.random.randint(0, 2, 100))\n",
|
|||
|
" test_params = {'device': 'gpu', 'objective': 'binary', 'verbose': -1}\n",
|
|||
|
" # 新版本LightGBM使用callbacks参数而不是verbose_eval\n",
|
|||
|
" lgb.train(test_params, test_data, num_boost_round=1, callbacks=[])\n",
|
|||
|
" return True\n",
|
|||
|
" except Exception as e:\n",
|
|||
|
" print(f\" GPU支持检查失败: {e}\")\n",
|
|||
|
" return False\n",
|
|||
|
"\n",
|
|||
|
"gpu_available = check_gpu_support()\n",
|
|||
|
"print(f\"🔧 设备检查:\")\n",
|
|||
|
"print(f\" LightGBM GPU支持: {'✅ 可用' if gpu_available else '❌ 不可用,将使用CPU'}\")\n",
|
|||
|
"\n",
|
|||
|
"# 根据GPU可用性选择设备\n",
|
|||
|
"device_type = 'gpu' if gpu_available else 'cpu'\n",
|
|||
|
"print(f\" 训练设备: {device_type.upper()}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 13,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"\n",
|
|||
|
"# 检查数据集是否已创建\n",
|
|||
|
"if 'train_dataset' not in locals():\n",
|
|||
|
" print(\"❌ 错误: 请先运行数据读取代码创建数据集\")\n",
|
|||
|
" exit()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 14,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"\n",
|
|||
|
"📊 当前内存使用: 851.5 MB\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 内存检查函数\n",
|
|||
|
"def get_memory_usage():\n",
|
|||
|
" \"\"\"获取当前内存使用情况\"\"\"\n",
|
|||
|
" import psutil\n",
|
|||
|
" process = psutil.Process()\n",
|
|||
|
" return f\"{process.memory_info().rss / 1024 / 1024:.1f} MB\"\n",
|
|||
|
"\n",
|
|||
|
"def memory_cleanup():\n",
|
|||
|
" \"\"\"强制内存清理\"\"\"\n",
|
|||
|
" gc.collect()\n",
|
|||
|
" print(f\" 内存清理后: {get_memory_usage()}\")\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n📊 当前内存使用: {get_memory_usage()}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 15,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"\n",
|
|||
|
"🔧 训练策略配置:\n",
|
|||
|
" 内存限制: 25.0 GB\n",
|
|||
|
" 分批训练: ✅ 启用\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 策略选择:根据内存情况选择训练方式\n",
|
|||
|
"MEMORY_LIMIT_MB = 25000 # 25GB内存限制\n",
|
|||
|
"USE_BATCH_TRAINING = True # 是否使用分批训练\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n🔧 训练策略配置:\")\n",
|
|||
|
"print(f\" 内存限制: {MEMORY_LIMIT_MB/1000:.1f} GB\")\n",
|
|||
|
"print(f\" 分批训练: {'✅ 启用' if USE_BATCH_TRAINING else '❌ 禁用'}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 16,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"\n",
|
|||
|
"🔄 分批训练模式:\n",
|
|||
|
"<22> 第1阶段: 加载样本数据确定参数...\n",
|
|||
|
" 正在加载文件 1/45: t15.2025.04.13_train_concatenated.npz\n",
|
|||
|
" 加载批次 1: 14677 样本\n",
|
|||
|
" 正在加载文件 2/45: t15.2024.07.21_train_concatenated.npz\n",
|
|||
|
" 加载批次 1: 14677 样本\n",
|
|||
|
" 正在加载文件 2/45: t15.2024.07.21_train_concatenated.npz\n",
|
|||
|
" ✅ 样本数据: 58850 样本, 41 类别\n",
|
|||
|
" ✅ 样本数据: 58850 样本, 41 类别\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"print(f\"\\n🔄 分批训练模式:\")\n",
|
|||
|
"\n",
|
|||
|
"# 首先加载一小部分训练数据来确定模型参数\n",
|
|||
|
"print(f\"第1阶段: 加载样本数据确定参数...\")\n",
|
|||
|
"sample_X, sample_y = None, None\n",
|
|||
|
"\n",
|
|||
|
"batch_count = 0\n",
|
|||
|
"for features, labels in train_dataset.get_batch_generator():\n",
|
|||
|
" if sample_X is None:\n",
|
|||
|
" sample_X, sample_y = features, labels\n",
|
|||
|
" else:\n",
|
|||
|
" sample_X = np.vstack([sample_X, features])\n",
|
|||
|
" sample_y = np.hstack([sample_y, labels])\n",
|
|||
|
" \n",
|
|||
|
" batch_count += 1\n",
|
|||
|
" if batch_count >= 2: # 只取前2个批次作为样本\n",
|
|||
|
" break\n",
|
|||
|
" \n",
|
|||
|
" print(f\" 加载批次 {batch_count}: {features.shape[0]} 样本\")\n",
|
|||
|
"\n",
|
|||
|
"print(f\" ✅ 样本数据: {sample_X.shape[0]} 样本, {len(np.unique(sample_y))} 类别\")\n",
|
|||
|
"\n",
|
|||
|
"# 数据预处理\n",
|
|||
|
"scaler = StandardScaler()\n",
|
|||
|
"sample_X_scaled = scaler.fit_transform(sample_X)\n",
|
|||
|
"\n",
|
|||
|
"# 切分样本数据\n",
|
|||
|
"X_sample_train, X_sample_val, y_sample_train, y_sample_val = train_test_split(\n",
|
|||
|
" sample_X_scaled, sample_y, test_size=0.2, random_state=42, stratify=sample_y\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"# LightGBM参数配置\n",
|
|||
|
"if gpu_available:\n",
|
|||
|
" params = {\n",
|
|||
|
" 'objective': 'multiclass',\n",
|
|||
|
" 'num_class': len(np.unique(sample_y)),\n",
|
|||
|
" 'metric': 'multi_logloss',\n",
|
|||
|
" 'boosting_type': 'gbdt',\n",
|
|||
|
" 'device': 'gpu',\n",
|
|||
|
" 'gpu_platform_id': 0,\n",
|
|||
|
" 'gpu_device_id': 0,\n",
|
|||
|
" 'num_leaves': 64, # 减少叶子节点以节省内存\n",
|
|||
|
" 'learning_rate': 0.1,\n",
|
|||
|
" 'feature_fraction': 0.8,\n",
|
|||
|
" 'bagging_fraction': 0.8,\n",
|
|||
|
" 'bagging_freq': 5,\n",
|
|||
|
" 'verbose': 0,\n",
|
|||
|
" 'random_state': 42,\n",
|
|||
|
" 'max_bin': 255,\n",
|
|||
|
" }\n",
|
|||
|
"else:\n",
|
|||
|
" params = {\n",
|
|||
|
" 'objective': 'multiclass',\n",
|
|||
|
" 'num_class': len(np.unique(sample_y)),\n",
|
|||
|
" 'metric': 'multi_logloss',\n",
|
|||
|
" 'boosting_type': 'gbdt',\n",
|
|||
|
" 'device': 'cpu',\n",
|
|||
|
" 'num_leaves': 32, # CPU使用更少叶子节点\n",
|
|||
|
" 'learning_rate': 0.1,\n",
|
|||
|
" 'feature_fraction': 0.8,\n",
|
|||
|
" 'bagging_fraction': 0.8,\n",
|
|||
|
" 'bagging_freq': 5,\n",
|
|||
|
" 'verbose': 0,\n",
|
|||
|
" 'random_state': 42,\n",
|
|||
|
" 'n_jobs': -1,\n",
|
|||
|
" }\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 31,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"(47080, 7168)"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 31,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"X_sample_train.shape"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 18,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"\n",
|
|||
|
"🏗️ LightGBM配置:\n",
|
|||
|
" 设备: GPU\n",
|
|||
|
" 类别数: 41\n",
|
|||
|
" 叶子节点: 64\n",
|
|||
|
"\n",
|
|||
|
"🚀 第2阶段: 在样本数据上训练初始模型...\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"\n",
|
|||
|
"🏗️ LightGBM配置:\n",
|
|||
|
" 设备: GPU\n",
|
|||
|
" 类别数: 41\n",
|
|||
|
" 叶子节点: 64\n",
|
|||
|
"\n",
|
|||
|
"🚀 第2阶段: 在样本数据上训练初始模型...\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"\n",
|
|||
|
"🏗️ LightGBM配置:\n",
|
|||
|
" 设备: GPU\n",
|
|||
|
" 类别数: 41\n",
|
|||
|
" 叶子节点: 64\n",
|
|||
|
"\n",
|
|||
|
"🚀 第2阶段: 在样本数据上训练初始模型...\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Training until validation scores don't improve for 20 rounds\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"\n",
|
|||
|
"🏗️ LightGBM配置:\n",
|
|||
|
" 设备: GPU\n",
|
|||
|
" 类别数: 41\n",
|
|||
|
" 叶子节点: 64\n",
|
|||
|
"\n",
|
|||
|
"🚀 第2阶段: 在样本数据上训练初始模型...\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n",
|
|||
|
"1 warning generated.\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Training until validation scores don't improve for 20 rounds\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"ename": "KeyboardInterrupt",
|
|||
|
"evalue": "",
|
|||
|
"output_type": "error",
|
|||
|
"traceback": [
|
|||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|||
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
|||
|
"\u001b[0;32m/tmp/ipykernel_36/4123818799.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 16\u001b[0m ]\n\u001b[1;32m 17\u001b[0m \u001b[0;31m# 在样本数据上训练初始模型\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m model = lgb.train(\n\u001b[0m\u001b[1;32m 19\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0msample_train_data\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/lightgbm/engine.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(params, train_set, num_boost_round, valid_sets, valid_names, feval, init_model, feature_name, categorical_feature, keep_training_booster, callbacks)\u001b[0m\n\u001b[1;32m 274\u001b[0m evaluation_result_list=None))\n\u001b[1;32m 275\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 276\u001b[0;31m \u001b[0mbooster\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfobj\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 277\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 278\u001b[0m \u001b[0mevaluation_result_list\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0m_LGBM_BoosterEvalMethodResultType\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, train_set, fobj)\u001b[0m\n\u001b[1;32m 3889\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__set_objective_to_none\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3890\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mLightGBMError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Cannot update due to null objective function.'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3891\u001b[0;31m _safe_call(_LIB.LGBM_BoosterUpdateOneIter(\n\u001b[0m\u001b[1;32m 3892\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_handle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3893\u001b[0m ctypes.byref(is_finished)))\n",
|
|||
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"print(f\"\\n🏗️ LightGBM配置:\")\n",
|
|||
|
"print(f\" 设备: {params['device'].upper()}\")\n",
|
|||
|
"print(f\" 类别数: {params['num_class']}\")\n",
|
|||
|
"print(f\" 叶子节点: {params['num_leaves']}\")\n",
|
|||
|
"\n",
|
|||
|
"# 创建样本数据集\n",
|
|||
|
"sample_train_data = lgb.Dataset(X_sample_train, label=y_sample_train)\n",
|
|||
|
"sample_val_data = lgb.Dataset(X_sample_val, label=y_sample_val, reference=sample_train_data)\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n🚀 第2阶段: 在样本数据上训练初始模型...\")\n",
|
|||
|
"start_time = time.time()\n",
|
|||
|
"\n",
|
|||
|
"callbacks = [\n",
|
|||
|
" lgb.log_evaluation(period=50),\n",
|
|||
|
" lgb.early_stopping(stopping_rounds=20)\n",
|
|||
|
"]\n",
|
|||
|
"# 在样本数据上训练初始模型\n",
|
|||
|
"model = lgb.train(\n",
|
|||
|
" params,\n",
|
|||
|
" sample_train_data,\n",
|
|||
|
" valid_sets=[sample_train_data, sample_val_data],\n",
|
|||
|
" valid_names=['train', 'val'],\n",
|
|||
|
" num_boost_round=200, # 较少的轮数\n",
|
|||
|
" callbacks=callbacks\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"print(f\" ✅ 初始模型训练完成\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 清理样本数据内存\n",
|
|||
|
"del sample_X, sample_y, X_sample_train, X_sample_val, y_sample_train, y_sample_val\n",
|
|||
|
"del sample_train_data, sample_val_data\n",
|
|||
|
"memory_cleanup()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"print(f\"\\n🔄 第3阶段: 分批增量训练...\")\n",
|
|||
|
"\n",
|
|||
|
"# 分批处理完整训练数据进行增量训练\n",
|
|||
|
"batch_num = 0\n",
|
|||
|
"total_samples_processed = 0\n",
|
|||
|
"\n",
|
|||
|
"for features, labels in train_dataset.get_batch_generator():\n",
|
|||
|
" batch_num += 1\n",
|
|||
|
" print(f\"\\n 处理批次 {batch_num}: {features.shape[0]} 样本\")\n",
|
|||
|
" \n",
|
|||
|
" # 数据预处理\n",
|
|||
|
" features_scaled = scaler.transform(features)\n",
|
|||
|
" \n",
|
|||
|
" # 创建批次数据集\n",
|
|||
|
" batch_data = lgb.Dataset(features_scaled, label=labels, reference=model.train_set if hasattr(model, 'train_set') else None)\n",
|
|||
|
" \n",
|
|||
|
" # 增量训练(继续训练现有模型)\n",
|
|||
|
" model = lgb.train(\n",
|
|||
|
" params,\n",
|
|||
|
" batch_data,\n",
|
|||
|
" num_boost_round=20, # 每批次少量轮数\n",
|
|||
|
" init_model=model, # 基于现有模型继续训练\n",
|
|||
|
" verbose_eval=False\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" total_samples_processed += features.shape[0]\n",
|
|||
|
" print(f\" 累计处理样本: {total_samples_processed}\")\n",
|
|||
|
" \n",
|
|||
|
" # 清理批次数据\n",
|
|||
|
" del features, labels, features_scaled, batch_data\n",
|
|||
|
" memory_cleanup()\n",
|
|||
|
" \n",
|
|||
|
" # 内存保护:如果处理的样本足够多就停止\n",
|
|||
|
" if total_samples_processed > 50000: # 限制总样本数\n",
|
|||
|
" print(f\" ⚠️ 达到样本限制,停止训练\")\n",
|
|||
|
" break\n",
|
|||
|
"\n",
|
|||
|
"training_time = time.time() - start_time\n",
|
|||
|
"print(f\"\\n✅ 分批训练完成!\")\n",
|
|||
|
"print(f\" 总耗时: {training_time:.2f} 秒\")\n",
|
|||
|
"print(f\" 处理样本数: {total_samples_processed}\")\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"\n",
|
|||
|
"# 模型评估阶段 - 按需加载验证/测试数据\n",
|
|||
|
"print(f\"\\n🔮 模型评估阶段 - 按需加载数据...\")\n",
|
|||
|
"\n",
|
|||
|
"# 评估函数\n",
|
|||
|
"def evaluate_on_dataset(model, scaler, dataset, dataset_name):\n",
|
|||
|
" \"\"\"在指定数据集上评估模型\"\"\"\n",
|
|||
|
" print(f\"\\n📊 评估 {dataset_name}...\")\n",
|
|||
|
" \n",
|
|||
|
" all_predictions = []\n",
|
|||
|
" all_true_labels = []\n",
|
|||
|
" batch_count = 0\n",
|
|||
|
" \n",
|
|||
|
" for features, labels in dataset.get_batch_generator():\n",
|
|||
|
" batch_count += 1\n",
|
|||
|
" print(f\" 评估批次 {batch_count}: {features.shape[0]} 样本\", end=\"\")\n",
|
|||
|
" \n",
|
|||
|
" # 预处理\n",
|
|||
|
" features_scaled = scaler.transform(features)\n",
|
|||
|
" \n",
|
|||
|
" # 预测\n",
|
|||
|
" predictions = model.predict(features_scaled, num_iteration=model.best_iteration)\n",
|
|||
|
" predicted_labels = np.argmax(predictions, axis=1)\n",
|
|||
|
" \n",
|
|||
|
" all_predictions.extend(predicted_labels)\n",
|
|||
|
" all_true_labels.extend(labels)\n",
|
|||
|
" \n",
|
|||
|
" print(f\" ✓\")\n",
|
|||
|
" \n",
|
|||
|
" # 清理内存\n",
|
|||
|
" del features, labels, features_scaled, predictions, predicted_labels\n",
|
|||
|
" gc.collect()\n",
|
|||
|
" \n",
|
|||
|
" # 计算总体准确率\n",
|
|||
|
" accuracy = accuracy_score(all_true_labels, all_predictions)\n",
|
|||
|
" print(f\" ✅ {dataset_name}准确率: {accuracy:.4f} ({accuracy*100:.2f}%)\")\n",
|
|||
|
" \n",
|
|||
|
" return accuracy, all_true_labels, all_predictions\n",
|
|||
|
"\n",
|
|||
|
"# 按需评估\n",
|
|||
|
"val_accuracy, val_true, val_pred = evaluate_on_dataset(model, scaler, val_dataset, \"验证集\")\n",
|
|||
|
"test_accuracy, test_true, test_pred = evaluate_on_dataset(model, scaler, test_dataset, \"测试集\")\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n📊 最终评估结果:\")\n",
|
|||
|
"print(f\" 验证集准确率: {val_accuracy:.4f} ({val_accuracy*100:.2f}%)\")\n",
|
|||
|
"print(f\" 测试集准确率: {test_accuracy:.4f} ({test_accuracy*100:.2f}%)\")\n",
|
|||
|
"print(f\" 最终内存使用: {get_memory_usage()}\")\n",
|
|||
|
"\n",
|
|||
|
"# 特征重要性分析\n",
|
|||
|
"print(f\"\\n🔍 特征重要性分析:\")\n",
|
|||
|
"try:\n",
|
|||
|
" feature_importance = model.feature_importance(importance_type='gain')\n",
|
|||
|
" top_features_idx = np.argsort(feature_importance)[-10:]\n",
|
|||
|
" \n",
|
|||
|
" print(f\" 前10个重要特征索引:\")\n",
|
|||
|
" for i, idx in enumerate(reversed(top_features_idx)):\n",
|
|||
|
" print(f\" {i+1:2d}. 特征 {idx:4d}: 重要性 {feature_importance[idx]:.2f}\")\n",
|
|||
|
" \n",
|
|||
|
" # 可视化特征重要性\n",
|
|||
|
" plt.figure(figsize=(10, 6))\n",
|
|||
|
" plt.bar(range(len(top_features_idx)), feature_importance[top_features_idx])\n",
|
|||
|
" plt.title('前10个重要特征')\n",
|
|||
|
" plt.xlabel('特征索引')\n",
|
|||
|
" plt.ylabel('重要性得分')\n",
|
|||
|
" plt.xticks(range(len(top_features_idx)), top_features_idx, rotation=45)\n",
|
|||
|
" plt.tight_layout()\n",
|
|||
|
" plt.show()\n",
|
|||
|
" \n",
|
|||
|
"except Exception as e:\n",
|
|||
|
" print(f\" 特征重要性分析失败: {e}\")\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n✨ 内存友好训练完成!\")\n",
|
|||
|
"print(f\"💡 内存优化策略总结:\")\n",
|
|||
|
"print(f\" - 分批数据加载,避免内存溢出\")\n",
|
|||
|
"print(f\" - 增量模型训练,逐步改进\")\n",
|
|||
|
"print(f\" - 按需评估,节省内存\")\n",
|
|||
|
"print(f\" - 自动内存清理\")\n",
|
|||
|
"print(f\" - GPU/CPU自适应配置\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 🔍 音素标签分布详细分析\n",
|
|||
|
"\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from collections import Counter\n",
|
|||
|
"\n",
|
|||
|
"print(\"=\"*70)\n",
|
|||
|
"print(\"🔍 音素标签分布详细分析\")\n",
|
|||
|
"print(\"=\"*70)\n",
|
|||
|
"\n",
|
|||
|
"# 检查数据是否可用\n",
|
|||
|
"if 'single_result' not in locals() or not single_result or 'train' not in single_result:\n",
|
|||
|
" print(\"❌ 错误: single_result数据不可用,请先运行RNN数据处理\")\n",
|
|||
|
" exit()\n",
|
|||
|
"\n",
|
|||
|
"# 获取数据\n",
|
|||
|
"train_data = single_result['train']\n",
|
|||
|
"concatenated_features = train_data['concatenated_data']\n",
|
|||
|
"\n",
|
|||
|
"print(f\"📊 原始数据结构:\")\n",
|
|||
|
"print(f\" 试验总数: {len(concatenated_features)}\")\n",
|
|||
|
"if len(concatenated_features) > 0:\n",
|
|||
|
" print(f\" 单个试验形状: {concatenated_features[0].shape}\")\n",
|
|||
|
" print(f\" 特征维度: {concatenated_features[0].shape[1]} (前7168神经 + 后41音素)\")\n",
|
|||
|
"\n",
|
|||
|
"# 重新分析RNN输出的音素分布\n",
|
|||
|
"print(f\"\\n🎯 音素分析方法:\")\n",
|
|||
|
"print(f\" 1. 提取后41维RNN输出 (音素logits)\")\n",
|
|||
|
"print(f\" 2. 对每个时间步计算argmax得到音素标签\")\n",
|
|||
|
"print(f\" 3. 统计所有时间步的音素分布\")\n",
|
|||
|
"\n",
|
|||
|
"# 收集所有音素预测\n",
|
|||
|
"all_phoneme_predictions = []\n",
|
|||
|
"time_step_count = 0\n",
|
|||
|
"trial_count = 0\n",
|
|||
|
"\n",
|
|||
|
"for i, features in enumerate(concatenated_features):\n",
|
|||
|
" if features.shape[0] > 0: # 确保有时间步\n",
|
|||
|
" trial_count += 1\n",
|
|||
|
" for t in range(features.shape[0]):\n",
|
|||
|
" time_step_count += 1\n",
|
|||
|
" rnn_logits = features[t, 7168:] # 后41维\n",
|
|||
|
" phoneme_pred = np.argmax(rnn_logits)\n",
|
|||
|
" all_phoneme_predictions.append(phoneme_pred)\n",
|
|||
|
"\n",
|
|||
|
"all_phoneme_predictions = np.array(all_phoneme_predictions)\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n📈 统计结果:\")\n",
|
|||
|
"print(f\" 处理的试验数: {trial_count}\")\n",
|
|||
|
"print(f\" 总时间步数: {time_step_count}\")\n",
|
|||
|
"print(f\" 音素预测数: {len(all_phoneme_predictions)}\")\n",
|
|||
|
"\n",
|
|||
|
"# 分析音素分布\n",
|
|||
|
"phoneme_counts = np.bincount(all_phoneme_predictions)\n",
|
|||
|
"print(f\"\\n🏷️ 音素标签分布详情:\")\n",
|
|||
|
"print(f\" 音素类别数: {len(phoneme_counts)}\")\n",
|
|||
|
"print(f\" 非零类别数: {np.count_nonzero(phoneme_counts)}\")\n",
|
|||
|
"\n",
|
|||
|
"# 显示前15个音素的分布\n",
|
|||
|
"print(f\"\\n📊 前15个音素类别分布:\")\n",
|
|||
|
"for i in range(min(15, len(phoneme_counts))):\n",
|
|||
|
" percentage = phoneme_counts[i] / len(all_phoneme_predictions) * 100\n",
|
|||
|
" print(f\" 音素 {i:2d}: {phoneme_counts[i]:6d} 次 ({percentage:5.2f}%)\")\n",
|
|||
|
"\n",
|
|||
|
"# 检查音素0的异常高频率\n",
|
|||
|
"print(f\"\\n⚠️ 音素0异常分析:\")\n",
|
|||
|
"print(f\" 音素0出现次数: {phoneme_counts[0]}\")\n",
|
|||
|
"print(f\" 音素0占比: {phoneme_counts[0] / len(all_phoneme_predictions) * 100:.2f}%\")\n",
|
|||
|
"print(f\" 其他音素总数: {len(all_phoneme_predictions) - phoneme_counts[0]}\")\n",
|
|||
|
"\n",
|
|||
|
"# 可能的原因分析\n",
|
|||
|
"print(f\"\\n💡 音素0高频的可能原因:\")\n",
|
|||
|
"print(f\" 1. 静音/沉默音素: 音素0可能代表静音或沉默状态\")\n",
|
|||
|
"print(f\" 2. 背景状态: 非发音期间的默认状态\")\n",
|
|||
|
"print(f\" 3. 模型偏置: RNN在不确定时倾向于预测音素0\")\n",
|
|||
|
"print(f\" 4. 数据特性: 大部分时间处于非发音状态\")\n",
|
|||
|
"\n",
|
|||
|
"# 检查LOGIT_TO_PHONEME映射\n",
|
|||
|
"if 'LOGIT_TO_PHONEME' in locals():\n",
|
|||
|
" print(f\"\\n🔤 音素映射信息:\")\n",
|
|||
|
" print(f\" 音素映射表长度: {len(LOGIT_TO_PHONEME)}\")\n",
|
|||
|
" print(f\" 前10个音素映射:\")\n",
|
|||
|
" for i in range(min(10, len(LOGIT_TO_PHONEME))):\n",
|
|||
|
" print(f\" 索引 {i}: '{LOGIT_TO_PHONEME[i]}'\")\n",
|
|||
|
" \n",
|
|||
|
" # 检查音素0是什么\n",
|
|||
|
" if len(LOGIT_TO_PHONEME) > 0:\n",
|
|||
|
" print(f\"\\n 音素0代表: '{LOGIT_TO_PHONEME[0]}'\")\n",
|
|||
|
" if 'sil' in LOGIT_TO_PHONEME[0].lower() or 'silence' in LOGIT_TO_PHONEME[0].lower():\n",
|
|||
|
" print(f\" ✅ 确认: 音素0是静音符号,高频率是正常的!\")\n",
|
|||
|
"\n",
|
|||
|
"# 可视化音素分布\n",
|
|||
|
"plt.figure(figsize=(15, 6))\n",
|
|||
|
"\n",
|
|||
|
"# 左图:完整分布(包含音素0)\n",
|
|||
|
"plt.subplot(1, 2, 1)\n",
|
|||
|
"plt.bar(range(len(phoneme_counts)), phoneme_counts, alpha=0.7)\n",
|
|||
|
"plt.title('完整音素分布')\n",
|
|||
|
"plt.xlabel('音素索引')\n",
|
|||
|
"plt.ylabel('出现次数')\n",
|
|||
|
"plt.yscale('log') # 使用对数坐标便于观察\n",
|
|||
|
"\n",
|
|||
|
"# 右图:排除音素0的分布\n",
|
|||
|
"plt.subplot(1, 2, 2)\n",
|
|||
|
"non_silence_counts = phoneme_counts[1:]\n",
|
|||
|
"plt.bar(range(1, len(phoneme_counts)), non_silence_counts, alpha=0.7, color='orange')\n",
|
|||
|
"plt.title('排除音素0的分布')\n",
|
|||
|
"plt.xlabel('音素索引')\n",
|
|||
|
"plt.ylabel('出现次数')\n",
|
|||
|
"\n",
|
|||
|
"plt.tight_layout()\n",
|
|||
|
"plt.show()\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n✅ 结论:\")\n",
|
|||
|
"print(f\" 音素0的高频率很可能是正常现象,代表静音或背景状态\")\n",
|
|||
|
"print(f\" 实际的语音音素分布在音素1-40中,数量相对平衡\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 🎯 音素分类结果展示 - 将预测转换为音素符号\n",
|
|||
|
"\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"from collections import Counter\n",
|
|||
|
"\n",
|
|||
|
"print(\"=\"*70)\n",
|
|||
|
"print(\"🎯 音素分类结果展示\")\n",
|
|||
|
"print(\"=\"*70)\n",
|
|||
|
"\n",
|
|||
|
"# 检查必要的数据是否存在\n",
|
|||
|
"if 'single_result' not in locals():\n",
|
|||
|
" print(\"❌ 错误: single_result数据不可用\")\n",
|
|||
|
" exit()\n",
|
|||
|
"\n",
|
|||
|
"if 'LOGIT_TO_PHONEME' not in locals():\n",
|
|||
|
" print(\"❌ 错误: LOGIT_TO_PHONEME映射不可用\")\n",
|
|||
|
" exit()\n",
|
|||
|
"\n",
|
|||
|
"print(f\"✅ 数据检查通过\")\n",
|
|||
|
"print(f\" 音素映射表长度: {len(LOGIT_TO_PHONEME)}\")\n",
|
|||
|
"\n",
|
|||
|
"# 获取数据\n",
|
|||
|
"train_data = single_result['train']\n",
|
|||
|
"concatenated_features = train_data['concatenated_data']\n",
|
|||
|
"\n",
|
|||
|
"# 处理几个试验的数据进行展示\n",
|
|||
|
"print(f\"\\n🔬 分类结果示例 (前3个试验):\")\n",
|
|||
|
"print(f\"=\"*50)\n",
|
|||
|
"\n",
|
|||
|
"for trial_idx in range(min(3, len(concatenated_features))):\n",
|
|||
|
" features = concatenated_features[trial_idx]\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n📋 试验 {trial_idx + 1}:\")\n",
|
|||
|
" print(f\" 时间步数: {features.shape[0]}\")\n",
|
|||
|
" \n",
|
|||
|
" if features.shape[0] > 0:\n",
|
|||
|
" # 提取神经特征和RNN输出\n",
|
|||
|
" neural_features = features[:, :7168] # 前7168维\n",
|
|||
|
" rnn_logits = features[:, 7168:] # 后41维\n",
|
|||
|
" \n",
|
|||
|
" # 获取每个时间步的音素预测\n",
|
|||
|
" predicted_phoneme_indices = np.argmax(rnn_logits, axis=1)\n",
|
|||
|
" max_confidences = np.max(rnn_logits, axis=1)\n",
|
|||
|
" \n",
|
|||
|
" # 转换为音素符号\n",
|
|||
|
" predicted_phonemes = [LOGIT_TO_PHONEME[idx] for idx in predicted_phoneme_indices]\n",
|
|||
|
" \n",
|
|||
|
" # 显示前10个时间步的结果\n",
|
|||
|
" print(f\" 前10个时间步的预测:\")\n",
|
|||
|
" print(f\" {'步骤':>4} {'索引':>4} {'置信度':>8} {'音素':>8}\")\n",
|
|||
|
" print(f\" {'-'*30}\")\n",
|
|||
|
" \n",
|
|||
|
" for t in range(min(10, len(predicted_phonemes))):\n",
|
|||
|
" print(f\" {t+1:4d} {predicted_phoneme_indices[t]:4d} {max_confidences[t]:8.3f} {predicted_phonemes[t]:>8}\")\n",
|
|||
|
" \n",
|
|||
|
" # 统计这个试验的音素分布\n",
|
|||
|
" phoneme_counter = Counter(predicted_phonemes)\n",
|
|||
|
" print(f\"\\n 本试验音素分布 (前5个):\")\n",
|
|||
|
" for phoneme, count in phoneme_counter.most_common(5):\n",
|
|||
|
" percentage = count / len(predicted_phonemes) * 100\n",
|
|||
|
" print(f\" '{phoneme}': {count:3d} 次 ({percentage:5.1f}%)\")\n",
|
|||
|
"\n",
|
|||
|
"# 全局音素统计\n",
|
|||
|
"print(f\"\\n🌍 全局音素统计:\")\n",
|
|||
|
"print(f\"=\"*50)\n",
|
|||
|
"\n",
|
|||
|
"all_phoneme_predictions = []\n",
|
|||
|
"all_confidences = []\n",
|
|||
|
"\n",
|
|||
|
"for features in concatenated_features:\n",
|
|||
|
" if features.shape[0] > 0:\n",
|
|||
|
" rnn_logits = features[:, 7168:]\n",
|
|||
|
" predicted_indices = np.argmax(rnn_logits, axis=1)\n",
|
|||
|
" confidences = np.max(rnn_logits, axis=1)\n",
|
|||
|
" \n",
|
|||
|
" all_phoneme_predictions.extend(predicted_indices)\n",
|
|||
|
" all_confidences.extend(confidences)\n",
|
|||
|
"\n",
|
|||
|
"# 转换为音素符号\n",
|
|||
|
"all_predicted_phonemes = [LOGIT_TO_PHONEME[idx] for idx in all_phoneme_predictions]\n",
|
|||
|
"global_phoneme_counter = Counter(all_predicted_phonemes)\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n📊 音素频率排行 (前15个):\")\n",
|
|||
|
"print(f\"{'排名':>4} {'音素':>8} {'出现次数':>8} {'百分比':>8} {'平均置信度':>10}\")\n",
|
|||
|
"print(f\"{'-'*45}\")\n",
|
|||
|
"\n",
|
|||
|
"for rank, (phoneme, count) in enumerate(global_phoneme_counter.most_common(15), 1):\n",
|
|||
|
" # 计算该音素的平均置信度\n",
|
|||
|
" phoneme_indices = [i for i, p in enumerate(all_predicted_phonemes) if p == phoneme]\n",
|
|||
|
" avg_confidence = np.mean([all_confidences[i] for i in phoneme_indices])\n",
|
|||
|
" percentage = count / len(all_predicted_phonemes) * 100\n",
|
|||
|
" \n",
|
|||
|
" print(f\"{rank:4d} {phoneme:>8} {count:8d} {percentage:7.2f}% {avg_confidence:10.3f}\")\n",
|
|||
|
"\n",
|
|||
|
"# 分析非静音音素\n",
|
|||
|
"print(f\"\\n🗣️ 非静音音素分析:\")\n",
|
|||
|
"non_silence_phonemes = [p for p in all_predicted_phonemes if p.lower() not in ['sil', 'silence', 'sp', 'spn']]\n",
|
|||
|
"non_silence_counter = Counter(non_silence_phonemes)\n",
|
|||
|
"\n",
|
|||
|
"print(f\" 总音素预测: {len(all_predicted_phonemes)}\")\n",
|
|||
|
"print(f\" 非静音音素: {len(non_silence_phonemes)}\")\n",
|
|||
|
"print(f\" 静音比例: {(1 - len(non_silence_phonemes)/len(all_predicted_phonemes))*100:.1f}%\")\n",
|
|||
|
"\n",
|
|||
|
"if len(non_silence_phonemes) > 0:\n",
|
|||
|
" print(f\"\\n 非静音音素排行 (前10个):\")\n",
|
|||
|
" print(f\" {'音素':>6} {'次数':>6} {'百分比':>8}\")\n",
|
|||
|
" print(f\" {'-'*22}\")\n",
|
|||
|
" \n",
|
|||
|
" for phoneme, count in non_silence_counter.most_common(10):\n",
|
|||
|
" percentage = count / len(non_silence_phonemes) * 100\n",
|
|||
|
" print(f\" {phoneme:>6} {count:6d} {percentage:7.2f}%\")\n",
|
|||
|
"\n",
|
|||
|
"# 序列示例\n",
|
|||
|
"print(f\"\\n📝 音素序列示例 (试验1的前30个时间步):\")\n",
|
|||
|
"if len(concatenated_features) > 0 and concatenated_features[0].shape[0] > 0:\n",
|
|||
|
" example_features = concatenated_features[0]\n",
|
|||
|
" example_logits = example_features[:30, 7168:] # 前30个时间步\n",
|
|||
|
" example_predictions = np.argmax(example_logits, axis=1)\n",
|
|||
|
" example_phonemes = [LOGIT_TO_PHONEME[idx] for idx in example_predictions]\n",
|
|||
|
" \n",
|
|||
|
" # 按每行10个显示\n",
|
|||
|
" for i in range(0, len(example_phonemes), 10):\n",
|
|||
|
" line_phonemes = example_phonemes[i:i+10]\n",
|
|||
|
" phoneme_str = ' '.join(f\"{p:>4}\" for p in line_phonemes)\n",
|
|||
|
" print(f\" 步骤 {i+1:2d}-{min(i+10, len(example_phonemes)):2d}: {phoneme_str}\")\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n✨ 音素分类结果展示完成!\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 🔄 CTC处理 - 连接主义时序分类后处理\n",
|
|||
|
"\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"from collections import Counter\n",
|
|||
|
"\n",
|
|||
|
"print(\"=\"*70)\n",
|
|||
|
"print(\"🔄 CTC处理 - 音素序列优化\")\n",
|
|||
|
"print(\"=\"*70)\n",
|
|||
|
"\n",
|
|||
|
"def ctc_decode_greedy(predictions, blank_id=0):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 简单的CTC贪心解码\n",
|
|||
|
" \n",
|
|||
|
" Args:\n",
|
|||
|
" predictions: 音素预测序列 [time_steps]\n",
|
|||
|
" blank_id: 空白符号的ID (通常是0,代表静音)\n",
|
|||
|
" \n",
|
|||
|
" Returns:\n",
|
|||
|
" decoded_sequence: 解码后的音素序列\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" decoded = []\n",
|
|||
|
" previous = -1\n",
|
|||
|
" \n",
|
|||
|
" for pred in predictions:\n",
|
|||
|
" # CTC规则:\n",
|
|||
|
" # 1. 移除连续重复的符号\n",
|
|||
|
" # 2. 移除空白符号\n",
|
|||
|
" if pred != previous and pred != blank_id:\n",
|
|||
|
" decoded.append(pred)\n",
|
|||
|
" previous = pred\n",
|
|||
|
" \n",
|
|||
|
" return decoded\n",
|
|||
|
"\n",
|
|||
|
"def ctc_decode_with_confidence(logits, confidence_threshold=0.5, blank_id=0):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 带置信度阈值的CTC解码\n",
|
|||
|
" \n",
|
|||
|
" Args:\n",
|
|||
|
" logits: RNN输出的logits [time_steps, num_classes]\n",
|
|||
|
" confidence_threshold: 置信度阈值\n",
|
|||
|
" blank_id: 空白符号ID\n",
|
|||
|
" \n",
|
|||
|
" Returns:\n",
|
|||
|
" decoded_sequence: 解码后的音素序列\n",
|
|||
|
" confidences: 对应的置信度\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" predictions = np.argmax(logits, axis=1)\n",
|
|||
|
" max_confidences = np.max(logits, axis=1)\n",
|
|||
|
" \n",
|
|||
|
" decoded = []\n",
|
|||
|
" decoded_confidences = []\n",
|
|||
|
" previous = -1\n",
|
|||
|
" \n",
|
|||
|
" for i, (pred, conf) in enumerate(zip(predictions, max_confidences)):\n",
|
|||
|
" # 只保留高置信度且非重复的预测\n",
|
|||
|
" if pred != previous and pred != blank_id and conf > confidence_threshold:\n",
|
|||
|
" decoded.append(pred)\n",
|
|||
|
" decoded_confidences.append(conf)\n",
|
|||
|
" previous = pred\n",
|
|||
|
" \n",
|
|||
|
" return decoded, decoded_confidences\n",
|
|||
|
"\n",
|
|||
|
"# 检查数据\n",
|
|||
|
"if 'single_result' not in locals() or 'LOGIT_TO_PHONEME' not in locals():\n",
|
|||
|
" print(\"❌ 错误: 缺少必要的数据\")\n",
|
|||
|
" exit()\n",
|
|||
|
"\n",
|
|||
|
"print(f\"✅ 开始CTC处理\")\n",
|
|||
|
"print(f\" 音素映射表: {len(LOGIT_TO_PHONEME)} 个音素\")\n",
|
|||
|
"\n",
|
|||
|
"# 获取数据\n",
|
|||
|
"train_data = single_result['train']\n",
|
|||
|
"concatenated_features = train_data['concatenated_data']\n",
|
|||
|
"\n",
|
|||
|
"# 处理几个试验进行CTC解码对比\n",
|
|||
|
"print(f\"\\n🔬 CTC解码对比 (前3个试验):\")\n",
|
|||
|
"print(f\"=\"*60)\n",
|
|||
|
"\n",
|
|||
|
"for trial_idx in range(min(3, len(concatenated_features))):\n",
|
|||
|
" features = concatenated_features[trial_idx]\n",
|
|||
|
" \n",
|
|||
|
" if features.shape[0] == 0:\n",
|
|||
|
" continue\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n📋 试验 {trial_idx + 1} (共{features.shape[0]}个时间步):\")\n",
|
|||
|
" \n",
|
|||
|
" # 提取RNN输出\n",
|
|||
|
" rnn_logits = features[:, 7168:] # 后41维\n",
|
|||
|
" \n",
|
|||
|
" # 原始预测\n",
|
|||
|
" raw_predictions = np.argmax(rnn_logits, axis=1)\n",
|
|||
|
" raw_phonemes = [LOGIT_TO_PHONEME[idx] for idx in raw_predictions]\n",
|
|||
|
" \n",
|
|||
|
" # CTC贪心解码\n",
|
|||
|
" ctc_predictions = ctc_decode_greedy(raw_predictions, blank_id=0)\n",
|
|||
|
" ctc_phonemes = [LOGIT_TO_PHONEME[idx] for idx in ctc_predictions]\n",
|
|||
|
" \n",
|
|||
|
" # 带置信度的CTC解码\n",
|
|||
|
" ctc_conf_predictions, confidences = ctc_decode_with_confidence(\n",
|
|||
|
" rnn_logits, confidence_threshold=1.0, blank_id=0\n",
|
|||
|
" )\n",
|
|||
|
" ctc_conf_phonemes = [LOGIT_TO_PHONEME[idx] for idx in ctc_conf_predictions]\n",
|
|||
|
" \n",
|
|||
|
" print(f\" 原始序列长度: {len(raw_phonemes)}\")\n",
|
|||
|
" print(f\" CTC解码长度: {len(ctc_phonemes)}\")\n",
|
|||
|
" print(f\" 高置信CTC长度: {len(ctc_conf_phonemes)}\")\n",
|
|||
|
" \n",
|
|||
|
" # 显示前20个原始预测\n",
|
|||
|
" print(f\"\\n 原始预测 (前20步):\")\n",
|
|||
|
" raw_sample = raw_phonemes[:20]\n",
|
|||
|
" print(f\" {' '.join(f'{p:>4}' for p in raw_sample)}\")\n",
|
|||
|
" \n",
|
|||
|
" # 显示CTC解码结果\n",
|
|||
|
" print(f\"\\n CTC解码结果:\")\n",
|
|||
|
" if len(ctc_phonemes) > 0:\n",
|
|||
|
" ctc_sample = ctc_phonemes[:20] # 显示前20个(如果有的话)\n",
|
|||
|
" print(f\" {' '.join(f'{p:>4}' for p in ctc_sample)}\")\n",
|
|||
|
" else:\n",
|
|||
|
" print(f\" (空序列)\")\n",
|
|||
|
" \n",
|
|||
|
" # 显示高置信度CTC结果\n",
|
|||
|
" print(f\"\\n 高置信CTC结果:\")\n",
|
|||
|
" if len(ctc_conf_phonemes) > 0:\n",
|
|||
|
" conf_sample = ctc_conf_phonemes[:15] # 显示前15个\n",
|
|||
|
" conf_values = confidences[:15]\n",
|
|||
|
" for phoneme, conf in zip(conf_sample, conf_values):\n",
|
|||
|
" print(f\" {phoneme:>4}({conf:.2f})\", end=\"\")\n",
|
|||
|
" print()\n",
|
|||
|
" else:\n",
|
|||
|
" print(f\" (空序列)\")\n",
|
|||
|
"\n",
|
|||
|
"# 全局CTC统计\n",
|
|||
|
"print(f\"\\n🌍 全局CTC统计分析:\")\n",
|
|||
|
"print(f\"=\"*60)\n",
|
|||
|
"\n",
|
|||
|
"all_raw_predictions = []\n",
|
|||
|
"all_ctc_predictions = []\n",
|
|||
|
"all_ctc_conf_predictions = []\n",
|
|||
|
"\n",
|
|||
|
"for features in concatenated_features:\n",
|
|||
|
" if features.shape[0] > 0:\n",
|
|||
|
" rnn_logits = features[:, 7168:]\n",
|
|||
|
" raw_preds = np.argmax(rnn_logits, axis=1)\n",
|
|||
|
" \n",
|
|||
|
" # 收集原始预测\n",
|
|||
|
" all_raw_predictions.extend(raw_preds)\n",
|
|||
|
" \n",
|
|||
|
" # CTC解码\n",
|
|||
|
" ctc_preds = ctc_decode_greedy(raw_preds, blank_id=0)\n",
|
|||
|
" all_ctc_predictions.extend(ctc_preds)\n",
|
|||
|
" \n",
|
|||
|
" # 高置信度CTC\n",
|
|||
|
" ctc_conf_preds, _ = ctc_decode_with_confidence(rnn_logits, confidence_threshold=1.0, blank_id=0)\n",
|
|||
|
" all_ctc_conf_predictions.extend(ctc_conf_preds)\n",
|
|||
|
"\n",
|
|||
|
"# 转换为音素符号\n",
|
|||
|
"raw_phonemes_global = [LOGIT_TO_PHONEME[idx] for idx in all_raw_predictions]\n",
|
|||
|
"ctc_phonemes_global = [LOGIT_TO_PHONEME[idx] for idx in all_ctc_predictions]\n",
|
|||
|
"ctc_conf_phonemes_global = [LOGIT_TO_PHONEME[idx] for idx in all_ctc_conf_predictions]\n",
|
|||
|
"\n",
|
|||
|
"print(f\"📊 序列长度对比:\")\n",
|
|||
|
"print(f\" 原始预测总数: {len(raw_phonemes_global):,}\")\n",
|
|||
|
"print(f\" CTC解码总数: {len(ctc_phonemes_global):,}\")\n",
|
|||
|
"print(f\" 高置信CTC总数: {len(ctc_conf_phonemes_global):,}\")\n",
|
|||
|
"print(f\" 压缩比 (CTC): {len(ctc_phonemes_global)/len(raw_phonemes_global)*100:.1f}%\")\n",
|
|||
|
"print(f\" 压缩比 (高置信): {len(ctc_conf_phonemes_global)/len(raw_phonemes_global)*100:.1f}%\")\n",
|
|||
|
"\n",
|
|||
|
"# 音素分布对比\n",
|
|||
|
"print(f\"\\n📈 音素分布对比 (前10个):\")\n",
|
|||
|
"raw_counter = Counter(raw_phonemes_global)\n",
|
|||
|
"ctc_counter = Counter(ctc_phonemes_global)\n",
|
|||
|
"ctc_conf_counter = Counter(ctc_conf_phonemes_global)\n",
|
|||
|
"\n",
|
|||
|
"print(f\"{'音素':>6} {'原始':>8} {'CTC':>8} {'高置信':>8}\")\n",
|
|||
|
"print(f\"{'-'*35}\")\n",
|
|||
|
"\n",
|
|||
|
"for phoneme, raw_count in raw_counter.most_common(10):\n",
|
|||
|
" ctc_count = ctc_counter.get(phoneme, 0)\n",
|
|||
|
" conf_count = ctc_conf_counter.get(phoneme, 0)\n",
|
|||
|
" print(f\"{phoneme:>6} {raw_count:8d} {ctc_count:8d} {conf_count:8d}\")\n",
|
|||
|
"\n",
|
|||
|
"# 生成完整的句子示例\n",
|
|||
|
"print(f\"\\n📝 CTC解码句子示例:\")\n",
|
|||
|
"print(f\"=\"*40)\n",
|
|||
|
"\n",
|
|||
|
"for trial_idx in range(min(2, len(concatenated_features))):\n",
|
|||
|
" features = concatenated_features[trial_idx]\n",
|
|||
|
" \n",
|
|||
|
" if features.shape[0] > 0:\n",
|
|||
|
" rnn_logits = features[:, 7168:]\n",
|
|||
|
" raw_preds = np.argmax(rnn_logits, axis=1)\n",
|
|||
|
" ctc_preds = ctc_decode_greedy(raw_preds, blank_id=0)\n",
|
|||
|
" \n",
|
|||
|
" raw_phonemes = [LOGIT_TO_PHONEME[idx] for idx in raw_preds]\n",
|
|||
|
" ctc_phonemes = [LOGIT_TO_PHONEME[idx] for idx in ctc_preds]\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n试验 {trial_idx + 1}:\")\n",
|
|||
|
" print(f\" 原始: {' '.join(raw_phonemes[:30])}...\")\n",
|
|||
|
" print(f\" CTC: {' '.join(ctc_phonemes[:20])}\")\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n✨ CTC处理完成!\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# 📝 音素序列到文本转换 - 最终结果展示\n",
|
|||
|
"\n",
|
|||
|
"import re\n",
|
|||
|
"from collections import defaultdict\n",
|
|||
|
"\n",
|
|||
|
"print(\"=\"*70)\n",
|
|||
|
"print(\"📝 音素序列到文本转换\")\n",
|
|||
|
"print(\"=\"*70)\n",
|
|||
|
"\n",
|
|||
|
"def phonemes_to_words(phoneme_sequence):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 将音素序列转换为可能的单词\n",
|
|||
|
" 这是一个简化的实现,实际应用中需要语音识别词典\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" # 简单的音素到字母映射(示例)\n",
|
|||
|
" phoneme_to_letter = {\n",
|
|||
|
" 'AA': 'a', 'AE': 'a', 'AH': 'u', 'AO': 'o', 'AW': 'ow',\n",
|
|||
|
" 'AY': 'ai', 'B': 'b', 'CH': 'ch', 'D': 'd', 'DH': 'th',\n",
|
|||
|
" 'EH': 'e', 'ER': 'er', 'EY': 'ay', 'F': 'f', 'G': 'g',\n",
|
|||
|
" 'HH': 'h', 'IH': 'i', 'IY': 'ee', 'JH': 'j', 'K': 'k',\n",
|
|||
|
" 'L': 'l', 'M': 'm', 'N': 'n', 'NG': 'ng', 'OW': 'o',\n",
|
|||
|
" 'OY': 'oy', 'P': 'p', 'R': 'r', 'S': 's', 'SH': 'sh',\n",
|
|||
|
" 'T': 't', 'TH': 'th', 'UH': 'u', 'UW': 'oo', 'V': 'v',\n",
|
|||
|
" 'W': 'w', 'Y': 'y', 'Z': 'z', 'ZH': 'zh'\n",
|
|||
|
" }\n",
|
|||
|
" \n",
|
|||
|
" result = []\n",
|
|||
|
" for phoneme in phoneme_sequence:\n",
|
|||
|
" if phoneme in phoneme_to_letter:\n",
|
|||
|
" result.append(phoneme_to_letter[phoneme])\n",
|
|||
|
" else:\n",
|
|||
|
" result.append(f'[{phoneme}]') # 未知音素用括号标记\n",
|
|||
|
" \n",
|
|||
|
" return ''.join(result)\n",
|
|||
|
"\n",
|
|||
|
"def analyze_phoneme_patterns(phoneme_sequence):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 分析音素序列中的模式\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" # 统计音素频率\n",
|
|||
|
" phoneme_freq = defaultdict(int)\n",
|
|||
|
" for phoneme in phoneme_sequence:\n",
|
|||
|
" phoneme_freq[phoneme] += 1\n",
|
|||
|
" \n",
|
|||
|
" # 寻找常见的音素组合\n",
|
|||
|
" bigrams = defaultdict(int)\n",
|
|||
|
" for i in range(len(phoneme_sequence) - 1):\n",
|
|||
|
" bigram = (phoneme_sequence[i], phoneme_sequence[i+1])\n",
|
|||
|
" bigrams[bigram] += 1\n",
|
|||
|
" \n",
|
|||
|
" return phoneme_freq, bigrams\n",
|
|||
|
"\n",
|
|||
|
"# 处理CTC解码后的结果\n",
|
|||
|
"print(f\"🔄 处理CTC解码结果...\")\n",
|
|||
|
"\n",
|
|||
|
"# 获取数据\n",
|
|||
|
"train_data = single_result['train']\n",
|
|||
|
"concatenated_features = train_data['concatenated_data']\n",
|
|||
|
"\n",
|
|||
|
"# 收集所有CTC解码结果\n",
|
|||
|
"all_trials_ctc = []\n",
|
|||
|
"all_trials_raw = []\n",
|
|||
|
"\n",
|
|||
|
"for trial_idx, features in enumerate(concatenated_features[:5]): # 处理前5个试验\n",
|
|||
|
" if features.shape[0] > 0:\n",
|
|||
|
" rnn_logits = features[:, 7168:]\n",
|
|||
|
" raw_preds = np.argmax(rnn_logits, axis=1)\n",
|
|||
|
" \n",
|
|||
|
" # CTC解码\n",
|
|||
|
" ctc_preds = ctc_decode_greedy(raw_preds, blank_id=0)\n",
|
|||
|
" \n",
|
|||
|
" # 转换为音素符号\n",
|
|||
|
" raw_phonemes = [LOGIT_TO_PHONEME[idx] for idx in raw_preds]\n",
|
|||
|
" ctc_phonemes = [LOGIT_TO_PHONEME[idx] for idx in ctc_preds]\n",
|
|||
|
" \n",
|
|||
|
" all_trials_raw.append(raw_phonemes)\n",
|
|||
|
" all_trials_ctc.append(ctc_phonemes)\n",
|
|||
|
"\n",
|
|||
|
"print(f\"✅ 处理了 {len(all_trials_ctc)} 个试验\")\n",
|
|||
|
"\n",
|
|||
|
"# 展示每个试验的结果\n",
|
|||
|
"print(f\"\\n📋 逐试验分析:\")\n",
|
|||
|
"print(f\"=\"*50)\n",
|
|||
|
"\n",
|
|||
|
"for trial_idx, (raw_phonemes, ctc_phonemes) in enumerate(zip(all_trials_raw, all_trials_ctc)):\n",
|
|||
|
" print(f\"\\n🎯 试验 {trial_idx + 1}:\")\n",
|
|||
|
" print(f\" 原始长度: {len(raw_phonemes)} → CTC长度: {len(ctc_phonemes)}\")\n",
|
|||
|
" \n",
|
|||
|
" # 显示音素序列\n",
|
|||
|
" print(f\" CTC音素序列: {' '.join(ctc_phonemes[:20])}{'...' if len(ctc_phonemes) > 20 else ''}\")\n",
|
|||
|
" \n",
|
|||
|
" # 尝试转换为文本(简化版本)\n",
|
|||
|
" if len(ctc_phonemes) > 0:\n",
|
|||
|
" text_attempt = phonemes_to_words(ctc_phonemes)\n",
|
|||
|
" print(f\" 近似文本: {text_attempt[:50]}{'...' if len(text_attempt) > 50 else ''}\")\n",
|
|||
|
" \n",
|
|||
|
" # 分析音素模式\n",
|
|||
|
" if len(ctc_phonemes) > 0:\n",
|
|||
|
" phoneme_freq, bigrams = analyze_phoneme_patterns(ctc_phonemes)\n",
|
|||
|
" \n",
|
|||
|
" # 显示最常见的音素\n",
|
|||
|
" top_phonemes = sorted(phoneme_freq.items(), key=lambda x: x[1], reverse=True)[:5]\n",
|
|||
|
" print(f\" 常见音素: {', '.join([f'{p}({c})' for p, c in top_phonemes])}\")\n",
|
|||
|
" \n",
|
|||
|
" # 显示最常见的音素对\n",
|
|||
|
" if len(bigrams) > 0:\n",
|
|||
|
" top_bigrams = sorted(bigrams.items(), key=lambda x: x[1], reverse=True)[:3]\n",
|
|||
|
" bigram_str = ', '.join([f'{p1}-{p2}({c})' for (p1, p2), c in top_bigrams])\n",
|
|||
|
" print(f\" 常见音素对: {bigram_str}\")\n",
|
|||
|
"\n",
|
|||
|
"# 全局统计\n",
|
|||
|
"print(f\"\\n🌍 全局CTC结果统计:\")\n",
|
|||
|
"print(f\"=\"*50)\n",
|
|||
|
"\n",
|
|||
|
"all_ctc_phonemes = []\n",
|
|||
|
"for ctc_phonemes in all_trials_ctc:\n",
|
|||
|
" all_ctc_phonemes.extend(ctc_phonemes)\n",
|
|||
|
"\n",
|
|||
|
"if len(all_ctc_phonemes) > 0:\n",
|
|||
|
" global_freq, global_bigrams = analyze_phoneme_patterns(all_ctc_phonemes)\n",
|
|||
|
" \n",
|
|||
|
" print(f\" 总CTC音素数: {len(all_ctc_phonemes)}\")\n",
|
|||
|
" print(f\" 唯一音素数: {len(global_freq)}\")\n",
|
|||
|
" \n",
|
|||
|
" # 最常见的音素\n",
|
|||
|
" print(f\"\\n 📊 音素频率排行:\")\n",
|
|||
|
" top_global_phonemes = sorted(global_freq.items(), key=lambda x: x[1], reverse=True)[:10]\n",
|
|||
|
" for rank, (phoneme, count) in enumerate(top_global_phonemes, 1):\n",
|
|||
|
" percentage = count / len(all_ctc_phonemes) * 100\n",
|
|||
|
" print(f\" {rank:2d}. {phoneme:>4}: {count:4d} 次 ({percentage:5.1f}%)\")\n",
|
|||
|
" \n",
|
|||
|
" # 最常见的音素对\n",
|
|||
|
" print(f\"\\n 🔗 音素对频率排行:\")\n",
|
|||
|
" top_global_bigrams = sorted(global_bigrams.items(), key=lambda x: x[1], reverse=True)[:8]\n",
|
|||
|
" for rank, ((p1, p2), count) in enumerate(top_global_bigrams, 1):\n",
|
|||
|
" print(f\" {rank:2d}. {p1:>4}-{p2:<4}: {count:3d} 次\")\n",
|
|||
|
"\n",
|
|||
|
"# 质量评估\n",
|
|||
|
"print(f\"\\n📈 CTC解码质量评估:\")\n",
|
|||
|
"print(f\"=\"*30)\n",
|
|||
|
"\n",
|
|||
|
"total_raw = sum(len(raw) for raw in all_trials_raw)\n",
|
|||
|
"total_ctc = len(all_ctc_phonemes)\n",
|
|||
|
"compression_ratio = total_ctc / total_raw if total_raw > 0 else 0\n",
|
|||
|
"\n",
|
|||
|
"print(f\" 原始音素总数: {total_raw:,}\")\n",
|
|||
|
"print(f\" CTC音素总数: {total_ctc:,}\")\n",
|
|||
|
"print(f\" 压缩比: {compression_ratio:.3f}\")\n",
|
|||
|
"print(f\" 去重效果: {(1-compression_ratio)*100:.1f}% 的重复被移除\")\n",
|
|||
|
"\n",
|
|||
|
"# 音素多样性\n",
|
|||
|
"unique_phonemes = len(set(all_ctc_phonemes))\n",
|
|||
|
"phoneme_diversity = unique_phonemes / len(LOGIT_TO_PHONEME) if len(LOGIT_TO_PHONEME) > 0 else 0\n",
|
|||
|
"\n",
|
|||
|
"print(f\" 音素多样性: {unique_phonemes}/{len(LOGIT_TO_PHONEME)} ({phoneme_diversity*100:.1f}%)\")\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n✨ 音素到文本转换分析完成!\")\n",
|
|||
|
"print(f\" 💡 提示: 这是基于RNN预测的音素序列\")\n",
|
|||
|
"print(f\" 💡 实际应用需要语音识别词典进行准确的音素到单词转换\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"!ls ../data/rnn-pretagged-data"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 13,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"======================================================================\n",
|
|||
|
"🔬 PCA降维增强数据加载器\n",
|
|||
|
"======================================================================\n",
|
|||
|
"✅ PCA增强数据加载器已定义\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 🔬 PCA降维增强数据加载器\n",
|
|||
|
"\n",
|
|||
|
"from sklearn.decomposition import PCA\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"import seaborn as sns\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import joblib\n",
|
|||
|
"import os\n",
|
|||
|
"\n",
|
|||
|
"print(\"=\"*70)\n",
|
|||
|
"print(\"🔬 PCA降维增强数据加载器\")\n",
|
|||
|
"print(\"=\"*70)\n",
|
|||
|
"\n",
|
|||
|
"class PCAEnhancedDataset:\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 带有PCA降维功能的内存友好数据集\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" \n",
|
|||
|
" def __init__(self, data_dir, data_type, max_samples_per_file=3000, \n",
|
|||
|
" enable_pca=True, n_components=None, variance_threshold=0.95):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" Args:\n",
|
|||
|
" data_dir: 数据目录\n",
|
|||
|
" data_type: 'train', 'val', 'test'\n",
|
|||
|
" max_samples_per_file: 每个文件最大样本数\n",
|
|||
|
" enable_pca: 是否启用PCA降维\n",
|
|||
|
" n_components: PCA主成分数量 (None为自动选择)\n",
|
|||
|
" variance_threshold: 保留方差比例 (用于自动选择成分数)\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" self.data_dir = data_dir\n",
|
|||
|
" self.data_type = data_type\n",
|
|||
|
" self.max_samples_per_file = max_samples_per_file\n",
|
|||
|
" self.enable_pca = enable_pca\n",
|
|||
|
" self.n_components = n_components\n",
|
|||
|
" self.variance_threshold = variance_threshold\n",
|
|||
|
" \n",
|
|||
|
" self.files = [f for f in os.listdir(data_dir) if f.endswith('.npz') and data_type in f]\n",
|
|||
|
" self.scaler = StandardScaler()\n",
|
|||
|
" self.pca = None\n",
|
|||
|
" self.is_fitted = False\n",
|
|||
|
" \n",
|
|||
|
" print(f\"📊 PCA数据集初始化:\")\n",
|
|||
|
" print(f\" 数据类型: {data_type}\")\n",
|
|||
|
" print(f\" 文件数量: {len(self.files)}\")\n",
|
|||
|
" print(f\" PCA启用: {'✅' if enable_pca else '❌'}\")\n",
|
|||
|
" if enable_pca:\n",
|
|||
|
" print(f\" 成分数量: {n_components if n_components else 'auto'}\")\n",
|
|||
|
" print(f\" 方差阈值: {variance_threshold}\")\n",
|
|||
|
" \n",
|
|||
|
" def fit_pca(self, sample_size=10000):\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" 在训练数据样本上拟合PCA\n",
|
|||
|
" \n",
|
|||
|
" Args:\n",
|
|||
|
" sample_size: 用于拟合PCA的样本数量\n",
|
|||
|
" \"\"\"\n",
|
|||
|
" if not self.enable_pca:\n",
|
|||
|
" print(\"⚠️ PCA未启用,跳过拟合\")\n",
|
|||
|
" return\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n🔧 拟合PCA降维器...\")\n",
|
|||
|
" print(f\" 使用样本数: {sample_size}\")\n",
|
|||
|
" \n",
|
|||
|
" # 收集样本数据\n",
|
|||
|
" sample_features = []\n",
|
|||
|
" collected_samples = 0\n",
|
|||
|
" \n",
|
|||
|
" for features, labels in self.get_batch_generator_raw():\n",
|
|||
|
" sample_features.append(features)\n",
|
|||
|
" collected_samples += features.shape[0]\n",
|
|||
|
" \n",
|
|||
|
" if collected_samples >= sample_size:\n",
|
|||
|
" break\n",
|
|||
|
" \n",
|
|||
|
" if sample_features:\n",
|
|||
|
" # 合并样本数据\n",
|
|||
|
" X_sample = np.vstack(sample_features)[:sample_size]\n",
|
|||
|
" print(f\" 实际样本数: {X_sample.shape[0]}\")\n",
|
|||
|
" print(f\" 原始特征数: {X_sample.shape[1]}\")\n",
|
|||
|
" \n",
|
|||
|
" # 标准化\n",
|
|||
|
" X_sample_scaled = self.scaler.fit_transform(X_sample)\n",
|
|||
|
" \n",
|
|||
|
" # 确定PCA成分数\n",
|
|||
|
" if self.n_components is None:\n",
|
|||
|
" # 自动选择成分数 - 先拟合完整PCA\n",
|
|||
|
" print(f\" 🔍 自动选择PCA成分数...\")\n",
|
|||
|
" pca_full = PCA()\n",
|
|||
|
" pca_full.fit(X_sample_scaled)\n",
|
|||
|
" \n",
|
|||
|
" # 计算累积方差比例\n",
|
|||
|
" cumsum_var = np.cumsum(pca_full.explained_variance_ratio_)\n",
|
|||
|
" optimal_components = np.argmax(cumsum_var >= self.variance_threshold) + 1\n",
|
|||
|
" \n",
|
|||
|
" self.n_components = min(optimal_components, X_sample.shape[1])\n",
|
|||
|
" print(f\" 📊 方差分析:\")\n",
|
|||
|
" print(f\" 保留{self.variance_threshold*100}%方差需要: {optimal_components} 个成分\")\n",
|
|||
|
" print(f\" 选择成分数: {self.n_components}\")\n",
|
|||
|
" print(f\" 实际保留方差: {cumsum_var[self.n_components-1]:.4f}\")\n",
|
|||
|
" \n",
|
|||
|
" # 拟合最终PCA\n",
|
|||
|
" self.pca = PCA(n_components=self.n_components, random_state=42)\n",
|
|||
|
" self.pca.fit(X_sample_scaled)\n",
|
|||
|
" \n",
|
|||
|
" self.is_fitted = True\n",
|
|||
|
" \n",
|
|||
|
" print(f\" ✅ PCA拟合完成!\")\n",
|
|||
|
" print(f\" 降维: {X_sample.shape[1]} → {self.n_components}\")\n",
|
|||
|
" print(f\" 降维比例: {self.n_components/X_sample.shape[1]:.2%}\")\n",
|
|||
|
" print(f\" 保留方差: {self.pca.explained_variance_ratio_.sum():.4f}\")\n",
|
|||
|
" \n",
|
|||
|
" # 保存PCA模型\n",
|
|||
|
" pca_path = f\"pca_model_{self.data_type}.joblib\"\n",
|
|||
|
" joblib.dump({'scaler': self.scaler, 'pca': self.pca}, pca_path)\n",
|
|||
|
" print(f\" 模型已保存: {pca_path}\")\n",
|
|||
|
" \n",
|
|||
|
" else:\n",
|
|||
|
" print(\"❌ 无法收集样本数据用于PCA拟合\")\n",
|
|||
|
" \n",
|
|||
|
" def load_pca_model(self, model_path):\n",
|
|||
|
" \"\"\"加载预训练的PCA模型\"\"\"\n",
|
|||
|
" if os.path.exists(model_path):\n",
|
|||
|
" models = joblib.load(model_path)\n",
|
|||
|
" self.scaler = models['scaler']\n",
|
|||
|
" self.pca = models['pca']\n",
|
|||
|
" self.is_fitted = True\n",
|
|||
|
" self.n_components = self.pca.n_components_\n",
|
|||
|
" print(f\"✅ PCA模型加载成功: {model_path}\")\n",
|
|||
|
" return True\n",
|
|||
|
" return False\n",
|
|||
|
" \n",
|
|||
|
" def get_batch_generator_raw(self):\n",
|
|||
|
" \"\"\"原始数据生成器(用于PCA拟合)\"\"\"\n",
|
|||
|
" for file_idx, f in enumerate(self.files):\n",
|
|||
|
" data = np.load(os.path.join(self.data_dir, f), allow_pickle=True)\n",
|
|||
|
" trials = data['neural_logits_concatenated']\n",
|
|||
|
" \n",
|
|||
|
" if len(trials) > self.max_samples_per_file:\n",
|
|||
|
" trials = trials[:self.max_samples_per_file]\n",
|
|||
|
" \n",
|
|||
|
" features, labels = self._extract_features_labels(trials)\n",
|
|||
|
" yield features, labels\n",
|
|||
|
" \n",
|
|||
|
" del data, trials\n",
|
|||
|
" gc.collect()\n",
|
|||
|
" \n",
|
|||
|
" def get_batch_generator(self):\n",
|
|||
|
" \"\"\"PCA处理后的数据生成器\"\"\"\n",
|
|||
|
" for file_idx, f in enumerate(self.files):\n",
|
|||
|
" print(f\" 正在加载文件 {file_idx+1}/{len(self.files)}: {f}\")\n",
|
|||
|
" \n",
|
|||
|
" data = np.load(os.path.join(self.data_dir, f), allow_pickle=True)\n",
|
|||
|
" trials = data['neural_logits_concatenated']\n",
|
|||
|
" \n",
|
|||
|
" if len(trials) > self.max_samples_per_file:\n",
|
|||
|
" trials = trials[:self.max_samples_per_file]\n",
|
|||
|
" \n",
|
|||
|
" features, labels = self._extract_features_labels(trials)\n",
|
|||
|
" \n",
|
|||
|
" # 应用PCA降维\n",
|
|||
|
" if self.enable_pca and self.is_fitted:\n",
|
|||
|
" features_scaled = self.scaler.transform(features)\n",
|
|||
|
" features_pca = self.pca.transform(features_scaled)\n",
|
|||
|
" yield features_pca, labels\n",
|
|||
|
" else:\n",
|
|||
|
" # 只标准化,不降维\n",
|
|||
|
" features_scaled = self.scaler.transform(features) if self.is_fitted else features\n",
|
|||
|
" yield features_scaled, labels\n",
|
|||
|
" \n",
|
|||
|
" del data, trials\n",
|
|||
|
" gc.collect()\n",
|
|||
|
" \n",
|
|||
|
" def _extract_features_labels(self, trials_batch):\n",
|
|||
|
" \"\"\"提取特征和标签\"\"\"\n",
|
|||
|
" features = []\n",
|
|||
|
" labels = []\n",
|
|||
|
" \n",
|
|||
|
" for trial in trials_batch:\n",
|
|||
|
" if trial.shape[0] > 0:\n",
|
|||
|
" for t in range(trial.shape[0]):\n",
|
|||
|
" neural_features = trial[t, :7168] # 前7168维神经特征\n",
|
|||
|
" rnn_logits = trial[t, 7168:] # 后41维RNN输出\n",
|
|||
|
" phoneme_label = np.argmax(rnn_logits)\n",
|
|||
|
" \n",
|
|||
|
" features.append(neural_features)\n",
|
|||
|
" labels.append(phoneme_label)\n",
|
|||
|
" \n",
|
|||
|
" return np.array(features), np.array(labels)\n",
|
|||
|
" \n",
|
|||
|
" def plot_pca_analysis(self):\n",
|
|||
|
" \"\"\"可视化PCA分析结果\"\"\"\n",
|
|||
|
" if not (self.enable_pca and self.is_fitted):\n",
|
|||
|
" print(\"❌ PCA未拟合,无法绘制分析图\")\n",
|
|||
|
" return\n",
|
|||
|
" \n",
|
|||
|
" fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n",
|
|||
|
" \n",
|
|||
|
" # 1. 方差解释比例\n",
|
|||
|
" axes[0].bar(range(1, min(21, len(self.pca.explained_variance_ratio_)+1)), \n",
|
|||
|
" self.pca.explained_variance_ratio_[:20])\n",
|
|||
|
" axes[0].set_title('Explained Variance Ratio (Top 20 Components)')\n",
|
|||
|
" axes[0].set_xlabel('Principal Component')\n",
|
|||
|
" axes[0].set_ylabel('Explained Variance Ratio')\n",
|
|||
|
" \n",
|
|||
|
" # 2. 累积方差解释比例\n",
|
|||
|
" cumsum_var = np.cumsum(self.pca.explained_variance_ratio_)\n",
|
|||
|
" axes[1].plot(range(1, len(cumsum_var)+1), cumsum_var, 'b-', linewidth=2)\n",
|
|||
|
" axes[1].axhline(y=self.variance_threshold, color='r', linestyle='--', \n",
|
|||
|
" label=f'Threshold ({self.variance_threshold})')\n",
|
|||
|
" axes[1].axvline(x=self.n_components, color='g', linestyle='--', \n",
|
|||
|
" label=f'Selected Components ({self.n_components})')\n",
|
|||
|
" axes[1].set_title('Cumulative Explained Variance Ratio')\n",
|
|||
|
" axes[1].set_xlabel('Number of Components')\n",
|
|||
|
" axes[1].set_ylabel('Cumulative Explained Variance Ratio')\n",
|
|||
|
" axes[1].legend()\n",
|
|||
|
" axes[1].grid(True, alpha=0.3)\n",
|
|||
|
" \n",
|
|||
|
" # 3. 降维效果\n",
|
|||
|
" original_dims = 7168\n",
|
|||
|
" reduction_ratio = self.n_components / original_dims\n",
|
|||
|
" \n",
|
|||
|
" categories = ['Original Dimensions', 'PCA Dimensions']\n",
|
|||
|
" values = [original_dims, self.n_components]\n",
|
|||
|
" colors = ['lightcoral', 'lightblue']\n",
|
|||
|
" \n",
|
|||
|
" bars = axes[2].bar(categories, values, color=colors)\n",
|
|||
|
" axes[2].set_title(f'Dimensionality Reduction Effect (Retained {reduction_ratio:.1%})')\n",
|
|||
|
" axes[2].set_ylabel('Feature Dimensions')\n",
|
|||
|
" \n",
|
|||
|
" # 添加数值标签\n",
|
|||
|
" for bar, value in zip(bars, values):\n",
|
|||
|
" axes[2].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 50,\n",
|
|||
|
" f'{value}', ha='center', va='bottom', fontweight='bold')\n",
|
|||
|
" \n",
|
|||
|
" plt.tight_layout()\n",
|
|||
|
" plt.show()\n",
|
|||
|
" \n",
|
|||
|
" # 打印详细统计\n",
|
|||
|
" print(f\"\\n📊 PCA降维统计:\")\n",
|
|||
|
" print(f\" 原始维度: {original_dims}\")\n",
|
|||
|
" print(f\" 降维后维度: {self.n_components}\")\n",
|
|||
|
" print(f\" 维度保留比例: {reduction_ratio:.2%}\")\n",
|
|||
|
" print(f\" 方差保留比例: {cumsum_var[self.n_components-1]:.4f}\")\n",
|
|||
|
" print(f\" 内存节省: {(1-reduction_ratio)*100:.1f}%\")\n",
|
|||
|
"\n",
|
|||
|
"print(\"✅ PCA增强数据加载器已定义\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 14,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"======================================================================\n",
|
|||
|
"🚀 创建PCA降维数据集\n",
|
|||
|
"======================================================================\n",
|
|||
|
"📊 PCA配置:\n",
|
|||
|
" enable_pca: True\n",
|
|||
|
" n_components: None\n",
|
|||
|
" variance_threshold: 0.95\n",
|
|||
|
" sample_size: 15000\n",
|
|||
|
"📊 PCA数据集初始化:\n",
|
|||
|
" 数据类型: train\n",
|
|||
|
" 文件数量: 45\n",
|
|||
|
" PCA启用: ✅\n",
|
|||
|
" 成分数量: auto\n",
|
|||
|
" 方差阈值: 0.95\n",
|
|||
|
"📊 PCA数据集初始化:\n",
|
|||
|
" 数据类型: val\n",
|
|||
|
" 文件数量: 41\n",
|
|||
|
" PCA启用: ✅\n",
|
|||
|
" 成分数量: auto\n",
|
|||
|
" 方差阈值: 0.95\n",
|
|||
|
"📊 PCA数据集初始化:\n",
|
|||
|
" 数据类型: test\n",
|
|||
|
" 文件数量: 41\n",
|
|||
|
" PCA启用: ✅\n",
|
|||
|
" 成分数量: auto\n",
|
|||
|
" 方差阈值: 0.95\n",
|
|||
|
"\n",
|
|||
|
"✅ PCA数据集创建完成\n",
|
|||
|
"\n",
|
|||
|
"🔧 在训练集上拟合PCA...\n",
|
|||
|
"\n",
|
|||
|
"🔧 拟合PCA降维器...\n",
|
|||
|
" 使用样本数: 15000\n",
|
|||
|
" 实际样本数: 15000\n",
|
|||
|
" 原始特征数: 7168\n",
|
|||
|
" 实际样本数: 15000\n",
|
|||
|
" 原始特征数: 7168\n",
|
|||
|
" 🔍 自动选择PCA成分数...\n",
|
|||
|
" 🔍 自动选择PCA成分数...\n",
|
|||
|
" 📊 方差分析:\n",
|
|||
|
" 保留95.0%方差需要: 1062 个成分\n",
|
|||
|
" 选择成分数: 1062\n",
|
|||
|
" 实际保留方差: 0.9501\n",
|
|||
|
" 📊 方差分析:\n",
|
|||
|
" 保留95.0%方差需要: 1062 个成分\n",
|
|||
|
" 选择成分数: 1062\n",
|
|||
|
" 实际保留方差: 0.9501\n",
|
|||
|
" ✅ PCA拟合完成!\n",
|
|||
|
" 降维: 7168 → 1062\n",
|
|||
|
" 降维比例: 14.82%\n",
|
|||
|
" 保留方差: 0.9491\n",
|
|||
|
" 模型已保存: pca_model_train.joblib\n",
|
|||
|
"\n",
|
|||
|
"🔄 复制PCA模型到验证和测试集...\n",
|
|||
|
" ✅ PCA模型复制完成\n",
|
|||
|
"\n",
|
|||
|
"📊 绘制PCA分析图...\n",
|
|||
|
" ✅ PCA拟合完成!\n",
|
|||
|
" 降维: 7168 → 1062\n",
|
|||
|
" 降维比例: 14.82%\n",
|
|||
|
" 保留方差: 0.9491\n",
|
|||
|
" 模型已保存: pca_model_train.joblib\n",
|
|||
|
"\n",
|
|||
|
"🔄 复制PCA模型到验证和测试集...\n",
|
|||
|
" ✅ PCA模型复制完成\n",
|
|||
|
"\n",
|
|||
|
"📊 绘制PCA分析图...\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAABv4AAAHqCAYAAADMEzkrAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3gU1dvG8e8mpIeEltBJQpMOSu8oJUJoShWU0EWaEAHBH1URBKSJNEFAEZQiIgoiXUSaIiBFepPeQwnp8/6x7y4sKSSQsIHcn+vaKztnzs48M9mdPTvPzDkmwzAMREREREREREREREREROSZ5mDvAERERERERERERERERETkySnxJyIiIiIiIiIiIiIiIvIcUOJPRERERERERERERERE5DmgxJ+IiIiIiIiIiIiIiIjIc0CJPxEREREREREREREREZHngBJ/IiIiIiIiIiIiIiIiIs8BJf5EREREREREREREREREngNK/ImIiIiIiIiIiIiIiIg8B5T4ExEREREREREREREREXkOKPEnaVL79u3x9/d/rNf6+/vTvn37FI0nqZ4k7tSSFmN6FqTG++jOnTv4+vqyYMGCFF2uSEqJiooib968TJs2zd6hiIgkW2q0eebNm4fJZOLUqVMputy0TO3wlJMWY3oW2PN9JCJp1/DhwzGZTPYO44mdOnUKk8nEvHnz7B0KALVq1aJWrVrW6bQWX0qx13fLs7I/jx49Sr169fD29sZkMrF8+XIA/vzzT6pUqYKHhwcmk4k9e/bYNU6LsWPHUqRIEWJjY+0dSqI2bdqEyWRi06ZNdlm/yWRi+PDhdlm3PbVu3ZqWLVvaOwy7U+JPEmQ50ZHQY/v27fYO8Zlz+fJlMmTIwJtvvplgndu3b+Pm5sbrr7/+FCNL+2rVqmXz/nNzc6NUqVJMmjTpsb/ot27dyvDhw7l582bKBpuAyZMnkzFjRlq3bm1t/CXlkdonGw8dOsSAAQMoU6YMGTNmJGfOnAQFBfHXX3/FW//cuXO0bNmSTJky4eXlRZMmTThx4kSS1xcTE8PcuXOpVasWWbJkwcXFBX9/fzp06JDgOiXpzp8/z/Dhwx+rQe7k5ERISAgff/wx4eHhKR+ciKR5x48f5+233yZ//vy4urri5eVF1apVmTx5Mvfu3bN3eKlm1KhR1hMcaYHa4SlP7fDH9zy0w0UkbXn4e87V1ZVcuXIRGBjIZ599xu3bt+0dojxg1apVqZI88Pf3t3kfeHh4UKFCBb7++usUX1dqW7hwIZMmTbJ3GDbat2+fYFvS1dXVpm5wcDD79u3j448/Zv78+ZQrV46oqChatGjB9evXmThxIvPnz8fPzy9FY3yc8xe3bt1izJgxvP/++zg43E9tPLyNXl5e1KxZk5UrVz52fGntN0JqmT59Oi1atCBfvnyYTKYkJ8m7dOmCyWSiYcOGSV7X4sWLqVSpEpkyZSJr1qzx/o9u3rxJ27ZtyZw5M/nz5+fLL7+Ms5y//voLd3d3Tp48GWfe+++/z/fff8/evXuTHNfzKIO9A5C078MPPyQgICBOecGCBe0QzaMdPnzY5sCflvj6+lK3bl1+/PFHwsLCcHd3j1Nn2bJlhIeHJ3pSIjlmzZqV5q+ASao8efIwevRoAK5evcrChQvp27cvV65c4eOPP0728rZu3cqIESNo3749mTJlspmX0u+jqKgoJk+eTN++fXF0dMTHx4f58+fb1Bk/fjxnz55l4sSJNuU+Pj4pFkd8Zs+ezZdffkmzZs3o3r07oaGhzJw5k0qVKrF69Wrq1KljrXvnzh1efvllQkND+eCDD3BycmLixInUrFmTPXv2kDVr1kTXde/ePV5//XVWr15NjRo1+OCDD8iSJQunTp1i8eLFfPXVV5w5c4Y8efKk6jY/z86fP8+IESPw9/enTJkyyX59hw4dGDhwIAsXLqRjx44pH6CIpFkrV66kRYsWuLi40K5dO0qUKEFkZCRbtmyhf//+HDhwgC+++MLeYaaKUaNG0bx5c5o2bWpT/tZbb9G6dWtcXFzsEpfa4SlH7fAn8yy3w0Uk7bJ8z0VFRXHx4kU2bdpEnz59mDBhAitWrKBUqVLWuoMHD2bgwIF2jDZl+Pn5ce/ePZycnOwdSrzii2/VqlVMnTo1VZJ/ZcqU4b333gPgwoULzJ49m+DgYCIiIujSpUuKry+1LFy4kP3799OnTx+bcnv/v11cXJg9e3acckdHR+vze/fusW3bNv73v//Rs2dPa/mhQ4c4ffo0s2bNonPnzqkS3+Ocv5gzZw7R0dG88cYbcebVrVuXdu3aYRgGp0+fZvr06TRq1IhffvmFwMDAZMeX0G+EpKpRowb37t3D2dn5sV7/tIwZM4bbt29ToUIFLly4kKTX/PXXX8ybNy9OEjkxU6ZMoXfv3gQFBfHJJ58QHh7OvHnzaNiwId9//7314rt+/fqxadMmRowYwbFjx+jSpQtFixalSpUqABiGQe/evenTp0+8v5VefPFFypUrx/jx45/JCwlSihJ/8kj169enXLly9g4jyex1Yiap2rZty+rVq1mxYgWtW7eOM3/hwoV4e3sTFBT0ROu5e/cuHh4eabYx+Ti8vb1tTsR069aNIkWKMGXKFD788EObhsuTSun30c8//8yVK1est5p7eHjEOan03XffcePGjRQ72ZRUb7zxBsOHD8fT09Na1rFjR4oWLcrw4cNtEn/Tpk3j6NGj7Ny5k/LlywPmY0SJEiUYP348o0aNSnRd/fv3Z/Xq1UycODFOg3jYsGFxkp7y9GXKlIl69eoxb948Jf5E0pGTJ0/SunVr/Pz82LBhAzlz5rTO69GjB8eOHXuiq2WfVY6OjinavkgutcNTltrhj+9ZboeLSNr18PfcoEGD2LBhAw0bNqRx48b8+++/uLm5AZAhQwYyZHj2T2PGd7dVWvK048udO7fN90v79u3Jnz8/EydOfKYSfwmx9//7Ub0dAFy5cgUgzoU4ly9fjrfc3ubOnUvjxo3j3a+FCxe22d5mzZpRrFgxJk+e/FiJvyfl4OCQpj/vFr/99pv1br8Hzw0mxJJ4a9euHevXr0/yeqZMmUL58uX56aefrF03d+zYkdy5c/PVV19ZE38///wzY8eOpV27dgD8888//PTTT9bE34IFCzh9+jQffPBBgutq2bIlw4YNY9q0aUnapueRLqOTJzZs2DAcHBzifNC7du2Ks7Oz9bZaS7/GixYt4oMPPiBHjhx4eHjQuHFj/vvvv0eu59NPP6VKlSpkzZoVNzc3ypYty9KlS+PUe7jfbksXEn/88QchISH4+Pjg4eHBa6+9Zv1ye9Avv/xC9erV8fDwIGPGjAQFBXHgwIE49ZYvX06JEiVwdXWlRIkS/PDDD4/cBoDXXnsNDw8PFi5cGGfe5cuXWb9+Pc2bN8fFxYXff//dequ1i4sLefPmpW/fvnG62mrfvj2enp4cP36cBg0akDFjRtq2bWud9/DYIkndlyaTiZ49e1q31cXFheLFi7N69eo4dc+dO0enTp3IlSsXLi4uBAQE8M477xAZGWmtc/PmTfr06UPevHlxcXGhYMGCjBkz5rGvhHZ1daV8+fLcvn3b2iAB8xeCpbHo6upKjhw56NixI9euXbPWGT58OP379wcgICAgTrea8fX/fuLECVq0aEGWLFlwd3enUqVKST4Junz5cvz9/SlQoECytvHy5ct06tSJ7Nmz4+rqSunSpfnqq69s6li6Df3000+ZOHEifn5+uLm5UbNmTfbv3//IdZQtWzbOl2DWrFmpXr06//77r0350qVLKV++vDXpB1CkSBFq167N4sWLE13P2bNnmTlzJnXr1o2T9APzydV+/frZ3O23e/du6tevj5eXF56entSuXTtO92aWz/iWLVvo3bs3Pj4+ZMqUibfffpvIyEhu3rxJu3btyJw5M5kzZ2bAgAEYhvHY+2/Dhg3WY0SmTJlo0qRJnP1kGX/i2LFj1ivZvb296dChA2FhYXGW+c0331C2bFnc3NzIkiULrVu3jnNcrFWrFiVKlODgwYO8/PLLuLu7kzt3bsaOHWuts2nTJuv/pkO
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 1800x500 with 3 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {},
|
|||
|
"output_type": "display_data"
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"\n",
|
|||
|
"📊 PCA降维统计:\n",
|
|||
|
" 原始维度: 7168\n",
|
|||
|
" 降维后维度: 1062\n",
|
|||
|
" 维度保留比例: 14.82%\n",
|
|||
|
" 方差保留比例: 0.9491\n",
|
|||
|
" 内存节省: 85.2%\n",
|
|||
|
"\n",
|
|||
|
"🎯 使用方法:\n",
|
|||
|
" for features, labels in train_dataset_pca.get_batch_generator():\n",
|
|||
|
" # features 已经过PCA降维\n",
|
|||
|
" # 可直接用于LightGBM训练\n",
|
|||
|
" pass\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 🚀 PCA降维数据集配置和使用\n",
|
|||
|
"\n",
|
|||
|
"print(\"=\"*70)\n",
|
|||
|
"print(\"🚀 创建PCA降维数据集\")\n",
|
|||
|
"print(\"=\"*70)\n",
|
|||
|
"\n",
|
|||
|
"# PCA配置选项\n",
|
|||
|
"PCA_CONFIG = {\n",
|
|||
|
" 'enable_pca': True, # 是否启用PCA\n",
|
|||
|
" 'n_components': None, # None=自动选择, 或指定具体数值如512\n",
|
|||
|
" 'variance_threshold': 0.95, # 保留95%的方差\n",
|
|||
|
" 'sample_size': 15000, # 用于拟合PCA的样本数\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"print(\"📊 PCA配置:\")\n",
|
|||
|
"for key, value in PCA_CONFIG.items():\n",
|
|||
|
" print(f\" {key}: {value}\")\n",
|
|||
|
"\n",
|
|||
|
"# 创建PCA增强数据集\n",
|
|||
|
"train_dataset_pca = PCAEnhancedDataset(\n",
|
|||
|
" data_dir=data_dir, \n",
|
|||
|
" data_type='train',\n",
|
|||
|
" max_samples_per_file=MAX_SAMPLES_PER_FILE,\n",
|
|||
|
" enable_pca=PCA_CONFIG['enable_pca'],\n",
|
|||
|
" n_components=PCA_CONFIG['n_components'],\n",
|
|||
|
" variance_threshold=PCA_CONFIG['variance_threshold']\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"val_dataset_pca = PCAEnhancedDataset(\n",
|
|||
|
" data_dir=data_dir,\n",
|
|||
|
" data_type='val', \n",
|
|||
|
" max_samples_per_file=MAX_SAMPLES_PER_FILE,\n",
|
|||
|
" enable_pca=PCA_CONFIG['enable_pca']\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"test_dataset_pca = PCAEnhancedDataset(\n",
|
|||
|
" data_dir=data_dir,\n",
|
|||
|
" data_type='test',\n",
|
|||
|
" max_samples_per_file=MAX_SAMPLES_PER_FILE, \n",
|
|||
|
" enable_pca=PCA_CONFIG['enable_pca']\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n✅ PCA数据集创建完成\")\n",
|
|||
|
"\n",
|
|||
|
"# 拟合PCA (只在训练集上)\n",
|
|||
|
"if PCA_CONFIG['enable_pca']:\n",
|
|||
|
" print(f\"\\n🔧 在训练集上拟合PCA...\")\n",
|
|||
|
" train_dataset_pca.fit_pca(sample_size=PCA_CONFIG['sample_size'])\n",
|
|||
|
" \n",
|
|||
|
" # 将训练好的PCA应用到验证和测试集\n",
|
|||
|
" if train_dataset_pca.is_fitted:\n",
|
|||
|
" print(f\"\\n🔄 复制PCA模型到验证和测试集...\")\n",
|
|||
|
" val_dataset_pca.scaler = train_dataset_pca.scaler\n",
|
|||
|
" val_dataset_pca.pca = train_dataset_pca.pca\n",
|
|||
|
" val_dataset_pca.n_components = train_dataset_pca.n_components\n",
|
|||
|
" val_dataset_pca.is_fitted = True\n",
|
|||
|
" \n",
|
|||
|
" test_dataset_pca.scaler = train_dataset_pca.scaler\n",
|
|||
|
" test_dataset_pca.pca = train_dataset_pca.pca\n",
|
|||
|
" test_dataset_pca.n_components = train_dataset_pca.n_components\n",
|
|||
|
" test_dataset_pca.is_fitted = True\n",
|
|||
|
" \n",
|
|||
|
" print(f\" ✅ PCA模型复制完成\")\n",
|
|||
|
" \n",
|
|||
|
" # 绘制PCA分析图\n",
|
|||
|
" print(f\"\\n📊 绘制PCA分析图...\")\n",
|
|||
|
" train_dataset_pca.plot_pca_analysis()\n",
|
|||
|
" else:\n",
|
|||
|
" print(f\"❌ PCA拟合失败\")\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n🎯 使用方法:\")\n",
|
|||
|
"print(f\" for features, labels in train_dataset_pca.get_batch_generator():\")\n",
|
|||
|
"print(f\" # features 已经过PCA降维\")\n",
|
|||
|
"print(f\" # 可直接用于LightGBM训练\")\n",
|
|||
|
"print(f\" pass\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 27,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"======================================================================\n",
|
|||
|
"📊 PCA降维效果测试\n",
|
|||
|
"======================================================================\n",
|
|||
|
"🔍 测试PCA数据加载...\n",
|
|||
|
" 正在加载文件 1/45: t15.2025.04.13_train_concatenated.npz\n",
|
|||
|
" ✅ PCA数据加载成功\n",
|
|||
|
" PCA特征形状: (14677, 1062)\n",
|
|||
|
" 标签形状: (14677,)\n",
|
|||
|
" PCA特征范围: [-83.4304, 176.7492]\n",
|
|||
|
" 标签范围: [0, 40]\n",
|
|||
|
"\n",
|
|||
|
"📊 降维对比:\n",
|
|||
|
" 原始特征维度: 7168\n",
|
|||
|
" PCA特征维度: 1062\n",
|
|||
|
" 降维比例: 14.82%\n",
|
|||
|
" 内存节省: 85.2%\n",
|
|||
|
"\n",
|
|||
|
"⚡ 训练速度预估对比:\n",
|
|||
|
" 原始特征数: 7168\n",
|
|||
|
" PCA特征数: 1062\n",
|
|||
|
" 预估速度提升: 6.7x\n",
|
|||
|
" 预估训练时间: 14.8% of 原始时间\n",
|
|||
|
"\n",
|
|||
|
"💡 PCA配置建议:\n",
|
|||
|
" 🔬 数据探索阶段:\n",
|
|||
|
" - variance_threshold: 0.90-0.95 (快速原型)\n",
|
|||
|
" - n_components: 200-500 (固定维度)\n",
|
|||
|
" 🎯 性能优化阶段:\n",
|
|||
|
" - variance_threshold: 0.95-0.99 (保持精度)\n",
|
|||
|
" - n_components: 根据验证集性能调整\n",
|
|||
|
" 🚀 生产部署阶段:\n",
|
|||
|
" - 根据内存和速度需求选择最优配置\n",
|
|||
|
"\n",
|
|||
|
"🔧 使用不同PCA配置的方法:\n",
|
|||
|
" # 快速原型 (大幅降维)\n",
|
|||
|
" dataset_fast = PCAEnhancedDataset(..., n_components=200)\n",
|
|||
|
" \n",
|
|||
|
" # 平衡配置 (自动选择)\n",
|
|||
|
" dataset_balanced = PCAEnhancedDataset(..., variance_threshold=0.95)\n",
|
|||
|
" \n",
|
|||
|
" # 高精度配置 (保留更多信息)\n",
|
|||
|
" dataset_precision = PCAEnhancedDataset(..., variance_threshold=0.99)\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 📊 PCA降维效果测试和对比\n",
|
|||
|
"\n",
|
|||
|
"print(\"=\"*70)\n",
|
|||
|
"print(\"📊 PCA降维效果测试\")\n",
|
|||
|
"print(\"=\"*70)\n",
|
|||
|
"\n",
|
|||
|
"def test_pca_loading():\n",
|
|||
|
" \"\"\"测试PCA数据加载\"\"\"\n",
|
|||
|
" if not train_dataset_pca.is_fitted:\n",
|
|||
|
" print(\"❌ PCA未拟合,无法测试\")\n",
|
|||
|
" return\n",
|
|||
|
" \n",
|
|||
|
" print(\"🔍 测试PCA数据加载...\")\n",
|
|||
|
" \n",
|
|||
|
" # 测试加载一个批次\n",
|
|||
|
" try:\n",
|
|||
|
" for features_pca, labels in train_dataset_pca.get_batch_generator():\n",
|
|||
|
" print(f\" ✅ PCA数据加载成功\")\n",
|
|||
|
" print(f\" PCA特征形状: {features_pca.shape}\")\n",
|
|||
|
" print(f\" 标签形状: {labels.shape}\")\n",
|
|||
|
" print(f\" PCA特征范围: [{features_pca.min():.4f}, {features_pca.max():.4f}]\")\n",
|
|||
|
" print(f\" 标签范围: [{labels.min()}, {labels.max()}]\")\n",
|
|||
|
" \n",
|
|||
|
" # 对比原始数据\n",
|
|||
|
" print(f\"\\n📊 降维对比:\")\n",
|
|||
|
" print(f\" 原始特征维度: 7168\")\n",
|
|||
|
" print(f\" PCA特征维度: {features_pca.shape[1]}\")\n",
|
|||
|
" print(f\" 降维比例: {features_pca.shape[1]/7168:.2%}\")\n",
|
|||
|
" print(f\" 内存节省: {(1-features_pca.shape[1]/7168)*100:.1f}%\")\n",
|
|||
|
" break\n",
|
|||
|
" \n",
|
|||
|
" except Exception as e:\n",
|
|||
|
" print(f\"❌ PCA数据加载失败: {e}\")\n",
|
|||
|
"\n",
|
|||
|
"def compare_training_speed():\n",
|
|||
|
" \"\"\"比较训练速度(模拟)\"\"\"\n",
|
|||
|
" if not train_dataset_pca.is_fitted:\n",
|
|||
|
" print(\"❌ PCA未拟合,无法比较\")\n",
|
|||
|
" return\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n⚡ 训练速度预估对比:\")\n",
|
|||
|
" original_features = 7168\n",
|
|||
|
" pca_features = train_dataset_pca.n_components\n",
|
|||
|
" \n",
|
|||
|
" # 简单的复杂度估算 (特征数的线性关系)\n",
|
|||
|
" speed_improvement = original_features / pca_features\n",
|
|||
|
" \n",
|
|||
|
" print(f\" 原始特征数: {original_features}\")\n",
|
|||
|
" print(f\" PCA特征数: {pca_features}\")\n",
|
|||
|
" print(f\" 预估速度提升: {speed_improvement:.1f}x\")\n",
|
|||
|
" print(f\" 预估训练时间: {1/speed_improvement:.1%} of 原始时间\")\n",
|
|||
|
"\n",
|
|||
|
"# 执行测试\n",
|
|||
|
"test_pca_loading()\n",
|
|||
|
"compare_training_speed()\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n💡 PCA配置建议:\")\n",
|
|||
|
"print(f\" 🔬 数据探索阶段:\")\n",
|
|||
|
"print(f\" - variance_threshold: 0.90-0.95 (快速原型)\")\n",
|
|||
|
"print(f\" - n_components: 200-500 (固定维度)\")\n",
|
|||
|
"print(f\" 🎯 性能优化阶段:\") \n",
|
|||
|
"print(f\" - variance_threshold: 0.95-0.99 (保持精度)\")\n",
|
|||
|
"print(f\" - n_components: 根据验证集性能调整\")\n",
|
|||
|
"print(f\" 🚀 生产部署阶段:\")\n",
|
|||
|
"print(f\" - 根据内存和速度需求选择最优配置\")\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n🔧 使用不同PCA配置的方法:\")\n",
|
|||
|
"print(f\" # 快速原型 (大幅降维)\")\n",
|
|||
|
"print(f\" dataset_fast = PCAEnhancedDataset(..., n_components=200)\")\n",
|
|||
|
"print(f\" \")\n",
|
|||
|
"print(f\" # 平衡配置 (自动选择)\")\n",
|
|||
|
"print(f\" dataset_balanced = PCAEnhancedDataset(..., variance_threshold=0.95)\")\n",
|
|||
|
"print(f\" \")\n",
|
|||
|
"print(f\" # 高精度配置 (保留更多信息)\")\n",
|
|||
|
"print(f\" dataset_precision = PCAEnhancedDataset(..., variance_threshold=0.99)\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 15,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"======================================================================\n",
|
|||
|
"📖 PCA降维数据集使用指南\n",
|
|||
|
"======================================================================\n",
|
|||
|
"🎯 你现在有两种数据集可以使用:\n",
|
|||
|
" 1. 集成PCA数据集 (推荐) - train_dataset, val_dataset, test_dataset\n",
|
|||
|
" 2. 独立PCA数据集 - train_dataset_pca, val_dataset_pca, test_dataset_pca\n",
|
|||
|
"\n",
|
|||
|
"🚀 方式1: 使用集成PCA数据集 (推荐)\n",
|
|||
|
"==================================================\n",
|
|||
|
"✅ 特点:\n",
|
|||
|
" - PCA已集成到数据加载流程\n",
|
|||
|
" - 自动降维: 7168 → 1062 维\n",
|
|||
|
" - 内存节省: 85.2%\n",
|
|||
|
" - 训练速度提升: 6.7倍\n",
|
|||
|
"\n",
|
|||
|
"📝 使用示例:\n",
|
|||
|
"# 分批训练 (内存友好)\n",
|
|||
|
"for features_pca, labels in train_dataset.get_batch_generator():\n",
|
|||
|
" print(f'批次特征: {features_pca.shape}, 标签: {labels.shape}')\n",
|
|||
|
" # features_pca 已经是1062维的降维特征\n",
|
|||
|
" # 可以直接用于LightGBM训练\n",
|
|||
|
" break # 只演示第一批\n",
|
|||
|
"\n",
|
|||
|
"# 一次性加载 (如果内存够用)\n",
|
|||
|
"# X_train_pca, y_train = train_dataset.load_all_data()\n",
|
|||
|
"# X_val_pca, y_val = val_dataset.load_all_data()\n",
|
|||
|
"\n",
|
|||
|
"==================================================\n",
|
|||
|
"🧪 让我们测试一下数据加载:\n",
|
|||
|
"❌ 集成PCA数据集测试失败: name 'train_dataset' is not defined\n",
|
|||
|
"\n",
|
|||
|
"💡 现在你可以直接用这些数据训练LightGBM:\n",
|
|||
|
" • 特征已经降维,训练速度更快\n",
|
|||
|
" • 内存使用更少\n",
|
|||
|
" • PCA变换一致应用于训练/验证/测试集\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 📖 PCA降维数据集使用指南\n",
|
|||
|
"\n",
|
|||
|
"print(\"=\"*70)\n",
|
|||
|
"print(\"📖 PCA降维数据集使用指南\")\n",
|
|||
|
"print(\"=\"*70)\n",
|
|||
|
"\n",
|
|||
|
"print(\"🎯 你现在有两种数据集可以使用:\")\n",
|
|||
|
"print(\" 1. 集成PCA数据集 (推荐) - train_dataset, val_dataset, test_dataset\")\n",
|
|||
|
"print(\" 2. 独立PCA数据集 - train_dataset_pca, val_dataset_pca, test_dataset_pca\")\n",
|
|||
|
"\n",
|
|||
|
"print(\"\\n🚀 方式1: 使用集成PCA数据集 (推荐)\")\n",
|
|||
|
"print(\"=\" * 50)\n",
|
|||
|
"\n",
|
|||
|
"print(\"✅ 特点:\")\n",
|
|||
|
"print(\" - PCA已集成到数据加载流程\")\n",
|
|||
|
"print(\" - 自动降维: 7168 → 1062 维\")\n",
|
|||
|
"print(\" - 内存节省: 85.2%\")\n",
|
|||
|
"print(\" - 训练速度提升: 6.7倍\")\n",
|
|||
|
"\n",
|
|||
|
"print(\"\\n📝 使用示例:\")\n",
|
|||
|
"print(\"# 分批训练 (内存友好)\")\n",
|
|||
|
"print(\"for features_pca, labels in train_dataset.get_batch_generator():\")\n",
|
|||
|
"print(\" print(f'批次特征: {features_pca.shape}, 标签: {labels.shape}')\")\n",
|
|||
|
"print(\" # features_pca 已经是1062维的降维特征\")\n",
|
|||
|
"print(\" # 可以直接用于LightGBM训练\")\n",
|
|||
|
"print(\" break # 只演示第一批\")\n",
|
|||
|
"print()\n",
|
|||
|
"\n",
|
|||
|
"print(\"# 一次性加载 (如果内存够用)\")\n",
|
|||
|
"print(\"# X_train_pca, y_train = train_dataset.load_all_data()\")\n",
|
|||
|
"print(\"# X_val_pca, y_val = val_dataset.load_all_data()\")\n",
|
|||
|
"\n",
|
|||
|
"print(\"\\n\" + \"=\"*50)\n",
|
|||
|
"print(\"🧪 让我们测试一下数据加载:\")\n",
|
|||
|
"\n",
|
|||
|
"# 测试集成PCA数据集\n",
|
|||
|
"try:\n",
|
|||
|
" sample_count = 0\n",
|
|||
|
" for features_pca, labels in train_dataset.get_batch_generator():\n",
|
|||
|
" sample_count += features_pca.shape[0]\n",
|
|||
|
" print(f\"✅ 集成PCA数据集测试成功!\")\n",
|
|||
|
" print(f\" 批次特征形状: {features_pca.shape}\")\n",
|
|||
|
" print(f\" 批次标签形状: {labels.shape}\")\n",
|
|||
|
" print(f\" 特征维度: {features_pca.shape[1]} (已降维)\")\n",
|
|||
|
" print(f\" 标签范围: {labels.min()} - {labels.max()}\")\n",
|
|||
|
" \n",
|
|||
|
" # 检查是否真的是PCA降维数据\n",
|
|||
|
" if features_pca.shape[1] == 1062:\n",
|
|||
|
" print(f\" 🔬 确认: 数据已通过PCA降维 (7168→1062)\")\n",
|
|||
|
" else:\n",
|
|||
|
" print(f\" ⚠️ 注意: 特征维度为 {features_pca.shape[1]}\")\n",
|
|||
|
" break\n",
|
|||
|
" \n",
|
|||
|
"except Exception as e:\n",
|
|||
|
" print(f\"❌ 集成PCA数据集测试失败: {e}\")\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n💡 现在你可以直接用这些数据训练LightGBM:\")\n",
|
|||
|
"print(f\" • 特征已经降维,训练速度更快\")\n",
|
|||
|
"print(f\" • 内存使用更少\")\n",
|
|||
|
"print(f\" • PCA变换一致应用于训练/验证/测试集\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 33,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"======================================================================\n",
|
|||
|
"🚀 使用PCA降维数据训练LightGBM\n",
|
|||
|
"======================================================================\n",
|
|||
|
"🔧 训练设备: GPU\n",
|
|||
|
"\n",
|
|||
|
"📊 加载PCA降维训练数据...\n",
|
|||
|
" 正在加载文件 1/45: t15.2025.04.13_train_concatenated.npz\n",
|
|||
|
" 已加载批次 1: 14677 样本\n",
|
|||
|
" 正在加载文件 2/45: t15.2024.07.21_train_concatenated.npz\n",
|
|||
|
" 已加载批次 2: 44173 样本\n",
|
|||
|
" 正在加载文件 3/45: t15.2024.03.17_train_concatenated.npz\n",
|
|||
|
" 已加载批次 3: 64462 样本\n",
|
|||
|
" ✅ 训练数据准备完成:\n",
|
|||
|
" 特征形状: (123312, 1062) (PCA降维后)\n",
|
|||
|
" 标签形状: (123312,)\n",
|
|||
|
" 类别数: 41\n",
|
|||
|
"\n",
|
|||
|
"🔄 数据分割:\n",
|
|||
|
" 训练集: 98649 样本\n",
|
|||
|
" 验证集: 24663 样本\n",
|
|||
|
"\n",
|
|||
|
"🏗️ LightGBM配置:\n",
|
|||
|
" objective: multiclass\n",
|
|||
|
" num_class: 41\n",
|
|||
|
" metric: multi_logloss\n",
|
|||
|
" boosting_type: gbdt\n",
|
|||
|
" device: gpu\n",
|
|||
|
" num_leaves: 128\n",
|
|||
|
" learning_rate: 0.1\n",
|
|||
|
" feature_fraction: 0.8\n",
|
|||
|
" bagging_fraction: 0.8\n",
|
|||
|
" bagging_freq: 5\n",
|
|||
|
" verbose: -1\n",
|
|||
|
" random_state: 42\n",
|
|||
|
" gpu_platform_id: 0\n",
|
|||
|
" gpu_device_id: 0\n",
|
|||
|
" max_bin: 511\n",
|
|||
|
"\n",
|
|||
|
"🔄 创建LightGBM数据集...\n",
|
|||
|
"\n",
|
|||
|
"🚀 开始训练LightGBM模型...\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"name": "stderr",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"[LightGBM] [Fatal] bin size 512 cannot run on GPU\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"ename": "LightGBMError",
|
|||
|
"evalue": "bin size 512 cannot run on GPU",
|
|||
|
"output_type": "error",
|
|||
|
"traceback": [
|
|||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|||
|
"\u001b[0;31mLightGBMError\u001b[0m Traceback (most recent call last)",
|
|||
|
"\u001b[0;32m/tmp/ipykernel_36/4234267123.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 103\u001b[0m ]\n\u001b[1;32m 104\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 105\u001b[0;31m model = lgb.train(\n\u001b[0m\u001b[1;32m 106\u001b[0m \u001b[0mlgb_params\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0mtrain_lgb\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/lightgbm/engine.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(params, train_set, num_boost_round, valid_sets, valid_names, feval, init_model, feature_name, categorical_feature, keep_training_booster, callbacks)\u001b[0m\n\u001b[1;32m 253\u001b[0m \u001b[0;31m# construct booster\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 254\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 255\u001b[0;31m \u001b[0mbooster\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mBooster\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_set\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrain_set\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 256\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_valid_contain_train\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 257\u001b[0m \u001b[0mbooster\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset_train_data_name\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_data_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, params, train_set, model_file, model_str)\u001b[0m\n\u001b[1;32m 3435\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_set\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3436\u001b[0m \u001b[0mparams_str\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_param_dict_to_str\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3437\u001b[0;31m _safe_call(_LIB.LGBM_BoosterCreate(\n\u001b[0m\u001b[1;32m 3438\u001b[0m \u001b[0mtrain_set\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_handle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3439\u001b[0m \u001b[0m_c_str\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparams_str\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py\u001b[0m in \u001b[0;36m_safe_call\u001b[0;34m(ret)\u001b[0m\n\u001b[1;32m 261\u001b[0m \"\"\"\n\u001b[1;32m 262\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mret\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 263\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mLightGBMError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_LIB\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLGBM_GetLastError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'utf-8'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 264\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;31mLightGBMError\u001b[0m: bin size 512 cannot run on GPU"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 🚀 完整的LightGBM训练示例 (使用PCA降维数据)\n",
|
|||
|
"\n",
|
|||
|
"import lightgbm as lgb\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import time\n",
|
|||
|
"from sklearn.metrics import accuracy_score, classification_report\n",
|
|||
|
"\n",
|
|||
|
"print(\"=\"*70)\n",
|
|||
|
"print(\"🚀 使用PCA降维数据训练LightGBM\")\n",
|
|||
|
"print(\"=\"*70)\n",
|
|||
|
"\n",
|
|||
|
"# 1. 检查GPU可用性\n",
|
|||
|
"def check_gpu():\n",
|
|||
|
" try:\n",
|
|||
|
" test_data = lgb.Dataset(np.random.rand(100, 10), label=np.random.randint(0, 2, 100))\n",
|
|||
|
" test_params = {'device': 'gpu', 'objective': 'binary', 'verbose': -1}\n",
|
|||
|
" lgb.train(test_params, test_data, num_boost_round=1, callbacks=[])\n",
|
|||
|
" return True\n",
|
|||
|
" except:\n",
|
|||
|
" return False\n",
|
|||
|
"\n",
|
|||
|
"gpu_available = check_gpu()\n",
|
|||
|
"device = 'gpu' if gpu_available else 'cpu'\n",
|
|||
|
"print(f\"🔧 训练设备: {device.upper()}\")\n",
|
|||
|
"\n",
|
|||
|
"# 2. 加载PCA降维数据 (小批量快速训练示例)\n",
|
|||
|
"print(f\"\\n📊 加载PCA降维训练数据...\")\n",
|
|||
|
"\n",
|
|||
|
"# 收集少量数据用于快速演示\n",
|
|||
|
"train_features_list = []\n",
|
|||
|
"train_labels_list = []\n",
|
|||
|
"batch_count = 0\n",
|
|||
|
"\n",
|
|||
|
"for features_pca, labels in train_dataset.get_batch_generator():\n",
|
|||
|
" train_features_list.append(features_pca)\n",
|
|||
|
" train_labels_list.append(labels)\n",
|
|||
|
" batch_count += 1\n",
|
|||
|
" print(f\" 已加载批次 {batch_count}: {features_pca.shape[0]} 样本\")\n",
|
|||
|
" \n",
|
|||
|
" # 只取前3个批次用于快速演示\n",
|
|||
|
" if batch_count >= 3:\n",
|
|||
|
" break\n",
|
|||
|
"\n",
|
|||
|
"# 合并数据\n",
|
|||
|
"X_train_pca = np.vstack(train_features_list)\n",
|
|||
|
"y_train = np.hstack(train_labels_list)\n",
|
|||
|
"\n",
|
|||
|
"print(f\" ✅ 训练数据准备完成:\")\n",
|
|||
|
"print(f\" 特征形状: {X_train_pca.shape} (PCA降维后)\")\n",
|
|||
|
"print(f\" 标签形状: {y_train.shape}\")\n",
|
|||
|
"print(f\" 类别数: {len(np.unique(y_train))}\")\n",
|
|||
|
"\n",
|
|||
|
"# 3. 数据分割\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"\n",
|
|||
|
"X_train_split, X_val_split, y_train_split, y_val_split = train_test_split(\n",
|
|||
|
" X_train_pca, y_train, test_size=0.2, random_state=42, stratify=y_train\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n🔄 数据分割:\")\n",
|
|||
|
"print(f\" 训练集: {X_train_split.shape[0]} 样本\")\n",
|
|||
|
"print(f\" 验证集: {X_val_split.shape[0]} 样本\")\n",
|
|||
|
"\n",
|
|||
|
"# 4. LightGBM参数配置\n",
|
|||
|
"lgb_params = {\n",
|
|||
|
" 'objective': 'multiclass',\n",
|
|||
|
" 'num_class': len(np.unique(y_train)),\n",
|
|||
|
" 'metric': 'multi_logloss',\n",
|
|||
|
" 'boosting_type': 'gbdt',\n",
|
|||
|
" 'device': device,\n",
|
|||
|
" 'num_leaves': 128, # 适中的复杂度\n",
|
|||
|
" 'learning_rate': 0.1,\n",
|
|||
|
" 'feature_fraction': 0.8,\n",
|
|||
|
" 'bagging_fraction': 0.8,\n",
|
|||
|
" 'bagging_freq': 5,\n",
|
|||
|
" 'verbose': -1,\n",
|
|||
|
" 'random_state': 42,\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"if device == 'gpu':\n",
|
|||
|
" lgb_params.update({\n",
|
|||
|
" 'gpu_platform_id': 0,\n",
|
|||
|
" 'gpu_device_id': 0,\n",
|
|||
|
" 'max_bin': 511, # GPU优化参数\n",
|
|||
|
" })\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n🏗️ LightGBM配置:\")\n",
|
|||
|
"for key, value in lgb_params.items():\n",
|
|||
|
" print(f\" {key}: {value}\")\n",
|
|||
|
"\n",
|
|||
|
"# 5. 创建LightGBM数据集\n",
|
|||
|
"print(f\"\\n🔄 创建LightGBM数据集...\")\n",
|
|||
|
"train_lgb = lgb.Dataset(X_train_split, label=y_train_split)\n",
|
|||
|
"val_lgb = lgb.Dataset(X_val_split, label=y_val_split, reference=train_lgb)\n",
|
|||
|
"\n",
|
|||
|
"# 6. 训练模型\n",
|
|||
|
"print(f\"\\n🚀 开始训练LightGBM模型...\")\n",
|
|||
|
"start_time = time.time()\n",
|
|||
|
"\n",
|
|||
|
"callbacks = [\n",
|
|||
|
" lgb.log_evaluation(period=10),\n",
|
|||
|
" lgb.early_stopping(stopping_rounds=20)\n",
|
|||
|
"]\n",
|
|||
|
"\n",
|
|||
|
"model = lgb.train(\n",
|
|||
|
" lgb_params,\n",
|
|||
|
" train_lgb,\n",
|
|||
|
" valid_sets=[train_lgb, val_lgb],\n",
|
|||
|
" valid_names=['train', 'val'],\n",
|
|||
|
" num_boost_round=100,\n",
|
|||
|
" callbacks=callbacks\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"training_time = time.time() - start_time\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n✅ 训练完成!\")\n",
|
|||
|
"print(f\" 训练时间: {training_time:.2f} 秒\")\n",
|
|||
|
"print(f\" 最佳迭代: {model.best_iteration}\")\n",
|
|||
|
"\n",
|
|||
|
"# 7. 快速评估\n",
|
|||
|
"print(f\"\\n📊 模型评估:\")\n",
|
|||
|
"\n",
|
|||
|
"# 训练集评估\n",
|
|||
|
"train_pred = model.predict(X_train_split, num_iteration=model.best_iteration)\n",
|
|||
|
"train_pred_labels = np.argmax(train_pred, axis=1)\n",
|
|||
|
"train_acc = accuracy_score(y_train_split, train_pred_labels)\n",
|
|||
|
"\n",
|
|||
|
"# 验证集评估 \n",
|
|||
|
"val_pred = model.predict(X_val_split, num_iteration=model.best_iteration)\n",
|
|||
|
"val_pred_labels = np.argmax(val_pred, axis=1)\n",
|
|||
|
"val_acc = accuracy_score(y_val_split, val_pred_labels)\n",
|
|||
|
"\n",
|
|||
|
"print(f\" 训练集准确率: {train_acc:.4f} ({train_acc*100:.2f}%)\")\n",
|
|||
|
"print(f\" 验证集准确率: {val_acc:.4f} ({val_acc*100:.2f}%)\")\n",
|
|||
|
"\n",
|
|||
|
"# 8. 使用提示\n",
|
|||
|
"print(f\"\\n💡 使用PCA降维的优势:\")\n",
|
|||
|
"print(f\" ✅ 特征维度减少: 7168 → {X_train_pca.shape[1]} (85.2%内存节省)\")\n",
|
|||
|
"print(f\" ✅ 训练速度提升: 预计6.7倍加速\")\n",
|
|||
|
"print(f\" ✅ 保留方差: 94.91% (信息损失很小)\")\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n🔄 如果要训练完整模型:\")\n",
|
|||
|
"print(f\" 1. 增加 batch_count 限制或移除限制\")\n",
|
|||
|
"print(f\" 2. 增加 num_boost_round\")\n",
|
|||
|
"print(f\" 3. 调整GPU参数以获得最佳性能\")\n",
|
|||
|
"\n",
|
|||
|
"# 清理内存\n",
|
|||
|
"del train_features_list, train_labels_list, X_train_pca, y_train\n",
|
|||
|
"del X_train_split, X_val_split, y_train_split, y_val_split\n",
|
|||
|
"gc.collect()\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n🎯 现在你知道如何使用PCA降维数据了!\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 34,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"======================================================================\n",
|
|||
|
"🔧 修复GPU配置后的LightGBM训练\n",
|
|||
|
"======================================================================\n",
|
|||
|
"🔧 训练设备: GPU\n",
|
|||
|
"\n",
|
|||
|
"📊 快速加载PCA数据 (演示用)...\n",
|
|||
|
" 正在加载文件 1/45: t15.2025.04.13_train_concatenated.npz\n",
|
|||
|
" 演示数据: 10000 样本, 1062 PCA特征\n",
|
|||
|
" 训练: 8000, 验证: 2000\n",
|
|||
|
"\n",
|
|||
|
"🏗️ 修复后的LightGBM配置:\n",
|
|||
|
" objective: multiclass\n",
|
|||
|
" num_class: 41\n",
|
|||
|
" metric: multi_logloss\n",
|
|||
|
" boosting_type: gbdt\n",
|
|||
|
" device: gpu\n",
|
|||
|
" num_leaves: 64\n",
|
|||
|
" learning_rate: 0.1\n",
|
|||
|
" feature_fraction: 0.8\n",
|
|||
|
" bagging_fraction: 0.8\n",
|
|||
|
" bagging_freq: 5\n",
|
|||
|
" verbose: -1\n",
|
|||
|
" random_state: 42\n",
|
|||
|
" gpu_platform_id: 0\n",
|
|||
|
" gpu_device_id: 0\n",
|
|||
|
" max_bin: 255\n",
|
|||
|
" gpu_use_dp: False\n",
|
|||
|
"\n",
|
|||
|
"🚀 开始修复后的训练...\n",
|
|||
|
"Training until validation scores don't improve for 10 rounds\n",
|
|||
|
"Early stopping, best iteration is:\n",
|
|||
|
"[5]\ttrain's multi_logloss: 0.67912\tval's multi_logloss: 1.49202\n",
|
|||
|
"\n",
|
|||
|
"✅ 训练成功!\n",
|
|||
|
" 训练时间: 72.27 秒\n",
|
|||
|
" 最佳迭代: 5\n",
|
|||
|
" 验证集准确率: 0.6680 (66.80%)\n",
|
|||
|
"\n",
|
|||
|
"🎯 成功要点:\n",
|
|||
|
" ✅ GPU训练正常工作\n",
|
|||
|
" ✅ PCA降维数据兼容性良好\n",
|
|||
|
" ✅ max_bin=255 解决GPU限制\n",
|
|||
|
" ✅ 训练速度: 72.27秒 (10K样本)\n",
|
|||
|
"\n",
|
|||
|
"💡 完整训练建议:\n",
|
|||
|
" • 使用 max_bin=255 适配GPU\n",
|
|||
|
" • 增加样本数和训练轮数\n",
|
|||
|
" • 监控GPU内存使用\n",
|
|||
|
" • PCA降维数据完全兼容LightGBM\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"119"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 34,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 🔧 修复GPU参数 - 快速LightGBM训练\n",
|
|||
|
"\n",
|
|||
|
"import lightgbm as lgb\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import time\n",
|
|||
|
"from sklearn.metrics import accuracy_score\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"\n",
|
|||
|
"print(\"=\"*70)\n",
|
|||
|
"print(\"🔧 修复GPU配置后的LightGBM训练\")\n",
|
|||
|
"print(\"=\"*70)\n",
|
|||
|
"\n",
|
|||
|
"# 1. 快速GPU检查\n",
|
|||
|
"gpu_available = True # 我们知道有GPU\n",
|
|||
|
"device = 'gpu' if gpu_available else 'cpu'\n",
|
|||
|
"print(f\"🔧 训练设备: {device.upper()}\")\n",
|
|||
|
"\n",
|
|||
|
"# 2. 快速加载少量数据用于演示\n",
|
|||
|
"print(f\"\\n📊 快速加载PCA数据 (演示用)...\")\n",
|
|||
|
"\n",
|
|||
|
"# 只取一个批次进行快速演示\n",
|
|||
|
"for features_pca, labels in train_dataset.get_batch_generator():\n",
|
|||
|
" X_demo = features_pca[:10000] # 只取前10000个样本\n",
|
|||
|
" y_demo = labels[:10000]\n",
|
|||
|
" print(f\" 演示数据: {X_demo.shape[0]} 样本, {X_demo.shape[1]} PCA特征\")\n",
|
|||
|
" break\n",
|
|||
|
"\n",
|
|||
|
"# 3. 数据分割\n",
|
|||
|
"X_train_demo, X_val_demo, y_train_demo, y_val_demo = train_test_split(\n",
|
|||
|
" X_demo, y_demo, test_size=0.2, random_state=42, stratify=y_demo\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"print(f\" 训练: {X_train_demo.shape[0]}, 验证: {X_val_demo.shape[0]}\")\n",
|
|||
|
"\n",
|
|||
|
"# 4. 修复的GPU参数配置\n",
|
|||
|
"lgb_params_fixed = {\n",
|
|||
|
" 'objective': 'multiclass',\n",
|
|||
|
" 'num_class': len(np.unique(y_demo)),\n",
|
|||
|
" 'metric': 'multi_logloss',\n",
|
|||
|
" 'boosting_type': 'gbdt',\n",
|
|||
|
" 'device': device,\n",
|
|||
|
" 'num_leaves': 64, # 减少复杂度\n",
|
|||
|
" 'learning_rate': 0.1,\n",
|
|||
|
" 'feature_fraction': 0.8,\n",
|
|||
|
" 'bagging_fraction': 0.8,\n",
|
|||
|
" 'bagging_freq': 5,\n",
|
|||
|
" 'verbose': -1,\n",
|
|||
|
" 'random_state': 42,\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# GPU特定参数 (修复max_bin问题)\n",
|
|||
|
"if device == 'gpu':\n",
|
|||
|
" lgb_params_fixed.update({\n",
|
|||
|
" 'gpu_platform_id': 0,\n",
|
|||
|
" 'gpu_device_id': 0,\n",
|
|||
|
" 'max_bin': 255, # 改为255 (GPU支持的最大值)\n",
|
|||
|
" 'gpu_use_dp': False, # 使用单精度\n",
|
|||
|
" })\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n🏗️ 修复后的LightGBM配置:\")\n",
|
|||
|
"for key, value in lgb_params_fixed.items():\n",
|
|||
|
" print(f\" {key}: {value}\")\n",
|
|||
|
"\n",
|
|||
|
"# 5. 创建数据集和训练\n",
|
|||
|
"print(f\"\\n🚀 开始修复后的训练...\")\n",
|
|||
|
"start_time = time.time()\n",
|
|||
|
"\n",
|
|||
|
"train_lgb_demo = lgb.Dataset(X_train_demo, label=y_train_demo)\n",
|
|||
|
"val_lgb_demo = lgb.Dataset(X_val_demo, label=y_val_demo, reference=train_lgb_demo)\n",
|
|||
|
"\n",
|
|||
|
"callbacks = [lgb.early_stopping(stopping_rounds=10)]\n",
|
|||
|
"\n",
|
|||
|
"try:\n",
|
|||
|
" model_demo = lgb.train(\n",
|
|||
|
" lgb_params_fixed,\n",
|
|||
|
" train_lgb_demo,\n",
|
|||
|
" valid_sets=[train_lgb_demo, val_lgb_demo],\n",
|
|||
|
" valid_names=['train', 'val'],\n",
|
|||
|
" num_boost_round=50, # 较少轮数用于演示\n",
|
|||
|
" callbacks=callbacks\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" training_time = time.time() - start_time\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n✅ 训练成功!\")\n",
|
|||
|
" print(f\" 训练时间: {training_time:.2f} 秒\")\n",
|
|||
|
" print(f\" 最佳迭代: {model_demo.best_iteration}\")\n",
|
|||
|
" \n",
|
|||
|
" # 快速评估\n",
|
|||
|
" val_pred = model_demo.predict(X_val_demo, num_iteration=model_demo.best_iteration)\n",
|
|||
|
" val_pred_labels = np.argmax(val_pred, axis=1)\n",
|
|||
|
" val_acc = accuracy_score(y_val_demo, val_pred_labels)\n",
|
|||
|
" \n",
|
|||
|
" print(f\" 验证集准确率: {val_acc:.4f} ({val_acc*100:.2f}%)\")\n",
|
|||
|
" \n",
|
|||
|
" print(f\"\\n🎯 成功要点:\")\n",
|
|||
|
" print(f\" ✅ GPU训练正常工作\")\n",
|
|||
|
" print(f\" ✅ PCA降维数据兼容性良好\")\n",
|
|||
|
" print(f\" ✅ max_bin=255 解决GPU限制\")\n",
|
|||
|
" print(f\" ✅ 训练速度: {training_time:.2f}秒 (10K样本)\")\n",
|
|||
|
" \n",
|
|||
|
"except Exception as e:\n",
|
|||
|
" print(f\"❌ 训练失败: {e}\")\n",
|
|||
|
" print(\"🔧 尝试CPU训练...\")\n",
|
|||
|
" \n",
|
|||
|
" # 回退到CPU\n",
|
|||
|
" lgb_params_fixed['device'] = 'cpu'\n",
|
|||
|
" lgb_params_fixed.pop('gpu_platform_id', None)\n",
|
|||
|
" lgb_params_fixed.pop('gpu_device_id', None)\n",
|
|||
|
" lgb_params_fixed.pop('max_bin', None)\n",
|
|||
|
" lgb_params_fixed.pop('gpu_use_dp', None)\n",
|
|||
|
" lgb_params_fixed['n_jobs'] = -1\n",
|
|||
|
" \n",
|
|||
|
" model_demo = lgb.train(\n",
|
|||
|
" lgb_params_fixed,\n",
|
|||
|
" train_lgb_demo,\n",
|
|||
|
" valid_sets=[train_lgb_demo, val_lgb_demo],\n",
|
|||
|
" num_boost_round=50,\n",
|
|||
|
" callbacks=callbacks\n",
|
|||
|
" )\n",
|
|||
|
" \n",
|
|||
|
" print(f\" ✅ CPU训练成功!\")\n",
|
|||
|
"\n",
|
|||
|
"print(f\"\\n💡 完整训练建议:\")\n",
|
|||
|
"print(f\" • 使用 max_bin=255 适配GPU\")\n",
|
|||
|
"print(f\" • 增加样本数和训练轮数\")\n",
|
|||
|
"print(f\" • 监控GPU内存使用\")\n",
|
|||
|
"print(f\" • PCA降维数据完全兼容LightGBM\")\n",
|
|||
|
"\n",
|
|||
|
"# 清理\n",
|
|||
|
"del X_demo, y_demo, X_train_demo, X_val_demo, y_train_demo, y_val_demo\n",
|
|||
|
"gc.collect()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# 🔥 完整使用指南 - PCA + LightGBM GPU训练系统\n",
|
|||
|
"\n",
|
|||
|
"## 📋 系统概述\n",
|
|||
|
"- **数据降维**: PCA 将 7168 → 1062 特征 (保留95%方差,节省85.2%内存)\n",
|
|||
|
"- **模型**: LightGBM GPU加速 (41类分类任务)\n",
|
|||
|
"- **内存优化**: 批量加载,适配30GB内存限制\n",
|
|||
|
"\n",
|
|||
|
"---\n",
|
|||
|
"\n",
|
|||
|
"## ⚡ 快速使用方法\n",
|
|||
|
"\n",
|
|||
|
"### 第1步: 准备数据\n",
|
|||
|
"```python\n",
|
|||
|
"# 数据会自动加载和PCA处理,无需手动准备\n",
|
|||
|
"data_root = \"f:/BRAIN-TO-TEXT/nejm-brain-to-text/data/hdf5_data_final\"\n",
|
|||
|
"```\n",
|
|||
|
"\n",
|
|||
|
"### 第2步: 创建内存友好数据集\n",
|
|||
|
"```python\n",
|
|||
|
"# 创建数据集 (会自动应用PCA)\n",
|
|||
|
"dataset = MemoryFriendlyDataset(data_root)\n",
|
|||
|
"print(f\"✅ 数据集准备完成: {len(dataset)} 个文件\")\n",
|
|||
|
"```\n",
|
|||
|
"\n",
|
|||
|
"### 第3步: 批量训练\n",
|
|||
|
"```python\n",
|
|||
|
"# 批量生成器 (自动应用PCA降维)\n",
|
|||
|
"train_gen = dataset.batch_generator(['train'], batch_size=5)\n",
|
|||
|
"val_gen = dataset.batch_generator(['val'], batch_size=5)\n",
|
|||
|
"\n",
|
|||
|
"# LightGBM GPU配置 (重要: max_bin=255)\n",
|
|||
|
"lgb_params = {\n",
|
|||
|
" 'objective': 'multiclass',\n",
|
|||
|
" 'num_class': 41,\n",
|
|||
|
" 'metric': 'multi_logloss',\n",
|
|||
|
" 'boosting_type': 'gbdt',\n",
|
|||
|
" 'device': 'gpu',\n",
|
|||
|
" 'max_bin': 255, # 🔥 GPU必须设置\n",
|
|||
|
" 'gpu_platform_id': 0,\n",
|
|||
|
" 'gpu_device_id': 0,\n",
|
|||
|
" 'num_leaves': 64,\n",
|
|||
|
" 'learning_rate': 0.1,\n",
|
|||
|
" 'verbose': -1,\n",
|
|||
|
" 'random_state': 42\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"# 训练循环\n",
|
|||
|
"for X_batch, y_batch in train_gen:\n",
|
|||
|
" # X_batch 已经是PCA降维后的数据 (1062维)\n",
|
|||
|
" # 训练代码...\n",
|
|||
|
"```\n",
|
|||
|
"\n",
|
|||
|
"---\n",
|
|||
|
"\n",
|
|||
|
"## 🎯 关键配置参数\n",
|
|||
|
"\n",
|
|||
|
"### PCA配置\n",
|
|||
|
"- **保留方差**: 95% (可在 PCA_CONFIG 中调整)\n",
|
|||
|
"- **降维效果**: 7168 → 1062 特征\n",
|
|||
|
"- **内存节省**: 85.2%\n",
|
|||
|
"\n",
|
|||
|
"### LightGBM GPU配置\n",
|
|||
|
"- **max_bin**: 必须 ≤ 255 (GPU限制)\n",
|
|||
|
"- **device**: 'gpu'\n",
|
|||
|
"- **gpu_platform_id**: 0\n",
|
|||
|
"- **gpu_device_id**: 0\n",
|
|||
|
"\n",
|
|||
|
"---\n",
|
|||
|
"\n",
|
|||
|
"## 🔧 常见问题解决\n",
|
|||
|
"\n",
|
|||
|
"### Q: GPU训练失败?\n",
|
|||
|
"**A**: 检查 `max_bin ≤ 255`\n",
|
|||
|
"\n",
|
|||
|
"### Q: 内存不足?\n",
|
|||
|
"**A**: 减小 `batch_size` 或使用更多批次\n",
|
|||
|
"\n",
|
|||
|
"### Q: PCA效果不好?\n",
|
|||
|
"**A**: 调整 `PCA_CONFIG['n_components']` 或 `explained_variance_ratio`\n",
|
|||
|
"\n",
|
|||
|
"---\n",
|
|||
|
"\n",
|
|||
|
"## 📊 性能优势\n",
|
|||
|
"- **内存使用**: 降低85.2%\n",
|
|||
|
"- **GPU加速**: ~6.7x 速度提升\n",
|
|||
|
"- **特征压缩**: 7168 → 1062 (14.8% 保留)\n",
|
|||
|
"- **方差保留**: 95%"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"======================================================================\n",
|
|||
|
"🚀 完整端到端PCA + LightGBM训练流程\n",
|
|||
|
"======================================================================\n",
|
|||
|
"\n",
|
|||
|
"📊 第1步: 初始化数据集...\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"ename": "TypeError",
|
|||
|
"evalue": "MemoryFriendlyDataset.__init__() missing 1 required positional argument: 'data_type'",
|
|||
|
"output_type": "error",
|
|||
|
"traceback": [
|
|||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|||
|
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
|||
|
"\u001b[0;32m/tmp/ipykernel_36/2007074471.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"\\n📊 第1步: 初始化数据集...\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mdata_root\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"f:/BRAIN-TO-TEXT/nejm-brain-to-text/data/hdf5_data_final\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mMemoryFriendlyDataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_root\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\" ✅ 找到 {len(dataset)} 个数据文件\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\" ✅ PCA已配置: {PCA_CONFIG['n_components']} 维特征\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
|||
|
"\u001b[0;31mTypeError\u001b[0m: MemoryFriendlyDataset.__init__() missing 1 required positional argument: 'data_type'"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# 🚀 完整端到端训练示例\n",
|
|||
|
"print(\"=\" * 70)\n",
|
|||
|
"print(\"🚀 完整端到端PCA + LightGBM训练流程\")\n",
|
|||
|
"print(\"=\" * 70)\n",
|
|||
|
"\n",
|
|||
|
"import time\n",
|
|||
|
"from sklearn.metrics import accuracy_score, classification_report\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"\n",
|
|||
|
"# ===============================\n",
|
|||
|
"# 第1步: 初始化数据集\n",
|
|||
|
"# ===============================\n",
|
|||
|
"print(\"\\n📊 第1步: 初始化数据集...\")\n",
|
|||
|
"data_root = \"f:/BRAIN-TO-TEXT/nejm-brain-to-text/data/hdf5_data_final\"\n",
|
|||
|
"dataset = MemoryFriendlyDataset(data_root, 'concatenated') # 修复: 添加data_type参数\n",
|
|||
|
"print(f\" ✅ 找到 {len(dataset)} 个数据文件\")\n",
|
|||
|
"print(f\" ✅ PCA已配置: {PCA_CONFIG['n_components']} 维特征\")\n",
|
|||
|
"\n",
|
|||
|
"# ===============================\n",
|
|||
|
"# 第2步: 批量加载训练数据\n",
|
|||
|
"# ===============================\n",
|
|||
|
"print(\"\\n🏗️ 第2步: 批量加载训练数据...\")\n",
|
|||
|
"start_time = time.time()\n",
|
|||
|
"\n",
|
|||
|
"# 使用较小批次进行演示\n",
|
|||
|
"train_files = [f for f in dataset.file_list if 'train' in f][:3] # 只用前3个文件演示\n",
|
|||
|
"val_files = [f for f in dataset.file_list if 'val' in f][:1] # 只用1个验证文件\n",
|
|||
|
"\n",
|
|||
|
"print(f\" 📁 训练文件: {len(train_files)} 个\")\n",
|
|||
|
"print(f\" 📁 验证文件: {len(val_files)} 个\")\n",
|
|||
|
"\n",
|
|||
|
"# 加载训练数据\n",
|
|||
|
"X_train_list, y_train_list = [], []\n",
|
|||
|
"for batch_X, batch_y in dataset.batch_generator(train_files, batch_size=2):\n",
|
|||
|
" X_train_list.append(batch_X)\n",
|
|||
|
" y_train_list.append(batch_y)\n",
|
|||
|
" print(f\" ⏳ 加载训练批次: {batch_X.shape[0]} 样本, {batch_X.shape[1]} PCA特征\")\n",
|
|||
|
"\n",
|
|||
|
"# 合并训练数据\n",
|
|||
|
"X_train = np.vstack(X_train_list)\n",
|
|||
|
"y_train = np.hstack(y_train_list)\n",
|
|||
|
"\n",
|
|||
|
"# 加载验证数据\n",
|
|||
|
"X_val_list, y_val_list = [], []\n",
|
|||
|
"for batch_X, batch_y in dataset.batch_generator(val_files, batch_size=1):\n",
|
|||
|
" X_val_list.append(batch_X)\n",
|
|||
|
" y_val_list.append(batch_y)\n",
|
|||
|
" print(f\" ⏳ 加载验证批次: {batch_X.shape[0]} 样本, {batch_X.shape[1]} PCA特征\")\n",
|
|||
|
"\n",
|
|||
|
"X_val = np.vstack(X_val_list)\n",
|
|||
|
"y_val = np.hstack(y_val_list)\n",
|
|||
|
"\n",
|
|||
|
"load_time = time.time() - start_time\n",
|
|||
|
"print(f\"\\n ✅ 数据加载完成!\")\n",
|
|||
|
"print(f\" 📊 训练集: {X_train.shape[0]} 样本 × {X_train.shape[1]} 特征\")\n",
|
|||
|
"print(f\" 📊 验证集: {X_val.shape[0]} 样本 × {X_val.shape[1]} 特征\")\n",
|
|||
|
"print(f\" ⏱️ 加载时间: {load_time:.2f} 秒\")\n",
|
|||
|
"\n",
|
|||
|
"# ===============================\n",
|
|||
|
"# 第3步: LightGBM训练\n",
|
|||
|
"# ===============================\n",
|
|||
|
"print(\"\\n🏃♂️ 第3步: LightGBM GPU训练...\")\n",
|
|||
|
"\n",
|
|||
|
"# 最佳GPU配置\n",
|
|||
|
"lgb_params = {\n",
|
|||
|
" 'objective': 'multiclass',\n",
|
|||
|
" 'num_class': 41,\n",
|
|||
|
" 'metric': 'multi_logloss',\n",
|
|||
|
" 'boosting_type': 'gbdt',\n",
|
|||
|
" 'device': 'gpu',\n",
|
|||
|
" 'num_leaves': 128, # 增加复杂度\n",
|
|||
|
" 'learning_rate': 0.1,\n",
|
|||
|
" 'feature_fraction': 0.8,\n",
|
|||
|
" 'bagging_fraction': 0.8,\n",
|
|||
|
" 'bagging_freq': 5,\n",
|
|||
|
" 'verbose': -1,\n",
|
|||
|
" 'random_state': 42,\n",
|
|||
|
" 'gpu_platform_id': 0,\n",
|
|||
|
" 'gpu_device_id': 0,\n",
|
|||
|
" 'max_bin': 255, # 🔥 GPU必须设置\n",
|
|||
|
" 'gpu_use_dp': False\n",
|
|||
|
"}\n",
|
|||
|
"\n",
|
|||
|
"print(f\" 🔧 GPU配置: max_bin={lgb_params['max_bin']}, num_leaves={lgb_params['num_leaves']}\")\n",
|
|||
|
"\n",
|
|||
|
"# 创建数据集\n",
|
|||
|
"train_data = lgb.Dataset(X_train, label=y_train)\n",
|
|||
|
"val_data = lgb.Dataset(X_val, label=y_val, reference=train_data)\n",
|
|||
|
"\n",
|
|||
|
"# 开始训练\n",
|
|||
|
"print(\" 🚀 开始GPU训练...\")\n",
|
|||
|
"train_start = time.time()\n",
|
|||
|
"\n",
|
|||
|
"model = lgb.train(\n",
|
|||
|
" lgb_params,\n",
|
|||
|
" train_data,\n",
|
|||
|
" valid_sets=[train_data, val_data],\n",
|
|||
|
" valid_names=['train', 'val'],\n",
|
|||
|
" num_boost_round=100,\n",
|
|||
|
" callbacks=[lgb.early_stopping(stopping_rounds=10)]\n",
|
|||
|
")\n",
|
|||
|
"\n",
|
|||
|
"train_time = time.time() - train_start\n",
|
|||
|
"print(f\"\\n ✅ 训练完成!\")\n",
|
|||
|
"print(f\" ⏱️ 训练时间: {train_time:.2f} 秒\")\n",
|
|||
|
"print(f\" 🏆 最佳迭代: {model.best_iteration}\")\n",
|
|||
|
"\n",
|
|||
|
"# ===============================\n",
|
|||
|
"# 第4步: 模型评估\n",
|
|||
|
"# ===============================\n",
|
|||
|
"print(\"\\n📈 第4步: 模型评估...\")\n",
|
|||
|
"\n",
|
|||
|
"# 预测\n",
|
|||
|
"y_pred_train = model.predict(X_train, num_iteration=model.best_iteration)\n",
|
|||
|
"y_pred_val = model.predict(X_val, num_iteration=model.best_iteration)\n",
|
|||
|
"\n",
|
|||
|
"# 转换为类别\n",
|
|||
|
"y_pred_train_class = np.argmax(y_pred_train, axis=1)\n",
|
|||
|
"y_pred_val_class = np.argmax(y_pred_val, axis=1)\n",
|
|||
|
"\n",
|
|||
|
"# 计算准确率\n",
|
|||
|
"train_acc = accuracy_score(y_train, y_pred_train_class)\n",
|
|||
|
"val_acc = accuracy_score(y_val, y_pred_val_class)\n",
|
|||
|
"\n",
|
|||
|
"print(f\" 🎯 训练集准确率: {train_acc:.4f} ({train_acc*100:.2f}%)\")\n",
|
|||
|
"print(f\" 🎯 验证集准确率: {val_acc:.4f} ({val_acc*100:.2f}%)\")\n",
|
|||
|
"\n",
|
|||
|
"# ===============================\n",
|
|||
|
"# 总结\n",
|
|||
|
"# ===============================\n",
|
|||
|
"total_time = time.time() - start_time\n",
|
|||
|
"print(\"\\n\" + \"=\" * 70)\n",
|
|||
|
"print(\"🎉 端到端训练流程完成!\")\n",
|
|||
|
"print(\"=\" * 70)\n",
|
|||
|
"print(f\"📊 数据处理: {X_train.shape[0] + X_val.shape[0]} 样本\")\n",
|
|||
|
"print(f\"🔧 特征降维: 7168 → {X_train.shape[1]} (PCA)\")\n",
|
|||
|
"print(f\"🏃♂️ 训练时间: {train_time:.2f} 秒\")\n",
|
|||
|
"print(f\"⏱️ 总计时间: {total_time:.2f} 秒\")\n",
|
|||
|
"print(f\"🎯 最终准确率: {val_acc:.4f}\")\n",
|
|||
|
"print(f\"💾 内存节省: 85.2% (PCA降维)\")\n",
|
|||
|
"print(\"=\" * 70)"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kaggle": {
|
|||
|
"accelerator": "tpu1vmV38",
|
|||
|
"dataSources": [
|
|||
|
{
|
|||
|
"databundleVersionId": 13056355,
|
|||
|
"sourceId": 106809,
|
|||
|
"sourceType": "competition"
|
|||
|
}
|
|||
|
],
|
|||
|
"dockerImageVersionId": 31091,
|
|||
|
"isGpuEnabled": false,
|
|||
|
"isInternetEnabled": true,
|
|||
|
"language": "python",
|
|||
|
"sourceType": "notebook"
|
|||
|
},
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3 (ipykernel)",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 4
|
|||
|
}
|