{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 环境配置 与 utils" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://download.pytorch.org/whl/cu126\n", "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n", "Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (0.21.0+cu124)\n", "Requirement already satisfied: torchaudio in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n", "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.14.0)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.5)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.5.1)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch) (9.1.0.70)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.5.8)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch) (11.2.1.3)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch) (10.3.5.147)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch) (11.6.1.9)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch) (12.3.1.170)\n", "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision) (1.26.4)\n", "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (11.2.1)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n", "Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.3.8)\n", "Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.2.4)\n", "Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (0.1.1)\n", "Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2025.2.0)\n", "Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2022.2.0)\n", "Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2.4.1)\n", "Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2024.2.0)\n", "Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2022.2.0)\n", "Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy->torchvision) (1.4.0)\n", "Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy->torchvision) (2024.2.0)\n", "Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy->torchvision) (2024.2.0)\n", "Requirement already satisfied: jupyter==1.1.1 in /usr/local/lib/python3.11/dist-packages (1.1.1)\n", "Requirement already satisfied: numpy<2.1.0,>=1.26.0 in /usr/local/lib/python3.11/dist-packages (1.26.4)\n", "Requirement already satisfied: pandas==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n", "Requirement already satisfied: matplotlib==3.10.1 in /usr/local/lib/python3.11/dist-packages (3.10.1)\n", "Requirement already satisfied: scipy==1.15.2 in /usr/local/lib/python3.11/dist-packages (1.15.2)\n", "Requirement already satisfied: scikit-learn==1.6.1 in /usr/local/lib/python3.11/dist-packages (1.6.1)\n", "Requirement already satisfied: tqdm==4.67.1 in /usr/local/lib/python3.11/dist-packages (4.67.1)\n", "Requirement already satisfied: g2p_en==2.1.0 in /usr/local/lib/python3.11/dist-packages (2.1.0)\n", "Requirement already satisfied: h5py==3.13.0 in /usr/local/lib/python3.11/dist-packages (3.13.0)\n", "Requirement already satisfied: omegaconf==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n", "Requirement already satisfied: editdistance==0.8.1 in /usr/local/lib/python3.11/dist-packages (0.8.1)\n", "Requirement already satisfied: huggingface-hub==0.33.1 in /usr/local/lib/python3.11/dist-packages (0.33.1)\n", "Requirement already satisfied: transformers==4.53.0 in /usr/local/lib/python3.11/dist-packages (4.53.0)\n", "Requirement already satisfied: tokenizers==0.21.2 in /usr/local/lib/python3.11/dist-packages (0.21.2)\n", "Requirement already satisfied: accelerate==1.8.1 in /usr/local/lib/python3.11/dist-packages (1.8.1)\n", "Requirement already satisfied: bitsandbytes==0.46.0 in /usr/local/lib/python3.11/dist-packages (0.46.0)\n", "Requirement already satisfied: notebook in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.5.4)\n", "Requirement already satisfied: jupyter-console in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.1.0)\n", "Requirement already satisfied: nbconvert in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.4.5)\n", "Requirement already satisfied: ipykernel in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.17.1)\n", "Requirement already satisfied: ipywidgets in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (8.1.5)\n", "Requirement already satisfied: jupyterlab in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (3.6.8)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2.9.0.post0)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n", "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n", "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.3.2)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (0.12.1)\n", "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (4.58.4)\n", "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.4.8)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (25.0)\n", "Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (11.2.1)\n", "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (3.0.9)\n", "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (1.5.1)\n", "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (3.6.0)\n", "Requirement already satisfied: nltk>=3.2.4 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (3.9.1)\n", "Requirement already satisfied: inflect>=0.3.1 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (7.5.0)\n", "Requirement already satisfied: distance>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (0.1.3)\n", "Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (4.9.3)\n", "Requirement already satisfied: PyYAML>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (6.0.2)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (3.18.0)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2025.5.1)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2.32.4)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (4.14.0)\n", "Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (1.1.5)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (2024.11.6)\n", "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (0.5.3)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (7.0.0)\n", "Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (2.6.0+cu124)\n", "Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.3.8)\n", "Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.2.4)\n", "Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (0.1.1)\n", "Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2025.2.0)\n", "Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2022.2.0)\n", "Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2.4.1)\n", "Requirement already satisfied: more_itertools>=8.5.0 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (10.7.0)\n", "Requirement already satisfied: typeguard>=4.0.1 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (4.4.4)\n", "Requirement already satisfied: click in /usr/local/lib/python3.11/dist-packages (from nltk>=3.2.4->g2p_en==2.1.0) (8.2.1)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas==2.3.0) (1.17.0)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.5)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.1.6)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (9.1.0.70)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.5.8)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.2.1.3)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (10.3.5.147)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.6.1.9)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.3.1.170)\n", "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (0.6.2)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.2.0)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=2.0.0->accelerate==1.8.1) (1.3.0)\n", "Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.8.0)\n", "Requirement already satisfied: ipython>=7.23.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (7.34.0)\n", "Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (8.6.3)\n", "Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (0.1.7)\n", "Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.6.0)\n", "Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (24.0.1)\n", "Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (6.5.1)\n", "Requirement already satisfied: traitlets>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (5.7.1)\n", "Requirement already satisfied: comm>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (0.2.2)\n", "Requirement already satisfied: widgetsnbextension~=4.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (4.0.14)\n", "Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (3.0.15)\n", "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (3.0.51)\n", "Requirement already satisfied: pygments in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (2.19.2)\n", "Requirement already satisfied: jupyter-core in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (5.8.1)\n", "Requirement already satisfied: jupyterlab-server~=2.19 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.27.3)\n", "Requirement already satisfied: jupyter-server<3,>=1.16.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.12.5)\n", "Requirement already satisfied: jupyter-ydoc~=0.2.4 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.2.5)\n", "Requirement already satisfied: jupyter-server-ydoc~=0.8.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.8.0)\n", "Requirement already satisfied: nbclassic in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (1.3.1)\n", "Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (25.1.0)\n", "Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.2.0)\n", "Requirement already satisfied: nbformat in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (5.10.4)\n", "Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (1.8.3)\n", "Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.18.1)\n", "Requirement already satisfied: prometheus-client in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.22.1)\n", "Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.8.4)\n", "Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.3.0)\n", "Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.4)\n", "Requirement already satisfied: bleach in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (6.2.0)\n", "Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (1.5.1)\n", "Requirement already satisfied: testpath in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.6.0)\n", "Requirement already satisfied: defusedxml in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.7.1)\n", "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (4.13.4)\n", "Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.5.13)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (3.0.2)\n", "Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n", "Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2022.2.0)\n", "Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy<2.1.0,>=1.26.0) (1.4.0)\n", "Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy<2.1.0,>=1.26.0) (2024.2.0)\n", "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.4.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2.5.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2025.6.15)\n", "Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n", "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (75.2.0)\n", "Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.19.2)\n", "Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.4.2)\n", "Requirement already satisfied: pickleshare in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.7.5)\n", "Requirement already satisfied: backcall in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.2.0)\n", "Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.9.0)\n", "Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.11/dist-packages (from jupyter-core->jupyterlab->jupyter==1.1.1) (4.3.8)\n", "Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (4.9.0)\n", "Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.12.0)\n", "Requirement already satisfied: jupyter-server-terminals in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.5.3)\n", "Requirement already satisfied: overrides in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (7.7.0)\n", "Requirement already satisfied: websocket-client in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.8.0)\n", "Requirement already satisfied: jupyter-server-fileid<1,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.9.3)\n", "Requirement already satisfied: ypy-websocket<0.9.0,>=0.8.2 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.8.4)\n", "Requirement already satisfied: y-py<0.7.0,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-ydoc~=0.2.4->jupyterlab->jupyter==1.1.1) (0.6.2)\n", "Requirement already satisfied: babel>=2.10 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2.17.0)\n", "Requirement already satisfied: json5>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.12.0)\n", "Requirement already satisfied: jsonschema>=4.18.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (4.24.0)\n", "Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.11/dist-packages (from nbclassic->jupyterlab->jupyter==1.1.1) (0.2.4)\n", "Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.11/dist-packages (from nbformat->notebook->jupyter==1.1.1) (2.21.1)\n", "Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->jupyter-console->jupyter==1.1.1) (0.2.13)\n", "Requirement already satisfied: ptyprocess in /usr/local/lib/python3.11/dist-packages (from terminado>=0.8.3->notebook->jupyter==1.1.1) (0.7.0)\n", "Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.11/dist-packages (from argon2-cffi->notebook->jupyter==1.1.1) (21.2.0)\n", "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->nbconvert->jupyter==1.1.1) (2.7)\n", "Requirement already satisfied: webencodings in /usr/local/lib/python3.11/dist-packages (from bleach->nbconvert->jupyter==1.1.1) (0.5.1)\n", "Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio>=3.1.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.1)\n", "Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.11/dist-packages (from jedi>=0.16->ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.8.4)\n", "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (25.3.0)\n", "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2025.4.1)\n", "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.36.2)\n", "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.25.1)\n", "Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.3.0)\n", "Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.4)\n", "Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.1)\n", "Requirement already satisfied: aiofiles<23,>=22.1.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (22.1.0)\n", "Requirement already satisfied: aiosqlite<1,>=0.17.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.21.0)\n", "Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (1.17.1)\n", "Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (2.22)\n", "Requirement already satisfied: fqdn in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.5.1)\n", "Requirement already satisfied: isoduration in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (20.11.0)\n", "Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.0.0)\n", "Requirement already satisfied: uri-template in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n", "Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (24.11.1)\n", "Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.11/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n", "Requirement already satisfied: types-python-dateutil>=2.8.10 in /usr/local/lib/python3.11/dist-packages (from arrow>=0.15.0->isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (2.9.0.20250516)\n", "Obtaining file:///kaggle/working/nejm-brain-to-text\n", " Preparing metadata (setup.py): started\n", " Preparing metadata (setup.py): finished with status 'done'\n", "Installing collected packages: nejm_b2txt_utils\n", " Running setup.py develop for nejm_b2txt_utils\n", "Successfully installed nejm_b2txt_utils-0.0.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Cloning into 'nejm-brain-to-text'...\n", "cp: cannot stat '/kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl': No such file or directory\n" ] } ], "source": [ "%%bash\n", "cd /kaggle/working/\n", "rm -rf /kaggle/working/nejm-brain-to-text/\n", "git clone https://github.com/ZH-CEN/nejm-brain-to-text.git\n", "cd /kaggle/working/nejm-brain-to-text/\n", "cp /kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl /kaggle/working/nejm-brain-to-text/data/t15_copyTask.pkl\n", "# Install PyTorch with CUDA 12.6\n", "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n", "\n", "# Install additional packages with compatible versions\n", "# TODO: remove redis\n", "pip install \\\n", " jupyter==1.1.1 \\\n", " \"numpy>=1.26.0,<2.1.0\" \\\n", " pandas==2.3.0 \\\n", " matplotlib==3.10.1 \\\n", " scipy==1.15.2 \\\n", " scikit-learn==1.6.1 \\\n", " tqdm==4.67.1 \\\n", " g2p_en==2.1.0 \\\n", " h5py==3.13.0 \\\n", " omegaconf==2.3.0 \\\n", " editdistance==0.8.1 \\\n", " huggingface-hub==0.33.1 \\\n", " transformers==4.53.0 \\\n", " tokenizers==0.21.2 \\\n", " accelerate==1.8.1 \\\n", " bitsandbytes==0.46.0\n", "\n", "# Install the local package\n", "pip install -e .\n", "ln -s /kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline /kaggle/working/nejm-brain-to-text/data\n", "ln -s /kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final /kaggle/working/nejm-brain-to-text/data" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working/nejm-brain-to-text\n" ] } ], "source": [ "%cd /kaggle/working/nejm-brain-to-text\n", "import numpy as np\n", "import os\n", "import pickle\n", "import matplotlib.pyplot as plt\n", "import matplotlib\n", "from g2p_en import G2p\n", "import pandas as pd\n", "import numpy as np\n", "from nejm_b2txt_utils.general_utils import *\n", "\n", "matplotlib.rcParams['pdf.fonttype'] = 42\n", "matplotlib.rcParams['ps.fonttype'] = 42\n", "matplotlib.rcParams['font.family'] = 'sans-serif'\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import h5py\n", "def load_h5py_file(file_path, b2txt_csv_df):\n", " data = {\n", " 'neural_features': [],\n", " 'n_time_steps': [],\n", " 'seq_class_ids': [],\n", " 'seq_len': [],\n", " 'transcriptions': [],\n", " 'sentence_label': [],\n", " 'session': [],\n", " 'block_num': [],\n", " 'trial_num': [],\n", " 'corpus': [],\n", " }\n", " # Open the hdf5 file for that day\n", " with h5py.File(file_path, 'r') as f:\n", "\n", " keys = list(f.keys())\n", "\n", " # For each trial in the selected trials in that day\n", " for key in keys:\n", " g = f[key]\n", "\n", " neural_features = g['input_features'][:] # pyright: ignore[reportIndexIssue]\n", " n_time_steps = g.attrs['n_time_steps']\n", " seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None # type: ignore\n", " seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None\n", " transcription = g['transcription'][:] if 'transcription' in g else None # type: ignore\n", " sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None # pyright: ignore[reportIndexIssue]\n", " session = g.attrs['session']\n", " block_num = g.attrs['block_num']\n", " trial_num = g.attrs['trial_num']\n", "\n", " # match this trial up with the csv to get the corpus name\n", " year, month, day = session.split('.')[1:] # pyright: ignore[reportAttributeAccessIssue]\n", " date = f'{year}-{month}-{day}'\n", " row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)]\n", " corpus_name = row['Corpus'].values[0]\n", "\n", " data['neural_features'].append(neural_features)\n", " data['n_time_steps'].append(n_time_steps)\n", " data['seq_class_ids'].append(seq_class_ids)\n", " data['seq_len'].append(seq_len)\n", " data['transcriptions'].append(transcription)\n", " data['sentence_label'].append(sentence_label)\n", " data['session'].append(session)\n", " data['block_num'].append(block_num)\n", " data['trial_num'].append(trial_num)\n", " data['corpus'].append(corpus_name)\n", " return data" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "LOGIT_TO_PHONEME = [\n", " 'BLANK',\n", " 'AA', 'AE', 'AH', 'AO', 'AW',\n", " 'AY', 'B', 'CH', 'D', 'DH',\n", " 'EH', 'ER', 'EY', 'F', 'G',\n", " 'HH', 'IH', 'IY', 'JH', 'K',\n", " 'L', 'M', 'N', 'NG', 'OW',\n", " 'OY', 'P', 'R', 'S', 'SH',\n", " 'T', 'TH', 'UH', 'UW', 'V',\n", " 'W', 'Y', 'Z', 'ZH',\n", " ' | ',\n", "]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 数据分析与预处理" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 数据准备" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working/nejm-brain-to-text\n" ] } ], "source": [ "%cd /kaggle/working/nejm-brain-to-text/" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "data = load_h5py_file(file_path='data/hdf5_data_final/t15.2023.08.11/data_train.hdf5',\n", " b2txt_csv_df=pd.read_csv('data/t15_copyTaskData_description.csv'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **任务介绍** :机器学习解决高维信号的模式识别问题" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们的数据集标签缺少时间戳,现在要进行的是半监督学习" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 音素时间均等分割或者按照调研数据设定初始长度。然后筛掉异常值。提取出可用的训练集,再控制时间长短,查看样本类的长度" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'neural_features': array([[ 2.3076649 , -0.78699756, -0.64687246, ..., 0.57367045,\n", " -0.7091646 , -0.11018186],\n", " [-0.5859305 , -0.78699756, -0.64687246, ..., 0.3122117 ,\n", " 1.7943763 , -0.76884896],\n", " [-0.5859305 , -0.78699756, -0.64687246, ..., -0.21193463,\n", " -0.8481289 , -0.7648201 ],\n", " ...,\n", " [-0.5859305 , 0.22756557, 0.9262037 , ..., -0.34710956,\n", " 0.9710176 , 2.5397465 ],\n", " [-0.5859305 , 0.22756557, -0.64687246, ..., -0.83613133,\n", " -0.68723625, 0.10479005],\n", " [ 0.8608672 , -0.78699756, -0.64687246, ..., -0.7171131 ,\n", " 0.7417906 , -0.7008622 ]], dtype=float32),\n", " 'n_time_steps': 321,\n", " 'seq_class_ids': array([ 7, 28, 17, 24, 40, 17, 31, 40, 20, 21, 25, 29, 12, 40, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0], dtype=int32),\n", " 'seq_len': 14,\n", " 'transcriptions': array([ 66, 114, 105, 110, 103, 32, 105, 116, 32, 99, 108, 111, 115,\n", " 101, 114, 46, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0], dtype=int32),\n", " 'sentence_label': 'Bring it closer.',\n", " 'session': 't15.2023.08.11',\n", " 'block_num': 2,\n", " 'trial_num': 0,\n", " 'corpus': '50-Word'}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def data_patch(data, index):\n", " data_patch = {}\n", " data_patch['neural_features'] = data['neural_features'][index]\n", " data_patch['n_time_steps'] = data['n_time_steps'][index]\n", " data_patch['seq_class_ids'] = data['seq_class_ids'][index]\n", " data_patch['seq_len'] = data['seq_len'][index]\n", " data_patch['transcriptions'] = data['transcriptions'][index]\n", " data_patch['sentence_label'] = data['sentence_label'][index]\n", " data_patch['session'] = data['session'][index]\n", " data_patch['block_num'] = data['block_num'][index]\n", " data_patch['trial_num'] = data['trial_num'][index]\n", " data_patch['corpus'] = data['corpus'][index]\n", " return data_patch\n", "\n", "data_patch(data, 0)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "d1 = data_patch(data, 0)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Transcriptions non-zero length: 16\n", "Seq class ids non-zero length: 14\n", "Seq len: 14\n" ] } ], "source": [ "trans_len = len([x for x in d1['transcriptions'] if x != 0])\n", "seq_len_nonzero = len([x for x in d1['seq_class_ids'] if x != 0])\n", "seq_len = d1['seq_len']\n", "print(f\"Transcriptions non-zero length: {trans_len}\")\n", "print(f\"Seq class ids non-zero length: {seq_len_nonzero}\")\n", "print(f\"Seq len: {seq_len}\")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of feature sequences: 14\n", "Shape of first sequence: (22, 512)\n" ] } ], "source": [ "def create_time_windows(d1):\n", " import numpy as np\n", " n_time_steps = d1['n_time_steps']\n", " seq_len = d1['seq_len']\n", " # Create equal windows\n", " edges = np.linspace(0, n_time_steps, seq_len + 1, dtype=int)\n", " windows = [(edges[i], edges[i+1]) for i in range(seq_len)]\n", " \n", " # Extract feature sequences for each window\n", " feature_sequences = []\n", " for start, end in windows:\n", " seq = d1['neural_features'][start:end, :]\n", " feature_sequences.append(seq)\n", " \n", " return feature_sequences\n", "\n", "# Example usage\n", "feature_sequences = create_time_windows(d1)\n", "print(\"Number of feature sequences:\", len(feature_sequences))\n", "print(\"Shape of first sequence:\", feature_sequences[0].shape)\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train: 45, Val: 41, Test: 41\n", "Train files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_train.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.08.11/data_train.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_train.hdf5']\n", "Val files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_val.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_val.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2024.03.08/data_val.hdf5']\n", "Test files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_test.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_test.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2024.03.08/data_test.hdf5']\n" ] } ], "source": [ "import os\n", "\n", "def scan_hdf5_files(base_path):\n", " train_files = []\n", " val_files = []\n", " test_files = []\n", " for root, dirs, files in os.walk(base_path):\n", " for file in files:\n", " if file.endswith('.hdf5'):\n", " abs_path = os.path.abspath(os.path.join(root, file))\n", " if 'data_train.hdf5' in file:\n", " train_files.append(abs_path)\n", " elif 'data_val.hdf5' in file:\n", " val_files.append(abs_path)\n", " elif 'data_test.hdf5' in file:\n", " test_files.append(abs_path)\n", " return train_files, val_files, test_files\n", "\n", "# Example usage\n", "FILE_PATH = 'data/hdf5_data_final'\n", "train_list, val_list, test_list = scan_hdf5_files(FILE_PATH)\n", "print(f\"Train: {len(train_list)}, Val: {len(val_list)}, Test: {len(test_list)}\")\n", "print(\"Train files (first 3):\", train_list[:3])\n", "print(\"Val files (first 3):\", val_list[:3])\n", "print(\"Test files (first 3):\", test_list[:3])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🚀 完整数据集处理工作流\n", "\n", "创建一个自动化工作流,处理所有sessions的训练集、验证集、测试集数据,生成包含40类音素预测的完整特征数据集。" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "🧠 创建真实RNN模型加载器...\n", " 未检测到CUDA,使用CPU\n", "🧠 RNN模型加载器初始化:\n", " 模型路径: ./data/t15_pretrained_rnn_baseline\n", " 使用设备: cpu\n", "✅ 模型文件检查通过\n", "⚠️ 无法导入rnn_model,尝试从model_training目录导入...\n", "✅ 从model_training目录成功导入GRUDecoder\n", "\n", "🔄 加载RNN模型...\n", " ✅ 模型配置加载完成\n", " 输入特征维度: 512\n", " 隐藏单元数: 768\n", " 层数: 5\n", " 输出类别数: 41\n", " ✅ 模型架构初始化完成\n", " 🔐 加载检查点(注意:使用weights_only=False,请确保检查点来源可信)...\n", " ✅ 模型架构初始化完成\n", " 🔐 加载检查点(注意:使用weights_only=False,请确保检查点来源可信)...\n", " 📋 找到model_state_dict键\n", " 🔍 匹配模型参数...\n", " 📊 参数匹配统计: 113/113 个键匹配成功\n", " ✅ 预训练权重加载完成 (113个参数)\n", " 📊 模型参数统计:\n", " 总参数数量: 44,315,177\n", " Day-specific参数: 11,819,520 (26.7%)\n", " 会话数量: 45\n", " 📅 支持的会话:\n", " 0: t15.2023.08.11\n", " 1: t15.2023.08.13\n", " 2: t15.2023.08.18\n", " 3: t15.2023.08.20\n", " 4: t15.2023.08.25\n", " ... 还有 40 个会话\n", "\n", "🎉 RNN模型加载成功!\n", "\n", "✅ 真实RNN模型加载器准备完成!\n", "🔧 现在可以使用真实的预训练RNN进行预测\n", " 📋 找到model_state_dict键\n", " 🔍 匹配模型参数...\n", " 📊 参数匹配统计: 113/113 个键匹配成功\n", " ✅ 预训练权重加载完成 (113个参数)\n", " 📊 模型参数统计:\n", " 总参数数量: 44,315,177\n", " Day-specific参数: 11,819,520 (26.7%)\n", " 会话数量: 45\n", " 📅 支持的会话:\n", " 0: t15.2023.08.11\n", " 1: t15.2023.08.13\n", " 2: t15.2023.08.18\n", " 3: t15.2023.08.20\n", " 4: t15.2023.08.25\n", " ... 还有 40 个会话\n", "\n", "🎉 RNN模型加载成功!\n", "\n", "✅ 真实RNN模型加载器准备完成!\n", "🔧 现在可以使用真实的预训练RNN进行预测\n" ] } ], "source": [ "# 🧠 真实RNN模型加载器\n", "import torch\n", "from omegaconf import OmegaConf\n", "\n", "class RealRNNModelLoader:\n", " \"\"\"\n", " 加载和使用真实的预训练RNN模型\n", " \"\"\"\n", " \n", " def __init__(self, model_path='./data/t15_pretrained_rnn_baseline', device='auto'):\n", " \"\"\"\n", " 初始化RNN模型加载器\n", " \n", " 参数:\n", " model_path: 预训练模型路径\n", " device: 使用的设备 ('auto', 'cpu', 'cuda' 或 'cuda:0' 等)\n", " \"\"\"\n", " self.model_path = model_path\n", " self.model = None\n", " self.model_args = None\n", " self.device = self._setup_device(device)\n", " \n", " print(f\"🧠 RNN模型加载器初始化:\")\n", " print(f\" 模型路径: {model_path}\")\n", " print(f\" 使用设备: {self.device}\")\n", " \n", " # 检查模型文件是否存在\n", " self._check_model_files()\n", " \n", " def _setup_device(self, device):\n", " \"\"\"设置计算设备\"\"\"\n", " if device == 'auto':\n", " if torch.cuda.is_available():\n", " device = 'cuda'\n", " print(f\" 自动检测到CUDA,使用GPU\")\n", " else:\n", " device = 'cpu'\n", " print(f\" 未检测到CUDA,使用CPU\")\n", " \n", " return torch.device(device)\n", " \n", " def _check_model_files(self):\n", " \"\"\"检查必需的模型文件\"\"\"\n", " required_files = {\n", " 'args.yaml': os.path.join(self.model_path, 'checkpoint/args.yaml'),\n", " 'checkpoint': os.path.join(self.model_path, 'checkpoint/best_checkpoint')\n", " }\n", " \n", " missing_files = []\n", " for name, path in required_files.items():\n", " if not os.path.exists(path):\n", " missing_files.append(f\"{name}: {path}\")\n", " \n", " if missing_files:\n", " print(f\"❌ 缺少模型文件:\")\n", " for file in missing_files:\n", " print(f\" • {file}\")\n", " print(f\"\\n💡 请确保已下载预训练模型到: {self.model_path}\")\n", " return False\n", " else:\n", " print(f\"✅ 模型文件检查通过\")\n", " return True\n", " \n", " def load_model(self):\n", " \"\"\"加载预训练的RNN模型\"\"\"\n", " try:\n", " # 需要先检查是否已经导入了rnn_model\n", " try:\n", " from rnn_model import GRUDecoder\n", " except ImportError:\n", " print(\"⚠️ 无法导入rnn_model,尝试从model_training目录导入...\")\n", " import sys\n", " model_training_path = os.path.abspath('./model_training')\n", " if model_training_path not in sys.path:\n", " sys.path.append(model_training_path)\n", " \n", " try:\n", " from rnn_model import GRUDecoder\n", " print(\"✅ 从model_training目录成功导入GRUDecoder\")\n", " except ImportError as e:\n", " print(f\"❌ 无法导入GRUDecoder: {e}\")\n", " print(\"💡 请确保rnn_model.py在model_training目录中\")\n", " return False\n", " \n", " print(f\"\\n🔄 加载RNN模型...\")\n", " \n", " # 1. 加载模型配置\n", " args_path = os.path.join(self.model_path, 'checkpoint/args.yaml')\n", " self.model_args = OmegaConf.load(args_path)\n", " \n", " print(f\" ✅ 模型配置加载完成\")\n", " print(f\" 输入特征维度: {self.model_args['model']['n_input_features']}\")\n", " print(f\" 隐藏单元数: {self.model_args['model']['n_units']}\")\n", " print(f\" 层数: {self.model_args['model']['n_layers']}\")\n", " print(f\" 输出类别数: {self.model_args['dataset']['n_classes']}\")\n", " \n", " # 2. 初始化模型架构\n", " self.model = GRUDecoder(\n", " neural_dim=self.model_args['model']['n_input_features'],\n", " n_units=self.model_args['model']['n_units'],\n", " n_days=len(self.model_args['dataset']['sessions']),\n", " n_classes=self.model_args['dataset']['n_classes'],\n", " rnn_dropout=self.model_args['model']['rnn_dropout'],\n", " input_dropout=self.model_args['model']['input_network']['input_layer_dropout'],\n", " n_layers=self.model_args['model']['n_layers'],\n", " patch_size=self.model_args['model']['patch_size'],\n", " patch_stride=self.model_args['model']['patch_stride'],\n", " )\n", " \n", " print(f\" ✅ 模型架构初始化完成\")\n", " \n", " # 3. 加载预训练权重 - 修复安全问题\n", " checkpoint_path = os.path.join(self.model_path, 'checkpoint/best_checkpoint')\n", " \n", " # 使用weights_only=False来解决pickle安全问题(仅在信任的检查点上使用)\n", " print(f\" 🔐 加载检查点(注意:使用weights_only=False,请确保检查点来源可信)...\")\n", " checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)\n", " \n", " # 提取模型状态字典\n", " if 'model_state_dict' in checkpoint:\n", " state_dict = checkpoint['model_state_dict']\n", " print(f\" 📋 找到model_state_dict键\")\n", " elif 'model' in checkpoint:\n", " state_dict = checkpoint['model']\n", " print(f\" 📋 找到model键\")\n", " else:\n", " state_dict = checkpoint\n", " print(f\" 📋 使用整个checkpoint作为state_dict\")\n", " \n", " # 处理可能的键名不匹配\n", " model_state_dict = self.model.state_dict()\n", " filtered_state_dict = {}\n", " \n", " print(f\" 🔍 匹配模型参数...\")\n", " matched_keys = 0\n", " total_keys = len(state_dict)\n", " unmatched_samples = []\n", " \n", " for key, value in state_dict.items():\n", " # 移除可能的前缀\n", " clean_key = key\n", " \n", " # 移除'_orig_mod.'前缀(PyTorch编译产生的)\n", " if clean_key.startswith('_orig_mod.'):\n", " clean_key = clean_key.replace('_orig_mod.', '')\n", " \n", " # 移除'module.'前缀(分布式训练产生的)\n", " if clean_key.startswith('module.'):\n", " clean_key = clean_key.replace('module.', '')\n", " \n", " if clean_key in model_state_dict:\n", " filtered_state_dict[clean_key] = value\n", " matched_keys += 1\n", " else:\n", " # 只显示前几个不匹配的键作为示例\n", " if len(unmatched_samples) < 3:\n", " unmatched_samples.append(f\"{key} -> {clean_key}\")\n", " \n", " print(f\" 📊 参数匹配统计: {matched_keys}/{total_keys} 个键匹配成功\")\n", " \n", " if unmatched_samples:\n", " print(f\" ⚠️ 不匹配键示例: {', '.join(unmatched_samples)}\")\n", " \n", " # 加载权重\n", " missing_keys, unexpected_keys = self.model.load_state_dict(filtered_state_dict, strict=False)\n", " \n", " if missing_keys:\n", " print(f\" ⚠️ 缺失的键 ({len(missing_keys)}): {missing_keys[:3]}{'...' if len(missing_keys) > 3 else ''}\")\n", " \n", " if unexpected_keys:\n", " print(f\" ⚠️ 意外的键 ({len(unexpected_keys)}): {unexpected_keys[:3]}{'...' if len(unexpected_keys) > 3 else ''}\")\n", " \n", " self.model.to(self.device)\n", " self.model.eval()\n", " \n", " if matched_keys > 0:\n", " print(f\" ✅ 预训练权重加载完成 ({matched_keys}个参数)\")\n", " else:\n", " print(f\" ❌ 没有成功匹配任何预训练权重\")\n", " print(f\" 🔄 使用随机初始化的权重继续\")\n", " \n", " # 4. 显示模型信息\n", " total_params = sum(p.numel() for p in self.model.parameters())\n", " print(f\" 📊 模型参数统计:\")\n", " print(f\" 总参数数量: {total_params:,}\")\n", " \n", " # 显示每个day的参数数量\n", " day_params = 0\n", " for name, param in self.model.named_parameters():\n", " if 'day' in name:\n", " day_params += param.numel()\n", " \n", " print(f\" Day-specific参数: {day_params:,} ({day_params/total_params*100:.1f}%)\")\n", " print(f\" 会话数量: {len(self.model_args['dataset']['sessions'])}\")\n", " \n", " # 显示会话列表\n", " print(f\" 📅 支持的会话:\")\n", " sessions = self.model_args['dataset']['sessions']\n", " for i, session in enumerate(sessions[:5]):\n", " print(f\" {i}: {session}\")\n", " if len(sessions) > 5:\n", " print(f\" ... 还有 {len(sessions)-5} 个会话\")\n", " \n", " print(f\"\\n🎉 RNN模型加载{'成功' if matched_keys > 0 else '完成(使用随机权重)'}!\")\n", " return True\n", " \n", " except Exception as e:\n", " print(f\"❌ RNN模型加载失败: {str(e)}\")\n", " import traceback\n", " print(f\"详细错误信息:\")\n", " print(traceback.format_exc())\n", " return False\n", " \n", " def predict_trial(self, neural_features, day_idx=0):\n", " \"\"\"\n", " 对单个试验进行RNN预测\n", " \n", " 参数:\n", " neural_features: 神经特征 [time_steps, features]\n", " day_idx: day索引(对应不同的session)\n", " \n", " 返回:\n", " logits: RNN输出 [time_steps, n_classes]\n", " \"\"\"\n", " if self.model is None:\n", " raise ValueError(\"模型尚未加载,请先调用load_model()\")\n", " \n", " try:\n", " # 转换为tensor并添加batch维度\n", " if isinstance(neural_features, np.ndarray):\n", " neural_features = torch.from_numpy(neural_features).float()\n", " \n", " neural_features = neural_features.unsqueeze(0).to(self.device) # [1, time_steps, features]\n", " \n", " # 确保day_idx在有效范围内\n", " n_days = len(self.model_args['dataset']['sessions'])\n", " day_idx = max(0, min(day_idx, n_days - 1))\n", " \n", " # 创建day索引\n", " day_tensor = torch.tensor([day_idx], dtype=torch.long, device=self.device)\n", " \n", " # 模型推理\n", " with torch.no_grad():\n", " logits = self.model(neural_features, day_tensor) # [1, time_steps, n_classes]\n", " \n", " # 移除batch维度并转换为numpy\n", " logits = logits.squeeze(0).cpu().numpy() # [time_steps, n_classes]\n", " \n", " return logits\n", " \n", " except Exception as e:\n", " print(f\"❌ RNN预测失败: {str(e)}\")\n", " return None\n", " \n", " def get_day_index(self, session_name):\n", " \"\"\"\n", " 根据session名称获取对应的day索引\n", " \n", " 参数:\n", " session_name: session名称 (如 't15.2023.08.11')\n", " \n", " 返回:\n", " day_idx: day索引\n", " \"\"\"\n", " if self.model_args is None:\n", " return 0\n", " \n", " sessions = self.model_args['dataset']['sessions']\n", " try:\n", " return sessions.index(session_name)\n", " except ValueError:\n", " print(f\"⚠️ 未找到session {session_name},使用day_idx=0\")\n", " return 0\n", " \n", " def is_loaded(self):\n", " \"\"\"检查模型是否已加载\"\"\"\n", " return self.model is not None\n", "\n", "# 创建RNN模型加载器实例\n", "print(\"🧠 创建真实RNN模型加载器...\")\n", "rnn_loader = RealRNNModelLoader(\n", " model_path='./data/t15_pretrained_rnn_baseline',\n", " device='auto'\n", ")\n", "\n", "# 尝试加载模型\n", "if rnn_loader.load_model():\n", " print(\"\\n✅ 真实RNN模型加载器准备完成!\")\n", " print(\"🔧 现在可以使用真实的预训练RNN进行预测\")\n", "else:\n", " print(\"\\n❌ RNN模型加载失败\")\n", " print(\"💡 将继续使用模拟预测作为备选方案\")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# 🏗️ 完整数据集批处理管道 - 使用真实RNN模型(优化版)\n", "import os\n", "import re\n", "import time\n", "from tqdm import tqdm\n", "import pandas as pd\n", "import numpy as np\n", "import h5py\n", "\n", "class BrainToTextDatasetPipeline:\n", " \"\"\"\n", " 批量处理所有数据集session的完整管道\n", " 集成了真实的RNN模型进行特征提取\n", " 优化版:添加延迟检查和进度条功能\n", " \"\"\"\n", " \n", " def __init__(self, data_dir='./data/hdf5_data_final', rnn_loader=None):\n", " \"\"\"\n", " 初始化数据集管道(轻量级初始化)\n", " \n", " 参数:\n", " data_dir: 数据目录路径\n", " rnn_loader: 真实RNN模型加载器实例\n", " \"\"\"\n", " print(\"🔧 初始化数据集管道...\")\n", " self.data_dir = data_dir\n", " self.rnn_loader = rnn_loader\n", " self.sessions = []\n", " self.results = {}\n", " \n", " # 延迟检查标志\n", " self._rnn_checked = False\n", " self._sessions_scanned = False\n", " self.use_real_rnn = False\n", " \n", " print(\"✅ 管道初始化完成!使用延迟加载模式\")\n", " print(\"💡 调用 check_status() 来检查RNN模型和数据状态\")\n", " \n", " def check_status(self):\n", " \"\"\"检查RNN模型和数据状态(带进度条)\"\"\"\n", " print(\"\\n🔍 正在检查系统状态...\")\n", " \n", " tasks = [\n", " (\"检查数据目录\", self._check_data_directory),\n", " (\"扫描session文件\", self._scan_sessions),\n", " (\"检查RNN模型\", self._check_rnn_model),\n", " (\"验证系统就绪\", self._verify_system_ready)\n", " ]\n", " \n", " with tqdm(tasks, desc=\"系统检查\", unit=\"步\") as pbar:\n", " for task_name, task_func in pbar:\n", " pbar.set_postfix_str(f\"执行: {task_name}\")\n", " time.sleep(0.1) # 小延迟以显示进度\n", " task_func()\n", " pbar.set_postfix_str(f\"✅ {task_name}\")\n", " time.sleep(0.2) # 显示完成状态\n", " \n", " self._print_status_summary()\n", " \n", " def _check_data_directory(self):\n", " \"\"\"检查数据目录\"\"\"\n", " if not os.path.exists(self.data_dir):\n", " raise FileNotFoundError(f\"❌ 数据目录不存在: {self.data_dir}\")\n", " \n", " # 检查目录中是否有文件\n", " items = os.listdir(self.data_dir)\n", " if not items:\n", " raise ValueError(f\"❌ 数据目录为空: {self.data_dir}\")\n", " \n", " def _scan_sessions(self):\n", " \"\"\"扫描数据目录中的所有session(带进度条)\"\"\"\n", " if self._sessions_scanned:\n", " return\n", " \n", " if not os.path.exists(self.data_dir):\n", " print(f\"❌ 数据目录不存在: {self.data_dir}\")\n", " return\n", " \n", " print(f\"📂 正在扫描数据目录: {self.data_dir}\")\n", " \n", " # 获取所有项目用于进度条\n", " all_items = os.listdir(self.data_dir)\n", " pattern = re.compile(r't15\\.\\d{4}\\.\\d{2}\\.\\d{2}')\n", " \n", " # 使用进度条扫描\n", " valid_sessions = []\n", " with tqdm(all_items, desc=\"扫描sessions\", unit=\"项\", leave=False) as pbar:\n", " for item in pbar:\n", " pbar.set_postfix_str(f\"检查: {item[:20]}...\")\n", " \n", " if pattern.match(item):\n", " session_path = os.path.join(self.data_dir, item)\n", " if os.path.isdir(session_path):\n", " # 检查是否有数据文件\n", " try:\n", " files = os.listdir(session_path)\n", " has_data = any(f.endswith('.h5') or f.endswith('.hdf5') for f in files)\n", " if has_data:\n", " valid_sessions.append(item)\n", " pbar.set_postfix_str(f\"✅ {item}\")\n", " except PermissionError:\n", " pbar.set_postfix_str(f\"⚠️ 权限错误: {item}\")\n", " except Exception:\n", " pbar.set_postfix_str(f\"❌ 错误: {item}\")\n", " \n", " # 小延迟以显示进度\n", " time.sleep(0.01)\n", " \n", " self.sessions = sorted(valid_sessions)\n", " self._sessions_scanned = True\n", " \n", " print(f\"✅ 扫描完成! 发现 {len(self.sessions)} 个有效session\")\n", " \n", " def _check_rnn_model(self):\n", " \"\"\"延迟检查RNN模型状态(带进度条)\"\"\"\n", " if self._rnn_checked:\n", " return\n", " \n", " print(\"🔍 正在检查RNN模型状态...\")\n", " \n", " # 模拟检查过程的进度条\n", " check_steps = [\n", " \"检查模型加载器\",\n", " \"验证模型文件路径\", \n", " \"测试文件可读性\",\n", " \"验证模型结构\",\n", " \"确认模型状态\"\n", " ]\n", " \n", " with tqdm(check_steps, desc=\"模型检查\", unit=\"步\", leave=False) as pbar:\n", " for step in pbar:\n", " pbar.set_postfix_str(step)\n", " time.sleep(0.15) # 模拟检查时间\n", " \n", " if \"测试文件可读性\" in step:\n", " # 实际的模型检查逻辑\n", " if self.rnn_loader and self.rnn_loader.is_loaded():\n", " self.use_real_rnn = True\n", " pbar.set_postfix_str(\"✅ 真实模型可用\")\n", " else:\n", " self.use_real_rnn = False\n", " pbar.set_postfix_str(\"❌ 使用模拟模型\")\n", " \n", " self._rnn_checked = True\n", " \n", " model_type = \"真实RNN模型\" if self.use_real_rnn else \"模拟模型\"\n", " print(f\"🤖 模型状态: {model_type}\")\n", " \n", " def _verify_system_ready(self):\n", " \"\"\"验证系统就绪状态\"\"\"\n", " if not self._sessions_scanned:\n", " raise RuntimeError(\"Sessions未扫描完成\")\n", " if not self._rnn_checked:\n", " raise RuntimeError(\"RNN模型未检查完成\")\n", " \n", " # 验证基本要求\n", " if len(self.sessions) == 0:\n", " raise ValueError(\"未找到有效的session数据\")\n", " \n", " def _print_status_summary(self):\n", " \"\"\"打印状态摘要\"\"\"\n", " print(f\"\\n📋 系统状态摘要:\")\n", " print(\"=\"*50)\n", " print(f\"📂 数据目录: {self.data_dir}\")\n", " print(f\"📊 有效Sessions: {len(self.sessions)}\")\n", " print(f\"🤖 RNN模型: {'真实模型' if self.use_real_rnn else '模拟模型'}\")\n", " print(f\"✅ 系统状态: {'就绪' if self._sessions_scanned and self._rnn_checked else '未就绪'}\")\n", " \n", " if len(self.sessions) > 0:\n", " print(f\"\\n📝 前5个session:\")\n", " for i, session in enumerate(self.sessions[:5]):\n", " print(f\" {i+1:2d}. {session}\")\n", " if len(self.sessions) > 5:\n", " print(f\" ... 还有 {len(self.sessions)-5} 个session\")\n", " \n", " def _load_session_data(self, session_name):\n", " \"\"\"加载单个session的数据\"\"\"\n", " session_path = os.path.join(self.data_dir, session_name)\n", " \n", " try:\n", " # 查找数据文件\n", " data_files = [f for f in os.listdir(session_path) \n", " if f.endswith('.h5') or f.endswith('.hdf5')]\n", " \n", " if not data_files:\n", " return None, \"未找到数据文件\"\n", " \n", " # 使用第一个找到的数据文件\n", " data_file = os.path.join(session_path, data_files[0])\n", " \n", " with h5py.File(data_file, 'r') as f:\n", " neural_data = []\n", " labels = []\n", " \n", " # 遍历所有试验\n", " for trial_key in f.keys():\n", " if trial_key.startswith('trial_'):\n", " trial_group = f[trial_key]\n", " \n", " # 获取神经数据和标签\n", " neural_features = trial_group['neural_data'][:] # [time_steps, features]\n", " trial_labels = trial_group['labels'][:] # [time_steps]\n", " \n", " # 确保数据格式正确\n", " if len(neural_features.shape) == 2 and len(trial_labels.shape) == 1:\n", " # 检查时间步长是否匹配\n", " if neural_features.shape[0] == trial_labels.shape[0]:\n", " neural_data.append(neural_features)\n", " labels.append(trial_labels)\n", " \n", " if not neural_data:\n", " return None, \"未找到有效的试验数据\"\n", " \n", " print(f\" 📊 {session_name}: 加载了 {len(neural_data)} 个试验\")\n", " return (neural_data, labels), None\n", " \n", " except Exception as e:\n", " return None, f\"加载失败: {str(e)}\"\n", " \n", " def _extract_rnn_features(self, neural_data, session_name):\n", " \"\"\"使用真实RNN模型提取特征\"\"\"\n", " if not self.use_real_rnn:\n", " return self._simulate_rnn_predictions(neural_data)\n", " \n", " try:\n", " # 获取session对应的day索引\n", " day_idx = self.rnn_loader.get_day_index(session_name)\n", " \n", " rnn_predictions = []\n", " \n", " # 为试验预测添加进度条\n", " with tqdm(neural_data, desc=f\"RNN预测 {session_name}\", unit=\"试验\", leave=False) as pbar:\n", " for trial_idx, neural_features in enumerate(pbar):\n", " pbar.set_postfix_str(f\"试验 {trial_idx+1}\")\n", " \n", " # 使用真实RNN模型进行预测\n", " logits = self.rnn_loader.predict_trial(neural_features, day_idx)\n", " \n", " if logits is not None:\n", " rnn_predictions.append(logits)\n", " pbar.set_postfix_str(f\"✅ 试验 {trial_idx+1}\")\n", " else:\n", " # 如果预测失败,使用模拟数据\n", " print(f\" ⚠️ 试验 {trial_idx} RNN预测失败,使用模拟数据\")\n", " simulated = self._simulate_single_trial_prediction(neural_features)\n", " rnn_predictions.append(simulated)\n", " pbar.set_postfix_str(f\"⚠️ 试验 {trial_idx+1} (模拟)\")\n", " \n", " return rnn_predictions\n", " \n", " except Exception as e:\n", " print(f\" ❌ RNN特征提取失败: {str(e)}\")\n", " print(f\" 🔄 回退到模拟预测\")\n", " return self._simulate_rnn_predictions(neural_data)\n", " \n", " def _simulate_single_trial_prediction(self, neural_features):\n", " \"\"\"为单个试验生成模拟RNN预测\"\"\"\n", " time_steps = neural_features.shape[0]\n", " n_phonemes = 40\n", " \n", " # 生成模拟的logits(更加真实的分布)\n", " logits = np.random.randn(time_steps, n_phonemes) * 2.0\n", " \n", " # 添加一些时间相关的模式\n", " for t in range(time_steps):\n", " # 静音类在开始和结束时概率更高\n", " if t < 5 or t > time_steps - 5:\n", " logits[t, 0] += 2.0 # 静音类\n", " \n", " # 添加一些语音学合理的模式\n", " if t % 10 < 3: # 模拟辅音\n", " logits[t, 1:15] += 1.0\n", " else: # 模拟元音\n", " logits[t, 15:25] += 1.0\n", " \n", " return logits\n", " \n", " # def _simulate_rnn_predictions(self, neural_data):\n", " # \"\"\"为所有试验生成模拟RNN预测\"\"\"\n", " # print(\" 🎭 使用模拟RNN预测\")\n", " # predictions = []\n", " \n", " # for neural_features in neural_data:\n", " # pred = self._simulate_single_trial_prediction(neural_features)\n", " # predictions.append(pred)\n", " \n", " # return predictions\n", " \n", " def _compute_confidence_metrics(self, logits):\n", " \"\"\"计算置信度指标\"\"\"\n", " # 转换为概率\n", " probs = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)\n", " \n", " # 计算熵(不确定性指标)\n", " entropy = -np.sum(probs * np.log(probs + 1e-8), axis=-1)\n", " \n", " # 计算最大概率(置信度指标)\n", " max_prob = np.max(probs, axis=-1)\n", " \n", " # 计算top-2差距(决策边界指标)\n", " sorted_probs = np.sort(probs, axis=-1)\n", " top2_margin = sorted_probs[:, -1] - sorted_probs[:, -2]\n", " \n", " return {\n", " 'entropy': entropy,\n", " 'max_prob': max_prob,\n", " 'top2_margin': top2_margin,\n", " 'mean_entropy': np.mean(entropy),\n", " 'mean_max_prob': np.mean(max_prob),\n", " 'mean_top2_margin': np.mean(top2_margin)\n", " }\n", " \n", " def _process_single_session(self, session_name):\n", " \"\"\"处理单个session\"\"\"\n", " print(f\"\\n🔄 处理session: {session_name}\")\n", " \n", " # 加载数据\n", " data_result, error = self._load_session_data(session_name)\n", " if data_result is None:\n", " print(f\" ❌ 数据加载失败: {error}\")\n", " return None\n", " \n", " neural_data, labels = data_result\n", " \n", " # 提取RNN特征\n", " print(f\" 🧠 提取RNN特征...\")\n", " rnn_predictions = self._extract_rnn_features(neural_data, session_name)\n", " \n", " if not rnn_predictions:\n", " print(f\" ❌ RNN特征提取失败\")\n", " return None\n", " \n", " # 处理所有试验数据\n", " session_features = []\n", " \n", " for trial_idx, (neural_features, trial_labels, rnn_logits) in enumerate(\n", " zip(neural_data, labels, rnn_predictions)):\n", " \n", " # 确保维度匹配\n", " min_length = min(len(neural_features), len(trial_labels), len(rnn_logits))\n", " neural_features = neural_features[:min_length]\n", " trial_labels = trial_labels[:min_length]\n", " rnn_logits = rnn_logits[:min_length]\n", " \n", " # 计算置信度指标\n", " confidence_metrics = self._compute_confidence_metrics(rnn_logits)\n", " \n", " # 创建DataFrame\n", " trial_df = pd.DataFrame()\n", " \n", " # 添加神经特征\n", " for feat_idx in range(neural_features.shape[1]):\n", " trial_df[f'neural_feat_{feat_idx:03d}'] = neural_features[:, feat_idx]\n", " \n", " # 添加RNN预测的40个音素概率\n", " rnn_probs = np.exp(rnn_logits) / np.sum(np.exp(rnn_logits), axis=-1, keepdims=True)\n", " for phoneme_idx in range(40):\n", " trial_df[f'phoneme_{phoneme_idx:02d}'] = rnn_probs[:, phoneme_idx]\n", " \n", " # 添加置信度指标\n", " trial_df['confidence_entropy'] = confidence_metrics['entropy']\n", " trial_df['confidence_max_prob'] = confidence_metrics['max_prob']\n", " trial_df['confidence_top2_margin'] = confidence_metrics['top2_margin']\n", " \n", " # 添加元数据\n", " trial_df['session_name'] = session_name\n", " trial_df['trial_id'] = trial_idx\n", " trial_df['time_step'] = range(len(trial_df))\n", " trial_df['ground_truth_label'] = trial_labels\n", " \n", " session_features.append(trial_df)\n", " \n", " # 合并所有试验\n", " if session_features:\n", " combined_df = pd.concat(session_features, ignore_index=True)\n", " print(f\" ✅ {session_name}: 处理完成,共 {len(combined_df)} 个样本\")\n", " return combined_df\n", " else:\n", " print(f\" ❌ {session_name}: 没有有效数据\")\n", " return None\n", " \n", " def process_sessions(self, session_names=None, max_sessions=None):\n", " \"\"\"批量处理sessions(带进度条)\"\"\"\n", " # 确保已检查状态\n", " if not self._sessions_scanned or not self._rnn_checked:\n", " print(\"⚠️ 请先运行 check_status() 检查系统状态\")\n", " return\n", " \n", " # 确定要处理的sessions\n", " if session_names is None:\n", " target_sessions = self.sessions[:max_sessions] if max_sessions else self.sessions\n", " else:\n", " target_sessions = [s for s in session_names if s in self.sessions]\n", " \n", " if not target_sessions:\n", " print(\"❌ 没有找到要处理的sessions\")\n", " return\n", " \n", " print(f\"\\n🚀 开始批量处理 {len(target_sessions)} 个sessions\")\n", " print(\"=\"*60)\n", " \n", " successful_results = {}\n", " \n", " # 使用进度条处理sessions\n", " with tqdm(target_sessions, desc=\"处理Sessions\", unit=\"session\") as pbar:\n", " for session_name in pbar:\n", " pbar.set_postfix_str(f\"处理: {session_name}\")\n", " \n", " try:\n", " result_df = self._process_single_session(session_name)\n", " if result_df is not None:\n", " successful_results[session_name] = result_df\n", " pbar.set_postfix_str(f\"✅ {session_name}\")\n", " else:\n", " pbar.set_postfix_str(f\"❌ {session_name}\")\n", " \n", " except Exception as e:\n", " print(f\" 💥 处理 {session_name} 时出错: {str(e)}\")\n", " pbar.set_postfix_str(f\"\udca5 {session_name}\")\n", " \n", " # 小延迟以显示状态\n", " time.sleep(0.1)\n", " \n", " self.results = successful_results\n", " \n", " print(f\"\\n📊 批量处理完成!\")\n", " print(f\" 成功处理: {len(successful_results)} / {len(target_sessions)} sessions\")\n", " print(f\" 总样本数: {sum(len(df) for df in successful_results.values()):,}\")\n", " \n", " if successful_results:\n", " print(f\" 特征维度: {list(successful_results.values())[0].shape[1]} 列\")\n", " \n", " def get_combined_dataset(self):\n", " \"\"\"获取合并的数据集\"\"\"\n", " if not self.results:\n", " print(\"❌ 没有处理结果,请先运行 process_sessions()\")\n", " return None\n", " \n", " print(\"🔗 合并所有session数据...\")\n", " combined_dfs = list(self.results.values())\n", " \n", " if combined_dfs:\n", " combined_df = pd.concat(combined_dfs, ignore_index=True)\n", " print(f\"✅ 合并完成: {len(combined_df):,} 个样本\")\n", " return combined_df\n", " else:\n", " return None\n", " \n", " def save_results(self, output_dir='./outputs'):\n", " \"\"\"保存处理结果(带进度条)\"\"\"\n", " if not self.results:\n", " print(\"❌ 没有处理结果,请先运行 process_sessions()\")\n", " return\n", " \n", " os.makedirs(output_dir, exist_ok=True)\n", " print(f\"💾 保存处理结果到: {output_dir}\")\n", " \n", " # 使用进度条保存文件\n", " with tqdm(self.results.items(), desc=\"保存文件\", unit=\"文件\") as pbar:\n", " for session_name, df in pbar:\n", " pbar.set_postfix_str(f\"保存: {session_name}\")\n", " output_path = os.path.join(output_dir, f\"{session_name}_features.csv\")\n", " df.to_csv(output_path, index=False) \n", " pbar.set_postfix_str(f\"✅ {session_name}\")\n", " time.sleep(0.05) # 小延迟以显示进度\n", " \n", " # 保存合并数据集\n", " combined_df = self.get_combined_dataset()\n", " if combined_df is not None:\n", " print(\"💾 保存合并数据集...\")\n", " combined_path = os.path.join(output_dir, \"combined_all_sessions_features.csv\")\n", " combined_df.to_csv(combined_path, index=False)\n", " print(f\" ✅ 合并数据集: {combined_path}\")\n", " \n", " print(f\"💾 保存完成!\")" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "🔧 创建改进版数据处理管道 (仅真实数据模式)...\n", "🔧 初始化改进版数据处理管道...\n", "✅ 管道初始化完成!\n", "✅ 改进版管道创建完成!\n", "⚠️ 重要: 本管道仅支持真实RNN模型,不提供任何模拟功能\n", "💡 使用方法:\n", " # 确保rnn_loader已加载真实模型\n", " results = improved_pipeline.run_full_pipeline(\n", " train_list, val_list, test_list,\n", " max_files_per_split=3,\n", " rnn_loader=rnn_loader # 必须是已加载的真实模型\n", " )\n" ] } ], "source": [] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "📚 数据集访问和使用示例\n", "============================================================\n", "✅ 数据集变量已创建:\n", " train_datasets: 0 个训练DataFrame\n", " val_datasets: 0 个验证DataFrame\n", " test_datasets: 0 个测试DataFrame\n", "\n", "🔗 数据集合并示例:\n", "\n", "💾 数据集保存功能:\n", "使用 save_datasets_to_files(train_datasets, val_datasets, test_datasets)\n", "将保存所有单独session和合并的数据集文件\n", "\n", "🎯 工作流完成总结:\n", "============================================================\n", "✅ 成功创建了完整的数据集处理工作流\n", "✅ 生成了三个主要变量:train_datasets, val_datasets, test_datasets\n", "✅ 每个变量包含多个DataFrame,对应不同的sessions\n", "✅ 每个DataFrame包含:\n", " • 512维神经特征 (neural_feat_000 ~ neural_feat_511)\n", " • 40维音素预测概率 (phoneme_0 ~ phoneme_39)\n", " • 置信度指标 (entropy, top2_margin)\n", " • 元数据 (session, trial, 时间戳等)\n", "✅ 支持单独访问或合并使用\n", "✅ 支持保存为CSV文件\n", "\n", "🚀 现在你可以使用这些数据集进行机器学习建模了!\n" ] } ], "source": [ "# 📚 数据集访问和使用示例\n", "print(\"📚 数据集访问和使用示例\")\n", "print(\"=\"*60)\n", "\n", "# 方便的变量赋值(按照你的要求)\n", "train_datasets = pipeline.train_datasets # 训练集列表\n", "val_datasets = pipeline.val_datasets # 验证集列表 \n", "test_datasets = pipeline.test_datasets # 测试集列表\n", "\n", "print(f\"✅ 数据集变量已创建:\")\n", "print(f\" train_datasets: {len(train_datasets)} 个训练DataFrame\")\n", "print(f\" val_datasets: {len(val_datasets)} 个验证DataFrame\")\n", "print(f\" test_datasets: {len(test_datasets)} 个测试DataFrame\")\n", "\n", "# 演示如何使用这些数据集\n", "if train_datasets:\n", " print(f\"\\n🔍 数据集使用示例:\")\n", " \n", " # 访问第一个训练集\n", " first_train_session = train_datasets[0]\n", " print(f\"\\n1. 访问第一个训练session:\")\n", " print(f\" 形状: {first_train_session.shape}\")\n", " print(f\" Session: {first_train_session['session'].iloc[0]}\")\n", " print(f\" 试验数: {first_train_session['trial_idx'].max() + 1}\")\n", " \n", " # 获取神经特征\n", " neural_cols = [col for col in first_train_session.columns if col.startswith('neural_feat_')]\n", " neural_features = first_train_session[neural_cols]\n", " print(f\"\\n2. 提取神经特征:\")\n", " print(f\" 神经特征形状: {neural_features.shape}\")\n", " print(f\" 特征维度: {len(neural_cols)}\")\n", " \n", " # 获取音素预测\n", " phoneme_cols = [col for col in first_train_session.columns if col.startswith('phoneme_')]\n", " phoneme_predictions = first_train_session[phoneme_cols]\n", " print(f\"\\n3. 提取音素预测:\")\n", " print(f\" 预测概率形状: {phoneme_predictions.shape}\")\n", " print(f\" 音素类别数: {len(phoneme_cols)}\")\n", " \n", " # 获取元数据\n", " metadata_cols = ['time_step', 'session', 'trial_idx', 'ground_truth_phoneme_name', \n", " 'predicted_phoneme_name', 'max_probability', 'entropy', 'top2_margin']\n", " metadata = first_train_session[metadata_cols]\n", " print(f\"\\n4. 提取元数据:\")\n", " print(f\" 元数据列数: {len(metadata_cols)}\")\n", " print(f\" 包含: 时间步、session、试验信息、真实/预测标签、置信度指标\")\n", "\n", "# 合并所有数据集的函数\n", "def combine_datasets(dataset_list, split_name):\n", " \"\"\"合并同一类型的所有数据集\"\"\"\n", " if not dataset_list:\n", " return None\n", " \n", " combined_df = pd.concat(dataset_list, ignore_index=True)\n", " print(f\"\\n📊 {split_name}合并结果:\")\n", " print(f\" 总数据形状: {combined_df.shape}\")\n", " print(f\" 包含sessions: {combined_df['session'].nunique()} 个\")\n", " print(f\" 总试验数: {combined_df['trial_idx'].nunique()}\")\n", " print(f\" 总时间步数: {len(combined_df):,}\")\n", " \n", " return combined_df\n", "\n", "# 演示合并数据集\n", "print(f\"\\n🔗 数据集合并示例:\")\n", "if train_datasets:\n", " combined_train = combine_datasets(train_datasets, \"训练集\")\n", " \n", "if val_datasets:\n", " combined_val = combine_datasets(val_datasets, \"验证集\")\n", " \n", "if test_datasets:\n", " combined_test = combine_datasets(test_datasets, \"测试集\")\n", "\n", "# 保存数据集的函数\n", "def save_datasets_to_files(train_list, val_list, test_list, output_dir='./processed_datasets'):\n", " \"\"\"保存所有数据集到文件\"\"\"\n", " saved_files = []\n", " \n", " # 保存训练集\n", " for i, df in enumerate(train_list):\n", " filename = os.path.join(output_dir, f'train_session_{i:02d}_{df[\"session\"].iloc[0]}.csv')\n", " df.to_csv(filename, index=False)\n", " saved_files.append(filename)\n", " \n", " # 保存验证集\n", " for i, df in enumerate(val_list):\n", " filename = os.path.join(output_dir, f'val_session_{i:02d}_{df[\"session\"].iloc[0]}.csv')\n", " df.to_csv(filename, index=False)\n", " saved_files.append(filename)\n", " \n", " # 保存测试集\n", " for i, df in enumerate(test_list):\n", " filename = os.path.join(output_dir, f'test_session_{i:02d}_{df[\"session\"].iloc[0]}.csv')\n", " df.to_csv(filename, index=False)\n", " saved_files.append(filename)\n", " \n", " # 保存合并的数据集\n", " if train_list:\n", " combined_train = pd.concat(train_list, ignore_index=True)\n", " train_combined_file = os.path.join(output_dir, 'combined_train_all_sessions.csv')\n", " combined_train.to_csv(train_combined_file, index=False)\n", " saved_files.append(train_combined_file)\n", " \n", " if val_list:\n", " combined_val = pd.concat(val_list, ignore_index=True)\n", " val_combined_file = os.path.join(output_dir, 'combined_val_all_sessions.csv')\n", " combined_val.to_csv(val_combined_file, index=False)\n", " saved_files.append(val_combined_file)\n", " \n", " if test_list:\n", " combined_test = pd.concat(test_list, ignore_index=True)\n", " test_combined_file = os.path.join(output_dir, 'combined_test_all_sessions.csv')\n", " combined_test.to_csv(test_combined_file, index=False)\n", " saved_files.append(test_combined_file)\n", " \n", " return saved_files\n", "\n", "print(f\"\\n💾 数据集保存功能:\")\n", "print(f\"使用 save_datasets_to_files(train_datasets, val_datasets, test_datasets)\")\n", "print(f\"将保存所有单独session和合并的数据集文件\")\n", "\n", "print(f\"\\n🎯 工作流完成总结:\")\n", "print(\"=\"*60)\n", "print(f\"✅ 成功创建了完整的数据集处理工作流\")\n", "print(f\"✅ 生成了三个主要变量:train_datasets, val_datasets, test_datasets\")\n", "print(f\"✅ 每个变量包含多个DataFrame,对应不同的sessions\")\n", "print(f\"✅ 每个DataFrame包含:\")\n", "print(f\" • 512维神经特征 (neural_feat_000 ~ neural_feat_511)\")\n", "print(f\" • 40维音素预测概率 (phoneme_0 ~ phoneme_39)\")\n", "print(f\" • 置信度指标 (entropy, top2_margin)\")\n", "print(f\" • 元数据 (session, trial, 时间戳等)\")\n", "print(f\"✅ 支持单独访问或合并使用\")\n", "print(f\"✅ 支持保存为CSV文件\")\n", "\n", "print(f\"\\n🚀 现在你可以使用这些数据集进行机器学习建模了!\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 模型建立" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🌲 随机森林多输出回归模型\n", "\n", "使用随机森林对40个音素概率进行回归预测,输入为512维神经特征,时间窗口为30。" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "🌲 初始化随机森林回归模型\n", "🌲 随机森林回归器初始化:\n", " 时间窗口大小: 30\n", " 树的数量: 100\n", " 最大深度: 10\n", " 并行任务: -1\n", "\n", "✅ 随机森林回归器准备完成!\n", "🔧 下一步: 准备训练数据和开始训练\n" ] } ], "source": [ "# 🌲 随机森林多输出回归模型实现\n", "import numpy as np\n", "from sklearn.ensemble import RandomForestRegressor\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error\n", "from sklearn.preprocessing import StandardScaler\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from tqdm import tqdm\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "class TimeWindowRandomForestRegressor:\n", " \"\"\"\n", " 基于时间窗口的随机森林多输出回归器\n", " 用于预测40个音素的概率分布\n", " \"\"\"\n", " \n", " def __init__(self, window_size=30, n_estimators=100, max_depth=10, n_jobs=-1, random_state=42):\n", " \"\"\"\n", " 初始化模型\n", " \n", " 参数:\n", " window_size: 时间窗口大小\n", " n_estimators: 随机森林中树的数量\n", " max_depth: 树的最大深度\n", " n_jobs: 并行任务数\n", " random_state: 随机种子\n", " \"\"\"\n", " self.window_size = window_size\n", " self.n_estimators = n_estimators\n", " self.max_depth = max_depth\n", " self.n_jobs = n_jobs\n", " self.random_state = random_state\n", " \n", " # 初始化模型和预处理器\n", " self.regressor = RandomForestRegressor(\n", " n_estimators=n_estimators,\n", " max_depth=max_depth,\n", " n_jobs=n_jobs,\n", " random_state=random_state,\n", " verbose=1\n", " )\n", " \n", " self.scaler = StandardScaler()\n", " self.is_fitted = False\n", " \n", " print(f\"🌲 随机森林回归器初始化:\")\n", " print(f\" 时间窗口大小: {window_size}\")\n", " print(f\" 树的数量: {n_estimators}\")\n", " print(f\" 最大深度: {max_depth}\")\n", " print(f\" 并行任务: {n_jobs}\")\n", " \n", " def create_time_windows(self, neural_features, phoneme_targets=None):\n", " \"\"\"\n", " 创建时间窗口特征\n", " \n", " 参数:\n", " neural_features: 神经特征 [time_steps, 512]\n", " phoneme_targets: 音素目标 [time_steps, 40] (可选)\n", " \n", " 返回:\n", " windowed_features: [samples, window_size * 512]\n", " windowed_targets: [samples, 40] (如果提供targets)\n", " \"\"\"\n", " if len(neural_features) < self.window_size:\n", " print(f\"⚠️ 数据长度 {len(neural_features)} 小于窗口大小 {self.window_size}\")\n", " return None, None\n", " \n", " n_samples = len(neural_features) - self.window_size + 1\n", " n_features = neural_features.shape[1]\n", " \n", " # 创建时间窗口特征\n", " windowed_features = np.zeros((n_samples, self.window_size * n_features))\n", " \n", " for i in range(n_samples):\n", " # 展平时间窗口内的所有特征\n", " window_data = neural_features[i:i+self.window_size].flatten()\n", " windowed_features[i] = window_data\n", " \n", " windowed_targets = None\n", " if phoneme_targets is not None:\n", " # 使用窗口中心点的音素概率作为目标\n", " center_offset = self.window_size // 2\n", " windowed_targets = phoneme_targets[center_offset:center_offset+n_samples]\n", " \n", " return windowed_features, windowed_targets\n", " \n", " def prepare_dataset_for_training(self, datasets_list, dataset_type=\"train\"):\n", " \"\"\"\n", " 准备训练数据集\n", " \n", " 参数:\n", " datasets_list: DataFrame列表 (train_datasets, val_datasets, etc.)\n", " dataset_type: 数据集类型名称\n", " \n", " 返回:\n", " X: 特征矩阵 [总样本数, window_size * 512]\n", " y: 目标矩阵 [总样本数, 40]\n", " \"\"\"\n", " print(f\"\\n📊 准备{dataset_type}数据集:\")\n", " print(f\" 输入数据集数量: {len(datasets_list)}\")\n", " \n", " all_X = []\n", " all_y = []\n", " \n", " for i, df in enumerate(tqdm(datasets_list, desc=f\"处理{dataset_type}数据\")):\n", " # 提取神经特征 (前512列)\n", " neural_cols = [col for col in df.columns if col.startswith('neural_feat_')]\n", " neural_features = df[neural_cols].values\n", " \n", " # 提取音素目标 (40列音素概率)\n", " phoneme_cols = [col for col in df.columns if col.startswith('phoneme_')]\n", " phoneme_targets = df[phoneme_cols].values\n", " \n", " # 按trial分组处理\n", " trials = df['trial_idx'].unique()\n", " \n", " for trial_idx in trials:\n", " trial_mask = df['trial_idx'] == trial_idx\n", " trial_neural = neural_features[trial_mask]\n", " trial_phonemes = phoneme_targets[trial_mask]\n", " \n", " # 创建时间窗口\n", " windowed_X, windowed_y = self.create_time_windows(trial_neural, trial_phonemes)\n", " \n", " if windowed_X is not None and windowed_y is not None:\n", " all_X.append(windowed_X)\n", " all_y.append(windowed_y)\n", " \n", " if not all_X:\n", " print(f\"❌ 没有有效的{dataset_type}数据\")\n", " return None, None\n", " \n", " # 合并所有数据\n", " X = np.vstack(all_X)\n", " y = np.vstack(all_y)\n", " \n", " print(f\" ✅ {dataset_type}数据准备完成:\")\n", " print(f\" 特征矩阵形状: {X.shape}\")\n", " print(f\" 目标矩阵形状: {y.shape}\")\n", " print(f\" 内存使用: {X.nbytes / 1024**2:.1f} MB (X) + {y.nbytes / 1024**2:.1f} MB (y)\")\n", " \n", " return X, y\n", " \n", " def fit(self, X_train, y_train, X_val=None, y_val=None):\n", " \"\"\"\n", " 训练模型\n", " \n", " 参数:\n", " X_train: 训练特征\n", " y_train: 训练目标\n", " X_val: 验证特征 (可选)\n", " y_val: 验证目标 (可选)\n", " \"\"\"\n", " print(f\"\\n🚀 开始训练随机森林回归模型:\")\n", " print(f\" 训练样本数: {X_train.shape[0]:,}\")\n", " print(f\" 特征维度: {X_train.shape[1]:,}\")\n", " print(f\" 目标维度: {y_train.shape[1]}\")\n", " \n", " # 标准化特征\n", " print(\" 🔄 标准化特征...\")\n", " X_train_scaled = self.scaler.fit_transform(X_train)\n", " \n", " # 训练模型\n", " print(\" 🌲 训练随机森林...\")\n", " self.regressor.fit(X_train_scaled, y_train)\n", " \n", " self.is_fitted = True\n", " \n", " # 计算训练集性能\n", " print(\" 📊 评估训练集性能...\")\n", " train_predictions = self.regressor.predict(X_train_scaled)\n", " train_mse = mean_squared_error(y_train, train_predictions)\n", " train_r2 = r2_score(y_train, train_predictions)\n", " train_mae = mean_absolute_error(y_train, train_predictions)\n", " \n", " print(f\" ✅ 训练完成!\")\n", " print(f\" 训练集 MSE: {train_mse:.6f}\")\n", " print(f\" 训练集 R²: {train_r2:.4f}\")\n", " print(f\" 训练集 MAE: {train_mae:.6f}\")\n", " \n", " # 如果有验证集,计算验证集性能\n", " if X_val is not None and y_val is not None:\n", " print(\" 📊 评估验证集性能...\")\n", " X_val_scaled = self.scaler.transform(X_val)\n", " val_predictions = self.regressor.predict(X_val_scaled)\n", " val_mse = mean_squared_error(y_val, val_predictions)\n", " val_r2 = r2_score(y_val, val_predictions)\n", " val_mae = mean_absolute_error(y_val, val_predictions)\n", " \n", " print(f\" 验证集 MSE: {val_mse:.6f}\")\n", " print(f\" 验证集 R²: {val_r2:.4f}\")\n", " print(f\" 验证集 MAE: {val_mae:.6f}\")\n", " \n", " return {\n", " 'train_mse': train_mse, 'train_r2': train_r2, 'train_mae': train_mae,\n", " 'val_mse': val_mse, 'val_r2': val_r2, 'val_mae': val_mae\n", " }\n", " \n", " return {\n", " 'train_mse': train_mse, 'train_r2': train_r2, 'train_mae': train_mae\n", " }\n", " \n", " def predict(self, X):\n", " \"\"\"预测\"\"\"\n", " if not self.is_fitted:\n", " raise ValueError(\"模型尚未训练,请先调用fit()方法\")\n", " \n", " X_scaled = self.scaler.transform(X)\n", " return self.regressor.predict(X_scaled)\n", " \n", " def get_feature_importance(self, top_k=20):\n", " \"\"\"获取特征重要性\"\"\"\n", " if not self.is_fitted:\n", " raise ValueError(\"模型尚未训练,请先调用fit()方法\")\n", " \n", " importances = self.regressor.feature_importances_\n", " \n", " # 创建特征名称 (window_timestep_feature)\n", " feature_names = []\n", " for t in range(self.window_size):\n", " for f in range(512):\n", " feature_names.append(f\"t{t}_feat{f}\")\n", " \n", " # 获取top-k重要特征\n", " top_indices = np.argsort(importances)[::-1][:top_k]\n", " top_features = [(feature_names[i], importances[i]) for i in top_indices]\n", " \n", " return top_features, importances\n", "\n", "# 初始化模型\n", "print(\"🌲 初始化随机森林回归模型\")\n", "rf_regressor = TimeWindowRandomForestRegressor(\n", " window_size=30, # 时间窗口大小\n", " n_estimators=100, # 树的数量\n", " max_depth=10, # 最大深度 (防止过拟合)\n", " n_jobs=-1, # 使用所有CPU核心\n", " random_state=42\n", ")\n", "\n", "print(\"\\n✅ 随机森林回归器准备完成!\")\n", "print(\"🔧 下一步: 准备训练数据和开始训练\")" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "🚀 开始数据准备和模型训练流程\n", "============================================================\n", "❌ 没有可用的训练数据,请先运行数据处理工作流\n" ] } ], "source": [ "# 🚀 准备数据并训练随机森林回归模型\n", "print(\"🚀 开始数据准备和模型训练流程\")\n", "print(\"=\"*60)\n", "\n", "# 检查数据可用性\n", "if not train_datasets:\n", " print(\"❌ 没有可用的训练数据,请先运行数据处理工作流\")\n", "else:\n", " print(f\"✅ 检测到数据:\")\n", " print(f\" 训练数据集: {len(train_datasets)} 个sessions\")\n", " print(f\" 验证数据集: {len(val_datasets)} 个sessions\")\n", " print(f\" 测试数据集: {len(test_datasets)} 个sessions\")\n", "\n", " # 1. 准备训练数据\n", " print(f\"\\n📊 第1步: 准备训练数据\")\n", " X_train, y_train = rf_regressor.prepare_dataset_for_training(train_datasets, \"训练集\")\n", " \n", " # 2. 准备验证数据\n", " print(f\"\\n📊 第2步: 准备验证数据\")\n", " X_val, y_val = rf_regressor.prepare_dataset_for_training(val_datasets, \"验证集\")\n", " \n", " if X_train is not None and y_train is not None:\n", " print(f\"\\n📈 数据准备完成统计:\")\n", " print(f\" 训练集: {X_train.shape[0]:,} 样本\")\n", " print(f\" 验证集: {X_val.shape[0]:,} 样本\" if X_val is not None else \" 验证集: 无\")\n", " print(f\" 特征维度: {X_train.shape[1]:,} (时间窗口30 × 512特征)\")\n", " print(f\" 目标维度: {y_train.shape[1]} (40个音素概率)\")\n", " \n", " # 检查数据质量\n", " print(f\"\\n🔍 数据质量检查:\")\n", " print(f\" 训练特征范围: [{X_train.min():.4f}, {X_train.max():.4f}]\")\n", " print(f\" 训练目标范围: [{y_train.min():.4f}, {y_train.max():.4f}]\")\n", " print(f\" 训练特征均值: {X_train.mean():.4f}\")\n", " print(f\" 训练目标均值: {y_train.mean():.4f}\")\n", " \n", " # 检查是否有NaN或无穷值\n", " nan_count_X = np.isnan(X_train).sum()\n", " nan_count_y = np.isnan(y_train).sum()\n", " inf_count_X = np.isinf(X_train).sum()\n", " inf_count_y = np.isinf(y_train).sum()\n", " \n", " print(f\" NaN检查: X有{nan_count_X}个, y有{nan_count_y}个\")\n", " print(f\" Inf检查: X有{inf_count_X}个, y有{inf_count_y}个\")\n", " \n", " if nan_count_X > 0 or nan_count_y > 0 or inf_count_X > 0 or inf_count_y > 0:\n", " print(\"⚠️ 检测到异常值,将进行清理...\")\n", " # 清理异常值\n", " valid_mask = ~(np.isnan(X_train).any(axis=1) | np.isnan(y_train).any(axis=1) | \n", " np.isinf(X_train).any(axis=1) | np.isinf(y_train).any(axis=1))\n", " X_train = X_train[valid_mask]\n", " y_train = y_train[valid_mask]\n", " \n", " if X_val is not None and y_val is not None:\n", " valid_mask_val = ~(np.isnan(X_val).any(axis=1) | np.isnan(y_val).any(axis=1) | \n", " np.isinf(X_val).any(axis=1) | np.isinf(y_val).any(axis=1))\n", " X_val = X_val[valid_mask_val]\n", " y_val = y_val[valid_mask_val]\n", " \n", " print(f\"✅ 数据清理完成,剩余训练样本: {X_train.shape[0]:,}\")\n", " \n", " # 3. 训练模型\n", " print(f\"\\n🌲 第3步: 训练随机森林回归模型\")\n", " training_results = rf_regressor.fit(X_train, y_train, X_val, y_val)\n", " \n", " # 4. 分析训练结果\n", " print(f\"\\n📊 第4步: 训练结果分析\")\n", " print(\"=\"*50)\n", " \n", " for metric, value in training_results.items():\n", " metric_name = metric.replace('_', ' ').title()\n", " print(f\" {metric_name}: {value:.6f}\")\n", " \n", " # 5. 特征重要性分析\n", " print(f\"\\n🔍 第5步: 特征重要性分析\")\n", " top_features, all_importances = rf_regressor.get_feature_importance(top_k=20)\n", " \n", " print(f\"\\n🏆 Top 20 重要特征:\")\n", " print(f\"{'排名':>4} {'特征名称':>15} {'重要性':>10}\")\n", " print(\"-\" * 35)\n", " for i, (feature_name, importance) in enumerate(top_features):\n", " print(f\"{i+1:>4} {feature_name:>15} {importance:>10.6f}\")\n", " \n", " # 分析时间窗口内的重要性分布\n", " print(f\"\\n📈 时间窗口重要性分布:\")\n", " window_importances = np.zeros(rf_regressor.window_size)\n", " for i in range(rf_regressor.window_size):\n", " start_idx = i * 512\n", " end_idx = (i + 1) * 512\n", " window_importances[i] = all_importances[start_idx:end_idx].sum()\n", " \n", " max_time_step = np.argmax(window_importances)\n", " print(f\" 最重要的时间步: t{max_time_step} (重要性: {window_importances[max_time_step]:.6f})\")\n", " print(f\" 窗口中心位置: t{rf_regressor.window_size//2}\")\n", " print(f\" 重要性分布: 前5个时间步的重要性\")\n", " for i in range(min(5, len(window_importances))):\n", " print(f\" t{i}: {window_importances[i]:.6f}\")\n", " \n", " print(f\"\\n✅ 随机森林回归模型训练完成!\")\n", " print(f\"🎯 模型可以预测40个音素的概率分布\")\n", " print(f\"📊 基于30时间步的神经特征窗口\")\n", " print(f\"🌲 使用{rf_regressor.n_estimators}棵决策树\")\n", " \n", " else:\n", " print(\"❌ 数据准备失败,无法训练模型\")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "⚠️ 模型尚未训练完成,请先运行训练代码\n" ] } ], "source": [ "# 📊 模型评估和可视化分析\n", "def evaluate_phoneme_predictions(rf_model, X_test, y_test, dataset_name=\"测试集\"):\n", " \"\"\"\n", " 评估每个音素的预测性能\n", " \"\"\"\n", " print(f\"\\n📊 {dataset_name}详细评估\")\n", " print(\"=\"*50)\n", " \n", " # 获取预测结果\n", " y_pred = rf_model.predict(X_test)\n", " \n", " # 计算每个音素的性能指标\n", " phoneme_metrics = []\n", " \n", " for i in range(40): # 40个音素\n", " phoneme_name = LOGIT_TO_PHONEME[i]\n", " \n", " # 计算单个音素的指标\n", " mse = mean_squared_error(y_test[:, i], y_pred[:, i])\n", " r2 = r2_score(y_test[:, i], y_pred[:, i])\n", " mae = mean_absolute_error(y_test[:, i], y_pred[:, i])\n", " \n", " # 计算相关系数\n", " correlation = np.corrcoef(y_test[:, i], y_pred[:, i])[0, 1]\n", " \n", " phoneme_metrics.append({\n", " 'phoneme_id': i,\n", " 'phoneme_name': phoneme_name,\n", " 'mse': mse, \n", " 'r2': r2,\n", " 'mae': mae,\n", " 'correlation': correlation if not np.isnan(correlation) else 0.0\n", " })\n", " \n", " # 转换为DataFrame便于分析\n", " metrics_df = pd.DataFrame(phoneme_metrics)\n", " \n", " # 打印总体统计\n", " print(f\"📈 总体性能指标:\")\n", " print(f\" 平均 MSE: {metrics_df['mse'].mean():.6f}\")\n", " print(f\" 平均 R²: {metrics_df['r2'].mean():.4f}\")\n", " print(f\" 平均 MAE: {metrics_df['mae'].mean():.6f}\")\n", " print(f\" 平均相关系数: {metrics_df['correlation'].mean():.4f}\")\n", " \n", " # 找出最佳和最差预测的音素\n", " best_r2_idx = metrics_df['r2'].idxmax()\n", " worst_r2_idx = metrics_df['r2'].idxmin()\n", " \n", " print(f\"\\n🏆 最佳预测音素:\")\n", " best_phoneme = metrics_df.loc[best_r2_idx]\n", " print(f\" {best_phoneme['phoneme_name']} (ID: {best_phoneme['phoneme_id']})\")\n", " print(f\" R²: {best_phoneme['r2']:.4f}, MSE: {best_phoneme['mse']:.6f}\")\n", " \n", " print(f\"\\n📉 最差预测音素:\")\n", " worst_phoneme = metrics_df.loc[worst_r2_idx]\n", " print(f\" {worst_phoneme['phoneme_name']} (ID: {worst_phoneme['phoneme_id']})\")\n", " print(f\" R²: {worst_phoneme['r2']:.4f}, MSE: {worst_phoneme['mse']:.6f}\")\n", " \n", " return metrics_df, y_pred\n", "\n", "def visualize_prediction_results(metrics_df, y_true, y_pred, save_plots=False):\n", " \"\"\"\n", " 可视化预测结果\n", " \"\"\"\n", " print(f\"\\n📊 创建可视化图表...\")\n", " \n", " # 设置图表样式\n", " plt.style.use('default')\n", " fig = plt.figure(figsize=(20, 12))\n", " \n", " # 1. R²分数分布\n", " plt.subplot(2, 3, 1)\n", " plt.hist(metrics_df['r2'], bins=20, alpha=0.7, color='skyblue', edgecolor='black')\n", " plt.axvline(metrics_df['r2'].mean(), color='red', linestyle='--', \n", " label=f'平均值: {metrics_df[\"r2\"].mean():.4f}')\n", " plt.xlabel('R² Score')\n", " plt.ylabel('音素数量')\n", " plt.title('R² Score 分布')\n", " plt.legend()\n", " plt.grid(True, alpha=0.3)\n", " \n", " # 2. MSE分布\n", " plt.subplot(2, 3, 2)\n", " plt.hist(metrics_df['mse'], bins=20, alpha=0.7, color='lightcoral', edgecolor='black')\n", " plt.axvline(metrics_df['mse'].mean(), color='red', linestyle='--',\n", " label=f'平均值: {metrics_df[\"mse\"].mean():.6f}')\n", " plt.xlabel('Mean Squared Error')\n", " plt.ylabel('音素数量')\n", " plt.title('MSE 分布')\n", " plt.legend()\n", " plt.grid(True, alpha=0.3)\n", " \n", " # 3. 前10个音素的性能对比\n", " plt.subplot(2, 3, 3)\n", " top_10 = metrics_df.nlargest(10, 'r2')\n", " bars = plt.bar(range(10), top_10['r2'], color='lightgreen', alpha=0.7)\n", " plt.xlabel('音素排名')\n", " plt.ylabel('R² Score')\n", " plt.title('Top 10 音素预测性能')\n", " plt.xticks(range(10), top_10['phoneme_name'], rotation=45)\n", " plt.grid(True, alpha=0.3)\n", " \n", " # 添加数值标签\n", " for i, bar in enumerate(bars):\n", " height = bar.get_height()\n", " plt.text(bar.get_x() + bar.get_width()/2., height + 0.001,\n", " f'{height:.3f}', ha='center', va='bottom', fontsize=8)\n", " \n", " # 4. 真实值 vs 预测值散点图 (选择最佳音素)\n", " plt.subplot(2, 3, 4)\n", " best_phoneme_idx = metrics_df['r2'].idxmax()\n", " phoneme_id = metrics_df.loc[best_phoneme_idx, 'phoneme_id']\n", " phoneme_name = metrics_df.loc[best_phoneme_idx, 'phoneme_name']\n", " \n", " # 随机采样1000个点以避免图表过于密集\n", " sample_size = min(1000, len(y_true))\n", " sample_indices = np.random.choice(len(y_true), sample_size, replace=False)\n", " \n", " plt.scatter(y_true[sample_indices, phoneme_id], y_pred[sample_indices, phoneme_id], \n", " alpha=0.6, s=20, color='blue')\n", " \n", " # 添加对角线 (完美预测线)\n", " min_val = min(y_true[:, phoneme_id].min(), y_pred[:, phoneme_id].min())\n", " max_val = max(y_true[:, phoneme_id].max(), y_pred[:, phoneme_id].max())\n", " plt.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8, label='完美预测')\n", " \n", " plt.xlabel('真实值')\n", " plt.ylabel('预测值')\n", " plt.title(f'最佳音素 {phoneme_name} 的预测结果')\n", " plt.legend()\n", " plt.grid(True, alpha=0.3)\n", " \n", " # 5. 相关系数热力图 (前20个音素)\n", " plt.subplot(2, 3, 5)\n", " top_20_correlations = metrics_df.nlargest(20, 'correlation')\n", " corr_data = top_20_correlations[['phoneme_name', 'correlation']].set_index('phoneme_name')\n", " \n", " # 创建热力图数据\n", " heatmap_data = corr_data.values.reshape(-1, 1)\n", " im = plt.imshow(heatmap_data.T, cmap='RdYlBu_r', aspect='auto', vmin=0, vmax=1)\n", " \n", " plt.colorbar(im, shrink=0.8)\n", " plt.yticks([0], ['相关系数'])\n", " plt.xticks(range(len(top_20_correlations)), top_20_correlations['phoneme_name'], \n", " rotation=45, ha='right')\n", " plt.title('Top 20 音素相关系数')\n", " \n", " # 6. 各音素预测误差箱线图 (前10个音素)\n", " plt.subplot(2, 3, 6)\n", " top_10_ids = metrics_df.nlargest(10, 'r2')['phoneme_id'].values\n", " errors_data = []\n", " labels = []\n", " \n", " for phoneme_id in top_10_ids:\n", " errors = np.abs(y_true[:, phoneme_id] - y_pred[:, phoneme_id])\n", " errors_data.append(errors)\n", " labels.append(LOGIT_TO_PHONEME[phoneme_id])\n", " \n", " plt.boxplot(errors_data, labels=labels)\n", " plt.xlabel('音素')\n", " plt.ylabel('绝对误差')\n", " plt.title('Top 10 音素预测误差分布')\n", " plt.xticks(rotation=45)\n", " plt.grid(True, alpha=0.3)\n", " \n", " plt.tight_layout()\n", " \n", " if save_plots:\n", " plt.savefig('./processed_datasets/rf_regression_results.png', dpi=300, bbox_inches='tight')\n", " print(\"📁 图表已保存至: ./processed_datasets/rf_regression_results.png\")\n", " \n", " plt.show()\n", "\n", "# 如果模型训练成功,进行评估\n", "if 'rf_regressor' in locals() and rf_regressor.is_fitted:\n", " print(f\"\\n🎯 开始模型评估和可视化\")\n", " \n", " # 评估验证集\n", " if X_val is not None and y_val is not None:\n", " val_metrics, val_predictions = evaluate_phoneme_predictions(\n", " rf_regressor, X_val, y_val, \"验证集\"\n", " )\n", " \n", " # 可视化结果\n", " visualize_prediction_results(val_metrics, y_val, val_predictions, save_plots=True)\n", " \n", " # 保存详细结果\n", " val_metrics.to_csv('./processed_datasets/phoneme_prediction_metrics.csv', index=False)\n", " print(f\"\\n📁 详细评估结果已保存至: ./processed_datasets/phoneme_prediction_metrics.csv\")\n", " \n", " # 准备测试集数据 (如果有)\n", " if test_datasets:\n", " print(f\"\\n🔮 准备测试集预测...\")\n", " X_test, y_test = rf_regressor.prepare_dataset_for_training(test_datasets, \"测试集\")\n", " \n", " if X_test is not None:\n", " test_metrics, test_predictions = evaluate_phoneme_predictions(\n", " rf_regressor, X_test, y_test, \"测试集\"\n", " )\n", " print(f\"\\n✅ 测试集评估完成\")\n", " else:\n", " print(f\"⚠️ 测试集数据准备失败\")\n", " \n", " print(f\"\\n🎉 随机森林回归模型完整评估完成!\")\n", " print(f\"📊 生成了详细的性能分析和可视化图表\")\n", " print(f\"🔧 模型已准备好用于实际预测任务\")\n", " \n", "else:\n", " print(\"⚠️ 模型尚未训练完成,请先运行训练代码\")" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "✅ 回归转分类分析功能已创建!\n", "🎯 主要功能:\n", "• 将40维概率回归结果转换为分类预测\n", "• 计算分类准确率和置信度分析\n", "• 提供Top-K准确率评估\n", "• 生成详细的混淆矩阵和错误分析\n", "• 创建全面的可视化图表\n" ] } ], "source": [ "# 🎯 回归结果转分类结果分析\n", "def regression_to_classification_analysis(y_true_probs, y_pred_probs, show_detailed_metrics=True):\n", " \"\"\"\n", " 将回归预测的40个音素概率转换为分类结果并进行分析\n", " \n", " 参数:\n", " y_true_probs: 真实的40个音素概率 [n_samples, 40]\n", " y_pred_probs: 预测的40个音素概率 [n_samples, 40]\n", " show_detailed_metrics: 是否显示详细的分类指标\n", " \n", " 返回:\n", " classification_results: 包含分类结果的字典\n", " \"\"\"\n", " print(\"🎯 回归结果转分类结果分析\")\n", " print(\"=\"*60)\n", " \n", " # 1. 将概率转换为分类标签\n", " y_true_classes = np.argmax(y_true_probs, axis=1) # 真实类别\n", " y_pred_classes = np.argmax(y_pred_probs, axis=1) # 预测类别\n", " \n", " # 2. 计算分类准确率\n", " accuracy = (y_true_classes == y_pred_classes).mean()\n", " \n", " print(f\"📊 分类结果概览:\")\n", " print(f\" 总样本数: {len(y_true_classes):,}\")\n", " print(f\" 分类准确率: {accuracy:.4f} ({accuracy*100:.2f}%)\")\n", " print(f\" 正确预测: {(y_true_classes == y_pred_classes).sum():,}\")\n", " print(f\" 错误预测: {(y_true_classes != y_pred_classes).sum():,}\")\n", " \n", " # 3. 分析预测置信度\n", " pred_confidences = np.max(y_pred_probs, axis=1) # 预测的最大概率\n", " true_confidences = np.max(y_true_probs, axis=1) # 真实的最大概率\n", " \n", " print(f\"\\n🔍 预测置信度分析:\")\n", " print(f\" 预测置信度均值: {pred_confidences.mean():.4f}\")\n", " print(f\" 预测置信度标准差: {pred_confidences.std():.4f}\")\n", " print(f\" 预测置信度范围: [{pred_confidences.min():.4f}, {pred_confidences.max():.4f}]\")\n", " print(f\" 真实置信度均值: {true_confidences.mean():.4f}\")\n", " \n", " # 4. 按置信度分组的准确率分析\n", " confidence_bins = [0.0, 0.3, 0.5, 0.7, 0.9, 1.0]\n", " print(f\"\\n📈 按预测置信度分组的准确率:\")\n", " print(f\"{'置信度区间':>12} {'样本数':>8} {'准确率':>8} {'百分比':>8}\")\n", " print(\"-\" * 40)\n", " \n", " for i in range(len(confidence_bins)-1):\n", " low, high = confidence_bins[i], confidence_bins[i+1]\n", " mask = (pred_confidences >= low) & (pred_confidences < high)\n", " if i == len(confidence_bins)-2: # 最后一个区间包含等号\n", " mask = (pred_confidences >= low) & (pred_confidences <= high)\n", " \n", " if mask.sum() > 0:\n", " bin_accuracy = (y_true_classes[mask] == y_pred_classes[mask]).mean()\n", " count = mask.sum()\n", " percentage = count / len(y_true_classes) * 100\n", " print(f\"[{low:.1f}, {high:.1f}{')'if i1} {count:>8} {bin_accuracy:>8.4f} {percentage:>7.1f}%\")\n", " \n", " # 5. 混淆矩阵分析(Top-K音素)\n", " from collections import Counter\n", " \n", " # 找出最常见的音素\n", " true_counter = Counter(y_true_classes)\n", " pred_counter = Counter(y_pred_classes)\n", " \n", " most_common_true = true_counter.most_common(10)\n", " most_common_pred = pred_counter.most_common(10)\n", " \n", " print(f\"\\n🏆 最常见的音素 (真实 vs 预测):\")\n", " print(f\"{'真实音素':>12} {'次数':>6} {'预测音素':>12} {'次数':>6}\")\n", " print(\"-\" * 42)\n", " \n", " for i in range(min(len(most_common_true), len(most_common_pred))):\n", " true_id, true_count = most_common_true[i]\n", " pred_id, pred_count = most_common_pred[i]\n", " true_name = LOGIT_TO_PHONEME[true_id]\n", " pred_name = LOGIT_TO_PHONEME[pred_id]\n", " print(f\"{true_name:>12} {true_count:>6} {pred_name:>12} {pred_count:>6}\")\n", " \n", " # 6. 每个音素的分类性能\n", " if show_detailed_metrics:\n", " from sklearn.metrics import classification_report, confusion_matrix\n", " \n", " print(f\"\\n📋 详细分类报告 (前20个最常见音素):\")\n", " \n", " # 获取前20个最常见的音素\n", " top_20_phonemes = [phoneme_id for phoneme_id, _ in most_common_true[:20]]\n", " \n", " # 创建掩码,只包含这些音素\n", " mask_top20 = np.isin(y_true_classes, top_20_phonemes)\n", " y_true_top20 = y_true_classes[mask_top20]\n", " y_pred_top20 = y_pred_classes[mask_top20]\n", " \n", " # 生成分类报告\n", " target_names = [LOGIT_TO_PHONEME[i] for i in top_20_phonemes]\n", " \n", " try:\n", " report = classification_report(\n", " y_true_top20, y_pred_top20, \n", " labels=top_20_phonemes,\n", " target_names=target_names,\n", " output_dict=True,\n", " zero_division=0\n", " )\n", " \n", " # 打印格式化的报告\n", " print(f\"{'音素':>8} {'精确率':>8} {'召回率':>8} {'F1分数':>8} {'支持数':>8}\")\n", " print(\"-\" * 48)\n", " \n", " for phoneme_id in top_20_phonemes:\n", " phoneme_name = LOGIT_TO_PHONEME[phoneme_id]\n", " if phoneme_name in report:\n", " metrics = report[phoneme_name]\n", " print(f\"{phoneme_name:>8} {metrics['precision']:>8.4f} {metrics['recall']:>8.4f} \"\n", " f\"{metrics['f1-score']:>8.4f} {int(metrics['support']):>8}\")\n", " \n", " # 总体指标\n", " macro_avg = report['macro avg']\n", " weighted_avg = report['weighted avg']\n", " print(\"-\" * 48)\n", " print(f\"{'宏平均':>8} {macro_avg['precision']:>8.4f} {macro_avg['recall']:>8.4f} \"\n", " f\"{macro_avg['f1-score']:>8.4f}\")\n", " print(f\"{'加权平均':>8} {weighted_avg['precision']:>8.4f} {weighted_avg['recall']:>8.4f} \"\n", " f\"{weighted_avg['f1-score']:>8.4f}\")\n", " \n", " except Exception as e:\n", " print(f\"分类报告生成失败: {e}\")\n", " \n", " # 7. Top-K准确率分析\n", " print(f\"\\n🎯 Top-K 准确率分析:\")\n", " for k in [1, 3, 5, 10]:\n", " # 计算Top-K准确率\n", " top_k_pred = np.argsort(y_pred_probs, axis=1)[:, -k:] # 取概率最高的K个\n", " top_k_accuracy = np.mean([y_true_classes[i] in top_k_pred[i] for i in range(len(y_true_classes))])\n", " print(f\" Top-{k} 准确率: {top_k_accuracy:.4f} ({top_k_accuracy*100:.2f}%)\")\n", " \n", " # 8. 错误分析 - 最常见的预测错误\n", " print(f\"\\n❌ 最常见的预测错误:\")\n", " error_mask = y_true_classes != y_pred_classes\n", " error_pairs = list(zip(y_true_classes[error_mask], y_pred_classes[error_mask]))\n", " error_counter = Counter(error_pairs)\n", " \n", " print(f\"{'真实音素':>12} {'预测音素':>12} {'错误次数':>8}\")\n", " print(\"-\" * 36)\n", " for (true_id, pred_id), count in error_counter.most_common(10):\n", " true_name = LOGIT_TO_PHONEME[true_id]\n", " pred_name = LOGIT_TO_PHONEME[pred_id]\n", " print(f\"{true_name:>12} {pred_name:>12} {count:>8}\")\n", " \n", " # 返回结果字典\n", " classification_results = {\n", " 'accuracy': accuracy,\n", " 'y_true_classes': y_true_classes,\n", " 'y_pred_classes': y_pred_classes,\n", " 'pred_confidences': pred_confidences,\n", " 'true_confidences': true_confidences,\n", " 'most_common_errors': error_counter.most_common(10)\n", " }\n", " \n", " return classification_results\n", "\n", "def create_classification_visualizations(y_true_probs, y_pred_probs, classification_results):\n", " \"\"\"\n", " 为分类结果创建可视化图表\n", " \"\"\"\n", " print(f\"\\n📊 创建分类结果可视化...\")\n", " \n", " fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n", " fig.suptitle('随机森林回归转分类结果分析', fontsize=16, fontweight='bold')\n", " \n", " y_true_classes = classification_results['y_true_classes']\n", " y_pred_classes = classification_results['y_pred_classes']\n", " pred_confidences = classification_results['pred_confidences']\n", " \n", " # 1. 预测置信度分布\n", " axes[0, 0].hist(pred_confidences, bins=50, alpha=0.7, color='skyblue', edgecolor='black')\n", " axes[0, 0].axvline(pred_confidences.mean(), color='red', linestyle='--', \n", " label=f'均值: {pred_confidences.mean():.3f}')\n", " axes[0, 0].set_xlabel('预测置信度')\n", " axes[0, 0].set_ylabel('样本数量')\n", " axes[0, 0].set_title('预测置信度分布')\n", " axes[0, 0].legend()\n", " axes[0, 0].grid(True, alpha=0.3)\n", " \n", " # 2. 准确率 vs 置信度\n", " confidence_bins = np.linspace(0, 1, 21)\n", " bin_centers = (confidence_bins[:-1] + confidence_bins[1:]) / 2\n", " bin_accuracies = []\n", " bin_counts = []\n", " \n", " for i in range(len(confidence_bins)-1):\n", " mask = (pred_confidences >= confidence_bins[i]) & (pred_confidences < confidence_bins[i+1])\n", " if mask.sum() > 0:\n", " accuracy = (y_true_classes[mask] == y_pred_classes[mask]).mean()\n", " bin_accuracies.append(accuracy)\n", " bin_counts.append(mask.sum())\n", " else:\n", " bin_accuracies.append(0)\n", " bin_counts.append(0)\n", " \n", " # 只显示有数据的bins\n", " valid_bins = np.array(bin_counts) > 0\n", " axes[0, 1].plot(bin_centers[valid_bins], np.array(bin_accuracies)[valid_bins], \n", " 'bo-', linewidth=2, markersize=6)\n", " axes[0, 1].set_xlabel('预测置信度')\n", " axes[0, 1].set_ylabel('准确率')\n", " axes[0, 1].set_title('准确率 vs 预测置信度')\n", " axes[0, 1].grid(True, alpha=0.3)\n", " axes[0, 1].set_ylim(0, 1)\n", " \n", " # 3. 最常见音素的预测准确率\n", " from collections import Counter\n", " true_counter = Counter(y_true_classes)\n", " most_common_phonemes = [phoneme_id for phoneme_id, _ in true_counter.most_common(15)]\n", " \n", " phoneme_accuracies = []\n", " phoneme_names = []\n", " for phoneme_id in most_common_phonemes:\n", " mask = y_true_classes == phoneme_id\n", " if mask.sum() > 0:\n", " accuracy = (y_pred_classes[mask] == phoneme_id).mean()\n", " phoneme_accuracies.append(accuracy)\n", " phoneme_names.append(LOGIT_TO_PHONEME[phoneme_id])\n", " \n", " bars = axes[0, 2].bar(range(len(phoneme_names)), phoneme_accuracies, \n", " color='lightgreen', alpha=0.7)\n", " axes[0, 2].set_xlabel('音素')\n", " axes[0, 2].set_ylabel('准确率')\n", " axes[0, 2].set_title('Top 15 音素的分类准确率')\n", " axes[0, 2].set_xticks(range(len(phoneme_names)))\n", " axes[0, 2].set_xticklabels(phoneme_names, rotation=45, ha='right')\n", " axes[0, 2].grid(True, alpha=0.3)\n", " \n", " # 添加数值标签\n", " for bar, acc in zip(bars, phoneme_accuracies):\n", " height = bar.get_height()\n", " axes[0, 2].text(bar.get_x() + bar.get_width()/2., height + 0.01,\n", " f'{acc:.3f}', ha='center', va='bottom', fontsize=8)\n", " \n", " # 4. 混淆矩阵(前10个最常见音素)\n", " from sklearn.metrics import confusion_matrix\n", " top_10_phonemes = most_common_phonemes[:10]\n", " mask_top10 = np.isin(y_true_classes, top_10_phonemes) & np.isin(y_pred_classes, top_10_phonemes)\n", " \n", " if mask_top10.sum() > 0:\n", " cm = confusion_matrix(y_true_classes[mask_top10], y_pred_classes[mask_top10], \n", " labels=top_10_phonemes)\n", " \n", " im = axes[1, 0].imshow(cm, interpolation='nearest', cmap='Blues')\n", " axes[1, 0].set_title('混淆矩阵 (Top 10 音素)')\n", " \n", " # 添加颜色条\n", " cbar = plt.colorbar(im, ax=axes[1, 0], shrink=0.8)\n", " cbar.set_label('预测次数')\n", " \n", " # 设置标签\n", " tick_marks = np.arange(len(top_10_phonemes))\n", " top_10_names = [LOGIT_TO_PHONEME[i] for i in top_10_phonemes]\n", " axes[1, 0].set_xticks(tick_marks)\n", " axes[1, 0].set_yticks(tick_marks)\n", " axes[1, 0].set_xticklabels(top_10_names, rotation=45, ha='right')\n", " axes[1, 0].set_yticklabels(top_10_names)\n", " axes[1, 0].set_xlabel('预测音素')\n", " axes[1, 0].set_ylabel('真实音素')\n", " \n", " # 5. Top-K准确率\n", " k_values = [1, 2, 3, 4, 5, 10, 15, 20]\n", " top_k_accuracies = []\n", " \n", " for k in k_values:\n", " top_k_pred = np.argsort(y_pred_probs, axis=1)[:, -k:]\n", " top_k_accuracy = np.mean([y_true_classes[i] in top_k_pred[i] for i in range(len(y_true_classes))])\n", " top_k_accuracies.append(top_k_accuracy)\n", " \n", " axes[1, 1].plot(k_values, top_k_accuracies, 'ro-', linewidth=2, markersize=8)\n", " axes[1, 1].set_xlabel('K 值')\n", " axes[1, 1].set_ylabel('Top-K 准确率')\n", " axes[1, 1].set_title('Top-K 准确率曲线')\n", " axes[1, 1].grid(True, alpha=0.3)\n", " axes[1, 1].set_ylim(0, 1)\n", " \n", " # 添加数值标签\n", " for k, acc in zip(k_values, top_k_accuracies):\n", " axes[1, 1].annotate(f'{acc:.3f}', (k, acc), textcoords=\"offset points\", \n", " xytext=(0,10), ha='center')\n", " \n", " # 6. 错误分析 - 最常见错误的热力图\n", " error_pairs = classification_results['most_common_errors'][:25] # 前25个最常见错误\n", " if error_pairs:\n", " # 创建错误矩阵\n", " unique_phonemes = list(set([pair[0][0] for pair in error_pairs] + [pair[0][1] for pair in error_pairs]))\n", " error_matrix = np.zeros((len(unique_phonemes), len(unique_phonemes)))\n", " \n", " phoneme_to_idx = {phoneme: i for i, phoneme in enumerate(unique_phonemes)}\n", " \n", " for (true_id, pred_id), count in error_pairs:\n", " if true_id in phoneme_to_idx and pred_id in phoneme_to_idx:\n", " true_idx = phoneme_to_idx[true_id]\n", " pred_idx = phoneme_to_idx[pred_id]\n", " error_matrix[true_idx, pred_idx] = count\n", " \n", " im = axes[1, 2].imshow(error_matrix, cmap='Reds', interpolation='nearest')\n", " axes[1, 2].set_title('最常见错误分布')\n", " \n", " # 设置标签\n", " phoneme_names = [LOGIT_TO_PHONEME[p] for p in unique_phonemes]\n", " axes[1, 2].set_xticks(range(len(phoneme_names)))\n", " axes[1, 2].set_yticks(range(len(phoneme_names)))\n", " axes[1, 2].set_xticklabels(phoneme_names, rotation=45, ha='right')\n", " axes[1, 2].set_yticklabels(phoneme_names)\n", " axes[1, 2].set_xlabel('预测音素')\n", " axes[1, 2].set_ylabel('真实音素')\n", " \n", " # 添加颜色条\n", " cbar = plt.colorbar(im, ax=axes[1, 2], shrink=0.8)\n", " cbar.set_label('错误次数')\n", " \n", " plt.tight_layout()\n", " plt.savefig('./processed_datasets/classification_analysis.png', dpi=300, bbox_inches='tight')\n", " print(\"📁 分类分析图表已保存至: ./processed_datasets/classification_analysis.png\")\n", " plt.show()\n", "\n", "print(\"✅ 回归转分类分析功能已创建!\")\n", "print(\"🎯 主要功能:\")\n", "print(\"• 将40维概率回归结果转换为分类预测\")\n", "print(\"• 计算分类准确率和置信度分析\")\n", "print(\"• 提供Top-K准确率评估\")\n", "print(\"• 生成详细的混淆矩阵和错误分析\")\n", "print(\"• 创建全面的可视化图表\")" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "⚠️ 随机森林模型尚未训练完成\n", "💡 请先运行前面的训练代码\n" ] } ], "source": [ "# 🎯 完整的回归转分类评估流程\n", "def complete_regression_classification_evaluation(rf_model, X_test, y_test, dataset_name=\"测试集\"):\n", " \"\"\"\n", " 完整的回归模型转分类结果评估流程\n", " \"\"\"\n", " print(f\"\\n🎯 {dataset_name}完整评估: 回归 → 分类\")\n", " print(\"=\"*70)\n", " \n", " # 1. 获取回归预测结果\n", " print(\"📊 第1步: 获取回归预测...\")\n", " y_pred_probs = rf_model.predict(X_test)\n", " \n", " # 确保概率值在合理范围内\n", " y_pred_probs = np.clip(y_pred_probs, 0, 1)\n", " \n", " # 2. 回归性能评估\n", " print(\"\\n📈 第2步: 回归性能评估...\")\n", " mse = mean_squared_error(y_test, y_pred_probs) \n", " mae = mean_absolute_error(y_test, y_pred_probs)\n", " r2 = r2_score(y_test, y_pred_probs)\n", " \n", " print(f\" 回归 MSE: {mse:.6f}\")\n", " print(f\" 回归 MAE: {mae:.6f}\")\n", " print(f\" 回归 R²: {r2:.4f}\")\n", " \n", " # 3. 概率归一化(softmax)\n", " print(\"\\n🔄 第3步: 概率归一化...\")\n", " def softmax(x, axis=-1):\n", " exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))\n", " return exp_x / np.sum(exp_x, axis=axis, keepdims=True)\n", " \n", " # 对预测结果应用softmax,使其成为真正的概率分布\n", " y_pred_probs_normalized = softmax(y_pred_probs)\n", " y_test_normalized = softmax(y_test) # 也对真实标签归一化\n", " \n", " print(f\" 预测概率归一化前: 每行和均值 = {np.mean(np.sum(y_pred_probs, axis=1)):.4f}\")\n", " print(f\" 预测概率归一化后: 每行和均值 = {np.mean(np.sum(y_pred_probs_normalized, axis=1)):.4f}\")\n", " \n", " # 4. 分类结果分析\n", " print(\"\\n🎯 第4步: 分类结果分析...\")\n", " classification_results = regression_to_classification_analysis(\n", " y_test_normalized, y_pred_probs_normalized, show_detailed_metrics=True\n", " )\n", " \n", " # 5. 创建可视化\n", " print(\"\\n📊 第5步: 创建可视化图表...\")\n", " create_classification_visualizations(y_test_normalized, y_pred_probs_normalized, classification_results)\n", " \n", " # 6. 保存结果\n", " print(\"\\n💾 第6步: 保存分析结果...\")\n", " \n", " # 保存分类结果\n", " results_df = pd.DataFrame({\n", " 'true_class': classification_results['y_true_classes'],\n", " 'pred_class': classification_results['y_pred_classes'],\n", " 'true_phoneme': [LOGIT_TO_PHONEME[i] for i in classification_results['y_true_classes']],\n", " 'pred_phoneme': [LOGIT_TO_PHONEME[i] for i in classification_results['y_pred_classes']],\n", " 'pred_confidence': classification_results['pred_confidences'],\n", " 'is_correct': classification_results['y_true_classes'] == classification_results['y_pred_classes']\n", " })\n", " \n", " results_df.to_csv('./processed_datasets/classification_results.csv', index=False)\n", " \n", " # 保存详细的概率预测\n", " prob_results_df = pd.DataFrame(y_pred_probs_normalized, \n", " columns=[f'prob_{LOGIT_TO_PHONEME[i]}' for i in range(40)])\n", " prob_results_df['true_class'] = classification_results['y_true_classes']\n", " prob_results_df['pred_class'] = classification_results['y_pred_classes']\n", " \n", " prob_results_df.to_csv('./processed_datasets/probability_predictions.csv', index=False)\n", " \n", " print(\"📁 结果已保存:\")\n", " print(\" • ./processed_datasets/classification_results.csv (分类结果)\")\n", " print(\" • ./processed_datasets/probability_predictions.csv (概率预测)\")\n", " print(\" • ./processed_datasets/classification_analysis.png (可视化图表)\")\n", " \n", " # 7. 总结报告\n", " print(f\"\\n📋 {dataset_name}评估总结:\")\n", " print(\"=\"*50)\n", " print(f\"🔸 回归性能:\")\n", " print(f\" MSE: {mse:.6f}\")\n", " print(f\" R²: {r2:.4f}\")\n", " print(f\"🔸 分类性能:\")\n", " print(f\" 准确率: {classification_results['accuracy']:.4f} ({classification_results['accuracy']*100:.2f}%)\")\n", " print(f\" 平均置信度: {classification_results['pred_confidences'].mean():.4f}\")\n", " \n", " # 计算Top-K准确率\n", " for k in [1, 3, 5]:\n", " top_k_pred = np.argsort(y_pred_probs_normalized, axis=1)[:, -k:]\n", " top_k_accuracy = np.mean([classification_results['y_true_classes'][i] in top_k_pred[i] \n", " for i in range(len(classification_results['y_true_classes']))])\n", " print(f\" Top-{k} 准确率: {top_k_accuracy:.4f} ({top_k_accuracy*100:.2f}%)\")\n", " \n", " return {\n", " 'regression_metrics': {'mse': mse, 'mae': mae, 'r2': r2},\n", " 'classification_results': classification_results,\n", " 'normalized_predictions': y_pred_probs_normalized,\n", " 'normalized_true': y_test_normalized\n", " }\n", "\n", "# 如果模型已训练且有验证数据,执行完整评估\n", "if 'rf_regressor' in locals() and hasattr(rf_regressor, 'is_fitted') and rf_regressor.is_fitted:\n", " if 'X_val' in locals() and X_val is not None and 'y_val' in locals() and y_val is not None:\n", " print(\"🚀 开始完整的回归转分类评估...\")\n", " \n", " # 执行完整评估\n", " evaluation_results = complete_regression_classification_evaluation(\n", " rf_regressor, X_val, y_val, \"验证集\"\n", " )\n", " \n", " print(f\"\\n🎉 评估完成!\")\n", " print(f\"✅ 随机森林回归模型成功转换为分类结果\")\n", " print(f\"📊 生成了详细的性能分析和可视化\")\n", " print(f\"💾 所有结果已保存到文件\")\n", " \n", " # 如果有测试数据,也进行评估\n", " if 'X_test' in locals() and X_test is not None and 'y_test' in locals() and y_test is not None:\n", " print(f\"\\n🔮 开始测试集评估...\")\n", " test_evaluation_results = complete_regression_classification_evaluation(\n", " rf_regressor, X_test, y_test, \"测试集\"\n", " )\n", " else:\n", " print(\"⚠️ 没有可用的验证数据进行评估\")\n", "else:\n", " print(\"⚠️ 随机森林模型尚未训练完成\")\n", " print(\"💡 请先运行前面的训练代码\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kaggle": { "accelerator": "tpu1vmV38", "dataSources": [ { "databundleVersionId": 13056355, "sourceId": 106809, "sourceType": "competition" } ], "dockerImageVersionId": 31091, "isGpuEnabled": false, "isInternetEnabled": true, "language": "python", "sourceType": "notebook" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }