{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 环境配置 与 utils" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://download.pytorch.org/whl/cu126\n", "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n", "Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (0.21.0+cu124)\n", "Requirement already satisfied: torchaudio in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n", "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.14.0)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.5)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.5.1)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch) (9.1.0.70)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.5.8)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch) (11.2.1.3)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch) (10.3.5.147)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch) (11.6.1.9)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch) (12.3.1.170)\n", "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision) (1.26.4)\n", "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (11.2.1)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n", "Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.3.8)\n", "Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.2.4)\n", "Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (0.1.1)\n", "Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2025.2.0)\n", "Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2022.2.0)\n", "Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2.4.1)\n", "Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2024.2.0)\n", "Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2022.2.0)\n", "Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy->torchvision) (1.4.0)\n", "Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy->torchvision) (2024.2.0)\n", "Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy->torchvision) (2024.2.0)\n", "Requirement already satisfied: jupyter==1.1.1 in /usr/local/lib/python3.11/dist-packages (1.1.1)\n", "Requirement already satisfied: numpy<2.1.0,>=1.26.0 in /usr/local/lib/python3.11/dist-packages (1.26.4)\n", "Requirement already satisfied: pandas==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n", "Requirement already satisfied: matplotlib==3.10.1 in /usr/local/lib/python3.11/dist-packages (3.10.1)\n", "Requirement already satisfied: scipy==1.15.2 in /usr/local/lib/python3.11/dist-packages (1.15.2)\n", "Requirement already satisfied: scikit-learn==1.6.1 in /usr/local/lib/python3.11/dist-packages (1.6.1)\n", "Requirement already satisfied: tqdm==4.67.1 in /usr/local/lib/python3.11/dist-packages (4.67.1)\n", "Requirement already satisfied: g2p_en==2.1.0 in /usr/local/lib/python3.11/dist-packages (2.1.0)\n", "Requirement already satisfied: h5py==3.13.0 in /usr/local/lib/python3.11/dist-packages (3.13.0)\n", "Requirement already satisfied: omegaconf==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n", "Requirement already satisfied: editdistance==0.8.1 in /usr/local/lib/python3.11/dist-packages (0.8.1)\n", "Requirement already satisfied: huggingface-hub==0.33.1 in /usr/local/lib/python3.11/dist-packages (0.33.1)\n", "Requirement already satisfied: transformers==4.53.0 in /usr/local/lib/python3.11/dist-packages (4.53.0)\n", "Requirement already satisfied: tokenizers==0.21.2 in /usr/local/lib/python3.11/dist-packages (0.21.2)\n", "Requirement already satisfied: accelerate==1.8.1 in /usr/local/lib/python3.11/dist-packages (1.8.1)\n", "Requirement already satisfied: bitsandbytes==0.46.0 in /usr/local/lib/python3.11/dist-packages (0.46.0)\n", "Requirement already satisfied: notebook in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.5.4)\n", "Requirement already satisfied: jupyter-console in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.1.0)\n", "Requirement already satisfied: nbconvert in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.4.5)\n", "Requirement already satisfied: ipykernel in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.17.1)\n", "Requirement already satisfied: ipywidgets in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (8.1.5)\n", "Requirement already satisfied: jupyterlab in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (3.6.8)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2.9.0.post0)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n", "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n", "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.3.2)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (0.12.1)\n", "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (4.58.4)\n", "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.4.8)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (25.0)\n", "Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (11.2.1)\n", "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (3.0.9)\n", "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (1.5.1)\n", "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (3.6.0)\n", "Requirement already satisfied: nltk>=3.2.4 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (3.9.1)\n", "Requirement already satisfied: inflect>=0.3.1 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (7.5.0)\n", "Requirement already satisfied: distance>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (0.1.3)\n", "Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (4.9.3)\n", "Requirement already satisfied: PyYAML>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (6.0.2)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (3.18.0)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2025.5.1)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2.32.4)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (4.14.0)\n", "Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (1.1.5)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (2024.11.6)\n", "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (0.5.3)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (7.0.0)\n", "Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (2.6.0+cu124)\n", "Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.3.8)\n", "Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.2.4)\n", "Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (0.1.1)\n", "Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2025.2.0)\n", "Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2022.2.0)\n", "Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2.4.1)\n", "Requirement already satisfied: more_itertools>=8.5.0 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (10.7.0)\n", "Requirement already satisfied: typeguard>=4.0.1 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (4.4.4)\n", "Requirement already satisfied: click in /usr/local/lib/python3.11/dist-packages (from nltk>=3.2.4->g2p_en==2.1.0) (8.2.1)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas==2.3.0) (1.17.0)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.5)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.1.6)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (9.1.0.70)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.5.8)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.2.1.3)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (10.3.5.147)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.6.1.9)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.3.1.170)\n", "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (0.6.2)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.2.0)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=2.0.0->accelerate==1.8.1) (1.3.0)\n", "Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.8.0)\n", "Requirement already satisfied: ipython>=7.23.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (7.34.0)\n", "Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (8.6.3)\n", "Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (0.1.7)\n", "Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.6.0)\n", "Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (24.0.1)\n", "Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (6.5.1)\n", "Requirement already satisfied: traitlets>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (5.7.1)\n", "Requirement already satisfied: comm>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (0.2.2)\n", "Requirement already satisfied: widgetsnbextension~=4.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (4.0.14)\n", "Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (3.0.15)\n", "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (3.0.51)\n", "Requirement already satisfied: pygments in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (2.19.2)\n", "Requirement already satisfied: jupyter-core in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (5.8.1)\n", "Requirement already satisfied: jupyterlab-server~=2.19 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.27.3)\n", "Requirement already satisfied: jupyter-server<3,>=1.16.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.12.5)\n", "Requirement already satisfied: jupyter-ydoc~=0.2.4 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.2.5)\n", "Requirement already satisfied: jupyter-server-ydoc~=0.8.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.8.0)\n", "Requirement already satisfied: nbclassic in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (1.3.1)\n", "Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (25.1.0)\n", "Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.2.0)\n", "Requirement already satisfied: nbformat in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (5.10.4)\n", "Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (1.8.3)\n", "Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.18.1)\n", "Requirement already satisfied: prometheus-client in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.22.1)\n", "Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.8.4)\n", "Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.3.0)\n", "Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.4)\n", "Requirement already satisfied: bleach in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (6.2.0)\n", "Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (1.5.1)\n", "Requirement already satisfied: testpath in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.6.0)\n", "Requirement already satisfied: defusedxml in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.7.1)\n", "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (4.13.4)\n", "Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.5.13)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (3.0.2)\n", "Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n", "Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2022.2.0)\n", "Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy<2.1.0,>=1.26.0) (1.4.0)\n", "Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy<2.1.0,>=1.26.0) (2024.2.0)\n", "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.4.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2.5.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2025.6.15)\n", "Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n", "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (75.2.0)\n", "Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.19.2)\n", "Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.4.2)\n", "Requirement already satisfied: pickleshare in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.7.5)\n", "Requirement already satisfied: backcall in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.2.0)\n", "Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.9.0)\n", "Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.11/dist-packages (from jupyter-core->jupyterlab->jupyter==1.1.1) (4.3.8)\n", "Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (4.9.0)\n", "Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.12.0)\n", "Requirement already satisfied: jupyter-server-terminals in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.5.3)\n", "Requirement already satisfied: overrides in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (7.7.0)\n", "Requirement already satisfied: websocket-client in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.8.0)\n", "Requirement already satisfied: jupyter-server-fileid<1,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.9.3)\n", "Requirement already satisfied: ypy-websocket<0.9.0,>=0.8.2 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.8.4)\n", "Requirement already satisfied: y-py<0.7.0,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-ydoc~=0.2.4->jupyterlab->jupyter==1.1.1) (0.6.2)\n", "Requirement already satisfied: babel>=2.10 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2.17.0)\n", "Requirement already satisfied: json5>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.12.0)\n", "Requirement already satisfied: jsonschema>=4.18.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (4.24.0)\n", "Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.11/dist-packages (from nbclassic->jupyterlab->jupyter==1.1.1) (0.2.4)\n", "Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.11/dist-packages (from nbformat->notebook->jupyter==1.1.1) (2.21.1)\n", "Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->jupyter-console->jupyter==1.1.1) (0.2.13)\n", "Requirement already satisfied: ptyprocess in /usr/local/lib/python3.11/dist-packages (from terminado>=0.8.3->notebook->jupyter==1.1.1) (0.7.0)\n", "Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.11/dist-packages (from argon2-cffi->notebook->jupyter==1.1.1) (21.2.0)\n", "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->nbconvert->jupyter==1.1.1) (2.7)\n", "Requirement already satisfied: webencodings in /usr/local/lib/python3.11/dist-packages (from bleach->nbconvert->jupyter==1.1.1) (0.5.1)\n", "Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio>=3.1.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.1)\n", "Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.11/dist-packages (from jedi>=0.16->ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.8.4)\n", "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (25.3.0)\n", "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2025.4.1)\n", "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.36.2)\n", "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.25.1)\n", "Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.3.0)\n", "Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.4)\n", "Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.1)\n", "Requirement already satisfied: aiofiles<23,>=22.1.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (22.1.0)\n", "Requirement already satisfied: aiosqlite<1,>=0.17.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.21.0)\n", "Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (1.17.1)\n", "Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (2.22)\n", "Requirement already satisfied: fqdn in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.5.1)\n", "Requirement already satisfied: isoduration in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (20.11.0)\n", "Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.0.0)\n", "Requirement already satisfied: uri-template in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n", "Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (24.11.1)\n", "Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.11/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n", "Requirement already satisfied: types-python-dateutil>=2.8.10 in /usr/local/lib/python3.11/dist-packages (from arrow>=0.15.0->isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (2.9.0.20250516)\n", "Obtaining file:///kaggle/working/nejm-brain-to-text\n", " Preparing metadata (setup.py): started\n", " Preparing metadata (setup.py): finished with status 'done'\n", "Installing collected packages: nejm_b2txt_utils\n", " Running setup.py develop for nejm_b2txt_utils\n", "Successfully installed nejm_b2txt_utils-0.0.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Cloning into 'nejm-brain-to-text'...\n" ] } ], "source": [ "%%bash\n", "cd /kaggle/working/\n", "rm -rf /kaggle/working/nejm-brain-to-text/\n", "git clone https://github.com/ZH-CEN/nejm-brain-to-text.git\n", "cd /kaggle/working/nejm-brain-to-text/\n", "cp /kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl /kaggle/working/nejm-brain-to-text/data/t15_copyTask.pkl\n", "\n", "# Install the local package\n", "\n", "ln -s /kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline /kaggle/working/nejm-brain-to-text/data\n", "ln -s /kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final /kaggle/working/nejm-brain-to-text/data\n", "ln -s /kaggle/input/rnn-pretagged-data /kaggle/working/nejm-brain-to-text/data\n", "\n", "# # Install PyTorch with CUDA 12.6\n", "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n", "\n", "# Install additional packages with compatible versions\n", "\n", "pip install \\\n", " redis==5.2.1 \\\n", " jupyter==1.1.1 \\\n", " \"numpy>=1.26.0,<2.1.0\" \\\n", " pandas==2.3.0 \\\n", " matplotlib==3.10.1 \\\n", " scipy==1.15.2 \\\n", " scikit-learn==1.6.1 \\\n", " tqdm==4.67.1 \\\n", " g2p_en==2.1.0 \\\n", " h5py==3.13.0 \\\n", " omegaconf==2.3.0 \\\n", " editdistance==0.8.1 \\\n", " huggingface-hub==0.33.1 \\\n", " transformers==4.53.0 \\\n", " tokenizers==0.21.2 \\\n", " accelerate==1.8.1 \\\n", " bitsandbytes==0.46.0\n", "pip install -e ." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working/nejm-brain-to-text\n" ] } ], "source": [ "%cd /kaggle/working/nejm-brain-to-text\n", "import numpy as np\n", "import os\n", "import pickle\n", "import matplotlib.pyplot as plt\n", "import matplotlib\n", "from g2p_en import G2p\n", "import pandas as pd\n", "import numpy as np\n", "from nejm_b2txt_utils.general_utils import *\n", "\n", "matplotlib.rcParams['pdf.fonttype'] = 42\n", "matplotlib.rcParams['ps.fonttype'] = 42\n", "matplotlib.rcParams['font.family'] = 'sans-serif'\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working/nejm-brain-to-text/model_training\n" ] } ], "source": [ "%cd model_training/\n", "from data_augmentations import gauss_smooth\n", "# single decoding step function that also returns smoothed input\n", "# smooths data and puts it through the model, returning both logits and smoothed input.\n", "def runSingleDecodingStepWithSmoothedInput(x, input_layer, model, model_args, device):\n", "\n", " # Use autocast for efficiency\n", " with torch.autocast(device_type = \"cuda\", enabled = model_args['use_amp'], dtype = torch.bfloat16):\n", "\n", " smoothed_x = gauss_smooth(\n", " inputs = x, \n", " device = device,\n", " smooth_kernel_std = model_args['dataset']['data_transforms']['smooth_kernel_std'],\n", " smooth_kernel_size = model_args['dataset']['data_transforms']['smooth_kernel_size'],\n", " padding = 'valid',\n", " )\n", "\n", " with torch.no_grad():\n", " logits, _ = model(\n", " x = smoothed_x,\n", " day_idx = torch.tensor([input_layer], device=device),\n", " states = None, # no initial states\n", " return_state = True,\n", " )\n", "\n", " # convert both logits and smoothed input from bfloat16 to float32\n", " logits = logits.float().cpu().numpy()\n", " smoothed_input = smoothed_x.float().cpu().numpy()\n", "\n", " # # original order is [BLANK, phonemes..., SIL]\n", " # # rearrange so the order is [BLANK, SIL, phonemes...]\n", " # logits = rearrange_speech_logits_pt(logits)\n", "\n", " return logits, smoothed_input\n", "\n", "\n", "import h5py\n", "def load_h5py_file(file_path, b2txt_csv_df):\n", " data = {\n", " 'neural_features': [],\n", " 'n_time_steps': [],\n", " 'seq_class_ids': [],\n", " 'seq_len': [],\n", " 'transcriptions': [],\n", " 'sentence_label': [],\n", " 'session': [],\n", " 'block_num': [],\n", " 'trial_num': [],\n", " 'corpus': [],\n", " }\n", " # Open the hdf5 file for that day\n", " with h5py.File(file_path, 'r') as f:\n", "\n", " keys = list(f.keys())\n", "\n", " # For each trial in the selected trials in that day\n", " for key in keys:\n", " g = f[key]\n", "\n", " neural_features = g['input_features'][:] # pyright: ignore[reportIndexIssue]\n", " n_time_steps = g.attrs['n_time_steps']\n", " seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None # type: ignore\n", " seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None\n", " transcription = g['transcription'][:] if 'transcription' in g else None # type: ignore\n", " sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None # pyright: ignore[reportIndexIssue]\n", " session = g.attrs['session']\n", " block_num = g.attrs['block_num']\n", " trial_num = g.attrs['trial_num']\n", "\n", " # match this trial up with the csv to get the corpus name\n", " year, month, day = session.split('.')[1:] # pyright: ignore[reportAttributeAccessIssue]\n", " date = f'{year}-{month}-{day}'\n", " row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)]\n", " corpus_name = row['Corpus'].values[0]\n", "\n", " data['neural_features'].append(neural_features)\n", " data['n_time_steps'].append(n_time_steps)\n", " data['seq_class_ids'].append(seq_class_ids)\n", " data['seq_len'].append(seq_len)\n", " data['transcriptions'].append(transcription)\n", " data['sentence_label'].append(sentence_label)\n", " data['session'].append(session)\n", " data['block_num'].append(block_num)\n", " data['trial_num'].append(trial_num)\n", " data['corpus'].append(corpus_name)\n", " return data" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "LOGIT_TO_PHONEME = [\n", " 'BLANK',\n", " 'AA', 'AE', 'AH', 'AO', 'AW',\n", " 'AY', 'B', 'CH', 'D', 'DH',\n", " 'EH', 'ER', 'EY', 'F', 'G',\n", " 'HH', 'IH', 'IY', 'JH', 'K',\n", " 'L', 'M', 'N', 'NG', 'OW',\n", " 'OY', 'P', 'R', 'S', 'SH',\n", " 'T', 'TH', 'UH', 'UW', 'V',\n", " 'W', 'Y', 'Z', 'ZH',\n", " ' | ',\n", "]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 数据分析与预处理" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 数据准备" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working/nejm-brain-to-text\n" ] } ], "source": [ "%cd /kaggle/working/nejm-brain-to-text/" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "data = load_h5py_file(file_path='data/hdf5_data_final/t15.2023.08.11/data_train.hdf5',\n", " b2txt_csv_df=pd.read_csv('data/t15_copyTaskData_description.csv'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **任务介绍** :机器学习解决高维信号的模式识别问题" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们的数据集标签缺少时间戳,现在要进行的是半监督学习" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 音素时间均等分割或者按照调研数据设定初始长度。然后筛掉异常值。提取出可用的训练集,再控制时间长短,查看样本类的长度" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'neural_features': array([[ 2.3076649 , -0.78699756, -0.64687246, ..., 0.57367045,\n", " -0.7091646 , -0.11018186],\n", " [-0.5859305 , -0.78699756, -0.64687246, ..., 0.3122117 ,\n", " 1.7943763 , -0.76884896],\n", " [-0.5859305 , -0.78699756, -0.64687246, ..., -0.21193463,\n", " -0.8481289 , -0.7648201 ],\n", " ...,\n", " [-0.5859305 , 0.22756557, 0.9262037 , ..., -0.34710956,\n", " 0.9710176 , 2.5397465 ],\n", " [-0.5859305 , 0.22756557, -0.64687246, ..., -0.83613133,\n", " -0.68723625, 0.10479005],\n", " [ 0.8608672 , -0.78699756, -0.64687246, ..., -0.7171131 ,\n", " 0.7417906 , -0.7008622 ]], dtype=float32),\n", " 'n_time_steps': 321,\n", " 'seq_class_ids': array([ 7, 28, 17, 24, 40, 17, 31, 40, 20, 21, 25, 29, 12, 40, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0], dtype=int32),\n", " 'seq_len': 14,\n", " 'transcriptions': array([ 66, 114, 105, 110, 103, 32, 105, 116, 32, 99, 108, 111, 115,\n", " 101, 114, 46, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0], dtype=int32),\n", " 'sentence_label': 'Bring it closer.',\n", " 'session': 't15.2023.08.11',\n", " 'block_num': 2,\n", " 'trial_num': 0,\n", " 'corpus': '50-Word'}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def data_patch(data, index):\n", " data_patch = {}\n", " data_patch['neural_features'] = data['neural_features'][index]\n", " data_patch['n_time_steps'] = data['n_time_steps'][index]\n", " data_patch['seq_class_ids'] = data['seq_class_ids'][index]\n", " data_patch['seq_len'] = data['seq_len'][index]\n", " data_patch['transcriptions'] = data['transcriptions'][index]\n", " data_patch['sentence_label'] = data['sentence_label'][index]\n", " data_patch['session'] = data['session'][index]\n", " data_patch['block_num'] = data['block_num'][index]\n", " data_patch['trial_num'] = data['trial_num'][index]\n", " data_patch['corpus'] = data['corpus'][index]\n", " return data_patch\n", "\n", "data_patch(data, 0)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "d1 = data_patch(data, 0)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Transcriptions non-zero length: 16\n", "Seq class ids non-zero length: 14\n", "Seq len: 14\n" ] } ], "source": [ "trans_len = len([x for x in d1['transcriptions'] if x != 0])\n", "seq_len_nonzero = len([x for x in d1['seq_class_ids'] if x != 0])\n", "seq_len = d1['seq_len']\n", "print(f\"Transcriptions non-zero length: {trans_len}\")\n", "print(f\"Seq class ids non-zero length: {seq_len_nonzero}\")\n", "print(f\"Seq len: {seq_len}\")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of feature sequences: 14\n", "Shape of first sequence: (22, 512)\n" ] } ], "source": [ "def create_time_windows(d1):\n", " import numpy as np\n", " n_time_steps = d1['n_time_steps']\n", " seq_len = d1['seq_len']\n", " # Create equal windows\n", " edges = np.linspace(0, n_time_steps, seq_len + 1, dtype=int)\n", " windows = [(edges[i], edges[i+1]) for i in range(seq_len)]\n", " \n", " # Extract feature sequences for each window\n", " feature_sequences = []\n", " for start, end in windows:\n", " seq = d1['neural_features'][start:end, :]\n", " feature_sequences.append(seq)\n", " \n", " return feature_sequences\n", "\n", "# Example usage\n", "feature_sequences = create_time_windows(d1)\n", "print(\"Number of feature sequences:\", len(feature_sequences))\n", "print(\"Shape of first sequence:\", feature_sequences[0].shape)\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train: 45, Val: 41, Test: 41\n", "Train files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_train.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.08.11/data_train.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_train.hdf5']\n", "Val files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_val.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_val.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2024.03.08/data_val.hdf5']\n", "Test files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_test.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_test.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2024.03.08/data_test.hdf5']\n" ] } ], "source": [ "import os\n", "\n", "def scan_hdf5_files(base_path):\n", " train_files = []\n", " val_files = []\n", " test_files = []\n", " for root, dirs, files in os.walk(base_path):\n", " for file in files:\n", " if file.endswith('.hdf5'):\n", " abs_path = os.path.abspath(os.path.join(root, file))\n", " if 'data_train.hdf5' in file:\n", " train_files.append(abs_path)\n", " elif 'data_val.hdf5' in file:\n", " val_files.append(abs_path)\n", " elif 'data_test.hdf5' in file:\n", " test_files.append(abs_path)\n", " return train_files, val_files, test_files\n", "\n", "# Example usage\n", "FILE_PATH = 'data/hdf5_data_final'\n", "train_list, val_list, test_list = scan_hdf5_files(FILE_PATH)\n", "print(f\"Train: {len(train_list)}, Val: {len(val_list)}, Test: {len(test_list)}\")\n", "print(\"Train files (first 3):\", train_list[:3])\n", "print(\"Val files (first 3):\", val_list[:3])\n", "print(\"Test files (first 3):\", test_list[:3])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 标签处理" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# def classify_windows_by_labels(d1):\n", "# seq_class_ids = d1['seq_class_ids'][:d1['seq_len']] # Take only the non-zero part\n", "# windows = create_time_windows(d1)\n", " \n", "# classified_windows = {}\n", "# for i, label in enumerate(seq_class_ids):\n", "# char = LOGIT_TO_PHONEME[label]\n", "# if char not in classified_windows:\n", "# classified_windows[char] = []\n", "# classified_windows[char].append(windows[i])\n", " \n", "# return classified_windows\n", "\n", "# # Example usage\n", "# classified = classify_windows_by_labels(d1)\n", "# print(\"Classified windows by label:\", classified)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "# import pandas as pd\n", "# import tqdm\n", "\n", "# b2txt_csv_df = pd.read_csv('data/t15_copyTaskData_description.csv')\n", "\n", "# def workflow(max_files=None):\n", "# group_by_labels = {}\n", "# files_to_process = train_list[:max_files] if max_files is not None else train_list\n", "# for file_path in tqdm.tqdm(files_to_process):\n", "# data = load_h5py_file(file_path, b2txt_csv_df)\n", "# for i in tqdm.tqdm(range(len(data['neural_features'])), leave=False):\n", "# # Process only the first trial for simplicity\n", "# d1 = data_patch(data, i)\n", "# classified = classify_windows_by_labels(d1)\n", "# for key, value in classified.items():\n", "# if key not in group_by_labels:\n", "# group_by_labels[key] = []\n", "# group_by_labels[key].extend(value)\n", "# return group_by_labels\n", "\n", "# # Example usage\n", "# result = workflow()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 核函数扭曲时间\n", "控制音素时间长度相同" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "# total = 0\n", "# count = 0\n", "# time_distribution = []\n", "# for i in result.values():\n", "# for j in i:\n", "# total += j.shape[0]\n", "# count += 1\n", "# time_distribution.append(j.shape[0])\n", "# print(f\"Total time steps: {total}, Total windows: {count}, Average window length: {total/count:.2f}\")\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "# import numpy as np\n", "# plt.figure(figsize=(12, 6))\n", "# plt.hist(time_distribution, bins=50, alpha=0.7, edgecolor='black')\n", "# plt.xlabel('Window Length (time steps)')\n", "# plt.ylabel('Frequency')\n", "# plt.title('Distribution of Time Window Lengths')\n", "# plt.grid(True, alpha=0.3)\n", "# plt.axvline(np.mean(time_distribution), color='red', linestyle='--', label=f'Mean: {np.mean(time_distribution):.2f}')\n", "# plt.axvline(np.median(time_distribution), color='green', linestyle='--', label=f'Median: {np.median(time_distribution):.2f}')\n", "# plt.legend()\n", "# plt.show()\n", "\n", "# print(f\"Statistics:\")\n", "# print(f\"Mean: {np.mean(time_distribution):.2f}\")\n", "# print(f\"Median: {np.median(time_distribution):.2f}\")\n", "# print(f\"Std: {np.std(time_distribution):.2f}\")\n", "# print(f\"Min: {np.min(time_distribution)}\")\n", "# print(f\"Max: {np.max(time_distribution)}\")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "# # 把时间序列对齐到这个长度上,然后做聚类去异常值\n", "# MEAN_WINDOWS_SIZE = 33\n", "\n", "# from scipy import interpolate\n", "# import numpy as np\n", "\n", "# def kernel_time_warp(sequence, target_length=33):\n", "# \"\"\"\n", "# 使用核函数处理时间序列,将序列长度标准化到target_length\n", "# - 对于长度 < target_length 的序列:使用插值扩展\n", "# - 对于长度 > target_length 的序列:使用压缩采样\n", "# - 对于长度 = target_length 的序列:直接返回\n", " \n", "# Args:\n", "# sequence: 输入时间序列 (time_steps, features)\n", "# target_length: 目标长度\n", " \n", "# Returns:\n", "# warped_sequence: 处理后的时间序列 (target_length, features)\n", "# \"\"\"\n", "# original_length = sequence.shape[0]\n", "# n_features = sequence.shape[1]\n", " \n", "# # 如果序列长度已经等于目标长度,直接返回\n", "# if original_length == target_length:\n", "# return sequence\n", " \n", "# # 处理边界情况\n", "# if original_length == 0:\n", "# return np.zeros((target_length, n_features))\n", " \n", "# if original_length == 1:\n", "# # 如果只有一个时间步,复制到所有目标时间步\n", "# return np.repeat(sequence, target_length, axis=0)\n", " \n", "# warped_sequence = np.zeros((target_length, n_features))\n", " \n", "# if original_length > target_length:\n", "# # 压缩:长序列 -> 短序列\n", "# # 使用均匀采样 + 局部平均的方式压缩\n", "# compression_ratio = original_length / target_length\n", " \n", "# for i in range(target_length):\n", "# # 计算当前目标位置对应的原始序列范围\n", "# start_idx = int(i * compression_ratio)\n", "# end_idx = int((i + 1) * compression_ratio)\n", " \n", "# # 确保不超出边界\n", "# start_idx = max(0, start_idx)\n", "# end_idx = min(original_length, end_idx)\n", " \n", "# if start_idx == end_idx:\n", "# # 避免空范围\n", "# end_idx = min(start_idx + 1, original_length)\n", " \n", "# # 对该范围内的数据取平均(压缩)\n", "# warped_sequence[i] = np.mean(sequence[start_idx:end_idx], axis=0)\n", " \n", "# else:\n", "# # 扩展:短序列 -> 长序列\n", "# # 使用插值的方式扩展\n", "# original_indices = np.linspace(0, 1, original_length)\n", "# target_indices = np.linspace(0, 1, target_length)\n", " \n", "# for feature_idx in range(n_features):\n", "# # 根据原始序列长度选择插值方法\n", "# if original_length >= 3:\n", "# # 对于长度>=3的序列,使用三次样条插值\n", "# interpolator = interpolate.interp1d(\n", "# original_indices, \n", "# sequence[:, feature_idx], \n", "# kind='cubic', \n", "# bounds_error=False, \n", "# fill_value='extrapolate'\n", "# )\n", "# else:\n", "# # 对于长度=2的序列,使用线性插值\n", "# interpolator = interpolate.interp1d(\n", "# original_indices, \n", "# sequence[:, feature_idx], \n", "# kind='linear', \n", "# bounds_error=False, \n", "# fill_value='extrapolate'\n", "# )\n", " \n", "# warped_sequence[:, feature_idx] = interpolator(target_indices)\n", " \n", "# return warped_sequence\n", "\n", "# def gaussian_kernel_weight(x, sigma=0.1):\n", "# \"\"\"\n", "# 高斯核函数权重,用于平滑处理\n", "# \"\"\"\n", "# return np.exp(-0.5 * (x / sigma) ** 2)\n", "\n", "# def process_result_with_kernel(result_dict, target_length=33):\n", "# \"\"\"\n", "# 使用核函数处理result字典中的所有时间序列\n", " \n", "# Args:\n", "# result_dict: 包含时间序列的字典\n", "# target_length: 目标长度\n", " \n", "# Returns:\n", "# processed_result: 处理后的字典\n", "# \"\"\"\n", "# processed_result = {}\n", " \n", "# print(\"Processing time series with kernel warping...\")\n", "# print(f\"Target length: {target_length}\")\n", " \n", "# # 统计不同长度的序列\n", "# length_stats = {}\n", "# total_sequences = 0\n", " \n", "# for label, sequences in tqdm.tqdm(result_dict.items()):\n", "# processed_sequences = []\n", " \n", "# for seq in sequences:\n", "# original_length = seq.shape[0]\n", " \n", "# # 统计长度分布\n", "# if original_length not in length_stats:\n", "# length_stats[original_length] = 0\n", "# length_stats[original_length] += 1\n", "# total_sequences += 1\n", " \n", "# # 应用核函数时间扭曲(包括压缩和插值)\n", "# warped_seq = kernel_time_warp(seq, target_length)\n", "# processed_sequences.append(warped_seq)\n", " \n", "# processed_result[label] = processed_sequences\n", "# print(f\"Label '{label}': {len(sequences)} sequences -> {len(processed_sequences)} sequences\")\n", " \n", "# # 打印长度统计信息\n", "# print(f\"\\n原始序列长度分布:\")\n", "# sorted_lengths = sorted(length_stats.items())\n", "# short_count = sum(count for length, count in sorted_lengths if length < target_length)\n", "# equal_count = sum(count for length, count in sorted_lengths if length == target_length)\n", "# long_count = sum(count for length, count in sorted_lengths if length > target_length)\n", " \n", "# print(f\"短于目标长度({target_length}),需要插值扩展: {short_count} 个序列\")\n", "# print(f\"等于目标长度({target_length}),无需处理: {equal_count} 个序列\")\n", "# print(f\"长于目标长度({target_length}),需要压缩: {long_count} 个序列\")\n", "# print(f\"总序列数: {total_sequences}\")\n", " \n", "# if len(sorted_lengths) <= 20: # 如果长度种类不多,显示详细分布\n", "# print(\"\\n详细长度分布:\")\n", "# for length, count in sorted_lengths:\n", "# percentage = (count / total_sequences) * 100\n", "# operation = \"\"\n", "# if length < target_length:\n", "# operation = \" (插值扩展)\"\n", "# elif length > target_length:\n", "# operation = \" (压缩)\"\n", "# else:\n", "# operation = \" (无需处理)\"\n", "# print(f\" 长度 {length}: {count} 个 ({percentage:.1f}%){operation}\")\n", " \n", "# return processed_result\n", "\n", "# # 处理result字典\n", "# processed_result = process_result_with_kernel(result, MEAN_WINDOWS_SIZE)\n", "\n", "# # 验证处理结果\n", "# print(\"\\n处理后的统计信息:\")\n", "# total_sequences = 0\n", "# for label, sequences in processed_result.items():\n", "# total_sequences += len(sequences)\n", "# if sequences: # 如果列表不为空\n", "# print(f\"Label '{label}': {len(sequences)} sequences, shape: {sequences[0].shape}\")\n", "\n", "# print(f\"总共处理了 {total_sequences} 个时间序列,目标长度: {MEAN_WINDOWS_SIZE}\")\n", "\n", "# # 验证所有序列现在都是目标长度\n", "# all_correct_length = True\n", "# for label, sequences in processed_result.items():\n", "# for seq in sequences:\n", "# if seq.shape[0] != MEAN_WINDOWS_SIZE:\n", "# print(f\"错误: 发现长度不正确的序列 - Label: {label}, Shape: {seq.shape}\")\n", "# all_correct_length = False\n", "# break\n", "# if not all_correct_length:\n", "# break\n", "\n", "# if all_correct_length:\n", "# print(f\"✅ 验证通过: 所有序列长度都已正确调整为 {MEAN_WINDOWS_SIZE}\")\n", "# print(\" - 长序列已通过压缩(局部平均)缩短\")\n", "# print(\" - 短序列已通过插值扩展\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "去掉异常的标签,去除的时候记得保存正常标签的原始索引,我们可能不用去除后的来训练模型。而是训练适应多个时间大小窗口的模型,通过单独扫描的WER大小来确定权重。再把两个一起扫描,按照权重赋值,同样用极大值抑制或者CTC来处理。\n", "建议CTC。毕竟我们多个长度窗口的模型已经和RNN差距不大了。" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# import pickle\n", "# with open('processed_time_series.pkl', 'wb') as f:\n", "# pickle.dump(processed_result, f)\n", " \n", "# with open('time_series_format.pkl', 'wb') as f:\n", "# pickle.dump(result, f)" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "# processed_result.keys()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "# from sklearn.cluster import KMeans\n", "# from sklearn.metrics import silhouette_score\n", "# from sklearn.metrics import silhouette_samples\n", "\n", "# # for key, value in processed_result.items():\n", "# key = 'IH'\n", "# value = processed_result[key]\n", "# print(f\"Label: {key}, Number of sequences: {len(value)}, Shape of first sequence: {value[0].shape if value else 'N/A'}\")\n", "\n", "# # Apply KMeans clustering to the sequences for the selected label\n", "\n", "# # First, we need to reshape the sequences to 2D for clustering\n", "# # Since all sequences now have the same length (33), we can flatten them\n", "# sequences = value\n", "# flattened_sequences = []\n", "\n", "# for seq in sequences:\n", "# # Flatten each sequence to 1D\n", "# flattened_seq = seq.flatten()\n", "# flattened_sequences.append(flattened_seq)\n", "\n", "# flattened_sequences = np.array(flattened_sequences)\n", "# print(f\"Flattened sequences shape: {flattened_sequences.shape}\")\n", "\n", "# # Perform KMeans clustering\n", "# n_clusters = 5 # You can adjust this number\n", "# kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)\n", "# cluster_labels = kmeans.fit_predict(flattened_sequences)\n", "\n", "# # Calculate silhouette score to evaluate clustering quality\n", "# silhouette_avg = silhouette_score(flattened_sequences, cluster_labels)\n", "\n", "# print(f\"Number of clusters: {n_clusters}\")\n", "# print(f\"Silhouette score: {silhouette_avg:.3f}\")\n", "# print(f\"Cluster distribution: {np.bincount(cluster_labels)}\")\n", "\n", "# # Visualize clustering results\n", "# plt.figure(figsize=(12, 8))\n", "\n", "# # Plot 1: Cluster distribution\n", "# plt.subplot(2, 2, 1)\n", "# unique_labels, counts = np.unique(cluster_labels, return_counts=True)\n", "# plt.bar(unique_labels, counts)\n", "# plt.xlabel('Cluster')\n", "# plt.ylabel('Number of Sequences')\n", "# plt.title(f'Cluster Distribution for Label \"{key}\"')\n", "\n", "# # Plot 2: First few dimensions of the data colored by cluster\n", "# plt.subplot(2, 2, 2)\n", "# for i in range(n_clusters):\n", "# mask = cluster_labels == i\n", "# plt.scatter(flattened_sequences[mask, 0], flattened_sequences[mask, 1], \n", "# label=f'Cluster {i}', alpha=0.6)\n", "# plt.xlabel('Feature 0')\n", "# plt.ylabel('Feature 1')\n", "# plt.title('Clusters in Feature Space (First 2 Dimensions)')\n", "# plt.legend()\n", "\n", "# # Plot 3: Silhouette analysis\n", "# plt.subplot(2, 2, 3)\n", "# silhouette_vals = silhouette_samples(flattened_sequences, cluster_labels)\n", "# y_lower = 10\n", "# for i in range(n_clusters):\n", "# cluster_silhouette_vals = silhouette_vals[cluster_labels == i]\n", "# cluster_silhouette_vals.sort()\n", " \n", "# size_cluster_i = cluster_silhouette_vals.shape[0]\n", "# y_upper = y_lower + size_cluster_i\n", " \n", "# color = plt.cm.nipy_spectral(float(i) / n_clusters)\n", "# plt.fill_betweenx(np.arange(y_lower, y_upper), 0, cluster_silhouette_vals,\n", "# facecolor=color, edgecolor=color, alpha=0.7)\n", " \n", "# plt.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))\n", "# y_lower = y_upper + 10\n", "\n", "# plt.axvline(x=silhouette_avg, color=\"red\", linestyle=\"--\")\n", "# plt.xlabel('Silhouette coefficient values')\n", "# plt.ylabel('Cluster label')\n", "# plt.title('Silhouette Analysis')\n", "\n", "# # Plot 4: Try different numbers of clusters\n", "# plt.subplot(2, 2, 4)\n", "# cluster_range = range(2, min(11, len(sequences)//2))\n", "# silhouette_scores = []\n", "# inertias = []\n", "\n", "# for n_clust in cluster_range:\n", "# kmeans_temp = KMeans(n_clusters=n_clust, random_state=42, n_init=10)\n", "# cluster_labels_temp = kmeans_temp.fit_predict(flattened_sequences)\n", "# silhouette_scores.append(silhouette_score(flattened_sequences, cluster_labels_temp))\n", "# inertias.append(kmeans_temp.inertia_)\n", "\n", "# plt.plot(cluster_range, silhouette_scores, 'bo-')\n", "# plt.xlabel('Number of Clusters')\n", "# plt.ylabel('Average Silhouette Score')\n", "# plt.title('Silhouette Score vs Number of Clusters')\n", "# plt.grid(True)\n", "\n", "# plt.tight_layout()\n", "# plt.show()\n", "\n", "# # Print cluster centers information\n", "# print(f\"\\nCluster centers shape: {kmeans.cluster_centers_.shape}\")\n", "# print(f\"Each cluster center represents the average pattern for sequences in that cluster\")" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "# # 余弦距离分析和距离度量对比\n", "# print(\"\\n\" + \"=\"*70)\n", "# print(\"余弦距离分析\")\n", "# print(\"=\"*70)\n", "\n", "# # 5. 计算余弦距离\n", "# print(\"\\n计算余弦距离...\")\n", "# cosine_matrix, phonemes_cosine = calculate_inter_class_distances(centroids, 'cosine')\n", "\n", "# # 6. 可视化余弦距离\n", "# print(\"\\n可视化余弦距离矩阵...\")\n", "# df_cosine, most_sim_cos, most_diff_cos = visualize_distance_matrix(\n", "# cosine_matrix, phonemes_cosine, 'cosine'\n", "# )\n", "\n", "# # 7. 分析相似音素群组(余弦距离)\n", "# similar_pairs_cos, threshold_cos = analyze_phoneme_groups(cosine_matrix, phonemes_cosine)\n", "\n", "# print(f\"\\n余弦距离分析完成!\")\n", "# print(f\"最相似音素对: {most_sim_cos}\")\n", "# print(f\"最不相似音素对: {most_diff_cos}\")\n", "\n", "# # 8. 比较两种距离度量\n", "# print(\"\\n\" + \"=\"*70)\n", "# print(\"距离度量对比分析\")\n", "# print(\"=\"*70)\n", "\n", "# def compare_distance_metrics(euclidean_matrix, cosine_matrix, phonemes):\n", "# \"\"\"\n", "# 比较不同距离度量的结果\n", "# \"\"\"\n", "# # 提取上三角矩阵的距离值\n", "# upper_triangle_indices = np.triu_indices_from(euclidean_matrix, k=1)\n", "# euclidean_distances = euclidean_matrix[upper_triangle_indices]\n", "# cosine_distances = cosine_matrix[upper_triangle_indices]\n", " \n", "# # 计算相关性\n", "# correlation = np.corrcoef(euclidean_distances, cosine_distances)[0, 1]\n", " \n", "# # 创建比较图\n", "# fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n", " \n", "# # 图1: 距离分布对比\n", "# ax1 = axes[0, 0]\n", "# ax1.hist(euclidean_distances, bins=30, alpha=0.6, label='欧氏距离', color='blue')\n", "# ax1.hist(cosine_distances, bins=30, alpha=0.6, label='余弦距离', color='red')\n", "# ax1.set_xlabel('距离值')\n", "# ax1.set_ylabel('频次')\n", "# ax1.set_title('距离分布对比')\n", "# ax1.legend()\n", "# ax1.grid(True, alpha=0.3)\n", " \n", "# # 图2: 距离相关性散点图\n", "# ax2 = axes[0, 1]\n", "# ax2.scatter(euclidean_distances, cosine_distances, alpha=0.6, s=10)\n", "# ax2.set_xlabel('欧氏距离')\n", "# ax2.set_ylabel('余弦距离')\n", "# ax2.set_title(f'距离度量相关性 (r={correlation:.3f})')\n", "# ax2.grid(True, alpha=0.3)\n", " \n", "# # 添加拟合线\n", "# z = np.polyfit(euclidean_distances, cosine_distances, 1)\n", "# p = np.poly1d(z)\n", "# ax2.plot(euclidean_distances, p(euclidean_distances), \"r--\", alpha=0.8)\n", " \n", "# # 图3: 最相似音素对比较\n", "# ax3 = axes[1, 0]\n", "# ax3.axis('off')\n", " \n", "# # 获取每种距离度量下的前10对最相似音素\n", "# eucl_top10 = similar_pairs_eucl[:10]\n", "# cos_top10 = similar_pairs_cos[:10]\n", " \n", "# comparison_text = \"最相似音素对比较 (前10对)\\n\\n\"\n", "# comparison_text += f\"{'欧氏距离':<30} {'余弦距离':<30}\\n\"\n", "# comparison_text += \"-\" * 60 + \"\\n\"\n", " \n", "# for i in range(min(10, len(eucl_top10), len(cos_top10))):\n", "# eucl_pair = f\"{eucl_top10[i][0]}-{eucl_top10[i][1]} ({eucl_top10[i][2]:.3f})\"\n", "# cos_pair = f\"{cos_top10[i][0]}-{cos_top10[i][1]} ({cos_top10[i][2]:.4f})\"\n", "# comparison_text += f\"{eucl_pair:<30} {cos_pair:<30}\\n\"\n", " \n", "# ax3.text(0.05, 0.95, comparison_text, transform=ax3.transAxes, fontsize=9,\n", "# verticalalignment='top', fontfamily='monospace',\n", "# bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))\n", " \n", "# # 图4: 统计对比\n", "# ax4 = axes[1, 1]\n", "# ax4.axis('off')\n", " \n", "# # 计算统计信息\n", "# eucl_stats = {\n", "# 'mean': np.mean(euclidean_distances),\n", "# 'median': np.median(euclidean_distances),\n", "# 'std': np.std(euclidean_distances),\n", "# 'min': np.min(euclidean_distances),\n", "# 'max': np.max(euclidean_distances)\n", "# }\n", " \n", "# cos_stats = {\n", "# 'mean': np.mean(cosine_distances),\n", "# 'median': np.median(cosine_distances),\n", "# 'std': np.std(cosine_distances),\n", "# 'min': np.min(cosine_distances),\n", "# 'max': np.max(cosine_distances)\n", "# }\n", " \n", "# stats_text = f\"\"\"\n", "# 距离度量统计对比\n", "\n", "# 指标 欧氏距离 余弦距离\n", "# {'='*45}\n", "# 平均值 {eucl_stats['mean']:.4f} {cos_stats['mean']:.4f}\n", "# 中位数 {eucl_stats['median']:.4f} {cos_stats['median']:.4f}\n", "# 标准差 {eucl_stats['std']:.4f} {cos_stats['std']:.4f}\n", "# 最小值 {eucl_stats['min']:.4f} {cos_stats['min']:.4f}\n", "# 最大值 {eucl_stats['max']:.4f} {cos_stats['max']:.4f}\n", "\n", "# 相关性系数: {correlation:.4f}\n", "\n", "# 解释:\n", "# - 欧氏距离: 测量特征空间中的直线距离\n", "# - 余弦距离: 测量向量间的角度差异\n", "# - 高相关性表明两种度量捕获相似的模式\n", "# \"\"\"\n", " \n", "# ax4.text(0.05, 0.95, stats_text, transform=ax4.transAxes, fontsize=9,\n", "# verticalalignment='top', fontfamily='monospace',\n", "# bbox=dict(boxstyle='round', facecolor='lightcyan', alpha=0.8))\n", " \n", "# plt.tight_layout()\n", "# plt.show()\n", " \n", "# return correlation, eucl_stats, cos_stats\n", "\n", "# # 执行距离度量比较\n", "# correlation, eucl_stats, cos_stats = compare_distance_metrics(\n", "# euclidean_matrix, cosine_matrix, phonemes\n", "# )\n", "\n", "# # 9. 音素聚类分析\n", "# print(\"\\n\" + \"=\"*70)\n", "# print(\"基于距离的音素聚类分析\")\n", "# print(\"=\"*70)\n", "\n", "# def analyze_phoneme_clusters(distance_matrix, phonemes, n_clusters_range=[3, 4, 5, 6]):\n", "# \"\"\"\n", "# 使用不同数量的聚类分析音素群组\n", "# \"\"\"\n", "# from sklearn.cluster import AgglomerativeClustering\n", " \n", "# results = {}\n", " \n", "# fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n", "# axes = axes.flatten()\n", " \n", "# for i, n_clusters in enumerate(n_clusters_range):\n", "# if i >= 4: # 最多显示4个图\n", "# break\n", " \n", "# # 执行层次聚类\n", "# clustering = AgglomerativeClustering(\n", "# n_clusters=n_clusters, \n", "# metric='precomputed',\n", "# linkage='average'\n", "# )\n", " \n", "# cluster_labels = clustering.fit_predict(distance_matrix)\n", " \n", "# # 分析聚类结果\n", "# clusters = {}\n", "# for phoneme, label in zip(phonemes, cluster_labels):\n", "# if label not in clusters:\n", "# clusters[label] = []\n", "# clusters[label].append(phoneme)\n", " \n", "# results[n_clusters] = clusters\n", " \n", "# # 可视化聚类结果\n", "# ax = axes[i]\n", " \n", "# # 创建颜色映射\n", "# colors = plt.cm.Set3(np.linspace(0, 1, n_clusters))\n", " \n", "# # 为每个音素分配颜色\n", "# phoneme_colors = [colors[cluster_labels[j]] for j in range(len(phonemes))]\n", " \n", "# # 使用PCA降维可视化(使用质心数据)\n", "# from sklearn.decomposition import PCA\n", " \n", "# # 重新获取质心矩阵\n", "# centroid_matrix = np.array([centroids[phoneme] for phoneme in phonemes])\n", " \n", "# if centroid_matrix.shape[1] > 2:\n", "# pca = PCA(n_components=2)\n", "# pca_result = pca.fit_transform(centroid_matrix)\n", "# else:\n", "# pca_result = centroid_matrix\n", " \n", "# # 绘制散点图\n", "# for cluster_id in range(n_clusters):\n", "# mask = cluster_labels == cluster_id\n", "# ax.scatter(pca_result[mask, 0], pca_result[mask, 1], \n", "# c=[colors[cluster_id]], label=f'聚类 {cluster_id}', s=50, alpha=0.7)\n", " \n", "# # 添加音素标签\n", "# for j, phoneme in enumerate(phonemes):\n", "# if cluster_labels[j] == cluster_id:\n", "# ax.annotate(phoneme, (pca_result[j, 0], pca_result[j, 1]), \n", "# xytext=(5, 5), textcoords='offset points', fontsize=8)\n", " \n", "# ax.set_title(f'{n_clusters} 个聚类')\n", "# ax.set_xlabel('PC1')\n", "# ax.set_ylabel('PC2')\n", "# ax.legend()\n", "# ax.grid(True, alpha=0.3)\n", " \n", "# plt.tight_layout()\n", "# plt.show()\n", " \n", "# # 打印聚类结果\n", "# for n_clusters, clusters in results.items():\n", "# print(f\"\\n{n_clusters} 个聚类的结果:\")\n", "# for cluster_id, phonemes_in_cluster in clusters.items():\n", "# print(f\" 聚类 {cluster_id}: {', '.join(phonemes_in_cluster)}\")\n", " \n", "# return results\n", "\n", "# # 执行音素聚类分析\n", "# clustering_results = analyze_phoneme_clusters(euclidean_matrix, phonemes)\n", "\n", "# print(f\"\\n音素类间距离分析完成!\")\n", "# print(f\"发现了丰富的音素相似性模式和聚类结构。\")" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "# # 优化的DBSCAN异常值检测分析\n", "# import numpy as np\n", "# import matplotlib.pyplot as plt\n", "# from sklearn.cluster import DBSCAN\n", "# from sklearn.preprocessing import StandardScaler\n", "# from sklearn.decomposition import PCA\n", "# from sklearn.metrics import silhouette_score\n", "# import pandas as pd\n", "\n", "# print(\"=\"*70)\n", "# print(\"优化的DBSCAN异常值检测分析\")\n", "# print(\"=\"*70)\n", "\n", "# def determine_optimal_pca_components(data, min_variance_ratio=0.80, max_components=500):\n", "# \"\"\"\n", "# 确定最优的PCA组件数,保证解释方差比达到要求\n", "# \"\"\"\n", "# print(f\"原始数据维度: {data.shape}\")\n", " \n", "# # 先用少量组件快速估计\n", "# n_samples, n_features = data.shape\n", " \n", "# # 确保组件数不超过样本数和特征数的最小值\n", "# max_possible_components = min(n_samples, n_features, max_components)\n", " \n", "# # 快速测试不同的组件数\n", "# test_components = [50, 100, 200, 300, 400, 500]\n", "# test_components = [c for c in test_components if c <= max_possible_components]\n", " \n", "# if not test_components:\n", "# test_components = [min(50, max_possible_components)]\n", " \n", "# print(f\"测试的组件数: {test_components}\")\n", " \n", "# best_components = test_components[0]\n", "# best_ratio = 0\n", " \n", "# for n_comp in test_components:\n", "# pca_temp = PCA(n_components=n_comp, random_state=42)\n", "# pca_temp.fit(data)\n", "# variance_ratio = np.sum(pca_temp.explained_variance_ratio_)\n", " \n", "# print(f\" {n_comp} 组件: 解释方差比 = {variance_ratio:.4f}\")\n", " \n", "# if variance_ratio >= min_variance_ratio:\n", "# best_components = n_comp\n", "# best_ratio = variance_ratio\n", "# break\n", "# elif variance_ratio > best_ratio:\n", "# best_components = n_comp\n", "# best_ratio = variance_ratio\n", " \n", "# print(f\"选择 {best_components} 个组件 (解释方差比: {best_ratio:.4f})\")\n", "# return best_components\n", "\n", "# def smart_dbscan_outlier_detection(processed_result, target_phonemes=None, min_variance_ratio=0.75):\n", "# \"\"\"\n", "# 智能DBSCAN异常值检测,使用合理的PCA降维\n", "# \"\"\"\n", "# outlier_results = {}\n", " \n", "# # 如果没有指定目标音素,选择样本数适中的音素\n", "# if target_phonemes is None:\n", "# phoneme_counts = [(p, len(seqs)) for p, seqs in processed_result.items()]\n", "# # 选择样本数在50-1000之间的音素\n", "# target_phonemes = [p for p, count in phoneme_counts if 50 <= count <= 1000]\n", "# target_phonemes = target_phonemes[:5] # 最多处理5个音素\n", " \n", "# print(f\"将处理的音素: {target_phonemes}\")\n", " \n", "# for phoneme in target_phonemes:\n", "# if phoneme not in processed_result:\n", "# continue\n", " \n", "# sequences = processed_result[phoneme]\n", "# print(f\"\\n\" + \"=\"*50)\n", "# print(f\"分析音素 '{phoneme}' ({len(sequences)} 个样本)\")\n", "# print(\"=\"*50)\n", " \n", "# # 展平序列数据\n", "# flattened_sequences = []\n", "# for seq in sequences:\n", "# flattened_sequences.append(seq.flatten())\n", " \n", "# flattened_sequences = np.array(flattened_sequences)\n", "# print(f\"原始数据形状: {flattened_sequences.shape}\")\n", " \n", "# # 标准化数据\n", "# scaler = StandardScaler()\n", "# scaled_data = scaler.fit_transform(flattened_sequences)\n", " \n", "# # 智能确定PCA组件数\n", "# optimal_components = determine_optimal_pca_components(\n", "# scaled_data, min_variance_ratio=min_variance_ratio\n", "# )\n", " \n", "# # 执行PCA降维\n", "# pca = PCA(n_components=optimal_components, random_state=42)\n", "# pca_data = pca.fit_transform(scaled_data)\n", "# variance_explained = np.sum(pca.explained_variance_ratio_)\n", " \n", "# print(f\"PCA降维结果:\")\n", "# print(f\" 原始维度: {scaled_data.shape[1]}\")\n", "# print(f\" 降维后: {pca_data.shape[1]}\")\n", "# print(f\" 解释方差比: {variance_explained:.4f}\")\n", "# print(f\" 信息保留率: {variance_explained*100:.2f}%\")\n", " \n", "# # 使用简化的DBSCAN参数搜索\n", "# print(f\"\\n开始DBSCAN参数搜索...\")\n", " \n", "# # 基于数据估计合理的eps范围\n", "# from sklearn.neighbors import NearestNeighbors\n", "# k = min(10, len(sequences)//10)\n", "# nbrs = NearestNeighbors(n_neighbors=k)\n", "# nbrs.fit(pca_data)\n", "# distances, _ = nbrs.kneighbors(pca_data)\n", "# k_distances = np.sort(distances[:, k-1])\n", " \n", "# # 选择eps候选值\n", "# eps_candidates = [\n", "# np.percentile(k_distances, 25),\n", "# np.percentile(k_distances, 50),\n", "# np.percentile(k_distances, 75),\n", "# np.percentile(k_distances, 90)\n", "# ]\n", " \n", "# min_samples_candidates = [5, 10, 15, 20]\n", " \n", "# print(f\"eps候选值: {[f'{e:.3f}' for e in eps_candidates]}\")\n", "# print(f\"min_samples候选值: {min_samples_candidates}\")\n", " \n", "# best_score = -1\n", "# best_result = None\n", " \n", "# for eps in eps_candidates:\n", "# for min_samples in min_samples_candidates:\n", "# if min_samples >= len(sequences) // 5: # min_samples不能太大\n", "# continue\n", " \n", "# dbscan = DBSCAN(eps=eps, min_samples=min_samples)\n", "# labels = dbscan.fit_predict(pca_data)\n", " \n", "# n_outliers = np.sum(labels == -1)\n", "# n_clusters = len(set(labels)) - (1 if -1 in labels else 0)\n", "# outlier_ratio = n_outliers / len(labels)\n", " \n", "# # 计算评分\n", "# if n_clusters > 0 and 0.05 <= outlier_ratio <= 0.40: # 异常值比例在5%-40%之间\n", "# try:\n", "# if len(set(labels[labels != -1])) > 1:\n", "# silhouette = silhouette_score(pca_data[labels != -1], labels[labels != -1])\n", "# else:\n", "# silhouette = 0.5 # 单聚类给中等分数\n", "# except:\n", "# silhouette = 0\n", " \n", "# # 综合评分:轮廓系数 + 合理的异常值比例奖励\n", "# if 0.1 <= outlier_ratio <= 0.25: # 最理想的异常值比例\n", "# ratio_bonus = 0.2\n", "# else:\n", "# ratio_bonus = 0.1\n", " \n", "# score = silhouette + ratio_bonus\n", " \n", "# if score > best_score:\n", "# best_score = score\n", "# best_result = {\n", "# 'eps': eps,\n", "# 'min_samples': min_samples,\n", "# 'labels': labels.copy(),\n", "# 'outliers': np.where(labels == -1)[0],\n", "# 'n_clusters': n_clusters,\n", "# 'outlier_ratio': outlier_ratio,\n", "# 'silhouette': silhouette\n", "# }\n", " \n", "# print(f\" eps={eps:.3f}, min_samples={min_samples}: \"\n", "# f\"{n_clusters}聚类, {n_outliers}异常值 ({outlier_ratio*100:.1f}%)\")\n", " \n", "# if best_result is not None:\n", "# outlier_results[phoneme] = {\n", "# **best_result,\n", "# 'pca_data': pca_data,\n", "# 'pca_model': pca,\n", "# 'variance_explained': variance_explained,\n", "# 'original_data': flattened_sequences,\n", "# 'scaled_data': scaled_data\n", "# }\n", " \n", "# print(f\"\\n✅ 找到最佳参数:\")\n", "# print(f\" eps: {best_result['eps']:.3f}\")\n", "# print(f\" min_samples: {best_result['min_samples']}\")\n", "# print(f\" 聚类数: {best_result['n_clusters']}\")\n", "# print(f\" 异常值: {len(best_result['outliers'])} ({best_result['outlier_ratio']*100:.1f}%)\")\n", "# print(f\" 轮廓系数: {best_result['silhouette']:.3f}\")\n", "# print(f\" 综合评分: {best_score:.3f}\")\n", "# else:\n", "# print(f\"\\n❌ 未找到合适的DBSCAN参数\")\n", " \n", "# return outlier_results\n", "\n", "# def visualize_smart_dbscan_results(outlier_results):\n", "# \"\"\"\n", "# 可视化智能DBSCAN结果\n", "# \"\"\"\n", "# if not outlier_results:\n", "# print(\"没有结果可可视化\")\n", "# return\n", " \n", "# n_phonemes = len(outlier_results)\n", "# fig, axes = plt.subplots(2, n_phonemes, figsize=(6*n_phonemes, 10))\n", " \n", "# if n_phonemes == 1:\n", "# axes = axes.reshape(2, 1)\n", " \n", "# for i, (phoneme, result) in enumerate(outlier_results.items()):\n", "# # 上图:PCA散点图\n", "# ax1 = axes[0, i]\n", "# pca_data = result['pca_data']\n", "# labels = result['labels']\n", " \n", "# # 绘制聚类\n", "# unique_labels = set(labels)\n", "# colors = plt.cm.Set3(np.linspace(0, 1, len(unique_labels)))\n", " \n", "# for j, label in enumerate(unique_labels):\n", "# if label == -1:\n", "# mask = labels == label\n", "# ax1.scatter(pca_data[mask, 0], pca_data[mask, 1], \n", "# c='red', marker='x', s=60, label='异常值', alpha=0.8)\n", "# else:\n", "# mask = labels == label\n", "# ax1.scatter(pca_data[mask, 0], pca_data[mask, 1], \n", "# c=[colors[j]], label=f'聚类 {label}', alpha=0.7, s=30)\n", " \n", "# ax1.set_title(f'音素 \"{phoneme}\" DBSCAN结果\\n'\n", "# f'{result[\"n_clusters\"]}聚类, {len(result[\"outliers\"])}异常值 '\n", "# f'({result[\"outlier_ratio\"]*100:.1f}%)')\n", "# ax1.set_xlabel('PC1')\n", "# ax1.set_ylabel('PC2')\n", "# ax1.legend()\n", "# ax1.grid(True, alpha=0.3)\n", " \n", "# # 下图:方差解释图\n", "# ax2 = axes[1, i]\n", "# pca_model = result['pca_model']\n", "# n_components = len(pca_model.explained_variance_ratio_)\n", " \n", "# # 绘制累计方差解释比\n", "# cumsum_var = np.cumsum(pca_model.explained_variance_ratio_)\n", "# ax2.plot(range(1, n_components+1), cumsum_var, 'b-', marker='o')\n", "# ax2.axhline(y=0.8, color='r', linestyle='--', label='80%阈值')\n", "# ax2.axhline(y=0.9, color='g', linestyle='--', label='90%阈值')\n", " \n", "# ax2.set_xlabel('主成分数量')\n", "# ax2.set_ylabel('累计解释方差比')\n", "# ax2.set_title(f'PCA方差解释 (总计: {result[\"variance_explained\"]:.3f})')\n", "# ax2.legend()\n", "# ax2.grid(True, alpha=0.3)\n", "# ax2.set_ylim(0, 1)\n", " \n", "# plt.tight_layout()\n", "# plt.show()\n", "\n", "# def analyze_outliers_detailed(outlier_results):\n", "# \"\"\"\n", "# 详细分析异常值\n", "# \"\"\"\n", "# print(\"\\n\" + \"=\"*70)\n", "# print(\"详细异常值分析报告\")\n", "# print(\"=\"*70)\n", " \n", "# for phoneme, result in outlier_results.items():\n", "# print(f\"\\n音素 '{phoneme}' 异常值分析:\")\n", "# print(\"-\" * 40)\n", " \n", "# outlier_indices = result['outliers']\n", "# normal_indices = [i for i in range(len(result['labels'])) if i not in outlier_indices]\n", " \n", "# print(f\"总样本数: {len(result['labels'])}\")\n", "# print(f\"正常样本: {len(normal_indices)} ({len(normal_indices)/len(result['labels'])*100:.1f}%)\")\n", "# print(f\"异常样本: {len(outlier_indices)} ({len(outlier_indices)/len(result['labels'])*100:.1f}%)\")\n", "# print(f\"聚类数量: {result['n_clusters']}\")\n", "# print(f\"PCA维度: {result['pca_data'].shape[1]}\")\n", "# print(f\"信息保留: {result['variance_explained']*100:.2f}%\")\n", " \n", "# # 分析异常值的特征\n", "# if len(outlier_indices) > 0:\n", "# outlier_data = result['pca_data'][outlier_indices]\n", "# normal_data = result['pca_data'][normal_indices] if len(normal_indices) > 0 else None\n", " \n", "# print(f\"\\n异常值特征分析:\")\n", "# print(f\" PC1均值: {np.mean(outlier_data[:, 0]):.3f} ± {np.std(outlier_data[:, 0]):.3f}\")\n", "# print(f\" PC2均值: {np.mean(outlier_data[:, 1]):.3f} ± {np.std(outlier_data[:, 1]):.3f}\")\n", " \n", "# if normal_data is not None:\n", "# print(f\"正常值特征对比:\")\n", "# print(f\" PC1均值: {np.mean(normal_data[:, 0]):.3f} ± {np.std(normal_data[:, 0]):.3f}\")\n", "# print(f\" PC2均值: {np.mean(normal_data[:, 1]):.3f} ± {np.std(normal_data[:, 1]):.3f}\")\n", "\n", "# # 执行优化的DBSCAN异常值检测\n", "# print(\"开始智能DBSCAN异常值检测...\")\n", "\n", "# # 选择几个有代表性的音素进行分析\n", "# target_phonemes = ['IH', 'T', 'S', 'N', 'AH'] # 手动选择一些常见音素\n", "\n", "# outlier_results = smart_dbscan_outlier_detection(\n", "# processed_result, \n", "# target_phonemes=target_phonemes,\n", "# min_variance_ratio=0.75 # 保留至少75%的方差\n", "# )\n", "\n", "# if outlier_results:\n", "# print(f\"\\n✅ 成功检测到 {len(outlier_results)} 个音素的异常值!\")\n", " \n", "# # 可视化结果\n", "# visualize_smart_dbscan_results(outlier_results)\n", " \n", "# # 详细分析\n", "# analyze_outliers_detailed(outlier_results)\n", " \n", "# print(f\"\\n📊 异常值检测总结:\")\n", "# for phoneme, result in outlier_results.items():\n", "# print(f\" {phoneme}: {len(result['outliers'])}/{len(result['labels'])} \"\n", "# f\"({result['outlier_ratio']*100:.1f}%) 异常值\")\n", " \n", "# else:\n", "# print(\"❌ 未检测到任何有效的异常值结果\")\n", "\n", "# print(f\"\\n💡 关键改进:\")\n", "# print(f\"1. 使用自适应PCA降维,保留75%以上的方差信息\")\n", "# print(f\"2. 基于数据分布智能选择DBSCAN参数\")\n", "# print(f\"3. 合理的异常值比例范围(5%-40%)\")\n", "# print(f\"4. 综合评分机制平衡聚类质量和异常值检测\")" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "# # 类内相似度分析\n", "# import numpy as np\n", "# import matplotlib.pyplot as plt\n", "# from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances\n", "# from sklearn.preprocessing import StandardScaler\n", "# import pandas as pd\n", "# import seaborn as sns\n", "\n", "# print(\"=\"*70)\n", "# print(\"类内相似度分析 - 音素分类有效性评估\")\n", "# print(\"=\"*70)\n", "\n", "# def calculate_intra_class_similarity(sequences, metric='cosine'):\n", "# \"\"\"\n", "# 计算类内相似度\n", "# \"\"\"\n", "# if len(sequences) < 2:\n", "# return np.nan, np.nan, np.nan\n", " \n", "# # 展平序列\n", "# flattened_sequences = []\n", "# for seq in sequences:\n", "# flattened_sequences.append(seq.flatten())\n", " \n", "# flattened_sequences = np.array(flattened_sequences)\n", " \n", "# # 标准化\n", "# scaler = StandardScaler()\n", "# scaled_sequences = scaler.fit_transform(flattened_sequences)\n", " \n", "# if metric == 'cosine':\n", "# # 计算余弦相似度\n", "# similarity_matrix = cosine_similarity(scaled_sequences)\n", "# # 提取上三角矩阵(排除对角线)\n", "# upper_tri = np.triu(similarity_matrix, k=1)\n", "# similarities = upper_tri[upper_tri > 0]\n", "# elif metric == 'euclidean':\n", "# # 计算欧氏距离,然后转换为相似度\n", "# distance_matrix = euclidean_distances(scaled_sequences)\n", "# # 转换为相似度(距离越小,相似度越高)\n", "# max_dist = np.max(distance_matrix)\n", "# similarity_matrix = 1 - (distance_matrix / max_dist)\n", "# upper_tri = np.triu(similarity_matrix, k=1)\n", "# similarities = upper_tri[upper_tri > 0]\n", " \n", "# mean_similarity = np.mean(similarities)\n", "# std_similarity = np.std(similarities)\n", "# median_similarity = np.median(similarities)\n", " \n", "# return mean_similarity, std_similarity, median_similarity\n", "\n", "# def analyze_phoneme_similarity(processed_result, metric='cosine', sample_limit=500):\n", "# \"\"\"\n", "# 分析每个音素的类内相似度\n", "# \"\"\"\n", "# print(f\"使用 {metric} 相似度度量\")\n", "# print(f\"每个音素最多分析 {sample_limit} 个样本\")\n", "# print(\"-\" * 50)\n", " \n", "# phoneme_similarities = {}\n", " \n", "# for phoneme, sequences in processed_result.items():\n", "# if len(sequences) < 5: # 跳过样本数太少的音素\n", "# print(f\"跳过音素 '{phoneme}' (样本数太少: {len(sequences)})\")\n", "# continue\n", " \n", "# # 如果样本太多,随机采样\n", "# if len(sequences) > sample_limit:\n", "# indices = np.random.choice(len(sequences), sample_limit, replace=False)\n", "# sampled_sequences = [sequences[i] for i in indices]\n", "# else:\n", "# sampled_sequences = sequences\n", " \n", "# mean_sim, std_sim, median_sim = calculate_intra_class_similarity(\n", "# sampled_sequences, metric=metric\n", "# )\n", " \n", "# phoneme_similarities[phoneme] = {\n", "# 'mean': mean_sim,\n", "# 'std': std_sim,\n", "# 'median': median_sim,\n", "# 'n_samples': len(sampled_sequences),\n", "# 'n_pairs': len(sampled_sequences) * (len(sampled_sequences) - 1) // 2\n", "# }\n", " \n", "# print(f\"音素 '{phoneme}': 平均相似度={mean_sim:.4f} ± {std_sim:.4f}, \"\n", "# f\"中位数={median_sim:.4f}, 样本数={len(sampled_sequences)}\")\n", " \n", "# return phoneme_similarities\n", "\n", "# def calculate_overall_similarity(processed_result, metric='cosine', sample_per_phoneme=50):\n", "# \"\"\"\n", "# 计算全部音素作为一类的相似度\n", "# \"\"\"\n", "# print(f\"\\n计算全体音素相似度 (每个音素采样 {sample_per_phoneme} 个)\")\n", "# print(\"-\" * 50)\n", " \n", "# all_sequences = []\n", "# phoneme_labels = []\n", " \n", "# # 从每个音素中采样一定数量的序列\n", "# for phoneme, sequences in processed_result.items():\n", "# if len(sequences) < 5:\n", "# continue\n", " \n", "# n_sample = min(sample_per_phoneme, len(sequences))\n", "# indices = np.random.choice(len(sequences), n_sample, replace=False)\n", " \n", "# for i in indices:\n", "# all_sequences.append(sequences[i])\n", "# phoneme_labels.append(phoneme)\n", " \n", "# print(f\"总共收集了 {len(all_sequences)} 个序列,来自 {len(set(phoneme_labels))} 个音素\")\n", " \n", "# # 计算整体相似度\n", "# mean_sim, std_sim, median_sim = calculate_intra_class_similarity(\n", "# all_sequences, metric=metric\n", "# )\n", " \n", "# overall_result = {\n", "# 'mean': mean_sim,\n", "# 'std': std_sim,\n", "# 'median': median_sim,\n", "# 'n_samples': len(all_sequences),\n", "# 'n_pairs': len(all_sequences) * (len(all_sequences) - 1) // 2,\n", "# 'n_phonemes': len(set(phoneme_labels))\n", "# }\n", " \n", "# print(f\"全体音素: 平均相似度={mean_sim:.4f} ± {std_sim:.4f}, \"\n", "# f\"中位数={median_sim:.4f}, 样本数={len(all_sequences)}\")\n", " \n", "# return overall_result, phoneme_labels\n", "\n", "# def visualize_similarity_comparison(phoneme_similarities, overall_result, metric='cosine'):\n", "# \"\"\"\n", "# 可视化相似度比较\n", "# \"\"\"\n", "# fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n", " \n", "# # 准备数据\n", "# phonemes = list(phoneme_similarities.keys())\n", "# mean_similarities = [phoneme_similarities[p]['mean'] for p in phonemes]\n", "# std_similarities = [phoneme_similarities[p]['std'] for p in phonemes]\n", "# sample_counts = [phoneme_similarities[p]['n_samples'] for p in phonemes]\n", " \n", "# overall_mean = overall_result['mean']\n", "# overall_std = overall_result['std']\n", " \n", "# # 图1: 每个音素的平均相似度 vs 整体相似度\n", "# ax1 = axes[0, 0]\n", "# bars = ax1.bar(range(len(phonemes)), mean_similarities, \n", "# yerr=std_similarities, capsize=3, alpha=0.7, color='skyblue')\n", "# ax1.axhline(y=overall_mean, color='red', linestyle='--', linewidth=2, \n", "# label=f'全体音素平均: {overall_mean:.4f}')\n", "# ax1.fill_between(range(len(phonemes)), \n", "# overall_mean - overall_std, \n", "# overall_mean + overall_std, \n", "# alpha=0.2, color='red', label=f'全体音素范围: ±{overall_std:.4f}')\n", " \n", "# ax1.set_xlabel('音素')\n", "# ax1.set_ylabel(f'{metric.title()} 相似度')\n", "# ax1.set_title('各音素类内相似度 vs 全体音素相似度')\n", "# ax1.set_xticks(range(len(phonemes)))\n", "# ax1.set_xticklabels(phonemes, rotation=45)\n", "# ax1.legend()\n", "# ax1.grid(True, alpha=0.3)\n", " \n", "# # 在柱状图上显示数值\n", "# for i, (bar, mean_val) in enumerate(zip(bars, mean_similarities)):\n", "# if mean_val > overall_mean:\n", "# ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, \n", "# f'{mean_val:.3f}', ha='center', va='bottom', fontsize=8, \n", "# color='green', weight='bold')\n", "# else:\n", "# ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, \n", "# f'{mean_val:.3f}', ha='center', va='bottom', fontsize=8, \n", "# color='red', weight='bold')\n", " \n", "# # 图2: 相似度提升程度\n", "# ax2 = axes[0, 1]\n", "# improvements = [(sim - overall_mean) for sim in mean_similarities]\n", "# colors = ['green' if imp > 0 else 'red' for imp in improvements]\n", " \n", "# bars2 = ax2.bar(range(len(phonemes)), improvements, color=colors, alpha=0.7)\n", "# ax2.axhline(y=0, color='black', linestyle='-', linewidth=1)\n", "# ax2.set_xlabel('音素')\n", "# ax2.set_ylabel('相似度提升 (相对于全体)')\n", "# ax2.set_title('音素分类的相似度提升效果')\n", "# ax2.set_xticks(range(len(phonemes)))\n", "# ax2.set_xticklabels(phonemes, rotation=45)\n", "# ax2.grid(True, alpha=0.3)\n", " \n", "# # 显示数值\n", "# for bar, imp in zip(bars2, improvements):\n", "# ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.001, \n", "# f'{imp:+.3f}', ha='center', va='bottom' if imp > 0 else 'top', \n", "# fontsize=8, weight='bold')\n", " \n", "# # 图3: 样本数量 vs 相似度\n", "# ax3 = axes[1, 0]\n", "# scatter = ax3.scatter(sample_counts, mean_similarities, \n", "# c=improvements, cmap='RdYlGn', s=60, alpha=0.7)\n", "# ax3.axhline(y=overall_mean, color='red', linestyle='--', alpha=0.7)\n", "# ax3.set_xlabel('样本数量')\n", "# ax3.set_ylabel(f'{metric.title()} 相似度')\n", "# ax3.set_title('样本数量 vs 相似度')\n", "# ax3.grid(True, alpha=0.3)\n", " \n", "# # 添加颜色条\n", "# cbar = plt.colorbar(scatter, ax=ax3)\n", "# cbar.set_label('相似度提升')\n", " \n", "# # 图4: 相似度分布统计\n", "# ax4 = axes[1, 1]\n", "# ax4.axis('off')\n", " \n", "# # 计算统计信息\n", "# positive_improvements = [imp for imp in improvements if imp > 0]\n", "# negative_improvements = [imp for imp in improvements if imp <= 0]\n", " \n", "# avg_improvement = np.mean(improvements)\n", "# max_improvement = np.max(improvements)\n", "# min_improvement = np.min(improvements)\n", " \n", "# best_phoneme = phonemes[improvements.index(max_improvement)]\n", "# worst_phoneme = phonemes[improvements.index(min_improvement)]\n", " \n", "# stats_text = f\"\"\"\n", "# 类内相似度分析统计报告\n", "\n", "# 度量方法: {metric.title()} 相似度\n", "\n", "# 全体音素基线:\n", "# - 平均相似度: {overall_mean:.4f} ± {overall_std:.4f}\n", "# - 样本总数: {overall_result['n_samples']}\n", "# - 音素数量: {overall_result['n_phonemes']}\n", "\n", "# 音素分类效果:\n", "# - 分析音素数: {len(phonemes)}\n", "# - 平均提升: {avg_improvement:+.4f}\n", "# - 最大提升: {max_improvement:+.4f} ({best_phoneme})\n", "# - 最小提升: {min_improvement:+.4f} ({worst_phoneme})\n", "\n", "# 提升统计:\n", "# - 相似度提升音素: {len(positive_improvements)}/{len(phonemes)} ({len(positive_improvements)/len(phonemes)*100:.1f}%)\n", "# - 相似度下降音素: {len(negative_improvements)}/{len(phonemes)} ({len(negative_improvements)/len(phonemes)*100:.1f}%)\n", "\n", "# 结论:\n", "# \"\"\"\n", " \n", "# if avg_improvement > 0:\n", "# stats_text += f\"✅ 音素分类整体有效 (平均提升 {avg_improvement:.4f})\\n\"\n", "# stats_text += f\" 按音素分类比混合所有音素更能保持相似性\"\n", "# else:\n", "# stats_text += f\"❌ 音素分类效果有限 (平均下降 {avg_improvement:.4f})\\n\"\n", "# stats_text += f\" 可能需要重新考虑分类策略\"\n", " \n", "# ax4.text(0.05, 0.95, stats_text, transform=ax4.transAxes, fontsize=10,\n", "# verticalalignment='top', fontfamily='monospace',\n", "# bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))\n", " \n", "# plt.tight_layout()\n", "# plt.show()\n", " \n", "# return improvements\n", "\n", "# def create_similarity_dataframe(phoneme_similarities, overall_result):\n", "# \"\"\"\n", "# 创建相似度对比的DataFrame\n", "# \"\"\"\n", "# data = []\n", "# overall_mean = overall_result['mean']\n", " \n", "# for phoneme, sim_data in phoneme_similarities.items():\n", "# improvement = sim_data['mean'] - overall_mean\n", "# relative_improvement = (improvement / overall_mean) * 100\n", " \n", "# data.append({\n", "# 'phoneme': phoneme,\n", "# 'intra_class_similarity': sim_data['mean'],\n", "# 'similarity_std': sim_data['std'],\n", "# 'n_samples': sim_data['n_samples'],\n", "# 'overall_baseline': overall_mean,\n", "# 'absolute_improvement': improvement,\n", "# 'relative_improvement_pct': relative_improvement,\n", "# 'is_better': improvement > 0\n", "# })\n", " \n", "# df = pd.DataFrame(data)\n", "# df = df.sort_values('absolute_improvement', ascending=False)\n", " \n", "# return df\n", "\n", "# # 执行类内相似度分析\n", "# print(\"开始类内相似度分析...\")\n", "\n", "# # 1. 分析每个音素的类内相似度\n", "# print(\"\\n1. 计算各音素类内相似度\")\n", "# phoneme_similarities_cosine = analyze_phoneme_similarity(\n", "# processed_result, metric='cosine', sample_limit=300\n", "# )\n", "\n", "# # 2. 计算全体音素的相似度\n", "# print(\"\\n2. 计算全体音素相似度作为基线\")\n", "# overall_result_cosine, all_phoneme_labels = calculate_overall_similarity(\n", "# processed_result, metric='cosine', sample_per_phoneme=30\n", "# )\n", "\n", "# # 3. 可视化比较\n", "# print(\"\\n3. 可视化相似度比较\")\n", "# improvements = visualize_similarity_comparison(\n", "# phoneme_similarities_cosine, overall_result_cosine, metric='cosine'\n", "# )\n", "\n", "# # 4. 创建详细对比表\n", "# print(\"\\n4. 详细相似度对比表\")\n", "# df_similarity = create_similarity_dataframe(phoneme_similarities_cosine, overall_result_cosine)\n", "# print(df_similarity.to_string(index=False, float_format='%.4f'))\n", "\n", "# print(f\"\\n✅ 类内相似度分析完成!\")\n", "# print(f\"这个分析揭示了音素分类对神经信号相似性的影响。\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 🔗 数据集批量处理工作流" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working/nejm-brain-to-text/model_training\n", "======================================================================\n", "🚀 RNN数据批量处理工具 - 新版本\n", "======================================================================\n", "🔧 创建RNN数据处理器...\n", "🔧 初始化RNN数据处理器...\n", " 模型路径: ../data/t15_pretrained_rnn_baseline\n", " 数据目录: ../data/hdf5_data_final\n", " 计算设备: cuda:0\n", "📋 模型配置:\n", " Sessions数量: 45\n", " 神经特征维度: 512\n", " Patch size: 14\n", " Patch stride: 4\n", " 输出类别数: 41\n", "✅ 模型加载成功\n", "📊 CSV数据加载完成: 265 条记录\n", "✅ 初始化完成!\n", "✅ RNN数据处理器创建成功!\n", "✅ 模型加载成功\n", "📊 CSV数据加载完成: 265 条记录\n", "✅ 初始化完成!\n", "✅ RNN数据处理器创建成功!\n" ] } ], "source": [ "%cd model_training\n", "# 🚀 RNN数据批量处理工具 - 完整版\n", "import os\n", "import torch\n", "import numpy as np\n", "import pandas as pd\n", "from omegaconf import OmegaConf\n", "import time\n", "from tqdm import tqdm\n", "import h5py\n", "from pathlib import Path\n", "\n", "# 导入模型相关模块\n", "import sys\n", "sys.path.append('../model_training')\n", "from rnn_model import GRUDecoder\n", "from evaluate_model_helpers import *\n", "from data_augmentations import gauss_smooth\n", "\n", "print(\"=\"*70)\n", "print(\"🚀 RNN数据批量处理工具 - 新版本\")\n", "print(\"=\"*70)\n", "\n", "class RNNDataProcessor:\n", " \"\"\"\n", " RNN数据批量处理器 - 生成RNN输入输出拼接数据\n", " \n", " 核心功能:\n", " 1. 加载预训练RNN模型\n", " 2. 处理原始神经数据(高斯平滑 + patch操作)\n", " 3. 获取RNN输出(40类置信度分数)\n", " 4. 拼接处理后的输入和输出\n", " 5. 批量保存所有session数据\n", " \"\"\"\n", " \n", " def __init__(self, model_path, data_dir, csv_path, device='auto'):\n", " \"\"\"\n", " 初始化处理器\n", " \n", " 参数:\n", " model_path: 预训练RNN模型路径\n", " data_dir: 数据目录路径 \n", " csv_path: 数据描述CSV文件路径\n", " device: 计算设备 ('auto', 'cpu', 'cuda:0'等)\n", " \"\"\"\n", " self.model_path = model_path\n", " self.data_dir = data_dir\n", " self.csv_path = csv_path\n", " \n", " # 设备选择\n", " if device == 'auto':\n", " self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", " else:\n", " self.device = torch.device(device)\n", " \n", " print(f\"🔧 初始化RNN数据处理器...\")\n", " print(f\" 模型路径: {model_path}\")\n", " print(f\" 数据目录: {data_dir}\")\n", " print(f\" 计算设备: {self.device}\")\n", " \n", " # 加载配置和模型\n", " self._load_config()\n", " self._load_model()\n", " self._load_csv()\n", " \n", " print(f\"✅ 初始化完成!\")\n", " \n", " def _load_config(self):\n", " \"\"\"加载模型配置\"\"\"\n", " config_path = os.path.join(self.model_path, 'checkpoint/args.yaml')\n", " if not os.path.exists(config_path):\n", " raise FileNotFoundError(f\"配置文件不存在: {config_path}\")\n", " \n", " self.model_args = OmegaConf.load(config_path)\n", " \n", " print(f\"📋 模型配置:\")\n", " print(f\" Sessions数量: {len(self.model_args['dataset']['sessions'])}\")\n", " print(f\" 神经特征维度: {self.model_args['model']['n_input_features']}\")\n", " print(f\" Patch size: {self.model_args['model']['patch_size']}\")\n", " print(f\" Patch stride: {self.model_args['model']['patch_stride']}\")\n", " print(f\" 输出类别数: {self.model_args['dataset']['n_classes']}\")\n", " \n", " def _load_model(self):\n", " \"\"\"加载预训练RNN模型\"\"\"\n", " try:\n", " # 创建模型\n", " self.model = GRUDecoder(\n", " neural_dim=self.model_args['model']['n_input_features'],\n", " n_units=self.model_args['model']['n_units'], \n", " n_days=len(self.model_args['dataset']['sessions']),\n", " n_classes=self.model_args['dataset']['n_classes'],\n", " rnn_dropout=self.model_args['model']['rnn_dropout'],\n", " input_dropout=self.model_args['model']['input_network']['input_layer_dropout'],\n", " n_layers=self.model_args['model']['n_layers'],\n", " patch_size=self.model_args['model']['patch_size'],\n", " patch_stride=self.model_args['model']['patch_stride'],\n", " )\n", " \n", " # 加载权重\n", " checkpoint_path = os.path.join(self.model_path, 'checkpoint/best_checkpoint')\n", " try:\n", " checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)\n", " except TypeError:\n", " checkpoint = torch.load(checkpoint_path, map_location=self.device)\n", " \n", " # 清理键名\n", " for key in list(checkpoint['model_state_dict'].keys()):\n", " checkpoint['model_state_dict'][key.replace(\"module.\", \"\")] = checkpoint['model_state_dict'].pop(key)\n", " checkpoint['model_state_dict'][key.replace(\"_orig_mod.\", \"\")] = checkpoint['model_state_dict'].pop(key)\n", " \n", " self.model.load_state_dict(checkpoint['model_state_dict'])\n", " self.model.to(self.device)\n", " self.model.eval()\n", " \n", " print(f\"✅ 模型加载成功\")\n", " \n", " except Exception as e:\n", " print(f\"❌ 模型加载失败: {e}\")\n", " raise\n", " \n", " def _load_csv(self):\n", " \"\"\"加载数据描述文件\"\"\"\n", " if not os.path.exists(self.csv_path):\n", " raise FileNotFoundError(f\"CSV文件不存在: {self.csv_path}\")\n", " \n", " self.csv_df = pd.read_csv(self.csv_path)\n", " print(f\"📊 CSV数据加载完成: {len(self.csv_df)} 条记录\")\n", " \n", " def _process_single_trial(self, neural_data, session_idx):\n", " \"\"\"\n", " 处理单个试验数据\n", " \n", " 参数:\n", " neural_data: 原始神经数据 [time_steps, features]\n", " session_idx: 会话索引\n", " \n", " 返回:\n", " dict: 包含拼接数据和统计信息\n", " \"\"\"\n", " # 添加batch维度\n", " neural_input = np.expand_dims(neural_data, axis=0)\n", " neural_tensor = torch.tensor(neural_input, device=self.device, dtype=torch.bfloat16)\n", " \n", " # 高斯平滑\n", " with torch.autocast(device_type=\"cuda\" if self.device.type == \"cuda\" else \"cpu\", \n", " enabled=self.model_args.get('use_amp', False), dtype=torch.bfloat16):\n", " \n", " smoothed_data = gauss_smooth(\n", " inputs=neural_tensor,\n", " device=self.device,\n", " smooth_kernel_std=self.model_args['dataset']['data_transforms']['smooth_kernel_std'],\n", " smooth_kernel_size=self.model_args['dataset']['data_transforms']['smooth_kernel_size'],\n", " padding='valid',\n", " )\n", " \n", " # Patch操作(复制模型内部逻辑)\n", " processed_data = smoothed_data\n", " if self.model.patch_size > 0:\n", " processed_data = processed_data.unsqueeze(1) # [batch, 1, time, features]\n", " processed_data = processed_data.permute(0, 3, 1, 2) # [batch, features, 1, time]\n", " \n", " # 滑动窗口提取\n", " patches = processed_data.unfold(3, self.model.patch_size, self.model.patch_stride)\n", " patches = patches.squeeze(2) # [batch, features, patches, patch_size]\n", " patches = patches.permute(0, 2, 3, 1) # [batch, patches, patch_size, features]\n", " \n", " # 展平最后两个维度\n", " processed_data = patches.reshape(patches.size(0), patches.size(1), -1)\n", " \n", " # RNN推理\n", " with torch.no_grad():\n", " logits, _ = self.model(\n", " x=smoothed_data,\n", " day_idx=torch.tensor([session_idx], device=self.device),\n", " states=None,\n", " return_state=True,\n", " )\n", " \n", " # 转换为numpy\n", " processed_features = processed_data.float().cpu().numpy()[0] # [time_steps, processed_features]\n", " confidence_scores = logits.float().cpu().numpy()[0] # [time_steps, 40]\n", " \n", " # 拼接数据\n", " concatenated = np.concatenate([processed_features, confidence_scores], axis=1)\n", " \n", " return {\n", " 'concatenated_data': concatenated,\n", " 'processed_features': processed_features,\n", " 'confidence_scores': confidence_scores,\n", " 'original_time_steps': neural_data.shape[0],\n", " 'processed_time_steps': concatenated.shape[0],\n", " 'feature_reduction_ratio': concatenated.shape[0] / neural_data.shape[0]\n", " }\n", " \n", " def process_session(self, session_name, data_types=['train', 'val', 'test']):\n", " \"\"\"\n", " 处理单个session的数据\n", " \n", " 参数:\n", " session_name: 会话名称\n", " data_types: 要处理的数据类型列表\n", " \n", " 返回:\n", " dict: 处理结果\n", " \"\"\"\n", " print(f\"\\n🔄 处理会话: {session_name}\")\n", " \n", " session_idx = self.model_args['dataset']['sessions'].index(session_name)\n", " session_results = {}\n", " \n", " for data_type in data_types:\n", " data_file = os.path.join(self.data_dir, session_name, f'data_{data_type}.hdf5')\n", " \n", " if not os.path.exists(data_file):\n", " print(f\" ⚠️ {data_type} 数据文件不存在,跳过\")\n", " continue\n", " \n", " print(f\" 📁 处理 {data_type} 数据...\")\n", " \n", " try:\n", " # 加载数据\n", " data = load_h5py_file(data_file, self.csv_df)\n", " num_trials = len(data['neural_features'])\n", " \n", " if num_trials == 0:\n", " print(f\" ⚠️ {data_type} 数据为空\")\n", " continue\n", " \n", " # 处理所有试验\n", " results = {\n", " 'concatenated_data': [],\n", " 'processed_features': [],\n", " 'confidence_scores': [],\n", " 'trial_metadata': [],\n", " 'processing_stats': []\n", " }\n", " \n", " for trial_idx in tqdm(range(num_trials), desc=f\" {data_type}\", leave=False):\n", " neural_data = data['neural_features'][trial_idx]\n", " \n", " # 处理单个试验\n", " trial_result = self._process_single_trial(neural_data, session_idx)\n", " \n", " # 保存结果\n", " results['concatenated_data'].append(trial_result['concatenated_data'])\n", " results['processed_features'].append(trial_result['processed_features'])\n", " results['confidence_scores'].append(trial_result['confidence_scores'])\n", " \n", " # 保存元数据\n", " metadata = {\n", " 'session': session_name,\n", " 'data_type': data_type,\n", " 'trial_idx': trial_idx,\n", " 'block_num': data.get('block_num', [None])[trial_idx],\n", " 'trial_num': data.get('trial_num', [None])[trial_idx],\n", " **trial_result\n", " }\n", " \n", " # 添加真实标签(如果可用)\n", " if data_type in ['train', 'val'] and 'sentence_label' in data:\n", " metadata.update({\n", " 'sentence_label': data['sentence_label'][trial_idx],\n", " 'seq_class_ids': data['seq_class_ids'][trial_idx],\n", " 'seq_len': data['seq_len'][trial_idx]\n", " })\n", " \n", " results['trial_metadata'].append(metadata)\n", " results['processing_stats'].append(trial_result)\n", " \n", " # 统计信息\n", " if results['concatenated_data']:\n", " time_steps = [data.shape[0] for data in results['concatenated_data']]\n", " feature_dims = [data.shape[1] for data in results['concatenated_data']]\n", " \n", " print(f\" ✅ {data_type} 处理完成:\")\n", " print(f\" 试验数: {len(results['concatenated_data'])}\")\n", " print(f\" 时间步范围: {min(time_steps)}-{max(time_steps)}\")\n", " print(f\" 特征维度: {feature_dims[0]} (处理后特征: {feature_dims[0]-40}, 置信度: 40)\")\n", " \n", " avg_reduction = np.mean([stat['feature_reduction_ratio'] for stat in results['processing_stats']])\n", " print(f\" 平均时间压缩比: {avg_reduction:.3f}\")\n", " \n", " session_results[data_type] = results\n", " \n", " except Exception as e:\n", " print(f\" ❌ {data_type} 处理失败: {e}\")\n", " continue\n", " \n", " return session_results\n", " \n", " def process_all_sessions(self, data_types=['train', 'val', 'test'], save_dir='./rnn_processed_data'):\n", " \"\"\"\n", " 批量处理所有sessions\n", " \n", " 参数:\n", " data_types: 要处理的数据类型\n", " save_dir: 保存目录\n", " \n", " 返回:\n", " dict: 所有处理结果\n", " \"\"\"\n", " print(f\"\\n🚀 开始批量处理所有会话...\")\n", " print(f\" 目标数据类型: {data_types}\")\n", " print(f\" 保存目录: {save_dir}\")\n", " \n", " save_path = Path(save_dir)\n", " save_path.mkdir(parents=True, exist_ok=True)\n", " \n", " all_results = {}\n", " sessions = self.model_args['dataset']['sessions']\n", " \n", " start_time = time.time()\n", " \n", " for i, session in enumerate(sessions):\n", " print(f\"\\n📊 进度: {i+1}/{len(sessions)}\")\n", " \n", " try:\n", " session_results = self.process_session(session, data_types)\n", " \n", " if session_results:\n", " all_results[session] = session_results\n", " \n", " # 保存单个session结果\n", " for data_type, data in session_results.items():\n", " filename = f\"{session}_{data_type}_rnn_processed.npz\"\n", " filepath = save_path / filename\n", " \n", " save_data = {\n", " 'concatenated_data': np.array(data['concatenated_data'], dtype=object),\n", " 'processed_features': np.array(data['processed_features'], dtype=object),\n", " 'confidence_scores': np.array(data['confidence_scores'], dtype=object),\n", " 'trial_metadata': np.array(data['trial_metadata'], dtype=object),\n", " }\n", " \n", " np.savez_compressed(str(filepath), **save_data)\n", " print(f\" 💾 保存: {filename}\")\n", " \n", " except Exception as e:\n", " print(f\"❌ 会话 {session} 处理失败: {e}\")\n", " continue\n", " \n", " # 生成总结\n", " end_time = time.time()\n", " processing_time = end_time - start_time\n", " \n", " total_trials = sum(\n", " len(session_data[data_type]['concatenated_data'])\n", " for session_data in all_results.values()\n", " for data_type in session_data.keys()\n", " )\n", " \n", " print(f\"\\n🎉 批量处理完成!\")\n", " print(f\"⏱️ 总耗时: {processing_time/60:.2f} 分钟\")\n", " print(f\"📊 处理统计:\")\n", " print(f\" 成功会话: {len(all_results)}/{len(sessions)}\")\n", " print(f\" 总试验数: {total_trials}\")\n", " print(f\"💾 数据保存在: {save_dir}\")\n", " \n", " # 保存总结信息\n", " summary = {\n", " 'processing_time': processing_time,\n", " 'total_sessions': len(all_results),\n", " 'total_trials': total_trials,\n", " 'data_types': data_types,\n", " 'sessions': list(all_results.keys()),\n", " 'model_config': {\n", " 'patch_size': self.model_args['model']['patch_size'],\n", " 'patch_stride': self.model_args['model']['patch_stride'],\n", " 'smooth_kernel_size': self.model_args['dataset']['data_transforms']['smooth_kernel_size'],\n", " 'smooth_kernel_std': self.model_args['dataset']['data_transforms']['smooth_kernel_std'],\n", " }\n", " }\n", " \n", " import json\n", " with open(save_path / 'processing_summary.json', 'w') as f:\n", " json.dump(summary, f, indent=2)\n", " \n", " return all_results\n", "\n", "# 创建处理器实例\n", "print(\"🔧 创建RNN数据处理器...\")\n", "\n", "try:\n", " processor = RNNDataProcessor(\n", " model_path='../data/t15_pretrained_rnn_baseline',\n", " data_dir='../data/hdf5_data_final',\n", " csv_path='../data/t15_copyTaskData_description.csv',\n", " device='auto'\n", " )\n", " \n", " print(f\"✅ RNN数据处理器创建成功!\")\n", " \n", "except Exception as e:\n", " print(f\"❌ 处理器创建失败: {e}\")\n", " processor = None" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "======================================================================\n", "🎯 RNN数据批量处理 - 使用示例\n", "======================================================================\n", "\n", "📋 可用的处理方法:\n", "1️⃣ 单session处理: processor.process_session('session_name')\n", "2️⃣ 批量处理所有: processor.process_all_sessions()\n", "\n", "📊 可用会话数量: 45\n", "📝 前5个会话: ['t15.2023.08.11', 't15.2023.08.13', 't15.2023.08.18', 't15.2023.08.20', 't15.2023.08.25']\n", "\n", "🧪 快速测试: 处理会话 't15.2023.08.13' 的训练数据...\n", "\n", "🔄 处理会话: t15.2023.08.13\n", " 📁 处理 train 数据...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] }, { "name": "stdout", "output_type": "stream", "text": [ " ✅ train 处理完成:\n", " 试验数: 348\n", " 时间步范围: 55-352\n", " 特征维度: 7209 (处理后特征: 7169, 置信度: 40)\n", " 平均时间压缩比: 0.243\n", "\n", "✅ 测试完成!结果概览:\n", " 处理的试验数: 348\n", " 第一个试验数据形状: (251, 7209)\n", " 特征维度详情:\n", " - 处理后的神经特征: 7168 维\n", " - RNN置信度分数: 41 维\n", " - 总拼接特征: 7209 维\n", " - 时间步数: 251\n", " 样本元数据:\n", " - 原始时间步: 1023\n", " - 处理后时间步: 251\n", " - 时间压缩比: 0.245\n", " - 句子标签: Which is most unfortunate because we all lose out.\n", "\n", "💡 要批量处理所有数据,运行:\n", " results = processor.process_all_sessions()\n", " # 这将处理所有45个sessions的train/val/test数据\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r" ] } ], "source": [ "# 🎯 使用示例和批量处理\n", "\n", "print(\"=\"*70)\n", "print(\"🎯 RNN数据批量处理 - 使用示例\")\n", "print(\"=\"*70)\n", "\n", "if processor is not None:\n", " \n", " # 方法1: 处理单个session (推荐用于测试)\n", " print(\"\\n📋 可用的处理方法:\")\n", " print(\"1️⃣ 单session处理: processor.process_session('session_name')\")\n", " print(\"2️⃣ 批量处理所有: processor.process_all_sessions()\")\n", " \n", " # 显示可用的sessions\n", " sessions = processor.model_args['dataset']['sessions']\n", " print(f\"\\n📊 可用会话数量: {len(sessions)}\")\n", " print(f\"📝 前5个会话: {sessions[:5]}\")\n", " \n", " # 快速测试 - 处理第一个session的部分数据\n", " test_session = sessions[1] # 't15.2023.08.11'\n", " \n", " print(f\"\\n🧪 快速测试: 处理会话 '{test_session}' 的训练数据...\")\n", " \n", " # 处理单个session(仅train数据进行测试)\n", " single_result = processor.process_session(test_session, ['train'])\n", " \n", " if single_result and 'train' in single_result:\n", " train_data = single_result['train']\n", " \n", " print(f\"\\n✅ 测试完成!结果概览:\")\n", " print(f\" 处理的试验数: {len(train_data['concatenated_data'])}\")\n", " \n", " if len(train_data['concatenated_data']) > 0:\n", " sample_data = train_data['concatenated_data'][0]\n", " print(f\" 第一个试验数据形状: {sample_data.shape}\")\n", " print(f\" 特征维度详情:\")\n", " print(f\" - 处理后的神经特征: {sample_data.shape[1] - 41} 维\")\n", " print(f\" - RNN置信度分数: 41 维\")\n", " print(f\" - 总拼接特征: {sample_data.shape[1]} 维\")\n", " print(f\" - 时间步数: {sample_data.shape[0]}\")\n", " \n", " # 显示一些样本元数据\n", " sample_metadata = train_data['trial_metadata'][0]\n", " print(f\" 样本元数据:\")\n", " print(f\" - 原始时间步: {sample_metadata['original_time_steps']}\")\n", " print(f\" - 处理后时间步: {sample_metadata['processed_time_steps']}\")\n", " print(f\" - 时间压缩比: {sample_metadata['feature_reduction_ratio']:.3f}\")\n", " \n", " if 'sentence_label' in sample_metadata:\n", " print(f\" - 句子标签: {sample_metadata['sentence_label']}\")\n", " \n", " print(f\"\\n💡 要批量处理所有数据,运行:\")\n", " print(f\" results = processor.process_all_sessions()\")\n", " print(f\" # 这将处理所有45个sessions的train/val/test数据\")\n", " \n", "else:\n", " print(\"❌ 处理器未创建成功,请检查上面的错误信息\")" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "======================================================================\n", "🚀 批量处理选项\n", "======================================================================\n", "📊 批量处理配置:\n", " 启用批量处理: False\n", " 保存目录: ./rnn_processed_data\n", " 数据类型: ['train', 'val', 'test']\n", " 总会话数: 45\n", "\n", "💡 要开始批量处理,请将 ENABLE_FULL_PROCESSING 设为 True\n", " 或者手动运行: processor.process_all_sessions()\n", "\n", "📋 数据使用说明:\n", "✅ 处理完成后,每个文件包含:\n", " - concatenated_data: 拼接后的特征 [神经特征(7168) + 置信度(41)]\n", " - processed_features: 仅处理后的神经特征\n", " - confidence_scores: 仅RNN输出的41类置信度分数\n", " - trial_metadata: 试验元数据(标签、时间步等)\n", "\n", "🔧 加载保存的数据:\n", " data = np.load('session_name_train_rnn_processed.npz', allow_pickle=True)\n", " features = data['concatenated_data'] # 用于训练分类器\n", " metadata = data['trial_metadata'] # 获取标签和其他信息\n" ] } ], "source": [ "# 🚀 批量处理所有数据 (可选择运行)\n", "\n", "print(\"=\"*70)\n", "print(\"🚀 批量处理选项\")\n", "print(\"=\"*70)\n", "\n", "# 设置参数\n", "ENABLE_FULL_PROCESSING = False # 设为True开始批量处理\n", "SAVE_DIR = \"./rnn_processed_data\" # 保存目录\n", "DATA_TYPES = ['train', 'val', 'test'] # 要处理的数据类型\n", "\n", "print(f\"📊 批量处理配置:\")\n", "print(f\" 启用批量处理: {ENABLE_FULL_PROCESSING}\")\n", "print(f\" 保存目录: {SAVE_DIR}\")\n", "print(f\" 数据类型: {DATA_TYPES}\")\n", "print(f\" 总会话数: {len(processor.model_args['dataset']['sessions'])}\")\n", "\n", "if ENABLE_FULL_PROCESSING and processor is not None:\n", " print(f\"\\n🚀 开始批量处理所有数据...\")\n", " print(f\"⚠️ 这可能需要较长时间(预计30-60分钟)\")\n", " \n", " # 批量处理\n", " all_results = processor.process_all_sessions(\n", " data_types=DATA_TYPES,\n", " save_dir=SAVE_DIR\n", " )\n", " \n", " print(f\"🎉 批量处理完成!结果保存在: {SAVE_DIR}\")\n", " \n", "else:\n", " print(f\"\\n💡 要开始批量处理,请将 ENABLE_FULL_PROCESSING 设为 True\")\n", " print(f\" 或者手动运行: processor.process_all_sessions()\")\n", "\n", "print(f\"\\n📋 数据使用说明:\")\n", "print(f\"✅ 处理完成后,每个文件包含:\")\n", "print(f\" - concatenated_data: 拼接后的特征 [神经特征(7168) + 置信度(41)]\")\n", "print(f\" - processed_features: 仅处理后的神经特征\")\n", "print(f\" - confidence_scores: 仅RNN输出的41类置信度分数\")\n", "print(f\" - trial_metadata: 试验元数据(标签、时间步等)\")\n", "print(f\"\")\n", "print(f\"🔧 加载保存的数据:\")\n", "print(f\" data = np.load('session_name_train_rnn_processed.npz', allow_pickle=True)\")\n", "print(f\" features = data['concatenated_data'] # 用于训练分类器\")\n", "print(f\" metadata = data['trial_metadata'] # 获取标签和其他信息\")" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(212, 7209)" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "single_result['train']['concatenated_data'][2].shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 模型建立" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "🌲 初始化随机森林回归模型\n", "🌲 随机森林回归器初始化:\n", " 时间窗口大小: 30\n", " 树的数量: 100\n", " 最大深度: 10\n", " 并行任务: -1\n", "\n", "✅ 随机森林回归器准备完成!\n", "🔧 下一步: 准备训练数据和开始训练\n" ] } ], "source": [ "# 🌲 随机森林多输出回归模型实现\n", "import numpy as np\n", "from sklearn.ensemble import RandomForestRegressor\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error\n", "from sklearn.preprocessing import StandardScaler\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from tqdm import tqdm\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "class TimeWindowRandomForestRegressor:\n", " \"\"\"\n", " 基于时间窗口的随机森林多输出回归器\n", " 用于预测40个音素的概率分布\n", " \"\"\"\n", " \n", " def __init__(self, window_size=30, n_estimators=100, max_depth=10, n_jobs=-1, random_state=42):\n", " \"\"\"\n", " 初始化模型\n", " \n", " 参数:\n", " window_size: 时间窗口大小\n", " n_estimators: 随机森林中树的数量\n", " max_depth: 树的最大深度\n", " n_jobs: 并行任务数\n", " random_state: 随机种子\n", " \"\"\"\n", " self.window_size = window_size\n", " self.n_estimators = n_estimators\n", " self.max_depth = max_depth\n", " self.n_jobs = n_jobs\n", " self.random_state = random_state\n", " \n", " # 初始化模型和预处理器\n", " self.regressor = RandomForestRegressor(\n", " n_estimators=n_estimators,\n", " max_depth=max_depth,\n", " n_jobs=n_jobs,\n", " random_state=random_state,\n", " verbose=1\n", " )\n", " \n", " self.scaler = StandardScaler()\n", " self.is_fitted = False\n", " \n", " print(f\"🌲 随机森林回归器初始化:\")\n", " print(f\" 时间窗口大小: {window_size}\")\n", " print(f\" 树的数量: {n_estimators}\")\n", " print(f\" 最大深度: {max_depth}\")\n", " print(f\" 并行任务: {n_jobs}\")\n", " \n", " def create_time_windows(self, neural_features, phoneme_targets=None):\n", " \"\"\"\n", " 创建时间窗口特征\n", " \n", " 参数:\n", " neural_features: 神经特征 [time_steps, 512]\n", " phoneme_targets: 音素目标 [time_steps, 40] (可选)\n", " \n", " 返回:\n", " windowed_features: [samples, window_size * 512]\n", " windowed_targets: [samples, 40] (如果提供targets)\n", " \"\"\"\n", " if len(neural_features) < self.window_size:\n", " print(f\"⚠️ 数据长度 {len(neural_features)} 小于窗口大小 {self.window_size}\")\n", " return None, None\n", " \n", " n_samples = len(neural_features) - self.window_size + 1\n", " n_features = neural_features.shape[1]\n", " \n", " # 创建时间窗口特征\n", " windowed_features = np.zeros((n_samples, self.window_size * n_features))\n", " \n", " for i in range(n_samples):\n", " # 展平时间窗口内的所有特征\n", " window_data = neural_features[i:i+self.window_size].flatten()\n", " windowed_features[i] = window_data\n", " \n", " windowed_targets = None\n", " if phoneme_targets is not None:\n", " # 使用窗口中心点的音素概率作为目标\n", " center_offset = self.window_size // 2\n", " windowed_targets = phoneme_targets[center_offset:center_offset+n_samples]\n", " \n", " return windowed_features, windowed_targets\n", " \n", " def prepare_dataset_for_training(self, datasets_list, dataset_type=\"train\"):\n", " \"\"\"\n", " 准备训练数据集\n", " \n", " 参数:\n", " datasets_list: DataFrame列表 (train_datasets, val_datasets, etc.)\n", " dataset_type: 数据集类型名称\n", " \n", " 返回:\n", " X: 特征矩阵 [总样本数, window_size * 512]\n", " y: 目标矩阵 [总样本数, 40]\n", " \"\"\"\n", " print(f\"\\n📊 准备{dataset_type}数据集:\")\n", " print(f\" 输入数据集数量: {len(datasets_list)}\")\n", " \n", " all_X = []\n", " all_y = []\n", " \n", " for i, df in enumerate(tqdm(datasets_list, desc=f\"处理{dataset_type}数据\")):\n", " # 提取神经特征 (前512列)\n", " neural_cols = [col for col in df.columns if col.startswith('neural_feat_')]\n", " neural_features = df[neural_cols].values\n", " \n", " # 提取音素目标 (40列音素概率)\n", " phoneme_cols = [col for col in df.columns if col.startswith('phoneme_')]\n", " phoneme_targets = df[phoneme_cols].values\n", " \n", " # 按trial分组处理\n", " trials = df['trial_idx'].unique()\n", " \n", " for trial_idx in trials:\n", " trial_mask = df['trial_idx'] == trial_idx\n", " trial_neural = neural_features[trial_mask]\n", " trial_phonemes = phoneme_targets[trial_mask]\n", " \n", " # 创建时间窗口\n", " windowed_X, windowed_y = self.create_time_windows(trial_neural, trial_phonemes)\n", " \n", " if windowed_X is not None and windowed_y is not None:\n", " all_X.append(windowed_X)\n", " all_y.append(windowed_y)\n", " \n", " if not all_X:\n", " print(f\"❌ 没有有效的{dataset_type}数据\")\n", " return None, None\n", " \n", " # 合并所有数据\n", " X = np.vstack(all_X)\n", " y = np.vstack(all_y)\n", " \n", " print(f\" ✅ {dataset_type}数据准备完成:\")\n", " print(f\" 特征矩阵形状: {X.shape}\")\n", " print(f\" 目标矩阵形状: {y.shape}\")\n", " print(f\" 内存使用: {X.nbytes / 1024**2:.1f} MB (X) + {y.nbytes / 1024**2:.1f} MB (y)\")\n", " \n", " return X, y\n", " \n", " def fit(self, X_train, y_train, X_val=None, y_val=None):\n", " \"\"\"\n", " 训练模型\n", " \n", " 参数:\n", " X_train: 训练特征\n", " y_train: 训练目标\n", " X_val: 验证特征 (可选)\n", " y_val: 验证目标 (可选)\n", " \"\"\"\n", " print(f\"\\n🚀 开始训练随机森林回归模型:\")\n", " print(f\" 训练样本数: {X_train.shape[0]:,}\")\n", " print(f\" 特征维度: {X_train.shape[1]:,}\")\n", " print(f\" 目标维度: {y_train.shape[1]}\")\n", " \n", " # 标准化特征\n", " print(\" 🔄 标准化特征...\")\n", " X_train_scaled = self.scaler.fit_transform(X_train)\n", " \n", " # 训练模型\n", " print(\" 🌲 训练随机森林...\")\n", " self.regressor.fit(X_train_scaled, y_train)\n", " \n", " self.is_fitted = True\n", " \n", " # 计算训练集性能\n", " print(\" 📊 评估训练集性能...\")\n", " train_predictions = self.regressor.predict(X_train_scaled)\n", " train_mse = mean_squared_error(y_train, train_predictions)\n", " train_r2 = r2_score(y_train, train_predictions)\n", " train_mae = mean_absolute_error(y_train, train_predictions)\n", " \n", " print(f\" ✅ 训练完成!\")\n", " print(f\" 训练集 MSE: {train_mse:.6f}\")\n", " print(f\" 训练集 R²: {train_r2:.4f}\")\n", " print(f\" 训练集 MAE: {train_mae:.6f}\")\n", " \n", " # 如果有验证集,计算验证集性能\n", " if X_val is not None and y_val is not None:\n", " print(\" 📊 评估验证集性能...\")\n", " X_val_scaled = self.scaler.transform(X_val)\n", " val_predictions = self.regressor.predict(X_val_scaled)\n", " val_mse = mean_squared_error(y_val, val_predictions)\n", " val_r2 = r2_score(y_val, val_predictions)\n", " val_mae = mean_absolute_error(y_val, val_predictions)\n", " \n", " print(f\" 验证集 MSE: {val_mse:.6f}\")\n", " print(f\" 验证集 R²: {val_r2:.4f}\")\n", " print(f\" 验证集 MAE: {val_mae:.6f}\")\n", " \n", " return {\n", " 'train_mse': train_mse, 'train_r2': train_r2, 'train_mae': train_mae,\n", " 'val_mse': val_mse, 'val_r2': val_r2, 'val_mae': val_mae\n", " }\n", " \n", " return {\n", " 'train_mse': train_mse, 'train_r2': train_r2, 'train_mae': train_mae\n", " }\n", " \n", " def predict(self, X):\n", " \"\"\"预测\"\"\"\n", " if not self.is_fitted:\n", " raise ValueError(\"模型尚未训练,请先调用fit()方法\")\n", " \n", " X_scaled = self.scaler.transform(X)\n", " return self.regressor.predict(X_scaled)\n", " \n", " def get_feature_importance(self, top_k=20):\n", " \"\"\"获取特征重要性\"\"\"\n", " if not self.is_fitted:\n", " raise ValueError(\"模型尚未训练,请先调用fit()方法\")\n", " \n", " importances = self.regressor.feature_importances_\n", " \n", " # 创建特征名称 (window_timestep_feature)\n", " feature_names = []\n", " for t in range(self.window_size):\n", " for f in range(512):\n", " feature_names.append(f\"t{t}_feat{f}\")\n", " \n", " # 获取top-k重要特征\n", " top_indices = np.argsort(importances)[::-1][:top_k]\n", " top_features = [(feature_names[i], importances[i]) for i in top_indices]\n", " \n", " return top_features, importances\n", "\n", "# 初始化模型\n", "print(\"🌲 初始化随机森林回归模型\")\n", "rf_regressor = TimeWindowRandomForestRegressor(\n", " window_size=30, # 时间窗口大小\n", " n_estimators=100, # 树的数量\n", " max_depth=10, # 最大深度 (防止过拟合)\n", " n_jobs=-1, # 使用所有CPU核心\n", " random_state=42\n", ")\n", "\n", "print(\"\\n✅ 随机森林回归器准备完成!\")\n", "print(\"🔧 下一步: 准备训练数据和开始训练\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "🚀 开始数据准备和模型训练流程\n", "============================================================\n" ] }, { "ename": "NameError", "evalue": "name 'train_datasets' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/tmp/ipykernel_37/3627267466.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;31m# 检查数据可用性\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mtrain_datasets\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"❌ 没有可用的训练数据,请先运行数据处理工作流\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mNameError\u001b[0m: name 'train_datasets' is not defined" ] } ], "source": [ "# 🚀 准备数据并训练随机森林回归模型\n", "print(\"🚀 开始数据准备和模型训练流程\")\n", "print(\"=\"*60)\n", "\n", "# 检查数据可用性\n", "if not train_datasets:\n", " print(\"❌ 没有可用的训练数据,请先运行数据处理工作流\")\n", "else:\n", " print(f\"✅ 检测到数据:\")\n", " print(f\" 训练数据集: {len(train_datasets)} 个sessions\")\n", " print(f\" 验证数据集: {len(val_datasets)} 个sessions\")\n", " print(f\" 测试数据集: {len(test_datasets)} 个sessions\")\n", "\n", " # 1. 准备训练数据\n", " print(f\"\\n📊 第1步: 准备训练数据\")\n", " X_train, y_train = rf_regressor.prepare_dataset_for_training(train_datasets, \"训练集\")\n", " \n", " # 2. 准备验证数据\n", " print(f\"\\n📊 第2步: 准备验证数据\")\n", " X_val, y_val = rf_regressor.prepare_dataset_for_training(val_datasets, \"验证集\")\n", " \n", " if X_train is not None and y_train is not None:\n", " print(f\"\\n📈 数据准备完成统计:\")\n", " print(f\" 训练集: {X_train.shape[0]:,} 样本\")\n", " print(f\" 验证集: {X_val.shape[0]:,} 样本\" if X_val is not None else \" 验证集: 无\")\n", " print(f\" 特征维度: {X_train.shape[1]:,} (时间窗口30 × 512特征)\")\n", " print(f\" 目标维度: {y_train.shape[1]} (40个音素概率)\")\n", " \n", " # 检查数据质量\n", " print(f\"\\n🔍 数据质量检查:\")\n", " print(f\" 训练特征范围: [{X_train.min():.4f}, {X_train.max():.4f}]\")\n", " print(f\" 训练目标范围: [{y_train.min():.4f}, {y_train.max():.4f}]\")\n", " print(f\" 训练特征均值: {X_train.mean():.4f}\")\n", " print(f\" 训练目标均值: {y_train.mean():.4f}\")\n", " \n", " # 检查是否有NaN或无穷值\n", " nan_count_X = np.isnan(X_train).sum()\n", " nan_count_y = np.isnan(y_train).sum()\n", " inf_count_X = np.isinf(X_train).sum()\n", " inf_count_y = np.isinf(y_train).sum()\n", " \n", " print(f\" NaN检查: X有{nan_count_X}个, y有{nan_count_y}个\")\n", " print(f\" Inf检查: X有{inf_count_X}个, y有{inf_count_y}个\")\n", " \n", " if nan_count_X > 0 or nan_count_y > 0 or inf_count_X > 0 or inf_count_y > 0:\n", " print(\"⚠️ 检测到异常值,将进行清理...\")\n", " # 清理异常值\n", " valid_mask = ~(np.isnan(X_train).any(axis=1) | np.isnan(y_train).any(axis=1) | \n", " np.isinf(X_train).any(axis=1) | np.isinf(y_train).any(axis=1))\n", " X_train = X_train[valid_mask]\n", " y_train = y_train[valid_mask]\n", " \n", " if X_val is not None and y_val is not None:\n", " valid_mask_val = ~(np.isnan(X_val).any(axis=1) | np.isnan(y_val).any(axis=1) | \n", " np.isinf(X_val).any(axis=1) | np.isinf(y_val).any(axis=1))\n", " X_val = X_val[valid_mask_val]\n", " y_val = y_val[valid_mask_val]\n", " \n", " print(f\"✅ 数据清理完成,剩余训练样本: {X_train.shape[0]:,}\")\n", " \n", " # 3. 训练模型\n", " print(f\"\\n🌲 第3步: 训练随机森林回归模型\")\n", " training_results = rf_regressor.fit(X_train, y_train, X_val, y_val)\n", " \n", " # 4. 分析训练结果\n", " print(f\"\\n📊 第4步: 训练结果分析\")\n", " print(\"=\"*50)\n", " \n", " for metric, value in training_results.items():\n", " metric_name = metric.replace('_', ' ').title()\n", " print(f\" {metric_name}: {value:.6f}\")\n", " \n", " # 5. 特征重要性分析\n", " print(f\"\\n🔍 第5步: 特征重要性分析\")\n", " top_features, all_importances = rf_regressor.get_feature_importance(top_k=20)\n", " \n", " print(f\"\\n🏆 Top 20 重要特征:\")\n", " print(f\"{'排名':>4} {'特征名称':>15} {'重要性':>10}\")\n", " print(\"-\" * 35)\n", " for i, (feature_name, importance) in enumerate(top_features):\n", " print(f\"{i+1:>4} {feature_name:>15} {importance:>10.6f}\")\n", " \n", " # 分析时间窗口内的重要性分布\n", " print(f\"\\n📈 时间窗口重要性分布:\")\n", " window_importances = np.zeros(rf_regressor.window_size)\n", " for i in range(rf_regressor.window_size):\n", " start_idx = i * 512\n", " end_idx = (i + 1) * 512\n", " window_importances[i] = all_importances[start_idx:end_idx].sum()\n", " \n", " max_time_step = np.argmax(window_importances)\n", " print(f\" 最重要的时间步: t{max_time_step} (重要性: {window_importances[max_time_step]:.6f})\")\n", " print(f\" 窗口中心位置: t{rf_regressor.window_size//2}\")\n", " print(f\" 重要性分布: 前5个时间步的重要性\")\n", " for i in range(min(5, len(window_importances))):\n", " print(f\" t{i}: {window_importances[i]:.6f}\")\n", " \n", " print(f\"\\n✅ 随机森林回归模型训练完成!\")\n", " print(f\"🎯 模型可以预测40个音素的概率分布\")\n", " print(f\"📊 基于30时间步的神经特征窗口\")\n", " print(f\"🌲 使用{rf_regressor.n_estimators}棵决策树\")\n", " \n", " else:\n", " print(\"❌ 数据准备失败,无法训练模型\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 📊 模型评估和可视化分析\n", "def evaluate_phoneme_predictions(rf_model, X_test, y_test, dataset_name=\"测试集\"):\n", " \"\"\"\n", " 评估每个音素的预测性能\n", " \"\"\"\n", " print(f\"\\n📊 {dataset_name}详细评估\")\n", " print(\"=\"*50)\n", " \n", " # 获取预测结果\n", " y_pred = rf_model.predict(X_test)\n", " \n", " # 计算每个音素的性能指标\n", " phoneme_metrics = []\n", " \n", " for i in range(40): # 40个音素\n", " phoneme_name = LOGIT_TO_PHONEME[i]\n", " \n", " # 计算单个音素的指标\n", " mse = mean_squared_error(y_test[:, i], y_pred[:, i])\n", " r2 = r2_score(y_test[:, i], y_pred[:, i])\n", " mae = mean_absolute_error(y_test[:, i], y_pred[:, i])\n", " \n", " # 计算相关系数\n", " correlation = np.corrcoef(y_test[:, i], y_pred[:, i])[0, 1]\n", " \n", " phoneme_metrics.append({\n", " 'phoneme_id': i,\n", " 'phoneme_name': phoneme_name,\n", " 'mse': mse, \n", " 'r2': r2,\n", " 'mae': mae,\n", " 'correlation': correlation if not np.isnan(correlation) else 0.0\n", " })\n", " \n", " # 转换为DataFrame便于分析\n", " metrics_df = pd.DataFrame(phoneme_metrics)\n", " \n", " # 打印总体统计\n", " print(f\"📈 总体性能指标:\")\n", " print(f\" 平均 MSE: {metrics_df['mse'].mean():.6f}\")\n", " print(f\" 平均 R²: {metrics_df['r2'].mean():.4f}\")\n", " print(f\" 平均 MAE: {metrics_df['mae'].mean():.6f}\")\n", " print(f\" 平均相关系数: {metrics_df['correlation'].mean():.4f}\")\n", " \n", " # 找出最佳和最差预测的音素\n", " best_r2_idx = metrics_df['r2'].idxmax()\n", " worst_r2_idx = metrics_df['r2'].idxmin()\n", " \n", " print(f\"\\n🏆 最佳预测音素:\")\n", " best_phoneme = metrics_df.loc[best_r2_idx]\n", " print(f\" {best_phoneme['phoneme_name']} (ID: {best_phoneme['phoneme_id']})\")\n", " print(f\" R²: {best_phoneme['r2']:.4f}, MSE: {best_phoneme['mse']:.6f}\")\n", " \n", " print(f\"\\n📉 最差预测音素:\")\n", " worst_phoneme = metrics_df.loc[worst_r2_idx]\n", " print(f\" {worst_phoneme['phoneme_name']} (ID: {worst_phoneme['phoneme_id']})\")\n", " print(f\" R²: {worst_phoneme['r2']:.4f}, MSE: {worst_phoneme['mse']:.6f}\")\n", " \n", " return metrics_df, y_pred\n", "\n", "def visualize_prediction_results(metrics_df, y_true, y_pred, save_plots=False):\n", " \"\"\"\n", " 可视化预测结果\n", " \"\"\"\n", " print(f\"\\n📊 创建可视化图表...\")\n", " \n", " # 设置图表样式\n", " plt.style.use('default')\n", " fig = plt.figure(figsize=(20, 12))\n", " \n", " # 1. R²分数分布\n", " plt.subplot(2, 3, 1)\n", " plt.hist(metrics_df['r2'], bins=20, alpha=0.7, color='skyblue', edgecolor='black')\n", " plt.axvline(metrics_df['r2'].mean(), color='red', linestyle='--', \n", " label=f'平均值: {metrics_df[\"r2\"].mean():.4f}')\n", " plt.xlabel('R² Score')\n", " plt.ylabel('音素数量')\n", " plt.title('R² Score 分布')\n", " plt.legend()\n", " plt.grid(True, alpha=0.3)\n", " \n", " # 2. MSE分布\n", " plt.subplot(2, 3, 2)\n", " plt.hist(metrics_df['mse'], bins=20, alpha=0.7, color='lightcoral', edgecolor='black')\n", " plt.axvline(metrics_df['mse'].mean(), color='red', linestyle='--',\n", " label=f'平均值: {metrics_df[\"mse\"].mean():.6f}')\n", " plt.xlabel('Mean Squared Error')\n", " plt.ylabel('音素数量')\n", " plt.title('MSE 分布')\n", " plt.legend()\n", " plt.grid(True, alpha=0.3)\n", " \n", " # 3. 前10个音素的性能对比\n", " plt.subplot(2, 3, 3)\n", " top_10 = metrics_df.nlargest(10, 'r2')\n", " bars = plt.bar(range(10), top_10['r2'], color='lightgreen', alpha=0.7)\n", " plt.xlabel('音素排名')\n", " plt.ylabel('R² Score')\n", " plt.title('Top 10 音素预测性能')\n", " plt.xticks(range(10), top_10['phoneme_name'], rotation=45)\n", " plt.grid(True, alpha=0.3)\n", " \n", " # 添加数值标签\n", " for i, bar in enumerate(bars):\n", " height = bar.get_height()\n", " plt.text(bar.get_x() + bar.get_width()/2., height + 0.001,\n", " f'{height:.3f}', ha='center', va='bottom', fontsize=8)\n", " \n", " # 4. 真实值 vs 预测值散点图 (选择最佳音素)\n", " plt.subplot(2, 3, 4)\n", " best_phoneme_idx = metrics_df['r2'].idxmax()\n", " phoneme_id = metrics_df.loc[best_phoneme_idx, 'phoneme_id']\n", " phoneme_name = metrics_df.loc[best_phoneme_idx, 'phoneme_name']\n", " \n", " # 随机采样1000个点以避免图表过于密集\n", " sample_size = min(1000, len(y_true))\n", " sample_indices = np.random.choice(len(y_true), sample_size, replace=False)\n", " \n", " plt.scatter(y_true[sample_indices, phoneme_id], y_pred[sample_indices, phoneme_id], \n", " alpha=0.6, s=20, color='blue')\n", " \n", " # 添加对角线 (完美预测线)\n", " min_val = min(y_true[:, phoneme_id].min(), y_pred[:, phoneme_id].min())\n", " max_val = max(y_true[:, phoneme_id].max(), y_pred[:, phoneme_id].max())\n", " plt.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8, label='完美预测')\n", " \n", " plt.xlabel('真实值')\n", " plt.ylabel('预测值')\n", " plt.title(f'最佳音素 {phoneme_name} 的预测结果')\n", " plt.legend()\n", " plt.grid(True, alpha=0.3)\n", " \n", " # 5. 相关系数热力图 (前20个音素)\n", " plt.subplot(2, 3, 5)\n", " top_20_correlations = metrics_df.nlargest(20, 'correlation')\n", " corr_data = top_20_correlations[['phoneme_name', 'correlation']].set_index('phoneme_name')\n", " \n", " # 创建热力图数据\n", " heatmap_data = corr_data.values.reshape(-1, 1)\n", " im = plt.imshow(heatmap_data.T, cmap='RdYlBu_r', aspect='auto', vmin=0, vmax=1)\n", " \n", " plt.colorbar(im, shrink=0.8)\n", " plt.yticks([0], ['相关系数'])\n", " plt.xticks(range(len(top_20_correlations)), top_20_correlations['phoneme_name'], \n", " rotation=45, ha='right')\n", " plt.title('Top 20 音素相关系数')\n", " \n", " # 6. 各音素预测误差箱线图 (前10个音素)\n", " plt.subplot(2, 3, 6)\n", " top_10_ids = metrics_df.nlargest(10, 'r2')['phoneme_id'].values\n", " errors_data = []\n", " labels = []\n", " \n", " for phoneme_id in top_10_ids:\n", " errors = np.abs(y_true[:, phoneme_id] - y_pred[:, phoneme_id])\n", " errors_data.append(errors)\n", " labels.append(LOGIT_TO_PHONEME[phoneme_id])\n", " \n", " plt.boxplot(errors_data, labels=labels)\n", " plt.xlabel('音素')\n", " plt.ylabel('绝对误差')\n", " plt.title('Top 10 音素预测误差分布')\n", " plt.xticks(rotation=45)\n", " plt.grid(True, alpha=0.3)\n", " \n", " plt.tight_layout()\n", " \n", " if save_plots:\n", " plt.savefig('./processed_datasets/rf_regression_results.png', dpi=300, bbox_inches='tight')\n", " print(\"📁 图表已保存至: ./processed_datasets/rf_regression_results.png\")\n", " \n", " plt.show()\n", "\n", "# 如果模型训练成功,进行评估\n", "if 'rf_regressor' in locals() and rf_regressor.is_fitted:\n", " print(f\"\\n🎯 开始模型评估和可视化\")\n", " \n", " # 评估验证集\n", " if X_val is not None and y_val is not None:\n", " val_metrics, val_predictions = evaluate_phoneme_predictions(\n", " rf_regressor, X_val, y_val, \"验证集\"\n", " )\n", " \n", " # 可视化结果\n", " visualize_prediction_results(val_metrics, y_val, val_predictions, save_plots=True)\n", " \n", " # 保存详细结果\n", " val_metrics.to_csv('./processed_datasets/phoneme_prediction_metrics.csv', index=False)\n", " print(f\"\\n📁 详细评估结果已保存至: ./processed_datasets/phoneme_prediction_metrics.csv\")\n", " \n", " # 准备测试集数据 (如果有)\n", " if test_datasets:\n", " print(f\"\\n🔮 准备测试集预测...\")\n", " X_test, y_test = rf_regressor.prepare_dataset_for_training(test_datasets, \"测试集\")\n", " \n", " if X_test is not None:\n", " test_metrics, test_predictions = evaluate_phoneme_predictions(\n", " rf_regressor, X_test, y_test, \"测试集\"\n", " )\n", " print(f\"\\n✅ 测试集评估完成\")\n", " else:\n", " print(f\"⚠️ 测试集数据准备失败\")\n", " \n", " print(f\"\\n🎉 随机森林回归模型完整评估完成!\")\n", " print(f\"📊 生成了详细的性能分析和可视化图表\")\n", " print(f\"🔧 模型已准备好用于实际预测任务\")\n", " \n", "else:\n", " print(\"⚠️ 模型尚未训练完成,请先运行训练代码\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 🎯 回归结果转分类结果分析\n", "def regression_to_classification_analysis(y_true_probs, y_pred_probs, show_detailed_metrics=True):\n", " \"\"\"\n", " 将回归预测的40个音素概率转换为分类结果并进行分析\n", " \n", " 参数:\n", " y_true_probs: 真实的40个音素概率 [n_samples, 40]\n", " y_pred_probs: 预测的40个音素概率 [n_samples, 40]\n", " show_detailed_metrics: 是否显示详细的分类指标\n", " \n", " 返回:\n", " classification_results: 包含分类结果的字典\n", " \"\"\"\n", " print(\"🎯 回归结果转分类结果分析\")\n", " print(\"=\"*60)\n", " \n", " # 1. 将概率转换为分类标签\n", " y_true_classes = np.argmax(y_true_probs, axis=1) # 真实类别\n", " y_pred_classes = np.argmax(y_pred_probs, axis=1) # 预测类别\n", " \n", " # 2. 计算分类准确率\n", " accuracy = (y_true_classes == y_pred_classes).mean()\n", " \n", " print(f\"📊 分类结果概览:\")\n", " print(f\" 总样本数: {len(y_true_classes):,}\")\n", " print(f\" 分类准确率: {accuracy:.4f} ({accuracy*100:.2f}%)\")\n", " print(f\" 正确预测: {(y_true_classes == y_pred_classes).sum():,}\")\n", " print(f\" 错误预测: {(y_true_classes != y_pred_classes).sum():,}\")\n", " \n", " # 3. 分析预测置信度\n", " pred_confidences = np.max(y_pred_probs, axis=1) # 预测的最大概率\n", " true_confidences = np.max(y_true_probs, axis=1) # 真实的最大概率\n", " \n", " print(f\"\\n🔍 预测置信度分析:\")\n", " print(f\" 预测置信度均值: {pred_confidences.mean():.4f}\")\n", " print(f\" 预测置信度标准差: {pred_confidences.std():.4f}\")\n", " print(f\" 预测置信度范围: [{pred_confidences.min():.4f}, {pred_confidences.max():.4f}]\")\n", " print(f\" 真实置信度均值: {true_confidences.mean():.4f}\")\n", " \n", " # 4. 按置信度分组的准确率分析\n", " confidence_bins = [0.0, 0.3, 0.5, 0.7, 0.9, 1.0]\n", " print(f\"\\n📈 按预测置信度分组的准确率:\")\n", " print(f\"{'置信度区间':>12} {'样本数':>8} {'准确率':>8} {'百分比':>8}\")\n", " print(\"-\" * 40)\n", " \n", " for i in range(len(confidence_bins)-1):\n", " low, high = confidence_bins[i], confidence_bins[i+1]\n", " mask = (pred_confidences >= low) & (pred_confidences < high)\n", " if i == len(confidence_bins)-2: # 最后一个区间包含等号\n", " mask = (pred_confidences >= low) & (pred_confidences <= high)\n", " \n", " if mask.sum() > 0:\n", " bin_accuracy = (y_true_classes[mask] == y_pred_classes[mask]).mean()\n", " count = mask.sum()\n", " percentage = count / len(y_true_classes) * 100\n", " print(f\"[{low:.1f}, {high:.1f}{')'if i1} {count:>8} {bin_accuracy:>8.4f} {percentage:>7.1f}%\")\n", " \n", " # 5. 混淆矩阵分析(Top-K音素)\n", " from collections import Counter\n", " \n", " # 找出最常见的音素\n", " true_counter = Counter(y_true_classes)\n", " pred_counter = Counter(y_pred_classes)\n", " \n", " most_common_true = true_counter.most_common(10)\n", " most_common_pred = pred_counter.most_common(10)\n", " \n", " print(f\"\\n🏆 最常见的音素 (真实 vs 预测):\")\n", " print(f\"{'真实音素':>12} {'次数':>6} {'预测音素':>12} {'次数':>6}\")\n", " print(\"-\" * 42)\n", " \n", " for i in range(min(len(most_common_true), len(most_common_pred))):\n", " true_id, true_count = most_common_true[i]\n", " pred_id, pred_count = most_common_pred[i]\n", " true_name = LOGIT_TO_PHONEME[true_id]\n", " pred_name = LOGIT_TO_PHONEME[pred_id]\n", " print(f\"{true_name:>12} {true_count:>6} {pred_name:>12} {pred_count:>6}\")\n", " \n", " # 6. 每个音素的分类性能\n", " if show_detailed_metrics:\n", " from sklearn.metrics import classification_report, confusion_matrix\n", " \n", " print(f\"\\n📋 详细分类报告 (前20个最常见音素):\")\n", " \n", " # 获取前20个最常见的音素\n", " top_20_phonemes = [phoneme_id for phoneme_id, _ in most_common_true[:20]]\n", " \n", " # 创建掩码,只包含这些音素\n", " mask_top20 = np.isin(y_true_classes, top_20_phonemes)\n", " y_true_top20 = y_true_classes[mask_top20]\n", " y_pred_top20 = y_pred_classes[mask_top20]\n", " \n", " # 生成分类报告\n", " target_names = [LOGIT_TO_PHONEME[i] for i in top_20_phonemes]\n", " \n", " try:\n", " report = classification_report(\n", " y_true_top20, y_pred_top20, \n", " labels=top_20_phonemes,\n", " target_names=target_names,\n", " output_dict=True,\n", " zero_division=0\n", " )\n", " \n", " # 打印格式化的报告\n", " print(f\"{'音素':>8} {'精确率':>8} {'召回率':>8} {'F1分数':>8} {'支持数':>8}\")\n", " print(\"-\" * 48)\n", " \n", " for phoneme_id in top_20_phonemes:\n", " phoneme_name = LOGIT_TO_PHONEME[phoneme_id]\n", " if phoneme_name in report:\n", " metrics = report[phoneme_name]\n", " print(f\"{phoneme_name:>8} {metrics['precision']:>8.4f} {metrics['recall']:>8.4f} \"\n", " f\"{metrics['f1-score']:>8.4f} {int(metrics['support']):>8}\")\n", " \n", " # 总体指标\n", " macro_avg = report['macro avg']\n", " weighted_avg = report['weighted avg']\n", " print(\"-\" * 48)\n", " print(f\"{'宏平均':>8} {macro_avg['precision']:>8.4f} {macro_avg['recall']:>8.4f} \"\n", " f\"{macro_avg['f1-score']:>8.4f}\")\n", " print(f\"{'加权平均':>8} {weighted_avg['precision']:>8.4f} {weighted_avg['recall']:>8.4f} \"\n", " f\"{weighted_avg['f1-score']:>8.4f}\")\n", " \n", " except Exception as e:\n", " print(f\"分类报告生成失败: {e}\")\n", " \n", " # 7. Top-K准确率分析\n", " print(f\"\\n🎯 Top-K 准确率分析:\")\n", " for k in [1, 3, 5, 10]:\n", " # 计算Top-K准确率\n", " top_k_pred = np.argsort(y_pred_probs, axis=1)[:, -k:] # 取概率最高的K个\n", " top_k_accuracy = np.mean([y_true_classes[i] in top_k_pred[i] for i in range(len(y_true_classes))])\n", " print(f\" Top-{k} 准确率: {top_k_accuracy:.4f} ({top_k_accuracy*100:.2f}%)\")\n", " \n", " # 8. 错误分析 - 最常见的预测错误\n", " print(f\"\\n❌ 最常见的预测错误:\")\n", " error_mask = y_true_classes != y_pred_classes\n", " error_pairs = list(zip(y_true_classes[error_mask], y_pred_classes[error_mask]))\n", " error_counter = Counter(error_pairs)\n", " \n", " print(f\"{'真实音素':>12} {'预测音素':>12} {'错误次数':>8}\")\n", " print(\"-\" * 36)\n", " for (true_id, pred_id), count in error_counter.most_common(10):\n", " true_name = LOGIT_TO_PHONEME[true_id]\n", " pred_name = LOGIT_TO_PHONEME[pred_id]\n", " print(f\"{true_name:>12} {pred_name:>12} {count:>8}\")\n", " \n", " # 返回结果字典\n", " classification_results = {\n", " 'accuracy': accuracy,\n", " 'y_true_classes': y_true_classes,\n", " 'y_pred_classes': y_pred_classes,\n", " 'pred_confidences': pred_confidences,\n", " 'true_confidences': true_confidences,\n", " 'most_common_errors': error_counter.most_common(10)\n", " }\n", " \n", " return classification_results\n", "\n", "def create_classification_visualizations(y_true_probs, y_pred_probs, classification_results):\n", " \"\"\"\n", " 为分类结果创建可视化图表\n", " \"\"\"\n", " print(f\"\\n📊 创建分类结果可视化...\")\n", " \n", " fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n", " fig.suptitle('随机森林回归转分类结果分析', fontsize=16, fontweight='bold')\n", " \n", " y_true_classes = classification_results['y_true_classes']\n", " y_pred_classes = classification_results['y_pred_classes']\n", " pred_confidences = classification_results['pred_confidences']\n", " \n", " # 1. 预测置信度分布\n", " axes[0, 0].hist(pred_confidences, bins=50, alpha=0.7, color='skyblue', edgecolor='black')\n", " axes[0, 0].axvline(pred_confidences.mean(), color='red', linestyle='--', \n", " label=f'均值: {pred_confidences.mean():.3f}')\n", " axes[0, 0].set_xlabel('预测置信度')\n", " axes[0, 0].set_ylabel('样本数量')\n", " axes[0, 0].set_title('预测置信度分布')\n", " axes[0, 0].legend()\n", " axes[0, 0].grid(True, alpha=0.3)\n", " \n", " # 2. 准确率 vs 置信度\n", " confidence_bins = np.linspace(0, 1, 21)\n", " bin_centers = (confidence_bins[:-1] + confidence_bins[1:]) / 2\n", " bin_accuracies = []\n", " bin_counts = []\n", " \n", " for i in range(len(confidence_bins)-1):\n", " mask = (pred_confidences >= confidence_bins[i]) & (pred_confidences < confidence_bins[i+1])\n", " if mask.sum() > 0:\n", " accuracy = (y_true_classes[mask] == y_pred_classes[mask]).mean()\n", " bin_accuracies.append(accuracy)\n", " bin_counts.append(mask.sum())\n", " else:\n", " bin_accuracies.append(0)\n", " bin_counts.append(0)\n", " \n", " # 只显示有数据的bins\n", " valid_bins = np.array(bin_counts) > 0\n", " axes[0, 1].plot(bin_centers[valid_bins], np.array(bin_accuracies)[valid_bins], \n", " 'bo-', linewidth=2, markersize=6)\n", " axes[0, 1].set_xlabel('预测置信度')\n", " axes[0, 1].set_ylabel('准确率')\n", " axes[0, 1].set_title('准确率 vs 预测置信度')\n", " axes[0, 1].grid(True, alpha=0.3)\n", " axes[0, 1].set_ylim(0, 1)\n", " \n", " # 3. 最常见音素的预测准确率\n", " from collections import Counter\n", " true_counter = Counter(y_true_classes)\n", " most_common_phonemes = [phoneme_id for phoneme_id, _ in true_counter.most_common(15)]\n", " \n", " phoneme_accuracies = []\n", " phoneme_names = []\n", " for phoneme_id in most_common_phonemes:\n", " mask = y_true_classes == phoneme_id\n", " if mask.sum() > 0:\n", " accuracy = (y_pred_classes[mask] == phoneme_id).mean()\n", " phoneme_accuracies.append(accuracy)\n", " phoneme_names.append(LOGIT_TO_PHONEME[phoneme_id])\n", " \n", " bars = axes[0, 2].bar(range(len(phoneme_names)), phoneme_accuracies, \n", " color='lightgreen', alpha=0.7)\n", " axes[0, 2].set_xlabel('音素')\n", " axes[0, 2].set_ylabel('准确率')\n", " axes[0, 2].set_title('Top 15 音素的分类准确率')\n", " axes[0, 2].set_xticks(range(len(phoneme_names)))\n", " axes[0, 2].set_xticklabels(phoneme_names, rotation=45, ha='right')\n", " axes[0, 2].grid(True, alpha=0.3)\n", " \n", " # 添加数值标签\n", " for bar, acc in zip(bars, phoneme_accuracies):\n", " height = bar.get_height()\n", " axes[0, 2].text(bar.get_x() + bar.get_width()/2., height + 0.01,\n", " f'{acc:.3f}', ha='center', va='bottom', fontsize=8)\n", " \n", " # 4. 混淆矩阵(前10个最常见音素)\n", " from sklearn.metrics import confusion_matrix\n", " top_10_phonemes = most_common_phonemes[:10]\n", " mask_top10 = np.isin(y_true_classes, top_10_phonemes) & np.isin(y_pred_classes, top_10_phonemes)\n", " \n", " if mask_top10.sum() > 0:\n", " cm = confusion_matrix(y_true_classes[mask_top10], y_pred_classes[mask_top10], \n", " labels=top_10_phonemes)\n", " \n", " im = axes[1, 0].imshow(cm, interpolation='nearest', cmap='Blues')\n", " axes[1, 0].set_title('混淆矩阵 (Top 10 音素)')\n", " \n", " # 添加颜色条\n", " cbar = plt.colorbar(im, ax=axes[1, 0], shrink=0.8)\n", " cbar.set_label('预测次数')\n", " \n", " # 设置标签\n", " tick_marks = np.arange(len(top_10_phonemes))\n", " top_10_names = [LOGIT_TO_PHONEME[i] for i in top_10_phonemes]\n", " axes[1, 0].set_xticks(tick_marks)\n", " axes[1, 0].set_yticks(tick_marks)\n", " axes[1, 0].set_xticklabels(top_10_names, rotation=45, ha='right')\n", " axes[1, 0].set_yticklabels(top_10_names)\n", " axes[1, 0].set_xlabel('预测音素')\n", " axes[1, 0].set_ylabel('真实音素')\n", " \n", " # 5. Top-K准确率\n", " k_values = [1, 2, 3, 4, 5, 10, 15, 20]\n", " top_k_accuracies = []\n", " \n", " for k in k_values:\n", " top_k_pred = np.argsort(y_pred_probs, axis=1)[:, -k:]\n", " top_k_accuracy = np.mean([y_true_classes[i] in top_k_pred[i] for i in range(len(y_true_classes))])\n", " top_k_accuracies.append(top_k_accuracy)\n", " \n", " axes[1, 1].plot(k_values, top_k_accuracies, 'ro-', linewidth=2, markersize=8)\n", " axes[1, 1].set_xlabel('K 值')\n", " axes[1, 1].set_ylabel('Top-K 准确率')\n", " axes[1, 1].set_title('Top-K 准确率曲线')\n", " axes[1, 1].grid(True, alpha=0.3)\n", " axes[1, 1].set_ylim(0, 1)\n", " \n", " # 添加数值标签\n", " for k, acc in zip(k_values, top_k_accuracies):\n", " axes[1, 1].annotate(f'{acc:.3f}', (k, acc), textcoords=\"offset points\", \n", " xytext=(0,10), ha='center')\n", " \n", " # 6. 错误分析 - 最常见错误的热力图\n", " error_pairs = classification_results['most_common_errors'][:25] # 前25个最常见错误\n", " if error_pairs:\n", " # 创建错误矩阵\n", " unique_phonemes = list(set([pair[0][0] for pair in error_pairs] + [pair[0][1] for pair in error_pairs]))\n", " error_matrix = np.zeros((len(unique_phonemes), len(unique_phonemes)))\n", " \n", " phoneme_to_idx = {phoneme: i for i, phoneme in enumerate(unique_phonemes)}\n", " \n", " for (true_id, pred_id), count in error_pairs:\n", " if true_id in phoneme_to_idx and pred_id in phoneme_to_idx:\n", " true_idx = phoneme_to_idx[true_id]\n", " pred_idx = phoneme_to_idx[pred_id]\n", " error_matrix[true_idx, pred_idx] = count\n", " \n", " im = axes[1, 2].imshow(error_matrix, cmap='Reds', interpolation='nearest')\n", " axes[1, 2].set_title('最常见错误分布')\n", " \n", " # 设置标签\n", " phoneme_names = [LOGIT_TO_PHONEME[p] for p in unique_phonemes]\n", " axes[1, 2].set_xticks(range(len(phoneme_names)))\n", " axes[1, 2].set_yticks(range(len(phoneme_names)))\n", " axes[1, 2].set_xticklabels(phoneme_names, rotation=45, ha='right')\n", " axes[1, 2].set_yticklabels(phoneme_names)\n", " axes[1, 2].set_xlabel('预测音素')\n", " axes[1, 2].set_ylabel('真实音素')\n", " \n", " # 添加颜色条\n", " cbar = plt.colorbar(im, ax=axes[1, 2], shrink=0.8)\n", " cbar.set_label('错误次数')\n", " \n", " plt.tight_layout()\n", " plt.savefig('./processed_datasets/classification_analysis.png', dpi=300, bbox_inches='tight')\n", " print(\"📁 分类分析图表已保存至: ./processed_datasets/classification_analysis.png\")\n", " plt.show()\n", "\n", "print(\"✅ 回归转分类分析功能已创建!\")\n", "print(\"🎯 主要功能:\")\n", "print(\"• 将40维概率回归结果转换为分类预测\")\n", "print(\"• 计算分类准确率和置信度分析\")\n", "print(\"• 提供Top-K准确率评估\")\n", "print(\"• 生成详细的混淆矩阵和错误分析\")\n", "print(\"• 创建全面的可视化图表\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 🎯 完整的回归转分类评估流程\n", "def complete_regression_classification_evaluation(rf_model, X_test, y_test, dataset_name=\"测试集\"):\n", " \"\"\"\n", " 完整的回归模型转分类结果评估流程\n", " \"\"\"\n", " print(f\"\\n🎯 {dataset_name}完整评估: 回归 → 分类\")\n", " print(\"=\"*70)\n", " \n", " # 1. 获取回归预测结果\n", " print(\"📊 第1步: 获取回归预测...\")\n", " y_pred_probs = rf_model.predict(X_test)\n", " \n", " # 确保概率值在合理范围内\n", " y_pred_probs = np.clip(y_pred_probs, 0, 1)\n", " \n", " # 2. 回归性能评估\n", " print(\"\\n📈 第2步: 回归性能评估...\")\n", " mse = mean_squared_error(y_test, y_pred_probs) \n", " mae = mean_absolute_error(y_test, y_pred_probs)\n", " r2 = r2_score(y_test, y_pred_probs)\n", " \n", " print(f\" 回归 MSE: {mse:.6f}\")\n", " print(f\" 回归 MAE: {mae:.6f}\")\n", " print(f\" 回归 R²: {r2:.4f}\")\n", " \n", " # 3. 概率归一化(softmax)\n", " print(\"\\n🔄 第3步: 概率归一化...\")\n", " def softmax(x, axis=-1):\n", " exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))\n", " return exp_x / np.sum(exp_x, axis=axis, keepdims=True)\n", " \n", " # 对预测结果应用softmax,使其成为真正的概率分布\n", " y_pred_probs_normalized = softmax(y_pred_probs)\n", " y_test_normalized = softmax(y_test) # 也对真实标签归一化\n", " \n", " print(f\" 预测概率归一化前: 每行和均值 = {np.mean(np.sum(y_pred_probs, axis=1)):.4f}\")\n", " print(f\" 预测概率归一化后: 每行和均值 = {np.mean(np.sum(y_pred_probs_normalized, axis=1)):.4f}\")\n", " \n", " # 4. 分类结果分析\n", " print(\"\\n🎯 第4步: 分类结果分析...\")\n", " classification_results = regression_to_classification_analysis(\n", " y_test_normalized, y_pred_probs_normalized, show_detailed_metrics=True\n", " )\n", " \n", " # 5. 创建可视化\n", " print(\"\\n📊 第5步: 创建可视化图表...\")\n", " create_classification_visualizations(y_test_normalized, y_pred_probs_normalized, classification_results)\n", " \n", " # 6. 保存结果\n", " print(\"\\n💾 第6步: 保存分析结果...\")\n", " \n", " # 保存分类结果\n", " results_df = pd.DataFrame({\n", " 'true_class': classification_results['y_true_classes'],\n", " 'pred_class': classification_results['y_pred_classes'],\n", " 'true_phoneme': [LOGIT_TO_PHONEME[i] for i in classification_results['y_true_classes']],\n", " 'pred_phoneme': [LOGIT_TO_PHONEME[i] for i in classification_results['y_pred_classes']],\n", " 'pred_confidence': classification_results['pred_confidences'],\n", " 'is_correct': classification_results['y_true_classes'] == classification_results['y_pred_classes']\n", " })\n", " \n", " results_df.to_csv('./processed_datasets/classification_results.csv', index=False)\n", " \n", " # 保存详细的概率预测\n", " prob_results_df = pd.DataFrame(y_pred_probs_normalized, \n", " columns=[f'prob_{LOGIT_TO_PHONEME[i]}' for i in range(40)])\n", " prob_results_df['true_class'] = classification_results['y_true_classes']\n", " prob_results_df['pred_class'] = classification_results['y_pred_classes']\n", " \n", " prob_results_df.to_csv('./processed_datasets/probability_predictions.csv', index=False)\n", " \n", " print(\"📁 结果已保存:\")\n", " print(\" • ./processed_datasets/classification_results.csv (分类结果)\")\n", " print(\" • ./processed_datasets/probability_predictions.csv (概率预测)\")\n", " print(\" • ./processed_datasets/classification_analysis.png (可视化图表)\")\n", " \n", " # 7. 总结报告\n", " print(f\"\\n📋 {dataset_name}评估总结:\")\n", " print(\"=\"*50)\n", " print(f\"🔸 回归性能:\")\n", " print(f\" MSE: {mse:.6f}\")\n", " print(f\" R²: {r2:.4f}\")\n", " print(f\"🔸 分类性能:\")\n", " print(f\" 准确率: {classification_results['accuracy']:.4f} ({classification_results['accuracy']*100:.2f}%)\")\n", " print(f\" 平均置信度: {classification_results['pred_confidences'].mean():.4f}\")\n", " \n", " # 计算Top-K准确率\n", " for k in [1, 3, 5]:\n", " top_k_pred = np.argsort(y_pred_probs_normalized, axis=1)[:, -k:]\n", " top_k_accuracy = np.mean([classification_results['y_true_classes'][i] in top_k_pred[i] \n", " for i in range(len(classification_results['y_true_classes']))])\n", " print(f\" Top-{k} 准确率: {top_k_accuracy:.4f} ({top_k_accuracy*100:.2f}%)\")\n", " \n", " return {\n", " 'regression_metrics': {'mse': mse, 'mae': mae, 'r2': r2},\n", " 'classification_results': classification_results,\n", " 'normalized_predictions': y_pred_probs_normalized,\n", " 'normalized_true': y_test_normalized\n", " }\n", "\n", "# 如果模型已训练且有验证数据,执行完整评估\n", "if 'rf_regressor' in locals() and hasattr(rf_regressor, 'is_fitted') and rf_regressor.is_fitted:\n", " if 'X_val' in locals() and X_val is not None and 'y_val' in locals() and y_val is not None:\n", " print(\"🚀 开始完整的回归转分类评估...\")\n", " \n", " # 执行完整评估\n", " evaluation_results = complete_regression_classification_evaluation(\n", " rf_regressor, X_val, y_val, \"验证集\"\n", " )\n", " \n", " print(f\"\\n🎉 评估完成!\")\n", " print(f\"✅ 随机森林回归模型成功转换为分类结果\")\n", " print(f\"📊 生成了详细的性能分析和可视化\")\n", " print(f\"💾 所有结果已保存到文件\")\n", " \n", " # 如果有测试数据,也进行评估\n", " if 'X_test' in locals() and X_test is not None and 'y_test' in locals() and y_test is not None:\n", " print(f\"\\n🔮 开始测试集评估...\")\n", " test_evaluation_results = complete_regression_classification_evaluation(\n", " rf_regressor, X_test, y_test, \"测试集\"\n", " )\n", " else:\n", " print(\"⚠️ 没有可用的验证数据进行评估\")\n", "else:\n", " print(\"⚠️ 随机森林模型尚未训练完成\")\n", " print(\"💡 请先运行前面的训练代码\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([17.75 , -0.80859375, -2.03125 , -1.8046875 , -0.85546875,\n", " -1.421875 , 1.765625 , -2.703125 , -1.984375 , 4.0625 ,\n", " 2. , -3.15625 , 0.72265625, -0.8671875 , -1.90625 ,\n", " -2.0625 , -1.28125 , -1.03125 , 0.21289062, -1.890625 ,\n", " -0.4453125 , -0.5546875 , 0.5625 , -0.421875 , -0.22460938,\n", " 0.3515625 , -2.375 , -1.8984375 , 2.796875 , 0.3515625 ,\n", " -2.484375 , 1.453125 , 0.30078125, -2.390625 , 0.19335938,\n", " 0.35742188, -1.484375 , -2.8125 , -0.84375 , -3.0625 ,\n", " 4.96875 ], dtype=float32)" ] }, "execution_count": 84, "metadata": {}, "output_type": "execute_result" } ], "source": [ "single_result['train']['confidence_scores'][0][0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "dict_keys(['concatenated_data', 'processed_features', 'confidence_scores', 'trial_metadata', 'processing_stats'])" ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "single_result['train'].keys()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🌲 随机森林" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "======================================================================\n", "🎯 重新定义任务:音素分类任务\n", "======================================================================\n", "📋 任务重新定义:\n", " 输入: 神经特征 (前7168维)\n", " 输出: 音素置信度 (后41维RNN输出)\n", " 目标: 预测每个时间步的音素概率分布\n", " 分类: 选择最大置信度对应的音素\n", "\n", "🔧 数据重新处理:\n", " 神经特征矩阵: (348, 7168)\n", " 音素logits矩阵: (348, 41)\n", " 音素类别标签: (348,) (值范围: 0-0)\n", " 音素分布: 1 个不同音素\n", " 样本最多的前5个音素:\n", " 音素 0: 348 次\n", "\n", "🔄 训练测试集切分:\n", " 训练集: 278 样本\n", " 测试集: 70 样本\n", " 音素类别数: 1\n", " 训练集音素分布: 1 个不同音素\n", " 测试集音素分布: 1 个不同音素\n", "\n", "======================================================================\n", "🌲 随机森林回归 + 分类\n", "======================================================================\n", "📊 方案1: 多输出回归 (神经特征 → 音素logits)\n", "🚀 训练回归模型...\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/tmp/ipykernel_186/3764425816.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"🚀 训练回归模型...\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0mstart_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 87\u001b[0;31m \u001b[0mrf_regressor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train_neu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_logits_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 88\u001b[0m \u001b[0mregression_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"✅ 回归训练完成!耗时: {regression_time:.2f} 秒\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/base.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1387\u001b[0m )\n\u001b[1;32m 1388\u001b[0m ):\n\u001b[0;32m-> 1389\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfit_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1390\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1391\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/multioutput.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, sample_weight, **fit_params)\u001b[0m\n\u001b[1;32m 272\u001b[0m \u001b[0mrouted_params\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"sample_weight\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 273\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 274\u001b[0;31m self.estimators_ = Parallel(n_jobs=self.n_jobs)(\n\u001b[0m\u001b[1;32m 275\u001b[0m delayed(_fit_estimator)(\n\u001b[1;32m 276\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mrouted_params\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/utils/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mdelayed_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 76\u001b[0m )\n\u001b[0;32m---> 77\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterable_with_config\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 78\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 79\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 1984\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_sequential_output\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1985\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1986\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreturn_generator\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1987\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1988\u001b[0m \u001b[0;31m# Let's create an ID that uniquely identifies the current call. If the\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m_get_sequential_output\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 1912\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_dispatched_batches\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1913\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_dispatched_tasks\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1914\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1915\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_completed_tasks\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1916\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_progress\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/utils/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0mconfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mconfig_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 139\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 140\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/multioutput.py\u001b[0m in \u001b[0;36m_fit_estimator\u001b[0;34m(estimator, X, y, sample_weight, **fit_params)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 64\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mestimator\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/base.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1387\u001b[0m )\n\u001b[1;32m 1388\u001b[0m ):\n\u001b[0;32m-> 1389\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfit_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1390\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1391\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/ensemble/_forest.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[0;31m# parallel_backend contexts set at a higher level,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 486\u001b[0m \u001b[0;31m# since correctness does not rely on using threads.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 487\u001b[0;31m trees = Parallel(\n\u001b[0m\u001b[1;32m 488\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_jobs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 489\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mverbose\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/utils/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mdelayed_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 76\u001b[0m )\n\u001b[0;32m---> 77\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterable_with_config\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 78\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 79\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 2070\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2071\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2072\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreturn_generator\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2073\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2074\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__repr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m_get_outputs\u001b[0;34m(self, iterator, pre_dispatch)\u001b[0m\n\u001b[1;32m 1680\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1681\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mretrieval_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1682\u001b[0;31m \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_retrieve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1683\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1684\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mGeneratorExit\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m_retrieve\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1798\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jobs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_status\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mTASK_PENDING\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1799\u001b[0m ):\n\u001b[0;32m-> 1800\u001b[0;31m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msleep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.01\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1801\u001b[0m \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1802\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "# 🎯 正确的任务:音素分类 (神经特征 → 音素置信度回归)\n", "\n", "print(\"\\n\" + \"=\"*70)\n", "print(\"🎯 重新定义任务:音素分类任务\")\n", "print(\"=\"*70)\n", "\n", "print(\"📋 任务重新定义:\")\n", "print(\" 输入: 神经特征 (前7168维)\")\n", "print(\" 输出: 音素置信度 (后41维RNN输出)\")\n", "print(\" 目标: 预测每个时间步的音素概率分布\")\n", "print(\" 分类: 选择最大置信度对应的音素\") \n", "\n", "# 重新准备数据\n", "print(f\"\\n🔧 数据重新处理:\")\n", "\n", "# 分离输入特征和目标\n", "X_neural = [] # 神经特征 (7168维)\n", "y_phoneme_logits = [] # 音素置信度 (41维)\n", "\n", "for features in valid_features:\n", " # features shape: [time_steps, 7209]\n", " neural_part = features[:, :7168] # 前7168维是神经特征\n", " rnn_part = features[:, 7168:] # 后41维是RNN输出(音素logits)\n", " \n", " # 对时间维度做平均\n", " neural_pooled = np.mean(neural_part, axis=0) # [7168]\n", " rnn_pooled = np.mean(rnn_part, axis=0) # [41]\n", " \n", " X_neural.append(neural_pooled)\n", " y_phoneme_logits.append(rnn_pooled)\n", "\n", "X_neural = np.array(X_neural) # [348, 7168]\n", "y_phoneme_logits = np.array(y_phoneme_logits) # [348, 41]\n", "\n", "print(f\" 神经特征矩阵: {X_neural.shape}\")\n", "print(f\" 音素logits矩阵: {y_phoneme_logits.shape}\")\n", "\n", "# 从音素logits得到分类标签\n", "y_phoneme_class = np.argmax(y_phoneme_logits, axis=1) # 选择最大值的索引\n", "print(f\" 音素类别标签: {y_phoneme_class.shape} (值范围: {y_phoneme_class.min()}-{y_phoneme_class.max()})\")\n", "\n", "# 显示音素分布\n", "from collections import Counter\n", "phoneme_dist = Counter(y_phoneme_class)\n", "print(f\" 音素分布: {len(phoneme_dist)} 个不同音素\")\n", "print(f\" 样本最多的前5个音素:\")\n", "for phoneme_id, count in phoneme_dist.most_common(5):\n", " print(f\" 音素 {phoneme_id}: {count} 次\")\n", "\n", "# 训练测试集切分\n", "print(f\"\\n🔄 训练测试集切分:\")\n", "X_train_neu, X_test_neu, y_logits_train, y_logits_test, y_class_train, y_class_test = train_test_split(\n", " X_neural, y_phoneme_logits, y_phoneme_class,\n", " test_size=0.2, random_state=42, stratify=y_phoneme_class\n", ")\n", "\n", "print(f\" 训练集: {X_train_neu.shape[0]} 样本\")\n", "print(f\" 测试集: {X_test_neu.shape[0]} 样本\")\n", "print(f\" 音素类别数: {len(np.unique(y_phoneme_class))}\")\n", "\n", "# 检查类别分布\n", "train_dist = Counter(y_class_train)\n", "test_dist = Counter(y_class_test)\n", "print(f\" 训练集音素分布: {len(train_dist)} 个不同音素\")\n", "print(f\" 测试集音素分布: {len(test_dist)} 个不同音素\")\n", "\n", "print(\"\\n\" + \"=\"*70)\n", "print(\"🌲 随机森林回归 + 分类\")\n", "print(\"=\"*70)\n", "\n", "# 方案1: 多输出回归 (预测41维音素logits)\n", "from sklearn.ensemble import RandomForestRegressor\n", "from sklearn.multioutput import MultiOutputRegressor\n", "\n", "print(\"📊 方案1: 多输出回归 (神经特征 → 音素logits)\")\n", "rf_regressor = MultiOutputRegressor(\n", " RandomForestRegressor(\n", " n_estimators=100,\n", " max_depth=10,\n", " random_state=42,\n", " n_jobs=-1\n", " )\n", ")\n", "\n", "print(\"🚀 训练回归模型...\")\n", "start_time = time.time()\n", "rf_regressor.fit(X_train_neu, y_logits_train)\n", "regression_time = time.time() - start_time\n", "print(f\"✅ 回归训练完成!耗时: {regression_time:.2f} 秒\")\n", "\n", "# 预测音素logits\n", "print(\"🔮 预测音素置信度...\")\n", "y_logits_pred_train = rf_regressor.predict(X_train_neu)\n", "y_logits_pred_test = rf_regressor.predict(X_test_neu)\n", "\n", "# 从预测的logits得到分类结果\n", "y_class_pred_train = np.argmax(y_logits_pred_train, axis=1)\n", "y_class_pred_test = np.argmax(y_logits_pred_test, axis=1)\n", "\n", "# 评估分类性能\n", "from sklearn.metrics import accuracy_score, classification_report\n", "\n", "train_acc = accuracy_score(y_class_train, y_class_pred_train)\n", "test_acc = accuracy_score(y_class_test, y_class_pred_test)\n", "\n", "print(f\"\\n📊 音素分类性能评估:\")\n", "print(f\" 训练集准确率: {train_acc:.4f} ({train_acc*100:.2f}%)\")\n", "print(f\" 测试集准确率: {test_acc:.4f} ({test_acc*100:.2f}%)\")\n", "print(f\" 过拟合程度: {(train_acc - test_acc)*100:.2f}%\")\n", "\n", "# 评估回归性能\n", "from sklearn.metrics import mean_squared_error, r2_score\n", "\n", "mse_train = mean_squared_error(y_logits_train, y_logits_pred_train)\n", "mse_test = mean_squared_error(y_logits_test, y_logits_pred_test)\n", "r2_train = r2_score(y_logits_train, y_logits_pred_train)\n", "r2_test = r2_score(y_logits_test, y_logits_pred_test)\n", "\n", "print(f\"\\n📈 音素logits回归性能:\")\n", "print(f\" 训练集 MSE: {mse_train:.6f}\")\n", "print(f\" 测试集 MSE: {mse_test:.6f}\")\n", "print(f\" 训练集 R²: {r2_train:.4f}\")\n", "print(f\" 测试集 R²: {r2_test:.4f}\")\n", "\n", "print(f\"\\n✨ 任务修正完成!现在是正确的音素分类任务\")" ] } ], "metadata": { "kaggle": { "accelerator": "tpu1vmV38", "dataSources": [ { "databundleVersionId": 13056355, "sourceId": 106809, "sourceType": "competition" } ], "dockerImageVersionId": 31091, "isGpuEnabled": false, "isInternetEnabled": true, "language": "python", "sourceType": "notebook" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }