Files
b2txt25/brain-to-text-25/brain-to-text-25 copy.ipynb
2025-10-06 15:17:44 +08:00

2907 lines
154 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 环境配置 与 utils"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://download.pytorch.org/whl/cu126\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n",
"Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (0.21.0+cu124)\n",
"Requirement already satisfied: torchaudio in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n",
"Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.14.0)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.5)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.5.1)\n",
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
"Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch) (9.1.0.70)\n",
"Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.5.8)\n",
"Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch) (11.2.1.3)\n",
"Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch) (10.3.5.147)\n",
"Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch) (11.6.1.9)\n",
"Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch) (12.3.1.170)\n",
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n",
"Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n",
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
"Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n",
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n",
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision) (1.26.4)\n",
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (11.2.1)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n",
"Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.3.8)\n",
"Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.2.4)\n",
"Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (0.1.1)\n",
"Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2025.2.0)\n",
"Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2022.2.0)\n",
"Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2.4.1)\n",
"Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2024.2.0)\n",
"Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2022.2.0)\n",
"Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy->torchvision) (1.4.0)\n",
"Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy->torchvision) (2024.2.0)\n",
"Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy->torchvision) (2024.2.0)\n",
"Requirement already satisfied: jupyter==1.1.1 in /usr/local/lib/python3.11/dist-packages (1.1.1)\n",
"Requirement already satisfied: numpy<2.1.0,>=1.26.0 in /usr/local/lib/python3.11/dist-packages (1.26.4)\n",
"Requirement already satisfied: pandas==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n",
"Requirement already satisfied: matplotlib==3.10.1 in /usr/local/lib/python3.11/dist-packages (3.10.1)\n",
"Requirement already satisfied: scipy==1.15.2 in /usr/local/lib/python3.11/dist-packages (1.15.2)\n",
"Requirement already satisfied: scikit-learn==1.6.1 in /usr/local/lib/python3.11/dist-packages (1.6.1)\n",
"Requirement already satisfied: tqdm==4.67.1 in /usr/local/lib/python3.11/dist-packages (4.67.1)\n",
"Requirement already satisfied: g2p_en==2.1.0 in /usr/local/lib/python3.11/dist-packages (2.1.0)\n",
"Requirement already satisfied: h5py==3.13.0 in /usr/local/lib/python3.11/dist-packages (3.13.0)\n",
"Requirement already satisfied: omegaconf==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n",
"Requirement already satisfied: editdistance==0.8.1 in /usr/local/lib/python3.11/dist-packages (0.8.1)\n",
"Requirement already satisfied: huggingface-hub==0.33.1 in /usr/local/lib/python3.11/dist-packages (0.33.1)\n",
"Requirement already satisfied: transformers==4.53.0 in /usr/local/lib/python3.11/dist-packages (4.53.0)\n",
"Requirement already satisfied: tokenizers==0.21.2 in /usr/local/lib/python3.11/dist-packages (0.21.2)\n",
"Requirement already satisfied: accelerate==1.8.1 in /usr/local/lib/python3.11/dist-packages (1.8.1)\n",
"Requirement already satisfied: bitsandbytes==0.46.0 in /usr/local/lib/python3.11/dist-packages (0.46.0)\n",
"Requirement already satisfied: notebook in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.5.4)\n",
"Requirement already satisfied: jupyter-console in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.1.0)\n",
"Requirement already satisfied: nbconvert in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.4.5)\n",
"Requirement already satisfied: ipykernel in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.17.1)\n",
"Requirement already satisfied: ipywidgets in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (8.1.5)\n",
"Requirement already satisfied: jupyterlab in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (3.6.8)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2.9.0.post0)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n",
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.3.2)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (0.12.1)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (4.58.4)\n",
"Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.4.8)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (25.0)\n",
"Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (11.2.1)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (3.0.9)\n",
"Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (1.5.1)\n",
"Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (3.6.0)\n",
"Requirement already satisfied: nltk>=3.2.4 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (3.9.1)\n",
"Requirement already satisfied: inflect>=0.3.1 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (7.5.0)\n",
"Requirement already satisfied: distance>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (0.1.3)\n",
"Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (4.9.3)\n",
"Requirement already satisfied: PyYAML>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (6.0.2)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (3.18.0)\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2025.5.1)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2.32.4)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (4.14.0)\n",
"Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (1.1.5)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (2024.11.6)\n",
"Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (0.5.3)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (7.0.0)\n",
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (2.6.0+cu124)\n",
"Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.3.8)\n",
"Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.2.4)\n",
"Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (0.1.1)\n",
"Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2025.2.0)\n",
"Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2022.2.0)\n",
"Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2.4.1)\n",
"Requirement already satisfied: more_itertools>=8.5.0 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (10.7.0)\n",
"Requirement already satisfied: typeguard>=4.0.1 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (4.4.4)\n",
"Requirement already satisfied: click in /usr/local/lib/python3.11/dist-packages (from nltk>=3.2.4->g2p_en==2.1.0) (8.2.1)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas==2.3.0) (1.17.0)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.5)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.1.6)\n",
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
"Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (9.1.0.70)\n",
"Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.5.8)\n",
"Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.2.1.3)\n",
"Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (10.3.5.147)\n",
"Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.6.1.9)\n",
"Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.3.1.170)\n",
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (0.6.2)\n",
"Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (2.21.5)\n",
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
"Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.2.0)\n",
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (1.13.1)\n",
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=2.0.0->accelerate==1.8.1) (1.3.0)\n",
"Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.8.0)\n",
"Requirement already satisfied: ipython>=7.23.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (7.34.0)\n",
"Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (8.6.3)\n",
"Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (0.1.7)\n",
"Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.6.0)\n",
"Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (24.0.1)\n",
"Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (6.5.1)\n",
"Requirement already satisfied: traitlets>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (5.7.1)\n",
"Requirement already satisfied: comm>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (0.2.2)\n",
"Requirement already satisfied: widgetsnbextension~=4.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (4.0.14)\n",
"Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (3.0.15)\n",
"Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (3.0.51)\n",
"Requirement already satisfied: pygments in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (2.19.2)\n",
"Requirement already satisfied: jupyter-core in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (5.8.1)\n",
"Requirement already satisfied: jupyterlab-server~=2.19 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.27.3)\n",
"Requirement already satisfied: jupyter-server<3,>=1.16.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.12.5)\n",
"Requirement already satisfied: jupyter-ydoc~=0.2.4 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.2.5)\n",
"Requirement already satisfied: jupyter-server-ydoc~=0.8.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.8.0)\n",
"Requirement already satisfied: nbclassic in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (1.3.1)\n",
"Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (25.1.0)\n",
"Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.2.0)\n",
"Requirement already satisfied: nbformat in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (5.10.4)\n",
"Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (1.8.3)\n",
"Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.18.1)\n",
"Requirement already satisfied: prometheus-client in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.22.1)\n",
"Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.8.4)\n",
"Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.3.0)\n",
"Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.4)\n",
"Requirement already satisfied: bleach in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (6.2.0)\n",
"Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (1.5.1)\n",
"Requirement already satisfied: testpath in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.6.0)\n",
"Requirement already satisfied: defusedxml in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.7.1)\n",
"Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (4.13.4)\n",
"Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.5.13)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (3.0.2)\n",
"Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
"Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2022.2.0)\n",
"Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy<2.1.0,>=1.26.0) (1.4.0)\n",
"Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
"Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.4.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.10)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2.5.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2025.6.15)\n",
"Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
"Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (75.2.0)\n",
"Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.19.2)\n",
"Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.4.2)\n",
"Requirement already satisfied: pickleshare in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.7.5)\n",
"Requirement already satisfied: backcall in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.2.0)\n",
"Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.9.0)\n",
"Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.11/dist-packages (from jupyter-core->jupyterlab->jupyter==1.1.1) (4.3.8)\n",
"Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (4.9.0)\n",
"Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.12.0)\n",
"Requirement already satisfied: jupyter-server-terminals in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.5.3)\n",
"Requirement already satisfied: overrides in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (7.7.0)\n",
"Requirement already satisfied: websocket-client in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.8.0)\n",
"Requirement already satisfied: jupyter-server-fileid<1,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.9.3)\n",
"Requirement already satisfied: ypy-websocket<0.9.0,>=0.8.2 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.8.4)\n",
"Requirement already satisfied: y-py<0.7.0,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-ydoc~=0.2.4->jupyterlab->jupyter==1.1.1) (0.6.2)\n",
"Requirement already satisfied: babel>=2.10 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2.17.0)\n",
"Requirement already satisfied: json5>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.12.0)\n",
"Requirement already satisfied: jsonschema>=4.18.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (4.24.0)\n",
"Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.11/dist-packages (from nbclassic->jupyterlab->jupyter==1.1.1) (0.2.4)\n",
"Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.11/dist-packages (from nbformat->notebook->jupyter==1.1.1) (2.21.1)\n",
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->jupyter-console->jupyter==1.1.1) (0.2.13)\n",
"Requirement already satisfied: ptyprocess in /usr/local/lib/python3.11/dist-packages (from terminado>=0.8.3->notebook->jupyter==1.1.1) (0.7.0)\n",
"Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.11/dist-packages (from argon2-cffi->notebook->jupyter==1.1.1) (21.2.0)\n",
"Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->nbconvert->jupyter==1.1.1) (2.7)\n",
"Requirement already satisfied: webencodings in /usr/local/lib/python3.11/dist-packages (from bleach->nbconvert->jupyter==1.1.1) (0.5.1)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio>=3.1.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.1)\n",
"Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.11/dist-packages (from jedi>=0.16->ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.8.4)\n",
"Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (25.3.0)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2025.4.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.36.2)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.25.1)\n",
"Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.3.0)\n",
"Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.4)\n",
"Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.1)\n",
"Requirement already satisfied: aiofiles<23,>=22.1.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (22.1.0)\n",
"Requirement already satisfied: aiosqlite<1,>=0.17.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.21.0)\n",
"Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (1.17.1)\n",
"Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (2.22)\n",
"Requirement already satisfied: fqdn in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.5.1)\n",
"Requirement already satisfied: isoduration in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (20.11.0)\n",
"Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.0.0)\n",
"Requirement already satisfied: uri-template in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n",
"Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (24.11.1)\n",
"Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.11/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n",
"Requirement already satisfied: types-python-dateutil>=2.8.10 in /usr/local/lib/python3.11/dist-packages (from arrow>=0.15.0->isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (2.9.0.20250516)\n",
"Obtaining file:///kaggle/working/nejm-brain-to-text\n",
" Preparing metadata (setup.py): started\n",
" Preparing metadata (setup.py): finished with status 'done'\n",
"Installing collected packages: nejm_b2txt_utils\n",
" Running setup.py develop for nejm_b2txt_utils\n",
"Successfully installed nejm_b2txt_utils-0.0.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Cloning into 'nejm-brain-to-text'...\n",
"cp: cannot stat '/kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl': No such file or directory\n"
]
}
],
"source": [
"%%bash\n",
"cd /kaggle/working/\n",
"rm -rf /kaggle/working/nejm-brain-to-text/\n",
"git clone https://github.com/ZH-CEN/nejm-brain-to-text.git\n",
"cd /kaggle/working/nejm-brain-to-text/\n",
"cp /kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl /kaggle/working/nejm-brain-to-text/data/t15_copyTask.pkl\n",
"# Install PyTorch with CUDA 12.6\n",
"pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n",
"\n",
"# Install additional packages with compatible versions\n",
"# TODO: remove redis\n",
"pip install \\\n",
" jupyter==1.1.1 \\\n",
" \"numpy>=1.26.0,<2.1.0\" \\\n",
" pandas==2.3.0 \\\n",
" matplotlib==3.10.1 \\\n",
" scipy==1.15.2 \\\n",
" scikit-learn==1.6.1 \\\n",
" tqdm==4.67.1 \\\n",
" g2p_en==2.1.0 \\\n",
" h5py==3.13.0 \\\n",
" omegaconf==2.3.0 \\\n",
" editdistance==0.8.1 \\\n",
" huggingface-hub==0.33.1 \\\n",
" transformers==4.53.0 \\\n",
" tokenizers==0.21.2 \\\n",
" accelerate==1.8.1 \\\n",
" bitsandbytes==0.46.0\n",
"\n",
"# Install the local package\n",
"pip install -e .\n",
"ln -s /kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline /kaggle/working/nejm-brain-to-text/data\n",
"ln -s /kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final /kaggle/working/nejm-brain-to-text/data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/kaggle/working/nejm-brain-to-text\n"
]
}
],
"source": [
"%cd /kaggle/working/nejm-brain-to-text\n",
"import numpy as np\n",
"import os\n",
"import pickle\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib\n",
"from g2p_en import G2p\n",
"import pandas as pd\n",
"import numpy as np\n",
"from nejm_b2txt_utils.general_utils import *\n",
"\n",
"matplotlib.rcParams['pdf.fonttype'] = 42\n",
"matplotlib.rcParams['ps.fonttype'] = 42\n",
"matplotlib.rcParams['font.family'] = 'sans-serif'\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import h5py\n",
"def load_h5py_file(file_path, b2txt_csv_df):\n",
" data = {\n",
" 'neural_features': [],\n",
" 'n_time_steps': [],\n",
" 'seq_class_ids': [],\n",
" 'seq_len': [],\n",
" 'transcriptions': [],\n",
" 'sentence_label': [],\n",
" 'session': [],\n",
" 'block_num': [],\n",
" 'trial_num': [],\n",
" 'corpus': [],\n",
" }\n",
" # Open the hdf5 file for that day\n",
" with h5py.File(file_path, 'r') as f:\n",
"\n",
" keys = list(f.keys())\n",
"\n",
" # For each trial in the selected trials in that day\n",
" for key in keys:\n",
" g = f[key]\n",
"\n",
" neural_features = g['input_features'][:] # pyright: ignore[reportIndexIssue]\n",
" n_time_steps = g.attrs['n_time_steps']\n",
" seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None # type: ignore\n",
" seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None\n",
" transcription = g['transcription'][:] if 'transcription' in g else None # type: ignore\n",
" sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None # pyright: ignore[reportIndexIssue]\n",
" session = g.attrs['session']\n",
" block_num = g.attrs['block_num']\n",
" trial_num = g.attrs['trial_num']\n",
"\n",
" # match this trial up with the csv to get the corpus name\n",
" year, month, day = session.split('.')[1:] # pyright: ignore[reportAttributeAccessIssue]\n",
" date = f'{year}-{month}-{day}'\n",
" row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)]\n",
" corpus_name = row['Corpus'].values[0]\n",
"\n",
" data['neural_features'].append(neural_features)\n",
" data['n_time_steps'].append(n_time_steps)\n",
" data['seq_class_ids'].append(seq_class_ids)\n",
" data['seq_len'].append(seq_len)\n",
" data['transcriptions'].append(transcription)\n",
" data['sentence_label'].append(sentence_label)\n",
" data['session'].append(session)\n",
" data['block_num'].append(block_num)\n",
" data['trial_num'].append(trial_num)\n",
" data['corpus'].append(corpus_name)\n",
" return data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"LOGIT_TO_PHONEME = [\n",
" 'BLANK',\n",
" 'AA', 'AE', 'AH', 'AO', 'AW',\n",
" 'AY', 'B', 'CH', 'D', 'DH',\n",
" 'EH', 'ER', 'EY', 'F', 'G',\n",
" 'HH', 'IH', 'IY', 'JH', 'K',\n",
" 'L', 'M', 'N', 'NG', 'OW',\n",
" 'OY', 'P', 'R', 'S', 'SH',\n",
" 'T', 'TH', 'UH', 'UW', 'V',\n",
" 'W', 'Y', 'Z', 'ZH',\n",
" ' | ',\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 数据分析与预处理"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 数据准备"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/kaggle/working/nejm-brain-to-text\n"
]
}
],
"source": [
"%cd /kaggle/working/nejm-brain-to-text/"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"data = load_h5py_file(file_path='data/hdf5_data_final/t15.2023.08.11/data_train.hdf5',\n",
" b2txt_csv_df=pd.read_csv('data/t15_copyTaskData_description.csv'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- **任务介绍** :机器学习解决高维信号的模式识别问题"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"我们的数据集标签缺少时间戳,现在要进行的是半监督学习"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 音素时间均等分割或者按照调研数据设定初始长度。然后筛掉异常值。提取出可用的训练集,再控制时间长短,查看样本类的长度"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'neural_features': array([[ 2.3076649 , -0.78699756, -0.64687246, ..., 0.57367045,\n",
" -0.7091646 , -0.11018186],\n",
" [-0.5859305 , -0.78699756, -0.64687246, ..., 0.3122117 ,\n",
" 1.7943763 , -0.76884896],\n",
" [-0.5859305 , -0.78699756, -0.64687246, ..., -0.21193463,\n",
" -0.8481289 , -0.7648201 ],\n",
" ...,\n",
" [-0.5859305 , 0.22756557, 0.9262037 , ..., -0.34710956,\n",
" 0.9710176 , 2.5397465 ],\n",
" [-0.5859305 , 0.22756557, -0.64687246, ..., -0.83613133,\n",
" -0.68723625, 0.10479005],\n",
" [ 0.8608672 , -0.78699756, -0.64687246, ..., -0.7171131 ,\n",
" 0.7417906 , -0.7008622 ]], dtype=float32),\n",
" 'n_time_steps': 321,\n",
" 'seq_class_ids': array([ 7, 28, 17, 24, 40, 17, 31, 40, 20, 21, 25, 29, 12, 40, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0], dtype=int32),\n",
" 'seq_len': 14,\n",
" 'transcriptions': array([ 66, 114, 105, 110, 103, 32, 105, 116, 32, 99, 108, 111, 115,\n",
" 101, 114, 46, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0], dtype=int32),\n",
" 'sentence_label': 'Bring it closer.',\n",
" 'session': 't15.2023.08.11',\n",
" 'block_num': 2,\n",
" 'trial_num': 0,\n",
" 'corpus': '50-Word'}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def data_patch(data, index):\n",
" data_patch = {}\n",
" data_patch['neural_features'] = data['neural_features'][index]\n",
" data_patch['n_time_steps'] = data['n_time_steps'][index]\n",
" data_patch['seq_class_ids'] = data['seq_class_ids'][index]\n",
" data_patch['seq_len'] = data['seq_len'][index]\n",
" data_patch['transcriptions'] = data['transcriptions'][index]\n",
" data_patch['sentence_label'] = data['sentence_label'][index]\n",
" data_patch['session'] = data['session'][index]\n",
" data_patch['block_num'] = data['block_num'][index]\n",
" data_patch['trial_num'] = data['trial_num'][index]\n",
" data_patch['corpus'] = data['corpus'][index]\n",
" return data_patch\n",
"\n",
"data_patch(data, 0)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"d1 = data_patch(data, 0)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Transcriptions non-zero length: 16\n",
"Seq class ids non-zero length: 14\n",
"Seq len: 14\n"
]
}
],
"source": [
"trans_len = len([x for x in d1['transcriptions'] if x != 0])\n",
"seq_len_nonzero = len([x for x in d1['seq_class_ids'] if x != 0])\n",
"seq_len = d1['seq_len']\n",
"print(f\"Transcriptions non-zero length: {trans_len}\")\n",
"print(f\"Seq class ids non-zero length: {seq_len_nonzero}\")\n",
"print(f\"Seq len: {seq_len}\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of feature sequences: 14\n",
"Shape of first sequence: (22, 512)\n"
]
}
],
"source": [
"def create_time_windows(d1):\n",
" import numpy as np\n",
" n_time_steps = d1['n_time_steps']\n",
" seq_len = d1['seq_len']\n",
" # Create equal windows\n",
" edges = np.linspace(0, n_time_steps, seq_len + 1, dtype=int)\n",
" windows = [(edges[i], edges[i+1]) for i in range(seq_len)]\n",
" \n",
" # Extract feature sequences for each window\n",
" feature_sequences = []\n",
" for start, end in windows:\n",
" seq = d1['neural_features'][start:end, :]\n",
" feature_sequences.append(seq)\n",
" \n",
" return feature_sequences\n",
"\n",
"# Example usage\n",
"feature_sequences = create_time_windows(d1)\n",
"print(\"Number of feature sequences:\", len(feature_sequences))\n",
"print(\"Shape of first sequence:\", feature_sequences[0].shape)\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train: 45, Val: 41, Test: 41\n",
"Train files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_train.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.08.11/data_train.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_train.hdf5']\n",
"Val files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_val.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_val.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2024.03.08/data_val.hdf5']\n",
"Test files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_test.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_test.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2024.03.08/data_test.hdf5']\n"
]
}
],
"source": [
"import os\n",
"\n",
"def scan_hdf5_files(base_path):\n",
" train_files = []\n",
" val_files = []\n",
" test_files = []\n",
" for root, dirs, files in os.walk(base_path):\n",
" for file in files:\n",
" if file.endswith('.hdf5'):\n",
" abs_path = os.path.abspath(os.path.join(root, file))\n",
" if 'data_train.hdf5' in file:\n",
" train_files.append(abs_path)\n",
" elif 'data_val.hdf5' in file:\n",
" val_files.append(abs_path)\n",
" elif 'data_test.hdf5' in file:\n",
" test_files.append(abs_path)\n",
" return train_files, val_files, test_files\n",
"\n",
"# Example usage\n",
"FILE_PATH = 'data/hdf5_data_final'\n",
"train_list, val_list, test_list = scan_hdf5_files(FILE_PATH)\n",
"print(f\"Train: {len(train_list)}, Val: {len(val_list)}, Test: {len(test_list)}\")\n",
"print(\"Train files (first 3):\", train_list[:3])\n",
"print(\"Val files (first 3):\", val_list[:3])\n",
"print(\"Test files (first 3):\", test_list[:3])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 🚀 完整数据集处理工作流\n",
"\n",
"创建一个自动化工作流处理所有sessions的训练集、验证集、测试集数据生成包含40类音素预测的完整特征数据集。"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🧠 创建真实RNN模型加载器...\n",
" 未检测到CUDA使用CPU\n",
"🧠 RNN模型加载器初始化:\n",
" 模型路径: ./data/t15_pretrained_rnn_baseline\n",
" 使用设备: cpu\n",
"✅ 模型文件检查通过\n",
"⚠️ 无法导入rnn_model尝试从model_training目录导入...\n",
"✅ 从model_training目录成功导入GRUDecoder\n",
"\n",
"🔄 加载RNN模型...\n",
" ✅ 模型配置加载完成\n",
" 输入特征维度: 512\n",
" 隐藏单元数: 768\n",
" 层数: 5\n",
" 输出类别数: 41\n",
" ✅ 模型架构初始化完成\n",
" 🔐 加载检查点注意使用weights_only=False请确保检查点来源可信...\n",
" ✅ 模型架构初始化完成\n",
" 🔐 加载检查点注意使用weights_only=False请确保检查点来源可信...\n",
" 📋 找到model_state_dict键\n",
" 🔍 匹配模型参数...\n",
" 📊 参数匹配统计: 113/113 个键匹配成功\n",
" ✅ 预训练权重加载完成 (113个参数)\n",
" 📊 模型参数统计:\n",
" 总参数数量: 44,315,177\n",
" Day-specific参数: 11,819,520 (26.7%)\n",
" 会话数量: 45\n",
" 📅 支持的会话:\n",
" 0: t15.2023.08.11\n",
" 1: t15.2023.08.13\n",
" 2: t15.2023.08.18\n",
" 3: t15.2023.08.20\n",
" 4: t15.2023.08.25\n",
" ... 还有 40 个会话\n",
"\n",
"🎉 RNN模型加载成功!\n",
"\n",
"✅ 真实RNN模型加载器准备完成\n",
"🔧 现在可以使用真实的预训练RNN进行预测\n",
" 📋 找到model_state_dict键\n",
" 🔍 匹配模型参数...\n",
" 📊 参数匹配统计: 113/113 个键匹配成功\n",
" ✅ 预训练权重加载完成 (113个参数)\n",
" 📊 模型参数统计:\n",
" 总参数数量: 44,315,177\n",
" Day-specific参数: 11,819,520 (26.7%)\n",
" 会话数量: 45\n",
" 📅 支持的会话:\n",
" 0: t15.2023.08.11\n",
" 1: t15.2023.08.13\n",
" 2: t15.2023.08.18\n",
" 3: t15.2023.08.20\n",
" 4: t15.2023.08.25\n",
" ... 还有 40 个会话\n",
"\n",
"🎉 RNN模型加载成功!\n",
"\n",
"✅ 真实RNN模型加载器准备完成\n",
"🔧 现在可以使用真实的预训练RNN进行预测\n"
]
}
],
"source": [
"# 🧠 真实RNN模型加载器\n",
"import torch\n",
"from omegaconf import OmegaConf\n",
"\n",
"class RealRNNModelLoader:\n",
" \"\"\"\n",
" 加载和使用真实的预训练RNN模型\n",
" \"\"\"\n",
" \n",
" def __init__(self, model_path='./data/t15_pretrained_rnn_baseline', device='auto'):\n",
" \"\"\"\n",
" 初始化RNN模型加载器\n",
" \n",
" 参数:\n",
" model_path: 预训练模型路径\n",
" device: 使用的设备 ('auto', 'cpu', 'cuda' 或 'cuda:0' 等)\n",
" \"\"\"\n",
" self.model_path = model_path\n",
" self.model = None\n",
" self.model_args = None\n",
" self.device = self._setup_device(device)\n",
" \n",
" print(f\"🧠 RNN模型加载器初始化:\")\n",
" print(f\" 模型路径: {model_path}\")\n",
" print(f\" 使用设备: {self.device}\")\n",
" \n",
" # 检查模型文件是否存在\n",
" self._check_model_files()\n",
" \n",
" def _setup_device(self, device):\n",
" \"\"\"设置计算设备\"\"\"\n",
" if device == 'auto':\n",
" if torch.cuda.is_available():\n",
" device = 'cuda'\n",
" print(f\" 自动检测到CUDA使用GPU\")\n",
" else:\n",
" device = 'cpu'\n",
" print(f\" 未检测到CUDA使用CPU\")\n",
" \n",
" return torch.device(device)\n",
" \n",
" def _check_model_files(self):\n",
" \"\"\"检查必需的模型文件\"\"\"\n",
" required_files = {\n",
" 'args.yaml': os.path.join(self.model_path, 'checkpoint/args.yaml'),\n",
" 'checkpoint': os.path.join(self.model_path, 'checkpoint/best_checkpoint')\n",
" }\n",
" \n",
" missing_files = []\n",
" for name, path in required_files.items():\n",
" if not os.path.exists(path):\n",
" missing_files.append(f\"{name}: {path}\")\n",
" \n",
" if missing_files:\n",
" print(f\"❌ 缺少模型文件:\")\n",
" for file in missing_files:\n",
" print(f\" • {file}\")\n",
" print(f\"\\n💡 请确保已下载预训练模型到: {self.model_path}\")\n",
" return False\n",
" else:\n",
" print(f\"✅ 模型文件检查通过\")\n",
" return True\n",
" \n",
" def load_model(self):\n",
" \"\"\"加载预训练的RNN模型\"\"\"\n",
" try:\n",
" # 需要先检查是否已经导入了rnn_model\n",
" try:\n",
" from rnn_model import GRUDecoder\n",
" except ImportError:\n",
" print(\"⚠️ 无法导入rnn_model尝试从model_training目录导入...\")\n",
" import sys\n",
" model_training_path = os.path.abspath('./model_training')\n",
" if model_training_path not in sys.path:\n",
" sys.path.append(model_training_path)\n",
" \n",
" try:\n",
" from rnn_model import GRUDecoder\n",
" print(\"✅ 从model_training目录成功导入GRUDecoder\")\n",
" except ImportError as e:\n",
" print(f\"❌ 无法导入GRUDecoder: {e}\")\n",
" print(\"💡 请确保rnn_model.py在model_training目录中\")\n",
" return False\n",
" \n",
" print(f\"\\n🔄 加载RNN模型...\")\n",
" \n",
" # 1. 加载模型配置\n",
" args_path = os.path.join(self.model_path, 'checkpoint/args.yaml')\n",
" self.model_args = OmegaConf.load(args_path)\n",
" \n",
" print(f\" ✅ 模型配置加载完成\")\n",
" print(f\" 输入特征维度: {self.model_args['model']['n_input_features']}\")\n",
" print(f\" 隐藏单元数: {self.model_args['model']['n_units']}\")\n",
" print(f\" 层数: {self.model_args['model']['n_layers']}\")\n",
" print(f\" 输出类别数: {self.model_args['dataset']['n_classes']}\")\n",
" \n",
" # 2. 初始化模型架构\n",
" self.model = GRUDecoder(\n",
" neural_dim=self.model_args['model']['n_input_features'],\n",
" n_units=self.model_args['model']['n_units'],\n",
" n_days=len(self.model_args['dataset']['sessions']),\n",
" n_classes=self.model_args['dataset']['n_classes'],\n",
" rnn_dropout=self.model_args['model']['rnn_dropout'],\n",
" input_dropout=self.model_args['model']['input_network']['input_layer_dropout'],\n",
" n_layers=self.model_args['model']['n_layers'],\n",
" patch_size=self.model_args['model']['patch_size'],\n",
" patch_stride=self.model_args['model']['patch_stride'],\n",
" )\n",
" \n",
" print(f\" ✅ 模型架构初始化完成\")\n",
" \n",
" # 3. 加载预训练权重 - 修复安全问题\n",
" checkpoint_path = os.path.join(self.model_path, 'checkpoint/best_checkpoint')\n",
" \n",
" # 使用weights_only=False来解决pickle安全问题仅在信任的检查点上使用\n",
" print(f\" 🔐 加载检查点注意使用weights_only=False请确保检查点来源可信...\")\n",
" checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)\n",
" \n",
" # 提取模型状态字典\n",
" if 'model_state_dict' in checkpoint:\n",
" state_dict = checkpoint['model_state_dict']\n",
" print(f\" 📋 找到model_state_dict键\")\n",
" elif 'model' in checkpoint:\n",
" state_dict = checkpoint['model']\n",
" print(f\" 📋 找到model键\")\n",
" else:\n",
" state_dict = checkpoint\n",
" print(f\" 📋 使用整个checkpoint作为state_dict\")\n",
" \n",
" # 处理可能的键名不匹配\n",
" model_state_dict = self.model.state_dict()\n",
" filtered_state_dict = {}\n",
" \n",
" print(f\" 🔍 匹配模型参数...\")\n",
" matched_keys = 0\n",
" total_keys = len(state_dict)\n",
" unmatched_samples = []\n",
" \n",
" for key, value in state_dict.items():\n",
" # 移除可能的前缀\n",
" clean_key = key\n",
" \n",
" # 移除'_orig_mod.'前缀PyTorch编译产生的\n",
" if clean_key.startswith('_orig_mod.'):\n",
" clean_key = clean_key.replace('_orig_mod.', '')\n",
" \n",
" # 移除'module.'前缀(分布式训练产生的)\n",
" if clean_key.startswith('module.'):\n",
" clean_key = clean_key.replace('module.', '')\n",
" \n",
" if clean_key in model_state_dict:\n",
" filtered_state_dict[clean_key] = value\n",
" matched_keys += 1\n",
" else:\n",
" # 只显示前几个不匹配的键作为示例\n",
" if len(unmatched_samples) < 3:\n",
" unmatched_samples.append(f\"{key} -> {clean_key}\")\n",
" \n",
" print(f\" 📊 参数匹配统计: {matched_keys}/{total_keys} 个键匹配成功\")\n",
" \n",
" if unmatched_samples:\n",
" print(f\" ⚠️ 不匹配键示例: {', '.join(unmatched_samples)}\")\n",
" \n",
" # 加载权重\n",
" missing_keys, unexpected_keys = self.model.load_state_dict(filtered_state_dict, strict=False)\n",
" \n",
" if missing_keys:\n",
" print(f\" ⚠️ 缺失的键 ({len(missing_keys)}): {missing_keys[:3]}{'...' if len(missing_keys) > 3 else ''}\")\n",
" \n",
" if unexpected_keys:\n",
" print(f\" ⚠️ 意外的键 ({len(unexpected_keys)}): {unexpected_keys[:3]}{'...' if len(unexpected_keys) > 3 else ''}\")\n",
" \n",
" self.model.to(self.device)\n",
" self.model.eval()\n",
" \n",
" if matched_keys > 0:\n",
" print(f\" ✅ 预训练权重加载完成 ({matched_keys}个参数)\")\n",
" else:\n",
" print(f\" ❌ 没有成功匹配任何预训练权重\")\n",
" print(f\" 🔄 使用随机初始化的权重继续\")\n",
" \n",
" # 4. 显示模型信息\n",
" total_params = sum(p.numel() for p in self.model.parameters())\n",
" print(f\" 📊 模型参数统计:\")\n",
" print(f\" 总参数数量: {total_params:,}\")\n",
" \n",
" # 显示每个day的参数数量\n",
" day_params = 0\n",
" for name, param in self.model.named_parameters():\n",
" if 'day' in name:\n",
" day_params += param.numel()\n",
" \n",
" print(f\" Day-specific参数: {day_params:,} ({day_params/total_params*100:.1f}%)\")\n",
" print(f\" 会话数量: {len(self.model_args['dataset']['sessions'])}\")\n",
" \n",
" # 显示会话列表\n",
" print(f\" 📅 支持的会话:\")\n",
" sessions = self.model_args['dataset']['sessions']\n",
" for i, session in enumerate(sessions[:5]):\n",
" print(f\" {i}: {session}\")\n",
" if len(sessions) > 5:\n",
" print(f\" ... 还有 {len(sessions)-5} 个会话\")\n",
" \n",
" print(f\"\\n🎉 RNN模型加载{'成功' if matched_keys > 0 else '完成(使用随机权重)'}!\")\n",
" return True\n",
" \n",
" except Exception as e:\n",
" print(f\"❌ RNN模型加载失败: {str(e)}\")\n",
" import traceback\n",
" print(f\"详细错误信息:\")\n",
" print(traceback.format_exc())\n",
" return False\n",
" \n",
" def predict_trial(self, neural_features, day_idx=0):\n",
" \"\"\"\n",
" 对单个试验进行RNN预测\n",
" \n",
" 参数:\n",
" neural_features: 神经特征 [time_steps, features]\n",
" day_idx: day索引对应不同的session\n",
" \n",
" 返回:\n",
" logits: RNN输出 [time_steps, n_classes]\n",
" \"\"\"\n",
" if self.model is None:\n",
" raise ValueError(\"模型尚未加载请先调用load_model()\")\n",
" \n",
" try:\n",
" # 转换为tensor并添加batch维度\n",
" if isinstance(neural_features, np.ndarray):\n",
" neural_features = torch.from_numpy(neural_features).float()\n",
" \n",
" neural_features = neural_features.unsqueeze(0).to(self.device) # [1, time_steps, features]\n",
" \n",
" # 确保day_idx在有效范围内\n",
" n_days = len(self.model_args['dataset']['sessions'])\n",
" day_idx = max(0, min(day_idx, n_days - 1))\n",
" \n",
" # 创建day索引\n",
" day_tensor = torch.tensor([day_idx], dtype=torch.long, device=self.device)\n",
" \n",
" # 模型推理\n",
" with torch.no_grad():\n",
" logits = self.model(neural_features, day_tensor) # [1, time_steps, n_classes]\n",
" \n",
" # 移除batch维度并转换为numpy\n",
" logits = logits.squeeze(0).cpu().numpy() # [time_steps, n_classes]\n",
" \n",
" return logits\n",
" \n",
" except Exception as e:\n",
" print(f\"❌ RNN预测失败: {str(e)}\")\n",
" return None\n",
" \n",
" def get_day_index(self, session_name):\n",
" \"\"\"\n",
" 根据session名称获取对应的day索引\n",
" \n",
" 参数:\n",
" session_name: session名称 (如 't15.2023.08.11')\n",
" \n",
" 返回:\n",
" day_idx: day索引\n",
" \"\"\"\n",
" if self.model_args is None:\n",
" return 0\n",
" \n",
" sessions = self.model_args['dataset']['sessions']\n",
" try:\n",
" return sessions.index(session_name)\n",
" except ValueError:\n",
" print(f\"⚠️ 未找到session {session_name}使用day_idx=0\")\n",
" return 0\n",
" \n",
" def is_loaded(self):\n",
" \"\"\"检查模型是否已加载\"\"\"\n",
" return self.model is not None\n",
"\n",
"# 创建RNN模型加载器实例\n",
"print(\"🧠 创建真实RNN模型加载器...\")\n",
"rnn_loader = RealRNNModelLoader(\n",
" model_path='./data/t15_pretrained_rnn_baseline',\n",
" device='auto'\n",
")\n",
"\n",
"# 尝试加载模型\n",
"if rnn_loader.load_model():\n",
" print(\"\\n✅ 真实RNN模型加载器准备完成\")\n",
" print(\"🔧 现在可以使用真实的预训练RNN进行预测\")\n",
"else:\n",
" print(\"\\n❌ RNN模型加载失败\")\n",
" print(\"💡 将继续使用模拟预测作为备选方案\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"# 🏗️ 完整数据集批处理管道 - 使用真实RNN模型优化版\n",
"import os\n",
"import re\n",
"import time\n",
"from tqdm import tqdm\n",
"import pandas as pd\n",
"import numpy as np\n",
"import h5py\n",
"\n",
"class BrainToTextDatasetPipeline:\n",
" \"\"\"\n",
" 批量处理所有数据集session的完整管道\n",
" 集成了真实的RNN模型进行特征提取\n",
" 优化版:添加延迟检查和进度条功能\n",
" \"\"\"\n",
" \n",
" def __init__(self, data_dir='./data/hdf5_data_final', rnn_loader=None):\n",
" \"\"\"\n",
" 初始化数据集管道(轻量级初始化)\n",
" \n",
" 参数:\n",
" data_dir: 数据目录路径\n",
" rnn_loader: 真实RNN模型加载器实例\n",
" \"\"\"\n",
" print(\"🔧 初始化数据集管道...\")\n",
" self.data_dir = data_dir\n",
" self.rnn_loader = rnn_loader\n",
" self.sessions = []\n",
" self.results = {}\n",
" \n",
" # 延迟检查标志\n",
" self._rnn_checked = False\n",
" self._sessions_scanned = False\n",
" self.use_real_rnn = False\n",
" \n",
" print(\"✅ 管道初始化完成!使用延迟加载模式\")\n",
" print(\"💡 调用 check_status() 来检查RNN模型和数据状态\")\n",
" \n",
" def check_status(self):\n",
" \"\"\"检查RNN模型和数据状态带进度条\"\"\"\n",
" print(\"\\n🔍 正在检查系统状态...\")\n",
" \n",
" tasks = [\n",
" (\"检查数据目录\", self._check_data_directory),\n",
" (\"扫描session文件\", self._scan_sessions),\n",
" (\"检查RNN模型\", self._check_rnn_model),\n",
" (\"验证系统就绪\", self._verify_system_ready)\n",
" ]\n",
" \n",
" with tqdm(tasks, desc=\"系统检查\", unit=\"步\") as pbar:\n",
" for task_name, task_func in pbar:\n",
" pbar.set_postfix_str(f\"执行: {task_name}\")\n",
" time.sleep(0.1) # 小延迟以显示进度\n",
" task_func()\n",
" pbar.set_postfix_str(f\"✅ {task_name}\")\n",
" time.sleep(0.2) # 显示完成状态\n",
" \n",
" self._print_status_summary()\n",
" \n",
" def _check_data_directory(self):\n",
" \"\"\"检查数据目录\"\"\"\n",
" if not os.path.exists(self.data_dir):\n",
" raise FileNotFoundError(f\"❌ 数据目录不存在: {self.data_dir}\")\n",
" \n",
" # 检查目录中是否有文件\n",
" items = os.listdir(self.data_dir)\n",
" if not items:\n",
" raise ValueError(f\"❌ 数据目录为空: {self.data_dir}\")\n",
" \n",
" def _scan_sessions(self):\n",
" \"\"\"扫描数据目录中的所有session带进度条\"\"\"\n",
" if self._sessions_scanned:\n",
" return\n",
" \n",
" if not os.path.exists(self.data_dir):\n",
" print(f\"❌ 数据目录不存在: {self.data_dir}\")\n",
" return\n",
" \n",
" print(f\"📂 正在扫描数据目录: {self.data_dir}\")\n",
" \n",
" # 获取所有项目用于进度条\n",
" all_items = os.listdir(self.data_dir)\n",
" pattern = re.compile(r't15\\.\\d{4}\\.\\d{2}\\.\\d{2}')\n",
" \n",
" # 使用进度条扫描\n",
" valid_sessions = []\n",
" with tqdm(all_items, desc=\"扫描sessions\", unit=\"项\", leave=False) as pbar:\n",
" for item in pbar:\n",
" pbar.set_postfix_str(f\"检查: {item[:20]}...\")\n",
" \n",
" if pattern.match(item):\n",
" session_path = os.path.join(self.data_dir, item)\n",
" if os.path.isdir(session_path):\n",
" # 检查是否有数据文件\n",
" try:\n",
" files = os.listdir(session_path)\n",
" has_data = any(f.endswith('.h5') or f.endswith('.hdf5') for f in files)\n",
" if has_data:\n",
" valid_sessions.append(item)\n",
" pbar.set_postfix_str(f\"✅ {item}\")\n",
" except PermissionError:\n",
" pbar.set_postfix_str(f\"⚠️ 权限错误: {item}\")\n",
" except Exception:\n",
" pbar.set_postfix_str(f\"❌ 错误: {item}\")\n",
" \n",
" # 小延迟以显示进度\n",
" time.sleep(0.01)\n",
" \n",
" self.sessions = sorted(valid_sessions)\n",
" self._sessions_scanned = True\n",
" \n",
" print(f\"✅ 扫描完成! 发现 {len(self.sessions)} 个有效session\")\n",
" \n",
" def _check_rnn_model(self):\n",
" \"\"\"延迟检查RNN模型状态带进度条\"\"\"\n",
" if self._rnn_checked:\n",
" return\n",
" \n",
" print(\"🔍 正在检查RNN模型状态...\")\n",
" \n",
" # 模拟检查过程的进度条\n",
" check_steps = [\n",
" \"检查模型加载器\",\n",
" \"验证模型文件路径\", \n",
" \"测试文件可读性\",\n",
" \"验证模型结构\",\n",
" \"确认模型状态\"\n",
" ]\n",
" \n",
" with tqdm(check_steps, desc=\"模型检查\", unit=\"步\", leave=False) as pbar:\n",
" for step in pbar:\n",
" pbar.set_postfix_str(step)\n",
" time.sleep(0.15) # 模拟检查时间\n",
" \n",
" if \"测试文件可读性\" in step:\n",
" # 实际的模型检查逻辑\n",
" if self.rnn_loader and self.rnn_loader.is_loaded():\n",
" self.use_real_rnn = True\n",
" pbar.set_postfix_str(\"✅ 真实模型可用\")\n",
" else:\n",
" self.use_real_rnn = False\n",
" pbar.set_postfix_str(\"❌ 使用模拟模型\")\n",
" \n",
" self._rnn_checked = True\n",
" \n",
" model_type = \"真实RNN模型\" if self.use_real_rnn else \"模拟模型\"\n",
" print(f\"🤖 模型状态: {model_type}\")\n",
" \n",
" def _verify_system_ready(self):\n",
" \"\"\"验证系统就绪状态\"\"\"\n",
" if not self._sessions_scanned:\n",
" raise RuntimeError(\"Sessions未扫描完成\")\n",
" if not self._rnn_checked:\n",
" raise RuntimeError(\"RNN模型未检查完成\")\n",
" \n",
" # 验证基本要求\n",
" if len(self.sessions) == 0:\n",
" raise ValueError(\"未找到有效的session数据\")\n",
" \n",
" def _print_status_summary(self):\n",
" \"\"\"打印状态摘要\"\"\"\n",
" print(f\"\\n📋 系统状态摘要:\")\n",
" print(\"=\"*50)\n",
" print(f\"📂 数据目录: {self.data_dir}\")\n",
" print(f\"📊 有效Sessions: {len(self.sessions)}\")\n",
" print(f\"🤖 RNN模型: {'真实模型' if self.use_real_rnn else '模拟模型'}\")\n",
" print(f\"✅ 系统状态: {'就绪' if self._sessions_scanned and self._rnn_checked else '未就绪'}\")\n",
" \n",
" if len(self.sessions) > 0:\n",
" print(f\"\\n📝 前5个session:\")\n",
" for i, session in enumerate(self.sessions[:5]):\n",
" print(f\" {i+1:2d}. {session}\")\n",
" if len(self.sessions) > 5:\n",
" print(f\" ... 还有 {len(self.sessions)-5} 个session\")\n",
" \n",
" def _load_session_data(self, session_name):\n",
" \"\"\"加载单个session的数据\"\"\"\n",
" session_path = os.path.join(self.data_dir, session_name)\n",
" \n",
" try:\n",
" # 查找数据文件\n",
" data_files = [f for f in os.listdir(session_path) \n",
" if f.endswith('.h5') or f.endswith('.hdf5')]\n",
" \n",
" if not data_files:\n",
" return None, \"未找到数据文件\"\n",
" \n",
" # 使用第一个找到的数据文件\n",
" data_file = os.path.join(session_path, data_files[0])\n",
" \n",
" with h5py.File(data_file, 'r') as f:\n",
" neural_data = []\n",
" labels = []\n",
" \n",
" # 遍历所有试验\n",
" for trial_key in f.keys():\n",
" if trial_key.startswith('trial_'):\n",
" trial_group = f[trial_key]\n",
" \n",
" # 获取神经数据和标签\n",
" neural_features = trial_group['neural_data'][:] # [time_steps, features]\n",
" trial_labels = trial_group['labels'][:] # [time_steps]\n",
" \n",
" # 确保数据格式正确\n",
" if len(neural_features.shape) == 2 and len(trial_labels.shape) == 1:\n",
" # 检查时间步长是否匹配\n",
" if neural_features.shape[0] == trial_labels.shape[0]:\n",
" neural_data.append(neural_features)\n",
" labels.append(trial_labels)\n",
" \n",
" if not neural_data:\n",
" return None, \"未找到有效的试验数据\"\n",
" \n",
" print(f\" 📊 {session_name}: 加载了 {len(neural_data)} 个试验\")\n",
" return (neural_data, labels), None\n",
" \n",
" except Exception as e:\n",
" return None, f\"加载失败: {str(e)}\"\n",
" \n",
" def _extract_rnn_features(self, neural_data, session_name):\n",
" \"\"\"使用真实RNN模型提取特征\"\"\"\n",
" if not self.use_real_rnn:\n",
" return self._simulate_rnn_predictions(neural_data)\n",
" \n",
" try:\n",
" # 获取session对应的day索引\n",
" day_idx = self.rnn_loader.get_day_index(session_name)\n",
" \n",
" rnn_predictions = []\n",
" \n",
" # 为试验预测添加进度条\n",
" with tqdm(neural_data, desc=f\"RNN预测 {session_name}\", unit=\"试验\", leave=False) as pbar:\n",
" for trial_idx, neural_features in enumerate(pbar):\n",
" pbar.set_postfix_str(f\"试验 {trial_idx+1}\")\n",
" \n",
" # 使用真实RNN模型进行预测\n",
" logits = self.rnn_loader.predict_trial(neural_features, day_idx)\n",
" \n",
" if logits is not None:\n",
" rnn_predictions.append(logits)\n",
" pbar.set_postfix_str(f\"✅ 试验 {trial_idx+1}\")\n",
" else:\n",
" # 如果预测失败,使用模拟数据\n",
" print(f\" ⚠️ 试验 {trial_idx} RNN预测失败使用模拟数据\")\n",
" simulated = self._simulate_single_trial_prediction(neural_features)\n",
" rnn_predictions.append(simulated)\n",
" pbar.set_postfix_str(f\"⚠️ 试验 {trial_idx+1} (模拟)\")\n",
" \n",
" return rnn_predictions\n",
" \n",
" except Exception as e:\n",
" print(f\" ❌ RNN特征提取失败: {str(e)}\")\n",
" print(f\" 🔄 回退到模拟预测\")\n",
" return self._simulate_rnn_predictions(neural_data)\n",
" \n",
" def _simulate_single_trial_prediction(self, neural_features):\n",
" \"\"\"为单个试验生成模拟RNN预测\"\"\"\n",
" time_steps = neural_features.shape[0]\n",
" n_phonemes = 40\n",
" \n",
" # 生成模拟的logits更加真实的分布\n",
" logits = np.random.randn(time_steps, n_phonemes) * 2.0\n",
" \n",
" # 添加一些时间相关的模式\n",
" for t in range(time_steps):\n",
" # 静音类在开始和结束时概率更高\n",
" if t < 5 or t > time_steps - 5:\n",
" logits[t, 0] += 2.0 # 静音类\n",
" \n",
" # 添加一些语音学合理的模式\n",
" if t % 10 < 3: # 模拟辅音\n",
" logits[t, 1:15] += 1.0\n",
" else: # 模拟元音\n",
" logits[t, 15:25] += 1.0\n",
" \n",
" return logits\n",
" \n",
" # def _simulate_rnn_predictions(self, neural_data):\n",
" # \"\"\"为所有试验生成模拟RNN预测\"\"\"\n",
" # print(\" 🎭 使用模拟RNN预测\")\n",
" # predictions = []\n",
" \n",
" # for neural_features in neural_data:\n",
" # pred = self._simulate_single_trial_prediction(neural_features)\n",
" # predictions.append(pred)\n",
" \n",
" # return predictions\n",
" \n",
" def _compute_confidence_metrics(self, logits):\n",
" \"\"\"计算置信度指标\"\"\"\n",
" # 转换为概率\n",
" probs = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)\n",
" \n",
" # 计算熵(不确定性指标)\n",
" entropy = -np.sum(probs * np.log(probs + 1e-8), axis=-1)\n",
" \n",
" # 计算最大概率(置信度指标)\n",
" max_prob = np.max(probs, axis=-1)\n",
" \n",
" # 计算top-2差距决策边界指标\n",
" sorted_probs = np.sort(probs, axis=-1)\n",
" top2_margin = sorted_probs[:, -1] - sorted_probs[:, -2]\n",
" \n",
" return {\n",
" 'entropy': entropy,\n",
" 'max_prob': max_prob,\n",
" 'top2_margin': top2_margin,\n",
" 'mean_entropy': np.mean(entropy),\n",
" 'mean_max_prob': np.mean(max_prob),\n",
" 'mean_top2_margin': np.mean(top2_margin)\n",
" }\n",
" \n",
" def _process_single_session(self, session_name):\n",
" \"\"\"处理单个session\"\"\"\n",
" print(f\"\\n🔄 处理session: {session_name}\")\n",
" \n",
" # 加载数据\n",
" data_result, error = self._load_session_data(session_name)\n",
" if data_result is None:\n",
" print(f\" ❌ 数据加载失败: {error}\")\n",
" return None\n",
" \n",
" neural_data, labels = data_result\n",
" \n",
" # 提取RNN特征\n",
" print(f\" 🧠 提取RNN特征...\")\n",
" rnn_predictions = self._extract_rnn_features(neural_data, session_name)\n",
" \n",
" if not rnn_predictions:\n",
" print(f\" ❌ RNN特征提取失败\")\n",
" return None\n",
" \n",
" # 处理所有试验数据\n",
" session_features = []\n",
" \n",
" for trial_idx, (neural_features, trial_labels, rnn_logits) in enumerate(\n",
" zip(neural_data, labels, rnn_predictions)):\n",
" \n",
" # 确保维度匹配\n",
" min_length = min(len(neural_features), len(trial_labels), len(rnn_logits))\n",
" neural_features = neural_features[:min_length]\n",
" trial_labels = trial_labels[:min_length]\n",
" rnn_logits = rnn_logits[:min_length]\n",
" \n",
" # 计算置信度指标\n",
" confidence_metrics = self._compute_confidence_metrics(rnn_logits)\n",
" \n",
" # 创建DataFrame\n",
" trial_df = pd.DataFrame()\n",
" \n",
" # 添加神经特征\n",
" for feat_idx in range(neural_features.shape[1]):\n",
" trial_df[f'neural_feat_{feat_idx:03d}'] = neural_features[:, feat_idx]\n",
" \n",
" # 添加RNN预测的40个音素概率\n",
" rnn_probs = np.exp(rnn_logits) / np.sum(np.exp(rnn_logits), axis=-1, keepdims=True)\n",
" for phoneme_idx in range(40):\n",
" trial_df[f'phoneme_{phoneme_idx:02d}'] = rnn_probs[:, phoneme_idx]\n",
" \n",
" # 添加置信度指标\n",
" trial_df['confidence_entropy'] = confidence_metrics['entropy']\n",
" trial_df['confidence_max_prob'] = confidence_metrics['max_prob']\n",
" trial_df['confidence_top2_margin'] = confidence_metrics['top2_margin']\n",
" \n",
" # 添加元数据\n",
" trial_df['session_name'] = session_name\n",
" trial_df['trial_id'] = trial_idx\n",
" trial_df['time_step'] = range(len(trial_df))\n",
" trial_df['ground_truth_label'] = trial_labels\n",
" \n",
" session_features.append(trial_df)\n",
" \n",
" # 合并所有试验\n",
" if session_features:\n",
" combined_df = pd.concat(session_features, ignore_index=True)\n",
" print(f\" ✅ {session_name}: 处理完成,共 {len(combined_df)} 个样本\")\n",
" return combined_df\n",
" else:\n",
" print(f\" ❌ {session_name}: 没有有效数据\")\n",
" return None\n",
" \n",
" def process_sessions(self, session_names=None, max_sessions=None):\n",
" \"\"\"批量处理sessions带进度条\"\"\"\n",
" # 确保已检查状态\n",
" if not self._sessions_scanned or not self._rnn_checked:\n",
" print(\"⚠️ 请先运行 check_status() 检查系统状态\")\n",
" return\n",
" \n",
" # 确定要处理的sessions\n",
" if session_names is None:\n",
" target_sessions = self.sessions[:max_sessions] if max_sessions else self.sessions\n",
" else:\n",
" target_sessions = [s for s in session_names if s in self.sessions]\n",
" \n",
" if not target_sessions:\n",
" print(\"❌ 没有找到要处理的sessions\")\n",
" return\n",
" \n",
" print(f\"\\n🚀 开始批量处理 {len(target_sessions)} 个sessions\")\n",
" print(\"=\"*60)\n",
" \n",
" successful_results = {}\n",
" \n",
" # 使用进度条处理sessions\n",
" with tqdm(target_sessions, desc=\"处理Sessions\", unit=\"session\") as pbar:\n",
" for session_name in pbar:\n",
" pbar.set_postfix_str(f\"处理: {session_name}\")\n",
" \n",
" try:\n",
" result_df = self._process_single_session(session_name)\n",
" if result_df is not None:\n",
" successful_results[session_name] = result_df\n",
" pbar.set_postfix_str(f\"✅ {session_name}\")\n",
" else:\n",
" pbar.set_postfix_str(f\"❌ {session_name}\")\n",
" \n",
" except Exception as e:\n",
" print(f\" 💥 处理 {session_name} 时出错: {str(e)}\")\n",
" pbar.set_postfix_str(f\"\udca5 {session_name}\")\n",
" \n",
" # 小延迟以显示状态\n",
" time.sleep(0.1)\n",
" \n",
" self.results = successful_results\n",
" \n",
" print(f\"\\n📊 批量处理完成!\")\n",
" print(f\" 成功处理: {len(successful_results)} / {len(target_sessions)} sessions\")\n",
" print(f\" 总样本数: {sum(len(df) for df in successful_results.values()):,}\")\n",
" \n",
" if successful_results:\n",
" print(f\" 特征维度: {list(successful_results.values())[0].shape[1]} 列\")\n",
" \n",
" def get_combined_dataset(self):\n",
" \"\"\"获取合并的数据集\"\"\"\n",
" if not self.results:\n",
" print(\"❌ 没有处理结果,请先运行 process_sessions()\")\n",
" return None\n",
" \n",
" print(\"🔗 合并所有session数据...\")\n",
" combined_dfs = list(self.results.values())\n",
" \n",
" if combined_dfs:\n",
" combined_df = pd.concat(combined_dfs, ignore_index=True)\n",
" print(f\"✅ 合并完成: {len(combined_df):,} 个样本\")\n",
" return combined_df\n",
" else:\n",
" return None\n",
" \n",
" def save_results(self, output_dir='./outputs'):\n",
" \"\"\"保存处理结果(带进度条)\"\"\"\n",
" if not self.results:\n",
" print(\"❌ 没有处理结果,请先运行 process_sessions()\")\n",
" return\n",
" \n",
" os.makedirs(output_dir, exist_ok=True)\n",
" print(f\"💾 保存处理结果到: {output_dir}\")\n",
" \n",
" # 使用进度条保存文件\n",
" with tqdm(self.results.items(), desc=\"保存文件\", unit=\"文件\") as pbar:\n",
" for session_name, df in pbar:\n",
" pbar.set_postfix_str(f\"保存: {session_name}\")\n",
" output_path = os.path.join(output_dir, f\"{session_name}_features.csv\")\n",
" df.to_csv(output_path, index=False) \n",
" pbar.set_postfix_str(f\"✅ {session_name}\")\n",
" time.sleep(0.05) # 小延迟以显示进度\n",
" \n",
" # 保存合并数据集\n",
" combined_df = self.get_combined_dataset()\n",
" if combined_df is not None:\n",
" print(\"💾 保存合并数据集...\")\n",
" combined_path = os.path.join(output_dir, \"combined_all_sessions_features.csv\")\n",
" combined_df.to_csv(combined_path, index=False)\n",
" print(f\" ✅ 合并数据集: {combined_path}\")\n",
" \n",
" print(f\"💾 保存完成!\")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🔧 创建改进版数据处理管道 (仅真实数据模式)...\n",
"🔧 初始化改进版数据处理管道...\n",
"✅ 管道初始化完成!\n",
"✅ 改进版管道创建完成!\n",
"⚠️ 重要: 本管道仅支持真实RNN模型不提供任何模拟功能\n",
"💡 使用方法:\n",
" # 确保rnn_loader已加载真实模型\n",
" results = improved_pipeline.run_full_pipeline(\n",
" train_list, val_list, test_list,\n",
" max_files_per_split=3,\n",
" rnn_loader=rnn_loader # 必须是已加载的真实模型\n",
" )\n"
]
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"📚 数据集访问和使用示例\n",
"============================================================\n",
"✅ 数据集变量已创建:\n",
" train_datasets: 0 个训练DataFrame\n",
" val_datasets: 0 个验证DataFrame\n",
" test_datasets: 0 个测试DataFrame\n",
"\n",
"🔗 数据集合并示例:\n",
"\n",
"💾 数据集保存功能:\n",
"使用 save_datasets_to_files(train_datasets, val_datasets, test_datasets)\n",
"将保存所有单独session和合并的数据集文件\n",
"\n",
"🎯 工作流完成总结:\n",
"============================================================\n",
"✅ 成功创建了完整的数据集处理工作流\n",
"✅ 生成了三个主要变量train_datasets, val_datasets, test_datasets\n",
"✅ 每个变量包含多个DataFrame对应不同的sessions\n",
"✅ 每个DataFrame包含\n",
" • 512维神经特征 (neural_feat_000 ~ neural_feat_511)\n",
" • 40维音素预测概率 (phoneme_0 ~ phoneme_39)\n",
" • 置信度指标 (entropy, top2_margin)\n",
" • 元数据 (session, trial, 时间戳等)\n",
"✅ 支持单独访问或合并使用\n",
"✅ 支持保存为CSV文件\n",
"\n",
"🚀 现在你可以使用这些数据集进行机器学习建模了!\n"
]
}
],
"source": [
"# 📚 数据集访问和使用示例\n",
"print(\"📚 数据集访问和使用示例\")\n",
"print(\"=\"*60)\n",
"\n",
"# 方便的变量赋值(按照你的要求)\n",
"train_datasets = pipeline.train_datasets # 训练集列表\n",
"val_datasets = pipeline.val_datasets # 验证集列表 \n",
"test_datasets = pipeline.test_datasets # 测试集列表\n",
"\n",
"print(f\"✅ 数据集变量已创建:\")\n",
"print(f\" train_datasets: {len(train_datasets)} 个训练DataFrame\")\n",
"print(f\" val_datasets: {len(val_datasets)} 个验证DataFrame\")\n",
"print(f\" test_datasets: {len(test_datasets)} 个测试DataFrame\")\n",
"\n",
"# 演示如何使用这些数据集\n",
"if train_datasets:\n",
" print(f\"\\n🔍 数据集使用示例:\")\n",
" \n",
" # 访问第一个训练集\n",
" first_train_session = train_datasets[0]\n",
" print(f\"\\n1. 访问第一个训练session:\")\n",
" print(f\" 形状: {first_train_session.shape}\")\n",
" print(f\" Session: {first_train_session['session'].iloc[0]}\")\n",
" print(f\" 试验数: {first_train_session['trial_idx'].max() + 1}\")\n",
" \n",
" # 获取神经特征\n",
" neural_cols = [col for col in first_train_session.columns if col.startswith('neural_feat_')]\n",
" neural_features = first_train_session[neural_cols]\n",
" print(f\"\\n2. 提取神经特征:\")\n",
" print(f\" 神经特征形状: {neural_features.shape}\")\n",
" print(f\" 特征维度: {len(neural_cols)}\")\n",
" \n",
" # 获取音素预测\n",
" phoneme_cols = [col for col in first_train_session.columns if col.startswith('phoneme_')]\n",
" phoneme_predictions = first_train_session[phoneme_cols]\n",
" print(f\"\\n3. 提取音素预测:\")\n",
" print(f\" 预测概率形状: {phoneme_predictions.shape}\")\n",
" print(f\" 音素类别数: {len(phoneme_cols)}\")\n",
" \n",
" # 获取元数据\n",
" metadata_cols = ['time_step', 'session', 'trial_idx', 'ground_truth_phoneme_name', \n",
" 'predicted_phoneme_name', 'max_probability', 'entropy', 'top2_margin']\n",
" metadata = first_train_session[metadata_cols]\n",
" print(f\"\\n4. 提取元数据:\")\n",
" print(f\" 元数据列数: {len(metadata_cols)}\")\n",
" print(f\" 包含: 时间步、session、试验信息、真实/预测标签、置信度指标\")\n",
"\n",
"# 合并所有数据集的函数\n",
"def combine_datasets(dataset_list, split_name):\n",
" \"\"\"合并同一类型的所有数据集\"\"\"\n",
" if not dataset_list:\n",
" return None\n",
" \n",
" combined_df = pd.concat(dataset_list, ignore_index=True)\n",
" print(f\"\\n📊 {split_name}合并结果:\")\n",
" print(f\" 总数据形状: {combined_df.shape}\")\n",
" print(f\" 包含sessions: {combined_df['session'].nunique()} 个\")\n",
" print(f\" 总试验数: {combined_df['trial_idx'].nunique()}\")\n",
" print(f\" 总时间步数: {len(combined_df):,}\")\n",
" \n",
" return combined_df\n",
"\n",
"# 演示合并数据集\n",
"print(f\"\\n🔗 数据集合并示例:\")\n",
"if train_datasets:\n",
" combined_train = combine_datasets(train_datasets, \"训练集\")\n",
" \n",
"if val_datasets:\n",
" combined_val = combine_datasets(val_datasets, \"验证集\")\n",
" \n",
"if test_datasets:\n",
" combined_test = combine_datasets(test_datasets, \"测试集\")\n",
"\n",
"# 保存数据集的函数\n",
"def save_datasets_to_files(train_list, val_list, test_list, output_dir='./processed_datasets'):\n",
" \"\"\"保存所有数据集到文件\"\"\"\n",
" saved_files = []\n",
" \n",
" # 保存训练集\n",
" for i, df in enumerate(train_list):\n",
" filename = os.path.join(output_dir, f'train_session_{i:02d}_{df[\"session\"].iloc[0]}.csv')\n",
" df.to_csv(filename, index=False)\n",
" saved_files.append(filename)\n",
" \n",
" # 保存验证集\n",
" for i, df in enumerate(val_list):\n",
" filename = os.path.join(output_dir, f'val_session_{i:02d}_{df[\"session\"].iloc[0]}.csv')\n",
" df.to_csv(filename, index=False)\n",
" saved_files.append(filename)\n",
" \n",
" # 保存测试集\n",
" for i, df in enumerate(test_list):\n",
" filename = os.path.join(output_dir, f'test_session_{i:02d}_{df[\"session\"].iloc[0]}.csv')\n",
" df.to_csv(filename, index=False)\n",
" saved_files.append(filename)\n",
" \n",
" # 保存合并的数据集\n",
" if train_list:\n",
" combined_train = pd.concat(train_list, ignore_index=True)\n",
" train_combined_file = os.path.join(output_dir, 'combined_train_all_sessions.csv')\n",
" combined_train.to_csv(train_combined_file, index=False)\n",
" saved_files.append(train_combined_file)\n",
" \n",
" if val_list:\n",
" combined_val = pd.concat(val_list, ignore_index=True)\n",
" val_combined_file = os.path.join(output_dir, 'combined_val_all_sessions.csv')\n",
" combined_val.to_csv(val_combined_file, index=False)\n",
" saved_files.append(val_combined_file)\n",
" \n",
" if test_list:\n",
" combined_test = pd.concat(test_list, ignore_index=True)\n",
" test_combined_file = os.path.join(output_dir, 'combined_test_all_sessions.csv')\n",
" combined_test.to_csv(test_combined_file, index=False)\n",
" saved_files.append(test_combined_file)\n",
" \n",
" return saved_files\n",
"\n",
"print(f\"\\n💾 数据集保存功能:\")\n",
"print(f\"使用 save_datasets_to_files(train_datasets, val_datasets, test_datasets)\")\n",
"print(f\"将保存所有单独session和合并的数据集文件\")\n",
"\n",
"print(f\"\\n🎯 工作流完成总结:\")\n",
"print(\"=\"*60)\n",
"print(f\"✅ 成功创建了完整的数据集处理工作流\")\n",
"print(f\"✅ 生成了三个主要变量train_datasets, val_datasets, test_datasets\")\n",
"print(f\"✅ 每个变量包含多个DataFrame对应不同的sessions\")\n",
"print(f\"✅ 每个DataFrame包含\")\n",
"print(f\" • 512维神经特征 (neural_feat_000 ~ neural_feat_511)\")\n",
"print(f\" • 40维音素预测概率 (phoneme_0 ~ phoneme_39)\")\n",
"print(f\" • 置信度指标 (entropy, top2_margin)\")\n",
"print(f\" • 元数据 (session, trial, 时间戳等)\")\n",
"print(f\"✅ 支持单独访问或合并使用\")\n",
"print(f\"✅ 支持保存为CSV文件\")\n",
"\n",
"print(f\"\\n🚀 现在你可以使用这些数据集进行机器学习建模了!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 模型建立"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 🌲 随机森林多输出回归模型\n",
"\n",
"使用随机森林对40个音素概率进行回归预测输入为512维神经特征时间窗口为30。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🌲 初始化随机森林回归模型\n",
"🌲 随机森林回归器初始化:\n",
" 时间窗口大小: 30\n",
" 树的数量: 100\n",
" 最大深度: 10\n",
" 并行任务: -1\n",
"\n",
"✅ 随机森林回归器准备完成!\n",
"🔧 下一步: 准备训练数据和开始训练\n"
]
}
],
"source": [
"# 🌲 随机森林多输出回归模型实现\n",
"import numpy as np\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error\n",
"from sklearn.preprocessing import StandardScaler\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from tqdm import tqdm\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"class TimeWindowRandomForestRegressor:\n",
" \"\"\"\n",
" 基于时间窗口的随机森林多输出回归器\n",
" 用于预测40个音素的概率分布\n",
" \"\"\"\n",
" \n",
" def __init__(self, window_size=30, n_estimators=100, max_depth=10, n_jobs=-1, random_state=42):\n",
" \"\"\"\n",
" 初始化模型\n",
" \n",
" 参数:\n",
" window_size: 时间窗口大小\n",
" n_estimators: 随机森林中树的数量\n",
" max_depth: 树的最大深度\n",
" n_jobs: 并行任务数\n",
" random_state: 随机种子\n",
" \"\"\"\n",
" self.window_size = window_size\n",
" self.n_estimators = n_estimators\n",
" self.max_depth = max_depth\n",
" self.n_jobs = n_jobs\n",
" self.random_state = random_state\n",
" \n",
" # 初始化模型和预处理器\n",
" self.regressor = RandomForestRegressor(\n",
" n_estimators=n_estimators,\n",
" max_depth=max_depth,\n",
" n_jobs=n_jobs,\n",
" random_state=random_state,\n",
" verbose=1\n",
" )\n",
" \n",
" self.scaler = StandardScaler()\n",
" self.is_fitted = False\n",
" \n",
" print(f\"🌲 随机森林回归器初始化:\")\n",
" print(f\" 时间窗口大小: {window_size}\")\n",
" print(f\" 树的数量: {n_estimators}\")\n",
" print(f\" 最大深度: {max_depth}\")\n",
" print(f\" 并行任务: {n_jobs}\")\n",
" \n",
" def create_time_windows(self, neural_features, phoneme_targets=None):\n",
" \"\"\"\n",
" 创建时间窗口特征\n",
" \n",
" 参数:\n",
" neural_features: 神经特征 [time_steps, 512]\n",
" phoneme_targets: 音素目标 [time_steps, 40] (可选)\n",
" \n",
" 返回:\n",
" windowed_features: [samples, window_size * 512]\n",
" windowed_targets: [samples, 40] (如果提供targets)\n",
" \"\"\"\n",
" if len(neural_features) < self.window_size:\n",
" print(f\"⚠️ 数据长度 {len(neural_features)} 小于窗口大小 {self.window_size}\")\n",
" return None, None\n",
" \n",
" n_samples = len(neural_features) - self.window_size + 1\n",
" n_features = neural_features.shape[1]\n",
" \n",
" # 创建时间窗口特征\n",
" windowed_features = np.zeros((n_samples, self.window_size * n_features))\n",
" \n",
" for i in range(n_samples):\n",
" # 展平时间窗口内的所有特征\n",
" window_data = neural_features[i:i+self.window_size].flatten()\n",
" windowed_features[i] = window_data\n",
" \n",
" windowed_targets = None\n",
" if phoneme_targets is not None:\n",
" # 使用窗口中心点的音素概率作为目标\n",
" center_offset = self.window_size // 2\n",
" windowed_targets = phoneme_targets[center_offset:center_offset+n_samples]\n",
" \n",
" return windowed_features, windowed_targets\n",
" \n",
" def prepare_dataset_for_training(self, datasets_list, dataset_type=\"train\"):\n",
" \"\"\"\n",
" 准备训练数据集\n",
" \n",
" 参数:\n",
" datasets_list: DataFrame列表 (train_datasets, val_datasets, etc.)\n",
" dataset_type: 数据集类型名称\n",
" \n",
" 返回:\n",
" X: 特征矩阵 [总样本数, window_size * 512]\n",
" y: 目标矩阵 [总样本数, 40]\n",
" \"\"\"\n",
" print(f\"\\n📊 准备{dataset_type}数据集:\")\n",
" print(f\" 输入数据集数量: {len(datasets_list)}\")\n",
" \n",
" all_X = []\n",
" all_y = []\n",
" \n",
" for i, df in enumerate(tqdm(datasets_list, desc=f\"处理{dataset_type}数据\")):\n",
" # 提取神经特征 (前512列)\n",
" neural_cols = [col for col in df.columns if col.startswith('neural_feat_')]\n",
" neural_features = df[neural_cols].values\n",
" \n",
" # 提取音素目标 (40列音素概率)\n",
" phoneme_cols = [col for col in df.columns if col.startswith('phoneme_')]\n",
" phoneme_targets = df[phoneme_cols].values\n",
" \n",
" # 按trial分组处理\n",
" trials = df['trial_idx'].unique()\n",
" \n",
" for trial_idx in trials:\n",
" trial_mask = df['trial_idx'] == trial_idx\n",
" trial_neural = neural_features[trial_mask]\n",
" trial_phonemes = phoneme_targets[trial_mask]\n",
" \n",
" # 创建时间窗口\n",
" windowed_X, windowed_y = self.create_time_windows(trial_neural, trial_phonemes)\n",
" \n",
" if windowed_X is not None and windowed_y is not None:\n",
" all_X.append(windowed_X)\n",
" all_y.append(windowed_y)\n",
" \n",
" if not all_X:\n",
" print(f\"❌ 没有有效的{dataset_type}数据\")\n",
" return None, None\n",
" \n",
" # 合并所有数据\n",
" X = np.vstack(all_X)\n",
" y = np.vstack(all_y)\n",
" \n",
" print(f\" ✅ {dataset_type}数据准备完成:\")\n",
" print(f\" 特征矩阵形状: {X.shape}\")\n",
" print(f\" 目标矩阵形状: {y.shape}\")\n",
" print(f\" 内存使用: {X.nbytes / 1024**2:.1f} MB (X) + {y.nbytes / 1024**2:.1f} MB (y)\")\n",
" \n",
" return X, y\n",
" \n",
" def fit(self, X_train, y_train, X_val=None, y_val=None):\n",
" \"\"\"\n",
" 训练模型\n",
" \n",
" 参数:\n",
" X_train: 训练特征\n",
" y_train: 训练目标\n",
" X_val: 验证特征 (可选)\n",
" y_val: 验证目标 (可选)\n",
" \"\"\"\n",
" print(f\"\\n🚀 开始训练随机森林回归模型:\")\n",
" print(f\" 训练样本数: {X_train.shape[0]:,}\")\n",
" print(f\" 特征维度: {X_train.shape[1]:,}\")\n",
" print(f\" 目标维度: {y_train.shape[1]}\")\n",
" \n",
" # 标准化特征\n",
" print(\" 🔄 标准化特征...\")\n",
" X_train_scaled = self.scaler.fit_transform(X_train)\n",
" \n",
" # 训练模型\n",
" print(\" 🌲 训练随机森林...\")\n",
" self.regressor.fit(X_train_scaled, y_train)\n",
" \n",
" self.is_fitted = True\n",
" \n",
" # 计算训练集性能\n",
" print(\" 📊 评估训练集性能...\")\n",
" train_predictions = self.regressor.predict(X_train_scaled)\n",
" train_mse = mean_squared_error(y_train, train_predictions)\n",
" train_r2 = r2_score(y_train, train_predictions)\n",
" train_mae = mean_absolute_error(y_train, train_predictions)\n",
" \n",
" print(f\" ✅ 训练完成!\")\n",
" print(f\" 训练集 MSE: {train_mse:.6f}\")\n",
" print(f\" 训练集 R²: {train_r2:.4f}\")\n",
" print(f\" 训练集 MAE: {train_mae:.6f}\")\n",
" \n",
" # 如果有验证集,计算验证集性能\n",
" if X_val is not None and y_val is not None:\n",
" print(\" 📊 评估验证集性能...\")\n",
" X_val_scaled = self.scaler.transform(X_val)\n",
" val_predictions = self.regressor.predict(X_val_scaled)\n",
" val_mse = mean_squared_error(y_val, val_predictions)\n",
" val_r2 = r2_score(y_val, val_predictions)\n",
" val_mae = mean_absolute_error(y_val, val_predictions)\n",
" \n",
" print(f\" 验证集 MSE: {val_mse:.6f}\")\n",
" print(f\" 验证集 R²: {val_r2:.4f}\")\n",
" print(f\" 验证集 MAE: {val_mae:.6f}\")\n",
" \n",
" return {\n",
" 'train_mse': train_mse, 'train_r2': train_r2, 'train_mae': train_mae,\n",
" 'val_mse': val_mse, 'val_r2': val_r2, 'val_mae': val_mae\n",
" }\n",
" \n",
" return {\n",
" 'train_mse': train_mse, 'train_r2': train_r2, 'train_mae': train_mae\n",
" }\n",
" \n",
" def predict(self, X):\n",
" \"\"\"预测\"\"\"\n",
" if not self.is_fitted:\n",
" raise ValueError(\"模型尚未训练请先调用fit()方法\")\n",
" \n",
" X_scaled = self.scaler.transform(X)\n",
" return self.regressor.predict(X_scaled)\n",
" \n",
" def get_feature_importance(self, top_k=20):\n",
" \"\"\"获取特征重要性\"\"\"\n",
" if not self.is_fitted:\n",
" raise ValueError(\"模型尚未训练请先调用fit()方法\")\n",
" \n",
" importances = self.regressor.feature_importances_\n",
" \n",
" # 创建特征名称 (window_timestep_feature)\n",
" feature_names = []\n",
" for t in range(self.window_size):\n",
" for f in range(512):\n",
" feature_names.append(f\"t{t}_feat{f}\")\n",
" \n",
" # 获取top-k重要特征\n",
" top_indices = np.argsort(importances)[::-1][:top_k]\n",
" top_features = [(feature_names[i], importances[i]) for i in top_indices]\n",
" \n",
" return top_features, importances\n",
"\n",
"# 初始化模型\n",
"print(\"🌲 初始化随机森林回归模型\")\n",
"rf_regressor = TimeWindowRandomForestRegressor(\n",
" window_size=30, # 时间窗口大小\n",
" n_estimators=100, # 树的数量\n",
" max_depth=10, # 最大深度 (防止过拟合)\n",
" n_jobs=-1, # 使用所有CPU核心\n",
" random_state=42\n",
")\n",
"\n",
"print(\"\\n✅ 随机森林回归器准备完成!\")\n",
"print(\"🔧 下一步: 准备训练数据和开始训练\")"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🚀 开始数据准备和模型训练流程\n",
"============================================================\n",
"❌ 没有可用的训练数据,请先运行数据处理工作流\n"
]
}
],
"source": [
"# 🚀 准备数据并训练随机森林回归模型\n",
"print(\"🚀 开始数据准备和模型训练流程\")\n",
"print(\"=\"*60)\n",
"\n",
"# 检查数据可用性\n",
"if not train_datasets:\n",
" print(\"❌ 没有可用的训练数据,请先运行数据处理工作流\")\n",
"else:\n",
" print(f\"✅ 检测到数据:\")\n",
" print(f\" 训练数据集: {len(train_datasets)} 个sessions\")\n",
" print(f\" 验证数据集: {len(val_datasets)} 个sessions\")\n",
" print(f\" 测试数据集: {len(test_datasets)} 个sessions\")\n",
"\n",
" # 1. 准备训练数据\n",
" print(f\"\\n📊 第1步: 准备训练数据\")\n",
" X_train, y_train = rf_regressor.prepare_dataset_for_training(train_datasets, \"训练集\")\n",
" \n",
" # 2. 准备验证数据\n",
" print(f\"\\n📊 第2步: 准备验证数据\")\n",
" X_val, y_val = rf_regressor.prepare_dataset_for_training(val_datasets, \"验证集\")\n",
" \n",
" if X_train is not None and y_train is not None:\n",
" print(f\"\\n📈 数据准备完成统计:\")\n",
" print(f\" 训练集: {X_train.shape[0]:,} 样本\")\n",
" print(f\" 验证集: {X_val.shape[0]:,} 样本\" if X_val is not None else \" 验证集: 无\")\n",
" print(f\" 特征维度: {X_train.shape[1]:,} (时间窗口30 × 512特征)\")\n",
" print(f\" 目标维度: {y_train.shape[1]} (40个音素概率)\")\n",
" \n",
" # 检查数据质量\n",
" print(f\"\\n🔍 数据质量检查:\")\n",
" print(f\" 训练特征范围: [{X_train.min():.4f}, {X_train.max():.4f}]\")\n",
" print(f\" 训练目标范围: [{y_train.min():.4f}, {y_train.max():.4f}]\")\n",
" print(f\" 训练特征均值: {X_train.mean():.4f}\")\n",
" print(f\" 训练目标均值: {y_train.mean():.4f}\")\n",
" \n",
" # 检查是否有NaN或无穷值\n",
" nan_count_X = np.isnan(X_train).sum()\n",
" nan_count_y = np.isnan(y_train).sum()\n",
" inf_count_X = np.isinf(X_train).sum()\n",
" inf_count_y = np.isinf(y_train).sum()\n",
" \n",
" print(f\" NaN检查: X有{nan_count_X}个, y有{nan_count_y}个\")\n",
" print(f\" Inf检查: X有{inf_count_X}个, y有{inf_count_y}个\")\n",
" \n",
" if nan_count_X > 0 or nan_count_y > 0 or inf_count_X > 0 or inf_count_y > 0:\n",
" print(\"⚠️ 检测到异常值,将进行清理...\")\n",
" # 清理异常值\n",
" valid_mask = ~(np.isnan(X_train).any(axis=1) | np.isnan(y_train).any(axis=1) | \n",
" np.isinf(X_train).any(axis=1) | np.isinf(y_train).any(axis=1))\n",
" X_train = X_train[valid_mask]\n",
" y_train = y_train[valid_mask]\n",
" \n",
" if X_val is not None and y_val is not None:\n",
" valid_mask_val = ~(np.isnan(X_val).any(axis=1) | np.isnan(y_val).any(axis=1) | \n",
" np.isinf(X_val).any(axis=1) | np.isinf(y_val).any(axis=1))\n",
" X_val = X_val[valid_mask_val]\n",
" y_val = y_val[valid_mask_val]\n",
" \n",
" print(f\"✅ 数据清理完成,剩余训练样本: {X_train.shape[0]:,}\")\n",
" \n",
" # 3. 训练模型\n",
" print(f\"\\n🌲 第3步: 训练随机森林回归模型\")\n",
" training_results = rf_regressor.fit(X_train, y_train, X_val, y_val)\n",
" \n",
" # 4. 分析训练结果\n",
" print(f\"\\n📊 第4步: 训练结果分析\")\n",
" print(\"=\"*50)\n",
" \n",
" for metric, value in training_results.items():\n",
" metric_name = metric.replace('_', ' ').title()\n",
" print(f\" {metric_name}: {value:.6f}\")\n",
" \n",
" # 5. 特征重要性分析\n",
" print(f\"\\n🔍 第5步: 特征重要性分析\")\n",
" top_features, all_importances = rf_regressor.get_feature_importance(top_k=20)\n",
" \n",
" print(f\"\\n🏆 Top 20 重要特征:\")\n",
" print(f\"{'排名':>4} {'特征名称':>15} {'重要性':>10}\")\n",
" print(\"-\" * 35)\n",
" for i, (feature_name, importance) in enumerate(top_features):\n",
" print(f\"{i+1:>4} {feature_name:>15} {importance:>10.6f}\")\n",
" \n",
" # 分析时间窗口内的重要性分布\n",
" print(f\"\\n📈 时间窗口重要性分布:\")\n",
" window_importances = np.zeros(rf_regressor.window_size)\n",
" for i in range(rf_regressor.window_size):\n",
" start_idx = i * 512\n",
" end_idx = (i + 1) * 512\n",
" window_importances[i] = all_importances[start_idx:end_idx].sum()\n",
" \n",
" max_time_step = np.argmax(window_importances)\n",
" print(f\" 最重要的时间步: t{max_time_step} (重要性: {window_importances[max_time_step]:.6f})\")\n",
" print(f\" 窗口中心位置: t{rf_regressor.window_size//2}\")\n",
" print(f\" 重要性分布: 前5个时间步的重要性\")\n",
" for i in range(min(5, len(window_importances))):\n",
" print(f\" t{i}: {window_importances[i]:.6f}\")\n",
" \n",
" print(f\"\\n✅ 随机森林回归模型训练完成!\")\n",
" print(f\"🎯 模型可以预测40个音素的概率分布\")\n",
" print(f\"📊 基于30时间步的神经特征窗口\")\n",
" print(f\"🌲 使用{rf_regressor.n_estimators}棵决策树\")\n",
" \n",
" else:\n",
" print(\"❌ 数据准备失败,无法训练模型\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"⚠️ 模型尚未训练完成,请先运行训练代码\n"
]
}
],
"source": [
"# 📊 模型评估和可视化分析\n",
"def evaluate_phoneme_predictions(rf_model, X_test, y_test, dataset_name=\"测试集\"):\n",
" \"\"\"\n",
" 评估每个音素的预测性能\n",
" \"\"\"\n",
" print(f\"\\n📊 {dataset_name}详细评估\")\n",
" print(\"=\"*50)\n",
" \n",
" # 获取预测结果\n",
" y_pred = rf_model.predict(X_test)\n",
" \n",
" # 计算每个音素的性能指标\n",
" phoneme_metrics = []\n",
" \n",
" for i in range(40): # 40个音素\n",
" phoneme_name = LOGIT_TO_PHONEME[i]\n",
" \n",
" # 计算单个音素的指标\n",
" mse = mean_squared_error(y_test[:, i], y_pred[:, i])\n",
" r2 = r2_score(y_test[:, i], y_pred[:, i])\n",
" mae = mean_absolute_error(y_test[:, i], y_pred[:, i])\n",
" \n",
" # 计算相关系数\n",
" correlation = np.corrcoef(y_test[:, i], y_pred[:, i])[0, 1]\n",
" \n",
" phoneme_metrics.append({\n",
" 'phoneme_id': i,\n",
" 'phoneme_name': phoneme_name,\n",
" 'mse': mse, \n",
" 'r2': r2,\n",
" 'mae': mae,\n",
" 'correlation': correlation if not np.isnan(correlation) else 0.0\n",
" })\n",
" \n",
" # 转换为DataFrame便于分析\n",
" metrics_df = pd.DataFrame(phoneme_metrics)\n",
" \n",
" # 打印总体统计\n",
" print(f\"📈 总体性能指标:\")\n",
" print(f\" 平均 MSE: {metrics_df['mse'].mean():.6f}\")\n",
" print(f\" 平均 R²: {metrics_df['r2'].mean():.4f}\")\n",
" print(f\" 平均 MAE: {metrics_df['mae'].mean():.6f}\")\n",
" print(f\" 平均相关系数: {metrics_df['correlation'].mean():.4f}\")\n",
" \n",
" # 找出最佳和最差预测的音素\n",
" best_r2_idx = metrics_df['r2'].idxmax()\n",
" worst_r2_idx = metrics_df['r2'].idxmin()\n",
" \n",
" print(f\"\\n🏆 最佳预测音素:\")\n",
" best_phoneme = metrics_df.loc[best_r2_idx]\n",
" print(f\" {best_phoneme['phoneme_name']} (ID: {best_phoneme['phoneme_id']})\")\n",
" print(f\" R²: {best_phoneme['r2']:.4f}, MSE: {best_phoneme['mse']:.6f}\")\n",
" \n",
" print(f\"\\n📉 最差预测音素:\")\n",
" worst_phoneme = metrics_df.loc[worst_r2_idx]\n",
" print(f\" {worst_phoneme['phoneme_name']} (ID: {worst_phoneme['phoneme_id']})\")\n",
" print(f\" R²: {worst_phoneme['r2']:.4f}, MSE: {worst_phoneme['mse']:.6f}\")\n",
" \n",
" return metrics_df, y_pred\n",
"\n",
"def visualize_prediction_results(metrics_df, y_true, y_pred, save_plots=False):\n",
" \"\"\"\n",
" 可视化预测结果\n",
" \"\"\"\n",
" print(f\"\\n📊 创建可视化图表...\")\n",
" \n",
" # 设置图表样式\n",
" plt.style.use('default')\n",
" fig = plt.figure(figsize=(20, 12))\n",
" \n",
" # 1. R²分数分布\n",
" plt.subplot(2, 3, 1)\n",
" plt.hist(metrics_df['r2'], bins=20, alpha=0.7, color='skyblue', edgecolor='black')\n",
" plt.axvline(metrics_df['r2'].mean(), color='red', linestyle='--', \n",
" label=f'平均值: {metrics_df[\"r2\"].mean():.4f}')\n",
" plt.xlabel('R² Score')\n",
" plt.ylabel('音素数量')\n",
" plt.title('R² Score 分布')\n",
" plt.legend()\n",
" plt.grid(True, alpha=0.3)\n",
" \n",
" # 2. MSE分布\n",
" plt.subplot(2, 3, 2)\n",
" plt.hist(metrics_df['mse'], bins=20, alpha=0.7, color='lightcoral', edgecolor='black')\n",
" plt.axvline(metrics_df['mse'].mean(), color='red', linestyle='--',\n",
" label=f'平均值: {metrics_df[\"mse\"].mean():.6f}')\n",
" plt.xlabel('Mean Squared Error')\n",
" plt.ylabel('音素数量')\n",
" plt.title('MSE 分布')\n",
" plt.legend()\n",
" plt.grid(True, alpha=0.3)\n",
" \n",
" # 3. 前10个音素的性能对比\n",
" plt.subplot(2, 3, 3)\n",
" top_10 = metrics_df.nlargest(10, 'r2')\n",
" bars = plt.bar(range(10), top_10['r2'], color='lightgreen', alpha=0.7)\n",
" plt.xlabel('音素排名')\n",
" plt.ylabel('R² Score')\n",
" plt.title('Top 10 音素预测性能')\n",
" plt.xticks(range(10), top_10['phoneme_name'], rotation=45)\n",
" plt.grid(True, alpha=0.3)\n",
" \n",
" # 添加数值标签\n",
" for i, bar in enumerate(bars):\n",
" height = bar.get_height()\n",
" plt.text(bar.get_x() + bar.get_width()/2., height + 0.001,\n",
" f'{height:.3f}', ha='center', va='bottom', fontsize=8)\n",
" \n",
" # 4. 真实值 vs 预测值散点图 (选择最佳音素)\n",
" plt.subplot(2, 3, 4)\n",
" best_phoneme_idx = metrics_df['r2'].idxmax()\n",
" phoneme_id = metrics_df.loc[best_phoneme_idx, 'phoneme_id']\n",
" phoneme_name = metrics_df.loc[best_phoneme_idx, 'phoneme_name']\n",
" \n",
" # 随机采样1000个点以避免图表过于密集\n",
" sample_size = min(1000, len(y_true))\n",
" sample_indices = np.random.choice(len(y_true), sample_size, replace=False)\n",
" \n",
" plt.scatter(y_true[sample_indices, phoneme_id], y_pred[sample_indices, phoneme_id], \n",
" alpha=0.6, s=20, color='blue')\n",
" \n",
" # 添加对角线 (完美预测线)\n",
" min_val = min(y_true[:, phoneme_id].min(), y_pred[:, phoneme_id].min())\n",
" max_val = max(y_true[:, phoneme_id].max(), y_pred[:, phoneme_id].max())\n",
" plt.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8, label='完美预测')\n",
" \n",
" plt.xlabel('真实值')\n",
" plt.ylabel('预测值')\n",
" plt.title(f'最佳音素 {phoneme_name} 的预测结果')\n",
" plt.legend()\n",
" plt.grid(True, alpha=0.3)\n",
" \n",
" # 5. 相关系数热力图 (前20个音素)\n",
" plt.subplot(2, 3, 5)\n",
" top_20_correlations = metrics_df.nlargest(20, 'correlation')\n",
" corr_data = top_20_correlations[['phoneme_name', 'correlation']].set_index('phoneme_name')\n",
" \n",
" # 创建热力图数据\n",
" heatmap_data = corr_data.values.reshape(-1, 1)\n",
" im = plt.imshow(heatmap_data.T, cmap='RdYlBu_r', aspect='auto', vmin=0, vmax=1)\n",
" \n",
" plt.colorbar(im, shrink=0.8)\n",
" plt.yticks([0], ['相关系数'])\n",
" plt.xticks(range(len(top_20_correlations)), top_20_correlations['phoneme_name'], \n",
" rotation=45, ha='right')\n",
" plt.title('Top 20 音素相关系数')\n",
" \n",
" # 6. 各音素预测误差箱线图 (前10个音素)\n",
" plt.subplot(2, 3, 6)\n",
" top_10_ids = metrics_df.nlargest(10, 'r2')['phoneme_id'].values\n",
" errors_data = []\n",
" labels = []\n",
" \n",
" for phoneme_id in top_10_ids:\n",
" errors = np.abs(y_true[:, phoneme_id] - y_pred[:, phoneme_id])\n",
" errors_data.append(errors)\n",
" labels.append(LOGIT_TO_PHONEME[phoneme_id])\n",
" \n",
" plt.boxplot(errors_data, labels=labels)\n",
" plt.xlabel('音素')\n",
" plt.ylabel('绝对误差')\n",
" plt.title('Top 10 音素预测误差分布')\n",
" plt.xticks(rotation=45)\n",
" plt.grid(True, alpha=0.3)\n",
" \n",
" plt.tight_layout()\n",
" \n",
" if save_plots:\n",
" plt.savefig('./processed_datasets/rf_regression_results.png', dpi=300, bbox_inches='tight')\n",
" print(\"📁 图表已保存至: ./processed_datasets/rf_regression_results.png\")\n",
" \n",
" plt.show()\n",
"\n",
"# 如果模型训练成功,进行评估\n",
"if 'rf_regressor' in locals() and rf_regressor.is_fitted:\n",
" print(f\"\\n🎯 开始模型评估和可视化\")\n",
" \n",
" # 评估验证集\n",
" if X_val is not None and y_val is not None:\n",
" val_metrics, val_predictions = evaluate_phoneme_predictions(\n",
" rf_regressor, X_val, y_val, \"验证集\"\n",
" )\n",
" \n",
" # 可视化结果\n",
" visualize_prediction_results(val_metrics, y_val, val_predictions, save_plots=True)\n",
" \n",
" # 保存详细结果\n",
" val_metrics.to_csv('./processed_datasets/phoneme_prediction_metrics.csv', index=False)\n",
" print(f\"\\n📁 详细评估结果已保存至: ./processed_datasets/phoneme_prediction_metrics.csv\")\n",
" \n",
" # 准备测试集数据 (如果有)\n",
" if test_datasets:\n",
" print(f\"\\n🔮 准备测试集预测...\")\n",
" X_test, y_test = rf_regressor.prepare_dataset_for_training(test_datasets, \"测试集\")\n",
" \n",
" if X_test is not None:\n",
" test_metrics, test_predictions = evaluate_phoneme_predictions(\n",
" rf_regressor, X_test, y_test, \"测试集\"\n",
" )\n",
" print(f\"\\n✅ 测试集评估完成\")\n",
" else:\n",
" print(f\"⚠️ 测试集数据准备失败\")\n",
" \n",
" print(f\"\\n🎉 随机森林回归模型完整评估完成!\")\n",
" print(f\"📊 生成了详细的性能分析和可视化图表\")\n",
" print(f\"🔧 模型已准备好用于实际预测任务\")\n",
" \n",
"else:\n",
" print(\"⚠️ 模型尚未训练完成,请先运行训练代码\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ 回归转分类分析功能已创建!\n",
"🎯 主要功能:\n",
"• 将40维概率回归结果转换为分类预测\n",
"• 计算分类准确率和置信度分析\n",
"• 提供Top-K准确率评估\n",
"• 生成详细的混淆矩阵和错误分析\n",
"• 创建全面的可视化图表\n"
]
}
],
"source": [
"# 🎯 回归结果转分类结果分析\n",
"def regression_to_classification_analysis(y_true_probs, y_pred_probs, show_detailed_metrics=True):\n",
" \"\"\"\n",
" 将回归预测的40个音素概率转换为分类结果并进行分析\n",
" \n",
" 参数:\n",
" y_true_probs: 真实的40个音素概率 [n_samples, 40]\n",
" y_pred_probs: 预测的40个音素概率 [n_samples, 40]\n",
" show_detailed_metrics: 是否显示详细的分类指标\n",
" \n",
" 返回:\n",
" classification_results: 包含分类结果的字典\n",
" \"\"\"\n",
" print(\"🎯 回归结果转分类结果分析\")\n",
" print(\"=\"*60)\n",
" \n",
" # 1. 将概率转换为分类标签\n",
" y_true_classes = np.argmax(y_true_probs, axis=1) # 真实类别\n",
" y_pred_classes = np.argmax(y_pred_probs, axis=1) # 预测类别\n",
" \n",
" # 2. 计算分类准确率\n",
" accuracy = (y_true_classes == y_pred_classes).mean()\n",
" \n",
" print(f\"📊 分类结果概览:\")\n",
" print(f\" 总样本数: {len(y_true_classes):,}\")\n",
" print(f\" 分类准确率: {accuracy:.4f} ({accuracy*100:.2f}%)\")\n",
" print(f\" 正确预测: {(y_true_classes == y_pred_classes).sum():,}\")\n",
" print(f\" 错误预测: {(y_true_classes != y_pred_classes).sum():,}\")\n",
" \n",
" # 3. 分析预测置信度\n",
" pred_confidences = np.max(y_pred_probs, axis=1) # 预测的最大概率\n",
" true_confidences = np.max(y_true_probs, axis=1) # 真实的最大概率\n",
" \n",
" print(f\"\\n🔍 预测置信度分析:\")\n",
" print(f\" 预测置信度均值: {pred_confidences.mean():.4f}\")\n",
" print(f\" 预测置信度标准差: {pred_confidences.std():.4f}\")\n",
" print(f\" 预测置信度范围: [{pred_confidences.min():.4f}, {pred_confidences.max():.4f}]\")\n",
" print(f\" 真实置信度均值: {true_confidences.mean():.4f}\")\n",
" \n",
" # 4. 按置信度分组的准确率分析\n",
" confidence_bins = [0.0, 0.3, 0.5, 0.7, 0.9, 1.0]\n",
" print(f\"\\n📈 按预测置信度分组的准确率:\")\n",
" print(f\"{'置信度区间':>12} {'样本数':>8} {'准确率':>8} {'百分比':>8}\")\n",
" print(\"-\" * 40)\n",
" \n",
" for i in range(len(confidence_bins)-1):\n",
" low, high = confidence_bins[i], confidence_bins[i+1]\n",
" mask = (pred_confidences >= low) & (pred_confidences < high)\n",
" if i == len(confidence_bins)-2: # 最后一个区间包含等号\n",
" mask = (pred_confidences >= low) & (pred_confidences <= high)\n",
" \n",
" if mask.sum() > 0:\n",
" bin_accuracy = (y_true_classes[mask] == y_pred_classes[mask]).mean()\n",
" count = mask.sum()\n",
" percentage = count / len(y_true_classes) * 100\n",
" print(f\"[{low:.1f}, {high:.1f}{')'if i<len(confidence_bins)-2 else ']':>1} {count:>8} {bin_accuracy:>8.4f} {percentage:>7.1f}%\")\n",
" \n",
" # 5. 混淆矩阵分析Top-K音素\n",
" from collections import Counter\n",
" \n",
" # 找出最常见的音素\n",
" true_counter = Counter(y_true_classes)\n",
" pred_counter = Counter(y_pred_classes)\n",
" \n",
" most_common_true = true_counter.most_common(10)\n",
" most_common_pred = pred_counter.most_common(10)\n",
" \n",
" print(f\"\\n🏆 最常见的音素 (真实 vs 预测):\")\n",
" print(f\"{'真实音素':>12} {'次数':>6} {'预测音素':>12} {'次数':>6}\")\n",
" print(\"-\" * 42)\n",
" \n",
" for i in range(min(len(most_common_true), len(most_common_pred))):\n",
" true_id, true_count = most_common_true[i]\n",
" pred_id, pred_count = most_common_pred[i]\n",
" true_name = LOGIT_TO_PHONEME[true_id]\n",
" pred_name = LOGIT_TO_PHONEME[pred_id]\n",
" print(f\"{true_name:>12} {true_count:>6} {pred_name:>12} {pred_count:>6}\")\n",
" \n",
" # 6. 每个音素的分类性能\n",
" if show_detailed_metrics:\n",
" from sklearn.metrics import classification_report, confusion_matrix\n",
" \n",
" print(f\"\\n📋 详细分类报告 (前20个最常见音素):\")\n",
" \n",
" # 获取前20个最常见的音素\n",
" top_20_phonemes = [phoneme_id for phoneme_id, _ in most_common_true[:20]]\n",
" \n",
" # 创建掩码,只包含这些音素\n",
" mask_top20 = np.isin(y_true_classes, top_20_phonemes)\n",
" y_true_top20 = y_true_classes[mask_top20]\n",
" y_pred_top20 = y_pred_classes[mask_top20]\n",
" \n",
" # 生成分类报告\n",
" target_names = [LOGIT_TO_PHONEME[i] for i in top_20_phonemes]\n",
" \n",
" try:\n",
" report = classification_report(\n",
" y_true_top20, y_pred_top20, \n",
" labels=top_20_phonemes,\n",
" target_names=target_names,\n",
" output_dict=True,\n",
" zero_division=0\n",
" )\n",
" \n",
" # 打印格式化的报告\n",
" print(f\"{'音素':>8} {'精确率':>8} {'召回率':>8} {'F1分数':>8} {'支持数':>8}\")\n",
" print(\"-\" * 48)\n",
" \n",
" for phoneme_id in top_20_phonemes:\n",
" phoneme_name = LOGIT_TO_PHONEME[phoneme_id]\n",
" if phoneme_name in report:\n",
" metrics = report[phoneme_name]\n",
" print(f\"{phoneme_name:>8} {metrics['precision']:>8.4f} {metrics['recall']:>8.4f} \"\n",
" f\"{metrics['f1-score']:>8.4f} {int(metrics['support']):>8}\")\n",
" \n",
" # 总体指标\n",
" macro_avg = report['macro avg']\n",
" weighted_avg = report['weighted avg']\n",
" print(\"-\" * 48)\n",
" print(f\"{'宏平均':>8} {macro_avg['precision']:>8.4f} {macro_avg['recall']:>8.4f} \"\n",
" f\"{macro_avg['f1-score']:>8.4f}\")\n",
" print(f\"{'加权平均':>8} {weighted_avg['precision']:>8.4f} {weighted_avg['recall']:>8.4f} \"\n",
" f\"{weighted_avg['f1-score']:>8.4f}\")\n",
" \n",
" except Exception as e:\n",
" print(f\"分类报告生成失败: {e}\")\n",
" \n",
" # 7. Top-K准确率分析\n",
" print(f\"\\n🎯 Top-K 准确率分析:\")\n",
" for k in [1, 3, 5, 10]:\n",
" # 计算Top-K准确率\n",
" top_k_pred = np.argsort(y_pred_probs, axis=1)[:, -k:] # 取概率最高的K个\n",
" top_k_accuracy = np.mean([y_true_classes[i] in top_k_pred[i] for i in range(len(y_true_classes))])\n",
" print(f\" Top-{k} 准确率: {top_k_accuracy:.4f} ({top_k_accuracy*100:.2f}%)\")\n",
" \n",
" # 8. 错误分析 - 最常见的预测错误\n",
" print(f\"\\n❌ 最常见的预测错误:\")\n",
" error_mask = y_true_classes != y_pred_classes\n",
" error_pairs = list(zip(y_true_classes[error_mask], y_pred_classes[error_mask]))\n",
" error_counter = Counter(error_pairs)\n",
" \n",
" print(f\"{'真实音素':>12} {'预测音素':>12} {'错误次数':>8}\")\n",
" print(\"-\" * 36)\n",
" for (true_id, pred_id), count in error_counter.most_common(10):\n",
" true_name = LOGIT_TO_PHONEME[true_id]\n",
" pred_name = LOGIT_TO_PHONEME[pred_id]\n",
" print(f\"{true_name:>12} {pred_name:>12} {count:>8}\")\n",
" \n",
" # 返回结果字典\n",
" classification_results = {\n",
" 'accuracy': accuracy,\n",
" 'y_true_classes': y_true_classes,\n",
" 'y_pred_classes': y_pred_classes,\n",
" 'pred_confidences': pred_confidences,\n",
" 'true_confidences': true_confidences,\n",
" 'most_common_errors': error_counter.most_common(10)\n",
" }\n",
" \n",
" return classification_results\n",
"\n",
"def create_classification_visualizations(y_true_probs, y_pred_probs, classification_results):\n",
" \"\"\"\n",
" 为分类结果创建可视化图表\n",
" \"\"\"\n",
" print(f\"\\n📊 创建分类结果可视化...\")\n",
" \n",
" fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n",
" fig.suptitle('随机森林回归转分类结果分析', fontsize=16, fontweight='bold')\n",
" \n",
" y_true_classes = classification_results['y_true_classes']\n",
" y_pred_classes = classification_results['y_pred_classes']\n",
" pred_confidences = classification_results['pred_confidences']\n",
" \n",
" # 1. 预测置信度分布\n",
" axes[0, 0].hist(pred_confidences, bins=50, alpha=0.7, color='skyblue', edgecolor='black')\n",
" axes[0, 0].axvline(pred_confidences.mean(), color='red', linestyle='--', \n",
" label=f'均值: {pred_confidences.mean():.3f}')\n",
" axes[0, 0].set_xlabel('预测置信度')\n",
" axes[0, 0].set_ylabel('样本数量')\n",
" axes[0, 0].set_title('预测置信度分布')\n",
" axes[0, 0].legend()\n",
" axes[0, 0].grid(True, alpha=0.3)\n",
" \n",
" # 2. 准确率 vs 置信度\n",
" confidence_bins = np.linspace(0, 1, 21)\n",
" bin_centers = (confidence_bins[:-1] + confidence_bins[1:]) / 2\n",
" bin_accuracies = []\n",
" bin_counts = []\n",
" \n",
" for i in range(len(confidence_bins)-1):\n",
" mask = (pred_confidences >= confidence_bins[i]) & (pred_confidences < confidence_bins[i+1])\n",
" if mask.sum() > 0:\n",
" accuracy = (y_true_classes[mask] == y_pred_classes[mask]).mean()\n",
" bin_accuracies.append(accuracy)\n",
" bin_counts.append(mask.sum())\n",
" else:\n",
" bin_accuracies.append(0)\n",
" bin_counts.append(0)\n",
" \n",
" # 只显示有数据的bins\n",
" valid_bins = np.array(bin_counts) > 0\n",
" axes[0, 1].plot(bin_centers[valid_bins], np.array(bin_accuracies)[valid_bins], \n",
" 'bo-', linewidth=2, markersize=6)\n",
" axes[0, 1].set_xlabel('预测置信度')\n",
" axes[0, 1].set_ylabel('准确率')\n",
" axes[0, 1].set_title('准确率 vs 预测置信度')\n",
" axes[0, 1].grid(True, alpha=0.3)\n",
" axes[0, 1].set_ylim(0, 1)\n",
" \n",
" # 3. 最常见音素的预测准确率\n",
" from collections import Counter\n",
" true_counter = Counter(y_true_classes)\n",
" most_common_phonemes = [phoneme_id for phoneme_id, _ in true_counter.most_common(15)]\n",
" \n",
" phoneme_accuracies = []\n",
" phoneme_names = []\n",
" for phoneme_id in most_common_phonemes:\n",
" mask = y_true_classes == phoneme_id\n",
" if mask.sum() > 0:\n",
" accuracy = (y_pred_classes[mask] == phoneme_id).mean()\n",
" phoneme_accuracies.append(accuracy)\n",
" phoneme_names.append(LOGIT_TO_PHONEME[phoneme_id])\n",
" \n",
" bars = axes[0, 2].bar(range(len(phoneme_names)), phoneme_accuracies, \n",
" color='lightgreen', alpha=0.7)\n",
" axes[0, 2].set_xlabel('音素')\n",
" axes[0, 2].set_ylabel('准确率')\n",
" axes[0, 2].set_title('Top 15 音素的分类准确率')\n",
" axes[0, 2].set_xticks(range(len(phoneme_names)))\n",
" axes[0, 2].set_xticklabels(phoneme_names, rotation=45, ha='right')\n",
" axes[0, 2].grid(True, alpha=0.3)\n",
" \n",
" # 添加数值标签\n",
" for bar, acc in zip(bars, phoneme_accuracies):\n",
" height = bar.get_height()\n",
" axes[0, 2].text(bar.get_x() + bar.get_width()/2., height + 0.01,\n",
" f'{acc:.3f}', ha='center', va='bottom', fontsize=8)\n",
" \n",
" # 4. 混淆矩阵前10个最常见音素\n",
" from sklearn.metrics import confusion_matrix\n",
" top_10_phonemes = most_common_phonemes[:10]\n",
" mask_top10 = np.isin(y_true_classes, top_10_phonemes) & np.isin(y_pred_classes, top_10_phonemes)\n",
" \n",
" if mask_top10.sum() > 0:\n",
" cm = confusion_matrix(y_true_classes[mask_top10], y_pred_classes[mask_top10], \n",
" labels=top_10_phonemes)\n",
" \n",
" im = axes[1, 0].imshow(cm, interpolation='nearest', cmap='Blues')\n",
" axes[1, 0].set_title('混淆矩阵 (Top 10 音素)')\n",
" \n",
" # 添加颜色条\n",
" cbar = plt.colorbar(im, ax=axes[1, 0], shrink=0.8)\n",
" cbar.set_label('预测次数')\n",
" \n",
" # 设置标签\n",
" tick_marks = np.arange(len(top_10_phonemes))\n",
" top_10_names = [LOGIT_TO_PHONEME[i] for i in top_10_phonemes]\n",
" axes[1, 0].set_xticks(tick_marks)\n",
" axes[1, 0].set_yticks(tick_marks)\n",
" axes[1, 0].set_xticklabels(top_10_names, rotation=45, ha='right')\n",
" axes[1, 0].set_yticklabels(top_10_names)\n",
" axes[1, 0].set_xlabel('预测音素')\n",
" axes[1, 0].set_ylabel('真实音素')\n",
" \n",
" # 5. Top-K准确率\n",
" k_values = [1, 2, 3, 4, 5, 10, 15, 20]\n",
" top_k_accuracies = []\n",
" \n",
" for k in k_values:\n",
" top_k_pred = np.argsort(y_pred_probs, axis=1)[:, -k:]\n",
" top_k_accuracy = np.mean([y_true_classes[i] in top_k_pred[i] for i in range(len(y_true_classes))])\n",
" top_k_accuracies.append(top_k_accuracy)\n",
" \n",
" axes[1, 1].plot(k_values, top_k_accuracies, 'ro-', linewidth=2, markersize=8)\n",
" axes[1, 1].set_xlabel('K 值')\n",
" axes[1, 1].set_ylabel('Top-K 准确率')\n",
" axes[1, 1].set_title('Top-K 准确率曲线')\n",
" axes[1, 1].grid(True, alpha=0.3)\n",
" axes[1, 1].set_ylim(0, 1)\n",
" \n",
" # 添加数值标签\n",
" for k, acc in zip(k_values, top_k_accuracies):\n",
" axes[1, 1].annotate(f'{acc:.3f}', (k, acc), textcoords=\"offset points\", \n",
" xytext=(0,10), ha='center')\n",
" \n",
" # 6. 错误分析 - 最常见错误的热力图\n",
" error_pairs = classification_results['most_common_errors'][:25] # 前25个最常见错误\n",
" if error_pairs:\n",
" # 创建错误矩阵\n",
" unique_phonemes = list(set([pair[0][0] for pair in error_pairs] + [pair[0][1] for pair in error_pairs]))\n",
" error_matrix = np.zeros((len(unique_phonemes), len(unique_phonemes)))\n",
" \n",
" phoneme_to_idx = {phoneme: i for i, phoneme in enumerate(unique_phonemes)}\n",
" \n",
" for (true_id, pred_id), count in error_pairs:\n",
" if true_id in phoneme_to_idx and pred_id in phoneme_to_idx:\n",
" true_idx = phoneme_to_idx[true_id]\n",
" pred_idx = phoneme_to_idx[pred_id]\n",
" error_matrix[true_idx, pred_idx] = count\n",
" \n",
" im = axes[1, 2].imshow(error_matrix, cmap='Reds', interpolation='nearest')\n",
" axes[1, 2].set_title('最常见错误分布')\n",
" \n",
" # 设置标签\n",
" phoneme_names = [LOGIT_TO_PHONEME[p] for p in unique_phonemes]\n",
" axes[1, 2].set_xticks(range(len(phoneme_names)))\n",
" axes[1, 2].set_yticks(range(len(phoneme_names)))\n",
" axes[1, 2].set_xticklabels(phoneme_names, rotation=45, ha='right')\n",
" axes[1, 2].set_yticklabels(phoneme_names)\n",
" axes[1, 2].set_xlabel('预测音素')\n",
" axes[1, 2].set_ylabel('真实音素')\n",
" \n",
" # 添加颜色条\n",
" cbar = plt.colorbar(im, ax=axes[1, 2], shrink=0.8)\n",
" cbar.set_label('错误次数')\n",
" \n",
" plt.tight_layout()\n",
" plt.savefig('./processed_datasets/classification_analysis.png', dpi=300, bbox_inches='tight')\n",
" print(\"📁 分类分析图表已保存至: ./processed_datasets/classification_analysis.png\")\n",
" plt.show()\n",
"\n",
"print(\"✅ 回归转分类分析功能已创建!\")\n",
"print(\"🎯 主要功能:\")\n",
"print(\"• 将40维概率回归结果转换为分类预测\")\n",
"print(\"• 计算分类准确率和置信度分析\")\n",
"print(\"• 提供Top-K准确率评估\")\n",
"print(\"• 生成详细的混淆矩阵和错误分析\")\n",
"print(\"• 创建全面的可视化图表\")"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"⚠️ 随机森林模型尚未训练完成\n",
"💡 请先运行前面的训练代码\n"
]
}
],
"source": [
"# 🎯 完整的回归转分类评估流程\n",
"def complete_regression_classification_evaluation(rf_model, X_test, y_test, dataset_name=\"测试集\"):\n",
" \"\"\"\n",
" 完整的回归模型转分类结果评估流程\n",
" \"\"\"\n",
" print(f\"\\n🎯 {dataset_name}完整评估: 回归 → 分类\")\n",
" print(\"=\"*70)\n",
" \n",
" # 1. 获取回归预测结果\n",
" print(\"📊 第1步: 获取回归预测...\")\n",
" y_pred_probs = rf_model.predict(X_test)\n",
" \n",
" # 确保概率值在合理范围内\n",
" y_pred_probs = np.clip(y_pred_probs, 0, 1)\n",
" \n",
" # 2. 回归性能评估\n",
" print(\"\\n📈 第2步: 回归性能评估...\")\n",
" mse = mean_squared_error(y_test, y_pred_probs) \n",
" mae = mean_absolute_error(y_test, y_pred_probs)\n",
" r2 = r2_score(y_test, y_pred_probs)\n",
" \n",
" print(f\" 回归 MSE: {mse:.6f}\")\n",
" print(f\" 回归 MAE: {mae:.6f}\")\n",
" print(f\" 回归 R²: {r2:.4f}\")\n",
" \n",
" # 3. 概率归一化softmax\n",
" print(\"\\n🔄 第3步: 概率归一化...\")\n",
" def softmax(x, axis=-1):\n",
" exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))\n",
" return exp_x / np.sum(exp_x, axis=axis, keepdims=True)\n",
" \n",
" # 对预测结果应用softmax使其成为真正的概率分布\n",
" y_pred_probs_normalized = softmax(y_pred_probs)\n",
" y_test_normalized = softmax(y_test) # 也对真实标签归一化\n",
" \n",
" print(f\" 预测概率归一化前: 每行和均值 = {np.mean(np.sum(y_pred_probs, axis=1)):.4f}\")\n",
" print(f\" 预测概率归一化后: 每行和均值 = {np.mean(np.sum(y_pred_probs_normalized, axis=1)):.4f}\")\n",
" \n",
" # 4. 分类结果分析\n",
" print(\"\\n🎯 第4步: 分类结果分析...\")\n",
" classification_results = regression_to_classification_analysis(\n",
" y_test_normalized, y_pred_probs_normalized, show_detailed_metrics=True\n",
" )\n",
" \n",
" # 5. 创建可视化\n",
" print(\"\\n📊 第5步: 创建可视化图表...\")\n",
" create_classification_visualizations(y_test_normalized, y_pred_probs_normalized, classification_results)\n",
" \n",
" # 6. 保存结果\n",
" print(\"\\n💾 第6步: 保存分析结果...\")\n",
" \n",
" # 保存分类结果\n",
" results_df = pd.DataFrame({\n",
" 'true_class': classification_results['y_true_classes'],\n",
" 'pred_class': classification_results['y_pred_classes'],\n",
" 'true_phoneme': [LOGIT_TO_PHONEME[i] for i in classification_results['y_true_classes']],\n",
" 'pred_phoneme': [LOGIT_TO_PHONEME[i] for i in classification_results['y_pred_classes']],\n",
" 'pred_confidence': classification_results['pred_confidences'],\n",
" 'is_correct': classification_results['y_true_classes'] == classification_results['y_pred_classes']\n",
" })\n",
" \n",
" results_df.to_csv('./processed_datasets/classification_results.csv', index=False)\n",
" \n",
" # 保存详细的概率预测\n",
" prob_results_df = pd.DataFrame(y_pred_probs_normalized, \n",
" columns=[f'prob_{LOGIT_TO_PHONEME[i]}' for i in range(40)])\n",
" prob_results_df['true_class'] = classification_results['y_true_classes']\n",
" prob_results_df['pred_class'] = classification_results['y_pred_classes']\n",
" \n",
" prob_results_df.to_csv('./processed_datasets/probability_predictions.csv', index=False)\n",
" \n",
" print(\"📁 结果已保存:\")\n",
" print(\" • ./processed_datasets/classification_results.csv (分类结果)\")\n",
" print(\" • ./processed_datasets/probability_predictions.csv (概率预测)\")\n",
" print(\" • ./processed_datasets/classification_analysis.png (可视化图表)\")\n",
" \n",
" # 7. 总结报告\n",
" print(f\"\\n📋 {dataset_name}评估总结:\")\n",
" print(\"=\"*50)\n",
" print(f\"🔸 回归性能:\")\n",
" print(f\" MSE: {mse:.6f}\")\n",
" print(f\" R²: {r2:.4f}\")\n",
" print(f\"🔸 分类性能:\")\n",
" print(f\" 准确率: {classification_results['accuracy']:.4f} ({classification_results['accuracy']*100:.2f}%)\")\n",
" print(f\" 平均置信度: {classification_results['pred_confidences'].mean():.4f}\")\n",
" \n",
" # 计算Top-K准确率\n",
" for k in [1, 3, 5]:\n",
" top_k_pred = np.argsort(y_pred_probs_normalized, axis=1)[:, -k:]\n",
" top_k_accuracy = np.mean([classification_results['y_true_classes'][i] in top_k_pred[i] \n",
" for i in range(len(classification_results['y_true_classes']))])\n",
" print(f\" Top-{k} 准确率: {top_k_accuracy:.4f} ({top_k_accuracy*100:.2f}%)\")\n",
" \n",
" return {\n",
" 'regression_metrics': {'mse': mse, 'mae': mae, 'r2': r2},\n",
" 'classification_results': classification_results,\n",
" 'normalized_predictions': y_pred_probs_normalized,\n",
" 'normalized_true': y_test_normalized\n",
" }\n",
"\n",
"# 如果模型已训练且有验证数据,执行完整评估\n",
"if 'rf_regressor' in locals() and hasattr(rf_regressor, 'is_fitted') and rf_regressor.is_fitted:\n",
" if 'X_val' in locals() and X_val is not None and 'y_val' in locals() and y_val is not None:\n",
" print(\"🚀 开始完整的回归转分类评估...\")\n",
" \n",
" # 执行完整评估\n",
" evaluation_results = complete_regression_classification_evaluation(\n",
" rf_regressor, X_val, y_val, \"验证集\"\n",
" )\n",
" \n",
" print(f\"\\n🎉 评估完成!\")\n",
" print(f\"✅ 随机森林回归模型成功转换为分类结果\")\n",
" print(f\"📊 生成了详细的性能分析和可视化\")\n",
" print(f\"💾 所有结果已保存到文件\")\n",
" \n",
" # 如果有测试数据,也进行评估\n",
" if 'X_test' in locals() and X_test is not None and 'y_test' in locals() and y_test is not None:\n",
" print(f\"\\n🔮 开始测试集评估...\")\n",
" test_evaluation_results = complete_regression_classification_evaluation(\n",
" rf_regressor, X_test, y_test, \"测试集\"\n",
" )\n",
" else:\n",
" print(\"⚠️ 没有可用的验证数据进行评估\")\n",
"else:\n",
" print(\"⚠️ 随机森林模型尚未训练完成\")\n",
" print(\"💡 请先运行前面的训练代码\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kaggle": {
"accelerator": "tpu1vmV38",
"dataSources": [
{
"databundleVersionId": 13056355,
"sourceId": 106809,
"sourceType": "competition"
}
],
"dockerImageVersionId": 31091,
"isGpuEnabled": false,
"isInternetEnabled": true,
"language": "python",
"sourceType": "notebook"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}