Files
b2txt25/brain-to-text-25/brain-to-text-25.ipynb
2025-10-06 15:17:44 +08:00

4148 lines
221 KiB
Plaintext
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 环境配置 与 utils"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://download.pytorch.org/whl/cu126\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n",
"Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (0.21.0+cu124)\n",
"Requirement already satisfied: torchaudio in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n",
"Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.14.0)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.5)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.5.1)\n",
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
"Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch) (9.1.0.70)\n",
"Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.5.8)\n",
"Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch) (11.2.1.3)\n",
"Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch) (10.3.5.147)\n",
"Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch) (11.6.1.9)\n",
"Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch) (12.3.1.170)\n",
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n",
"Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n",
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
"Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n",
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n",
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision) (1.26.4)\n",
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (11.2.1)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n",
"Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.3.8)\n",
"Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.2.4)\n",
"Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (0.1.1)\n",
"Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2025.2.0)\n",
"Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2022.2.0)\n",
"Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2.4.1)\n",
"Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2024.2.0)\n",
"Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2022.2.0)\n",
"Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy->torchvision) (1.4.0)\n",
"Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy->torchvision) (2024.2.0)\n",
"Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy->torchvision) (2024.2.0)\n",
"Requirement already satisfied: jupyter==1.1.1 in /usr/local/lib/python3.11/dist-packages (1.1.1)\n",
"Requirement already satisfied: numpy<2.1.0,>=1.26.0 in /usr/local/lib/python3.11/dist-packages (1.26.4)\n",
"Requirement already satisfied: pandas==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n",
"Requirement already satisfied: matplotlib==3.10.1 in /usr/local/lib/python3.11/dist-packages (3.10.1)\n",
"Requirement already satisfied: scipy==1.15.2 in /usr/local/lib/python3.11/dist-packages (1.15.2)\n",
"Requirement already satisfied: scikit-learn==1.6.1 in /usr/local/lib/python3.11/dist-packages (1.6.1)\n",
"Requirement already satisfied: tqdm==4.67.1 in /usr/local/lib/python3.11/dist-packages (4.67.1)\n",
"Requirement already satisfied: g2p_en==2.1.0 in /usr/local/lib/python3.11/dist-packages (2.1.0)\n",
"Requirement already satisfied: h5py==3.13.0 in /usr/local/lib/python3.11/dist-packages (3.13.0)\n",
"Requirement already satisfied: omegaconf==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n",
"Requirement already satisfied: editdistance==0.8.1 in /usr/local/lib/python3.11/dist-packages (0.8.1)\n",
"Requirement already satisfied: huggingface-hub==0.33.1 in /usr/local/lib/python3.11/dist-packages (0.33.1)\n",
"Requirement already satisfied: transformers==4.53.0 in /usr/local/lib/python3.11/dist-packages (4.53.0)\n",
"Requirement already satisfied: tokenizers==0.21.2 in /usr/local/lib/python3.11/dist-packages (0.21.2)\n",
"Requirement already satisfied: accelerate==1.8.1 in /usr/local/lib/python3.11/dist-packages (1.8.1)\n",
"Requirement already satisfied: bitsandbytes==0.46.0 in /usr/local/lib/python3.11/dist-packages (0.46.0)\n",
"Requirement already satisfied: notebook in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.5.4)\n",
"Requirement already satisfied: jupyter-console in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.1.0)\n",
"Requirement already satisfied: nbconvert in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.4.5)\n",
"Requirement already satisfied: ipykernel in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.17.1)\n",
"Requirement already satisfied: ipywidgets in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (8.1.5)\n",
"Requirement already satisfied: jupyterlab in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (3.6.8)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2.9.0.post0)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n",
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.3.2)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (0.12.1)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (4.58.4)\n",
"Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.4.8)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (25.0)\n",
"Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (11.2.1)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (3.0.9)\n",
"Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (1.5.1)\n",
"Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (3.6.0)\n",
"Requirement already satisfied: nltk>=3.2.4 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (3.9.1)\n",
"Requirement already satisfied: inflect>=0.3.1 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (7.5.0)\n",
"Requirement already satisfied: distance>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (0.1.3)\n",
"Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (4.9.3)\n",
"Requirement already satisfied: PyYAML>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (6.0.2)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (3.18.0)\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2025.5.1)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2.32.4)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (4.14.0)\n",
"Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (1.1.5)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (2024.11.6)\n",
"Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (0.5.3)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (7.0.0)\n",
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (2.6.0+cu124)\n",
"Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.3.8)\n",
"Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.2.4)\n",
"Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (0.1.1)\n",
"Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2025.2.0)\n",
"Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2022.2.0)\n",
"Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2.4.1)\n",
"Requirement already satisfied: more_itertools>=8.5.0 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (10.7.0)\n",
"Requirement already satisfied: typeguard>=4.0.1 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (4.4.4)\n",
"Requirement already satisfied: click in /usr/local/lib/python3.11/dist-packages (from nltk>=3.2.4->g2p_en==2.1.0) (8.2.1)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas==2.3.0) (1.17.0)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.5)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.1.6)\n",
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
"Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (9.1.0.70)\n",
"Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.5.8)\n",
"Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.2.1.3)\n",
"Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (10.3.5.147)\n",
"Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.6.1.9)\n",
"Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.3.1.170)\n",
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (0.6.2)\n",
"Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (2.21.5)\n",
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
"Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.2.0)\n",
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (1.13.1)\n",
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=2.0.0->accelerate==1.8.1) (1.3.0)\n",
"Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.8.0)\n",
"Requirement already satisfied: ipython>=7.23.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (7.34.0)\n",
"Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (8.6.3)\n",
"Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (0.1.7)\n",
"Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.6.0)\n",
"Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (24.0.1)\n",
"Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (6.5.1)\n",
"Requirement already satisfied: traitlets>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (5.7.1)\n",
"Requirement already satisfied: comm>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (0.2.2)\n",
"Requirement already satisfied: widgetsnbextension~=4.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (4.0.14)\n",
"Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (3.0.15)\n",
"Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (3.0.51)\n",
"Requirement already satisfied: pygments in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (2.19.2)\n",
"Requirement already satisfied: jupyter-core in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (5.8.1)\n",
"Requirement already satisfied: jupyterlab-server~=2.19 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.27.3)\n",
"Requirement already satisfied: jupyter-server<3,>=1.16.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.12.5)\n",
"Requirement already satisfied: jupyter-ydoc~=0.2.4 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.2.5)\n",
"Requirement already satisfied: jupyter-server-ydoc~=0.8.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.8.0)\n",
"Requirement already satisfied: nbclassic in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (1.3.1)\n",
"Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (25.1.0)\n",
"Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.2.0)\n",
"Requirement already satisfied: nbformat in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (5.10.4)\n",
"Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (1.8.3)\n",
"Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.18.1)\n",
"Requirement already satisfied: prometheus-client in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.22.1)\n",
"Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.8.4)\n",
"Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.3.0)\n",
"Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.4)\n",
"Requirement already satisfied: bleach in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (6.2.0)\n",
"Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (1.5.1)\n",
"Requirement already satisfied: testpath in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.6.0)\n",
"Requirement already satisfied: defusedxml in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.7.1)\n",
"Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (4.13.4)\n",
"Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.5.13)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (3.0.2)\n",
"Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
"Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2022.2.0)\n",
"Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy<2.1.0,>=1.26.0) (1.4.0)\n",
"Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
"Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.4.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.10)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2.5.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2025.6.15)\n",
"Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
"Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (75.2.0)\n",
"Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.19.2)\n",
"Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.4.2)\n",
"Requirement already satisfied: pickleshare in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.7.5)\n",
"Requirement already satisfied: backcall in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.2.0)\n",
"Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.9.0)\n",
"Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.11/dist-packages (from jupyter-core->jupyterlab->jupyter==1.1.1) (4.3.8)\n",
"Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (4.9.0)\n",
"Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.12.0)\n",
"Requirement already satisfied: jupyter-server-terminals in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.5.3)\n",
"Requirement already satisfied: overrides in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (7.7.0)\n",
"Requirement already satisfied: websocket-client in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.8.0)\n",
"Requirement already satisfied: jupyter-server-fileid<1,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.9.3)\n",
"Requirement already satisfied: ypy-websocket<0.9.0,>=0.8.2 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.8.4)\n",
"Requirement already satisfied: y-py<0.7.0,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-ydoc~=0.2.4->jupyterlab->jupyter==1.1.1) (0.6.2)\n",
"Requirement already satisfied: babel>=2.10 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2.17.0)\n",
"Requirement already satisfied: json5>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.12.0)\n",
"Requirement already satisfied: jsonschema>=4.18.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (4.24.0)\n",
"Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.11/dist-packages (from nbclassic->jupyterlab->jupyter==1.1.1) (0.2.4)\n",
"Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.11/dist-packages (from nbformat->notebook->jupyter==1.1.1) (2.21.1)\n",
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->jupyter-console->jupyter==1.1.1) (0.2.13)\n",
"Requirement already satisfied: ptyprocess in /usr/local/lib/python3.11/dist-packages (from terminado>=0.8.3->notebook->jupyter==1.1.1) (0.7.0)\n",
"Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.11/dist-packages (from argon2-cffi->notebook->jupyter==1.1.1) (21.2.0)\n",
"Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->nbconvert->jupyter==1.1.1) (2.7)\n",
"Requirement already satisfied: webencodings in /usr/local/lib/python3.11/dist-packages (from bleach->nbconvert->jupyter==1.1.1) (0.5.1)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio>=3.1.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.1)\n",
"Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.11/dist-packages (from jedi>=0.16->ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.8.4)\n",
"Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (25.3.0)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2025.4.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.36.2)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.25.1)\n",
"Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.3.0)\n",
"Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.4)\n",
"Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.1)\n",
"Requirement already satisfied: aiofiles<23,>=22.1.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (22.1.0)\n",
"Requirement already satisfied: aiosqlite<1,>=0.17.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.21.0)\n",
"Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (1.17.1)\n",
"Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (2.22)\n",
"Requirement already satisfied: fqdn in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.5.1)\n",
"Requirement already satisfied: isoduration in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (20.11.0)\n",
"Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.0.0)\n",
"Requirement already satisfied: uri-template in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n",
"Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (24.11.1)\n",
"Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.11/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n",
"Requirement already satisfied: types-python-dateutil>=2.8.10 in /usr/local/lib/python3.11/dist-packages (from arrow>=0.15.0->isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (2.9.0.20250516)\n",
"Obtaining file:///kaggle/working/nejm-brain-to-text\n",
" Preparing metadata (setup.py): started\n",
" Preparing metadata (setup.py): finished with status 'done'\n",
"Installing collected packages: nejm_b2txt_utils\n",
" Running setup.py develop for nejm_b2txt_utils\n",
"Successfully installed nejm_b2txt_utils-0.0.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Cloning into 'nejm-brain-to-text'...\n"
]
}
],
"source": [
"%%bash\n",
"cd /kaggle/working/\n",
"# rm -rf /kaggle/working/nejm-brain-to-text/\n",
"# git clone https://github.com/ZH-CEN/nejm-brain-to-text.git\n",
"# cd /kaggle/working/nejm-brain-to-text/\n",
"# cp /kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl /kaggle/working/nejm-brain-to-text/data/t15_copyTask.pkl\n",
"\n",
"# Install the local package\n",
"\n",
"# ln -s /kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline /kaggle/working/nejm-brain-to-text/data\n",
"# ln -s /kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final /kaggle/working/nejm-brain-to-text/data\n",
"# ln -s /kaggle/input/rnn-pretagged-data /kaggle/working/nejm-brain-to-text/data\n",
"\n",
"# # Install PyTorch with CUDA 12.6\n",
"pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n",
"\n",
"# Install additional packages with compatible versions\n",
"# TODO: remove redis\n",
"pip install \\\n",
" jupyter==1.1.1 \\\n",
" \"numpy>=1.26.0,<2.1.0\" \\\n",
" pandas==2.3.0 \\\n",
" matplotlib==3.10.1 \\\n",
" scipy==1.15.2 \\\n",
" scikit-learn==1.6.1 \\\n",
" tqdm==4.67.1 \\\n",
" g2p_en==2.1.0 \\\n",
" h5py==3.13.0 \\\n",
" omegaconf==2.3.0 \\\n",
" editdistance==0.8.1 \\\n",
" huggingface-hub==0.33.1 \\\n",
" transformers==4.53.0 \\\n",
" tokenizers==0.21.2 \\\n",
" accelerate==1.8.1 \\\n",
" bitsandbytes==0.46.0\n",
"pip install -e ."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/kaggle/working/nejm-brain-to-text\n"
]
}
],
"source": [
"%cd /kaggle/working/nejm-brain-to-text\n",
"import numpy as np\n",
"import os\n",
"import pickle\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib\n",
"from g2p_en import G2p\n",
"import pandas as pd\n",
"import numpy as np\n",
"from nejm_b2txt_utils.general_utils import *\n",
"\n",
"matplotlib.rcParams['pdf.fonttype'] = 42\n",
"matplotlib.rcParams['ps.fonttype'] = 42\n",
"matplotlib.rcParams['font.family'] = 'sans-serif'\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/kaggle/working/nejm-brain-to-text/model_training\n"
]
}
],
"source": [
"%cd model_training/\n",
"from data_augmentations import gauss_smooth\n",
"# single decoding step function that also returns smoothed input\n",
"# smooths data and puts it through the model, returning both logits and smoothed input.\n",
"def runSingleDecodingStepWithSmoothedInput(x, input_layer, model, model_args, device):\n",
"\n",
" # Use autocast for efficiency\n",
" with torch.autocast(device_type = \"cuda\", enabled = model_args['use_amp'], dtype = torch.bfloat16):\n",
"\n",
" smoothed_x = gauss_smooth(\n",
" inputs = x, \n",
" device = device,\n",
" smooth_kernel_std = model_args['dataset']['data_transforms']['smooth_kernel_std'],\n",
" smooth_kernel_size = model_args['dataset']['data_transforms']['smooth_kernel_size'],\n",
" padding = 'valid',\n",
" )\n",
"\n",
" with torch.no_grad():\n",
" logits, _ = model(\n",
" x = smoothed_x,\n",
" day_idx = torch.tensor([input_layer], device=device),\n",
" states = None, # no initial states\n",
" return_state = True,\n",
" )\n",
"\n",
" # convert both logits and smoothed input from bfloat16 to float32\n",
" logits = logits.float().cpu().numpy()\n",
" smoothed_input = smoothed_x.float().cpu().numpy()\n",
"\n",
" # # original order is [BLANK, phonemes..., SIL]\n",
" # # rearrange so the order is [BLANK, SIL, phonemes...]\n",
" # logits = rearrange_speech_logits_pt(logits)\n",
"\n",
" return logits, smoothed_input\n",
"\n",
"\n",
"import h5py\n",
"def load_h5py_file(file_path, b2txt_csv_df):\n",
" data = {\n",
" 'neural_features': [],\n",
" 'n_time_steps': [],\n",
" 'seq_class_ids': [],\n",
" 'seq_len': [],\n",
" 'transcriptions': [],\n",
" 'sentence_label': [],\n",
" 'session': [],\n",
" 'block_num': [],\n",
" 'trial_num': [],\n",
" 'corpus': [],\n",
" }\n",
" # Open the hdf5 file for that day\n",
" with h5py.File(file_path, 'r') as f:\n",
"\n",
" keys = list(f.keys())\n",
"\n",
" # For each trial in the selected trials in that day\n",
" for key in keys:\n",
" g = f[key]\n",
"\n",
" neural_features = g['input_features'][:] # pyright: ignore[reportIndexIssue]\n",
" n_time_steps = g.attrs['n_time_steps']\n",
" seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None # type: ignore\n",
" seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None\n",
" transcription = g['transcription'][:] if 'transcription' in g else None # type: ignore\n",
" sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None # pyright: ignore[reportIndexIssue]\n",
" session = g.attrs['session']\n",
" block_num = g.attrs['block_num']\n",
" trial_num = g.attrs['trial_num']\n",
"\n",
" # match this trial up with the csv to get the corpus name\n",
" year, month, day = session.split('.')[1:] # pyright: ignore[reportAttributeAccessIssue]\n",
" date = f'{year}-{month}-{day}'\n",
" row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)]\n",
" corpus_name = row['Corpus'].values[0]\n",
"\n",
" data['neural_features'].append(neural_features)\n",
" data['n_time_steps'].append(n_time_steps)\n",
" data['seq_class_ids'].append(seq_class_ids)\n",
" data['seq_len'].append(seq_len)\n",
" data['transcriptions'].append(transcription)\n",
" data['sentence_label'].append(sentence_label)\n",
" data['session'].append(session)\n",
" data['block_num'].append(block_num)\n",
" data['trial_num'].append(trial_num)\n",
" data['corpus'].append(corpus_name)\n",
" return data"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"LOGIT_TO_PHONEME = [\n",
" 'BLANK',\n",
" 'AA', 'AE', 'AH', 'AO', 'AW',\n",
" 'AY', 'B', 'CH', 'D', 'DH',\n",
" 'EH', 'ER', 'EY', 'F', 'G',\n",
" 'HH', 'IH', 'IY', 'JH', 'K',\n",
" 'L', 'M', 'N', 'NG', 'OW',\n",
" 'OY', 'P', 'R', 'S', 'SH',\n",
" 'T', 'TH', 'UH', 'UW', 'V',\n",
" 'W', 'Y', 'Z', 'ZH',\n",
" ' | ',\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 数据分析与预处理"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 数据准备"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/kaggle/working/nejm-brain-to-text\n"
]
}
],
"source": [
"%cd /kaggle/working/nejm-brain-to-text/"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"data = load_h5py_file(file_path='data/hdf5_data_final/t15.2023.08.11/data_train.hdf5',\n",
" b2txt_csv_df=pd.read_csv('data/t15_copyTaskData_description.csv'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- **任务介绍** :机器学习解决高维信号的模式识别问题"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"我们的数据集标签缺少时间戳,现在要进行的是半监督学习"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 音素时间均等分割或者按照调研数据设定初始长度。然后筛掉异常值。提取出可用的训练集,再控制时间长短,查看样本类的长度"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'neural_features': array([[ 2.3076649 , -0.78699756, -0.64687246, ..., 0.57367045,\n",
" -0.7091646 , -0.11018186],\n",
" [-0.5859305 , -0.78699756, -0.64687246, ..., 0.3122117 ,\n",
" 1.7943763 , -0.76884896],\n",
" [-0.5859305 , -0.78699756, -0.64687246, ..., -0.21193463,\n",
" -0.8481289 , -0.7648201 ],\n",
" ...,\n",
" [-0.5859305 , 0.22756557, 0.9262037 , ..., -0.34710956,\n",
" 0.9710176 , 2.5397465 ],\n",
" [-0.5859305 , 0.22756557, -0.64687246, ..., -0.83613133,\n",
" -0.68723625, 0.10479005],\n",
" [ 0.8608672 , -0.78699756, -0.64687246, ..., -0.7171131 ,\n",
" 0.7417906 , -0.7008622 ]], dtype=float32),\n",
" 'n_time_steps': 321,\n",
" 'seq_class_ids': array([ 7, 28, 17, 24, 40, 17, 31, 40, 20, 21, 25, 29, 12, 40, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0], dtype=int32),\n",
" 'seq_len': 14,\n",
" 'transcriptions': array([ 66, 114, 105, 110, 103, 32, 105, 116, 32, 99, 108, 111, 115,\n",
" 101, 114, 46, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0], dtype=int32),\n",
" 'sentence_label': 'Bring it closer.',\n",
" 'session': 't15.2023.08.11',\n",
" 'block_num': 2,\n",
" 'trial_num': 0,\n",
" 'corpus': '50-Word'}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def data_patch(data, index):\n",
" data_patch = {}\n",
" data_patch['neural_features'] = data['neural_features'][index]\n",
" data_patch['n_time_steps'] = data['n_time_steps'][index]\n",
" data_patch['seq_class_ids'] = data['seq_class_ids'][index]\n",
" data_patch['seq_len'] = data['seq_len'][index]\n",
" data_patch['transcriptions'] = data['transcriptions'][index]\n",
" data_patch['sentence_label'] = data['sentence_label'][index]\n",
" data_patch['session'] = data['session'][index]\n",
" data_patch['block_num'] = data['block_num'][index]\n",
" data_patch['trial_num'] = data['trial_num'][index]\n",
" data_patch['corpus'] = data['corpus'][index]\n",
" return data_patch\n",
"\n",
"data_patch(data, 0)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"d1 = data_patch(data, 0)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Transcriptions non-zero length: 16\n",
"Seq class ids non-zero length: 14\n",
"Seq len: 14\n"
]
}
],
"source": [
"trans_len = len([x for x in d1['transcriptions'] if x != 0])\n",
"seq_len_nonzero = len([x for x in d1['seq_class_ids'] if x != 0])\n",
"seq_len = d1['seq_len']\n",
"print(f\"Transcriptions non-zero length: {trans_len}\")\n",
"print(f\"Seq class ids non-zero length: {seq_len_nonzero}\")\n",
"print(f\"Seq len: {seq_len}\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of feature sequences: 14\n",
"Shape of first sequence: (22, 512)\n"
]
}
],
"source": [
"def create_time_windows(d1):\n",
" import numpy as np\n",
" n_time_steps = d1['n_time_steps']\n",
" seq_len = d1['seq_len']\n",
" # Create equal windows\n",
" edges = np.linspace(0, n_time_steps, seq_len + 1, dtype=int)\n",
" windows = [(edges[i], edges[i+1]) for i in range(seq_len)]\n",
" \n",
" # Extract feature sequences for each window\n",
" feature_sequences = []\n",
" for start, end in windows:\n",
" seq = d1['neural_features'][start:end, :]\n",
" feature_sequences.append(seq)\n",
" \n",
" return feature_sequences\n",
"\n",
"# Example usage\n",
"feature_sequences = create_time_windows(d1)\n",
"print(\"Number of feature sequences:\", len(feature_sequences))\n",
"print(\"Shape of first sequence:\", feature_sequences[0].shape)\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train: 45, Val: 41, Test: 41\n",
"Train files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_train.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.08.11/data_train.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_train.hdf5']\n",
"Val files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_val.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_val.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2024.03.08/data_val.hdf5']\n",
"Test files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_test.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_test.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2024.03.08/data_test.hdf5']\n"
]
}
],
"source": [
"import os\n",
"\n",
"def scan_hdf5_files(base_path):\n",
" train_files = []\n",
" val_files = []\n",
" test_files = []\n",
" for root, dirs, files in os.walk(base_path):\n",
" for file in files:\n",
" if file.endswith('.hdf5'):\n",
" abs_path = os.path.abspath(os.path.join(root, file))\n",
" if 'data_train.hdf5' in file:\n",
" train_files.append(abs_path)\n",
" elif 'data_val.hdf5' in file:\n",
" val_files.append(abs_path)\n",
" elif 'data_test.hdf5' in file:\n",
" test_files.append(abs_path)\n",
" return train_files, val_files, test_files\n",
"\n",
"# Example usage\n",
"FILE_PATH = 'data/hdf5_data_final'\n",
"train_list, val_list, test_list = scan_hdf5_files(FILE_PATH)\n",
"print(f\"Train: {len(train_list)}, Val: {len(val_list)}, Test: {len(test_list)}\")\n",
"print(\"Train files (first 3):\", train_list[:3])\n",
"print(\"Val files (first 3):\", val_list[:3])\n",
"print(\"Test files (first 3):\", test_list[:3])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 标签处理"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# def classify_windows_by_labels(d1):\n",
"# seq_class_ids = d1['seq_class_ids'][:d1['seq_len']] # Take only the non-zero part\n",
"# windows = create_time_windows(d1)\n",
" \n",
"# classified_windows = {}\n",
"# for i, label in enumerate(seq_class_ids):\n",
"# char = LOGIT_TO_PHONEME[label]\n",
"# if char not in classified_windows:\n",
"# classified_windows[char] = []\n",
"# classified_windows[char].append(windows[i])\n",
" \n",
"# return classified_windows\n",
"\n",
"# # Example usage\n",
"# classified = classify_windows_by_labels(d1)\n",
"# print(\"Classified windows by label:\", classified)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# import pandas as pd\n",
"# import tqdm\n",
"\n",
"# b2txt_csv_df = pd.read_csv('data/t15_copyTaskData_description.csv')\n",
"\n",
"# def workflow(max_files=None):\n",
"# group_by_labels = {}\n",
"# files_to_process = train_list[:max_files] if max_files is not None else train_list\n",
"# for file_path in tqdm.tqdm(files_to_process):\n",
"# data = load_h5py_file(file_path, b2txt_csv_df)\n",
"# for i in tqdm.tqdm(range(len(data['neural_features'])), leave=False):\n",
"# # Process only the first trial for simplicity\n",
"# d1 = data_patch(data, i)\n",
"# classified = classify_windows_by_labels(d1)\n",
"# for key, value in classified.items():\n",
"# if key not in group_by_labels:\n",
"# group_by_labels[key] = []\n",
"# group_by_labels[key].extend(value)\n",
"# return group_by_labels\n",
"\n",
"# # Example usage\n",
"# result = workflow()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 核函数扭曲时间\n",
"控制音素时间长度相同"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# total = 0\n",
"# count = 0\n",
"# time_distribution = []\n",
"# for i in result.values():\n",
"# for j in i:\n",
"# total += j.shape[0]\n",
"# count += 1\n",
"# time_distribution.append(j.shape[0])\n",
"# print(f\"Total time steps: {total}, Total windows: {count}, Average window length: {total/count:.2f}\")\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# import numpy as np\n",
"# plt.figure(figsize=(12, 6))\n",
"# plt.hist(time_distribution, bins=50, alpha=0.7, edgecolor='black')\n",
"# plt.xlabel('Window Length (time steps)')\n",
"# plt.ylabel('Frequency')\n",
"# plt.title('Distribution of Time Window Lengths')\n",
"# plt.grid(True, alpha=0.3)\n",
"# plt.axvline(np.mean(time_distribution), color='red', linestyle='--', label=f'Mean: {np.mean(time_distribution):.2f}')\n",
"# plt.axvline(np.median(time_distribution), color='green', linestyle='--', label=f'Median: {np.median(time_distribution):.2f}')\n",
"# plt.legend()\n",
"# plt.show()\n",
"\n",
"# print(f\"Statistics:\")\n",
"# print(f\"Mean: {np.mean(time_distribution):.2f}\")\n",
"# print(f\"Median: {np.median(time_distribution):.2f}\")\n",
"# print(f\"Std: {np.std(time_distribution):.2f}\")\n",
"# print(f\"Min: {np.min(time_distribution)}\")\n",
"# print(f\"Max: {np.max(time_distribution)}\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"# # 把时间序列对齐到这个长度上,然后做聚类去异常值\n",
"# MEAN_WINDOWS_SIZE = 33\n",
"\n",
"# from scipy import interpolate\n",
"# import numpy as np\n",
"\n",
"# def kernel_time_warp(sequence, target_length=33):\n",
"# \"\"\"\n",
"# 使用核函数处理时间序列将序列长度标准化到target_length\n",
"# - 对于长度 < target_length 的序列:使用插值扩展\n",
"# - 对于长度 > target_length 的序列:使用压缩采样\n",
"# - 对于长度 = target_length 的序列:直接返回\n",
" \n",
"# Args:\n",
"# sequence: 输入时间序列 (time_steps, features)\n",
"# target_length: 目标长度\n",
" \n",
"# Returns:\n",
"# warped_sequence: 处理后的时间序列 (target_length, features)\n",
"# \"\"\"\n",
"# original_length = sequence.shape[0]\n",
"# n_features = sequence.shape[1]\n",
" \n",
"# # 如果序列长度已经等于目标长度,直接返回\n",
"# if original_length == target_length:\n",
"# return sequence\n",
" \n",
"# # 处理边界情况\n",
"# if original_length == 0:\n",
"# return np.zeros((target_length, n_features))\n",
" \n",
"# if original_length == 1:\n",
"# # 如果只有一个时间步,复制到所有目标时间步\n",
"# return np.repeat(sequence, target_length, axis=0)\n",
" \n",
"# warped_sequence = np.zeros((target_length, n_features))\n",
" \n",
"# if original_length > target_length:\n",
"# # 压缩:长序列 -> 短序列\n",
"# # 使用均匀采样 + 局部平均的方式压缩\n",
"# compression_ratio = original_length / target_length\n",
" \n",
"# for i in range(target_length):\n",
"# # 计算当前目标位置对应的原始序列范围\n",
"# start_idx = int(i * compression_ratio)\n",
"# end_idx = int((i + 1) * compression_ratio)\n",
" \n",
"# # 确保不超出边界\n",
"# start_idx = max(0, start_idx)\n",
"# end_idx = min(original_length, end_idx)\n",
" \n",
"# if start_idx == end_idx:\n",
"# # 避免空范围\n",
"# end_idx = min(start_idx + 1, original_length)\n",
" \n",
"# # 对该范围内的数据取平均(压缩)\n",
"# warped_sequence[i] = np.mean(sequence[start_idx:end_idx], axis=0)\n",
" \n",
"# else:\n",
"# # 扩展:短序列 -> 长序列\n",
"# # 使用插值的方式扩展\n",
"# original_indices = np.linspace(0, 1, original_length)\n",
"# target_indices = np.linspace(0, 1, target_length)\n",
" \n",
"# for feature_idx in range(n_features):\n",
"# # 根据原始序列长度选择插值方法\n",
"# if original_length >= 3:\n",
"# # 对于长度>=3的序列使用三次样条插值\n",
"# interpolator = interpolate.interp1d(\n",
"# original_indices, \n",
"# sequence[:, feature_idx], \n",
"# kind='cubic', \n",
"# bounds_error=False, \n",
"# fill_value='extrapolate'\n",
"# )\n",
"# else:\n",
"# # 对于长度=2的序列使用线性插值\n",
"# interpolator = interpolate.interp1d(\n",
"# original_indices, \n",
"# sequence[:, feature_idx], \n",
"# kind='linear', \n",
"# bounds_error=False, \n",
"# fill_value='extrapolate'\n",
"# )\n",
" \n",
"# warped_sequence[:, feature_idx] = interpolator(target_indices)\n",
" \n",
"# return warped_sequence\n",
"\n",
"# def gaussian_kernel_weight(x, sigma=0.1):\n",
"# \"\"\"\n",
"# 高斯核函数权重,用于平滑处理\n",
"# \"\"\"\n",
"# return np.exp(-0.5 * (x / sigma) ** 2)\n",
"\n",
"# def process_result_with_kernel(result_dict, target_length=33):\n",
"# \"\"\"\n",
"# 使用核函数处理result字典中的所有时间序列\n",
" \n",
"# Args:\n",
"# result_dict: 包含时间序列的字典\n",
"# target_length: 目标长度\n",
" \n",
"# Returns:\n",
"# processed_result: 处理后的字典\n",
"# \"\"\"\n",
"# processed_result = {}\n",
" \n",
"# print(\"Processing time series with kernel warping...\")\n",
"# print(f\"Target length: {target_length}\")\n",
" \n",
"# # 统计不同长度的序列\n",
"# length_stats = {}\n",
"# total_sequences = 0\n",
" \n",
"# for label, sequences in tqdm.tqdm(result_dict.items()):\n",
"# processed_sequences = []\n",
" \n",
"# for seq in sequences:\n",
"# original_length = seq.shape[0]\n",
" \n",
"# # 统计长度分布\n",
"# if original_length not in length_stats:\n",
"# length_stats[original_length] = 0\n",
"# length_stats[original_length] += 1\n",
"# total_sequences += 1\n",
" \n",
"# # 应用核函数时间扭曲(包括压缩和插值)\n",
"# warped_seq = kernel_time_warp(seq, target_length)\n",
"# processed_sequences.append(warped_seq)\n",
" \n",
"# processed_result[label] = processed_sequences\n",
"# print(f\"Label '{label}': {len(sequences)} sequences -> {len(processed_sequences)} sequences\")\n",
" \n",
"# # 打印长度统计信息\n",
"# print(f\"\\n原始序列长度分布:\")\n",
"# sorted_lengths = sorted(length_stats.items())\n",
"# short_count = sum(count for length, count in sorted_lengths if length < target_length)\n",
"# equal_count = sum(count for length, count in sorted_lengths if length == target_length)\n",
"# long_count = sum(count for length, count in sorted_lengths if length > target_length)\n",
" \n",
"# print(f\"短于目标长度({target_length}),需要插值扩展: {short_count} 个序列\")\n",
"# print(f\"等于目标长度({target_length}),无需处理: {equal_count} 个序列\")\n",
"# print(f\"长于目标长度({target_length}),需要压缩: {long_count} 个序列\")\n",
"# print(f\"总序列数: {total_sequences}\")\n",
" \n",
"# if len(sorted_lengths) <= 20: # 如果长度种类不多,显示详细分布\n",
"# print(\"\\n详细长度分布:\")\n",
"# for length, count in sorted_lengths:\n",
"# percentage = (count / total_sequences) * 100\n",
"# operation = \"\"\n",
"# if length < target_length:\n",
"# operation = \" (插值扩展)\"\n",
"# elif length > target_length:\n",
"# operation = \" (压缩)\"\n",
"# else:\n",
"# operation = \" (无需处理)\"\n",
"# print(f\" 长度 {length}: {count} 个 ({percentage:.1f}%){operation}\")\n",
" \n",
"# return processed_result\n",
"\n",
"# # 处理result字典\n",
"# processed_result = process_result_with_kernel(result, MEAN_WINDOWS_SIZE)\n",
"\n",
"# # 验证处理结果\n",
"# print(\"\\n处理后的统计信息:\")\n",
"# total_sequences = 0\n",
"# for label, sequences in processed_result.items():\n",
"# total_sequences += len(sequences)\n",
"# if sequences: # 如果列表不为空\n",
"# print(f\"Label '{label}': {len(sequences)} sequences, shape: {sequences[0].shape}\")\n",
"\n",
"# print(f\"总共处理了 {total_sequences} 个时间序列,目标长度: {MEAN_WINDOWS_SIZE}\")\n",
"\n",
"# # 验证所有序列现在都是目标长度\n",
"# all_correct_length = True\n",
"# for label, sequences in processed_result.items():\n",
"# for seq in sequences:\n",
"# if seq.shape[0] != MEAN_WINDOWS_SIZE:\n",
"# print(f\"错误: 发现长度不正确的序列 - Label: {label}, Shape: {seq.shape}\")\n",
"# all_correct_length = False\n",
"# break\n",
"# if not all_correct_length:\n",
"# break\n",
"\n",
"# if all_correct_length:\n",
"# print(f\"✅ 验证通过: 所有序列长度都已正确调整为 {MEAN_WINDOWS_SIZE}\")\n",
"# print(\" - 长序列已通过压缩(局部平均)缩短\")\n",
"# print(\" - 短序列已通过插值扩展\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"去掉异常的标签去除的时候记得保存正常标签的原始索引我们可能不用去除后的来训练模型。而是训练适应多个时间大小窗口的模型通过单独扫描的WER大小来确定权重。再把两个一起扫描按照权重赋值同样用极大值抑制或者CTC来处理。\n",
"建议CTC。毕竟我们多个长度窗口的模型已经和RNN差距不大了。"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"# import pickle\n",
"# with open('processed_time_series.pkl', 'wb') as f:\n",
"# pickle.dump(processed_result, f)\n",
" \n",
"# with open('time_series_format.pkl', 'wb') as f:\n",
"# pickle.dump(result, f)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"# processed_result.keys()"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# from sklearn.cluster import KMeans\n",
"# from sklearn.metrics import silhouette_score\n",
"# from sklearn.metrics import silhouette_samples\n",
"\n",
"# # for key, value in processed_result.items():\n",
"# key = 'IH'\n",
"# value = processed_result[key]\n",
"# print(f\"Label: {key}, Number of sequences: {len(value)}, Shape of first sequence: {value[0].shape if value else 'N/A'}\")\n",
"\n",
"# # Apply KMeans clustering to the sequences for the selected label\n",
"\n",
"# # First, we need to reshape the sequences to 2D for clustering\n",
"# # Since all sequences now have the same length (33), we can flatten them\n",
"# sequences = value\n",
"# flattened_sequences = []\n",
"\n",
"# for seq in sequences:\n",
"# # Flatten each sequence to 1D\n",
"# flattened_seq = seq.flatten()\n",
"# flattened_sequences.append(flattened_seq)\n",
"\n",
"# flattened_sequences = np.array(flattened_sequences)\n",
"# print(f\"Flattened sequences shape: {flattened_sequences.shape}\")\n",
"\n",
"# # Perform KMeans clustering\n",
"# n_clusters = 5 # You can adjust this number\n",
"# kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)\n",
"# cluster_labels = kmeans.fit_predict(flattened_sequences)\n",
"\n",
"# # Calculate silhouette score to evaluate clustering quality\n",
"# silhouette_avg = silhouette_score(flattened_sequences, cluster_labels)\n",
"\n",
"# print(f\"Number of clusters: {n_clusters}\")\n",
"# print(f\"Silhouette score: {silhouette_avg:.3f}\")\n",
"# print(f\"Cluster distribution: {np.bincount(cluster_labels)}\")\n",
"\n",
"# # Visualize clustering results\n",
"# plt.figure(figsize=(12, 8))\n",
"\n",
"# # Plot 1: Cluster distribution\n",
"# plt.subplot(2, 2, 1)\n",
"# unique_labels, counts = np.unique(cluster_labels, return_counts=True)\n",
"# plt.bar(unique_labels, counts)\n",
"# plt.xlabel('Cluster')\n",
"# plt.ylabel('Number of Sequences')\n",
"# plt.title(f'Cluster Distribution for Label \"{key}\"')\n",
"\n",
"# # Plot 2: First few dimensions of the data colored by cluster\n",
"# plt.subplot(2, 2, 2)\n",
"# for i in range(n_clusters):\n",
"# mask = cluster_labels == i\n",
"# plt.scatter(flattened_sequences[mask, 0], flattened_sequences[mask, 1], \n",
"# label=f'Cluster {i}', alpha=0.6)\n",
"# plt.xlabel('Feature 0')\n",
"# plt.ylabel('Feature 1')\n",
"# plt.title('Clusters in Feature Space (First 2 Dimensions)')\n",
"# plt.legend()\n",
"\n",
"# # Plot 3: Silhouette analysis\n",
"# plt.subplot(2, 2, 3)\n",
"# silhouette_vals = silhouette_samples(flattened_sequences, cluster_labels)\n",
"# y_lower = 10\n",
"# for i in range(n_clusters):\n",
"# cluster_silhouette_vals = silhouette_vals[cluster_labels == i]\n",
"# cluster_silhouette_vals.sort()\n",
" \n",
"# size_cluster_i = cluster_silhouette_vals.shape[0]\n",
"# y_upper = y_lower + size_cluster_i\n",
" \n",
"# color = plt.cm.nipy_spectral(float(i) / n_clusters)\n",
"# plt.fill_betweenx(np.arange(y_lower, y_upper), 0, cluster_silhouette_vals,\n",
"# facecolor=color, edgecolor=color, alpha=0.7)\n",
" \n",
"# plt.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))\n",
"# y_lower = y_upper + 10\n",
"\n",
"# plt.axvline(x=silhouette_avg, color=\"red\", linestyle=\"--\")\n",
"# plt.xlabel('Silhouette coefficient values')\n",
"# plt.ylabel('Cluster label')\n",
"# plt.title('Silhouette Analysis')\n",
"\n",
"# # Plot 4: Try different numbers of clusters\n",
"# plt.subplot(2, 2, 4)\n",
"# cluster_range = range(2, min(11, len(sequences)//2))\n",
"# silhouette_scores = []\n",
"# inertias = []\n",
"\n",
"# for n_clust in cluster_range:\n",
"# kmeans_temp = KMeans(n_clusters=n_clust, random_state=42, n_init=10)\n",
"# cluster_labels_temp = kmeans_temp.fit_predict(flattened_sequences)\n",
"# silhouette_scores.append(silhouette_score(flattened_sequences, cluster_labels_temp))\n",
"# inertias.append(kmeans_temp.inertia_)\n",
"\n",
"# plt.plot(cluster_range, silhouette_scores, 'bo-')\n",
"# plt.xlabel('Number of Clusters')\n",
"# plt.ylabel('Average Silhouette Score')\n",
"# plt.title('Silhouette Score vs Number of Clusters')\n",
"# plt.grid(True)\n",
"\n",
"# plt.tight_layout()\n",
"# plt.show()\n",
"\n",
"# # Print cluster centers information\n",
"# print(f\"\\nCluster centers shape: {kmeans.cluster_centers_.shape}\")\n",
"# print(f\"Each cluster center represents the average pattern for sequences in that cluster\")"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"# # 余弦距离分析和距离度量对比\n",
"# print(\"\\n\" + \"=\"*70)\n",
"# print(\"余弦距离分析\")\n",
"# print(\"=\"*70)\n",
"\n",
"# # 5. 计算余弦距离\n",
"# print(\"\\n计算余弦距离...\")\n",
"# cosine_matrix, phonemes_cosine = calculate_inter_class_distances(centroids, 'cosine')\n",
"\n",
"# # 6. 可视化余弦距离\n",
"# print(\"\\n可视化余弦距离矩阵...\")\n",
"# df_cosine, most_sim_cos, most_diff_cos = visualize_distance_matrix(\n",
"# cosine_matrix, phonemes_cosine, 'cosine'\n",
"# )\n",
"\n",
"# # 7. 分析相似音素群组(余弦距离)\n",
"# similar_pairs_cos, threshold_cos = analyze_phoneme_groups(cosine_matrix, phonemes_cosine)\n",
"\n",
"# print(f\"\\n余弦距离分析完成!\")\n",
"# print(f\"最相似音素对: {most_sim_cos}\")\n",
"# print(f\"最不相似音素对: {most_diff_cos}\")\n",
"\n",
"# # 8. 比较两种距离度量\n",
"# print(\"\\n\" + \"=\"*70)\n",
"# print(\"距离度量对比分析\")\n",
"# print(\"=\"*70)\n",
"\n",
"# def compare_distance_metrics(euclidean_matrix, cosine_matrix, phonemes):\n",
"# \"\"\"\n",
"# 比较不同距离度量的结果\n",
"# \"\"\"\n",
"# # 提取上三角矩阵的距离值\n",
"# upper_triangle_indices = np.triu_indices_from(euclidean_matrix, k=1)\n",
"# euclidean_distances = euclidean_matrix[upper_triangle_indices]\n",
"# cosine_distances = cosine_matrix[upper_triangle_indices]\n",
" \n",
"# # 计算相关性\n",
"# correlation = np.corrcoef(euclidean_distances, cosine_distances)[0, 1]\n",
" \n",
"# # 创建比较图\n",
"# fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n",
" \n",
"# # 图1: 距离分布对比\n",
"# ax1 = axes[0, 0]\n",
"# ax1.hist(euclidean_distances, bins=30, alpha=0.6, label='欧氏距离', color='blue')\n",
"# ax1.hist(cosine_distances, bins=30, alpha=0.6, label='余弦距离', color='red')\n",
"# ax1.set_xlabel('距离值')\n",
"# ax1.set_ylabel('频次')\n",
"# ax1.set_title('距离分布对比')\n",
"# ax1.legend()\n",
"# ax1.grid(True, alpha=0.3)\n",
" \n",
"# # 图2: 距离相关性散点图\n",
"# ax2 = axes[0, 1]\n",
"# ax2.scatter(euclidean_distances, cosine_distances, alpha=0.6, s=10)\n",
"# ax2.set_xlabel('欧氏距离')\n",
"# ax2.set_ylabel('余弦距离')\n",
"# ax2.set_title(f'距离度量相关性 (r={correlation:.3f})')\n",
"# ax2.grid(True, alpha=0.3)\n",
" \n",
"# # 添加拟合线\n",
"# z = np.polyfit(euclidean_distances, cosine_distances, 1)\n",
"# p = np.poly1d(z)\n",
"# ax2.plot(euclidean_distances, p(euclidean_distances), \"r--\", alpha=0.8)\n",
" \n",
"# # 图3: 最相似音素对比较\n",
"# ax3 = axes[1, 0]\n",
"# ax3.axis('off')\n",
" \n",
"# # 获取每种距离度量下的前10对最相似音素\n",
"# eucl_top10 = similar_pairs_eucl[:10]\n",
"# cos_top10 = similar_pairs_cos[:10]\n",
" \n",
"# comparison_text = \"最相似音素对比较 (前10对)\\n\\n\"\n",
"# comparison_text += f\"{'欧氏距离':<30} {'余弦距离':<30}\\n\"\n",
"# comparison_text += \"-\" * 60 + \"\\n\"\n",
" \n",
"# for i in range(min(10, len(eucl_top10), len(cos_top10))):\n",
"# eucl_pair = f\"{eucl_top10[i][0]}-{eucl_top10[i][1]} ({eucl_top10[i][2]:.3f})\"\n",
"# cos_pair = f\"{cos_top10[i][0]}-{cos_top10[i][1]} ({cos_top10[i][2]:.4f})\"\n",
"# comparison_text += f\"{eucl_pair:<30} {cos_pair:<30}\\n\"\n",
" \n",
"# ax3.text(0.05, 0.95, comparison_text, transform=ax3.transAxes, fontsize=9,\n",
"# verticalalignment='top', fontfamily='monospace',\n",
"# bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))\n",
" \n",
"# # 图4: 统计对比\n",
"# ax4 = axes[1, 1]\n",
"# ax4.axis('off')\n",
" \n",
"# # 计算统计信息\n",
"# eucl_stats = {\n",
"# 'mean': np.mean(euclidean_distances),\n",
"# 'median': np.median(euclidean_distances),\n",
"# 'std': np.std(euclidean_distances),\n",
"# 'min': np.min(euclidean_distances),\n",
"# 'max': np.max(euclidean_distances)\n",
"# }\n",
" \n",
"# cos_stats = {\n",
"# 'mean': np.mean(cosine_distances),\n",
"# 'median': np.median(cosine_distances),\n",
"# 'std': np.std(cosine_distances),\n",
"# 'min': np.min(cosine_distances),\n",
"# 'max': np.max(cosine_distances)\n",
"# }\n",
" \n",
"# stats_text = f\"\"\"\n",
"# 距离度量统计对比\n",
"\n",
"# 指标 欧氏距离 余弦距离\n",
"# {'='*45}\n",
"# 平均值 {eucl_stats['mean']:.4f} {cos_stats['mean']:.4f}\n",
"# 中位数 {eucl_stats['median']:.4f} {cos_stats['median']:.4f}\n",
"# 标准差 {eucl_stats['std']:.4f} {cos_stats['std']:.4f}\n",
"# 最小值 {eucl_stats['min']:.4f} {cos_stats['min']:.4f}\n",
"# 最大值 {eucl_stats['max']:.4f} {cos_stats['max']:.4f}\n",
"\n",
"# 相关性系数: {correlation:.4f}\n",
"\n",
"# 解释:\n",
"# - 欧氏距离: 测量特征空间中的直线距离\n",
"# - 余弦距离: 测量向量间的角度差异\n",
"# - 高相关性表明两种度量捕获相似的模式\n",
"# \"\"\"\n",
" \n",
"# ax4.text(0.05, 0.95, stats_text, transform=ax4.transAxes, fontsize=9,\n",
"# verticalalignment='top', fontfamily='monospace',\n",
"# bbox=dict(boxstyle='round', facecolor='lightcyan', alpha=0.8))\n",
" \n",
"# plt.tight_layout()\n",
"# plt.show()\n",
" \n",
"# return correlation, eucl_stats, cos_stats\n",
"\n",
"# # 执行距离度量比较\n",
"# correlation, eucl_stats, cos_stats = compare_distance_metrics(\n",
"# euclidean_matrix, cosine_matrix, phonemes\n",
"# )\n",
"\n",
"# # 9. 音素聚类分析\n",
"# print(\"\\n\" + \"=\"*70)\n",
"# print(\"基于距离的音素聚类分析\")\n",
"# print(\"=\"*70)\n",
"\n",
"# def analyze_phoneme_clusters(distance_matrix, phonemes, n_clusters_range=[3, 4, 5, 6]):\n",
"# \"\"\"\n",
"# 使用不同数量的聚类分析音素群组\n",
"# \"\"\"\n",
"# from sklearn.cluster import AgglomerativeClustering\n",
" \n",
"# results = {}\n",
" \n",
"# fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n",
"# axes = axes.flatten()\n",
" \n",
"# for i, n_clusters in enumerate(n_clusters_range):\n",
"# if i >= 4: # 最多显示4个图\n",
"# break\n",
" \n",
"# # 执行层次聚类\n",
"# clustering = AgglomerativeClustering(\n",
"# n_clusters=n_clusters, \n",
"# metric='precomputed',\n",
"# linkage='average'\n",
"# )\n",
" \n",
"# cluster_labels = clustering.fit_predict(distance_matrix)\n",
" \n",
"# # 分析聚类结果\n",
"# clusters = {}\n",
"# for phoneme, label in zip(phonemes, cluster_labels):\n",
"# if label not in clusters:\n",
"# clusters[label] = []\n",
"# clusters[label].append(phoneme)\n",
" \n",
"# results[n_clusters] = clusters\n",
" \n",
"# # 可视化聚类结果\n",
"# ax = axes[i]\n",
" \n",
"# # 创建颜色映射\n",
"# colors = plt.cm.Set3(np.linspace(0, 1, n_clusters))\n",
" \n",
"# # 为每个音素分配颜色\n",
"# phoneme_colors = [colors[cluster_labels[j]] for j in range(len(phonemes))]\n",
" \n",
"# # 使用PCA降维可视化使用质心数据\n",
"# from sklearn.decomposition import PCA\n",
" \n",
"# # 重新获取质心矩阵\n",
"# centroid_matrix = np.array([centroids[phoneme] for phoneme in phonemes])\n",
" \n",
"# if centroid_matrix.shape[1] > 2:\n",
"# pca = PCA(n_components=2)\n",
"# pca_result = pca.fit_transform(centroid_matrix)\n",
"# else:\n",
"# pca_result = centroid_matrix\n",
" \n",
"# # 绘制散点图\n",
"# for cluster_id in range(n_clusters):\n",
"# mask = cluster_labels == cluster_id\n",
"# ax.scatter(pca_result[mask, 0], pca_result[mask, 1], \n",
"# c=[colors[cluster_id]], label=f'聚类 {cluster_id}', s=50, alpha=0.7)\n",
" \n",
"# # 添加音素标签\n",
"# for j, phoneme in enumerate(phonemes):\n",
"# if cluster_labels[j] == cluster_id:\n",
"# ax.annotate(phoneme, (pca_result[j, 0], pca_result[j, 1]), \n",
"# xytext=(5, 5), textcoords='offset points', fontsize=8)\n",
" \n",
"# ax.set_title(f'{n_clusters} 个聚类')\n",
"# ax.set_xlabel('PC1')\n",
"# ax.set_ylabel('PC2')\n",
"# ax.legend()\n",
"# ax.grid(True, alpha=0.3)\n",
" \n",
"# plt.tight_layout()\n",
"# plt.show()\n",
" \n",
"# # 打印聚类结果\n",
"# for n_clusters, clusters in results.items():\n",
"# print(f\"\\n{n_clusters} 个聚类的结果:\")\n",
"# for cluster_id, phonemes_in_cluster in clusters.items():\n",
"# print(f\" 聚类 {cluster_id}: {', '.join(phonemes_in_cluster)}\")\n",
" \n",
"# return results\n",
"\n",
"# # 执行音素聚类分析\n",
"# clustering_results = analyze_phoneme_clusters(euclidean_matrix, phonemes)\n",
"\n",
"# print(f\"\\n音素类间距离分析完成\")\n",
"# print(f\"发现了丰富的音素相似性模式和聚类结构。\")"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"# # 优化的DBSCAN异常值检测分析\n",
"# import numpy as np\n",
"# import matplotlib.pyplot as plt\n",
"# from sklearn.cluster import DBSCAN\n",
"# from sklearn.preprocessing import StandardScaler\n",
"# from sklearn.decomposition import PCA\n",
"# from sklearn.metrics import silhouette_score\n",
"# import pandas as pd\n",
"\n",
"# print(\"=\"*70)\n",
"# print(\"优化的DBSCAN异常值检测分析\")\n",
"# print(\"=\"*70)\n",
"\n",
"# def determine_optimal_pca_components(data, min_variance_ratio=0.80, max_components=500):\n",
"# \"\"\"\n",
"# 确定最优的PCA组件数保证解释方差比达到要求\n",
"# \"\"\"\n",
"# print(f\"原始数据维度: {data.shape}\")\n",
" \n",
"# # 先用少量组件快速估计\n",
"# n_samples, n_features = data.shape\n",
" \n",
"# # 确保组件数不超过样本数和特征数的最小值\n",
"# max_possible_components = min(n_samples, n_features, max_components)\n",
" \n",
"# # 快速测试不同的组件数\n",
"# test_components = [50, 100, 200, 300, 400, 500]\n",
"# test_components = [c for c in test_components if c <= max_possible_components]\n",
" \n",
"# if not test_components:\n",
"# test_components = [min(50, max_possible_components)]\n",
" \n",
"# print(f\"测试的组件数: {test_components}\")\n",
" \n",
"# best_components = test_components[0]\n",
"# best_ratio = 0\n",
" \n",
"# for n_comp in test_components:\n",
"# pca_temp = PCA(n_components=n_comp, random_state=42)\n",
"# pca_temp.fit(data)\n",
"# variance_ratio = np.sum(pca_temp.explained_variance_ratio_)\n",
" \n",
"# print(f\" {n_comp} 组件: 解释方差比 = {variance_ratio:.4f}\")\n",
" \n",
"# if variance_ratio >= min_variance_ratio:\n",
"# best_components = n_comp\n",
"# best_ratio = variance_ratio\n",
"# break\n",
"# elif variance_ratio > best_ratio:\n",
"# best_components = n_comp\n",
"# best_ratio = variance_ratio\n",
" \n",
"# print(f\"选择 {best_components} 个组件 (解释方差比: {best_ratio:.4f})\")\n",
"# return best_components\n",
"\n",
"# def smart_dbscan_outlier_detection(processed_result, target_phonemes=None, min_variance_ratio=0.75):\n",
"# \"\"\"\n",
"# 智能DBSCAN异常值检测使用合理的PCA降维\n",
"# \"\"\"\n",
"# outlier_results = {}\n",
" \n",
"# # 如果没有指定目标音素,选择样本数适中的音素\n",
"# if target_phonemes is None:\n",
"# phoneme_counts = [(p, len(seqs)) for p, seqs in processed_result.items()]\n",
"# # 选择样本数在50-1000之间的音素\n",
"# target_phonemes = [p for p, count in phoneme_counts if 50 <= count <= 1000]\n",
"# target_phonemes = target_phonemes[:5] # 最多处理5个音素\n",
" \n",
"# print(f\"将处理的音素: {target_phonemes}\")\n",
" \n",
"# for phoneme in target_phonemes:\n",
"# if phoneme not in processed_result:\n",
"# continue\n",
" \n",
"# sequences = processed_result[phoneme]\n",
"# print(f\"\\n\" + \"=\"*50)\n",
"# print(f\"分析音素 '{phoneme}' ({len(sequences)} 个样本)\")\n",
"# print(\"=\"*50)\n",
" \n",
"# # 展平序列数据\n",
"# flattened_sequences = []\n",
"# for seq in sequences:\n",
"# flattened_sequences.append(seq.flatten())\n",
" \n",
"# flattened_sequences = np.array(flattened_sequences)\n",
"# print(f\"原始数据形状: {flattened_sequences.shape}\")\n",
" \n",
"# # 标准化数据\n",
"# scaler = StandardScaler()\n",
"# scaled_data = scaler.fit_transform(flattened_sequences)\n",
" \n",
"# # 智能确定PCA组件数\n",
"# optimal_components = determine_optimal_pca_components(\n",
"# scaled_data, min_variance_ratio=min_variance_ratio\n",
"# )\n",
" \n",
"# # 执行PCA降维\n",
"# pca = PCA(n_components=optimal_components, random_state=42)\n",
"# pca_data = pca.fit_transform(scaled_data)\n",
"# variance_explained = np.sum(pca.explained_variance_ratio_)\n",
" \n",
"# print(f\"PCA降维结果:\")\n",
"# print(f\" 原始维度: {scaled_data.shape[1]}\")\n",
"# print(f\" 降维后: {pca_data.shape[1]}\")\n",
"# print(f\" 解释方差比: {variance_explained:.4f}\")\n",
"# print(f\" 信息保留率: {variance_explained*100:.2f}%\")\n",
" \n",
"# # 使用简化的DBSCAN参数搜索\n",
"# print(f\"\\n开始DBSCAN参数搜索...\")\n",
" \n",
"# # 基于数据估计合理的eps范围\n",
"# from sklearn.neighbors import NearestNeighbors\n",
"# k = min(10, len(sequences)//10)\n",
"# nbrs = NearestNeighbors(n_neighbors=k)\n",
"# nbrs.fit(pca_data)\n",
"# distances, _ = nbrs.kneighbors(pca_data)\n",
"# k_distances = np.sort(distances[:, k-1])\n",
" \n",
"# # 选择eps候选值\n",
"# eps_candidates = [\n",
"# np.percentile(k_distances, 25),\n",
"# np.percentile(k_distances, 50),\n",
"# np.percentile(k_distances, 75),\n",
"# np.percentile(k_distances, 90)\n",
"# ]\n",
" \n",
"# min_samples_candidates = [5, 10, 15, 20]\n",
" \n",
"# print(f\"eps候选值: {[f'{e:.3f}' for e in eps_candidates]}\")\n",
"# print(f\"min_samples候选值: {min_samples_candidates}\")\n",
" \n",
"# best_score = -1\n",
"# best_result = None\n",
" \n",
"# for eps in eps_candidates:\n",
"# for min_samples in min_samples_candidates:\n",
"# if min_samples >= len(sequences) // 5: # min_samples不能太大\n",
"# continue\n",
" \n",
"# dbscan = DBSCAN(eps=eps, min_samples=min_samples)\n",
"# labels = dbscan.fit_predict(pca_data)\n",
" \n",
"# n_outliers = np.sum(labels == -1)\n",
"# n_clusters = len(set(labels)) - (1 if -1 in labels else 0)\n",
"# outlier_ratio = n_outliers / len(labels)\n",
" \n",
"# # 计算评分\n",
"# if n_clusters > 0 and 0.05 <= outlier_ratio <= 0.40: # 异常值比例在5%-40%之间\n",
"# try:\n",
"# if len(set(labels[labels != -1])) > 1:\n",
"# silhouette = silhouette_score(pca_data[labels != -1], labels[labels != -1])\n",
"# else:\n",
"# silhouette = 0.5 # 单聚类给中等分数\n",
"# except:\n",
"# silhouette = 0\n",
" \n",
"# # 综合评分:轮廓系数 + 合理的异常值比例奖励\n",
"# if 0.1 <= outlier_ratio <= 0.25: # 最理想的异常值比例\n",
"# ratio_bonus = 0.2\n",
"# else:\n",
"# ratio_bonus = 0.1\n",
" \n",
"# score = silhouette + ratio_bonus\n",
" \n",
"# if score > best_score:\n",
"# best_score = score\n",
"# best_result = {\n",
"# 'eps': eps,\n",
"# 'min_samples': min_samples,\n",
"# 'labels': labels.copy(),\n",
"# 'outliers': np.where(labels == -1)[0],\n",
"# 'n_clusters': n_clusters,\n",
"# 'outlier_ratio': outlier_ratio,\n",
"# 'silhouette': silhouette\n",
"# }\n",
" \n",
"# print(f\" eps={eps:.3f}, min_samples={min_samples}: \"\n",
"# f\"{n_clusters}聚类, {n_outliers}异常值 ({outlier_ratio*100:.1f}%)\")\n",
" \n",
"# if best_result is not None:\n",
"# outlier_results[phoneme] = {\n",
"# **best_result,\n",
"# 'pca_data': pca_data,\n",
"# 'pca_model': pca,\n",
"# 'variance_explained': variance_explained,\n",
"# 'original_data': flattened_sequences,\n",
"# 'scaled_data': scaled_data\n",
"# }\n",
" \n",
"# print(f\"\\n✅ 找到最佳参数:\")\n",
"# print(f\" eps: {best_result['eps']:.3f}\")\n",
"# print(f\" min_samples: {best_result['min_samples']}\")\n",
"# print(f\" 聚类数: {best_result['n_clusters']}\")\n",
"# print(f\" 异常值: {len(best_result['outliers'])} ({best_result['outlier_ratio']*100:.1f}%)\")\n",
"# print(f\" 轮廓系数: {best_result['silhouette']:.3f}\")\n",
"# print(f\" 综合评分: {best_score:.3f}\")\n",
"# else:\n",
"# print(f\"\\n❌ 未找到合适的DBSCAN参数\")\n",
" \n",
"# return outlier_results\n",
"\n",
"# def visualize_smart_dbscan_results(outlier_results):\n",
"# \"\"\"\n",
"# 可视化智能DBSCAN结果\n",
"# \"\"\"\n",
"# if not outlier_results:\n",
"# print(\"没有结果可可视化\")\n",
"# return\n",
" \n",
"# n_phonemes = len(outlier_results)\n",
"# fig, axes = plt.subplots(2, n_phonemes, figsize=(6*n_phonemes, 10))\n",
" \n",
"# if n_phonemes == 1:\n",
"# axes = axes.reshape(2, 1)\n",
" \n",
"# for i, (phoneme, result) in enumerate(outlier_results.items()):\n",
"# # 上图PCA散点图\n",
"# ax1 = axes[0, i]\n",
"# pca_data = result['pca_data']\n",
"# labels = result['labels']\n",
" \n",
"# # 绘制聚类\n",
"# unique_labels = set(labels)\n",
"# colors = plt.cm.Set3(np.linspace(0, 1, len(unique_labels)))\n",
" \n",
"# for j, label in enumerate(unique_labels):\n",
"# if label == -1:\n",
"# mask = labels == label\n",
"# ax1.scatter(pca_data[mask, 0], pca_data[mask, 1], \n",
"# c='red', marker='x', s=60, label='异常值', alpha=0.8)\n",
"# else:\n",
"# mask = labels == label\n",
"# ax1.scatter(pca_data[mask, 0], pca_data[mask, 1], \n",
"# c=[colors[j]], label=f'聚类 {label}', alpha=0.7, s=30)\n",
" \n",
"# ax1.set_title(f'音素 \"{phoneme}\" DBSCAN结果\\n'\n",
"# f'{result[\"n_clusters\"]}聚类, {len(result[\"outliers\"])}异常值 '\n",
"# f'({result[\"outlier_ratio\"]*100:.1f}%)')\n",
"# ax1.set_xlabel('PC1')\n",
"# ax1.set_ylabel('PC2')\n",
"# ax1.legend()\n",
"# ax1.grid(True, alpha=0.3)\n",
" \n",
"# # 下图:方差解释图\n",
"# ax2 = axes[1, i]\n",
"# pca_model = result['pca_model']\n",
"# n_components = len(pca_model.explained_variance_ratio_)\n",
" \n",
"# # 绘制累计方差解释比\n",
"# cumsum_var = np.cumsum(pca_model.explained_variance_ratio_)\n",
"# ax2.plot(range(1, n_components+1), cumsum_var, 'b-', marker='o')\n",
"# ax2.axhline(y=0.8, color='r', linestyle='--', label='80%阈值')\n",
"# ax2.axhline(y=0.9, color='g', linestyle='--', label='90%阈值')\n",
" \n",
"# ax2.set_xlabel('主成分数量')\n",
"# ax2.set_ylabel('累计解释方差比')\n",
"# ax2.set_title(f'PCA方差解释 (总计: {result[\"variance_explained\"]:.3f})')\n",
"# ax2.legend()\n",
"# ax2.grid(True, alpha=0.3)\n",
"# ax2.set_ylim(0, 1)\n",
" \n",
"# plt.tight_layout()\n",
"# plt.show()\n",
"\n",
"# def analyze_outliers_detailed(outlier_results):\n",
"# \"\"\"\n",
"# 详细分析异常值\n",
"# \"\"\"\n",
"# print(\"\\n\" + \"=\"*70)\n",
"# print(\"详细异常值分析报告\")\n",
"# print(\"=\"*70)\n",
" \n",
"# for phoneme, result in outlier_results.items():\n",
"# print(f\"\\n音素 '{phoneme}' 异常值分析:\")\n",
"# print(\"-\" * 40)\n",
" \n",
"# outlier_indices = result['outliers']\n",
"# normal_indices = [i for i in range(len(result['labels'])) if i not in outlier_indices]\n",
" \n",
"# print(f\"总样本数: {len(result['labels'])}\")\n",
"# print(f\"正常样本: {len(normal_indices)} ({len(normal_indices)/len(result['labels'])*100:.1f}%)\")\n",
"# print(f\"异常样本: {len(outlier_indices)} ({len(outlier_indices)/len(result['labels'])*100:.1f}%)\")\n",
"# print(f\"聚类数量: {result['n_clusters']}\")\n",
"# print(f\"PCA维度: {result['pca_data'].shape[1]}\")\n",
"# print(f\"信息保留: {result['variance_explained']*100:.2f}%\")\n",
" \n",
"# # 分析异常值的特征\n",
"# if len(outlier_indices) > 0:\n",
"# outlier_data = result['pca_data'][outlier_indices]\n",
"# normal_data = result['pca_data'][normal_indices] if len(normal_indices) > 0 else None\n",
" \n",
"# print(f\"\\n异常值特征分析:\")\n",
"# print(f\" PC1均值: {np.mean(outlier_data[:, 0]):.3f} ± {np.std(outlier_data[:, 0]):.3f}\")\n",
"# print(f\" PC2均值: {np.mean(outlier_data[:, 1]):.3f} ± {np.std(outlier_data[:, 1]):.3f}\")\n",
" \n",
"# if normal_data is not None:\n",
"# print(f\"正常值特征对比:\")\n",
"# print(f\" PC1均值: {np.mean(normal_data[:, 0]):.3f} ± {np.std(normal_data[:, 0]):.3f}\")\n",
"# print(f\" PC2均值: {np.mean(normal_data[:, 1]):.3f} ± {np.std(normal_data[:, 1]):.3f}\")\n",
"\n",
"# # 执行优化的DBSCAN异常值检测\n",
"# print(\"开始智能DBSCAN异常值检测...\")\n",
"\n",
"# # 选择几个有代表性的音素进行分析\n",
"# target_phonemes = ['IH', 'T', 'S', 'N', 'AH'] # 手动选择一些常见音素\n",
"\n",
"# outlier_results = smart_dbscan_outlier_detection(\n",
"# processed_result, \n",
"# target_phonemes=target_phonemes,\n",
"# min_variance_ratio=0.75 # 保留至少75%的方差\n",
"# )\n",
"\n",
"# if outlier_results:\n",
"# print(f\"\\n✅ 成功检测到 {len(outlier_results)} 个音素的异常值!\")\n",
" \n",
"# # 可视化结果\n",
"# visualize_smart_dbscan_results(outlier_results)\n",
" \n",
"# # 详细分析\n",
"# analyze_outliers_detailed(outlier_results)\n",
" \n",
"# print(f\"\\n📊 异常值检测总结:\")\n",
"# for phoneme, result in outlier_results.items():\n",
"# print(f\" {phoneme}: {len(result['outliers'])}/{len(result['labels'])} \"\n",
"# f\"({result['outlier_ratio']*100:.1f}%) 异常值\")\n",
" \n",
"# else:\n",
"# print(\"❌ 未检测到任何有效的异常值结果\")\n",
"\n",
"# print(f\"\\n💡 关键改进:\")\n",
"# print(f\"1. 使用自适应PCA降维保留75%以上的方差信息\")\n",
"# print(f\"2. 基于数据分布智能选择DBSCAN参数\")\n",
"# print(f\"3. 合理的异常值比例范围(5%-40%)\")\n",
"# print(f\"4. 综合评分机制平衡聚类质量和异常值检测\")"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"# # 类内相似度分析\n",
"# import numpy as np\n",
"# import matplotlib.pyplot as plt\n",
"# from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances\n",
"# from sklearn.preprocessing import StandardScaler\n",
"# import pandas as pd\n",
"# import seaborn as sns\n",
"\n",
"# print(\"=\"*70)\n",
"# print(\"类内相似度分析 - 音素分类有效性评估\")\n",
"# print(\"=\"*70)\n",
"\n",
"# def calculate_intra_class_similarity(sequences, metric='cosine'):\n",
"# \"\"\"\n",
"# 计算类内相似度\n",
"# \"\"\"\n",
"# if len(sequences) < 2:\n",
"# return np.nan, np.nan, np.nan\n",
" \n",
"# # 展平序列\n",
"# flattened_sequences = []\n",
"# for seq in sequences:\n",
"# flattened_sequences.append(seq.flatten())\n",
" \n",
"# flattened_sequences = np.array(flattened_sequences)\n",
" \n",
"# # 标准化\n",
"# scaler = StandardScaler()\n",
"# scaled_sequences = scaler.fit_transform(flattened_sequences)\n",
" \n",
"# if metric == 'cosine':\n",
"# # 计算余弦相似度\n",
"# similarity_matrix = cosine_similarity(scaled_sequences)\n",
"# # 提取上三角矩阵(排除对角线)\n",
"# upper_tri = np.triu(similarity_matrix, k=1)\n",
"# similarities = upper_tri[upper_tri > 0]\n",
"# elif metric == 'euclidean':\n",
"# # 计算欧氏距离,然后转换为相似度\n",
"# distance_matrix = euclidean_distances(scaled_sequences)\n",
"# # 转换为相似度(距离越小,相似度越高)\n",
"# max_dist = np.max(distance_matrix)\n",
"# similarity_matrix = 1 - (distance_matrix / max_dist)\n",
"# upper_tri = np.triu(similarity_matrix, k=1)\n",
"# similarities = upper_tri[upper_tri > 0]\n",
" \n",
"# mean_similarity = np.mean(similarities)\n",
"# std_similarity = np.std(similarities)\n",
"# median_similarity = np.median(similarities)\n",
" \n",
"# return mean_similarity, std_similarity, median_similarity\n",
"\n",
"# def analyze_phoneme_similarity(processed_result, metric='cosine', sample_limit=500):\n",
"# \"\"\"\n",
"# 分析每个音素的类内相似度\n",
"# \"\"\"\n",
"# print(f\"使用 {metric} 相似度度量\")\n",
"# print(f\"每个音素最多分析 {sample_limit} 个样本\")\n",
"# print(\"-\" * 50)\n",
" \n",
"# phoneme_similarities = {}\n",
" \n",
"# for phoneme, sequences in processed_result.items():\n",
"# if len(sequences) < 5: # 跳过样本数太少的音素\n",
"# print(f\"跳过音素 '{phoneme}' (样本数太少: {len(sequences)})\")\n",
"# continue\n",
" \n",
"# # 如果样本太多,随机采样\n",
"# if len(sequences) > sample_limit:\n",
"# indices = np.random.choice(len(sequences), sample_limit, replace=False)\n",
"# sampled_sequences = [sequences[i] for i in indices]\n",
"# else:\n",
"# sampled_sequences = sequences\n",
" \n",
"# mean_sim, std_sim, median_sim = calculate_intra_class_similarity(\n",
"# sampled_sequences, metric=metric\n",
"# )\n",
" \n",
"# phoneme_similarities[phoneme] = {\n",
"# 'mean': mean_sim,\n",
"# 'std': std_sim,\n",
"# 'median': median_sim,\n",
"# 'n_samples': len(sampled_sequences),\n",
"# 'n_pairs': len(sampled_sequences) * (len(sampled_sequences) - 1) // 2\n",
"# }\n",
" \n",
"# print(f\"音素 '{phoneme}': 平均相似度={mean_sim:.4f} ± {std_sim:.4f}, \"\n",
"# f\"中位数={median_sim:.4f}, 样本数={len(sampled_sequences)}\")\n",
" \n",
"# return phoneme_similarities\n",
"\n",
"# def calculate_overall_similarity(processed_result, metric='cosine', sample_per_phoneme=50):\n",
"# \"\"\"\n",
"# 计算全部音素作为一类的相似度\n",
"# \"\"\"\n",
"# print(f\"\\n计算全体音素相似度 (每个音素采样 {sample_per_phoneme} 个)\")\n",
"# print(\"-\" * 50)\n",
" \n",
"# all_sequences = []\n",
"# phoneme_labels = []\n",
" \n",
"# # 从每个音素中采样一定数量的序列\n",
"# for phoneme, sequences in processed_result.items():\n",
"# if len(sequences) < 5:\n",
"# continue\n",
" \n",
"# n_sample = min(sample_per_phoneme, len(sequences))\n",
"# indices = np.random.choice(len(sequences), n_sample, replace=False)\n",
" \n",
"# for i in indices:\n",
"# all_sequences.append(sequences[i])\n",
"# phoneme_labels.append(phoneme)\n",
" \n",
"# print(f\"总共收集了 {len(all_sequences)} 个序列,来自 {len(set(phoneme_labels))} 个音素\")\n",
" \n",
"# # 计算整体相似度\n",
"# mean_sim, std_sim, median_sim = calculate_intra_class_similarity(\n",
"# all_sequences, metric=metric\n",
"# )\n",
" \n",
"# overall_result = {\n",
"# 'mean': mean_sim,\n",
"# 'std': std_sim,\n",
"# 'median': median_sim,\n",
"# 'n_samples': len(all_sequences),\n",
"# 'n_pairs': len(all_sequences) * (len(all_sequences) - 1) // 2,\n",
"# 'n_phonemes': len(set(phoneme_labels))\n",
"# }\n",
" \n",
"# print(f\"全体音素: 平均相似度={mean_sim:.4f} ± {std_sim:.4f}, \"\n",
"# f\"中位数={median_sim:.4f}, 样本数={len(all_sequences)}\")\n",
" \n",
"# return overall_result, phoneme_labels\n",
"\n",
"# def visualize_similarity_comparison(phoneme_similarities, overall_result, metric='cosine'):\n",
"# \"\"\"\n",
"# 可视化相似度比较\n",
"# \"\"\"\n",
"# fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n",
" \n",
"# # 准备数据\n",
"# phonemes = list(phoneme_similarities.keys())\n",
"# mean_similarities = [phoneme_similarities[p]['mean'] for p in phonemes]\n",
"# std_similarities = [phoneme_similarities[p]['std'] for p in phonemes]\n",
"# sample_counts = [phoneme_similarities[p]['n_samples'] for p in phonemes]\n",
" \n",
"# overall_mean = overall_result['mean']\n",
"# overall_std = overall_result['std']\n",
" \n",
"# # 图1: 每个音素的平均相似度 vs 整体相似度\n",
"# ax1 = axes[0, 0]\n",
"# bars = ax1.bar(range(len(phonemes)), mean_similarities, \n",
"# yerr=std_similarities, capsize=3, alpha=0.7, color='skyblue')\n",
"# ax1.axhline(y=overall_mean, color='red', linestyle='--', linewidth=2, \n",
"# label=f'全体音素平均: {overall_mean:.4f}')\n",
"# ax1.fill_between(range(len(phonemes)), \n",
"# overall_mean - overall_std, \n",
"# overall_mean + overall_std, \n",
"# alpha=0.2, color='red', label=f'全体音素范围: ±{overall_std:.4f}')\n",
" \n",
"# ax1.set_xlabel('音素')\n",
"# ax1.set_ylabel(f'{metric.title()} 相似度')\n",
"# ax1.set_title('各音素类内相似度 vs 全体音素相似度')\n",
"# ax1.set_xticks(range(len(phonemes)))\n",
"# ax1.set_xticklabels(phonemes, rotation=45)\n",
"# ax1.legend()\n",
"# ax1.grid(True, alpha=0.3)\n",
" \n",
"# # 在柱状图上显示数值\n",
"# for i, (bar, mean_val) in enumerate(zip(bars, mean_similarities)):\n",
"# if mean_val > overall_mean:\n",
"# ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, \n",
"# f'{mean_val:.3f}', ha='center', va='bottom', fontsize=8, \n",
"# color='green', weight='bold')\n",
"# else:\n",
"# ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, \n",
"# f'{mean_val:.3f}', ha='center', va='bottom', fontsize=8, \n",
"# color='red', weight='bold')\n",
" \n",
"# # 图2: 相似度提升程度\n",
"# ax2 = axes[0, 1]\n",
"# improvements = [(sim - overall_mean) for sim in mean_similarities]\n",
"# colors = ['green' if imp > 0 else 'red' for imp in improvements]\n",
" \n",
"# bars2 = ax2.bar(range(len(phonemes)), improvements, color=colors, alpha=0.7)\n",
"# ax2.axhline(y=0, color='black', linestyle='-', linewidth=1)\n",
"# ax2.set_xlabel('音素')\n",
"# ax2.set_ylabel('相似度提升 (相对于全体)')\n",
"# ax2.set_title('音素分类的相似度提升效果')\n",
"# ax2.set_xticks(range(len(phonemes)))\n",
"# ax2.set_xticklabels(phonemes, rotation=45)\n",
"# ax2.grid(True, alpha=0.3)\n",
" \n",
"# # 显示数值\n",
"# for bar, imp in zip(bars2, improvements):\n",
"# ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.001, \n",
"# f'{imp:+.3f}', ha='center', va='bottom' if imp > 0 else 'top', \n",
"# fontsize=8, weight='bold')\n",
" \n",
"# # 图3: 样本数量 vs 相似度\n",
"# ax3 = axes[1, 0]\n",
"# scatter = ax3.scatter(sample_counts, mean_similarities, \n",
"# c=improvements, cmap='RdYlGn', s=60, alpha=0.7)\n",
"# ax3.axhline(y=overall_mean, color='red', linestyle='--', alpha=0.7)\n",
"# ax3.set_xlabel('样本数量')\n",
"# ax3.set_ylabel(f'{metric.title()} 相似度')\n",
"# ax3.set_title('样本数量 vs 相似度')\n",
"# ax3.grid(True, alpha=0.3)\n",
" \n",
"# # 添加颜色条\n",
"# cbar = plt.colorbar(scatter, ax=ax3)\n",
"# cbar.set_label('相似度提升')\n",
" \n",
"# # 图4: 相似度分布统计\n",
"# ax4 = axes[1, 1]\n",
"# ax4.axis('off')\n",
" \n",
"# # 计算统计信息\n",
"# positive_improvements = [imp for imp in improvements if imp > 0]\n",
"# negative_improvements = [imp for imp in improvements if imp <= 0]\n",
" \n",
"# avg_improvement = np.mean(improvements)\n",
"# max_improvement = np.max(improvements)\n",
"# min_improvement = np.min(improvements)\n",
" \n",
"# best_phoneme = phonemes[improvements.index(max_improvement)]\n",
"# worst_phoneme = phonemes[improvements.index(min_improvement)]\n",
" \n",
"# stats_text = f\"\"\"\n",
"# 类内相似度分析统计报告\n",
"\n",
"# 度量方法: {metric.title()} 相似度\n",
"\n",
"# 全体音素基线:\n",
"# - 平均相似度: {overall_mean:.4f} ± {overall_std:.4f}\n",
"# - 样本总数: {overall_result['n_samples']}\n",
"# - 音素数量: {overall_result['n_phonemes']}\n",
"\n",
"# 音素分类效果:\n",
"# - 分析音素数: {len(phonemes)}\n",
"# - 平均提升: {avg_improvement:+.4f}\n",
"# - 最大提升: {max_improvement:+.4f} ({best_phoneme})\n",
"# - 最小提升: {min_improvement:+.4f} ({worst_phoneme})\n",
"\n",
"# 提升统计:\n",
"# - 相似度提升音素: {len(positive_improvements)}/{len(phonemes)} ({len(positive_improvements)/len(phonemes)*100:.1f}%)\n",
"# - 相似度下降音素: {len(negative_improvements)}/{len(phonemes)} ({len(negative_improvements)/len(phonemes)*100:.1f}%)\n",
"\n",
"# 结论:\n",
"# \"\"\"\n",
" \n",
"# if avg_improvement > 0:\n",
"# stats_text += f\"✅ 音素分类整体有效 (平均提升 {avg_improvement:.4f})\\n\"\n",
"# stats_text += f\" 按音素分类比混合所有音素更能保持相似性\"\n",
"# else:\n",
"# stats_text += f\"❌ 音素分类效果有限 (平均下降 {avg_improvement:.4f})\\n\"\n",
"# stats_text += f\" 可能需要重新考虑分类策略\"\n",
" \n",
"# ax4.text(0.05, 0.95, stats_text, transform=ax4.transAxes, fontsize=10,\n",
"# verticalalignment='top', fontfamily='monospace',\n",
"# bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))\n",
" \n",
"# plt.tight_layout()\n",
"# plt.show()\n",
" \n",
"# return improvements\n",
"\n",
"# def create_similarity_dataframe(phoneme_similarities, overall_result):\n",
"# \"\"\"\n",
"# 创建相似度对比的DataFrame\n",
"# \"\"\"\n",
"# data = []\n",
"# overall_mean = overall_result['mean']\n",
" \n",
"# for phoneme, sim_data in phoneme_similarities.items():\n",
"# improvement = sim_data['mean'] - overall_mean\n",
"# relative_improvement = (improvement / overall_mean) * 100\n",
" \n",
"# data.append({\n",
"# 'phoneme': phoneme,\n",
"# 'intra_class_similarity': sim_data['mean'],\n",
"# 'similarity_std': sim_data['std'],\n",
"# 'n_samples': sim_data['n_samples'],\n",
"# 'overall_baseline': overall_mean,\n",
"# 'absolute_improvement': improvement,\n",
"# 'relative_improvement_pct': relative_improvement,\n",
"# 'is_better': improvement > 0\n",
"# })\n",
" \n",
"# df = pd.DataFrame(data)\n",
"# df = df.sort_values('absolute_improvement', ascending=False)\n",
" \n",
"# return df\n",
"\n",
"# # 执行类内相似度分析\n",
"# print(\"开始类内相似度分析...\")\n",
"\n",
"# # 1. 分析每个音素的类内相似度\n",
"# print(\"\\n1. 计算各音素类内相似度\")\n",
"# phoneme_similarities_cosine = analyze_phoneme_similarity(\n",
"# processed_result, metric='cosine', sample_limit=300\n",
"# )\n",
"\n",
"# # 2. 计算全体音素的相似度\n",
"# print(\"\\n2. 计算全体音素相似度作为基线\")\n",
"# overall_result_cosine, all_phoneme_labels = calculate_overall_similarity(\n",
"# processed_result, metric='cosine', sample_per_phoneme=30\n",
"# )\n",
"\n",
"# # 3. 可视化比较\n",
"# print(\"\\n3. 可视化相似度比较\")\n",
"# improvements = visualize_similarity_comparison(\n",
"# phoneme_similarities_cosine, overall_result_cosine, metric='cosine'\n",
"# )\n",
"\n",
"# # 4. 创建详细对比表\n",
"# print(\"\\n4. 详细相似度对比表\")\n",
"# df_similarity = create_similarity_dataframe(phoneme_similarities_cosine, overall_result_cosine)\n",
"# print(df_similarity.to_string(index=False, float_format='%.4f'))\n",
"\n",
"# print(f\"\\n✅ 类内相似度分析完成!\")\n",
"# print(f\"这个分析揭示了音素分类对神经信号相似性的影响。\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 🔗 数据集批量处理工作流"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/kaggle/working/nejm-brain-to-text/model_training\n",
"======================================================================\n",
"🚀 RNN数据批量处理工具 - 新版本\n",
"======================================================================\n",
"🔧 创建RNN数据处理器...\n",
"🔧 初始化RNN数据处理器...\n",
" 模型路径: ../data/t15_pretrained_rnn_baseline\n",
" 数据目录: ../data/hdf5_data_final\n",
" 计算设备: cuda:0\n",
"📋 模型配置:\n",
" Sessions数量: 45\n",
" 神经特征维度: 512\n",
" Patch size: 14\n",
" Patch stride: 4\n",
" 输出类别数: 41\n",
"✅ 模型加载成功\n",
"📊 CSV数据加载完成: 265 条记录\n",
"✅ 初始化完成!\n",
"✅ RNN数据处理器创建成功\n",
"✅ 模型加载成功\n",
"📊 CSV数据加载完成: 265 条记录\n",
"✅ 初始化完成!\n",
"✅ RNN数据处理器创建成功\n"
]
}
],
"source": [
"%cd model_training\n",
"# 🚀 RNN数据批量处理工具 - 完整版\n",
"import os\n",
"import torch\n",
"import numpy as np\n",
"import pandas as pd\n",
"from omegaconf import OmegaConf\n",
"import time\n",
"from tqdm import tqdm\n",
"import h5py\n",
"from pathlib import Path\n",
"\n",
"# 导入模型相关模块\n",
"import sys\n",
"sys.path.append('../model_training')\n",
"from rnn_model import GRUDecoder\n",
"from evaluate_model_helpers import *\n",
"from data_augmentations import gauss_smooth\n",
"\n",
"print(\"=\"*70)\n",
"print(\"🚀 RNN数据批量处理工具 - 新版本\")\n",
"print(\"=\"*70)\n",
"\n",
"class RNNDataProcessor:\n",
" \"\"\"\n",
" RNN数据批量处理器 - 生成RNN输入输出拼接数据\n",
" \n",
" 核心功能:\n",
" 1. 加载预训练RNN模型\n",
" 2. 处理原始神经数据(高斯平滑 + patch操作\n",
" 3. 获取RNN输出40类置信度分数\n",
" 4. 拼接处理后的输入和输出\n",
" 5. 批量保存所有session数据\n",
" \"\"\"\n",
" \n",
" def __init__(self, model_path, data_dir, csv_path, device='auto'):\n",
" \"\"\"\n",
" 初始化处理器\n",
" \n",
" 参数:\n",
" model_path: 预训练RNN模型路径\n",
" data_dir: 数据目录路径 \n",
" csv_path: 数据描述CSV文件路径\n",
" device: 计算设备 ('auto', 'cpu', 'cuda:0'等)\n",
" \"\"\"\n",
" self.model_path = model_path\n",
" self.data_dir = data_dir\n",
" self.csv_path = csv_path\n",
" \n",
" # 设备选择\n",
" if device == 'auto':\n",
" self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
" else:\n",
" self.device = torch.device(device)\n",
" \n",
" print(f\"🔧 初始化RNN数据处理器...\")\n",
" print(f\" 模型路径: {model_path}\")\n",
" print(f\" 数据目录: {data_dir}\")\n",
" print(f\" 计算设备: {self.device}\")\n",
" \n",
" # 加载配置和模型\n",
" self._load_config()\n",
" self._load_model()\n",
" self._load_csv()\n",
" \n",
" print(f\"✅ 初始化完成!\")\n",
" \n",
" def _load_config(self):\n",
" \"\"\"加载模型配置\"\"\"\n",
" config_path = os.path.join(self.model_path, 'checkpoint/args.yaml')\n",
" if not os.path.exists(config_path):\n",
" raise FileNotFoundError(f\"配置文件不存在: {config_path}\")\n",
" \n",
" self.model_args = OmegaConf.load(config_path)\n",
" \n",
" print(f\"📋 模型配置:\")\n",
" print(f\" Sessions数量: {len(self.model_args['dataset']['sessions'])}\")\n",
" print(f\" 神经特征维度: {self.model_args['model']['n_input_features']}\")\n",
" print(f\" Patch size: {self.model_args['model']['patch_size']}\")\n",
" print(f\" Patch stride: {self.model_args['model']['patch_stride']}\")\n",
" print(f\" 输出类别数: {self.model_args['dataset']['n_classes']}\")\n",
" \n",
" def _load_model(self):\n",
" \"\"\"加载预训练RNN模型\"\"\"\n",
" try:\n",
" # 创建模型\n",
" self.model = GRUDecoder(\n",
" neural_dim=self.model_args['model']['n_input_features'],\n",
" n_units=self.model_args['model']['n_units'], \n",
" n_days=len(self.model_args['dataset']['sessions']),\n",
" n_classes=self.model_args['dataset']['n_classes'],\n",
" rnn_dropout=self.model_args['model']['rnn_dropout'],\n",
" input_dropout=self.model_args['model']['input_network']['input_layer_dropout'],\n",
" n_layers=self.model_args['model']['n_layers'],\n",
" patch_size=self.model_args['model']['patch_size'],\n",
" patch_stride=self.model_args['model']['patch_stride'],\n",
" )\n",
" \n",
" # 加载权重\n",
" checkpoint_path = os.path.join(self.model_path, 'checkpoint/best_checkpoint')\n",
" try:\n",
" checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)\n",
" except TypeError:\n",
" checkpoint = torch.load(checkpoint_path, map_location=self.device)\n",
" \n",
" # 清理键名\n",
" for key in list(checkpoint['model_state_dict'].keys()):\n",
" checkpoint['model_state_dict'][key.replace(\"module.\", \"\")] = checkpoint['model_state_dict'].pop(key)\n",
" checkpoint['model_state_dict'][key.replace(\"_orig_mod.\", \"\")] = checkpoint['model_state_dict'].pop(key)\n",
" \n",
" self.model.load_state_dict(checkpoint['model_state_dict'])\n",
" self.model.to(self.device)\n",
" self.model.eval()\n",
" \n",
" print(f\"✅ 模型加载成功\")\n",
" \n",
" except Exception as e:\n",
" print(f\"❌ 模型加载失败: {e}\")\n",
" raise\n",
" \n",
" def _load_csv(self):\n",
" \"\"\"加载数据描述文件\"\"\"\n",
" if not os.path.exists(self.csv_path):\n",
" raise FileNotFoundError(f\"CSV文件不存在: {self.csv_path}\")\n",
" \n",
" self.csv_df = pd.read_csv(self.csv_path)\n",
" print(f\"📊 CSV数据加载完成: {len(self.csv_df)} 条记录\")\n",
" \n",
" def _process_single_trial(self, neural_data, session_idx):\n",
" \"\"\"\n",
" 处理单个试验数据\n",
" \n",
" 参数:\n",
" neural_data: 原始神经数据 [time_steps, features]\n",
" session_idx: 会话索引\n",
" \n",
" 返回:\n",
" dict: 包含拼接数据和统计信息\n",
" \"\"\"\n",
" # 添加batch维度\n",
" neural_input = np.expand_dims(neural_data, axis=0)\n",
" neural_tensor = torch.tensor(neural_input, device=self.device, dtype=torch.bfloat16)\n",
" \n",
" # 高斯平滑\n",
" with torch.autocast(device_type=\"cuda\" if self.device.type == \"cuda\" else \"cpu\", \n",
" enabled=self.model_args.get('use_amp', False), dtype=torch.bfloat16):\n",
" \n",
" smoothed_data = gauss_smooth(\n",
" inputs=neural_tensor,\n",
" device=self.device,\n",
" smooth_kernel_std=self.model_args['dataset']['data_transforms']['smooth_kernel_std'],\n",
" smooth_kernel_size=self.model_args['dataset']['data_transforms']['smooth_kernel_size'],\n",
" padding='valid',\n",
" )\n",
" \n",
" # Patch操作复制模型内部逻辑\n",
" processed_data = smoothed_data\n",
" if self.model.patch_size > 0:\n",
" processed_data = processed_data.unsqueeze(1) # [batch, 1, time, features]\n",
" processed_data = processed_data.permute(0, 3, 1, 2) # [batch, features, 1, time]\n",
" \n",
" # 滑动窗口提取\n",
" patches = processed_data.unfold(3, self.model.patch_size, self.model.patch_stride)\n",
" patches = patches.squeeze(2) # [batch, features, patches, patch_size]\n",
" patches = patches.permute(0, 2, 3, 1) # [batch, patches, patch_size, features]\n",
" \n",
" # 展平最后两个维度\n",
" processed_data = patches.reshape(patches.size(0), patches.size(1), -1)\n",
" \n",
" # RNN推理\n",
" with torch.no_grad():\n",
" logits, _ = self.model(\n",
" x=smoothed_data,\n",
" day_idx=torch.tensor([session_idx], device=self.device),\n",
" states=None,\n",
" return_state=True,\n",
" )\n",
" \n",
" # 转换为numpy\n",
" processed_features = processed_data.float().cpu().numpy()[0] # [time_steps, processed_features]\n",
" confidence_scores = logits.float().cpu().numpy()[0] # [time_steps, 40]\n",
" \n",
" # 拼接数据\n",
" concatenated = np.concatenate([processed_features, confidence_scores], axis=1)\n",
" \n",
" return {\n",
" 'concatenated_data': concatenated,\n",
" 'processed_features': processed_features,\n",
" 'confidence_scores': confidence_scores,\n",
" 'original_time_steps': neural_data.shape[0],\n",
" 'processed_time_steps': concatenated.shape[0],\n",
" 'feature_reduction_ratio': concatenated.shape[0] / neural_data.shape[0]\n",
" }\n",
" \n",
" def process_session(self, session_name, data_types=['train', 'val', 'test']):\n",
" \"\"\"\n",
" 处理单个session的数据\n",
" \n",
" 参数:\n",
" session_name: 会话名称\n",
" data_types: 要处理的数据类型列表\n",
" \n",
" 返回:\n",
" dict: 处理结果\n",
" \"\"\"\n",
" print(f\"\\n🔄 处理会话: {session_name}\")\n",
" \n",
" session_idx = self.model_args['dataset']['sessions'].index(session_name)\n",
" session_results = {}\n",
" \n",
" for data_type in data_types:\n",
" data_file = os.path.join(self.data_dir, session_name, f'data_{data_type}.hdf5')\n",
" \n",
" if not os.path.exists(data_file):\n",
" print(f\" ⚠️ {data_type} 数据文件不存在,跳过\")\n",
" continue\n",
" \n",
" print(f\" 📁 处理 {data_type} 数据...\")\n",
" \n",
" try:\n",
" # 加载数据\n",
" data = load_h5py_file(data_file, self.csv_df)\n",
" num_trials = len(data['neural_features'])\n",
" \n",
" if num_trials == 0:\n",
" print(f\" ⚠️ {data_type} 数据为空\")\n",
" continue\n",
" \n",
" # 处理所有试验\n",
" results = {\n",
" 'concatenated_data': [],\n",
" 'processed_features': [],\n",
" 'confidence_scores': [],\n",
" 'trial_metadata': [],\n",
" 'processing_stats': []\n",
" }\n",
" \n",
" for trial_idx in tqdm(range(num_trials), desc=f\" {data_type}\", leave=False):\n",
" neural_data = data['neural_features'][trial_idx]\n",
" \n",
" # 处理单个试验\n",
" trial_result = self._process_single_trial(neural_data, session_idx)\n",
" \n",
" # 保存结果\n",
" results['concatenated_data'].append(trial_result['concatenated_data'])\n",
" results['processed_features'].append(trial_result['processed_features'])\n",
" results['confidence_scores'].append(trial_result['confidence_scores'])\n",
" \n",
" # 保存元数据\n",
" metadata = {\n",
" 'session': session_name,\n",
" 'data_type': data_type,\n",
" 'trial_idx': trial_idx,\n",
" 'block_num': data.get('block_num', [None])[trial_idx],\n",
" 'trial_num': data.get('trial_num', [None])[trial_idx],\n",
" **trial_result\n",
" }\n",
" \n",
" # 添加真实标签(如果可用)\n",
" if data_type in ['train', 'val'] and 'sentence_label' in data:\n",
" metadata.update({\n",
" 'sentence_label': data['sentence_label'][trial_idx],\n",
" 'seq_class_ids': data['seq_class_ids'][trial_idx],\n",
" 'seq_len': data['seq_len'][trial_idx]\n",
" })\n",
" \n",
" results['trial_metadata'].append(metadata)\n",
" results['processing_stats'].append(trial_result)\n",
" \n",
" # 统计信息\n",
" if results['concatenated_data']:\n",
" time_steps = [data.shape[0] for data in results['concatenated_data']]\n",
" feature_dims = [data.shape[1] for data in results['concatenated_data']]\n",
" \n",
" print(f\" ✅ {data_type} 处理完成:\")\n",
" print(f\" 试验数: {len(results['concatenated_data'])}\")\n",
" print(f\" 时间步范围: {min(time_steps)}-{max(time_steps)}\")\n",
" print(f\" 特征维度: {feature_dims[0]} (处理后特征: {feature_dims[0]-40}, 置信度: 40)\")\n",
" \n",
" avg_reduction = np.mean([stat['feature_reduction_ratio'] for stat in results['processing_stats']])\n",
" print(f\" 平均时间压缩比: {avg_reduction:.3f}\")\n",
" \n",
" session_results[data_type] = results\n",
" \n",
" except Exception as e:\n",
" print(f\" ❌ {data_type} 处理失败: {e}\")\n",
" continue\n",
" \n",
" return session_results\n",
" \n",
" def process_all_sessions(self, data_types=['train', 'val', 'test'], save_dir='./rnn_processed_data'):\n",
" \"\"\"\n",
" 批量处理所有sessions\n",
" \n",
" 参数:\n",
" data_types: 要处理的数据类型\n",
" save_dir: 保存目录\n",
" \n",
" 返回:\n",
" dict: 所有处理结果\n",
" \"\"\"\n",
" print(f\"\\n🚀 开始批量处理所有会话...\")\n",
" print(f\" 目标数据类型: {data_types}\")\n",
" print(f\" 保存目录: {save_dir}\")\n",
" \n",
" save_path = Path(save_dir)\n",
" save_path.mkdir(parents=True, exist_ok=True)\n",
" \n",
" all_results = {}\n",
" sessions = self.model_args['dataset']['sessions']\n",
" \n",
" start_time = time.time()\n",
" \n",
" for i, session in enumerate(sessions):\n",
" print(f\"\\n📊 进度: {i+1}/{len(sessions)}\")\n",
" \n",
" try:\n",
" session_results = self.process_session(session, data_types)\n",
" \n",
" if session_results:\n",
" all_results[session] = session_results\n",
" \n",
" # 保存单个session结果\n",
" for data_type, data in session_results.items():\n",
" filename = f\"{session}_{data_type}_rnn_processed.npz\"\n",
" filepath = save_path / filename\n",
" \n",
" save_data = {\n",
" 'concatenated_data': np.array(data['concatenated_data'], dtype=object),\n",
" 'processed_features': np.array(data['processed_features'], dtype=object),\n",
" 'confidence_scores': np.array(data['confidence_scores'], dtype=object),\n",
" 'trial_metadata': np.array(data['trial_metadata'], dtype=object),\n",
" }\n",
" \n",
" np.savez_compressed(str(filepath), **save_data)\n",
" print(f\" 💾 保存: {filename}\")\n",
" \n",
" except Exception as e:\n",
" print(f\"❌ 会话 {session} 处理失败: {e}\")\n",
" continue\n",
" \n",
" # 生成总结\n",
" end_time = time.time()\n",
" processing_time = end_time - start_time\n",
" \n",
" total_trials = sum(\n",
" len(session_data[data_type]['concatenated_data'])\n",
" for session_data in all_results.values()\n",
" for data_type in session_data.keys()\n",
" )\n",
" \n",
" print(f\"\\n🎉 批量处理完成!\")\n",
" print(f\"⏱️ 总耗时: {processing_time/60:.2f} 分钟\")\n",
" print(f\"📊 处理统计:\")\n",
" print(f\" 成功会话: {len(all_results)}/{len(sessions)}\")\n",
" print(f\" 总试验数: {total_trials}\")\n",
" print(f\"💾 数据保存在: {save_dir}\")\n",
" \n",
" # 保存总结信息\n",
" summary = {\n",
" 'processing_time': processing_time,\n",
" 'total_sessions': len(all_results),\n",
" 'total_trials': total_trials,\n",
" 'data_types': data_types,\n",
" 'sessions': list(all_results.keys()),\n",
" 'model_config': {\n",
" 'patch_size': self.model_args['model']['patch_size'],\n",
" 'patch_stride': self.model_args['model']['patch_stride'],\n",
" 'smooth_kernel_size': self.model_args['dataset']['data_transforms']['smooth_kernel_size'],\n",
" 'smooth_kernel_std': self.model_args['dataset']['data_transforms']['smooth_kernel_std'],\n",
" }\n",
" }\n",
" \n",
" import json\n",
" with open(save_path / 'processing_summary.json', 'w') as f:\n",
" json.dump(summary, f, indent=2)\n",
" \n",
" return all_results\n",
"\n",
"# 创建处理器实例\n",
"print(\"🔧 创建RNN数据处理器...\")\n",
"\n",
"try:\n",
" processor = RNNDataProcessor(\n",
" model_path='../data/t15_pretrained_rnn_baseline',\n",
" data_dir='../data/hdf5_data_final',\n",
" csv_path='../data/t15_copyTaskData_description.csv',\n",
" device='auto'\n",
" )\n",
" \n",
" print(f\"✅ RNN数据处理器创建成功\")\n",
" \n",
"except Exception as e:\n",
" print(f\"❌ 处理器创建失败: {e}\")\n",
" processor = None"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"======================================================================\n",
"🎯 RNN数据批量处理 - 使用示例\n",
"======================================================================\n",
"\n",
"📋 可用的处理方法:\n",
"1⃣ 单session处理: processor.process_session('session_name')\n",
"2⃣ 批量处理所有: processor.process_all_sessions()\n",
"\n",
"📊 可用会话数量: 45\n",
"📝 前5个会话: ['t15.2023.08.11', 't15.2023.08.13', 't15.2023.08.18', 't15.2023.08.20', 't15.2023.08.25']\n",
"\n",
"🧪 快速测试: 处理会话 't15.2023.08.13' 的训练数据...\n",
"\n",
"🔄 处理会话: t15.2023.08.13\n",
" 📁 处理 train 数据...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" ✅ train 处理完成:\n",
" 试验数: 348\n",
" 时间步范围: 55-352\n",
" 特征维度: 7209 (处理后特征: 7169, 置信度: 40)\n",
" 平均时间压缩比: 0.243\n",
"\n",
"✅ 测试完成!结果概览:\n",
" 处理的试验数: 348\n",
" 第一个试验数据形状: (251, 7209)\n",
" 特征维度详情:\n",
" - 处理后的神经特征: 7168 维\n",
" - RNN置信度分数: 41 维\n",
" - 总拼接特征: 7209 维\n",
" - 时间步数: 251\n",
" 样本元数据:\n",
" - 原始时间步: 1023\n",
" - 处理后时间步: 251\n",
" - 时间压缩比: 0.245\n",
" - 句子标签: Which is most unfortunate because we all lose out.\n",
"\n",
"💡 要批量处理所有数据,运行:\n",
" results = processor.process_all_sessions()\n",
" # 这将处理所有45个sessions的train/val/test数据\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r"
]
}
],
"source": [
"# 🎯 使用示例和批量处理\n",
"\n",
"print(\"=\"*70)\n",
"print(\"🎯 RNN数据批量处理 - 使用示例\")\n",
"print(\"=\"*70)\n",
"\n",
"if processor is not None:\n",
" \n",
" # 方法1: 处理单个session (推荐用于测试)\n",
" print(\"\\n📋 可用的处理方法:\")\n",
" print(\"1⃣ 单session处理: processor.process_session('session_name')\")\n",
" print(\"2⃣ 批量处理所有: processor.process_all_sessions()\")\n",
" \n",
" # 显示可用的sessions\n",
" sessions = processor.model_args['dataset']['sessions']\n",
" print(f\"\\n📊 可用会话数量: {len(sessions)}\")\n",
" print(f\"📝 前5个会话: {sessions[:5]}\")\n",
" \n",
" # 快速测试 - 处理第一个session的部分数据\n",
" test_session = sessions[1] # 't15.2023.08.11'\n",
" \n",
" print(f\"\\n🧪 快速测试: 处理会话 '{test_session}' 的训练数据...\")\n",
" \n",
" # 处理单个session仅train数据进行测试\n",
" single_result = processor.process_session(test_session, ['train'])\n",
" \n",
" if single_result and 'train' in single_result:\n",
" train_data = single_result['train']\n",
" \n",
" print(f\"\\n✅ 测试完成!结果概览:\")\n",
" print(f\" 处理的试验数: {len(train_data['concatenated_data'])}\")\n",
" \n",
" if len(train_data['concatenated_data']) > 0:\n",
" sample_data = train_data['concatenated_data'][0]\n",
" print(f\" 第一个试验数据形状: {sample_data.shape}\")\n",
" print(f\" 特征维度详情:\")\n",
" print(f\" - 处理后的神经特征: {sample_data.shape[1] - 41} 维\")\n",
" print(f\" - RNN置信度分数: 41 维\")\n",
" print(f\" - 总拼接特征: {sample_data.shape[1]} 维\")\n",
" print(f\" - 时间步数: {sample_data.shape[0]}\")\n",
" \n",
" # 显示一些样本元数据\n",
" sample_metadata = train_data['trial_metadata'][0]\n",
" print(f\" 样本元数据:\")\n",
" print(f\" - 原始时间步: {sample_metadata['original_time_steps']}\")\n",
" print(f\" - 处理后时间步: {sample_metadata['processed_time_steps']}\")\n",
" print(f\" - 时间压缩比: {sample_metadata['feature_reduction_ratio']:.3f}\")\n",
" \n",
" if 'sentence_label' in sample_metadata:\n",
" print(f\" - 句子标签: {sample_metadata['sentence_label']}\")\n",
" \n",
" print(f\"\\n💡 要批量处理所有数据,运行:\")\n",
" print(f\" results = processor.process_all_sessions()\")\n",
" print(f\" # 这将处理所有45个sessions的train/val/test数据\")\n",
" \n",
"else:\n",
" print(\"❌ 处理器未创建成功,请检查上面的错误信息\")"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"======================================================================\n",
"🚀 批量处理选项\n",
"======================================================================\n",
"📊 批量处理配置:\n",
" 启用批量处理: False\n",
" 保存目录: ./rnn_processed_data\n",
" 数据类型: ['train', 'val', 'test']\n",
" 总会话数: 45\n",
"\n",
"💡 要开始批量处理,请将 ENABLE_FULL_PROCESSING 设为 True\n",
" 或者手动运行: processor.process_all_sessions()\n",
"\n",
"📋 数据使用说明:\n",
"✅ 处理完成后,每个文件包含:\n",
" - concatenated_data: 拼接后的特征 [神经特征(7168) + 置信度(41)]\n",
" - processed_features: 仅处理后的神经特征\n",
" - confidence_scores: 仅RNN输出的41类置信度分数\n",
" - trial_metadata: 试验元数据(标签、时间步等)\n",
"\n",
"🔧 加载保存的数据:\n",
" data = np.load('session_name_train_rnn_processed.npz', allow_pickle=True)\n",
" features = data['concatenated_data'] # 用于训练分类器\n",
" metadata = data['trial_metadata'] # 获取标签和其他信息\n"
]
}
],
"source": [
"# 🚀 批量处理所有数据 (可选择运行)\n",
"\n",
"print(\"=\"*70)\n",
"print(\"🚀 批量处理选项\")\n",
"print(\"=\"*70)\n",
"\n",
"# 设置参数\n",
"ENABLE_FULL_PROCESSING = False # 设为True开始批量处理\n",
"SAVE_DIR = \"./rnn_processed_data\" # 保存目录\n",
"DATA_TYPES = ['train', 'val', 'test'] # 要处理的数据类型\n",
"\n",
"print(f\"📊 批量处理配置:\")\n",
"print(f\" 启用批量处理: {ENABLE_FULL_PROCESSING}\")\n",
"print(f\" 保存目录: {SAVE_DIR}\")\n",
"print(f\" 数据类型: {DATA_TYPES}\")\n",
"print(f\" 总会话数: {len(processor.model_args['dataset']['sessions'])}\")\n",
"\n",
"if ENABLE_FULL_PROCESSING and processor is not None:\n",
" print(f\"\\n🚀 开始批量处理所有数据...\")\n",
" print(f\"⚠️ 这可能需要较长时间预计30-60分钟\")\n",
" \n",
" # 批量处理\n",
" all_results = processor.process_all_sessions(\n",
" data_types=DATA_TYPES,\n",
" save_dir=SAVE_DIR\n",
" )\n",
" \n",
" print(f\"🎉 批量处理完成!结果保存在: {SAVE_DIR}\")\n",
" \n",
"else:\n",
" print(f\"\\n💡 要开始批量处理,请将 ENABLE_FULL_PROCESSING 设为 True\")\n",
" print(f\" 或者手动运行: processor.process_all_sessions()\")\n",
"\n",
"print(f\"\\n📋 数据使用说明:\")\n",
"print(f\"✅ 处理完成后,每个文件包含:\")\n",
"print(f\" - concatenated_data: 拼接后的特征 [神经特征(7168) + 置信度(41)]\")\n",
"print(f\" - processed_features: 仅处理后的神经特征\")\n",
"print(f\" - confidence_scores: 仅RNN输出的41类置信度分数\")\n",
"print(f\" - trial_metadata: 试验元数据(标签、时间步等)\")\n",
"print(f\"\")\n",
"print(f\"🔧 加载保存的数据:\")\n",
"print(f\" data = np.load('session_name_train_rnn_processed.npz', allow_pickle=True)\")\n",
"print(f\" features = data['concatenated_data'] # 用于训练分类器\")\n",
"print(f\" metadata = data['trial_metadata'] # 获取标签和其他信息\")"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(212, 7209)"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"single_result['train']['concatenated_data'][2].shape"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 模型建立"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🌲 初始化随机森林回归模型\n",
"🌲 随机森林回归器初始化:\n",
" 时间窗口大小: 30\n",
" 树的数量: 100\n",
" 最大深度: 10\n",
" 并行任务: -1\n",
"\n",
"✅ 随机森林回归器准备完成!\n",
"🔧 下一步: 准备训练数据和开始训练\n"
]
}
],
"source": [
"# 🌲 随机森林多输出回归模型实现\n",
"import numpy as np\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error\n",
"from sklearn.preprocessing import StandardScaler\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from tqdm import tqdm\n",
"import warnings\n",
"warnings.filterwarnings('ignore')\n",
"\n",
"class TimeWindowRandomForestRegressor:\n",
" \"\"\"\n",
" 基于时间窗口的随机森林多输出回归器\n",
" 用于预测40个音素的概率分布\n",
" \"\"\"\n",
" \n",
" def __init__(self, window_size=30, n_estimators=100, max_depth=10, n_jobs=-1, random_state=42):\n",
" \"\"\"\n",
" 初始化模型\n",
" \n",
" 参数:\n",
" window_size: 时间窗口大小\n",
" n_estimators: 随机森林中树的数量\n",
" max_depth: 树的最大深度\n",
" n_jobs: 并行任务数\n",
" random_state: 随机种子\n",
" \"\"\"\n",
" self.window_size = window_size\n",
" self.n_estimators = n_estimators\n",
" self.max_depth = max_depth\n",
" self.n_jobs = n_jobs\n",
" self.random_state = random_state\n",
" \n",
" # 初始化模型和预处理器\n",
" self.regressor = RandomForestRegressor(\n",
" n_estimators=n_estimators,\n",
" max_depth=max_depth,\n",
" n_jobs=n_jobs,\n",
" random_state=random_state,\n",
" verbose=1\n",
" )\n",
" \n",
" self.scaler = StandardScaler()\n",
" self.is_fitted = False\n",
" \n",
" print(f\"🌲 随机森林回归器初始化:\")\n",
" print(f\" 时间窗口大小: {window_size}\")\n",
" print(f\" 树的数量: {n_estimators}\")\n",
" print(f\" 最大深度: {max_depth}\")\n",
" print(f\" 并行任务: {n_jobs}\")\n",
" \n",
" def create_time_windows(self, neural_features, phoneme_targets=None):\n",
" \"\"\"\n",
" 创建时间窗口特征\n",
" \n",
" 参数:\n",
" neural_features: 神经特征 [time_steps, 512]\n",
" phoneme_targets: 音素目标 [time_steps, 40] (可选)\n",
" \n",
" 返回:\n",
" windowed_features: [samples, window_size * 512]\n",
" windowed_targets: [samples, 40] (如果提供targets)\n",
" \"\"\"\n",
" if len(neural_features) < self.window_size:\n",
" print(f\"⚠️ 数据长度 {len(neural_features)} 小于窗口大小 {self.window_size}\")\n",
" return None, None\n",
" \n",
" n_samples = len(neural_features) - self.window_size + 1\n",
" n_features = neural_features.shape[1]\n",
" \n",
" # 创建时间窗口特征\n",
" windowed_features = np.zeros((n_samples, self.window_size * n_features))\n",
" \n",
" for i in range(n_samples):\n",
" # 展平时间窗口内的所有特征\n",
" window_data = neural_features[i:i+self.window_size].flatten()\n",
" windowed_features[i] = window_data\n",
" \n",
" windowed_targets = None\n",
" if phoneme_targets is not None:\n",
" # 使用窗口中心点的音素概率作为目标\n",
" center_offset = self.window_size // 2\n",
" windowed_targets = phoneme_targets[center_offset:center_offset+n_samples]\n",
" \n",
" return windowed_features, windowed_targets\n",
" \n",
" def prepare_dataset_for_training(self, datasets_list, dataset_type=\"train\"):\n",
" \"\"\"\n",
" 准备训练数据集\n",
" \n",
" 参数:\n",
" datasets_list: DataFrame列表 (train_datasets, val_datasets, etc.)\n",
" dataset_type: 数据集类型名称\n",
" \n",
" 返回:\n",
" X: 特征矩阵 [总样本数, window_size * 512]\n",
" y: 目标矩阵 [总样本数, 40]\n",
" \"\"\"\n",
" print(f\"\\n📊 准备{dataset_type}数据集:\")\n",
" print(f\" 输入数据集数量: {len(datasets_list)}\")\n",
" \n",
" all_X = []\n",
" all_y = []\n",
" \n",
" for i, df in enumerate(tqdm(datasets_list, desc=f\"处理{dataset_type}数据\")):\n",
" # 提取神经特征 (前512列)\n",
" neural_cols = [col for col in df.columns if col.startswith('neural_feat_')]\n",
" neural_features = df[neural_cols].values\n",
" \n",
" # 提取音素目标 (40列音素概率)\n",
" phoneme_cols = [col for col in df.columns if col.startswith('phoneme_')]\n",
" phoneme_targets = df[phoneme_cols].values\n",
" \n",
" # 按trial分组处理\n",
" trials = df['trial_idx'].unique()\n",
" \n",
" for trial_idx in trials:\n",
" trial_mask = df['trial_idx'] == trial_idx\n",
" trial_neural = neural_features[trial_mask]\n",
" trial_phonemes = phoneme_targets[trial_mask]\n",
" \n",
" # 创建时间窗口\n",
" windowed_X, windowed_y = self.create_time_windows(trial_neural, trial_phonemes)\n",
" \n",
" if windowed_X is not None and windowed_y is not None:\n",
" all_X.append(windowed_X)\n",
" all_y.append(windowed_y)\n",
" \n",
" if not all_X:\n",
" print(f\"❌ 没有有效的{dataset_type}数据\")\n",
" return None, None\n",
" \n",
" # 合并所有数据\n",
" X = np.vstack(all_X)\n",
" y = np.vstack(all_y)\n",
" \n",
" print(f\" ✅ {dataset_type}数据准备完成:\")\n",
" print(f\" 特征矩阵形状: {X.shape}\")\n",
" print(f\" 目标矩阵形状: {y.shape}\")\n",
" print(f\" 内存使用: {X.nbytes / 1024**2:.1f} MB (X) + {y.nbytes / 1024**2:.1f} MB (y)\")\n",
" \n",
" return X, y\n",
" \n",
" def fit(self, X_train, y_train, X_val=None, y_val=None):\n",
" \"\"\"\n",
" 训练模型\n",
" \n",
" 参数:\n",
" X_train: 训练特征\n",
" y_train: 训练目标\n",
" X_val: 验证特征 (可选)\n",
" y_val: 验证目标 (可选)\n",
" \"\"\"\n",
" print(f\"\\n🚀 开始训练随机森林回归模型:\")\n",
" print(f\" 训练样本数: {X_train.shape[0]:,}\")\n",
" print(f\" 特征维度: {X_train.shape[1]:,}\")\n",
" print(f\" 目标维度: {y_train.shape[1]}\")\n",
" \n",
" # 标准化特征\n",
" print(\" 🔄 标准化特征...\")\n",
" X_train_scaled = self.scaler.fit_transform(X_train)\n",
" \n",
" # 训练模型\n",
" print(\" 🌲 训练随机森林...\")\n",
" self.regressor.fit(X_train_scaled, y_train)\n",
" \n",
" self.is_fitted = True\n",
" \n",
" # 计算训练集性能\n",
" print(\" 📊 评估训练集性能...\")\n",
" train_predictions = self.regressor.predict(X_train_scaled)\n",
" train_mse = mean_squared_error(y_train, train_predictions)\n",
" train_r2 = r2_score(y_train, train_predictions)\n",
" train_mae = mean_absolute_error(y_train, train_predictions)\n",
" \n",
" print(f\" ✅ 训练完成!\")\n",
" print(f\" 训练集 MSE: {train_mse:.6f}\")\n",
" print(f\" 训练集 R²: {train_r2:.4f}\")\n",
" print(f\" 训练集 MAE: {train_mae:.6f}\")\n",
" \n",
" # 如果有验证集,计算验证集性能\n",
" if X_val is not None and y_val is not None:\n",
" print(\" 📊 评估验证集性能...\")\n",
" X_val_scaled = self.scaler.transform(X_val)\n",
" val_predictions = self.regressor.predict(X_val_scaled)\n",
" val_mse = mean_squared_error(y_val, val_predictions)\n",
" val_r2 = r2_score(y_val, val_predictions)\n",
" val_mae = mean_absolute_error(y_val, val_predictions)\n",
" \n",
" print(f\" 验证集 MSE: {val_mse:.6f}\")\n",
" print(f\" 验证集 R²: {val_r2:.4f}\")\n",
" print(f\" 验证集 MAE: {val_mae:.6f}\")\n",
" \n",
" return {\n",
" 'train_mse': train_mse, 'train_r2': train_r2, 'train_mae': train_mae,\n",
" 'val_mse': val_mse, 'val_r2': val_r2, 'val_mae': val_mae\n",
" }\n",
" \n",
" return {\n",
" 'train_mse': train_mse, 'train_r2': train_r2, 'train_mae': train_mae\n",
" }\n",
" \n",
" def predict(self, X):\n",
" \"\"\"预测\"\"\"\n",
" if not self.is_fitted:\n",
" raise ValueError(\"模型尚未训练请先调用fit()方法\")\n",
" \n",
" X_scaled = self.scaler.transform(X)\n",
" return self.regressor.predict(X_scaled)\n",
" \n",
" def get_feature_importance(self, top_k=20):\n",
" \"\"\"获取特征重要性\"\"\"\n",
" if not self.is_fitted:\n",
" raise ValueError(\"模型尚未训练请先调用fit()方法\")\n",
" \n",
" importances = self.regressor.feature_importances_\n",
" \n",
" # 创建特征名称 (window_timestep_feature)\n",
" feature_names = []\n",
" for t in range(self.window_size):\n",
" for f in range(512):\n",
" feature_names.append(f\"t{t}_feat{f}\")\n",
" \n",
" # 获取top-k重要特征\n",
" top_indices = np.argsort(importances)[::-1][:top_k]\n",
" top_features = [(feature_names[i], importances[i]) for i in top_indices]\n",
" \n",
" return top_features, importances\n",
"\n",
"# 初始化模型\n",
"print(\"🌲 初始化随机森林回归模型\")\n",
"rf_regressor = TimeWindowRandomForestRegressor(\n",
" window_size=30, # 时间窗口大小\n",
" n_estimators=100, # 树的数量\n",
" max_depth=10, # 最大深度 (防止过拟合)\n",
" n_jobs=-1, # 使用所有CPU核心\n",
" random_state=42\n",
")\n",
"\n",
"print(\"\\n✅ 随机森林回归器准备完成!\")\n",
"print(\"🔧 下一步: 准备训练数据和开始训练\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🚀 开始数据准备和模型训练流程\n",
"============================================================\n"
]
},
{
"ename": "NameError",
"evalue": "name 'train_datasets' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_37/3627267466.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;31m# 检查数据可用性\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mtrain_datasets\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"❌ 没有可用的训练数据,请先运行数据处理工作流\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mNameError\u001b[0m: name 'train_datasets' is not defined"
]
}
],
"source": [
"# 🚀 准备数据并训练随机森林回归模型\n",
"print(\"🚀 开始数据准备和模型训练流程\")\n",
"print(\"=\"*60)\n",
"\n",
"# 检查数据可用性\n",
"if not train_datasets:\n",
" print(\"❌ 没有可用的训练数据,请先运行数据处理工作流\")\n",
"else:\n",
" print(f\"✅ 检测到数据:\")\n",
" print(f\" 训练数据集: {len(train_datasets)} 个sessions\")\n",
" print(f\" 验证数据集: {len(val_datasets)} 个sessions\")\n",
" print(f\" 测试数据集: {len(test_datasets)} 个sessions\")\n",
"\n",
" # 1. 准备训练数据\n",
" print(f\"\\n📊 第1步: 准备训练数据\")\n",
" X_train, y_train = rf_regressor.prepare_dataset_for_training(train_datasets, \"训练集\")\n",
" \n",
" # 2. 准备验证数据\n",
" print(f\"\\n📊 第2步: 准备验证数据\")\n",
" X_val, y_val = rf_regressor.prepare_dataset_for_training(val_datasets, \"验证集\")\n",
" \n",
" if X_train is not None and y_train is not None:\n",
" print(f\"\\n📈 数据准备完成统计:\")\n",
" print(f\" 训练集: {X_train.shape[0]:,} 样本\")\n",
" print(f\" 验证集: {X_val.shape[0]:,} 样本\" if X_val is not None else \" 验证集: 无\")\n",
" print(f\" 特征维度: {X_train.shape[1]:,} (时间窗口30 × 512特征)\")\n",
" print(f\" 目标维度: {y_train.shape[1]} (40个音素概率)\")\n",
" \n",
" # 检查数据质量\n",
" print(f\"\\n🔍 数据质量检查:\")\n",
" print(f\" 训练特征范围: [{X_train.min():.4f}, {X_train.max():.4f}]\")\n",
" print(f\" 训练目标范围: [{y_train.min():.4f}, {y_train.max():.4f}]\")\n",
" print(f\" 训练特征均值: {X_train.mean():.4f}\")\n",
" print(f\" 训练目标均值: {y_train.mean():.4f}\")\n",
" \n",
" # 检查是否有NaN或无穷值\n",
" nan_count_X = np.isnan(X_train).sum()\n",
" nan_count_y = np.isnan(y_train).sum()\n",
" inf_count_X = np.isinf(X_train).sum()\n",
" inf_count_y = np.isinf(y_train).sum()\n",
" \n",
" print(f\" NaN检查: X有{nan_count_X}个, y有{nan_count_y}个\")\n",
" print(f\" Inf检查: X有{inf_count_X}个, y有{inf_count_y}个\")\n",
" \n",
" if nan_count_X > 0 or nan_count_y > 0 or inf_count_X > 0 or inf_count_y > 0:\n",
" print(\"⚠️ 检测到异常值,将进行清理...\")\n",
" # 清理异常值\n",
" valid_mask = ~(np.isnan(X_train).any(axis=1) | np.isnan(y_train).any(axis=1) | \n",
" np.isinf(X_train).any(axis=1) | np.isinf(y_train).any(axis=1))\n",
" X_train = X_train[valid_mask]\n",
" y_train = y_train[valid_mask]\n",
" \n",
" if X_val is not None and y_val is not None:\n",
" valid_mask_val = ~(np.isnan(X_val).any(axis=1) | np.isnan(y_val).any(axis=1) | \n",
" np.isinf(X_val).any(axis=1) | np.isinf(y_val).any(axis=1))\n",
" X_val = X_val[valid_mask_val]\n",
" y_val = y_val[valid_mask_val]\n",
" \n",
" print(f\"✅ 数据清理完成,剩余训练样本: {X_train.shape[0]:,}\")\n",
" \n",
" # 3. 训练模型\n",
" print(f\"\\n🌲 第3步: 训练随机森林回归模型\")\n",
" training_results = rf_regressor.fit(X_train, y_train, X_val, y_val)\n",
" \n",
" # 4. 分析训练结果\n",
" print(f\"\\n📊 第4步: 训练结果分析\")\n",
" print(\"=\"*50)\n",
" \n",
" for metric, value in training_results.items():\n",
" metric_name = metric.replace('_', ' ').title()\n",
" print(f\" {metric_name}: {value:.6f}\")\n",
" \n",
" # 5. 特征重要性分析\n",
" print(f\"\\n🔍 第5步: 特征重要性分析\")\n",
" top_features, all_importances = rf_regressor.get_feature_importance(top_k=20)\n",
" \n",
" print(f\"\\n🏆 Top 20 重要特征:\")\n",
" print(f\"{'排名':>4} {'特征名称':>15} {'重要性':>10}\")\n",
" print(\"-\" * 35)\n",
" for i, (feature_name, importance) in enumerate(top_features):\n",
" print(f\"{i+1:>4} {feature_name:>15} {importance:>10.6f}\")\n",
" \n",
" # 分析时间窗口内的重要性分布\n",
" print(f\"\\n📈 时间窗口重要性分布:\")\n",
" window_importances = np.zeros(rf_regressor.window_size)\n",
" for i in range(rf_regressor.window_size):\n",
" start_idx = i * 512\n",
" end_idx = (i + 1) * 512\n",
" window_importances[i] = all_importances[start_idx:end_idx].sum()\n",
" \n",
" max_time_step = np.argmax(window_importances)\n",
" print(f\" 最重要的时间步: t{max_time_step} (重要性: {window_importances[max_time_step]:.6f})\")\n",
" print(f\" 窗口中心位置: t{rf_regressor.window_size//2}\")\n",
" print(f\" 重要性分布: 前5个时间步的重要性\")\n",
" for i in range(min(5, len(window_importances))):\n",
" print(f\" t{i}: {window_importances[i]:.6f}\")\n",
" \n",
" print(f\"\\n✅ 随机森林回归模型训练完成!\")\n",
" print(f\"🎯 模型可以预测40个音素的概率分布\")\n",
" print(f\"📊 基于30时间步的神经特征窗口\")\n",
" print(f\"🌲 使用{rf_regressor.n_estimators}棵决策树\")\n",
" \n",
" else:\n",
" print(\"❌ 数据准备失败,无法训练模型\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 📊 模型评估和可视化分析\n",
"def evaluate_phoneme_predictions(rf_model, X_test, y_test, dataset_name=\"测试集\"):\n",
" \"\"\"\n",
" 评估每个音素的预测性能\n",
" \"\"\"\n",
" print(f\"\\n📊 {dataset_name}详细评估\")\n",
" print(\"=\"*50)\n",
" \n",
" # 获取预测结果\n",
" y_pred = rf_model.predict(X_test)\n",
" \n",
" # 计算每个音素的性能指标\n",
" phoneme_metrics = []\n",
" \n",
" for i in range(40): # 40个音素\n",
" phoneme_name = LOGIT_TO_PHONEME[i]\n",
" \n",
" # 计算单个音素的指标\n",
" mse = mean_squared_error(y_test[:, i], y_pred[:, i])\n",
" r2 = r2_score(y_test[:, i], y_pred[:, i])\n",
" mae = mean_absolute_error(y_test[:, i], y_pred[:, i])\n",
" \n",
" # 计算相关系数\n",
" correlation = np.corrcoef(y_test[:, i], y_pred[:, i])[0, 1]\n",
" \n",
" phoneme_metrics.append({\n",
" 'phoneme_id': i,\n",
" 'phoneme_name': phoneme_name,\n",
" 'mse': mse, \n",
" 'r2': r2,\n",
" 'mae': mae,\n",
" 'correlation': correlation if not np.isnan(correlation) else 0.0\n",
" })\n",
" \n",
" # 转换为DataFrame便于分析\n",
" metrics_df = pd.DataFrame(phoneme_metrics)\n",
" \n",
" # 打印总体统计\n",
" print(f\"📈 总体性能指标:\")\n",
" print(f\" 平均 MSE: {metrics_df['mse'].mean():.6f}\")\n",
" print(f\" 平均 R²: {metrics_df['r2'].mean():.4f}\")\n",
" print(f\" 平均 MAE: {metrics_df['mae'].mean():.6f}\")\n",
" print(f\" 平均相关系数: {metrics_df['correlation'].mean():.4f}\")\n",
" \n",
" # 找出最佳和最差预测的音素\n",
" best_r2_idx = metrics_df['r2'].idxmax()\n",
" worst_r2_idx = metrics_df['r2'].idxmin()\n",
" \n",
" print(f\"\\n🏆 最佳预测音素:\")\n",
" best_phoneme = metrics_df.loc[best_r2_idx]\n",
" print(f\" {best_phoneme['phoneme_name']} (ID: {best_phoneme['phoneme_id']})\")\n",
" print(f\" R²: {best_phoneme['r2']:.4f}, MSE: {best_phoneme['mse']:.6f}\")\n",
" \n",
" print(f\"\\n📉 最差预测音素:\")\n",
" worst_phoneme = metrics_df.loc[worst_r2_idx]\n",
" print(f\" {worst_phoneme['phoneme_name']} (ID: {worst_phoneme['phoneme_id']})\")\n",
" print(f\" R²: {worst_phoneme['r2']:.4f}, MSE: {worst_phoneme['mse']:.6f}\")\n",
" \n",
" return metrics_df, y_pred\n",
"\n",
"def visualize_prediction_results(metrics_df, y_true, y_pred, save_plots=False):\n",
" \"\"\"\n",
" 可视化预测结果\n",
" \"\"\"\n",
" print(f\"\\n📊 创建可视化图表...\")\n",
" \n",
" # 设置图表样式\n",
" plt.style.use('default')\n",
" fig = plt.figure(figsize=(20, 12))\n",
" \n",
" # 1. R²分数分布\n",
" plt.subplot(2, 3, 1)\n",
" plt.hist(metrics_df['r2'], bins=20, alpha=0.7, color='skyblue', edgecolor='black')\n",
" plt.axvline(metrics_df['r2'].mean(), color='red', linestyle='--', \n",
" label=f'平均值: {metrics_df[\"r2\"].mean():.4f}')\n",
" plt.xlabel('R² Score')\n",
" plt.ylabel('音素数量')\n",
" plt.title('R² Score 分布')\n",
" plt.legend()\n",
" plt.grid(True, alpha=0.3)\n",
" \n",
" # 2. MSE分布\n",
" plt.subplot(2, 3, 2)\n",
" plt.hist(metrics_df['mse'], bins=20, alpha=0.7, color='lightcoral', edgecolor='black')\n",
" plt.axvline(metrics_df['mse'].mean(), color='red', linestyle='--',\n",
" label=f'平均值: {metrics_df[\"mse\"].mean():.6f}')\n",
" plt.xlabel('Mean Squared Error')\n",
" plt.ylabel('音素数量')\n",
" plt.title('MSE 分布')\n",
" plt.legend()\n",
" plt.grid(True, alpha=0.3)\n",
" \n",
" # 3. 前10个音素的性能对比\n",
" plt.subplot(2, 3, 3)\n",
" top_10 = metrics_df.nlargest(10, 'r2')\n",
" bars = plt.bar(range(10), top_10['r2'], color='lightgreen', alpha=0.7)\n",
" plt.xlabel('音素排名')\n",
" plt.ylabel('R² Score')\n",
" plt.title('Top 10 音素预测性能')\n",
" plt.xticks(range(10), top_10['phoneme_name'], rotation=45)\n",
" plt.grid(True, alpha=0.3)\n",
" \n",
" # 添加数值标签\n",
" for i, bar in enumerate(bars):\n",
" height = bar.get_height()\n",
" plt.text(bar.get_x() + bar.get_width()/2., height + 0.001,\n",
" f'{height:.3f}', ha='center', va='bottom', fontsize=8)\n",
" \n",
" # 4. 真实值 vs 预测值散点图 (选择最佳音素)\n",
" plt.subplot(2, 3, 4)\n",
" best_phoneme_idx = metrics_df['r2'].idxmax()\n",
" phoneme_id = metrics_df.loc[best_phoneme_idx, 'phoneme_id']\n",
" phoneme_name = metrics_df.loc[best_phoneme_idx, 'phoneme_name']\n",
" \n",
" # 随机采样1000个点以避免图表过于密集\n",
" sample_size = min(1000, len(y_true))\n",
" sample_indices = np.random.choice(len(y_true), sample_size, replace=False)\n",
" \n",
" plt.scatter(y_true[sample_indices, phoneme_id], y_pred[sample_indices, phoneme_id], \n",
" alpha=0.6, s=20, color='blue')\n",
" \n",
" # 添加对角线 (完美预测线)\n",
" min_val = min(y_true[:, phoneme_id].min(), y_pred[:, phoneme_id].min())\n",
" max_val = max(y_true[:, phoneme_id].max(), y_pred[:, phoneme_id].max())\n",
" plt.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8, label='完美预测')\n",
" \n",
" plt.xlabel('真实值')\n",
" plt.ylabel('预测值')\n",
" plt.title(f'最佳音素 {phoneme_name} 的预测结果')\n",
" plt.legend()\n",
" plt.grid(True, alpha=0.3)\n",
" \n",
" # 5. 相关系数热力图 (前20个音素)\n",
" plt.subplot(2, 3, 5)\n",
" top_20_correlations = metrics_df.nlargest(20, 'correlation')\n",
" corr_data = top_20_correlations[['phoneme_name', 'correlation']].set_index('phoneme_name')\n",
" \n",
" # 创建热力图数据\n",
" heatmap_data = corr_data.values.reshape(-1, 1)\n",
" im = plt.imshow(heatmap_data.T, cmap='RdYlBu_r', aspect='auto', vmin=0, vmax=1)\n",
" \n",
" plt.colorbar(im, shrink=0.8)\n",
" plt.yticks([0], ['相关系数'])\n",
" plt.xticks(range(len(top_20_correlations)), top_20_correlations['phoneme_name'], \n",
" rotation=45, ha='right')\n",
" plt.title('Top 20 音素相关系数')\n",
" \n",
" # 6. 各音素预测误差箱线图 (前10个音素)\n",
" plt.subplot(2, 3, 6)\n",
" top_10_ids = metrics_df.nlargest(10, 'r2')['phoneme_id'].values\n",
" errors_data = []\n",
" labels = []\n",
" \n",
" for phoneme_id in top_10_ids:\n",
" errors = np.abs(y_true[:, phoneme_id] - y_pred[:, phoneme_id])\n",
" errors_data.append(errors)\n",
" labels.append(LOGIT_TO_PHONEME[phoneme_id])\n",
" \n",
" plt.boxplot(errors_data, labels=labels)\n",
" plt.xlabel('音素')\n",
" plt.ylabel('绝对误差')\n",
" plt.title('Top 10 音素预测误差分布')\n",
" plt.xticks(rotation=45)\n",
" plt.grid(True, alpha=0.3)\n",
" \n",
" plt.tight_layout()\n",
" \n",
" if save_plots:\n",
" plt.savefig('./processed_datasets/rf_regression_results.png', dpi=300, bbox_inches='tight')\n",
" print(\"📁 图表已保存至: ./processed_datasets/rf_regression_results.png\")\n",
" \n",
" plt.show()\n",
"\n",
"# 如果模型训练成功,进行评估\n",
"if 'rf_regressor' in locals() and rf_regressor.is_fitted:\n",
" print(f\"\\n🎯 开始模型评估和可视化\")\n",
" \n",
" # 评估验证集\n",
" if X_val is not None and y_val is not None:\n",
" val_metrics, val_predictions = evaluate_phoneme_predictions(\n",
" rf_regressor, X_val, y_val, \"验证集\"\n",
" )\n",
" \n",
" # 可视化结果\n",
" visualize_prediction_results(val_metrics, y_val, val_predictions, save_plots=True)\n",
" \n",
" # 保存详细结果\n",
" val_metrics.to_csv('./processed_datasets/phoneme_prediction_metrics.csv', index=False)\n",
" print(f\"\\n📁 详细评估结果已保存至: ./processed_datasets/phoneme_prediction_metrics.csv\")\n",
" \n",
" # 准备测试集数据 (如果有)\n",
" if test_datasets:\n",
" print(f\"\\n🔮 准备测试集预测...\")\n",
" X_test, y_test = rf_regressor.prepare_dataset_for_training(test_datasets, \"测试集\")\n",
" \n",
" if X_test is not None:\n",
" test_metrics, test_predictions = evaluate_phoneme_predictions(\n",
" rf_regressor, X_test, y_test, \"测试集\"\n",
" )\n",
" print(f\"\\n✅ 测试集评估完成\")\n",
" else:\n",
" print(f\"⚠️ 测试集数据准备失败\")\n",
" \n",
" print(f\"\\n🎉 随机森林回归模型完整评估完成!\")\n",
" print(f\"📊 生成了详细的性能分析和可视化图表\")\n",
" print(f\"🔧 模型已准备好用于实际预测任务\")\n",
" \n",
"else:\n",
" print(\"⚠️ 模型尚未训练完成,请先运行训练代码\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 🎯 回归结果转分类结果分析\n",
"def regression_to_classification_analysis(y_true_probs, y_pred_probs, show_detailed_metrics=True):\n",
" \"\"\"\n",
" 将回归预测的40个音素概率转换为分类结果并进行分析\n",
" \n",
" 参数:\n",
" y_true_probs: 真实的40个音素概率 [n_samples, 40]\n",
" y_pred_probs: 预测的40个音素概率 [n_samples, 40]\n",
" show_detailed_metrics: 是否显示详细的分类指标\n",
" \n",
" 返回:\n",
" classification_results: 包含分类结果的字典\n",
" \"\"\"\n",
" print(\"🎯 回归结果转分类结果分析\")\n",
" print(\"=\"*60)\n",
" \n",
" # 1. 将概率转换为分类标签\n",
" y_true_classes = np.argmax(y_true_probs, axis=1) # 真实类别\n",
" y_pred_classes = np.argmax(y_pred_probs, axis=1) # 预测类别\n",
" \n",
" # 2. 计算分类准确率\n",
" accuracy = (y_true_classes == y_pred_classes).mean()\n",
" \n",
" print(f\"📊 分类结果概览:\")\n",
" print(f\" 总样本数: {len(y_true_classes):,}\")\n",
" print(f\" 分类准确率: {accuracy:.4f} ({accuracy*100:.2f}%)\")\n",
" print(f\" 正确预测: {(y_true_classes == y_pred_classes).sum():,}\")\n",
" print(f\" 错误预测: {(y_true_classes != y_pred_classes).sum():,}\")\n",
" \n",
" # 3. 分析预测置信度\n",
" pred_confidences = np.max(y_pred_probs, axis=1) # 预测的最大概率\n",
" true_confidences = np.max(y_true_probs, axis=1) # 真实的最大概率\n",
" \n",
" print(f\"\\n🔍 预测置信度分析:\")\n",
" print(f\" 预测置信度均值: {pred_confidences.mean():.4f}\")\n",
" print(f\" 预测置信度标准差: {pred_confidences.std():.4f}\")\n",
" print(f\" 预测置信度范围: [{pred_confidences.min():.4f}, {pred_confidences.max():.4f}]\")\n",
" print(f\" 真实置信度均值: {true_confidences.mean():.4f}\")\n",
" \n",
" # 4. 按置信度分组的准确率分析\n",
" confidence_bins = [0.0, 0.3, 0.5, 0.7, 0.9, 1.0]\n",
" print(f\"\\n📈 按预测置信度分组的准确率:\")\n",
" print(f\"{'置信度区间':>12} {'样本数':>8} {'准确率':>8} {'百分比':>8}\")\n",
" print(\"-\" * 40)\n",
" \n",
" for i in range(len(confidence_bins)-1):\n",
" low, high = confidence_bins[i], confidence_bins[i+1]\n",
" mask = (pred_confidences >= low) & (pred_confidences < high)\n",
" if i == len(confidence_bins)-2: # 最后一个区间包含等号\n",
" mask = (pred_confidences >= low) & (pred_confidences <= high)\n",
" \n",
" if mask.sum() > 0:\n",
" bin_accuracy = (y_true_classes[mask] == y_pred_classes[mask]).mean()\n",
" count = mask.sum()\n",
" percentage = count / len(y_true_classes) * 100\n",
" print(f\"[{low:.1f}, {high:.1f}{')'if i<len(confidence_bins)-2 else ']':>1} {count:>8} {bin_accuracy:>8.4f} {percentage:>7.1f}%\")\n",
" \n",
" # 5. 混淆矩阵分析Top-K音素\n",
" from collections import Counter\n",
" \n",
" # 找出最常见的音素\n",
" true_counter = Counter(y_true_classes)\n",
" pred_counter = Counter(y_pred_classes)\n",
" \n",
" most_common_true = true_counter.most_common(10)\n",
" most_common_pred = pred_counter.most_common(10)\n",
" \n",
" print(f\"\\n🏆 最常见的音素 (真实 vs 预测):\")\n",
" print(f\"{'真实音素':>12} {'次数':>6} {'预测音素':>12} {'次数':>6}\")\n",
" print(\"-\" * 42)\n",
" \n",
" for i in range(min(len(most_common_true), len(most_common_pred))):\n",
" true_id, true_count = most_common_true[i]\n",
" pred_id, pred_count = most_common_pred[i]\n",
" true_name = LOGIT_TO_PHONEME[true_id]\n",
" pred_name = LOGIT_TO_PHONEME[pred_id]\n",
" print(f\"{true_name:>12} {true_count:>6} {pred_name:>12} {pred_count:>6}\")\n",
" \n",
" # 6. 每个音素的分类性能\n",
" if show_detailed_metrics:\n",
" from sklearn.metrics import classification_report, confusion_matrix\n",
" \n",
" print(f\"\\n📋 详细分类报告 (前20个最常见音素):\")\n",
" \n",
" # 获取前20个最常见的音素\n",
" top_20_phonemes = [phoneme_id for phoneme_id, _ in most_common_true[:20]]\n",
" \n",
" # 创建掩码,只包含这些音素\n",
" mask_top20 = np.isin(y_true_classes, top_20_phonemes)\n",
" y_true_top20 = y_true_classes[mask_top20]\n",
" y_pred_top20 = y_pred_classes[mask_top20]\n",
" \n",
" # 生成分类报告\n",
" target_names = [LOGIT_TO_PHONEME[i] for i in top_20_phonemes]\n",
" \n",
" try:\n",
" report = classification_report(\n",
" y_true_top20, y_pred_top20, \n",
" labels=top_20_phonemes,\n",
" target_names=target_names,\n",
" output_dict=True,\n",
" zero_division=0\n",
" )\n",
" \n",
" # 打印格式化的报告\n",
" print(f\"{'音素':>8} {'精确率':>8} {'召回率':>8} {'F1分数':>8} {'支持数':>8}\")\n",
" print(\"-\" * 48)\n",
" \n",
" for phoneme_id in top_20_phonemes:\n",
" phoneme_name = LOGIT_TO_PHONEME[phoneme_id]\n",
" if phoneme_name in report:\n",
" metrics = report[phoneme_name]\n",
" print(f\"{phoneme_name:>8} {metrics['precision']:>8.4f} {metrics['recall']:>8.4f} \"\n",
" f\"{metrics['f1-score']:>8.4f} {int(metrics['support']):>8}\")\n",
" \n",
" # 总体指标\n",
" macro_avg = report['macro avg']\n",
" weighted_avg = report['weighted avg']\n",
" print(\"-\" * 48)\n",
" print(f\"{'宏平均':>8} {macro_avg['precision']:>8.4f} {macro_avg['recall']:>8.4f} \"\n",
" f\"{macro_avg['f1-score']:>8.4f}\")\n",
" print(f\"{'加权平均':>8} {weighted_avg['precision']:>8.4f} {weighted_avg['recall']:>8.4f} \"\n",
" f\"{weighted_avg['f1-score']:>8.4f}\")\n",
" \n",
" except Exception as e:\n",
" print(f\"分类报告生成失败: {e}\")\n",
" \n",
" # 7. Top-K准确率分析\n",
" print(f\"\\n🎯 Top-K 准确率分析:\")\n",
" for k in [1, 3, 5, 10]:\n",
" # 计算Top-K准确率\n",
" top_k_pred = np.argsort(y_pred_probs, axis=1)[:, -k:] # 取概率最高的K个\n",
" top_k_accuracy = np.mean([y_true_classes[i] in top_k_pred[i] for i in range(len(y_true_classes))])\n",
" print(f\" Top-{k} 准确率: {top_k_accuracy:.4f} ({top_k_accuracy*100:.2f}%)\")\n",
" \n",
" # 8. 错误分析 - 最常见的预测错误\n",
" print(f\"\\n❌ 最常见的预测错误:\")\n",
" error_mask = y_true_classes != y_pred_classes\n",
" error_pairs = list(zip(y_true_classes[error_mask], y_pred_classes[error_mask]))\n",
" error_counter = Counter(error_pairs)\n",
" \n",
" print(f\"{'真实音素':>12} {'预测音素':>12} {'错误次数':>8}\")\n",
" print(\"-\" * 36)\n",
" for (true_id, pred_id), count in error_counter.most_common(10):\n",
" true_name = LOGIT_TO_PHONEME[true_id]\n",
" pred_name = LOGIT_TO_PHONEME[pred_id]\n",
" print(f\"{true_name:>12} {pred_name:>12} {count:>8}\")\n",
" \n",
" # 返回结果字典\n",
" classification_results = {\n",
" 'accuracy': accuracy,\n",
" 'y_true_classes': y_true_classes,\n",
" 'y_pred_classes': y_pred_classes,\n",
" 'pred_confidences': pred_confidences,\n",
" 'true_confidences': true_confidences,\n",
" 'most_common_errors': error_counter.most_common(10)\n",
" }\n",
" \n",
" return classification_results\n",
"\n",
"def create_classification_visualizations(y_true_probs, y_pred_probs, classification_results):\n",
" \"\"\"\n",
" 为分类结果创建可视化图表\n",
" \"\"\"\n",
" print(f\"\\n📊 创建分类结果可视化...\")\n",
" \n",
" fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n",
" fig.suptitle('随机森林回归转分类结果分析', fontsize=16, fontweight='bold')\n",
" \n",
" y_true_classes = classification_results['y_true_classes']\n",
" y_pred_classes = classification_results['y_pred_classes']\n",
" pred_confidences = classification_results['pred_confidences']\n",
" \n",
" # 1. 预测置信度分布\n",
" axes[0, 0].hist(pred_confidences, bins=50, alpha=0.7, color='skyblue', edgecolor='black')\n",
" axes[0, 0].axvline(pred_confidences.mean(), color='red', linestyle='--', \n",
" label=f'均值: {pred_confidences.mean():.3f}')\n",
" axes[0, 0].set_xlabel('预测置信度')\n",
" axes[0, 0].set_ylabel('样本数量')\n",
" axes[0, 0].set_title('预测置信度分布')\n",
" axes[0, 0].legend()\n",
" axes[0, 0].grid(True, alpha=0.3)\n",
" \n",
" # 2. 准确率 vs 置信度\n",
" confidence_bins = np.linspace(0, 1, 21)\n",
" bin_centers = (confidence_bins[:-1] + confidence_bins[1:]) / 2\n",
" bin_accuracies = []\n",
" bin_counts = []\n",
" \n",
" for i in range(len(confidence_bins)-1):\n",
" mask = (pred_confidences >= confidence_bins[i]) & (pred_confidences < confidence_bins[i+1])\n",
" if mask.sum() > 0:\n",
" accuracy = (y_true_classes[mask] == y_pred_classes[mask]).mean()\n",
" bin_accuracies.append(accuracy)\n",
" bin_counts.append(mask.sum())\n",
" else:\n",
" bin_accuracies.append(0)\n",
" bin_counts.append(0)\n",
" \n",
" # 只显示有数据的bins\n",
" valid_bins = np.array(bin_counts) > 0\n",
" axes[0, 1].plot(bin_centers[valid_bins], np.array(bin_accuracies)[valid_bins], \n",
" 'bo-', linewidth=2, markersize=6)\n",
" axes[0, 1].set_xlabel('预测置信度')\n",
" axes[0, 1].set_ylabel('准确率')\n",
" axes[0, 1].set_title('准确率 vs 预测置信度')\n",
" axes[0, 1].grid(True, alpha=0.3)\n",
" axes[0, 1].set_ylim(0, 1)\n",
" \n",
" # 3. 最常见音素的预测准确率\n",
" from collections import Counter\n",
" true_counter = Counter(y_true_classes)\n",
" most_common_phonemes = [phoneme_id for phoneme_id, _ in true_counter.most_common(15)]\n",
" \n",
" phoneme_accuracies = []\n",
" phoneme_names = []\n",
" for phoneme_id in most_common_phonemes:\n",
" mask = y_true_classes == phoneme_id\n",
" if mask.sum() > 0:\n",
" accuracy = (y_pred_classes[mask] == phoneme_id).mean()\n",
" phoneme_accuracies.append(accuracy)\n",
" phoneme_names.append(LOGIT_TO_PHONEME[phoneme_id])\n",
" \n",
" bars = axes[0, 2].bar(range(len(phoneme_names)), phoneme_accuracies, \n",
" color='lightgreen', alpha=0.7)\n",
" axes[0, 2].set_xlabel('音素')\n",
" axes[0, 2].set_ylabel('准确率')\n",
" axes[0, 2].set_title('Top 15 音素的分类准确率')\n",
" axes[0, 2].set_xticks(range(len(phoneme_names)))\n",
" axes[0, 2].set_xticklabels(phoneme_names, rotation=45, ha='right')\n",
" axes[0, 2].grid(True, alpha=0.3)\n",
" \n",
" # 添加数值标签\n",
" for bar, acc in zip(bars, phoneme_accuracies):\n",
" height = bar.get_height()\n",
" axes[0, 2].text(bar.get_x() + bar.get_width()/2., height + 0.01,\n",
" f'{acc:.3f}', ha='center', va='bottom', fontsize=8)\n",
" \n",
" # 4. 混淆矩阵前10个最常见音素\n",
" from sklearn.metrics import confusion_matrix\n",
" top_10_phonemes = most_common_phonemes[:10]\n",
" mask_top10 = np.isin(y_true_classes, top_10_phonemes) & np.isin(y_pred_classes, top_10_phonemes)\n",
" \n",
" if mask_top10.sum() > 0:\n",
" cm = confusion_matrix(y_true_classes[mask_top10], y_pred_classes[mask_top10], \n",
" labels=top_10_phonemes)\n",
" \n",
" im = axes[1, 0].imshow(cm, interpolation='nearest', cmap='Blues')\n",
" axes[1, 0].set_title('混淆矩阵 (Top 10 音素)')\n",
" \n",
" # 添加颜色条\n",
" cbar = plt.colorbar(im, ax=axes[1, 0], shrink=0.8)\n",
" cbar.set_label('预测次数')\n",
" \n",
" # 设置标签\n",
" tick_marks = np.arange(len(top_10_phonemes))\n",
" top_10_names = [LOGIT_TO_PHONEME[i] for i in top_10_phonemes]\n",
" axes[1, 0].set_xticks(tick_marks)\n",
" axes[1, 0].set_yticks(tick_marks)\n",
" axes[1, 0].set_xticklabels(top_10_names, rotation=45, ha='right')\n",
" axes[1, 0].set_yticklabels(top_10_names)\n",
" axes[1, 0].set_xlabel('预测音素')\n",
" axes[1, 0].set_ylabel('真实音素')\n",
" \n",
" # 5. Top-K准确率\n",
" k_values = [1, 2, 3, 4, 5, 10, 15, 20]\n",
" top_k_accuracies = []\n",
" \n",
" for k in k_values:\n",
" top_k_pred = np.argsort(y_pred_probs, axis=1)[:, -k:]\n",
" top_k_accuracy = np.mean([y_true_classes[i] in top_k_pred[i] for i in range(len(y_true_classes))])\n",
" top_k_accuracies.append(top_k_accuracy)\n",
" \n",
" axes[1, 1].plot(k_values, top_k_accuracies, 'ro-', linewidth=2, markersize=8)\n",
" axes[1, 1].set_xlabel('K 值')\n",
" axes[1, 1].set_ylabel('Top-K 准确率')\n",
" axes[1, 1].set_title('Top-K 准确率曲线')\n",
" axes[1, 1].grid(True, alpha=0.3)\n",
" axes[1, 1].set_ylim(0, 1)\n",
" \n",
" # 添加数值标签\n",
" for k, acc in zip(k_values, top_k_accuracies):\n",
" axes[1, 1].annotate(f'{acc:.3f}', (k, acc), textcoords=\"offset points\", \n",
" xytext=(0,10), ha='center')\n",
" \n",
" # 6. 错误分析 - 最常见错误的热力图\n",
" error_pairs = classification_results['most_common_errors'][:25] # 前25个最常见错误\n",
" if error_pairs:\n",
" # 创建错误矩阵\n",
" unique_phonemes = list(set([pair[0][0] for pair in error_pairs] + [pair[0][1] for pair in error_pairs]))\n",
" error_matrix = np.zeros((len(unique_phonemes), len(unique_phonemes)))\n",
" \n",
" phoneme_to_idx = {phoneme: i for i, phoneme in enumerate(unique_phonemes)}\n",
" \n",
" for (true_id, pred_id), count in error_pairs:\n",
" if true_id in phoneme_to_idx and pred_id in phoneme_to_idx:\n",
" true_idx = phoneme_to_idx[true_id]\n",
" pred_idx = phoneme_to_idx[pred_id]\n",
" error_matrix[true_idx, pred_idx] = count\n",
" \n",
" im = axes[1, 2].imshow(error_matrix, cmap='Reds', interpolation='nearest')\n",
" axes[1, 2].set_title('最常见错误分布')\n",
" \n",
" # 设置标签\n",
" phoneme_names = [LOGIT_TO_PHONEME[p] for p in unique_phonemes]\n",
" axes[1, 2].set_xticks(range(len(phoneme_names)))\n",
" axes[1, 2].set_yticks(range(len(phoneme_names)))\n",
" axes[1, 2].set_xticklabels(phoneme_names, rotation=45, ha='right')\n",
" axes[1, 2].set_yticklabels(phoneme_names)\n",
" axes[1, 2].set_xlabel('预测音素')\n",
" axes[1, 2].set_ylabel('真实音素')\n",
" \n",
" # 添加颜色条\n",
" cbar = plt.colorbar(im, ax=axes[1, 2], shrink=0.8)\n",
" cbar.set_label('错误次数')\n",
" \n",
" plt.tight_layout()\n",
" plt.savefig('./processed_datasets/classification_analysis.png', dpi=300, bbox_inches='tight')\n",
" print(\"📁 分类分析图表已保存至: ./processed_datasets/classification_analysis.png\")\n",
" plt.show()\n",
"\n",
"print(\"✅ 回归转分类分析功能已创建!\")\n",
"print(\"🎯 主要功能:\")\n",
"print(\"• 将40维概率回归结果转换为分类预测\")\n",
"print(\"• 计算分类准确率和置信度分析\")\n",
"print(\"• 提供Top-K准确率评估\")\n",
"print(\"• 生成详细的混淆矩阵和错误分析\")\n",
"print(\"• 创建全面的可视化图表\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 🎯 完整的回归转分类评估流程\n",
"def complete_regression_classification_evaluation(rf_model, X_test, y_test, dataset_name=\"测试集\"):\n",
" \"\"\"\n",
" 完整的回归模型转分类结果评估流程\n",
" \"\"\"\n",
" print(f\"\\n🎯 {dataset_name}完整评估: 回归 → 分类\")\n",
" print(\"=\"*70)\n",
" \n",
" # 1. 获取回归预测结果\n",
" print(\"📊 第1步: 获取回归预测...\")\n",
" y_pred_probs = rf_model.predict(X_test)\n",
" \n",
" # 确保概率值在合理范围内\n",
" y_pred_probs = np.clip(y_pred_probs, 0, 1)\n",
" \n",
" # 2. 回归性能评估\n",
" print(\"\\n📈 第2步: 回归性能评估...\")\n",
" mse = mean_squared_error(y_test, y_pred_probs) \n",
" mae = mean_absolute_error(y_test, y_pred_probs)\n",
" r2 = r2_score(y_test, y_pred_probs)\n",
" \n",
" print(f\" 回归 MSE: {mse:.6f}\")\n",
" print(f\" 回归 MAE: {mae:.6f}\")\n",
" print(f\" 回归 R²: {r2:.4f}\")\n",
" \n",
" # 3. 概率归一化softmax\n",
" print(\"\\n🔄 第3步: 概率归一化...\")\n",
" def softmax(x, axis=-1):\n",
" exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))\n",
" return exp_x / np.sum(exp_x, axis=axis, keepdims=True)\n",
" \n",
" # 对预测结果应用softmax使其成为真正的概率分布\n",
" y_pred_probs_normalized = softmax(y_pred_probs)\n",
" y_test_normalized = softmax(y_test) # 也对真实标签归一化\n",
" \n",
" print(f\" 预测概率归一化前: 每行和均值 = {np.mean(np.sum(y_pred_probs, axis=1)):.4f}\")\n",
" print(f\" 预测概率归一化后: 每行和均值 = {np.mean(np.sum(y_pred_probs_normalized, axis=1)):.4f}\")\n",
" \n",
" # 4. 分类结果分析\n",
" print(\"\\n🎯 第4步: 分类结果分析...\")\n",
" classification_results = regression_to_classification_analysis(\n",
" y_test_normalized, y_pred_probs_normalized, show_detailed_metrics=True\n",
" )\n",
" \n",
" # 5. 创建可视化\n",
" print(\"\\n📊 第5步: 创建可视化图表...\")\n",
" create_classification_visualizations(y_test_normalized, y_pred_probs_normalized, classification_results)\n",
" \n",
" # 6. 保存结果\n",
" print(\"\\n💾 第6步: 保存分析结果...\")\n",
" \n",
" # 保存分类结果\n",
" results_df = pd.DataFrame({\n",
" 'true_class': classification_results['y_true_classes'],\n",
" 'pred_class': classification_results['y_pred_classes'],\n",
" 'true_phoneme': [LOGIT_TO_PHONEME[i] for i in classification_results['y_true_classes']],\n",
" 'pred_phoneme': [LOGIT_TO_PHONEME[i] for i in classification_results['y_pred_classes']],\n",
" 'pred_confidence': classification_results['pred_confidences'],\n",
" 'is_correct': classification_results['y_true_classes'] == classification_results['y_pred_classes']\n",
" })\n",
" \n",
" results_df.to_csv('./processed_datasets/classification_results.csv', index=False)\n",
" \n",
" # 保存详细的概率预测\n",
" prob_results_df = pd.DataFrame(y_pred_probs_normalized, \n",
" columns=[f'prob_{LOGIT_TO_PHONEME[i]}' for i in range(40)])\n",
" prob_results_df['true_class'] = classification_results['y_true_classes']\n",
" prob_results_df['pred_class'] = classification_results['y_pred_classes']\n",
" \n",
" prob_results_df.to_csv('./processed_datasets/probability_predictions.csv', index=False)\n",
" \n",
" print(\"📁 结果已保存:\")\n",
" print(\" • ./processed_datasets/classification_results.csv (分类结果)\")\n",
" print(\" • ./processed_datasets/probability_predictions.csv (概率预测)\")\n",
" print(\" • ./processed_datasets/classification_analysis.png (可视化图表)\")\n",
" \n",
" # 7. 总结报告\n",
" print(f\"\\n📋 {dataset_name}评估总结:\")\n",
" print(\"=\"*50)\n",
" print(f\"🔸 回归性能:\")\n",
" print(f\" MSE: {mse:.6f}\")\n",
" print(f\" R²: {r2:.4f}\")\n",
" print(f\"🔸 分类性能:\")\n",
" print(f\" 准确率: {classification_results['accuracy']:.4f} ({classification_results['accuracy']*100:.2f}%)\")\n",
" print(f\" 平均置信度: {classification_results['pred_confidences'].mean():.4f}\")\n",
" \n",
" # 计算Top-K准确率\n",
" for k in [1, 3, 5]:\n",
" top_k_pred = np.argsort(y_pred_probs_normalized, axis=1)[:, -k:]\n",
" top_k_accuracy = np.mean([classification_results['y_true_classes'][i] in top_k_pred[i] \n",
" for i in range(len(classification_results['y_true_classes']))])\n",
" print(f\" Top-{k} 准确率: {top_k_accuracy:.4f} ({top_k_accuracy*100:.2f}%)\")\n",
" \n",
" return {\n",
" 'regression_metrics': {'mse': mse, 'mae': mae, 'r2': r2},\n",
" 'classification_results': classification_results,\n",
" 'normalized_predictions': y_pred_probs_normalized,\n",
" 'normalized_true': y_test_normalized\n",
" }\n",
"\n",
"# 如果模型已训练且有验证数据,执行完整评估\n",
"if 'rf_regressor' in locals() and hasattr(rf_regressor, 'is_fitted') and rf_regressor.is_fitted:\n",
" if 'X_val' in locals() and X_val is not None and 'y_val' in locals() and y_val is not None:\n",
" print(\"🚀 开始完整的回归转分类评估...\")\n",
" \n",
" # 执行完整评估\n",
" evaluation_results = complete_regression_classification_evaluation(\n",
" rf_regressor, X_val, y_val, \"验证集\"\n",
" )\n",
" \n",
" print(f\"\\n🎉 评估完成!\")\n",
" print(f\"✅ 随机森林回归模型成功转换为分类结果\")\n",
" print(f\"📊 生成了详细的性能分析和可视化\")\n",
" print(f\"💾 所有结果已保存到文件\")\n",
" \n",
" # 如果有测试数据,也进行评估\n",
" if 'X_test' in locals() and X_test is not None and 'y_test' in locals() and y_test is not None:\n",
" print(f\"\\n🔮 开始测试集评估...\")\n",
" test_evaluation_results = complete_regression_classification_evaluation(\n",
" rf_regressor, X_test, y_test, \"测试集\"\n",
" )\n",
" else:\n",
" print(\"⚠️ 没有可用的验证数据进行评估\")\n",
"else:\n",
" print(\"⚠️ 随机森林模型尚未训练完成\")\n",
" print(\"💡 请先运行前面的训练代码\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([17.75 , -0.80859375, -2.03125 , -1.8046875 , -0.85546875,\n",
" -1.421875 , 1.765625 , -2.703125 , -1.984375 , 4.0625 ,\n",
" 2. , -3.15625 , 0.72265625, -0.8671875 , -1.90625 ,\n",
" -2.0625 , -1.28125 , -1.03125 , 0.21289062, -1.890625 ,\n",
" -0.4453125 , -0.5546875 , 0.5625 , -0.421875 , -0.22460938,\n",
" 0.3515625 , -2.375 , -1.8984375 , 2.796875 , 0.3515625 ,\n",
" -2.484375 , 1.453125 , 0.30078125, -2.390625 , 0.19335938,\n",
" 0.35742188, -1.484375 , -2.8125 , -0.84375 , -3.0625 ,\n",
" 4.96875 ], dtype=float32)"
]
},
"execution_count": 84,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"single_result['train']['confidence_scores'][0][0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['concatenated_data', 'processed_features', 'confidence_scores', 'trial_metadata', 'processing_stats'])"
]
},
"execution_count": 77,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"single_result['train'].keys()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 🌲 随机森林"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"======================================================================\n",
"🎯 重新定义任务:音素分类任务\n",
"======================================================================\n",
"📋 任务重新定义:\n",
" 输入: 神经特征 (前7168维)\n",
" 输出: 音素置信度 (后41维RNN输出)\n",
" 目标: 预测每个时间步的音素概率分布\n",
" 分类: 选择最大置信度对应的音素\n",
"\n",
"🔧 数据重新处理:\n",
" 神经特征矩阵: (348, 7168)\n",
" 音素logits矩阵: (348, 41)\n",
" 音素类别标签: (348,) (值范围: 0-0)\n",
" 音素分布: 1 个不同音素\n",
" 样本最多的前5个音素:\n",
" 音素 0: 348 次\n",
"\n",
"🔄 训练测试集切分:\n",
" 训练集: 278 样本\n",
" 测试集: 70 样本\n",
" 音素类别数: 1\n",
" 训练集音素分布: 1 个不同音素\n",
" 测试集音素分布: 1 个不同音素\n",
"\n",
"======================================================================\n",
"🌲 随机森林回归 + 分类\n",
"======================================================================\n",
"📊 方案1: 多输出回归 (神经特征 → 音素logits)\n",
"🚀 训练回归模型...\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_186/3764425816.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"🚀 训练回归模型...\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0mstart_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 87\u001b[0;31m \u001b[0mrf_regressor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train_neu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_logits_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 88\u001b[0m \u001b[0mregression_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"✅ 回归训练完成!耗时: {regression_time:.2f} 秒\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/base.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1387\u001b[0m )\n\u001b[1;32m 1388\u001b[0m ):\n\u001b[0;32m-> 1389\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfit_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1390\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1391\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/multioutput.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, sample_weight, **fit_params)\u001b[0m\n\u001b[1;32m 272\u001b[0m \u001b[0mrouted_params\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"sample_weight\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 273\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 274\u001b[0;31m self.estimators_ = Parallel(n_jobs=self.n_jobs)(\n\u001b[0m\u001b[1;32m 275\u001b[0m delayed(_fit_estimator)(\n\u001b[1;32m 276\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mrouted_params\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/utils/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mdelayed_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 76\u001b[0m )\n\u001b[0;32m---> 77\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterable_with_config\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 78\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 79\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 1984\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_sequential_output\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1985\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1986\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreturn_generator\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1987\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1988\u001b[0m \u001b[0;31m# Let's create an ID that uniquely identifies the current call. If the\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m_get_sequential_output\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 1912\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_dispatched_batches\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1913\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_dispatched_tasks\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1914\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1915\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_completed_tasks\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1916\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_progress\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/utils/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0mconfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mconfig_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 139\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 140\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/multioutput.py\u001b[0m in \u001b[0;36m_fit_estimator\u001b[0;34m(estimator, X, y, sample_weight, **fit_params)\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 62\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 64\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mestimator\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 65\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/base.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1387\u001b[0m )\n\u001b[1;32m 1388\u001b[0m ):\n\u001b[0;32m-> 1389\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfit_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1390\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1391\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/ensemble/_forest.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[0;31m# parallel_backend contexts set at a higher level,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 486\u001b[0m \u001b[0;31m# since correctness does not rely on using threads.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 487\u001b[0;31m trees = Parallel(\n\u001b[0m\u001b[1;32m 488\u001b[0m \u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_jobs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 489\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mverbose\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/utils/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mdelayed_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 76\u001b[0m )\n\u001b[0;32m---> 77\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterable_with_config\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 78\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 79\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m 2070\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2071\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2072\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreturn_generator\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2073\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2074\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__repr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m_get_outputs\u001b[0;34m(self, iterator, pre_dispatch)\u001b[0m\n\u001b[1;32m 1680\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1681\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mretrieval_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1682\u001b[0;31m \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_retrieve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1683\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1684\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mGeneratorExit\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.11/dist-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m_retrieve\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 1798\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jobs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_status\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mTASK_PENDING\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1799\u001b[0m ):\n\u001b[0;32m-> 1800\u001b[0;31m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msleep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.01\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1801\u001b[0m \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1802\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"# 🎯 正确的任务:音素分类 (神经特征 → 音素置信度回归)\n",
"\n",
"print(\"\\n\" + \"=\"*70)\n",
"print(\"🎯 重新定义任务:音素分类任务\")\n",
"print(\"=\"*70)\n",
"\n",
"print(\"📋 任务重新定义:\")\n",
"print(\" 输入: 神经特征 (前7168维)\")\n",
"print(\" 输出: 音素置信度 (后41维RNN输出)\")\n",
"print(\" 目标: 预测每个时间步的音素概率分布\")\n",
"print(\" 分类: 选择最大置信度对应的音素\") \n",
"\n",
"# 重新准备数据\n",
"print(f\"\\n🔧 数据重新处理:\")\n",
"\n",
"# 分离输入特征和目标\n",
"X_neural = [] # 神经特征 (7168维)\n",
"y_phoneme_logits = [] # 音素置信度 (41维)\n",
"\n",
"for features in valid_features:\n",
" # features shape: [time_steps, 7209]\n",
" neural_part = features[:, :7168] # 前7168维是神经特征\n",
" rnn_part = features[:, 7168:] # 后41维是RNN输出(音素logits)\n",
" \n",
" # 对时间维度做平均\n",
" neural_pooled = np.mean(neural_part, axis=0) # [7168]\n",
" rnn_pooled = np.mean(rnn_part, axis=0) # [41]\n",
" \n",
" X_neural.append(neural_pooled)\n",
" y_phoneme_logits.append(rnn_pooled)\n",
"\n",
"X_neural = np.array(X_neural) # [348, 7168]\n",
"y_phoneme_logits = np.array(y_phoneme_logits) # [348, 41]\n",
"\n",
"print(f\" 神经特征矩阵: {X_neural.shape}\")\n",
"print(f\" 音素logits矩阵: {y_phoneme_logits.shape}\")\n",
"\n",
"# 从音素logits得到分类标签\n",
"y_phoneme_class = np.argmax(y_phoneme_logits, axis=1) # 选择最大值的索引\n",
"print(f\" 音素类别标签: {y_phoneme_class.shape} (值范围: {y_phoneme_class.min()}-{y_phoneme_class.max()})\")\n",
"\n",
"# 显示音素分布\n",
"from collections import Counter\n",
"phoneme_dist = Counter(y_phoneme_class)\n",
"print(f\" 音素分布: {len(phoneme_dist)} 个不同音素\")\n",
"print(f\" 样本最多的前5个音素:\")\n",
"for phoneme_id, count in phoneme_dist.most_common(5):\n",
" print(f\" 音素 {phoneme_id}: {count} 次\")\n",
"\n",
"# 训练测试集切分\n",
"print(f\"\\n🔄 训练测试集切分:\")\n",
"X_train_neu, X_test_neu, y_logits_train, y_logits_test, y_class_train, y_class_test = train_test_split(\n",
" X_neural, y_phoneme_logits, y_phoneme_class,\n",
" test_size=0.2, random_state=42, stratify=y_phoneme_class\n",
")\n",
"\n",
"print(f\" 训练集: {X_train_neu.shape[0]} 样本\")\n",
"print(f\" 测试集: {X_test_neu.shape[0]} 样本\")\n",
"print(f\" 音素类别数: {len(np.unique(y_phoneme_class))}\")\n",
"\n",
"# 检查类别分布\n",
"train_dist = Counter(y_class_train)\n",
"test_dist = Counter(y_class_test)\n",
"print(f\" 训练集音素分布: {len(train_dist)} 个不同音素\")\n",
"print(f\" 测试集音素分布: {len(test_dist)} 个不同音素\")\n",
"\n",
"print(\"\\n\" + \"=\"*70)\n",
"print(\"🌲 随机森林回归 + 分类\")\n",
"print(\"=\"*70)\n",
"\n",
"# 方案1: 多输出回归 (预测41维音素logits)\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.multioutput import MultiOutputRegressor\n",
"\n",
"print(\"📊 方案1: 多输出回归 (神经特征 → 音素logits)\")\n",
"rf_regressor = MultiOutputRegressor(\n",
" RandomForestRegressor(\n",
" n_estimators=100,\n",
" max_depth=10,\n",
" random_state=42,\n",
" n_jobs=-1\n",
" )\n",
")\n",
"\n",
"print(\"🚀 训练回归模型...\")\n",
"start_time = time.time()\n",
"rf_regressor.fit(X_train_neu, y_logits_train)\n",
"regression_time = time.time() - start_time\n",
"print(f\"✅ 回归训练完成!耗时: {regression_time:.2f} 秒\")\n",
"\n",
"# 预测音素logits\n",
"print(\"🔮 预测音素置信度...\")\n",
"y_logits_pred_train = rf_regressor.predict(X_train_neu)\n",
"y_logits_pred_test = rf_regressor.predict(X_test_neu)\n",
"\n",
"# 从预测的logits得到分类结果\n",
"y_class_pred_train = np.argmax(y_logits_pred_train, axis=1)\n",
"y_class_pred_test = np.argmax(y_logits_pred_test, axis=1)\n",
"\n",
"# 评估分类性能\n",
"from sklearn.metrics import accuracy_score, classification_report\n",
"\n",
"train_acc = accuracy_score(y_class_train, y_class_pred_train)\n",
"test_acc = accuracy_score(y_class_test, y_class_pred_test)\n",
"\n",
"print(f\"\\n📊 音素分类性能评估:\")\n",
"print(f\" 训练集准确率: {train_acc:.4f} ({train_acc*100:.2f}%)\")\n",
"print(f\" 测试集准确率: {test_acc:.4f} ({test_acc*100:.2f}%)\")\n",
"print(f\" 过拟合程度: {(train_acc - test_acc)*100:.2f}%\")\n",
"\n",
"# 评估回归性能\n",
"from sklearn.metrics import mean_squared_error, r2_score\n",
"\n",
"mse_train = mean_squared_error(y_logits_train, y_logits_pred_train)\n",
"mse_test = mean_squared_error(y_logits_test, y_logits_pred_test)\n",
"r2_train = r2_score(y_logits_train, y_logits_pred_train)\n",
"r2_test = r2_score(y_logits_test, y_logits_pred_test)\n",
"\n",
"print(f\"\\n📈 音素logits回归性能:\")\n",
"print(f\" 训练集 MSE: {mse_train:.6f}\")\n",
"print(f\" 测试集 MSE: {mse_test:.6f}\")\n",
"print(f\" 训练集 R²: {r2_train:.4f}\")\n",
"print(f\" 测试集 R²: {r2_test:.4f}\")\n",
"\n",
"print(f\"\\n✨ 任务修正完成!现在是正确的音素分类任务\")"
]
}
],
"metadata": {
"kaggle": {
"accelerator": "tpu1vmV38",
"dataSources": [
{
"databundleVersionId": 13056355,
"sourceId": 106809,
"sourceType": "competition"
}
],
"dockerImageVersionId": 31091,
"isGpuEnabled": false,
"isInternetEnabled": true,
"language": "python",
"sourceType": "notebook"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}