2907 lines
154 KiB
Plaintext
2907 lines
154 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 环境配置 与 utils"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Looking in indexes: https://download.pytorch.org/whl/cu126\n",
|
||
"Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n",
|
||
"Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (0.21.0+cu124)\n",
|
||
"Requirement already satisfied: torchaudio in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n",
|
||
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n",
|
||
"Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.14.0)\n",
|
||
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.5)\n",
|
||
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n",
|
||
"Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.5.1)\n",
|
||
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
|
||
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
|
||
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
|
||
"Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch) (9.1.0.70)\n",
|
||
"Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.5.8)\n",
|
||
"Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch) (11.2.1.3)\n",
|
||
"Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch) (10.3.5.147)\n",
|
||
"Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch) (11.6.1.9)\n",
|
||
"Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch) (12.3.1.170)\n",
|
||
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n",
|
||
"Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n",
|
||
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
|
||
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
|
||
"Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n",
|
||
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n",
|
||
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n",
|
||
"Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision) (1.26.4)\n",
|
||
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (11.2.1)\n",
|
||
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n",
|
||
"Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.3.8)\n",
|
||
"Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.2.4)\n",
|
||
"Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (0.1.1)\n",
|
||
"Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2025.2.0)\n",
|
||
"Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2022.2.0)\n",
|
||
"Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2.4.1)\n",
|
||
"Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2024.2.0)\n",
|
||
"Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2022.2.0)\n",
|
||
"Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy->torchvision) (1.4.0)\n",
|
||
"Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy->torchvision) (2024.2.0)\n",
|
||
"Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy->torchvision) (2024.2.0)\n",
|
||
"Requirement already satisfied: jupyter==1.1.1 in /usr/local/lib/python3.11/dist-packages (1.1.1)\n",
|
||
"Requirement already satisfied: numpy<2.1.0,>=1.26.0 in /usr/local/lib/python3.11/dist-packages (1.26.4)\n",
|
||
"Requirement already satisfied: pandas==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n",
|
||
"Requirement already satisfied: matplotlib==3.10.1 in /usr/local/lib/python3.11/dist-packages (3.10.1)\n",
|
||
"Requirement already satisfied: scipy==1.15.2 in /usr/local/lib/python3.11/dist-packages (1.15.2)\n",
|
||
"Requirement already satisfied: scikit-learn==1.6.1 in /usr/local/lib/python3.11/dist-packages (1.6.1)\n",
|
||
"Requirement already satisfied: tqdm==4.67.1 in /usr/local/lib/python3.11/dist-packages (4.67.1)\n",
|
||
"Requirement already satisfied: g2p_en==2.1.0 in /usr/local/lib/python3.11/dist-packages (2.1.0)\n",
|
||
"Requirement already satisfied: h5py==3.13.0 in /usr/local/lib/python3.11/dist-packages (3.13.0)\n",
|
||
"Requirement already satisfied: omegaconf==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n",
|
||
"Requirement already satisfied: editdistance==0.8.1 in /usr/local/lib/python3.11/dist-packages (0.8.1)\n",
|
||
"Requirement already satisfied: huggingface-hub==0.33.1 in /usr/local/lib/python3.11/dist-packages (0.33.1)\n",
|
||
"Requirement already satisfied: transformers==4.53.0 in /usr/local/lib/python3.11/dist-packages (4.53.0)\n",
|
||
"Requirement already satisfied: tokenizers==0.21.2 in /usr/local/lib/python3.11/dist-packages (0.21.2)\n",
|
||
"Requirement already satisfied: accelerate==1.8.1 in /usr/local/lib/python3.11/dist-packages (1.8.1)\n",
|
||
"Requirement already satisfied: bitsandbytes==0.46.0 in /usr/local/lib/python3.11/dist-packages (0.46.0)\n",
|
||
"Requirement already satisfied: notebook in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.5.4)\n",
|
||
"Requirement already satisfied: jupyter-console in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.1.0)\n",
|
||
"Requirement already satisfied: nbconvert in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.4.5)\n",
|
||
"Requirement already satisfied: ipykernel in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.17.1)\n",
|
||
"Requirement already satisfied: ipywidgets in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (8.1.5)\n",
|
||
"Requirement already satisfied: jupyterlab in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (3.6.8)\n",
|
||
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2.9.0.post0)\n",
|
||
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n",
|
||
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n",
|
||
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.3.2)\n",
|
||
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (0.12.1)\n",
|
||
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (4.58.4)\n",
|
||
"Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.4.8)\n",
|
||
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (25.0)\n",
|
||
"Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (11.2.1)\n",
|
||
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (3.0.9)\n",
|
||
"Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (1.5.1)\n",
|
||
"Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (3.6.0)\n",
|
||
"Requirement already satisfied: nltk>=3.2.4 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (3.9.1)\n",
|
||
"Requirement already satisfied: inflect>=0.3.1 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (7.5.0)\n",
|
||
"Requirement already satisfied: distance>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (0.1.3)\n",
|
||
"Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (4.9.3)\n",
|
||
"Requirement already satisfied: PyYAML>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (6.0.2)\n",
|
||
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (3.18.0)\n",
|
||
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2025.5.1)\n",
|
||
"Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2.32.4)\n",
|
||
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (4.14.0)\n",
|
||
"Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (1.1.5)\n",
|
||
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (2024.11.6)\n",
|
||
"Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (0.5.3)\n",
|
||
"Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (7.0.0)\n",
|
||
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (2.6.0+cu124)\n",
|
||
"Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.3.8)\n",
|
||
"Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.2.4)\n",
|
||
"Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (0.1.1)\n",
|
||
"Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2025.2.0)\n",
|
||
"Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2022.2.0)\n",
|
||
"Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2.4.1)\n",
|
||
"Requirement already satisfied: more_itertools>=8.5.0 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (10.7.0)\n",
|
||
"Requirement already satisfied: typeguard>=4.0.1 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (4.4.4)\n",
|
||
"Requirement already satisfied: click in /usr/local/lib/python3.11/dist-packages (from nltk>=3.2.4->g2p_en==2.1.0) (8.2.1)\n",
|
||
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas==2.3.0) (1.17.0)\n",
|
||
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.5)\n",
|
||
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.1.6)\n",
|
||
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
||
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
||
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
||
"Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (9.1.0.70)\n",
|
||
"Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.5.8)\n",
|
||
"Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.2.1.3)\n",
|
||
"Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (10.3.5.147)\n",
|
||
"Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.6.1.9)\n",
|
||
"Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.3.1.170)\n",
|
||
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (0.6.2)\n",
|
||
"Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (2.21.5)\n",
|
||
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
||
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
||
"Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.2.0)\n",
|
||
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (1.13.1)\n",
|
||
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=2.0.0->accelerate==1.8.1) (1.3.0)\n",
|
||
"Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.8.0)\n",
|
||
"Requirement already satisfied: ipython>=7.23.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (7.34.0)\n",
|
||
"Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (8.6.3)\n",
|
||
"Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (0.1.7)\n",
|
||
"Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.6.0)\n",
|
||
"Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (24.0.1)\n",
|
||
"Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (6.5.1)\n",
|
||
"Requirement already satisfied: traitlets>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (5.7.1)\n",
|
||
"Requirement already satisfied: comm>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (0.2.2)\n",
|
||
"Requirement already satisfied: widgetsnbextension~=4.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (4.0.14)\n",
|
||
"Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (3.0.15)\n",
|
||
"Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (3.0.51)\n",
|
||
"Requirement already satisfied: pygments in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (2.19.2)\n",
|
||
"Requirement already satisfied: jupyter-core in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (5.8.1)\n",
|
||
"Requirement already satisfied: jupyterlab-server~=2.19 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.27.3)\n",
|
||
"Requirement already satisfied: jupyter-server<3,>=1.16.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.12.5)\n",
|
||
"Requirement already satisfied: jupyter-ydoc~=0.2.4 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.2.5)\n",
|
||
"Requirement already satisfied: jupyter-server-ydoc~=0.8.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.8.0)\n",
|
||
"Requirement already satisfied: nbclassic in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (1.3.1)\n",
|
||
"Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (25.1.0)\n",
|
||
"Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.2.0)\n",
|
||
"Requirement already satisfied: nbformat in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (5.10.4)\n",
|
||
"Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (1.8.3)\n",
|
||
"Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.18.1)\n",
|
||
"Requirement already satisfied: prometheus-client in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.22.1)\n",
|
||
"Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.8.4)\n",
|
||
"Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.3.0)\n",
|
||
"Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.4)\n",
|
||
"Requirement already satisfied: bleach in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (6.2.0)\n",
|
||
"Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (1.5.1)\n",
|
||
"Requirement already satisfied: testpath in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.6.0)\n",
|
||
"Requirement already satisfied: defusedxml in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.7.1)\n",
|
||
"Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (4.13.4)\n",
|
||
"Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.5.13)\n",
|
||
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (3.0.2)\n",
|
||
"Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
|
||
"Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2022.2.0)\n",
|
||
"Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy<2.1.0,>=1.26.0) (1.4.0)\n",
|
||
"Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
|
||
"Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.4.2)\n",
|
||
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.10)\n",
|
||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2.5.0)\n",
|
||
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2025.6.15)\n",
|
||
"Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
|
||
"Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (75.2.0)\n",
|
||
"Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.19.2)\n",
|
||
"Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.4.2)\n",
|
||
"Requirement already satisfied: pickleshare in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.7.5)\n",
|
||
"Requirement already satisfied: backcall in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.2.0)\n",
|
||
"Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.9.0)\n",
|
||
"Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.11/dist-packages (from jupyter-core->jupyterlab->jupyter==1.1.1) (4.3.8)\n",
|
||
"Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (4.9.0)\n",
|
||
"Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.12.0)\n",
|
||
"Requirement already satisfied: jupyter-server-terminals in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.5.3)\n",
|
||
"Requirement already satisfied: overrides in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (7.7.0)\n",
|
||
"Requirement already satisfied: websocket-client in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.8.0)\n",
|
||
"Requirement already satisfied: jupyter-server-fileid<1,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.9.3)\n",
|
||
"Requirement already satisfied: ypy-websocket<0.9.0,>=0.8.2 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.8.4)\n",
|
||
"Requirement already satisfied: y-py<0.7.0,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-ydoc~=0.2.4->jupyterlab->jupyter==1.1.1) (0.6.2)\n",
|
||
"Requirement already satisfied: babel>=2.10 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2.17.0)\n",
|
||
"Requirement already satisfied: json5>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.12.0)\n",
|
||
"Requirement already satisfied: jsonschema>=4.18.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (4.24.0)\n",
|
||
"Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.11/dist-packages (from nbclassic->jupyterlab->jupyter==1.1.1) (0.2.4)\n",
|
||
"Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.11/dist-packages (from nbformat->notebook->jupyter==1.1.1) (2.21.1)\n",
|
||
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->jupyter-console->jupyter==1.1.1) (0.2.13)\n",
|
||
"Requirement already satisfied: ptyprocess in /usr/local/lib/python3.11/dist-packages (from terminado>=0.8.3->notebook->jupyter==1.1.1) (0.7.0)\n",
|
||
"Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.11/dist-packages (from argon2-cffi->notebook->jupyter==1.1.1) (21.2.0)\n",
|
||
"Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->nbconvert->jupyter==1.1.1) (2.7)\n",
|
||
"Requirement already satisfied: webencodings in /usr/local/lib/python3.11/dist-packages (from bleach->nbconvert->jupyter==1.1.1) (0.5.1)\n",
|
||
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio>=3.1.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.1)\n",
|
||
"Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.11/dist-packages (from jedi>=0.16->ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.8.4)\n",
|
||
"Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (25.3.0)\n",
|
||
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2025.4.1)\n",
|
||
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.36.2)\n",
|
||
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.25.1)\n",
|
||
"Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.3.0)\n",
|
||
"Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.4)\n",
|
||
"Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.1)\n",
|
||
"Requirement already satisfied: aiofiles<23,>=22.1.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (22.1.0)\n",
|
||
"Requirement already satisfied: aiosqlite<1,>=0.17.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.21.0)\n",
|
||
"Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (1.17.1)\n",
|
||
"Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (2.22)\n",
|
||
"Requirement already satisfied: fqdn in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.5.1)\n",
|
||
"Requirement already satisfied: isoduration in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (20.11.0)\n",
|
||
"Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.0.0)\n",
|
||
"Requirement already satisfied: uri-template in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n",
|
||
"Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (24.11.1)\n",
|
||
"Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.11/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n",
|
||
"Requirement already satisfied: types-python-dateutil>=2.8.10 in /usr/local/lib/python3.11/dist-packages (from arrow>=0.15.0->isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (2.9.0.20250516)\n",
|
||
"Obtaining file:///kaggle/working/nejm-brain-to-text\n",
|
||
" Preparing metadata (setup.py): started\n",
|
||
" Preparing metadata (setup.py): finished with status 'done'\n",
|
||
"Installing collected packages: nejm_b2txt_utils\n",
|
||
" Running setup.py develop for nejm_b2txt_utils\n",
|
||
"Successfully installed nejm_b2txt_utils-0.0.0\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Cloning into 'nejm-brain-to-text'...\n",
|
||
"cp: cannot stat '/kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl': No such file or directory\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"%%bash\n",
|
||
"cd /kaggle/working/\n",
|
||
"rm -rf /kaggle/working/nejm-brain-to-text/\n",
|
||
"git clone https://github.com/ZH-CEN/nejm-brain-to-text.git\n",
|
||
"cd /kaggle/working/nejm-brain-to-text/\n",
|
||
"cp /kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl /kaggle/working/nejm-brain-to-text/data/t15_copyTask.pkl\n",
|
||
"# Install PyTorch with CUDA 12.6\n",
|
||
"pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n",
|
||
"\n",
|
||
"# Install additional packages with compatible versions\n",
|
||
"# TODO: remove redis\n",
|
||
"pip install \\\n",
|
||
" jupyter==1.1.1 \\\n",
|
||
" \"numpy>=1.26.0,<2.1.0\" \\\n",
|
||
" pandas==2.3.0 \\\n",
|
||
" matplotlib==3.10.1 \\\n",
|
||
" scipy==1.15.2 \\\n",
|
||
" scikit-learn==1.6.1 \\\n",
|
||
" tqdm==4.67.1 \\\n",
|
||
" g2p_en==2.1.0 \\\n",
|
||
" h5py==3.13.0 \\\n",
|
||
" omegaconf==2.3.0 \\\n",
|
||
" editdistance==0.8.1 \\\n",
|
||
" huggingface-hub==0.33.1 \\\n",
|
||
" transformers==4.53.0 \\\n",
|
||
" tokenizers==0.21.2 \\\n",
|
||
" accelerate==1.8.1 \\\n",
|
||
" bitsandbytes==0.46.0\n",
|
||
"\n",
|
||
"# Install the local package\n",
|
||
"pip install -e .\n",
|
||
"ln -s /kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline /kaggle/working/nejm-brain-to-text/data\n",
|
||
"ln -s /kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final /kaggle/working/nejm-brain-to-text/data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/kaggle/working/nejm-brain-to-text\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"%cd /kaggle/working/nejm-brain-to-text\n",
|
||
"import numpy as np\n",
|
||
"import os\n",
|
||
"import pickle\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import matplotlib\n",
|
||
"from g2p_en import G2p\n",
|
||
"import pandas as pd\n",
|
||
"import numpy as np\n",
|
||
"from nejm_b2txt_utils.general_utils import *\n",
|
||
"\n",
|
||
"matplotlib.rcParams['pdf.fonttype'] = 42\n",
|
||
"matplotlib.rcParams['ps.fonttype'] = 42\n",
|
||
"matplotlib.rcParams['font.family'] = 'sans-serif'\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import h5py\n",
|
||
"def load_h5py_file(file_path, b2txt_csv_df):\n",
|
||
" data = {\n",
|
||
" 'neural_features': [],\n",
|
||
" 'n_time_steps': [],\n",
|
||
" 'seq_class_ids': [],\n",
|
||
" 'seq_len': [],\n",
|
||
" 'transcriptions': [],\n",
|
||
" 'sentence_label': [],\n",
|
||
" 'session': [],\n",
|
||
" 'block_num': [],\n",
|
||
" 'trial_num': [],\n",
|
||
" 'corpus': [],\n",
|
||
" }\n",
|
||
" # Open the hdf5 file for that day\n",
|
||
" with h5py.File(file_path, 'r') as f:\n",
|
||
"\n",
|
||
" keys = list(f.keys())\n",
|
||
"\n",
|
||
" # For each trial in the selected trials in that day\n",
|
||
" for key in keys:\n",
|
||
" g = f[key]\n",
|
||
"\n",
|
||
" neural_features = g['input_features'][:] # pyright: ignore[reportIndexIssue]\n",
|
||
" n_time_steps = g.attrs['n_time_steps']\n",
|
||
" seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None # type: ignore\n",
|
||
" seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None\n",
|
||
" transcription = g['transcription'][:] if 'transcription' in g else None # type: ignore\n",
|
||
" sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None # pyright: ignore[reportIndexIssue]\n",
|
||
" session = g.attrs['session']\n",
|
||
" block_num = g.attrs['block_num']\n",
|
||
" trial_num = g.attrs['trial_num']\n",
|
||
"\n",
|
||
" # match this trial up with the csv to get the corpus name\n",
|
||
" year, month, day = session.split('.')[1:] # pyright: ignore[reportAttributeAccessIssue]\n",
|
||
" date = f'{year}-{month}-{day}'\n",
|
||
" row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)]\n",
|
||
" corpus_name = row['Corpus'].values[0]\n",
|
||
"\n",
|
||
" data['neural_features'].append(neural_features)\n",
|
||
" data['n_time_steps'].append(n_time_steps)\n",
|
||
" data['seq_class_ids'].append(seq_class_ids)\n",
|
||
" data['seq_len'].append(seq_len)\n",
|
||
" data['transcriptions'].append(transcription)\n",
|
||
" data['sentence_label'].append(sentence_label)\n",
|
||
" data['session'].append(session)\n",
|
||
" data['block_num'].append(block_num)\n",
|
||
" data['trial_num'].append(trial_num)\n",
|
||
" data['corpus'].append(corpus_name)\n",
|
||
" return data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"LOGIT_TO_PHONEME = [\n",
|
||
" 'BLANK',\n",
|
||
" 'AA', 'AE', 'AH', 'AO', 'AW',\n",
|
||
" 'AY', 'B', 'CH', 'D', 'DH',\n",
|
||
" 'EH', 'ER', 'EY', 'F', 'G',\n",
|
||
" 'HH', 'IH', 'IY', 'JH', 'K',\n",
|
||
" 'L', 'M', 'N', 'NG', 'OW',\n",
|
||
" 'OY', 'P', 'R', 'S', 'SH',\n",
|
||
" 'T', 'TH', 'UH', 'UW', 'V',\n",
|
||
" 'W', 'Y', 'Z', 'ZH',\n",
|
||
" ' | ',\n",
|
||
"]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 数据分析与预处理"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 数据准备"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/kaggle/working/nejm-brain-to-text\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"%cd /kaggle/working/nejm-brain-to-text/"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"data = load_h5py_file(file_path='data/hdf5_data_final/t15.2023.08.11/data_train.hdf5',\n",
|
||
" b2txt_csv_df=pd.read_csv('data/t15_copyTaskData_description.csv'))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"- **任务介绍** :机器学习解决高维信号的模式识别问题"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"我们的数据集标签缺少时间戳,现在要进行的是半监督学习"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"- 音素时间均等分割或者按照调研数据设定初始长度。然后筛掉异常值。提取出可用的训练集,再控制时间长短,查看样本类的长度"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'neural_features': array([[ 2.3076649 , -0.78699756, -0.64687246, ..., 0.57367045,\n",
|
||
" -0.7091646 , -0.11018186],\n",
|
||
" [-0.5859305 , -0.78699756, -0.64687246, ..., 0.3122117 ,\n",
|
||
" 1.7943763 , -0.76884896],\n",
|
||
" [-0.5859305 , -0.78699756, -0.64687246, ..., -0.21193463,\n",
|
||
" -0.8481289 , -0.7648201 ],\n",
|
||
" ...,\n",
|
||
" [-0.5859305 , 0.22756557, 0.9262037 , ..., -0.34710956,\n",
|
||
" 0.9710176 , 2.5397465 ],\n",
|
||
" [-0.5859305 , 0.22756557, -0.64687246, ..., -0.83613133,\n",
|
||
" -0.68723625, 0.10479005],\n",
|
||
" [ 0.8608672 , -0.78699756, -0.64687246, ..., -0.7171131 ,\n",
|
||
" 0.7417906 , -0.7008622 ]], dtype=float32),\n",
|
||
" 'n_time_steps': 321,\n",
|
||
" 'seq_class_ids': array([ 7, 28, 17, 24, 40, 17, 31, 40, 20, 21, 25, 29, 12, 40, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0], dtype=int32),\n",
|
||
" 'seq_len': 14,\n",
|
||
" 'transcriptions': array([ 66, 114, 105, 110, 103, 32, 105, 116, 32, 99, 108, 111, 115,\n",
|
||
" 101, 114, 46, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0], dtype=int32),\n",
|
||
" 'sentence_label': 'Bring it closer.',\n",
|
||
" 'session': 't15.2023.08.11',\n",
|
||
" 'block_num': 2,\n",
|
||
" 'trial_num': 0,\n",
|
||
" 'corpus': '50-Word'}"
|
||
]
|
||
},
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"def data_patch(data, index):\n",
|
||
" data_patch = {}\n",
|
||
" data_patch['neural_features'] = data['neural_features'][index]\n",
|
||
" data_patch['n_time_steps'] = data['n_time_steps'][index]\n",
|
||
" data_patch['seq_class_ids'] = data['seq_class_ids'][index]\n",
|
||
" data_patch['seq_len'] = data['seq_len'][index]\n",
|
||
" data_patch['transcriptions'] = data['transcriptions'][index]\n",
|
||
" data_patch['sentence_label'] = data['sentence_label'][index]\n",
|
||
" data_patch['session'] = data['session'][index]\n",
|
||
" data_patch['block_num'] = data['block_num'][index]\n",
|
||
" data_patch['trial_num'] = data['trial_num'][index]\n",
|
||
" data_patch['corpus'] = data['corpus'][index]\n",
|
||
" return data_patch\n",
|
||
"\n",
|
||
"data_patch(data, 0)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"d1 = data_patch(data, 0)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Transcriptions non-zero length: 16\n",
|
||
"Seq class ids non-zero length: 14\n",
|
||
"Seq len: 14\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"trans_len = len([x for x in d1['transcriptions'] if x != 0])\n",
|
||
"seq_len_nonzero = len([x for x in d1['seq_class_ids'] if x != 0])\n",
|
||
"seq_len = d1['seq_len']\n",
|
||
"print(f\"Transcriptions non-zero length: {trans_len}\")\n",
|
||
"print(f\"Seq class ids non-zero length: {seq_len_nonzero}\")\n",
|
||
"print(f\"Seq len: {seq_len}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Number of feature sequences: 14\n",
|
||
"Shape of first sequence: (22, 512)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"def create_time_windows(d1):\n",
|
||
" import numpy as np\n",
|
||
" n_time_steps = d1['n_time_steps']\n",
|
||
" seq_len = d1['seq_len']\n",
|
||
" # Create equal windows\n",
|
||
" edges = np.linspace(0, n_time_steps, seq_len + 1, dtype=int)\n",
|
||
" windows = [(edges[i], edges[i+1]) for i in range(seq_len)]\n",
|
||
" \n",
|
||
" # Extract feature sequences for each window\n",
|
||
" feature_sequences = []\n",
|
||
" for start, end in windows:\n",
|
||
" seq = d1['neural_features'][start:end, :]\n",
|
||
" feature_sequences.append(seq)\n",
|
||
" \n",
|
||
" return feature_sequences\n",
|
||
"\n",
|
||
"# Example usage\n",
|
||
"feature_sequences = create_time_windows(d1)\n",
|
||
"print(\"Number of feature sequences:\", len(feature_sequences))\n",
|
||
"print(\"Shape of first sequence:\", feature_sequences[0].shape)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train: 45, Val: 41, Test: 41\n",
|
||
"Train files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_train.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.08.11/data_train.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_train.hdf5']\n",
|
||
"Val files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_val.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_val.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2024.03.08/data_val.hdf5']\n",
|
||
"Test files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_test.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_test.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2024.03.08/data_test.hdf5']\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import os\n",
|
||
"\n",
|
||
"def scan_hdf5_files(base_path):\n",
|
||
" train_files = []\n",
|
||
" val_files = []\n",
|
||
" test_files = []\n",
|
||
" for root, dirs, files in os.walk(base_path):\n",
|
||
" for file in files:\n",
|
||
" if file.endswith('.hdf5'):\n",
|
||
" abs_path = os.path.abspath(os.path.join(root, file))\n",
|
||
" if 'data_train.hdf5' in file:\n",
|
||
" train_files.append(abs_path)\n",
|
||
" elif 'data_val.hdf5' in file:\n",
|
||
" val_files.append(abs_path)\n",
|
||
" elif 'data_test.hdf5' in file:\n",
|
||
" test_files.append(abs_path)\n",
|
||
" return train_files, val_files, test_files\n",
|
||
"\n",
|
||
"# Example usage\n",
|
||
"FILE_PATH = 'data/hdf5_data_final'\n",
|
||
"train_list, val_list, test_list = scan_hdf5_files(FILE_PATH)\n",
|
||
"print(f\"Train: {len(train_list)}, Val: {len(val_list)}, Test: {len(test_list)}\")\n",
|
||
"print(\"Train files (first 3):\", train_list[:3])\n",
|
||
"print(\"Val files (first 3):\", val_list[:3])\n",
|
||
"print(\"Test files (first 3):\", test_list[:3])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 🚀 完整数据集处理工作流\n",
|
||
"\n",
|
||
"创建一个自动化工作流,处理所有sessions的训练集、验证集、测试集数据,生成包含40类音素预测的完整特征数据集。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"🧠 创建真实RNN模型加载器...\n",
|
||
" 未检测到CUDA,使用CPU\n",
|
||
"🧠 RNN模型加载器初始化:\n",
|
||
" 模型路径: ./data/t15_pretrained_rnn_baseline\n",
|
||
" 使用设备: cpu\n",
|
||
"✅ 模型文件检查通过\n",
|
||
"⚠️ 无法导入rnn_model,尝试从model_training目录导入...\n",
|
||
"✅ 从model_training目录成功导入GRUDecoder\n",
|
||
"\n",
|
||
"🔄 加载RNN模型...\n",
|
||
" ✅ 模型配置加载完成\n",
|
||
" 输入特征维度: 512\n",
|
||
" 隐藏单元数: 768\n",
|
||
" 层数: 5\n",
|
||
" 输出类别数: 41\n",
|
||
" ✅ 模型架构初始化完成\n",
|
||
" 🔐 加载检查点(注意:使用weights_only=False,请确保检查点来源可信)...\n",
|
||
" ✅ 模型架构初始化完成\n",
|
||
" 🔐 加载检查点(注意:使用weights_only=False,请确保检查点来源可信)...\n",
|
||
" 📋 找到model_state_dict键\n",
|
||
" 🔍 匹配模型参数...\n",
|
||
" 📊 参数匹配统计: 113/113 个键匹配成功\n",
|
||
" ✅ 预训练权重加载完成 (113个参数)\n",
|
||
" 📊 模型参数统计:\n",
|
||
" 总参数数量: 44,315,177\n",
|
||
" Day-specific参数: 11,819,520 (26.7%)\n",
|
||
" 会话数量: 45\n",
|
||
" 📅 支持的会话:\n",
|
||
" 0: t15.2023.08.11\n",
|
||
" 1: t15.2023.08.13\n",
|
||
" 2: t15.2023.08.18\n",
|
||
" 3: t15.2023.08.20\n",
|
||
" 4: t15.2023.08.25\n",
|
||
" ... 还有 40 个会话\n",
|
||
"\n",
|
||
"🎉 RNN模型加载成功!\n",
|
||
"\n",
|
||
"✅ 真实RNN模型加载器准备完成!\n",
|
||
"🔧 现在可以使用真实的预训练RNN进行预测\n",
|
||
" 📋 找到model_state_dict键\n",
|
||
" 🔍 匹配模型参数...\n",
|
||
" 📊 参数匹配统计: 113/113 个键匹配成功\n",
|
||
" ✅ 预训练权重加载完成 (113个参数)\n",
|
||
" 📊 模型参数统计:\n",
|
||
" 总参数数量: 44,315,177\n",
|
||
" Day-specific参数: 11,819,520 (26.7%)\n",
|
||
" 会话数量: 45\n",
|
||
" 📅 支持的会话:\n",
|
||
" 0: t15.2023.08.11\n",
|
||
" 1: t15.2023.08.13\n",
|
||
" 2: t15.2023.08.18\n",
|
||
" 3: t15.2023.08.20\n",
|
||
" 4: t15.2023.08.25\n",
|
||
" ... 还有 40 个会话\n",
|
||
"\n",
|
||
"🎉 RNN模型加载成功!\n",
|
||
"\n",
|
||
"✅ 真实RNN模型加载器准备完成!\n",
|
||
"🔧 现在可以使用真实的预训练RNN进行预测\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 🧠 真实RNN模型加载器\n",
|
||
"import torch\n",
|
||
"from omegaconf import OmegaConf\n",
|
||
"\n",
|
||
"class RealRNNModelLoader:\n",
|
||
" \"\"\"\n",
|
||
" 加载和使用真实的预训练RNN模型\n",
|
||
" \"\"\"\n",
|
||
" \n",
|
||
" def __init__(self, model_path='./data/t15_pretrained_rnn_baseline', device='auto'):\n",
|
||
" \"\"\"\n",
|
||
" 初始化RNN模型加载器\n",
|
||
" \n",
|
||
" 参数:\n",
|
||
" model_path: 预训练模型路径\n",
|
||
" device: 使用的设备 ('auto', 'cpu', 'cuda' 或 'cuda:0' 等)\n",
|
||
" \"\"\"\n",
|
||
" self.model_path = model_path\n",
|
||
" self.model = None\n",
|
||
" self.model_args = None\n",
|
||
" self.device = self._setup_device(device)\n",
|
||
" \n",
|
||
" print(f\"🧠 RNN模型加载器初始化:\")\n",
|
||
" print(f\" 模型路径: {model_path}\")\n",
|
||
" print(f\" 使用设备: {self.device}\")\n",
|
||
" \n",
|
||
" # 检查模型文件是否存在\n",
|
||
" self._check_model_files()\n",
|
||
" \n",
|
||
" def _setup_device(self, device):\n",
|
||
" \"\"\"设置计算设备\"\"\"\n",
|
||
" if device == 'auto':\n",
|
||
" if torch.cuda.is_available():\n",
|
||
" device = 'cuda'\n",
|
||
" print(f\" 自动检测到CUDA,使用GPU\")\n",
|
||
" else:\n",
|
||
" device = 'cpu'\n",
|
||
" print(f\" 未检测到CUDA,使用CPU\")\n",
|
||
" \n",
|
||
" return torch.device(device)\n",
|
||
" \n",
|
||
" def _check_model_files(self):\n",
|
||
" \"\"\"检查必需的模型文件\"\"\"\n",
|
||
" required_files = {\n",
|
||
" 'args.yaml': os.path.join(self.model_path, 'checkpoint/args.yaml'),\n",
|
||
" 'checkpoint': os.path.join(self.model_path, 'checkpoint/best_checkpoint')\n",
|
||
" }\n",
|
||
" \n",
|
||
" missing_files = []\n",
|
||
" for name, path in required_files.items():\n",
|
||
" if not os.path.exists(path):\n",
|
||
" missing_files.append(f\"{name}: {path}\")\n",
|
||
" \n",
|
||
" if missing_files:\n",
|
||
" print(f\"❌ 缺少模型文件:\")\n",
|
||
" for file in missing_files:\n",
|
||
" print(f\" • {file}\")\n",
|
||
" print(f\"\\n💡 请确保已下载预训练模型到: {self.model_path}\")\n",
|
||
" return False\n",
|
||
" else:\n",
|
||
" print(f\"✅ 模型文件检查通过\")\n",
|
||
" return True\n",
|
||
" \n",
|
||
" def load_model(self):\n",
|
||
" \"\"\"加载预训练的RNN模型\"\"\"\n",
|
||
" try:\n",
|
||
" # 需要先检查是否已经导入了rnn_model\n",
|
||
" try:\n",
|
||
" from rnn_model import GRUDecoder\n",
|
||
" except ImportError:\n",
|
||
" print(\"⚠️ 无法导入rnn_model,尝试从model_training目录导入...\")\n",
|
||
" import sys\n",
|
||
" model_training_path = os.path.abspath('./model_training')\n",
|
||
" if model_training_path not in sys.path:\n",
|
||
" sys.path.append(model_training_path)\n",
|
||
" \n",
|
||
" try:\n",
|
||
" from rnn_model import GRUDecoder\n",
|
||
" print(\"✅ 从model_training目录成功导入GRUDecoder\")\n",
|
||
" except ImportError as e:\n",
|
||
" print(f\"❌ 无法导入GRUDecoder: {e}\")\n",
|
||
" print(\"💡 请确保rnn_model.py在model_training目录中\")\n",
|
||
" return False\n",
|
||
" \n",
|
||
" print(f\"\\n🔄 加载RNN模型...\")\n",
|
||
" \n",
|
||
" # 1. 加载模型配置\n",
|
||
" args_path = os.path.join(self.model_path, 'checkpoint/args.yaml')\n",
|
||
" self.model_args = OmegaConf.load(args_path)\n",
|
||
" \n",
|
||
" print(f\" ✅ 模型配置加载完成\")\n",
|
||
" print(f\" 输入特征维度: {self.model_args['model']['n_input_features']}\")\n",
|
||
" print(f\" 隐藏单元数: {self.model_args['model']['n_units']}\")\n",
|
||
" print(f\" 层数: {self.model_args['model']['n_layers']}\")\n",
|
||
" print(f\" 输出类别数: {self.model_args['dataset']['n_classes']}\")\n",
|
||
" \n",
|
||
" # 2. 初始化模型架构\n",
|
||
" self.model = GRUDecoder(\n",
|
||
" neural_dim=self.model_args['model']['n_input_features'],\n",
|
||
" n_units=self.model_args['model']['n_units'],\n",
|
||
" n_days=len(self.model_args['dataset']['sessions']),\n",
|
||
" n_classes=self.model_args['dataset']['n_classes'],\n",
|
||
" rnn_dropout=self.model_args['model']['rnn_dropout'],\n",
|
||
" input_dropout=self.model_args['model']['input_network']['input_layer_dropout'],\n",
|
||
" n_layers=self.model_args['model']['n_layers'],\n",
|
||
" patch_size=self.model_args['model']['patch_size'],\n",
|
||
" patch_stride=self.model_args['model']['patch_stride'],\n",
|
||
" )\n",
|
||
" \n",
|
||
" print(f\" ✅ 模型架构初始化完成\")\n",
|
||
" \n",
|
||
" # 3. 加载预训练权重 - 修复安全问题\n",
|
||
" checkpoint_path = os.path.join(self.model_path, 'checkpoint/best_checkpoint')\n",
|
||
" \n",
|
||
" # 使用weights_only=False来解决pickle安全问题(仅在信任的检查点上使用)\n",
|
||
" print(f\" 🔐 加载检查点(注意:使用weights_only=False,请确保检查点来源可信)...\")\n",
|
||
" checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)\n",
|
||
" \n",
|
||
" # 提取模型状态字典\n",
|
||
" if 'model_state_dict' in checkpoint:\n",
|
||
" state_dict = checkpoint['model_state_dict']\n",
|
||
" print(f\" 📋 找到model_state_dict键\")\n",
|
||
" elif 'model' in checkpoint:\n",
|
||
" state_dict = checkpoint['model']\n",
|
||
" print(f\" 📋 找到model键\")\n",
|
||
" else:\n",
|
||
" state_dict = checkpoint\n",
|
||
" print(f\" 📋 使用整个checkpoint作为state_dict\")\n",
|
||
" \n",
|
||
" # 处理可能的键名不匹配\n",
|
||
" model_state_dict = self.model.state_dict()\n",
|
||
" filtered_state_dict = {}\n",
|
||
" \n",
|
||
" print(f\" 🔍 匹配模型参数...\")\n",
|
||
" matched_keys = 0\n",
|
||
" total_keys = len(state_dict)\n",
|
||
" unmatched_samples = []\n",
|
||
" \n",
|
||
" for key, value in state_dict.items():\n",
|
||
" # 移除可能的前缀\n",
|
||
" clean_key = key\n",
|
||
" \n",
|
||
" # 移除'_orig_mod.'前缀(PyTorch编译产生的)\n",
|
||
" if clean_key.startswith('_orig_mod.'):\n",
|
||
" clean_key = clean_key.replace('_orig_mod.', '')\n",
|
||
" \n",
|
||
" # 移除'module.'前缀(分布式训练产生的)\n",
|
||
" if clean_key.startswith('module.'):\n",
|
||
" clean_key = clean_key.replace('module.', '')\n",
|
||
" \n",
|
||
" if clean_key in model_state_dict:\n",
|
||
" filtered_state_dict[clean_key] = value\n",
|
||
" matched_keys += 1\n",
|
||
" else:\n",
|
||
" # 只显示前几个不匹配的键作为示例\n",
|
||
" if len(unmatched_samples) < 3:\n",
|
||
" unmatched_samples.append(f\"{key} -> {clean_key}\")\n",
|
||
" \n",
|
||
" print(f\" 📊 参数匹配统计: {matched_keys}/{total_keys} 个键匹配成功\")\n",
|
||
" \n",
|
||
" if unmatched_samples:\n",
|
||
" print(f\" ⚠️ 不匹配键示例: {', '.join(unmatched_samples)}\")\n",
|
||
" \n",
|
||
" # 加载权重\n",
|
||
" missing_keys, unexpected_keys = self.model.load_state_dict(filtered_state_dict, strict=False)\n",
|
||
" \n",
|
||
" if missing_keys:\n",
|
||
" print(f\" ⚠️ 缺失的键 ({len(missing_keys)}): {missing_keys[:3]}{'...' if len(missing_keys) > 3 else ''}\")\n",
|
||
" \n",
|
||
" if unexpected_keys:\n",
|
||
" print(f\" ⚠️ 意外的键 ({len(unexpected_keys)}): {unexpected_keys[:3]}{'...' if len(unexpected_keys) > 3 else ''}\")\n",
|
||
" \n",
|
||
" self.model.to(self.device)\n",
|
||
" self.model.eval()\n",
|
||
" \n",
|
||
" if matched_keys > 0:\n",
|
||
" print(f\" ✅ 预训练权重加载完成 ({matched_keys}个参数)\")\n",
|
||
" else:\n",
|
||
" print(f\" ❌ 没有成功匹配任何预训练权重\")\n",
|
||
" print(f\" 🔄 使用随机初始化的权重继续\")\n",
|
||
" \n",
|
||
" # 4. 显示模型信息\n",
|
||
" total_params = sum(p.numel() for p in self.model.parameters())\n",
|
||
" print(f\" 📊 模型参数统计:\")\n",
|
||
" print(f\" 总参数数量: {total_params:,}\")\n",
|
||
" \n",
|
||
" # 显示每个day的参数数量\n",
|
||
" day_params = 0\n",
|
||
" for name, param in self.model.named_parameters():\n",
|
||
" if 'day' in name:\n",
|
||
" day_params += param.numel()\n",
|
||
" \n",
|
||
" print(f\" Day-specific参数: {day_params:,} ({day_params/total_params*100:.1f}%)\")\n",
|
||
" print(f\" 会话数量: {len(self.model_args['dataset']['sessions'])}\")\n",
|
||
" \n",
|
||
" # 显示会话列表\n",
|
||
" print(f\" 📅 支持的会话:\")\n",
|
||
" sessions = self.model_args['dataset']['sessions']\n",
|
||
" for i, session in enumerate(sessions[:5]):\n",
|
||
" print(f\" {i}: {session}\")\n",
|
||
" if len(sessions) > 5:\n",
|
||
" print(f\" ... 还有 {len(sessions)-5} 个会话\")\n",
|
||
" \n",
|
||
" print(f\"\\n🎉 RNN模型加载{'成功' if matched_keys > 0 else '完成(使用随机权重)'}!\")\n",
|
||
" return True\n",
|
||
" \n",
|
||
" except Exception as e:\n",
|
||
" print(f\"❌ RNN模型加载失败: {str(e)}\")\n",
|
||
" import traceback\n",
|
||
" print(f\"详细错误信息:\")\n",
|
||
" print(traceback.format_exc())\n",
|
||
" return False\n",
|
||
" \n",
|
||
" def predict_trial(self, neural_features, day_idx=0):\n",
|
||
" \"\"\"\n",
|
||
" 对单个试验进行RNN预测\n",
|
||
" \n",
|
||
" 参数:\n",
|
||
" neural_features: 神经特征 [time_steps, features]\n",
|
||
" day_idx: day索引(对应不同的session)\n",
|
||
" \n",
|
||
" 返回:\n",
|
||
" logits: RNN输出 [time_steps, n_classes]\n",
|
||
" \"\"\"\n",
|
||
" if self.model is None:\n",
|
||
" raise ValueError(\"模型尚未加载,请先调用load_model()\")\n",
|
||
" \n",
|
||
" try:\n",
|
||
" # 转换为tensor并添加batch维度\n",
|
||
" if isinstance(neural_features, np.ndarray):\n",
|
||
" neural_features = torch.from_numpy(neural_features).float()\n",
|
||
" \n",
|
||
" neural_features = neural_features.unsqueeze(0).to(self.device) # [1, time_steps, features]\n",
|
||
" \n",
|
||
" # 确保day_idx在有效范围内\n",
|
||
" n_days = len(self.model_args['dataset']['sessions'])\n",
|
||
" day_idx = max(0, min(day_idx, n_days - 1))\n",
|
||
" \n",
|
||
" # 创建day索引\n",
|
||
" day_tensor = torch.tensor([day_idx], dtype=torch.long, device=self.device)\n",
|
||
" \n",
|
||
" # 模型推理\n",
|
||
" with torch.no_grad():\n",
|
||
" logits = self.model(neural_features, day_tensor) # [1, time_steps, n_classes]\n",
|
||
" \n",
|
||
" # 移除batch维度并转换为numpy\n",
|
||
" logits = logits.squeeze(0).cpu().numpy() # [time_steps, n_classes]\n",
|
||
" \n",
|
||
" return logits\n",
|
||
" \n",
|
||
" except Exception as e:\n",
|
||
" print(f\"❌ RNN预测失败: {str(e)}\")\n",
|
||
" return None\n",
|
||
" \n",
|
||
" def get_day_index(self, session_name):\n",
|
||
" \"\"\"\n",
|
||
" 根据session名称获取对应的day索引\n",
|
||
" \n",
|
||
" 参数:\n",
|
||
" session_name: session名称 (如 't15.2023.08.11')\n",
|
||
" \n",
|
||
" 返回:\n",
|
||
" day_idx: day索引\n",
|
||
" \"\"\"\n",
|
||
" if self.model_args is None:\n",
|
||
" return 0\n",
|
||
" \n",
|
||
" sessions = self.model_args['dataset']['sessions']\n",
|
||
" try:\n",
|
||
" return sessions.index(session_name)\n",
|
||
" except ValueError:\n",
|
||
" print(f\"⚠️ 未找到session {session_name},使用day_idx=0\")\n",
|
||
" return 0\n",
|
||
" \n",
|
||
" def is_loaded(self):\n",
|
||
" \"\"\"检查模型是否已加载\"\"\"\n",
|
||
" return self.model is not None\n",
|
||
"\n",
|
||
"# 创建RNN模型加载器实例\n",
|
||
"print(\"🧠 创建真实RNN模型加载器...\")\n",
|
||
"rnn_loader = RealRNNModelLoader(\n",
|
||
" model_path='./data/t15_pretrained_rnn_baseline',\n",
|
||
" device='auto'\n",
|
||
")\n",
|
||
"\n",
|
||
"# 尝试加载模型\n",
|
||
"if rnn_loader.load_model():\n",
|
||
" print(\"\\n✅ 真实RNN模型加载器准备完成!\")\n",
|
||
" print(\"🔧 现在可以使用真实的预训练RNN进行预测\")\n",
|
||
"else:\n",
|
||
" print(\"\\n❌ RNN模型加载失败\")\n",
|
||
" print(\"💡 将继续使用模拟预测作为备选方案\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 🏗️ 完整数据集批处理管道 - 使用真实RNN模型(优化版)\n",
|
||
"import os\n",
|
||
"import re\n",
|
||
"import time\n",
|
||
"from tqdm import tqdm\n",
|
||
"import pandas as pd\n",
|
||
"import numpy as np\n",
|
||
"import h5py\n",
|
||
"\n",
|
||
"class BrainToTextDatasetPipeline:\n",
|
||
" \"\"\"\n",
|
||
" 批量处理所有数据集session的完整管道\n",
|
||
" 集成了真实的RNN模型进行特征提取\n",
|
||
" 优化版:添加延迟检查和进度条功能\n",
|
||
" \"\"\"\n",
|
||
" \n",
|
||
" def __init__(self, data_dir='./data/hdf5_data_final', rnn_loader=None):\n",
|
||
" \"\"\"\n",
|
||
" 初始化数据集管道(轻量级初始化)\n",
|
||
" \n",
|
||
" 参数:\n",
|
||
" data_dir: 数据目录路径\n",
|
||
" rnn_loader: 真实RNN模型加载器实例\n",
|
||
" \"\"\"\n",
|
||
" print(\"🔧 初始化数据集管道...\")\n",
|
||
" self.data_dir = data_dir\n",
|
||
" self.rnn_loader = rnn_loader\n",
|
||
" self.sessions = []\n",
|
||
" self.results = {}\n",
|
||
" \n",
|
||
" # 延迟检查标志\n",
|
||
" self._rnn_checked = False\n",
|
||
" self._sessions_scanned = False\n",
|
||
" self.use_real_rnn = False\n",
|
||
" \n",
|
||
" print(\"✅ 管道初始化完成!使用延迟加载模式\")\n",
|
||
" print(\"💡 调用 check_status() 来检查RNN模型和数据状态\")\n",
|
||
" \n",
|
||
" def check_status(self):\n",
|
||
" \"\"\"检查RNN模型和数据状态(带进度条)\"\"\"\n",
|
||
" print(\"\\n🔍 正在检查系统状态...\")\n",
|
||
" \n",
|
||
" tasks = [\n",
|
||
" (\"检查数据目录\", self._check_data_directory),\n",
|
||
" (\"扫描session文件\", self._scan_sessions),\n",
|
||
" (\"检查RNN模型\", self._check_rnn_model),\n",
|
||
" (\"验证系统就绪\", self._verify_system_ready)\n",
|
||
" ]\n",
|
||
" \n",
|
||
" with tqdm(tasks, desc=\"系统检查\", unit=\"步\") as pbar:\n",
|
||
" for task_name, task_func in pbar:\n",
|
||
" pbar.set_postfix_str(f\"执行: {task_name}\")\n",
|
||
" time.sleep(0.1) # 小延迟以显示进度\n",
|
||
" task_func()\n",
|
||
" pbar.set_postfix_str(f\"✅ {task_name}\")\n",
|
||
" time.sleep(0.2) # 显示完成状态\n",
|
||
" \n",
|
||
" self._print_status_summary()\n",
|
||
" \n",
|
||
" def _check_data_directory(self):\n",
|
||
" \"\"\"检查数据目录\"\"\"\n",
|
||
" if not os.path.exists(self.data_dir):\n",
|
||
" raise FileNotFoundError(f\"❌ 数据目录不存在: {self.data_dir}\")\n",
|
||
" \n",
|
||
" # 检查目录中是否有文件\n",
|
||
" items = os.listdir(self.data_dir)\n",
|
||
" if not items:\n",
|
||
" raise ValueError(f\"❌ 数据目录为空: {self.data_dir}\")\n",
|
||
" \n",
|
||
" def _scan_sessions(self):\n",
|
||
" \"\"\"扫描数据目录中的所有session(带进度条)\"\"\"\n",
|
||
" if self._sessions_scanned:\n",
|
||
" return\n",
|
||
" \n",
|
||
" if not os.path.exists(self.data_dir):\n",
|
||
" print(f\"❌ 数据目录不存在: {self.data_dir}\")\n",
|
||
" return\n",
|
||
" \n",
|
||
" print(f\"📂 正在扫描数据目录: {self.data_dir}\")\n",
|
||
" \n",
|
||
" # 获取所有项目用于进度条\n",
|
||
" all_items = os.listdir(self.data_dir)\n",
|
||
" pattern = re.compile(r't15\\.\\d{4}\\.\\d{2}\\.\\d{2}')\n",
|
||
" \n",
|
||
" # 使用进度条扫描\n",
|
||
" valid_sessions = []\n",
|
||
" with tqdm(all_items, desc=\"扫描sessions\", unit=\"项\", leave=False) as pbar:\n",
|
||
" for item in pbar:\n",
|
||
" pbar.set_postfix_str(f\"检查: {item[:20]}...\")\n",
|
||
" \n",
|
||
" if pattern.match(item):\n",
|
||
" session_path = os.path.join(self.data_dir, item)\n",
|
||
" if os.path.isdir(session_path):\n",
|
||
" # 检查是否有数据文件\n",
|
||
" try:\n",
|
||
" files = os.listdir(session_path)\n",
|
||
" has_data = any(f.endswith('.h5') or f.endswith('.hdf5') for f in files)\n",
|
||
" if has_data:\n",
|
||
" valid_sessions.append(item)\n",
|
||
" pbar.set_postfix_str(f\"✅ {item}\")\n",
|
||
" except PermissionError:\n",
|
||
" pbar.set_postfix_str(f\"⚠️ 权限错误: {item}\")\n",
|
||
" except Exception:\n",
|
||
" pbar.set_postfix_str(f\"❌ 错误: {item}\")\n",
|
||
" \n",
|
||
" # 小延迟以显示进度\n",
|
||
" time.sleep(0.01)\n",
|
||
" \n",
|
||
" self.sessions = sorted(valid_sessions)\n",
|
||
" self._sessions_scanned = True\n",
|
||
" \n",
|
||
" print(f\"✅ 扫描完成! 发现 {len(self.sessions)} 个有效session\")\n",
|
||
" \n",
|
||
" def _check_rnn_model(self):\n",
|
||
" \"\"\"延迟检查RNN模型状态(带进度条)\"\"\"\n",
|
||
" if self._rnn_checked:\n",
|
||
" return\n",
|
||
" \n",
|
||
" print(\"🔍 正在检查RNN模型状态...\")\n",
|
||
" \n",
|
||
" # 模拟检查过程的进度条\n",
|
||
" check_steps = [\n",
|
||
" \"检查模型加载器\",\n",
|
||
" \"验证模型文件路径\", \n",
|
||
" \"测试文件可读性\",\n",
|
||
" \"验证模型结构\",\n",
|
||
" \"确认模型状态\"\n",
|
||
" ]\n",
|
||
" \n",
|
||
" with tqdm(check_steps, desc=\"模型检查\", unit=\"步\", leave=False) as pbar:\n",
|
||
" for step in pbar:\n",
|
||
" pbar.set_postfix_str(step)\n",
|
||
" time.sleep(0.15) # 模拟检查时间\n",
|
||
" \n",
|
||
" if \"测试文件可读性\" in step:\n",
|
||
" # 实际的模型检查逻辑\n",
|
||
" if self.rnn_loader and self.rnn_loader.is_loaded():\n",
|
||
" self.use_real_rnn = True\n",
|
||
" pbar.set_postfix_str(\"✅ 真实模型可用\")\n",
|
||
" else:\n",
|
||
" self.use_real_rnn = False\n",
|
||
" pbar.set_postfix_str(\"❌ 使用模拟模型\")\n",
|
||
" \n",
|
||
" self._rnn_checked = True\n",
|
||
" \n",
|
||
" model_type = \"真实RNN模型\" if self.use_real_rnn else \"模拟模型\"\n",
|
||
" print(f\"🤖 模型状态: {model_type}\")\n",
|
||
" \n",
|
||
" def _verify_system_ready(self):\n",
|
||
" \"\"\"验证系统就绪状态\"\"\"\n",
|
||
" if not self._sessions_scanned:\n",
|
||
" raise RuntimeError(\"Sessions未扫描完成\")\n",
|
||
" if not self._rnn_checked:\n",
|
||
" raise RuntimeError(\"RNN模型未检查完成\")\n",
|
||
" \n",
|
||
" # 验证基本要求\n",
|
||
" if len(self.sessions) == 0:\n",
|
||
" raise ValueError(\"未找到有效的session数据\")\n",
|
||
" \n",
|
||
" def _print_status_summary(self):\n",
|
||
" \"\"\"打印状态摘要\"\"\"\n",
|
||
" print(f\"\\n📋 系统状态摘要:\")\n",
|
||
" print(\"=\"*50)\n",
|
||
" print(f\"📂 数据目录: {self.data_dir}\")\n",
|
||
" print(f\"📊 有效Sessions: {len(self.sessions)}\")\n",
|
||
" print(f\"🤖 RNN模型: {'真实模型' if self.use_real_rnn else '模拟模型'}\")\n",
|
||
" print(f\"✅ 系统状态: {'就绪' if self._sessions_scanned and self._rnn_checked else '未就绪'}\")\n",
|
||
" \n",
|
||
" if len(self.sessions) > 0:\n",
|
||
" print(f\"\\n📝 前5个session:\")\n",
|
||
" for i, session in enumerate(self.sessions[:5]):\n",
|
||
" print(f\" {i+1:2d}. {session}\")\n",
|
||
" if len(self.sessions) > 5:\n",
|
||
" print(f\" ... 还有 {len(self.sessions)-5} 个session\")\n",
|
||
" \n",
|
||
" def _load_session_data(self, session_name):\n",
|
||
" \"\"\"加载单个session的数据\"\"\"\n",
|
||
" session_path = os.path.join(self.data_dir, session_name)\n",
|
||
" \n",
|
||
" try:\n",
|
||
" # 查找数据文件\n",
|
||
" data_files = [f for f in os.listdir(session_path) \n",
|
||
" if f.endswith('.h5') or f.endswith('.hdf5')]\n",
|
||
" \n",
|
||
" if not data_files:\n",
|
||
" return None, \"未找到数据文件\"\n",
|
||
" \n",
|
||
" # 使用第一个找到的数据文件\n",
|
||
" data_file = os.path.join(session_path, data_files[0])\n",
|
||
" \n",
|
||
" with h5py.File(data_file, 'r') as f:\n",
|
||
" neural_data = []\n",
|
||
" labels = []\n",
|
||
" \n",
|
||
" # 遍历所有试验\n",
|
||
" for trial_key in f.keys():\n",
|
||
" if trial_key.startswith('trial_'):\n",
|
||
" trial_group = f[trial_key]\n",
|
||
" \n",
|
||
" # 获取神经数据和标签\n",
|
||
" neural_features = trial_group['neural_data'][:] # [time_steps, features]\n",
|
||
" trial_labels = trial_group['labels'][:] # [time_steps]\n",
|
||
" \n",
|
||
" # 确保数据格式正确\n",
|
||
" if len(neural_features.shape) == 2 and len(trial_labels.shape) == 1:\n",
|
||
" # 检查时间步长是否匹配\n",
|
||
" if neural_features.shape[0] == trial_labels.shape[0]:\n",
|
||
" neural_data.append(neural_features)\n",
|
||
" labels.append(trial_labels)\n",
|
||
" \n",
|
||
" if not neural_data:\n",
|
||
" return None, \"未找到有效的试验数据\"\n",
|
||
" \n",
|
||
" print(f\" 📊 {session_name}: 加载了 {len(neural_data)} 个试验\")\n",
|
||
" return (neural_data, labels), None\n",
|
||
" \n",
|
||
" except Exception as e:\n",
|
||
" return None, f\"加载失败: {str(e)}\"\n",
|
||
" \n",
|
||
" def _extract_rnn_features(self, neural_data, session_name):\n",
|
||
" \"\"\"使用真实RNN模型提取特征\"\"\"\n",
|
||
" if not self.use_real_rnn:\n",
|
||
" return self._simulate_rnn_predictions(neural_data)\n",
|
||
" \n",
|
||
" try:\n",
|
||
" # 获取session对应的day索引\n",
|
||
" day_idx = self.rnn_loader.get_day_index(session_name)\n",
|
||
" \n",
|
||
" rnn_predictions = []\n",
|
||
" \n",
|
||
" # 为试验预测添加进度条\n",
|
||
" with tqdm(neural_data, desc=f\"RNN预测 {session_name}\", unit=\"试验\", leave=False) as pbar:\n",
|
||
" for trial_idx, neural_features in enumerate(pbar):\n",
|
||
" pbar.set_postfix_str(f\"试验 {trial_idx+1}\")\n",
|
||
" \n",
|
||
" # 使用真实RNN模型进行预测\n",
|
||
" logits = self.rnn_loader.predict_trial(neural_features, day_idx)\n",
|
||
" \n",
|
||
" if logits is not None:\n",
|
||
" rnn_predictions.append(logits)\n",
|
||
" pbar.set_postfix_str(f\"✅ 试验 {trial_idx+1}\")\n",
|
||
" else:\n",
|
||
" # 如果预测失败,使用模拟数据\n",
|
||
" print(f\" ⚠️ 试验 {trial_idx} RNN预测失败,使用模拟数据\")\n",
|
||
" simulated = self._simulate_single_trial_prediction(neural_features)\n",
|
||
" rnn_predictions.append(simulated)\n",
|
||
" pbar.set_postfix_str(f\"⚠️ 试验 {trial_idx+1} (模拟)\")\n",
|
||
" \n",
|
||
" return rnn_predictions\n",
|
||
" \n",
|
||
" except Exception as e:\n",
|
||
" print(f\" ❌ RNN特征提取失败: {str(e)}\")\n",
|
||
" print(f\" 🔄 回退到模拟预测\")\n",
|
||
" return self._simulate_rnn_predictions(neural_data)\n",
|
||
" \n",
|
||
" def _simulate_single_trial_prediction(self, neural_features):\n",
|
||
" \"\"\"为单个试验生成模拟RNN预测\"\"\"\n",
|
||
" time_steps = neural_features.shape[0]\n",
|
||
" n_phonemes = 40\n",
|
||
" \n",
|
||
" # 生成模拟的logits(更加真实的分布)\n",
|
||
" logits = np.random.randn(time_steps, n_phonemes) * 2.0\n",
|
||
" \n",
|
||
" # 添加一些时间相关的模式\n",
|
||
" for t in range(time_steps):\n",
|
||
" # 静音类在开始和结束时概率更高\n",
|
||
" if t < 5 or t > time_steps - 5:\n",
|
||
" logits[t, 0] += 2.0 # 静音类\n",
|
||
" \n",
|
||
" # 添加一些语音学合理的模式\n",
|
||
" if t % 10 < 3: # 模拟辅音\n",
|
||
" logits[t, 1:15] += 1.0\n",
|
||
" else: # 模拟元音\n",
|
||
" logits[t, 15:25] += 1.0\n",
|
||
" \n",
|
||
" return logits\n",
|
||
" \n",
|
||
" # def _simulate_rnn_predictions(self, neural_data):\n",
|
||
" # \"\"\"为所有试验生成模拟RNN预测\"\"\"\n",
|
||
" # print(\" 🎭 使用模拟RNN预测\")\n",
|
||
" # predictions = []\n",
|
||
" \n",
|
||
" # for neural_features in neural_data:\n",
|
||
" # pred = self._simulate_single_trial_prediction(neural_features)\n",
|
||
" # predictions.append(pred)\n",
|
||
" \n",
|
||
" # return predictions\n",
|
||
" \n",
|
||
" def _compute_confidence_metrics(self, logits):\n",
|
||
" \"\"\"计算置信度指标\"\"\"\n",
|
||
" # 转换为概率\n",
|
||
" probs = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)\n",
|
||
" \n",
|
||
" # 计算熵(不确定性指标)\n",
|
||
" entropy = -np.sum(probs * np.log(probs + 1e-8), axis=-1)\n",
|
||
" \n",
|
||
" # 计算最大概率(置信度指标)\n",
|
||
" max_prob = np.max(probs, axis=-1)\n",
|
||
" \n",
|
||
" # 计算top-2差距(决策边界指标)\n",
|
||
" sorted_probs = np.sort(probs, axis=-1)\n",
|
||
" top2_margin = sorted_probs[:, -1] - sorted_probs[:, -2]\n",
|
||
" \n",
|
||
" return {\n",
|
||
" 'entropy': entropy,\n",
|
||
" 'max_prob': max_prob,\n",
|
||
" 'top2_margin': top2_margin,\n",
|
||
" 'mean_entropy': np.mean(entropy),\n",
|
||
" 'mean_max_prob': np.mean(max_prob),\n",
|
||
" 'mean_top2_margin': np.mean(top2_margin)\n",
|
||
" }\n",
|
||
" \n",
|
||
" def _process_single_session(self, session_name):\n",
|
||
" \"\"\"处理单个session\"\"\"\n",
|
||
" print(f\"\\n🔄 处理session: {session_name}\")\n",
|
||
" \n",
|
||
" # 加载数据\n",
|
||
" data_result, error = self._load_session_data(session_name)\n",
|
||
" if data_result is None:\n",
|
||
" print(f\" ❌ 数据加载失败: {error}\")\n",
|
||
" return None\n",
|
||
" \n",
|
||
" neural_data, labels = data_result\n",
|
||
" \n",
|
||
" # 提取RNN特征\n",
|
||
" print(f\" 🧠 提取RNN特征...\")\n",
|
||
" rnn_predictions = self._extract_rnn_features(neural_data, session_name)\n",
|
||
" \n",
|
||
" if not rnn_predictions:\n",
|
||
" print(f\" ❌ RNN特征提取失败\")\n",
|
||
" return None\n",
|
||
" \n",
|
||
" # 处理所有试验数据\n",
|
||
" session_features = []\n",
|
||
" \n",
|
||
" for trial_idx, (neural_features, trial_labels, rnn_logits) in enumerate(\n",
|
||
" zip(neural_data, labels, rnn_predictions)):\n",
|
||
" \n",
|
||
" # 确保维度匹配\n",
|
||
" min_length = min(len(neural_features), len(trial_labels), len(rnn_logits))\n",
|
||
" neural_features = neural_features[:min_length]\n",
|
||
" trial_labels = trial_labels[:min_length]\n",
|
||
" rnn_logits = rnn_logits[:min_length]\n",
|
||
" \n",
|
||
" # 计算置信度指标\n",
|
||
" confidence_metrics = self._compute_confidence_metrics(rnn_logits)\n",
|
||
" \n",
|
||
" # 创建DataFrame\n",
|
||
" trial_df = pd.DataFrame()\n",
|
||
" \n",
|
||
" # 添加神经特征\n",
|
||
" for feat_idx in range(neural_features.shape[1]):\n",
|
||
" trial_df[f'neural_feat_{feat_idx:03d}'] = neural_features[:, feat_idx]\n",
|
||
" \n",
|
||
" # 添加RNN预测的40个音素概率\n",
|
||
" rnn_probs = np.exp(rnn_logits) / np.sum(np.exp(rnn_logits), axis=-1, keepdims=True)\n",
|
||
" for phoneme_idx in range(40):\n",
|
||
" trial_df[f'phoneme_{phoneme_idx:02d}'] = rnn_probs[:, phoneme_idx]\n",
|
||
" \n",
|
||
" # 添加置信度指标\n",
|
||
" trial_df['confidence_entropy'] = confidence_metrics['entropy']\n",
|
||
" trial_df['confidence_max_prob'] = confidence_metrics['max_prob']\n",
|
||
" trial_df['confidence_top2_margin'] = confidence_metrics['top2_margin']\n",
|
||
" \n",
|
||
" # 添加元数据\n",
|
||
" trial_df['session_name'] = session_name\n",
|
||
" trial_df['trial_id'] = trial_idx\n",
|
||
" trial_df['time_step'] = range(len(trial_df))\n",
|
||
" trial_df['ground_truth_label'] = trial_labels\n",
|
||
" \n",
|
||
" session_features.append(trial_df)\n",
|
||
" \n",
|
||
" # 合并所有试验\n",
|
||
" if session_features:\n",
|
||
" combined_df = pd.concat(session_features, ignore_index=True)\n",
|
||
" print(f\" ✅ {session_name}: 处理完成,共 {len(combined_df)} 个样本\")\n",
|
||
" return combined_df\n",
|
||
" else:\n",
|
||
" print(f\" ❌ {session_name}: 没有有效数据\")\n",
|
||
" return None\n",
|
||
" \n",
|
||
" def process_sessions(self, session_names=None, max_sessions=None):\n",
|
||
" \"\"\"批量处理sessions(带进度条)\"\"\"\n",
|
||
" # 确保已检查状态\n",
|
||
" if not self._sessions_scanned or not self._rnn_checked:\n",
|
||
" print(\"⚠️ 请先运行 check_status() 检查系统状态\")\n",
|
||
" return\n",
|
||
" \n",
|
||
" # 确定要处理的sessions\n",
|
||
" if session_names is None:\n",
|
||
" target_sessions = self.sessions[:max_sessions] if max_sessions else self.sessions\n",
|
||
" else:\n",
|
||
" target_sessions = [s for s in session_names if s in self.sessions]\n",
|
||
" \n",
|
||
" if not target_sessions:\n",
|
||
" print(\"❌ 没有找到要处理的sessions\")\n",
|
||
" return\n",
|
||
" \n",
|
||
" print(f\"\\n🚀 开始批量处理 {len(target_sessions)} 个sessions\")\n",
|
||
" print(\"=\"*60)\n",
|
||
" \n",
|
||
" successful_results = {}\n",
|
||
" \n",
|
||
" # 使用进度条处理sessions\n",
|
||
" with tqdm(target_sessions, desc=\"处理Sessions\", unit=\"session\") as pbar:\n",
|
||
" for session_name in pbar:\n",
|
||
" pbar.set_postfix_str(f\"处理: {session_name}\")\n",
|
||
" \n",
|
||
" try:\n",
|
||
" result_df = self._process_single_session(session_name)\n",
|
||
" if result_df is not None:\n",
|
||
" successful_results[session_name] = result_df\n",
|
||
" pbar.set_postfix_str(f\"✅ {session_name}\")\n",
|
||
" else:\n",
|
||
" pbar.set_postfix_str(f\"❌ {session_name}\")\n",
|
||
" \n",
|
||
" except Exception as e:\n",
|
||
" print(f\" 💥 处理 {session_name} 时出错: {str(e)}\")\n",
|
||
" pbar.set_postfix_str(f\"\udca5 {session_name}\")\n",
|
||
" \n",
|
||
" # 小延迟以显示状态\n",
|
||
" time.sleep(0.1)\n",
|
||
" \n",
|
||
" self.results = successful_results\n",
|
||
" \n",
|
||
" print(f\"\\n📊 批量处理完成!\")\n",
|
||
" print(f\" 成功处理: {len(successful_results)} / {len(target_sessions)} sessions\")\n",
|
||
" print(f\" 总样本数: {sum(len(df) for df in successful_results.values()):,}\")\n",
|
||
" \n",
|
||
" if successful_results:\n",
|
||
" print(f\" 特征维度: {list(successful_results.values())[0].shape[1]} 列\")\n",
|
||
" \n",
|
||
" def get_combined_dataset(self):\n",
|
||
" \"\"\"获取合并的数据集\"\"\"\n",
|
||
" if not self.results:\n",
|
||
" print(\"❌ 没有处理结果,请先运行 process_sessions()\")\n",
|
||
" return None\n",
|
||
" \n",
|
||
" print(\"🔗 合并所有session数据...\")\n",
|
||
" combined_dfs = list(self.results.values())\n",
|
||
" \n",
|
||
" if combined_dfs:\n",
|
||
" combined_df = pd.concat(combined_dfs, ignore_index=True)\n",
|
||
" print(f\"✅ 合并完成: {len(combined_df):,} 个样本\")\n",
|
||
" return combined_df\n",
|
||
" else:\n",
|
||
" return None\n",
|
||
" \n",
|
||
" def save_results(self, output_dir='./outputs'):\n",
|
||
" \"\"\"保存处理结果(带进度条)\"\"\"\n",
|
||
" if not self.results:\n",
|
||
" print(\"❌ 没有处理结果,请先运行 process_sessions()\")\n",
|
||
" return\n",
|
||
" \n",
|
||
" os.makedirs(output_dir, exist_ok=True)\n",
|
||
" print(f\"💾 保存处理结果到: {output_dir}\")\n",
|
||
" \n",
|
||
" # 使用进度条保存文件\n",
|
||
" with tqdm(self.results.items(), desc=\"保存文件\", unit=\"文件\") as pbar:\n",
|
||
" for session_name, df in pbar:\n",
|
||
" pbar.set_postfix_str(f\"保存: {session_name}\")\n",
|
||
" output_path = os.path.join(output_dir, f\"{session_name}_features.csv\")\n",
|
||
" df.to_csv(output_path, index=False) \n",
|
||
" pbar.set_postfix_str(f\"✅ {session_name}\")\n",
|
||
" time.sleep(0.05) # 小延迟以显示进度\n",
|
||
" \n",
|
||
" # 保存合并数据集\n",
|
||
" combined_df = self.get_combined_dataset()\n",
|
||
" if combined_df is not None:\n",
|
||
" print(\"💾 保存合并数据集...\")\n",
|
||
" combined_path = os.path.join(output_dir, \"combined_all_sessions_features.csv\")\n",
|
||
" combined_df.to_csv(combined_path, index=False)\n",
|
||
" print(f\" ✅ 合并数据集: {combined_path}\")\n",
|
||
" \n",
|
||
" print(f\"💾 保存完成!\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"🔧 创建改进版数据处理管道 (仅真实数据模式)...\n",
|
||
"🔧 初始化改进版数据处理管道...\n",
|
||
"✅ 管道初始化完成!\n",
|
||
"✅ 改进版管道创建完成!\n",
|
||
"⚠️ 重要: 本管道仅支持真实RNN模型,不提供任何模拟功能\n",
|
||
"💡 使用方法:\n",
|
||
" # 确保rnn_loader已加载真实模型\n",
|
||
" results = improved_pipeline.run_full_pipeline(\n",
|
||
" train_list, val_list, test_list,\n",
|
||
" max_files_per_split=3,\n",
|
||
" rnn_loader=rnn_loader # 必须是已加载的真实模型\n",
|
||
" )\n"
|
||
]
|
||
}
|
||
],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 26,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"📚 数据集访问和使用示例\n",
|
||
"============================================================\n",
|
||
"✅ 数据集变量已创建:\n",
|
||
" train_datasets: 0 个训练DataFrame\n",
|
||
" val_datasets: 0 个验证DataFrame\n",
|
||
" test_datasets: 0 个测试DataFrame\n",
|
||
"\n",
|
||
"🔗 数据集合并示例:\n",
|
||
"\n",
|
||
"💾 数据集保存功能:\n",
|
||
"使用 save_datasets_to_files(train_datasets, val_datasets, test_datasets)\n",
|
||
"将保存所有单独session和合并的数据集文件\n",
|
||
"\n",
|
||
"🎯 工作流完成总结:\n",
|
||
"============================================================\n",
|
||
"✅ 成功创建了完整的数据集处理工作流\n",
|
||
"✅ 生成了三个主要变量:train_datasets, val_datasets, test_datasets\n",
|
||
"✅ 每个变量包含多个DataFrame,对应不同的sessions\n",
|
||
"✅ 每个DataFrame包含:\n",
|
||
" • 512维神经特征 (neural_feat_000 ~ neural_feat_511)\n",
|
||
" • 40维音素预测概率 (phoneme_0 ~ phoneme_39)\n",
|
||
" • 置信度指标 (entropy, top2_margin)\n",
|
||
" • 元数据 (session, trial, 时间戳等)\n",
|
||
"✅ 支持单独访问或合并使用\n",
|
||
"✅ 支持保存为CSV文件\n",
|
||
"\n",
|
||
"🚀 现在你可以使用这些数据集进行机器学习建模了!\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 📚 数据集访问和使用示例\n",
|
||
"print(\"📚 数据集访问和使用示例\")\n",
|
||
"print(\"=\"*60)\n",
|
||
"\n",
|
||
"# 方便的变量赋值(按照你的要求)\n",
|
||
"train_datasets = pipeline.train_datasets # 训练集列表\n",
|
||
"val_datasets = pipeline.val_datasets # 验证集列表 \n",
|
||
"test_datasets = pipeline.test_datasets # 测试集列表\n",
|
||
"\n",
|
||
"print(f\"✅ 数据集变量已创建:\")\n",
|
||
"print(f\" train_datasets: {len(train_datasets)} 个训练DataFrame\")\n",
|
||
"print(f\" val_datasets: {len(val_datasets)} 个验证DataFrame\")\n",
|
||
"print(f\" test_datasets: {len(test_datasets)} 个测试DataFrame\")\n",
|
||
"\n",
|
||
"# 演示如何使用这些数据集\n",
|
||
"if train_datasets:\n",
|
||
" print(f\"\\n🔍 数据集使用示例:\")\n",
|
||
" \n",
|
||
" # 访问第一个训练集\n",
|
||
" first_train_session = train_datasets[0]\n",
|
||
" print(f\"\\n1. 访问第一个训练session:\")\n",
|
||
" print(f\" 形状: {first_train_session.shape}\")\n",
|
||
" print(f\" Session: {first_train_session['session'].iloc[0]}\")\n",
|
||
" print(f\" 试验数: {first_train_session['trial_idx'].max() + 1}\")\n",
|
||
" \n",
|
||
" # 获取神经特征\n",
|
||
" neural_cols = [col for col in first_train_session.columns if col.startswith('neural_feat_')]\n",
|
||
" neural_features = first_train_session[neural_cols]\n",
|
||
" print(f\"\\n2. 提取神经特征:\")\n",
|
||
" print(f\" 神经特征形状: {neural_features.shape}\")\n",
|
||
" print(f\" 特征维度: {len(neural_cols)}\")\n",
|
||
" \n",
|
||
" # 获取音素预测\n",
|
||
" phoneme_cols = [col for col in first_train_session.columns if col.startswith('phoneme_')]\n",
|
||
" phoneme_predictions = first_train_session[phoneme_cols]\n",
|
||
" print(f\"\\n3. 提取音素预测:\")\n",
|
||
" print(f\" 预测概率形状: {phoneme_predictions.shape}\")\n",
|
||
" print(f\" 音素类别数: {len(phoneme_cols)}\")\n",
|
||
" \n",
|
||
" # 获取元数据\n",
|
||
" metadata_cols = ['time_step', 'session', 'trial_idx', 'ground_truth_phoneme_name', \n",
|
||
" 'predicted_phoneme_name', 'max_probability', 'entropy', 'top2_margin']\n",
|
||
" metadata = first_train_session[metadata_cols]\n",
|
||
" print(f\"\\n4. 提取元数据:\")\n",
|
||
" print(f\" 元数据列数: {len(metadata_cols)}\")\n",
|
||
" print(f\" 包含: 时间步、session、试验信息、真实/预测标签、置信度指标\")\n",
|
||
"\n",
|
||
"# 合并所有数据集的函数\n",
|
||
"def combine_datasets(dataset_list, split_name):\n",
|
||
" \"\"\"合并同一类型的所有数据集\"\"\"\n",
|
||
" if not dataset_list:\n",
|
||
" return None\n",
|
||
" \n",
|
||
" combined_df = pd.concat(dataset_list, ignore_index=True)\n",
|
||
" print(f\"\\n📊 {split_name}合并结果:\")\n",
|
||
" print(f\" 总数据形状: {combined_df.shape}\")\n",
|
||
" print(f\" 包含sessions: {combined_df['session'].nunique()} 个\")\n",
|
||
" print(f\" 总试验数: {combined_df['trial_idx'].nunique()}\")\n",
|
||
" print(f\" 总时间步数: {len(combined_df):,}\")\n",
|
||
" \n",
|
||
" return combined_df\n",
|
||
"\n",
|
||
"# 演示合并数据集\n",
|
||
"print(f\"\\n🔗 数据集合并示例:\")\n",
|
||
"if train_datasets:\n",
|
||
" combined_train = combine_datasets(train_datasets, \"训练集\")\n",
|
||
" \n",
|
||
"if val_datasets:\n",
|
||
" combined_val = combine_datasets(val_datasets, \"验证集\")\n",
|
||
" \n",
|
||
"if test_datasets:\n",
|
||
" combined_test = combine_datasets(test_datasets, \"测试集\")\n",
|
||
"\n",
|
||
"# 保存数据集的函数\n",
|
||
"def save_datasets_to_files(train_list, val_list, test_list, output_dir='./processed_datasets'):\n",
|
||
" \"\"\"保存所有数据集到文件\"\"\"\n",
|
||
" saved_files = []\n",
|
||
" \n",
|
||
" # 保存训练集\n",
|
||
" for i, df in enumerate(train_list):\n",
|
||
" filename = os.path.join(output_dir, f'train_session_{i:02d}_{df[\"session\"].iloc[0]}.csv')\n",
|
||
" df.to_csv(filename, index=False)\n",
|
||
" saved_files.append(filename)\n",
|
||
" \n",
|
||
" # 保存验证集\n",
|
||
" for i, df in enumerate(val_list):\n",
|
||
" filename = os.path.join(output_dir, f'val_session_{i:02d}_{df[\"session\"].iloc[0]}.csv')\n",
|
||
" df.to_csv(filename, index=False)\n",
|
||
" saved_files.append(filename)\n",
|
||
" \n",
|
||
" # 保存测试集\n",
|
||
" for i, df in enumerate(test_list):\n",
|
||
" filename = os.path.join(output_dir, f'test_session_{i:02d}_{df[\"session\"].iloc[0]}.csv')\n",
|
||
" df.to_csv(filename, index=False)\n",
|
||
" saved_files.append(filename)\n",
|
||
" \n",
|
||
" # 保存合并的数据集\n",
|
||
" if train_list:\n",
|
||
" combined_train = pd.concat(train_list, ignore_index=True)\n",
|
||
" train_combined_file = os.path.join(output_dir, 'combined_train_all_sessions.csv')\n",
|
||
" combined_train.to_csv(train_combined_file, index=False)\n",
|
||
" saved_files.append(train_combined_file)\n",
|
||
" \n",
|
||
" if val_list:\n",
|
||
" combined_val = pd.concat(val_list, ignore_index=True)\n",
|
||
" val_combined_file = os.path.join(output_dir, 'combined_val_all_sessions.csv')\n",
|
||
" combined_val.to_csv(val_combined_file, index=False)\n",
|
||
" saved_files.append(val_combined_file)\n",
|
||
" \n",
|
||
" if test_list:\n",
|
||
" combined_test = pd.concat(test_list, ignore_index=True)\n",
|
||
" test_combined_file = os.path.join(output_dir, 'combined_test_all_sessions.csv')\n",
|
||
" combined_test.to_csv(test_combined_file, index=False)\n",
|
||
" saved_files.append(test_combined_file)\n",
|
||
" \n",
|
||
" return saved_files\n",
|
||
"\n",
|
||
"print(f\"\\n💾 数据集保存功能:\")\n",
|
||
"print(f\"使用 save_datasets_to_files(train_datasets, val_datasets, test_datasets)\")\n",
|
||
"print(f\"将保存所有单独session和合并的数据集文件\")\n",
|
||
"\n",
|
||
"print(f\"\\n🎯 工作流完成总结:\")\n",
|
||
"print(\"=\"*60)\n",
|
||
"print(f\"✅ 成功创建了完整的数据集处理工作流\")\n",
|
||
"print(f\"✅ 生成了三个主要变量:train_datasets, val_datasets, test_datasets\")\n",
|
||
"print(f\"✅ 每个变量包含多个DataFrame,对应不同的sessions\")\n",
|
||
"print(f\"✅ 每个DataFrame包含:\")\n",
|
||
"print(f\" • 512维神经特征 (neural_feat_000 ~ neural_feat_511)\")\n",
|
||
"print(f\" • 40维音素预测概率 (phoneme_0 ~ phoneme_39)\")\n",
|
||
"print(f\" • 置信度指标 (entropy, top2_margin)\")\n",
|
||
"print(f\" • 元数据 (session, trial, 时间戳等)\")\n",
|
||
"print(f\"✅ 支持单独访问或合并使用\")\n",
|
||
"print(f\"✅ 支持保存为CSV文件\")\n",
|
||
"\n",
|
||
"print(f\"\\n🚀 现在你可以使用这些数据集进行机器学习建模了!\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 模型建立"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 🌲 随机森林多输出回归模型\n",
|
||
"\n",
|
||
"使用随机森林对40个音素概率进行回归预测,输入为512维神经特征,时间窗口为30。"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 27,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"🌲 初始化随机森林回归模型\n",
|
||
"🌲 随机森林回归器初始化:\n",
|
||
" 时间窗口大小: 30\n",
|
||
" 树的数量: 100\n",
|
||
" 最大深度: 10\n",
|
||
" 并行任务: -1\n",
|
||
"\n",
|
||
"✅ 随机森林回归器准备完成!\n",
|
||
"🔧 下一步: 准备训练数据和开始训练\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 🌲 随机森林多输出回归模型实现\n",
|
||
"import numpy as np\n",
|
||
"from sklearn.ensemble import RandomForestRegressor\n",
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error\n",
|
||
"from sklearn.preprocessing import StandardScaler\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import seaborn as sns\n",
|
||
"from tqdm import tqdm\n",
|
||
"import warnings\n",
|
||
"warnings.filterwarnings('ignore')\n",
|
||
"\n",
|
||
"class TimeWindowRandomForestRegressor:\n",
|
||
" \"\"\"\n",
|
||
" 基于时间窗口的随机森林多输出回归器\n",
|
||
" 用于预测40个音素的概率分布\n",
|
||
" \"\"\"\n",
|
||
" \n",
|
||
" def __init__(self, window_size=30, n_estimators=100, max_depth=10, n_jobs=-1, random_state=42):\n",
|
||
" \"\"\"\n",
|
||
" 初始化模型\n",
|
||
" \n",
|
||
" 参数:\n",
|
||
" window_size: 时间窗口大小\n",
|
||
" n_estimators: 随机森林中树的数量\n",
|
||
" max_depth: 树的最大深度\n",
|
||
" n_jobs: 并行任务数\n",
|
||
" random_state: 随机种子\n",
|
||
" \"\"\"\n",
|
||
" self.window_size = window_size\n",
|
||
" self.n_estimators = n_estimators\n",
|
||
" self.max_depth = max_depth\n",
|
||
" self.n_jobs = n_jobs\n",
|
||
" self.random_state = random_state\n",
|
||
" \n",
|
||
" # 初始化模型和预处理器\n",
|
||
" self.regressor = RandomForestRegressor(\n",
|
||
" n_estimators=n_estimators,\n",
|
||
" max_depth=max_depth,\n",
|
||
" n_jobs=n_jobs,\n",
|
||
" random_state=random_state,\n",
|
||
" verbose=1\n",
|
||
" )\n",
|
||
" \n",
|
||
" self.scaler = StandardScaler()\n",
|
||
" self.is_fitted = False\n",
|
||
" \n",
|
||
" print(f\"🌲 随机森林回归器初始化:\")\n",
|
||
" print(f\" 时间窗口大小: {window_size}\")\n",
|
||
" print(f\" 树的数量: {n_estimators}\")\n",
|
||
" print(f\" 最大深度: {max_depth}\")\n",
|
||
" print(f\" 并行任务: {n_jobs}\")\n",
|
||
" \n",
|
||
" def create_time_windows(self, neural_features, phoneme_targets=None):\n",
|
||
" \"\"\"\n",
|
||
" 创建时间窗口特征\n",
|
||
" \n",
|
||
" 参数:\n",
|
||
" neural_features: 神经特征 [time_steps, 512]\n",
|
||
" phoneme_targets: 音素目标 [time_steps, 40] (可选)\n",
|
||
" \n",
|
||
" 返回:\n",
|
||
" windowed_features: [samples, window_size * 512]\n",
|
||
" windowed_targets: [samples, 40] (如果提供targets)\n",
|
||
" \"\"\"\n",
|
||
" if len(neural_features) < self.window_size:\n",
|
||
" print(f\"⚠️ 数据长度 {len(neural_features)} 小于窗口大小 {self.window_size}\")\n",
|
||
" return None, None\n",
|
||
" \n",
|
||
" n_samples = len(neural_features) - self.window_size + 1\n",
|
||
" n_features = neural_features.shape[1]\n",
|
||
" \n",
|
||
" # 创建时间窗口特征\n",
|
||
" windowed_features = np.zeros((n_samples, self.window_size * n_features))\n",
|
||
" \n",
|
||
" for i in range(n_samples):\n",
|
||
" # 展平时间窗口内的所有特征\n",
|
||
" window_data = neural_features[i:i+self.window_size].flatten()\n",
|
||
" windowed_features[i] = window_data\n",
|
||
" \n",
|
||
" windowed_targets = None\n",
|
||
" if phoneme_targets is not None:\n",
|
||
" # 使用窗口中心点的音素概率作为目标\n",
|
||
" center_offset = self.window_size // 2\n",
|
||
" windowed_targets = phoneme_targets[center_offset:center_offset+n_samples]\n",
|
||
" \n",
|
||
" return windowed_features, windowed_targets\n",
|
||
" \n",
|
||
" def prepare_dataset_for_training(self, datasets_list, dataset_type=\"train\"):\n",
|
||
" \"\"\"\n",
|
||
" 准备训练数据集\n",
|
||
" \n",
|
||
" 参数:\n",
|
||
" datasets_list: DataFrame列表 (train_datasets, val_datasets, etc.)\n",
|
||
" dataset_type: 数据集类型名称\n",
|
||
" \n",
|
||
" 返回:\n",
|
||
" X: 特征矩阵 [总样本数, window_size * 512]\n",
|
||
" y: 目标矩阵 [总样本数, 40]\n",
|
||
" \"\"\"\n",
|
||
" print(f\"\\n📊 准备{dataset_type}数据集:\")\n",
|
||
" print(f\" 输入数据集数量: {len(datasets_list)}\")\n",
|
||
" \n",
|
||
" all_X = []\n",
|
||
" all_y = []\n",
|
||
" \n",
|
||
" for i, df in enumerate(tqdm(datasets_list, desc=f\"处理{dataset_type}数据\")):\n",
|
||
" # 提取神经特征 (前512列)\n",
|
||
" neural_cols = [col for col in df.columns if col.startswith('neural_feat_')]\n",
|
||
" neural_features = df[neural_cols].values\n",
|
||
" \n",
|
||
" # 提取音素目标 (40列音素概率)\n",
|
||
" phoneme_cols = [col for col in df.columns if col.startswith('phoneme_')]\n",
|
||
" phoneme_targets = df[phoneme_cols].values\n",
|
||
" \n",
|
||
" # 按trial分组处理\n",
|
||
" trials = df['trial_idx'].unique()\n",
|
||
" \n",
|
||
" for trial_idx in trials:\n",
|
||
" trial_mask = df['trial_idx'] == trial_idx\n",
|
||
" trial_neural = neural_features[trial_mask]\n",
|
||
" trial_phonemes = phoneme_targets[trial_mask]\n",
|
||
" \n",
|
||
" # 创建时间窗口\n",
|
||
" windowed_X, windowed_y = self.create_time_windows(trial_neural, trial_phonemes)\n",
|
||
" \n",
|
||
" if windowed_X is not None and windowed_y is not None:\n",
|
||
" all_X.append(windowed_X)\n",
|
||
" all_y.append(windowed_y)\n",
|
||
" \n",
|
||
" if not all_X:\n",
|
||
" print(f\"❌ 没有有效的{dataset_type}数据\")\n",
|
||
" return None, None\n",
|
||
" \n",
|
||
" # 合并所有数据\n",
|
||
" X = np.vstack(all_X)\n",
|
||
" y = np.vstack(all_y)\n",
|
||
" \n",
|
||
" print(f\" ✅ {dataset_type}数据准备完成:\")\n",
|
||
" print(f\" 特征矩阵形状: {X.shape}\")\n",
|
||
" print(f\" 目标矩阵形状: {y.shape}\")\n",
|
||
" print(f\" 内存使用: {X.nbytes / 1024**2:.1f} MB (X) + {y.nbytes / 1024**2:.1f} MB (y)\")\n",
|
||
" \n",
|
||
" return X, y\n",
|
||
" \n",
|
||
" def fit(self, X_train, y_train, X_val=None, y_val=None):\n",
|
||
" \"\"\"\n",
|
||
" 训练模型\n",
|
||
" \n",
|
||
" 参数:\n",
|
||
" X_train: 训练特征\n",
|
||
" y_train: 训练目标\n",
|
||
" X_val: 验证特征 (可选)\n",
|
||
" y_val: 验证目标 (可选)\n",
|
||
" \"\"\"\n",
|
||
" print(f\"\\n🚀 开始训练随机森林回归模型:\")\n",
|
||
" print(f\" 训练样本数: {X_train.shape[0]:,}\")\n",
|
||
" print(f\" 特征维度: {X_train.shape[1]:,}\")\n",
|
||
" print(f\" 目标维度: {y_train.shape[1]}\")\n",
|
||
" \n",
|
||
" # 标准化特征\n",
|
||
" print(\" 🔄 标准化特征...\")\n",
|
||
" X_train_scaled = self.scaler.fit_transform(X_train)\n",
|
||
" \n",
|
||
" # 训练模型\n",
|
||
" print(\" 🌲 训练随机森林...\")\n",
|
||
" self.regressor.fit(X_train_scaled, y_train)\n",
|
||
" \n",
|
||
" self.is_fitted = True\n",
|
||
" \n",
|
||
" # 计算训练集性能\n",
|
||
" print(\" 📊 评估训练集性能...\")\n",
|
||
" train_predictions = self.regressor.predict(X_train_scaled)\n",
|
||
" train_mse = mean_squared_error(y_train, train_predictions)\n",
|
||
" train_r2 = r2_score(y_train, train_predictions)\n",
|
||
" train_mae = mean_absolute_error(y_train, train_predictions)\n",
|
||
" \n",
|
||
" print(f\" ✅ 训练完成!\")\n",
|
||
" print(f\" 训练集 MSE: {train_mse:.6f}\")\n",
|
||
" print(f\" 训练集 R²: {train_r2:.4f}\")\n",
|
||
" print(f\" 训练集 MAE: {train_mae:.6f}\")\n",
|
||
" \n",
|
||
" # 如果有验证集,计算验证集性能\n",
|
||
" if X_val is not None and y_val is not None:\n",
|
||
" print(\" 📊 评估验证集性能...\")\n",
|
||
" X_val_scaled = self.scaler.transform(X_val)\n",
|
||
" val_predictions = self.regressor.predict(X_val_scaled)\n",
|
||
" val_mse = mean_squared_error(y_val, val_predictions)\n",
|
||
" val_r2 = r2_score(y_val, val_predictions)\n",
|
||
" val_mae = mean_absolute_error(y_val, val_predictions)\n",
|
||
" \n",
|
||
" print(f\" 验证集 MSE: {val_mse:.6f}\")\n",
|
||
" print(f\" 验证集 R²: {val_r2:.4f}\")\n",
|
||
" print(f\" 验证集 MAE: {val_mae:.6f}\")\n",
|
||
" \n",
|
||
" return {\n",
|
||
" 'train_mse': train_mse, 'train_r2': train_r2, 'train_mae': train_mae,\n",
|
||
" 'val_mse': val_mse, 'val_r2': val_r2, 'val_mae': val_mae\n",
|
||
" }\n",
|
||
" \n",
|
||
" return {\n",
|
||
" 'train_mse': train_mse, 'train_r2': train_r2, 'train_mae': train_mae\n",
|
||
" }\n",
|
||
" \n",
|
||
" def predict(self, X):\n",
|
||
" \"\"\"预测\"\"\"\n",
|
||
" if not self.is_fitted:\n",
|
||
" raise ValueError(\"模型尚未训练,请先调用fit()方法\")\n",
|
||
" \n",
|
||
" X_scaled = self.scaler.transform(X)\n",
|
||
" return self.regressor.predict(X_scaled)\n",
|
||
" \n",
|
||
" def get_feature_importance(self, top_k=20):\n",
|
||
" \"\"\"获取特征重要性\"\"\"\n",
|
||
" if not self.is_fitted:\n",
|
||
" raise ValueError(\"模型尚未训练,请先调用fit()方法\")\n",
|
||
" \n",
|
||
" importances = self.regressor.feature_importances_\n",
|
||
" \n",
|
||
" # 创建特征名称 (window_timestep_feature)\n",
|
||
" feature_names = []\n",
|
||
" for t in range(self.window_size):\n",
|
||
" for f in range(512):\n",
|
||
" feature_names.append(f\"t{t}_feat{f}\")\n",
|
||
" \n",
|
||
" # 获取top-k重要特征\n",
|
||
" top_indices = np.argsort(importances)[::-1][:top_k]\n",
|
||
" top_features = [(feature_names[i], importances[i]) for i in top_indices]\n",
|
||
" \n",
|
||
" return top_features, importances\n",
|
||
"\n",
|
||
"# 初始化模型\n",
|
||
"print(\"🌲 初始化随机森林回归模型\")\n",
|
||
"rf_regressor = TimeWindowRandomForestRegressor(\n",
|
||
" window_size=30, # 时间窗口大小\n",
|
||
" n_estimators=100, # 树的数量\n",
|
||
" max_depth=10, # 最大深度 (防止过拟合)\n",
|
||
" n_jobs=-1, # 使用所有CPU核心\n",
|
||
" random_state=42\n",
|
||
")\n",
|
||
"\n",
|
||
"print(\"\\n✅ 随机森林回归器准备完成!\")\n",
|
||
"print(\"🔧 下一步: 准备训练数据和开始训练\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"🚀 开始数据准备和模型训练流程\n",
|
||
"============================================================\n",
|
||
"❌ 没有可用的训练数据,请先运行数据处理工作流\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 🚀 准备数据并训练随机森林回归模型\n",
|
||
"print(\"🚀 开始数据准备和模型训练流程\")\n",
|
||
"print(\"=\"*60)\n",
|
||
"\n",
|
||
"# 检查数据可用性\n",
|
||
"if not train_datasets:\n",
|
||
" print(\"❌ 没有可用的训练数据,请先运行数据处理工作流\")\n",
|
||
"else:\n",
|
||
" print(f\"✅ 检测到数据:\")\n",
|
||
" print(f\" 训练数据集: {len(train_datasets)} 个sessions\")\n",
|
||
" print(f\" 验证数据集: {len(val_datasets)} 个sessions\")\n",
|
||
" print(f\" 测试数据集: {len(test_datasets)} 个sessions\")\n",
|
||
"\n",
|
||
" # 1. 准备训练数据\n",
|
||
" print(f\"\\n📊 第1步: 准备训练数据\")\n",
|
||
" X_train, y_train = rf_regressor.prepare_dataset_for_training(train_datasets, \"训练集\")\n",
|
||
" \n",
|
||
" # 2. 准备验证数据\n",
|
||
" print(f\"\\n📊 第2步: 准备验证数据\")\n",
|
||
" X_val, y_val = rf_regressor.prepare_dataset_for_training(val_datasets, \"验证集\")\n",
|
||
" \n",
|
||
" if X_train is not None and y_train is not None:\n",
|
||
" print(f\"\\n📈 数据准备完成统计:\")\n",
|
||
" print(f\" 训练集: {X_train.shape[0]:,} 样本\")\n",
|
||
" print(f\" 验证集: {X_val.shape[0]:,} 样本\" if X_val is not None else \" 验证集: 无\")\n",
|
||
" print(f\" 特征维度: {X_train.shape[1]:,} (时间窗口30 × 512特征)\")\n",
|
||
" print(f\" 目标维度: {y_train.shape[1]} (40个音素概率)\")\n",
|
||
" \n",
|
||
" # 检查数据质量\n",
|
||
" print(f\"\\n🔍 数据质量检查:\")\n",
|
||
" print(f\" 训练特征范围: [{X_train.min():.4f}, {X_train.max():.4f}]\")\n",
|
||
" print(f\" 训练目标范围: [{y_train.min():.4f}, {y_train.max():.4f}]\")\n",
|
||
" print(f\" 训练特征均值: {X_train.mean():.4f}\")\n",
|
||
" print(f\" 训练目标均值: {y_train.mean():.4f}\")\n",
|
||
" \n",
|
||
" # 检查是否有NaN或无穷值\n",
|
||
" nan_count_X = np.isnan(X_train).sum()\n",
|
||
" nan_count_y = np.isnan(y_train).sum()\n",
|
||
" inf_count_X = np.isinf(X_train).sum()\n",
|
||
" inf_count_y = np.isinf(y_train).sum()\n",
|
||
" \n",
|
||
" print(f\" NaN检查: X有{nan_count_X}个, y有{nan_count_y}个\")\n",
|
||
" print(f\" Inf检查: X有{inf_count_X}个, y有{inf_count_y}个\")\n",
|
||
" \n",
|
||
" if nan_count_X > 0 or nan_count_y > 0 or inf_count_X > 0 or inf_count_y > 0:\n",
|
||
" print(\"⚠️ 检测到异常值,将进行清理...\")\n",
|
||
" # 清理异常值\n",
|
||
" valid_mask = ~(np.isnan(X_train).any(axis=1) | np.isnan(y_train).any(axis=1) | \n",
|
||
" np.isinf(X_train).any(axis=1) | np.isinf(y_train).any(axis=1))\n",
|
||
" X_train = X_train[valid_mask]\n",
|
||
" y_train = y_train[valid_mask]\n",
|
||
" \n",
|
||
" if X_val is not None and y_val is not None:\n",
|
||
" valid_mask_val = ~(np.isnan(X_val).any(axis=1) | np.isnan(y_val).any(axis=1) | \n",
|
||
" np.isinf(X_val).any(axis=1) | np.isinf(y_val).any(axis=1))\n",
|
||
" X_val = X_val[valid_mask_val]\n",
|
||
" y_val = y_val[valid_mask_val]\n",
|
||
" \n",
|
||
" print(f\"✅ 数据清理完成,剩余训练样本: {X_train.shape[0]:,}\")\n",
|
||
" \n",
|
||
" # 3. 训练模型\n",
|
||
" print(f\"\\n🌲 第3步: 训练随机森林回归模型\")\n",
|
||
" training_results = rf_regressor.fit(X_train, y_train, X_val, y_val)\n",
|
||
" \n",
|
||
" # 4. 分析训练结果\n",
|
||
" print(f\"\\n📊 第4步: 训练结果分析\")\n",
|
||
" print(\"=\"*50)\n",
|
||
" \n",
|
||
" for metric, value in training_results.items():\n",
|
||
" metric_name = metric.replace('_', ' ').title()\n",
|
||
" print(f\" {metric_name}: {value:.6f}\")\n",
|
||
" \n",
|
||
" # 5. 特征重要性分析\n",
|
||
" print(f\"\\n🔍 第5步: 特征重要性分析\")\n",
|
||
" top_features, all_importances = rf_regressor.get_feature_importance(top_k=20)\n",
|
||
" \n",
|
||
" print(f\"\\n🏆 Top 20 重要特征:\")\n",
|
||
" print(f\"{'排名':>4} {'特征名称':>15} {'重要性':>10}\")\n",
|
||
" print(\"-\" * 35)\n",
|
||
" for i, (feature_name, importance) in enumerate(top_features):\n",
|
||
" print(f\"{i+1:>4} {feature_name:>15} {importance:>10.6f}\")\n",
|
||
" \n",
|
||
" # 分析时间窗口内的重要性分布\n",
|
||
" print(f\"\\n📈 时间窗口重要性分布:\")\n",
|
||
" window_importances = np.zeros(rf_regressor.window_size)\n",
|
||
" for i in range(rf_regressor.window_size):\n",
|
||
" start_idx = i * 512\n",
|
||
" end_idx = (i + 1) * 512\n",
|
||
" window_importances[i] = all_importances[start_idx:end_idx].sum()\n",
|
||
" \n",
|
||
" max_time_step = np.argmax(window_importances)\n",
|
||
" print(f\" 最重要的时间步: t{max_time_step} (重要性: {window_importances[max_time_step]:.6f})\")\n",
|
||
" print(f\" 窗口中心位置: t{rf_regressor.window_size//2}\")\n",
|
||
" print(f\" 重要性分布: 前5个时间步的重要性\")\n",
|
||
" for i in range(min(5, len(window_importances))):\n",
|
||
" print(f\" t{i}: {window_importances[i]:.6f}\")\n",
|
||
" \n",
|
||
" print(f\"\\n✅ 随机森林回归模型训练完成!\")\n",
|
||
" print(f\"🎯 模型可以预测40个音素的概率分布\")\n",
|
||
" print(f\"📊 基于30时间步的神经特征窗口\")\n",
|
||
" print(f\"🌲 使用{rf_regressor.n_estimators}棵决策树\")\n",
|
||
" \n",
|
||
" else:\n",
|
||
" print(\"❌ 数据准备失败,无法训练模型\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"⚠️ 模型尚未训练完成,请先运行训练代码\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 📊 模型评估和可视化分析\n",
|
||
"def evaluate_phoneme_predictions(rf_model, X_test, y_test, dataset_name=\"测试集\"):\n",
|
||
" \"\"\"\n",
|
||
" 评估每个音素的预测性能\n",
|
||
" \"\"\"\n",
|
||
" print(f\"\\n📊 {dataset_name}详细评估\")\n",
|
||
" print(\"=\"*50)\n",
|
||
" \n",
|
||
" # 获取预测结果\n",
|
||
" y_pred = rf_model.predict(X_test)\n",
|
||
" \n",
|
||
" # 计算每个音素的性能指标\n",
|
||
" phoneme_metrics = []\n",
|
||
" \n",
|
||
" for i in range(40): # 40个音素\n",
|
||
" phoneme_name = LOGIT_TO_PHONEME[i]\n",
|
||
" \n",
|
||
" # 计算单个音素的指标\n",
|
||
" mse = mean_squared_error(y_test[:, i], y_pred[:, i])\n",
|
||
" r2 = r2_score(y_test[:, i], y_pred[:, i])\n",
|
||
" mae = mean_absolute_error(y_test[:, i], y_pred[:, i])\n",
|
||
" \n",
|
||
" # 计算相关系数\n",
|
||
" correlation = np.corrcoef(y_test[:, i], y_pred[:, i])[0, 1]\n",
|
||
" \n",
|
||
" phoneme_metrics.append({\n",
|
||
" 'phoneme_id': i,\n",
|
||
" 'phoneme_name': phoneme_name,\n",
|
||
" 'mse': mse, \n",
|
||
" 'r2': r2,\n",
|
||
" 'mae': mae,\n",
|
||
" 'correlation': correlation if not np.isnan(correlation) else 0.0\n",
|
||
" })\n",
|
||
" \n",
|
||
" # 转换为DataFrame便于分析\n",
|
||
" metrics_df = pd.DataFrame(phoneme_metrics)\n",
|
||
" \n",
|
||
" # 打印总体统计\n",
|
||
" print(f\"📈 总体性能指标:\")\n",
|
||
" print(f\" 平均 MSE: {metrics_df['mse'].mean():.6f}\")\n",
|
||
" print(f\" 平均 R²: {metrics_df['r2'].mean():.4f}\")\n",
|
||
" print(f\" 平均 MAE: {metrics_df['mae'].mean():.6f}\")\n",
|
||
" print(f\" 平均相关系数: {metrics_df['correlation'].mean():.4f}\")\n",
|
||
" \n",
|
||
" # 找出最佳和最差预测的音素\n",
|
||
" best_r2_idx = metrics_df['r2'].idxmax()\n",
|
||
" worst_r2_idx = metrics_df['r2'].idxmin()\n",
|
||
" \n",
|
||
" print(f\"\\n🏆 最佳预测音素:\")\n",
|
||
" best_phoneme = metrics_df.loc[best_r2_idx]\n",
|
||
" print(f\" {best_phoneme['phoneme_name']} (ID: {best_phoneme['phoneme_id']})\")\n",
|
||
" print(f\" R²: {best_phoneme['r2']:.4f}, MSE: {best_phoneme['mse']:.6f}\")\n",
|
||
" \n",
|
||
" print(f\"\\n📉 最差预测音素:\")\n",
|
||
" worst_phoneme = metrics_df.loc[worst_r2_idx]\n",
|
||
" print(f\" {worst_phoneme['phoneme_name']} (ID: {worst_phoneme['phoneme_id']})\")\n",
|
||
" print(f\" R²: {worst_phoneme['r2']:.4f}, MSE: {worst_phoneme['mse']:.6f}\")\n",
|
||
" \n",
|
||
" return metrics_df, y_pred\n",
|
||
"\n",
|
||
"def visualize_prediction_results(metrics_df, y_true, y_pred, save_plots=False):\n",
|
||
" \"\"\"\n",
|
||
" 可视化预测结果\n",
|
||
" \"\"\"\n",
|
||
" print(f\"\\n📊 创建可视化图表...\")\n",
|
||
" \n",
|
||
" # 设置图表样式\n",
|
||
" plt.style.use('default')\n",
|
||
" fig = plt.figure(figsize=(20, 12))\n",
|
||
" \n",
|
||
" # 1. R²分数分布\n",
|
||
" plt.subplot(2, 3, 1)\n",
|
||
" plt.hist(metrics_df['r2'], bins=20, alpha=0.7, color='skyblue', edgecolor='black')\n",
|
||
" plt.axvline(metrics_df['r2'].mean(), color='red', linestyle='--', \n",
|
||
" label=f'平均值: {metrics_df[\"r2\"].mean():.4f}')\n",
|
||
" plt.xlabel('R² Score')\n",
|
||
" plt.ylabel('音素数量')\n",
|
||
" plt.title('R² Score 分布')\n",
|
||
" plt.legend()\n",
|
||
" plt.grid(True, alpha=0.3)\n",
|
||
" \n",
|
||
" # 2. MSE分布\n",
|
||
" plt.subplot(2, 3, 2)\n",
|
||
" plt.hist(metrics_df['mse'], bins=20, alpha=0.7, color='lightcoral', edgecolor='black')\n",
|
||
" plt.axvline(metrics_df['mse'].mean(), color='red', linestyle='--',\n",
|
||
" label=f'平均值: {metrics_df[\"mse\"].mean():.6f}')\n",
|
||
" plt.xlabel('Mean Squared Error')\n",
|
||
" plt.ylabel('音素数量')\n",
|
||
" plt.title('MSE 分布')\n",
|
||
" plt.legend()\n",
|
||
" plt.grid(True, alpha=0.3)\n",
|
||
" \n",
|
||
" # 3. 前10个音素的性能对比\n",
|
||
" plt.subplot(2, 3, 3)\n",
|
||
" top_10 = metrics_df.nlargest(10, 'r2')\n",
|
||
" bars = plt.bar(range(10), top_10['r2'], color='lightgreen', alpha=0.7)\n",
|
||
" plt.xlabel('音素排名')\n",
|
||
" plt.ylabel('R² Score')\n",
|
||
" plt.title('Top 10 音素预测性能')\n",
|
||
" plt.xticks(range(10), top_10['phoneme_name'], rotation=45)\n",
|
||
" plt.grid(True, alpha=0.3)\n",
|
||
" \n",
|
||
" # 添加数值标签\n",
|
||
" for i, bar in enumerate(bars):\n",
|
||
" height = bar.get_height()\n",
|
||
" plt.text(bar.get_x() + bar.get_width()/2., height + 0.001,\n",
|
||
" f'{height:.3f}', ha='center', va='bottom', fontsize=8)\n",
|
||
" \n",
|
||
" # 4. 真实值 vs 预测值散点图 (选择最佳音素)\n",
|
||
" plt.subplot(2, 3, 4)\n",
|
||
" best_phoneme_idx = metrics_df['r2'].idxmax()\n",
|
||
" phoneme_id = metrics_df.loc[best_phoneme_idx, 'phoneme_id']\n",
|
||
" phoneme_name = metrics_df.loc[best_phoneme_idx, 'phoneme_name']\n",
|
||
" \n",
|
||
" # 随机采样1000个点以避免图表过于密集\n",
|
||
" sample_size = min(1000, len(y_true))\n",
|
||
" sample_indices = np.random.choice(len(y_true), sample_size, replace=False)\n",
|
||
" \n",
|
||
" plt.scatter(y_true[sample_indices, phoneme_id], y_pred[sample_indices, phoneme_id], \n",
|
||
" alpha=0.6, s=20, color='blue')\n",
|
||
" \n",
|
||
" # 添加对角线 (完美预测线)\n",
|
||
" min_val = min(y_true[:, phoneme_id].min(), y_pred[:, phoneme_id].min())\n",
|
||
" max_val = max(y_true[:, phoneme_id].max(), y_pred[:, phoneme_id].max())\n",
|
||
" plt.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8, label='完美预测')\n",
|
||
" \n",
|
||
" plt.xlabel('真实值')\n",
|
||
" plt.ylabel('预测值')\n",
|
||
" plt.title(f'最佳音素 {phoneme_name} 的预测结果')\n",
|
||
" plt.legend()\n",
|
||
" plt.grid(True, alpha=0.3)\n",
|
||
" \n",
|
||
" # 5. 相关系数热力图 (前20个音素)\n",
|
||
" plt.subplot(2, 3, 5)\n",
|
||
" top_20_correlations = metrics_df.nlargest(20, 'correlation')\n",
|
||
" corr_data = top_20_correlations[['phoneme_name', 'correlation']].set_index('phoneme_name')\n",
|
||
" \n",
|
||
" # 创建热力图数据\n",
|
||
" heatmap_data = corr_data.values.reshape(-1, 1)\n",
|
||
" im = plt.imshow(heatmap_data.T, cmap='RdYlBu_r', aspect='auto', vmin=0, vmax=1)\n",
|
||
" \n",
|
||
" plt.colorbar(im, shrink=0.8)\n",
|
||
" plt.yticks([0], ['相关系数'])\n",
|
||
" plt.xticks(range(len(top_20_correlations)), top_20_correlations['phoneme_name'], \n",
|
||
" rotation=45, ha='right')\n",
|
||
" plt.title('Top 20 音素相关系数')\n",
|
||
" \n",
|
||
" # 6. 各音素预测误差箱线图 (前10个音素)\n",
|
||
" plt.subplot(2, 3, 6)\n",
|
||
" top_10_ids = metrics_df.nlargest(10, 'r2')['phoneme_id'].values\n",
|
||
" errors_data = []\n",
|
||
" labels = []\n",
|
||
" \n",
|
||
" for phoneme_id in top_10_ids:\n",
|
||
" errors = np.abs(y_true[:, phoneme_id] - y_pred[:, phoneme_id])\n",
|
||
" errors_data.append(errors)\n",
|
||
" labels.append(LOGIT_TO_PHONEME[phoneme_id])\n",
|
||
" \n",
|
||
" plt.boxplot(errors_data, labels=labels)\n",
|
||
" plt.xlabel('音素')\n",
|
||
" plt.ylabel('绝对误差')\n",
|
||
" plt.title('Top 10 音素预测误差分布')\n",
|
||
" plt.xticks(rotation=45)\n",
|
||
" plt.grid(True, alpha=0.3)\n",
|
||
" \n",
|
||
" plt.tight_layout()\n",
|
||
" \n",
|
||
" if save_plots:\n",
|
||
" plt.savefig('./processed_datasets/rf_regression_results.png', dpi=300, bbox_inches='tight')\n",
|
||
" print(\"📁 图表已保存至: ./processed_datasets/rf_regression_results.png\")\n",
|
||
" \n",
|
||
" plt.show()\n",
|
||
"\n",
|
||
"# 如果模型训练成功,进行评估\n",
|
||
"if 'rf_regressor' in locals() and rf_regressor.is_fitted:\n",
|
||
" print(f\"\\n🎯 开始模型评估和可视化\")\n",
|
||
" \n",
|
||
" # 评估验证集\n",
|
||
" if X_val is not None and y_val is not None:\n",
|
||
" val_metrics, val_predictions = evaluate_phoneme_predictions(\n",
|
||
" rf_regressor, X_val, y_val, \"验证集\"\n",
|
||
" )\n",
|
||
" \n",
|
||
" # 可视化结果\n",
|
||
" visualize_prediction_results(val_metrics, y_val, val_predictions, save_plots=True)\n",
|
||
" \n",
|
||
" # 保存详细结果\n",
|
||
" val_metrics.to_csv('./processed_datasets/phoneme_prediction_metrics.csv', index=False)\n",
|
||
" print(f\"\\n📁 详细评估结果已保存至: ./processed_datasets/phoneme_prediction_metrics.csv\")\n",
|
||
" \n",
|
||
" # 准备测试集数据 (如果有)\n",
|
||
" if test_datasets:\n",
|
||
" print(f\"\\n🔮 准备测试集预测...\")\n",
|
||
" X_test, y_test = rf_regressor.prepare_dataset_for_training(test_datasets, \"测试集\")\n",
|
||
" \n",
|
||
" if X_test is not None:\n",
|
||
" test_metrics, test_predictions = evaluate_phoneme_predictions(\n",
|
||
" rf_regressor, X_test, y_test, \"测试集\"\n",
|
||
" )\n",
|
||
" print(f\"\\n✅ 测试集评估完成\")\n",
|
||
" else:\n",
|
||
" print(f\"⚠️ 测试集数据准备失败\")\n",
|
||
" \n",
|
||
" print(f\"\\n🎉 随机森林回归模型完整评估完成!\")\n",
|
||
" print(f\"📊 生成了详细的性能分析和可视化图表\")\n",
|
||
" print(f\"🔧 模型已准备好用于实际预测任务\")\n",
|
||
" \n",
|
||
"else:\n",
|
||
" print(\"⚠️ 模型尚未训练完成,请先运行训练代码\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"✅ 回归转分类分析功能已创建!\n",
|
||
"🎯 主要功能:\n",
|
||
"• 将40维概率回归结果转换为分类预测\n",
|
||
"• 计算分类准确率和置信度分析\n",
|
||
"• 提供Top-K准确率评估\n",
|
||
"• 生成详细的混淆矩阵和错误分析\n",
|
||
"• 创建全面的可视化图表\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 🎯 回归结果转分类结果分析\n",
|
||
"def regression_to_classification_analysis(y_true_probs, y_pred_probs, show_detailed_metrics=True):\n",
|
||
" \"\"\"\n",
|
||
" 将回归预测的40个音素概率转换为分类结果并进行分析\n",
|
||
" \n",
|
||
" 参数:\n",
|
||
" y_true_probs: 真实的40个音素概率 [n_samples, 40]\n",
|
||
" y_pred_probs: 预测的40个音素概率 [n_samples, 40]\n",
|
||
" show_detailed_metrics: 是否显示详细的分类指标\n",
|
||
" \n",
|
||
" 返回:\n",
|
||
" classification_results: 包含分类结果的字典\n",
|
||
" \"\"\"\n",
|
||
" print(\"🎯 回归结果转分类结果分析\")\n",
|
||
" print(\"=\"*60)\n",
|
||
" \n",
|
||
" # 1. 将概率转换为分类标签\n",
|
||
" y_true_classes = np.argmax(y_true_probs, axis=1) # 真实类别\n",
|
||
" y_pred_classes = np.argmax(y_pred_probs, axis=1) # 预测类别\n",
|
||
" \n",
|
||
" # 2. 计算分类准确率\n",
|
||
" accuracy = (y_true_classes == y_pred_classes).mean()\n",
|
||
" \n",
|
||
" print(f\"📊 分类结果概览:\")\n",
|
||
" print(f\" 总样本数: {len(y_true_classes):,}\")\n",
|
||
" print(f\" 分类准确率: {accuracy:.4f} ({accuracy*100:.2f}%)\")\n",
|
||
" print(f\" 正确预测: {(y_true_classes == y_pred_classes).sum():,}\")\n",
|
||
" print(f\" 错误预测: {(y_true_classes != y_pred_classes).sum():,}\")\n",
|
||
" \n",
|
||
" # 3. 分析预测置信度\n",
|
||
" pred_confidences = np.max(y_pred_probs, axis=1) # 预测的最大概率\n",
|
||
" true_confidences = np.max(y_true_probs, axis=1) # 真实的最大概率\n",
|
||
" \n",
|
||
" print(f\"\\n🔍 预测置信度分析:\")\n",
|
||
" print(f\" 预测置信度均值: {pred_confidences.mean():.4f}\")\n",
|
||
" print(f\" 预测置信度标准差: {pred_confidences.std():.4f}\")\n",
|
||
" print(f\" 预测置信度范围: [{pred_confidences.min():.4f}, {pred_confidences.max():.4f}]\")\n",
|
||
" print(f\" 真实置信度均值: {true_confidences.mean():.4f}\")\n",
|
||
" \n",
|
||
" # 4. 按置信度分组的准确率分析\n",
|
||
" confidence_bins = [0.0, 0.3, 0.5, 0.7, 0.9, 1.0]\n",
|
||
" print(f\"\\n📈 按预测置信度分组的准确率:\")\n",
|
||
" print(f\"{'置信度区间':>12} {'样本数':>8} {'准确率':>8} {'百分比':>8}\")\n",
|
||
" print(\"-\" * 40)\n",
|
||
" \n",
|
||
" for i in range(len(confidence_bins)-1):\n",
|
||
" low, high = confidence_bins[i], confidence_bins[i+1]\n",
|
||
" mask = (pred_confidences >= low) & (pred_confidences < high)\n",
|
||
" if i == len(confidence_bins)-2: # 最后一个区间包含等号\n",
|
||
" mask = (pred_confidences >= low) & (pred_confidences <= high)\n",
|
||
" \n",
|
||
" if mask.sum() > 0:\n",
|
||
" bin_accuracy = (y_true_classes[mask] == y_pred_classes[mask]).mean()\n",
|
||
" count = mask.sum()\n",
|
||
" percentage = count / len(y_true_classes) * 100\n",
|
||
" print(f\"[{low:.1f}, {high:.1f}{')'if i<len(confidence_bins)-2 else ']':>1} {count:>8} {bin_accuracy:>8.4f} {percentage:>7.1f}%\")\n",
|
||
" \n",
|
||
" # 5. 混淆矩阵分析(Top-K音素)\n",
|
||
" from collections import Counter\n",
|
||
" \n",
|
||
" # 找出最常见的音素\n",
|
||
" true_counter = Counter(y_true_classes)\n",
|
||
" pred_counter = Counter(y_pred_classes)\n",
|
||
" \n",
|
||
" most_common_true = true_counter.most_common(10)\n",
|
||
" most_common_pred = pred_counter.most_common(10)\n",
|
||
" \n",
|
||
" print(f\"\\n🏆 最常见的音素 (真实 vs 预测):\")\n",
|
||
" print(f\"{'真实音素':>12} {'次数':>6} {'预测音素':>12} {'次数':>6}\")\n",
|
||
" print(\"-\" * 42)\n",
|
||
" \n",
|
||
" for i in range(min(len(most_common_true), len(most_common_pred))):\n",
|
||
" true_id, true_count = most_common_true[i]\n",
|
||
" pred_id, pred_count = most_common_pred[i]\n",
|
||
" true_name = LOGIT_TO_PHONEME[true_id]\n",
|
||
" pred_name = LOGIT_TO_PHONEME[pred_id]\n",
|
||
" print(f\"{true_name:>12} {true_count:>6} {pred_name:>12} {pred_count:>6}\")\n",
|
||
" \n",
|
||
" # 6. 每个音素的分类性能\n",
|
||
" if show_detailed_metrics:\n",
|
||
" from sklearn.metrics import classification_report, confusion_matrix\n",
|
||
" \n",
|
||
" print(f\"\\n📋 详细分类报告 (前20个最常见音素):\")\n",
|
||
" \n",
|
||
" # 获取前20个最常见的音素\n",
|
||
" top_20_phonemes = [phoneme_id for phoneme_id, _ in most_common_true[:20]]\n",
|
||
" \n",
|
||
" # 创建掩码,只包含这些音素\n",
|
||
" mask_top20 = np.isin(y_true_classes, top_20_phonemes)\n",
|
||
" y_true_top20 = y_true_classes[mask_top20]\n",
|
||
" y_pred_top20 = y_pred_classes[mask_top20]\n",
|
||
" \n",
|
||
" # 生成分类报告\n",
|
||
" target_names = [LOGIT_TO_PHONEME[i] for i in top_20_phonemes]\n",
|
||
" \n",
|
||
" try:\n",
|
||
" report = classification_report(\n",
|
||
" y_true_top20, y_pred_top20, \n",
|
||
" labels=top_20_phonemes,\n",
|
||
" target_names=target_names,\n",
|
||
" output_dict=True,\n",
|
||
" zero_division=0\n",
|
||
" )\n",
|
||
" \n",
|
||
" # 打印格式化的报告\n",
|
||
" print(f\"{'音素':>8} {'精确率':>8} {'召回率':>8} {'F1分数':>8} {'支持数':>8}\")\n",
|
||
" print(\"-\" * 48)\n",
|
||
" \n",
|
||
" for phoneme_id in top_20_phonemes:\n",
|
||
" phoneme_name = LOGIT_TO_PHONEME[phoneme_id]\n",
|
||
" if phoneme_name in report:\n",
|
||
" metrics = report[phoneme_name]\n",
|
||
" print(f\"{phoneme_name:>8} {metrics['precision']:>8.4f} {metrics['recall']:>8.4f} \"\n",
|
||
" f\"{metrics['f1-score']:>8.4f} {int(metrics['support']):>8}\")\n",
|
||
" \n",
|
||
" # 总体指标\n",
|
||
" macro_avg = report['macro avg']\n",
|
||
" weighted_avg = report['weighted avg']\n",
|
||
" print(\"-\" * 48)\n",
|
||
" print(f\"{'宏平均':>8} {macro_avg['precision']:>8.4f} {macro_avg['recall']:>8.4f} \"\n",
|
||
" f\"{macro_avg['f1-score']:>8.4f}\")\n",
|
||
" print(f\"{'加权平均':>8} {weighted_avg['precision']:>8.4f} {weighted_avg['recall']:>8.4f} \"\n",
|
||
" f\"{weighted_avg['f1-score']:>8.4f}\")\n",
|
||
" \n",
|
||
" except Exception as e:\n",
|
||
" print(f\"分类报告生成失败: {e}\")\n",
|
||
" \n",
|
||
" # 7. Top-K准确率分析\n",
|
||
" print(f\"\\n🎯 Top-K 准确率分析:\")\n",
|
||
" for k in [1, 3, 5, 10]:\n",
|
||
" # 计算Top-K准确率\n",
|
||
" top_k_pred = np.argsort(y_pred_probs, axis=1)[:, -k:] # 取概率最高的K个\n",
|
||
" top_k_accuracy = np.mean([y_true_classes[i] in top_k_pred[i] for i in range(len(y_true_classes))])\n",
|
||
" print(f\" Top-{k} 准确率: {top_k_accuracy:.4f} ({top_k_accuracy*100:.2f}%)\")\n",
|
||
" \n",
|
||
" # 8. 错误分析 - 最常见的预测错误\n",
|
||
" print(f\"\\n❌ 最常见的预测错误:\")\n",
|
||
" error_mask = y_true_classes != y_pred_classes\n",
|
||
" error_pairs = list(zip(y_true_classes[error_mask], y_pred_classes[error_mask]))\n",
|
||
" error_counter = Counter(error_pairs)\n",
|
||
" \n",
|
||
" print(f\"{'真实音素':>12} {'预测音素':>12} {'错误次数':>8}\")\n",
|
||
" print(\"-\" * 36)\n",
|
||
" for (true_id, pred_id), count in error_counter.most_common(10):\n",
|
||
" true_name = LOGIT_TO_PHONEME[true_id]\n",
|
||
" pred_name = LOGIT_TO_PHONEME[pred_id]\n",
|
||
" print(f\"{true_name:>12} {pred_name:>12} {count:>8}\")\n",
|
||
" \n",
|
||
" # 返回结果字典\n",
|
||
" classification_results = {\n",
|
||
" 'accuracy': accuracy,\n",
|
||
" 'y_true_classes': y_true_classes,\n",
|
||
" 'y_pred_classes': y_pred_classes,\n",
|
||
" 'pred_confidences': pred_confidences,\n",
|
||
" 'true_confidences': true_confidences,\n",
|
||
" 'most_common_errors': error_counter.most_common(10)\n",
|
||
" }\n",
|
||
" \n",
|
||
" return classification_results\n",
|
||
"\n",
|
||
"def create_classification_visualizations(y_true_probs, y_pred_probs, classification_results):\n",
|
||
" \"\"\"\n",
|
||
" 为分类结果创建可视化图表\n",
|
||
" \"\"\"\n",
|
||
" print(f\"\\n📊 创建分类结果可视化...\")\n",
|
||
" \n",
|
||
" fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n",
|
||
" fig.suptitle('随机森林回归转分类结果分析', fontsize=16, fontweight='bold')\n",
|
||
" \n",
|
||
" y_true_classes = classification_results['y_true_classes']\n",
|
||
" y_pred_classes = classification_results['y_pred_classes']\n",
|
||
" pred_confidences = classification_results['pred_confidences']\n",
|
||
" \n",
|
||
" # 1. 预测置信度分布\n",
|
||
" axes[0, 0].hist(pred_confidences, bins=50, alpha=0.7, color='skyblue', edgecolor='black')\n",
|
||
" axes[0, 0].axvline(pred_confidences.mean(), color='red', linestyle='--', \n",
|
||
" label=f'均值: {pred_confidences.mean():.3f}')\n",
|
||
" axes[0, 0].set_xlabel('预测置信度')\n",
|
||
" axes[0, 0].set_ylabel('样本数量')\n",
|
||
" axes[0, 0].set_title('预测置信度分布')\n",
|
||
" axes[0, 0].legend()\n",
|
||
" axes[0, 0].grid(True, alpha=0.3)\n",
|
||
" \n",
|
||
" # 2. 准确率 vs 置信度\n",
|
||
" confidence_bins = np.linspace(0, 1, 21)\n",
|
||
" bin_centers = (confidence_bins[:-1] + confidence_bins[1:]) / 2\n",
|
||
" bin_accuracies = []\n",
|
||
" bin_counts = []\n",
|
||
" \n",
|
||
" for i in range(len(confidence_bins)-1):\n",
|
||
" mask = (pred_confidences >= confidence_bins[i]) & (pred_confidences < confidence_bins[i+1])\n",
|
||
" if mask.sum() > 0:\n",
|
||
" accuracy = (y_true_classes[mask] == y_pred_classes[mask]).mean()\n",
|
||
" bin_accuracies.append(accuracy)\n",
|
||
" bin_counts.append(mask.sum())\n",
|
||
" else:\n",
|
||
" bin_accuracies.append(0)\n",
|
||
" bin_counts.append(0)\n",
|
||
" \n",
|
||
" # 只显示有数据的bins\n",
|
||
" valid_bins = np.array(bin_counts) > 0\n",
|
||
" axes[0, 1].plot(bin_centers[valid_bins], np.array(bin_accuracies)[valid_bins], \n",
|
||
" 'bo-', linewidth=2, markersize=6)\n",
|
||
" axes[0, 1].set_xlabel('预测置信度')\n",
|
||
" axes[0, 1].set_ylabel('准确率')\n",
|
||
" axes[0, 1].set_title('准确率 vs 预测置信度')\n",
|
||
" axes[0, 1].grid(True, alpha=0.3)\n",
|
||
" axes[0, 1].set_ylim(0, 1)\n",
|
||
" \n",
|
||
" # 3. 最常见音素的预测准确率\n",
|
||
" from collections import Counter\n",
|
||
" true_counter = Counter(y_true_classes)\n",
|
||
" most_common_phonemes = [phoneme_id for phoneme_id, _ in true_counter.most_common(15)]\n",
|
||
" \n",
|
||
" phoneme_accuracies = []\n",
|
||
" phoneme_names = []\n",
|
||
" for phoneme_id in most_common_phonemes:\n",
|
||
" mask = y_true_classes == phoneme_id\n",
|
||
" if mask.sum() > 0:\n",
|
||
" accuracy = (y_pred_classes[mask] == phoneme_id).mean()\n",
|
||
" phoneme_accuracies.append(accuracy)\n",
|
||
" phoneme_names.append(LOGIT_TO_PHONEME[phoneme_id])\n",
|
||
" \n",
|
||
" bars = axes[0, 2].bar(range(len(phoneme_names)), phoneme_accuracies, \n",
|
||
" color='lightgreen', alpha=0.7)\n",
|
||
" axes[0, 2].set_xlabel('音素')\n",
|
||
" axes[0, 2].set_ylabel('准确率')\n",
|
||
" axes[0, 2].set_title('Top 15 音素的分类准确率')\n",
|
||
" axes[0, 2].set_xticks(range(len(phoneme_names)))\n",
|
||
" axes[0, 2].set_xticklabels(phoneme_names, rotation=45, ha='right')\n",
|
||
" axes[0, 2].grid(True, alpha=0.3)\n",
|
||
" \n",
|
||
" # 添加数值标签\n",
|
||
" for bar, acc in zip(bars, phoneme_accuracies):\n",
|
||
" height = bar.get_height()\n",
|
||
" axes[0, 2].text(bar.get_x() + bar.get_width()/2., height + 0.01,\n",
|
||
" f'{acc:.3f}', ha='center', va='bottom', fontsize=8)\n",
|
||
" \n",
|
||
" # 4. 混淆矩阵(前10个最常见音素)\n",
|
||
" from sklearn.metrics import confusion_matrix\n",
|
||
" top_10_phonemes = most_common_phonemes[:10]\n",
|
||
" mask_top10 = np.isin(y_true_classes, top_10_phonemes) & np.isin(y_pred_classes, top_10_phonemes)\n",
|
||
" \n",
|
||
" if mask_top10.sum() > 0:\n",
|
||
" cm = confusion_matrix(y_true_classes[mask_top10], y_pred_classes[mask_top10], \n",
|
||
" labels=top_10_phonemes)\n",
|
||
" \n",
|
||
" im = axes[1, 0].imshow(cm, interpolation='nearest', cmap='Blues')\n",
|
||
" axes[1, 0].set_title('混淆矩阵 (Top 10 音素)')\n",
|
||
" \n",
|
||
" # 添加颜色条\n",
|
||
" cbar = plt.colorbar(im, ax=axes[1, 0], shrink=0.8)\n",
|
||
" cbar.set_label('预测次数')\n",
|
||
" \n",
|
||
" # 设置标签\n",
|
||
" tick_marks = np.arange(len(top_10_phonemes))\n",
|
||
" top_10_names = [LOGIT_TO_PHONEME[i] for i in top_10_phonemes]\n",
|
||
" axes[1, 0].set_xticks(tick_marks)\n",
|
||
" axes[1, 0].set_yticks(tick_marks)\n",
|
||
" axes[1, 0].set_xticklabels(top_10_names, rotation=45, ha='right')\n",
|
||
" axes[1, 0].set_yticklabels(top_10_names)\n",
|
||
" axes[1, 0].set_xlabel('预测音素')\n",
|
||
" axes[1, 0].set_ylabel('真实音素')\n",
|
||
" \n",
|
||
" # 5. Top-K准确率\n",
|
||
" k_values = [1, 2, 3, 4, 5, 10, 15, 20]\n",
|
||
" top_k_accuracies = []\n",
|
||
" \n",
|
||
" for k in k_values:\n",
|
||
" top_k_pred = np.argsort(y_pred_probs, axis=1)[:, -k:]\n",
|
||
" top_k_accuracy = np.mean([y_true_classes[i] in top_k_pred[i] for i in range(len(y_true_classes))])\n",
|
||
" top_k_accuracies.append(top_k_accuracy)\n",
|
||
" \n",
|
||
" axes[1, 1].plot(k_values, top_k_accuracies, 'ro-', linewidth=2, markersize=8)\n",
|
||
" axes[1, 1].set_xlabel('K 值')\n",
|
||
" axes[1, 1].set_ylabel('Top-K 准确率')\n",
|
||
" axes[1, 1].set_title('Top-K 准确率曲线')\n",
|
||
" axes[1, 1].grid(True, alpha=0.3)\n",
|
||
" axes[1, 1].set_ylim(0, 1)\n",
|
||
" \n",
|
||
" # 添加数值标签\n",
|
||
" for k, acc in zip(k_values, top_k_accuracies):\n",
|
||
" axes[1, 1].annotate(f'{acc:.3f}', (k, acc), textcoords=\"offset points\", \n",
|
||
" xytext=(0,10), ha='center')\n",
|
||
" \n",
|
||
" # 6. 错误分析 - 最常见错误的热力图\n",
|
||
" error_pairs = classification_results['most_common_errors'][:25] # 前25个最常见错误\n",
|
||
" if error_pairs:\n",
|
||
" # 创建错误矩阵\n",
|
||
" unique_phonemes = list(set([pair[0][0] for pair in error_pairs] + [pair[0][1] for pair in error_pairs]))\n",
|
||
" error_matrix = np.zeros((len(unique_phonemes), len(unique_phonemes)))\n",
|
||
" \n",
|
||
" phoneme_to_idx = {phoneme: i for i, phoneme in enumerate(unique_phonemes)}\n",
|
||
" \n",
|
||
" for (true_id, pred_id), count in error_pairs:\n",
|
||
" if true_id in phoneme_to_idx and pred_id in phoneme_to_idx:\n",
|
||
" true_idx = phoneme_to_idx[true_id]\n",
|
||
" pred_idx = phoneme_to_idx[pred_id]\n",
|
||
" error_matrix[true_idx, pred_idx] = count\n",
|
||
" \n",
|
||
" im = axes[1, 2].imshow(error_matrix, cmap='Reds', interpolation='nearest')\n",
|
||
" axes[1, 2].set_title('最常见错误分布')\n",
|
||
" \n",
|
||
" # 设置标签\n",
|
||
" phoneme_names = [LOGIT_TO_PHONEME[p] for p in unique_phonemes]\n",
|
||
" axes[1, 2].set_xticks(range(len(phoneme_names)))\n",
|
||
" axes[1, 2].set_yticks(range(len(phoneme_names)))\n",
|
||
" axes[1, 2].set_xticklabels(phoneme_names, rotation=45, ha='right')\n",
|
||
" axes[1, 2].set_yticklabels(phoneme_names)\n",
|
||
" axes[1, 2].set_xlabel('预测音素')\n",
|
||
" axes[1, 2].set_ylabel('真实音素')\n",
|
||
" \n",
|
||
" # 添加颜色条\n",
|
||
" cbar = plt.colorbar(im, ax=axes[1, 2], shrink=0.8)\n",
|
||
" cbar.set_label('错误次数')\n",
|
||
" \n",
|
||
" plt.tight_layout()\n",
|
||
" plt.savefig('./processed_datasets/classification_analysis.png', dpi=300, bbox_inches='tight')\n",
|
||
" print(\"📁 分类分析图表已保存至: ./processed_datasets/classification_analysis.png\")\n",
|
||
" plt.show()\n",
|
||
"\n",
|
||
"print(\"✅ 回归转分类分析功能已创建!\")\n",
|
||
"print(\"🎯 主要功能:\")\n",
|
||
"print(\"• 将40维概率回归结果转换为分类预测\")\n",
|
||
"print(\"• 计算分类准确率和置信度分析\")\n",
|
||
"print(\"• 提供Top-K准确率评估\")\n",
|
||
"print(\"• 生成详细的混淆矩阵和错误分析\")\n",
|
||
"print(\"• 创建全面的可视化图表\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"⚠️ 随机森林模型尚未训练完成\n",
|
||
"💡 请先运行前面的训练代码\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 🎯 完整的回归转分类评估流程\n",
|
||
"def complete_regression_classification_evaluation(rf_model, X_test, y_test, dataset_name=\"测试集\"):\n",
|
||
" \"\"\"\n",
|
||
" 完整的回归模型转分类结果评估流程\n",
|
||
" \"\"\"\n",
|
||
" print(f\"\\n🎯 {dataset_name}完整评估: 回归 → 分类\")\n",
|
||
" print(\"=\"*70)\n",
|
||
" \n",
|
||
" # 1. 获取回归预测结果\n",
|
||
" print(\"📊 第1步: 获取回归预测...\")\n",
|
||
" y_pred_probs = rf_model.predict(X_test)\n",
|
||
" \n",
|
||
" # 确保概率值在合理范围内\n",
|
||
" y_pred_probs = np.clip(y_pred_probs, 0, 1)\n",
|
||
" \n",
|
||
" # 2. 回归性能评估\n",
|
||
" print(\"\\n📈 第2步: 回归性能评估...\")\n",
|
||
" mse = mean_squared_error(y_test, y_pred_probs) \n",
|
||
" mae = mean_absolute_error(y_test, y_pred_probs)\n",
|
||
" r2 = r2_score(y_test, y_pred_probs)\n",
|
||
" \n",
|
||
" print(f\" 回归 MSE: {mse:.6f}\")\n",
|
||
" print(f\" 回归 MAE: {mae:.6f}\")\n",
|
||
" print(f\" 回归 R²: {r2:.4f}\")\n",
|
||
" \n",
|
||
" # 3. 概率归一化(softmax)\n",
|
||
" print(\"\\n🔄 第3步: 概率归一化...\")\n",
|
||
" def softmax(x, axis=-1):\n",
|
||
" exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))\n",
|
||
" return exp_x / np.sum(exp_x, axis=axis, keepdims=True)\n",
|
||
" \n",
|
||
" # 对预测结果应用softmax,使其成为真正的概率分布\n",
|
||
" y_pred_probs_normalized = softmax(y_pred_probs)\n",
|
||
" y_test_normalized = softmax(y_test) # 也对真实标签归一化\n",
|
||
" \n",
|
||
" print(f\" 预测概率归一化前: 每行和均值 = {np.mean(np.sum(y_pred_probs, axis=1)):.4f}\")\n",
|
||
" print(f\" 预测概率归一化后: 每行和均值 = {np.mean(np.sum(y_pred_probs_normalized, axis=1)):.4f}\")\n",
|
||
" \n",
|
||
" # 4. 分类结果分析\n",
|
||
" print(\"\\n🎯 第4步: 分类结果分析...\")\n",
|
||
" classification_results = regression_to_classification_analysis(\n",
|
||
" y_test_normalized, y_pred_probs_normalized, show_detailed_metrics=True\n",
|
||
" )\n",
|
||
" \n",
|
||
" # 5. 创建可视化\n",
|
||
" print(\"\\n📊 第5步: 创建可视化图表...\")\n",
|
||
" create_classification_visualizations(y_test_normalized, y_pred_probs_normalized, classification_results)\n",
|
||
" \n",
|
||
" # 6. 保存结果\n",
|
||
" print(\"\\n💾 第6步: 保存分析结果...\")\n",
|
||
" \n",
|
||
" # 保存分类结果\n",
|
||
" results_df = pd.DataFrame({\n",
|
||
" 'true_class': classification_results['y_true_classes'],\n",
|
||
" 'pred_class': classification_results['y_pred_classes'],\n",
|
||
" 'true_phoneme': [LOGIT_TO_PHONEME[i] for i in classification_results['y_true_classes']],\n",
|
||
" 'pred_phoneme': [LOGIT_TO_PHONEME[i] for i in classification_results['y_pred_classes']],\n",
|
||
" 'pred_confidence': classification_results['pred_confidences'],\n",
|
||
" 'is_correct': classification_results['y_true_classes'] == classification_results['y_pred_classes']\n",
|
||
" })\n",
|
||
" \n",
|
||
" results_df.to_csv('./processed_datasets/classification_results.csv', index=False)\n",
|
||
" \n",
|
||
" # 保存详细的概率预测\n",
|
||
" prob_results_df = pd.DataFrame(y_pred_probs_normalized, \n",
|
||
" columns=[f'prob_{LOGIT_TO_PHONEME[i]}' for i in range(40)])\n",
|
||
" prob_results_df['true_class'] = classification_results['y_true_classes']\n",
|
||
" prob_results_df['pred_class'] = classification_results['y_pred_classes']\n",
|
||
" \n",
|
||
" prob_results_df.to_csv('./processed_datasets/probability_predictions.csv', index=False)\n",
|
||
" \n",
|
||
" print(\"📁 结果已保存:\")\n",
|
||
" print(\" • ./processed_datasets/classification_results.csv (分类结果)\")\n",
|
||
" print(\" • ./processed_datasets/probability_predictions.csv (概率预测)\")\n",
|
||
" print(\" • ./processed_datasets/classification_analysis.png (可视化图表)\")\n",
|
||
" \n",
|
||
" # 7. 总结报告\n",
|
||
" print(f\"\\n📋 {dataset_name}评估总结:\")\n",
|
||
" print(\"=\"*50)\n",
|
||
" print(f\"🔸 回归性能:\")\n",
|
||
" print(f\" MSE: {mse:.6f}\")\n",
|
||
" print(f\" R²: {r2:.4f}\")\n",
|
||
" print(f\"🔸 分类性能:\")\n",
|
||
" print(f\" 准确率: {classification_results['accuracy']:.4f} ({classification_results['accuracy']*100:.2f}%)\")\n",
|
||
" print(f\" 平均置信度: {classification_results['pred_confidences'].mean():.4f}\")\n",
|
||
" \n",
|
||
" # 计算Top-K准确率\n",
|
||
" for k in [1, 3, 5]:\n",
|
||
" top_k_pred = np.argsort(y_pred_probs_normalized, axis=1)[:, -k:]\n",
|
||
" top_k_accuracy = np.mean([classification_results['y_true_classes'][i] in top_k_pred[i] \n",
|
||
" for i in range(len(classification_results['y_true_classes']))])\n",
|
||
" print(f\" Top-{k} 准确率: {top_k_accuracy:.4f} ({top_k_accuracy*100:.2f}%)\")\n",
|
||
" \n",
|
||
" return {\n",
|
||
" 'regression_metrics': {'mse': mse, 'mae': mae, 'r2': r2},\n",
|
||
" 'classification_results': classification_results,\n",
|
||
" 'normalized_predictions': y_pred_probs_normalized,\n",
|
||
" 'normalized_true': y_test_normalized\n",
|
||
" }\n",
|
||
"\n",
|
||
"# 如果模型已训练且有验证数据,执行完整评估\n",
|
||
"if 'rf_regressor' in locals() and hasattr(rf_regressor, 'is_fitted') and rf_regressor.is_fitted:\n",
|
||
" if 'X_val' in locals() and X_val is not None and 'y_val' in locals() and y_val is not None:\n",
|
||
" print(\"🚀 开始完整的回归转分类评估...\")\n",
|
||
" \n",
|
||
" # 执行完整评估\n",
|
||
" evaluation_results = complete_regression_classification_evaluation(\n",
|
||
" rf_regressor, X_val, y_val, \"验证集\"\n",
|
||
" )\n",
|
||
" \n",
|
||
" print(f\"\\n🎉 评估完成!\")\n",
|
||
" print(f\"✅ 随机森林回归模型成功转换为分类结果\")\n",
|
||
" print(f\"📊 生成了详细的性能分析和可视化\")\n",
|
||
" print(f\"💾 所有结果已保存到文件\")\n",
|
||
" \n",
|
||
" # 如果有测试数据,也进行评估\n",
|
||
" if 'X_test' in locals() and X_test is not None and 'y_test' in locals() and y_test is not None:\n",
|
||
" print(f\"\\n🔮 开始测试集评估...\")\n",
|
||
" test_evaluation_results = complete_regression_classification_evaluation(\n",
|
||
" rf_regressor, X_test, y_test, \"测试集\"\n",
|
||
" )\n",
|
||
" else:\n",
|
||
" print(\"⚠️ 没有可用的验证数据进行评估\")\n",
|
||
"else:\n",
|
||
" print(\"⚠️ 随机森林模型尚未训练完成\")\n",
|
||
" print(\"💡 请先运行前面的训练代码\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kaggle": {
|
||
"accelerator": "tpu1vmV38",
|
||
"dataSources": [
|
||
{
|
||
"databundleVersionId": 13056355,
|
||
"sourceId": 106809,
|
||
"sourceType": "competition"
|
||
}
|
||
],
|
||
"dockerImageVersionId": 31091,
|
||
"isGpuEnabled": false,
|
||
"isInternetEnabled": true,
|
||
"language": "python",
|
||
"sourceType": "notebook"
|
||
},
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 4
|
||
}
|