2907 lines
		
	
	
		
			154 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
		
		
			
		
	
	
			2907 lines
		
	
	
		
			154 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
|   | { | |||
|  |  "cells": [ | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "# 环境配置 与 utils" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 1, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "Looking in indexes: https://download.pytorch.org/whl/cu126\n", | |||
|  |       "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n", | |||
|  |       "Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (0.21.0+cu124)\n", | |||
|  |       "Requirement already satisfied: torchaudio in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n", | |||
|  |       "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n", | |||
|  |       "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.14.0)\n", | |||
|  |       "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.5)\n", | |||
|  |       "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n", | |||
|  |       "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.5.1)\n", | |||
|  |       "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", | |||
|  |       "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", | |||
|  |       "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", | |||
|  |       "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch) (9.1.0.70)\n", | |||
|  |       "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.5.8)\n", | |||
|  |       "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch) (11.2.1.3)\n", | |||
|  |       "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch) (10.3.5.147)\n", | |||
|  |       "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch) (11.6.1.9)\n", | |||
|  |       "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch) (12.3.1.170)\n", | |||
|  |       "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n", | |||
|  |       "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n", | |||
|  |       "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", | |||
|  |       "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", | |||
|  |       "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n", | |||
|  |       "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n", | |||
|  |       "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n", | |||
|  |       "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision) (1.26.4)\n", | |||
|  |       "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (11.2.1)\n", | |||
|  |       "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n", | |||
|  |       "Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.3.8)\n", | |||
|  |       "Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.2.4)\n", | |||
|  |       "Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (0.1.1)\n", | |||
|  |       "Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2025.2.0)\n", | |||
|  |       "Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2022.2.0)\n", | |||
|  |       "Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2.4.1)\n", | |||
|  |       "Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2024.2.0)\n", | |||
|  |       "Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2022.2.0)\n", | |||
|  |       "Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy->torchvision) (1.4.0)\n", | |||
|  |       "Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy->torchvision) (2024.2.0)\n", | |||
|  |       "Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy->torchvision) (2024.2.0)\n", | |||
|  |       "Requirement already satisfied: jupyter==1.1.1 in /usr/local/lib/python3.11/dist-packages (1.1.1)\n", | |||
|  |       "Requirement already satisfied: numpy<2.1.0,>=1.26.0 in /usr/local/lib/python3.11/dist-packages (1.26.4)\n", | |||
|  |       "Requirement already satisfied: pandas==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n", | |||
|  |       "Requirement already satisfied: matplotlib==3.10.1 in /usr/local/lib/python3.11/dist-packages (3.10.1)\n", | |||
|  |       "Requirement already satisfied: scipy==1.15.2 in /usr/local/lib/python3.11/dist-packages (1.15.2)\n", | |||
|  |       "Requirement already satisfied: scikit-learn==1.6.1 in /usr/local/lib/python3.11/dist-packages (1.6.1)\n", | |||
|  |       "Requirement already satisfied: tqdm==4.67.1 in /usr/local/lib/python3.11/dist-packages (4.67.1)\n", | |||
|  |       "Requirement already satisfied: g2p_en==2.1.0 in /usr/local/lib/python3.11/dist-packages (2.1.0)\n", | |||
|  |       "Requirement already satisfied: h5py==3.13.0 in /usr/local/lib/python3.11/dist-packages (3.13.0)\n", | |||
|  |       "Requirement already satisfied: omegaconf==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n", | |||
|  |       "Requirement already satisfied: editdistance==0.8.1 in /usr/local/lib/python3.11/dist-packages (0.8.1)\n", | |||
|  |       "Requirement already satisfied: huggingface-hub==0.33.1 in /usr/local/lib/python3.11/dist-packages (0.33.1)\n", | |||
|  |       "Requirement already satisfied: transformers==4.53.0 in /usr/local/lib/python3.11/dist-packages (4.53.0)\n", | |||
|  |       "Requirement already satisfied: tokenizers==0.21.2 in /usr/local/lib/python3.11/dist-packages (0.21.2)\n", | |||
|  |       "Requirement already satisfied: accelerate==1.8.1 in /usr/local/lib/python3.11/dist-packages (1.8.1)\n", | |||
|  |       "Requirement already satisfied: bitsandbytes==0.46.0 in /usr/local/lib/python3.11/dist-packages (0.46.0)\n", | |||
|  |       "Requirement already satisfied: notebook in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.5.4)\n", | |||
|  |       "Requirement already satisfied: jupyter-console in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.1.0)\n", | |||
|  |       "Requirement already satisfied: nbconvert in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.4.5)\n", | |||
|  |       "Requirement already satisfied: ipykernel in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.17.1)\n", | |||
|  |       "Requirement already satisfied: ipywidgets in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (8.1.5)\n", | |||
|  |       "Requirement already satisfied: jupyterlab in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (3.6.8)\n", | |||
|  |       "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2.9.0.post0)\n", | |||
|  |       "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n", | |||
|  |       "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n", | |||
|  |       "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.3.2)\n", | |||
|  |       "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (0.12.1)\n", | |||
|  |       "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (4.58.4)\n", | |||
|  |       "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.4.8)\n", | |||
|  |       "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (25.0)\n", | |||
|  |       "Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (11.2.1)\n", | |||
|  |       "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (3.0.9)\n", | |||
|  |       "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (1.5.1)\n", | |||
|  |       "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (3.6.0)\n", | |||
|  |       "Requirement already satisfied: nltk>=3.2.4 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (3.9.1)\n", | |||
|  |       "Requirement already satisfied: inflect>=0.3.1 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (7.5.0)\n", | |||
|  |       "Requirement already satisfied: distance>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (0.1.3)\n", | |||
|  |       "Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (4.9.3)\n", | |||
|  |       "Requirement already satisfied: PyYAML>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (6.0.2)\n", | |||
|  |       "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (3.18.0)\n", | |||
|  |       "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2025.5.1)\n", | |||
|  |       "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2.32.4)\n", | |||
|  |       "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (4.14.0)\n", | |||
|  |       "Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (1.1.5)\n", | |||
|  |       "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (2024.11.6)\n", | |||
|  |       "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (0.5.3)\n", | |||
|  |       "Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (7.0.0)\n", | |||
|  |       "Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (2.6.0+cu124)\n", | |||
|  |       "Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.3.8)\n", | |||
|  |       "Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.2.4)\n", | |||
|  |       "Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (0.1.1)\n", | |||
|  |       "Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2025.2.0)\n", | |||
|  |       "Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2022.2.0)\n", | |||
|  |       "Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2.4.1)\n", | |||
|  |       "Requirement already satisfied: more_itertools>=8.5.0 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (10.7.0)\n", | |||
|  |       "Requirement already satisfied: typeguard>=4.0.1 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (4.4.4)\n", | |||
|  |       "Requirement already satisfied: click in /usr/local/lib/python3.11/dist-packages (from nltk>=3.2.4->g2p_en==2.1.0) (8.2.1)\n", | |||
|  |       "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas==2.3.0) (1.17.0)\n", | |||
|  |       "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.5)\n", | |||
|  |       "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.1.6)\n", | |||
|  |       "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", | |||
|  |       "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", | |||
|  |       "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", | |||
|  |       "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (9.1.0.70)\n", | |||
|  |       "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.5.8)\n", | |||
|  |       "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.2.1.3)\n", | |||
|  |       "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (10.3.5.147)\n", | |||
|  |       "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.6.1.9)\n", | |||
|  |       "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.3.1.170)\n", | |||
|  |       "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (0.6.2)\n", | |||
|  |       "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (2.21.5)\n", | |||
|  |       "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", | |||
|  |       "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", | |||
|  |       "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.2.0)\n", | |||
|  |       "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (1.13.1)\n", | |||
|  |       "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=2.0.0->accelerate==1.8.1) (1.3.0)\n", | |||
|  |       "Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.8.0)\n", | |||
|  |       "Requirement already satisfied: ipython>=7.23.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (7.34.0)\n", | |||
|  |       "Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (8.6.3)\n", | |||
|  |       "Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (0.1.7)\n", | |||
|  |       "Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.6.0)\n", | |||
|  |       "Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (24.0.1)\n", | |||
|  |       "Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (6.5.1)\n", | |||
|  |       "Requirement already satisfied: traitlets>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (5.7.1)\n", | |||
|  |       "Requirement already satisfied: comm>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (0.2.2)\n", | |||
|  |       "Requirement already satisfied: widgetsnbextension~=4.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (4.0.14)\n", | |||
|  |       "Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (3.0.15)\n", | |||
|  |       "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (3.0.51)\n", | |||
|  |       "Requirement already satisfied: pygments in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (2.19.2)\n", | |||
|  |       "Requirement already satisfied: jupyter-core in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (5.8.1)\n", | |||
|  |       "Requirement already satisfied: jupyterlab-server~=2.19 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.27.3)\n", | |||
|  |       "Requirement already satisfied: jupyter-server<3,>=1.16.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.12.5)\n", | |||
|  |       "Requirement already satisfied: jupyter-ydoc~=0.2.4 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.2.5)\n", | |||
|  |       "Requirement already satisfied: jupyter-server-ydoc~=0.8.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.8.0)\n", | |||
|  |       "Requirement already satisfied: nbclassic in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (1.3.1)\n", | |||
|  |       "Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (25.1.0)\n", | |||
|  |       "Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.2.0)\n", | |||
|  |       "Requirement already satisfied: nbformat in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (5.10.4)\n", | |||
|  |       "Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (1.8.3)\n", | |||
|  |       "Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.18.1)\n", | |||
|  |       "Requirement already satisfied: prometheus-client in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.22.1)\n", | |||
|  |       "Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.8.4)\n", | |||
|  |       "Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.3.0)\n", | |||
|  |       "Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.4)\n", | |||
|  |       "Requirement already satisfied: bleach in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (6.2.0)\n", | |||
|  |       "Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (1.5.1)\n", | |||
|  |       "Requirement already satisfied: testpath in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.6.0)\n", | |||
|  |       "Requirement already satisfied: defusedxml in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.7.1)\n", | |||
|  |       "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (4.13.4)\n", | |||
|  |       "Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.5.13)\n", | |||
|  |       "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (3.0.2)\n", | |||
|  |       "Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n", | |||
|  |       "Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2022.2.0)\n", | |||
|  |       "Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy<2.1.0,>=1.26.0) (1.4.0)\n", | |||
|  |       "Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy<2.1.0,>=1.26.0) (2024.2.0)\n", | |||
|  |       "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.4.2)\n", | |||
|  |       "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.10)\n", | |||
|  |       "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2.5.0)\n", | |||
|  |       "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2025.6.15)\n", | |||
|  |       "Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n", | |||
|  |       "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (75.2.0)\n", | |||
|  |       "Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.19.2)\n", | |||
|  |       "Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.4.2)\n", | |||
|  |       "Requirement already satisfied: pickleshare in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.7.5)\n", | |||
|  |       "Requirement already satisfied: backcall in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.2.0)\n", | |||
|  |       "Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.9.0)\n", | |||
|  |       "Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.11/dist-packages (from jupyter-core->jupyterlab->jupyter==1.1.1) (4.3.8)\n", | |||
|  |       "Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (4.9.0)\n", | |||
|  |       "Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.12.0)\n", | |||
|  |       "Requirement already satisfied: jupyter-server-terminals in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.5.3)\n", | |||
|  |       "Requirement already satisfied: overrides in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (7.7.0)\n", | |||
|  |       "Requirement already satisfied: websocket-client in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.8.0)\n", | |||
|  |       "Requirement already satisfied: jupyter-server-fileid<1,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.9.3)\n", | |||
|  |       "Requirement already satisfied: ypy-websocket<0.9.0,>=0.8.2 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.8.4)\n", | |||
|  |       "Requirement already satisfied: y-py<0.7.0,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-ydoc~=0.2.4->jupyterlab->jupyter==1.1.1) (0.6.2)\n", | |||
|  |       "Requirement already satisfied: babel>=2.10 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2.17.0)\n", | |||
|  |       "Requirement already satisfied: json5>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.12.0)\n", | |||
|  |       "Requirement already satisfied: jsonschema>=4.18.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (4.24.0)\n", | |||
|  |       "Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.11/dist-packages (from nbclassic->jupyterlab->jupyter==1.1.1) (0.2.4)\n", | |||
|  |       "Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.11/dist-packages (from nbformat->notebook->jupyter==1.1.1) (2.21.1)\n", | |||
|  |       "Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->jupyter-console->jupyter==1.1.1) (0.2.13)\n", | |||
|  |       "Requirement already satisfied: ptyprocess in /usr/local/lib/python3.11/dist-packages (from terminado>=0.8.3->notebook->jupyter==1.1.1) (0.7.0)\n", | |||
|  |       "Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.11/dist-packages (from argon2-cffi->notebook->jupyter==1.1.1) (21.2.0)\n", | |||
|  |       "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->nbconvert->jupyter==1.1.1) (2.7)\n", | |||
|  |       "Requirement already satisfied: webencodings in /usr/local/lib/python3.11/dist-packages (from bleach->nbconvert->jupyter==1.1.1) (0.5.1)\n", | |||
|  |       "Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio>=3.1.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.1)\n", | |||
|  |       "Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.11/dist-packages (from jedi>=0.16->ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.8.4)\n", | |||
|  |       "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (25.3.0)\n", | |||
|  |       "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2025.4.1)\n", | |||
|  |       "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.36.2)\n", | |||
|  |       "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.25.1)\n", | |||
|  |       "Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.3.0)\n", | |||
|  |       "Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.4)\n", | |||
|  |       "Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.1)\n", | |||
|  |       "Requirement already satisfied: aiofiles<23,>=22.1.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (22.1.0)\n", | |||
|  |       "Requirement already satisfied: aiosqlite<1,>=0.17.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.21.0)\n", | |||
|  |       "Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (1.17.1)\n", | |||
|  |       "Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (2.22)\n", | |||
|  |       "Requirement already satisfied: fqdn in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.5.1)\n", | |||
|  |       "Requirement already satisfied: isoduration in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (20.11.0)\n", | |||
|  |       "Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.0.0)\n", | |||
|  |       "Requirement already satisfied: uri-template in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n", | |||
|  |       "Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (24.11.1)\n", | |||
|  |       "Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.11/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n", | |||
|  |       "Requirement already satisfied: types-python-dateutil>=2.8.10 in /usr/local/lib/python3.11/dist-packages (from arrow>=0.15.0->isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (2.9.0.20250516)\n", | |||
|  |       "Obtaining file:///kaggle/working/nejm-brain-to-text\n", | |||
|  |       "  Preparing metadata (setup.py): started\n", | |||
|  |       "  Preparing metadata (setup.py): finished with status 'done'\n", | |||
|  |       "Installing collected packages: nejm_b2txt_utils\n", | |||
|  |       "  Running setup.py develop for nejm_b2txt_utils\n", | |||
|  |       "Successfully installed nejm_b2txt_utils-0.0.0\n" | |||
|  |      ] | |||
|  |     }, | |||
|  |     { | |||
|  |      "name": "stderr", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "Cloning into 'nejm-brain-to-text'...\n", | |||
|  |       "cp: cannot stat '/kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl': No such file or directory\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "%%bash\n", | |||
|  |     "cd /kaggle/working/\n", | |||
|  |     "rm -rf /kaggle/working/nejm-brain-to-text/\n", | |||
|  |     "git clone https://github.com/ZH-CEN/nejm-brain-to-text.git\n", | |||
|  |     "cd /kaggle/working/nejm-brain-to-text/\n", | |||
|  |     "cp /kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl /kaggle/working/nejm-brain-to-text/data/t15_copyTask.pkl\n", | |||
|  |     "# Install PyTorch with CUDA 12.6\n", | |||
|  |     "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n", | |||
|  |     "\n", | |||
|  |     "# Install additional packages with compatible versions\n", | |||
|  |     "# TODO: remove redis\n", | |||
|  |     "pip install \\\n", | |||
|  |     "    jupyter==1.1.1 \\\n", | |||
|  |     "    \"numpy>=1.26.0,<2.1.0\" \\\n", | |||
|  |     "    pandas==2.3.0 \\\n", | |||
|  |     "    matplotlib==3.10.1 \\\n", | |||
|  |     "    scipy==1.15.2 \\\n", | |||
|  |     "    scikit-learn==1.6.1 \\\n", | |||
|  |     "    tqdm==4.67.1 \\\n", | |||
|  |     "    g2p_en==2.1.0 \\\n", | |||
|  |     "    h5py==3.13.0 \\\n", | |||
|  |     "    omegaconf==2.3.0 \\\n", | |||
|  |     "    editdistance==0.8.1 \\\n", | |||
|  |     "    huggingface-hub==0.33.1 \\\n", | |||
|  |     "    transformers==4.53.0 \\\n", | |||
|  |     "    tokenizers==0.21.2 \\\n", | |||
|  |     "    accelerate==1.8.1 \\\n", | |||
|  |     "    bitsandbytes==0.46.0\n", | |||
|  |     "\n", | |||
|  |     "# Install the local package\n", | |||
|  |     "pip install -e .\n", | |||
|  |     "ln -s /kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline /kaggle/working/nejm-brain-to-text/data\n", | |||
|  |     "ln -s /kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final /kaggle/working/nejm-brain-to-text/data" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 2, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "/kaggle/working/nejm-brain-to-text\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "%cd /kaggle/working/nejm-brain-to-text\n", | |||
|  |     "import numpy as np\n", | |||
|  |     "import os\n", | |||
|  |     "import pickle\n", | |||
|  |     "import matplotlib.pyplot as plt\n", | |||
|  |     "import matplotlib\n", | |||
|  |     "from g2p_en import G2p\n", | |||
|  |     "import pandas as pd\n", | |||
|  |     "import numpy as np\n", | |||
|  |     "from nejm_b2txt_utils.general_utils import *\n", | |||
|  |     "\n", | |||
|  |     "matplotlib.rcParams['pdf.fonttype'] = 42\n", | |||
|  |     "matplotlib.rcParams['ps.fonttype'] = 42\n", | |||
|  |     "matplotlib.rcParams['font.family'] = 'sans-serif'\n" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 3, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [], | |||
|  |    "source": [ | |||
|  |     "import h5py\n", | |||
|  |     "def load_h5py_file(file_path, b2txt_csv_df):\n", | |||
|  |     "    data = {\n", | |||
|  |     "        'neural_features': [],\n", | |||
|  |     "        'n_time_steps': [],\n", | |||
|  |     "        'seq_class_ids': [],\n", | |||
|  |     "        'seq_len': [],\n", | |||
|  |     "        'transcriptions': [],\n", | |||
|  |     "        'sentence_label': [],\n", | |||
|  |     "        'session': [],\n", | |||
|  |     "        'block_num': [],\n", | |||
|  |     "        'trial_num': [],\n", | |||
|  |     "        'corpus': [],\n", | |||
|  |     "    }\n", | |||
|  |     "    # Open the hdf5 file for that day\n", | |||
|  |     "    with h5py.File(file_path, 'r') as f:\n", | |||
|  |     "\n", | |||
|  |     "        keys = list(f.keys())\n", | |||
|  |     "\n", | |||
|  |     "        # For each trial in the selected trials in that day\n", | |||
|  |     "        for key in keys:\n", | |||
|  |     "            g = f[key]\n", | |||
|  |     "\n", | |||
|  |     "            neural_features = g['input_features'][:] # pyright: ignore[reportIndexIssue]\n", | |||
|  |     "            n_time_steps = g.attrs['n_time_steps']\n", | |||
|  |     "            seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None # type: ignore\n", | |||
|  |     "            seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None\n", | |||
|  |     "            transcription = g['transcription'][:] if 'transcription' in g else None # type: ignore\n", | |||
|  |     "            sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None # pyright: ignore[reportIndexIssue]\n", | |||
|  |     "            session = g.attrs['session']\n", | |||
|  |     "            block_num = g.attrs['block_num']\n", | |||
|  |     "            trial_num = g.attrs['trial_num']\n", | |||
|  |     "\n", | |||
|  |     "            # match this trial up with the csv to get the corpus name\n", | |||
|  |     "            year, month, day = session.split('.')[1:] # pyright: ignore[reportAttributeAccessIssue]\n", | |||
|  |     "            date = f'{year}-{month}-{day}'\n", | |||
|  |     "            row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)]\n", | |||
|  |     "            corpus_name = row['Corpus'].values[0]\n", | |||
|  |     "\n", | |||
|  |     "            data['neural_features'].append(neural_features)\n", | |||
|  |     "            data['n_time_steps'].append(n_time_steps)\n", | |||
|  |     "            data['seq_class_ids'].append(seq_class_ids)\n", | |||
|  |     "            data['seq_len'].append(seq_len)\n", | |||
|  |     "            data['transcriptions'].append(transcription)\n", | |||
|  |     "            data['sentence_label'].append(sentence_label)\n", | |||
|  |     "            data['session'].append(session)\n", | |||
|  |     "            data['block_num'].append(block_num)\n", | |||
|  |     "            data['trial_num'].append(trial_num)\n", | |||
|  |     "            data['corpus'].append(corpus_name)\n", | |||
|  |     "    return data" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 4, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [], | |||
|  |    "source": [ | |||
|  |     "LOGIT_TO_PHONEME = [\n", | |||
|  |     "    'BLANK',\n", | |||
|  |     "    'AA', 'AE', 'AH', 'AO', 'AW',\n", | |||
|  |     "    'AY', 'B',  'CH', 'D', 'DH',\n", | |||
|  |     "    'EH', 'ER', 'EY', 'F', 'G',\n", | |||
|  |     "    'HH', 'IH', 'IY', 'JH', 'K',\n", | |||
|  |     "    'L', 'M', 'N', 'NG', 'OW',\n", | |||
|  |     "    'OY', 'P', 'R', 'S', 'SH',\n", | |||
|  |     "    'T', 'TH', 'UH', 'UW', 'V',\n", | |||
|  |     "    'W', 'Y', 'Z', 'ZH',\n", | |||
|  |     "    ' | ',\n", | |||
|  |     "]" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "# 数据分析与预处理" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "## 数据准备" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 5, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "/kaggle/working/nejm-brain-to-text\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "%cd /kaggle/working/nejm-brain-to-text/" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 6, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [], | |||
|  |    "source": [ | |||
|  |     "import pandas as pd\n", | |||
|  |     "data = load_h5py_file(file_path='data/hdf5_data_final/t15.2023.08.11/data_train.hdf5',\n", | |||
|  |     "                        b2txt_csv_df=pd.read_csv('data/t15_copyTaskData_description.csv'))" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "- **任务介绍** :机器学习解决高维信号的模式识别问题" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "我们的数据集标签缺少时间戳,现在要进行的是半监督学习" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "- 音素时间均等分割或者按照调研数据设定初始长度。然后筛掉异常值。提取出可用的训练集,再控制时间长短,查看样本类的长度" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 7, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "data": { | |||
|  |       "text/plain": [ | |||
|  |        "{'neural_features': array([[ 2.3076649 , -0.78699756, -0.64687246, ...,  0.57367045,\n", | |||
|  |        "         -0.7091646 , -0.11018186],\n", | |||
|  |        "        [-0.5859305 , -0.78699756, -0.64687246, ...,  0.3122117 ,\n", | |||
|  |        "          1.7943763 , -0.76884896],\n", | |||
|  |        "        [-0.5859305 , -0.78699756, -0.64687246, ..., -0.21193463,\n", | |||
|  |        "         -0.8481289 , -0.7648201 ],\n", | |||
|  |        "        ...,\n", | |||
|  |        "        [-0.5859305 ,  0.22756557,  0.9262037 , ..., -0.34710956,\n", | |||
|  |        "          0.9710176 ,  2.5397465 ],\n", | |||
|  |        "        [-0.5859305 ,  0.22756557, -0.64687246, ..., -0.83613133,\n", | |||
|  |        "         -0.68723625,  0.10479005],\n", | |||
|  |        "        [ 0.8608672 , -0.78699756, -0.64687246, ..., -0.7171131 ,\n", | |||
|  |        "          0.7417906 , -0.7008622 ]], dtype=float32),\n", | |||
|  |        " 'n_time_steps': 321,\n", | |||
|  |        " 'seq_class_ids': array([ 7, 28, 17, 24, 40, 17, 31, 40, 20, 21, 25, 29, 12, 40,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n", | |||
|  |        "         0,  0,  0,  0,  0,  0,  0], dtype=int32),\n", | |||
|  |        " 'seq_len': 14,\n", | |||
|  |        " 'transcriptions': array([ 66, 114, 105, 110, 103,  32, 105, 116,  32,  99, 108, 111, 115,\n", | |||
|  |        "        101, 114,  46,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n", | |||
|  |        "          0,   0,   0,   0,   0,   0], dtype=int32),\n", | |||
|  |        " 'sentence_label': 'Bring it closer.',\n", | |||
|  |        " 'session': 't15.2023.08.11',\n", | |||
|  |        " 'block_num': 2,\n", | |||
|  |        " 'trial_num': 0,\n", | |||
|  |        " 'corpus': '50-Word'}" | |||
|  |       ] | |||
|  |      }, | |||
|  |      "execution_count": 7, | |||
|  |      "metadata": {}, | |||
|  |      "output_type": "execute_result" | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "def data_patch(data, index):\n", | |||
|  |     "    data_patch = {}\n", | |||
|  |     "    data_patch['neural_features'] = data['neural_features'][index]\n", | |||
|  |     "    data_patch['n_time_steps'] = data['n_time_steps'][index]\n", | |||
|  |     "    data_patch['seq_class_ids'] = data['seq_class_ids'][index]\n", | |||
|  |     "    data_patch['seq_len'] = data['seq_len'][index]\n", | |||
|  |     "    data_patch['transcriptions'] = data['transcriptions'][index]\n", | |||
|  |     "    data_patch['sentence_label'] = data['sentence_label'][index]\n", | |||
|  |     "    data_patch['session'] = data['session'][index]\n", | |||
|  |     "    data_patch['block_num'] = data['block_num'][index]\n", | |||
|  |     "    data_patch['trial_num'] = data['trial_num'][index]\n", | |||
|  |     "    data_patch['corpus'] = data['corpus'][index]\n", | |||
|  |     "    return data_patch\n", | |||
|  |     "\n", | |||
|  |     "data_patch(data, 0)" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 8, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [], | |||
|  |    "source": [ | |||
|  |     "d1 = data_patch(data, 0)" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 9, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "Transcriptions non-zero length: 16\n", | |||
|  |       "Seq class ids non-zero length: 14\n", | |||
|  |       "Seq len: 14\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "trans_len = len([x for x in d1['transcriptions'] if x != 0])\n", | |||
|  |     "seq_len_nonzero = len([x for x in d1['seq_class_ids'] if x != 0])\n", | |||
|  |     "seq_len = d1['seq_len']\n", | |||
|  |     "print(f\"Transcriptions non-zero length: {trans_len}\")\n", | |||
|  |     "print(f\"Seq class ids non-zero length: {seq_len_nonzero}\")\n", | |||
|  |     "print(f\"Seq len: {seq_len}\")" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 10, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "Number of feature sequences: 14\n", | |||
|  |       "Shape of first sequence: (22, 512)\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "def create_time_windows(d1):\n", | |||
|  |     "    import numpy as np\n", | |||
|  |     "    n_time_steps = d1['n_time_steps']\n", | |||
|  |     "    seq_len = d1['seq_len']\n", | |||
|  |     "    # Create equal windows\n", | |||
|  |     "    edges = np.linspace(0, n_time_steps, seq_len + 1, dtype=int)\n", | |||
|  |     "    windows = [(edges[i], edges[i+1]) for i in range(seq_len)]\n", | |||
|  |     "    \n", | |||
|  |     "    # Extract feature sequences for each window\n", | |||
|  |     "    feature_sequences = []\n", | |||
|  |     "    for start, end in windows:\n", | |||
|  |     "        seq = d1['neural_features'][start:end, :]\n", | |||
|  |     "        feature_sequences.append(seq)\n", | |||
|  |     "    \n", | |||
|  |     "    return feature_sequences\n", | |||
|  |     "\n", | |||
|  |     "# Example usage\n", | |||
|  |     "feature_sequences = create_time_windows(d1)\n", | |||
|  |     "print(\"Number of feature sequences:\", len(feature_sequences))\n", | |||
|  |     "print(\"Shape of first sequence:\", feature_sequences[0].shape)\n" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 11, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "Train: 45, Val: 41, Test: 41\n", | |||
|  |       "Train files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_train.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.08.11/data_train.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_train.hdf5']\n", | |||
|  |       "Val files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_val.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_val.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2024.03.08/data_val.hdf5']\n", | |||
|  |       "Test files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_test.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_test.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2024.03.08/data_test.hdf5']\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "import os\n", | |||
|  |     "\n", | |||
|  |     "def scan_hdf5_files(base_path):\n", | |||
|  |     "    train_files = []\n", | |||
|  |     "    val_files = []\n", | |||
|  |     "    test_files = []\n", | |||
|  |     "    for root, dirs, files in os.walk(base_path):\n", | |||
|  |     "        for file in files:\n", | |||
|  |     "            if file.endswith('.hdf5'):\n", | |||
|  |     "                abs_path = os.path.abspath(os.path.join(root, file))\n", | |||
|  |     "                if 'data_train.hdf5' in file:\n", | |||
|  |     "                    train_files.append(abs_path)\n", | |||
|  |     "                elif 'data_val.hdf5' in file:\n", | |||
|  |     "                    val_files.append(abs_path)\n", | |||
|  |     "                elif 'data_test.hdf5' in file:\n", | |||
|  |     "                    test_files.append(abs_path)\n", | |||
|  |     "    return train_files, val_files, test_files\n", | |||
|  |     "\n", | |||
|  |     "# Example usage\n", | |||
|  |     "FILE_PATH = 'data/hdf5_data_final'\n", | |||
|  |     "train_list, val_list, test_list = scan_hdf5_files(FILE_PATH)\n", | |||
|  |     "print(f\"Train: {len(train_list)}, Val: {len(val_list)}, Test: {len(test_list)}\")\n", | |||
|  |     "print(\"Train files (first 3):\", train_list[:3])\n", | |||
|  |     "print(\"Val files (first 3):\", val_list[:3])\n", | |||
|  |     "print(\"Test files (first 3):\", test_list[:3])" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "## 🚀 完整数据集处理工作流\n", | |||
|  |     "\n", | |||
|  |     "创建一个自动化工作流,处理所有sessions的训练集、验证集、测试集数据,生成包含40类音素预测的完整特征数据集。" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 12, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "🧠 创建真实RNN模型加载器...\n", | |||
|  |       "   未检测到CUDA,使用CPU\n", | |||
|  |       "🧠 RNN模型加载器初始化:\n", | |||
|  |       "   模型路径: ./data/t15_pretrained_rnn_baseline\n", | |||
|  |       "   使用设备: cpu\n", | |||
|  |       "✅ 模型文件检查通过\n", | |||
|  |       "⚠️ 无法导入rnn_model,尝试从model_training目录导入...\n", | |||
|  |       "✅ 从model_training目录成功导入GRUDecoder\n", | |||
|  |       "\n", | |||
|  |       "🔄 加载RNN模型...\n", | |||
|  |       "   ✅ 模型配置加载完成\n", | |||
|  |       "      输入特征维度: 512\n", | |||
|  |       "      隐藏单元数: 768\n", | |||
|  |       "      层数: 5\n", | |||
|  |       "      输出类别数: 41\n", | |||
|  |       "   ✅ 模型架构初始化完成\n", | |||
|  |       "   🔐 加载检查点(注意:使用weights_only=False,请确保检查点来源可信)...\n", | |||
|  |       "   ✅ 模型架构初始化完成\n", | |||
|  |       "   🔐 加载检查点(注意:使用weights_only=False,请确保检查点来源可信)...\n", | |||
|  |       "   📋 找到model_state_dict键\n", | |||
|  |       "   🔍 匹配模型参数...\n", | |||
|  |       "   📊 参数匹配统计: 113/113 个键匹配成功\n", | |||
|  |       "   ✅ 预训练权重加载完成 (113个参数)\n", | |||
|  |       "   📊 模型参数统计:\n", | |||
|  |       "      总参数数量: 44,315,177\n", | |||
|  |       "      Day-specific参数: 11,819,520 (26.7%)\n", | |||
|  |       "      会话数量: 45\n", | |||
|  |       "   📅 支持的会话:\n", | |||
|  |       "      0: t15.2023.08.11\n", | |||
|  |       "      1: t15.2023.08.13\n", | |||
|  |       "      2: t15.2023.08.18\n", | |||
|  |       "      3: t15.2023.08.20\n", | |||
|  |       "      4: t15.2023.08.25\n", | |||
|  |       "      ... 还有 40 个会话\n", | |||
|  |       "\n", | |||
|  |       "🎉 RNN模型加载成功!\n", | |||
|  |       "\n", | |||
|  |       "✅ 真实RNN模型加载器准备完成!\n", | |||
|  |       "🔧 现在可以使用真实的预训练RNN进行预测\n", | |||
|  |       "   📋 找到model_state_dict键\n", | |||
|  |       "   🔍 匹配模型参数...\n", | |||
|  |       "   📊 参数匹配统计: 113/113 个键匹配成功\n", | |||
|  |       "   ✅ 预训练权重加载完成 (113个参数)\n", | |||
|  |       "   📊 模型参数统计:\n", | |||
|  |       "      总参数数量: 44,315,177\n", | |||
|  |       "      Day-specific参数: 11,819,520 (26.7%)\n", | |||
|  |       "      会话数量: 45\n", | |||
|  |       "   📅 支持的会话:\n", | |||
|  |       "      0: t15.2023.08.11\n", | |||
|  |       "      1: t15.2023.08.13\n", | |||
|  |       "      2: t15.2023.08.18\n", | |||
|  |       "      3: t15.2023.08.20\n", | |||
|  |       "      4: t15.2023.08.25\n", | |||
|  |       "      ... 还有 40 个会话\n", | |||
|  |       "\n", | |||
|  |       "🎉 RNN模型加载成功!\n", | |||
|  |       "\n", | |||
|  |       "✅ 真实RNN模型加载器准备完成!\n", | |||
|  |       "🔧 现在可以使用真实的预训练RNN进行预测\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "# 🧠 真实RNN模型加载器\n", | |||
|  |     "import torch\n", | |||
|  |     "from omegaconf import OmegaConf\n", | |||
|  |     "\n", | |||
|  |     "class RealRNNModelLoader:\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    加载和使用真实的预训练RNN模型\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    \n", | |||
|  |     "    def __init__(self, model_path='./data/t15_pretrained_rnn_baseline', device='auto'):\n", | |||
|  |     "        \"\"\"\n", | |||
|  |     "        初始化RNN模型加载器\n", | |||
|  |     "        \n", | |||
|  |     "        参数:\n", | |||
|  |     "            model_path: 预训练模型路径\n", | |||
|  |     "            device: 使用的设备 ('auto', 'cpu', 'cuda' 或 'cuda:0' 等)\n", | |||
|  |     "        \"\"\"\n", | |||
|  |     "        self.model_path = model_path\n", | |||
|  |     "        self.model = None\n", | |||
|  |     "        self.model_args = None\n", | |||
|  |     "        self.device = self._setup_device(device)\n", | |||
|  |     "        \n", | |||
|  |     "        print(f\"🧠 RNN模型加载器初始化:\")\n", | |||
|  |     "        print(f\"   模型路径: {model_path}\")\n", | |||
|  |     "        print(f\"   使用设备: {self.device}\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 检查模型文件是否存在\n", | |||
|  |     "        self._check_model_files()\n", | |||
|  |     "        \n", | |||
|  |     "    def _setup_device(self, device):\n", | |||
|  |     "        \"\"\"设置计算设备\"\"\"\n", | |||
|  |     "        if device == 'auto':\n", | |||
|  |     "            if torch.cuda.is_available():\n", | |||
|  |     "                device = 'cuda'\n", | |||
|  |     "                print(f\"   自动检测到CUDA,使用GPU\")\n", | |||
|  |     "            else:\n", | |||
|  |     "                device = 'cpu'\n", | |||
|  |     "                print(f\"   未检测到CUDA,使用CPU\")\n", | |||
|  |     "        \n", | |||
|  |     "        return torch.device(device)\n", | |||
|  |     "    \n", | |||
|  |     "    def _check_model_files(self):\n", | |||
|  |     "        \"\"\"检查必需的模型文件\"\"\"\n", | |||
|  |     "        required_files = {\n", | |||
|  |     "            'args.yaml': os.path.join(self.model_path, 'checkpoint/args.yaml'),\n", | |||
|  |     "            'checkpoint': os.path.join(self.model_path, 'checkpoint/best_checkpoint')\n", | |||
|  |     "        }\n", | |||
|  |     "        \n", | |||
|  |     "        missing_files = []\n", | |||
|  |     "        for name, path in required_files.items():\n", | |||
|  |     "            if not os.path.exists(path):\n", | |||
|  |     "                missing_files.append(f\"{name}: {path}\")\n", | |||
|  |     "        \n", | |||
|  |     "        if missing_files:\n", | |||
|  |     "            print(f\"❌ 缺少模型文件:\")\n", | |||
|  |     "            for file in missing_files:\n", | |||
|  |     "                print(f\"   • {file}\")\n", | |||
|  |     "            print(f\"\\n💡 请确保已下载预训练模型到: {self.model_path}\")\n", | |||
|  |     "            return False\n", | |||
|  |     "        else:\n", | |||
|  |     "            print(f\"✅ 模型文件检查通过\")\n", | |||
|  |     "            return True\n", | |||
|  |     "    \n", | |||
|  |     "    def load_model(self):\n", | |||
|  |     "        \"\"\"加载预训练的RNN模型\"\"\"\n", | |||
|  |     "        try:\n", | |||
|  |     "            # 需要先检查是否已经导入了rnn_model\n", | |||
|  |     "            try:\n", | |||
|  |     "                from rnn_model import GRUDecoder\n", | |||
|  |     "            except ImportError:\n", | |||
|  |     "                print(\"⚠️ 无法导入rnn_model,尝试从model_training目录导入...\")\n", | |||
|  |     "                import sys\n", | |||
|  |     "                model_training_path = os.path.abspath('./model_training')\n", | |||
|  |     "                if model_training_path not in sys.path:\n", | |||
|  |     "                    sys.path.append(model_training_path)\n", | |||
|  |     "                \n", | |||
|  |     "                try:\n", | |||
|  |     "                    from rnn_model import GRUDecoder\n", | |||
|  |     "                    print(\"✅ 从model_training目录成功导入GRUDecoder\")\n", | |||
|  |     "                except ImportError as e:\n", | |||
|  |     "                    print(f\"❌ 无法导入GRUDecoder: {e}\")\n", | |||
|  |     "                    print(\"💡 请确保rnn_model.py在model_training目录中\")\n", | |||
|  |     "                    return False\n", | |||
|  |     "            \n", | |||
|  |     "            print(f\"\\n🔄 加载RNN模型...\")\n", | |||
|  |     "            \n", | |||
|  |     "            # 1. 加载模型配置\n", | |||
|  |     "            args_path = os.path.join(self.model_path, 'checkpoint/args.yaml')\n", | |||
|  |     "            self.model_args = OmegaConf.load(args_path)\n", | |||
|  |     "            \n", | |||
|  |     "            print(f\"   ✅ 模型配置加载完成\")\n", | |||
|  |     "            print(f\"      输入特征维度: {self.model_args['model']['n_input_features']}\")\n", | |||
|  |     "            print(f\"      隐藏单元数: {self.model_args['model']['n_units']}\")\n", | |||
|  |     "            print(f\"      层数: {self.model_args['model']['n_layers']}\")\n", | |||
|  |     "            print(f\"      输出类别数: {self.model_args['dataset']['n_classes']}\")\n", | |||
|  |     "            \n", | |||
|  |     "            # 2. 初始化模型架构\n", | |||
|  |     "            self.model = GRUDecoder(\n", | |||
|  |     "                neural_dim=self.model_args['model']['n_input_features'],\n", | |||
|  |     "                n_units=self.model_args['model']['n_units'],\n", | |||
|  |     "                n_days=len(self.model_args['dataset']['sessions']),\n", | |||
|  |     "                n_classes=self.model_args['dataset']['n_classes'],\n", | |||
|  |     "                rnn_dropout=self.model_args['model']['rnn_dropout'],\n", | |||
|  |     "                input_dropout=self.model_args['model']['input_network']['input_layer_dropout'],\n", | |||
|  |     "                n_layers=self.model_args['model']['n_layers'],\n", | |||
|  |     "                patch_size=self.model_args['model']['patch_size'],\n", | |||
|  |     "                patch_stride=self.model_args['model']['patch_stride'],\n", | |||
|  |     "            )\n", | |||
|  |     "            \n", | |||
|  |     "            print(f\"   ✅ 模型架构初始化完成\")\n", | |||
|  |     "            \n", | |||
|  |     "            # 3. 加载预训练权重 - 修复安全问题\n", | |||
|  |     "            checkpoint_path = os.path.join(self.model_path, 'checkpoint/best_checkpoint')\n", | |||
|  |     "            \n", | |||
|  |     "            # 使用weights_only=False来解决pickle安全问题(仅在信任的检查点上使用)\n", | |||
|  |     "            print(f\"   🔐 加载检查点(注意:使用weights_only=False,请确保检查点来源可信)...\")\n", | |||
|  |     "            checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)\n", | |||
|  |     "            \n", | |||
|  |     "            # 提取模型状态字典\n", | |||
|  |     "            if 'model_state_dict' in checkpoint:\n", | |||
|  |     "                state_dict = checkpoint['model_state_dict']\n", | |||
|  |     "                print(f\"   📋 找到model_state_dict键\")\n", | |||
|  |     "            elif 'model' in checkpoint:\n", | |||
|  |     "                state_dict = checkpoint['model']\n", | |||
|  |     "                print(f\"   📋 找到model键\")\n", | |||
|  |     "            else:\n", | |||
|  |     "                state_dict = checkpoint\n", | |||
|  |     "                print(f\"   📋 使用整个checkpoint作为state_dict\")\n", | |||
|  |     "            \n", | |||
|  |     "            # 处理可能的键名不匹配\n", | |||
|  |     "            model_state_dict = self.model.state_dict()\n", | |||
|  |     "            filtered_state_dict = {}\n", | |||
|  |     "            \n", | |||
|  |     "            print(f\"   🔍 匹配模型参数...\")\n", | |||
|  |     "            matched_keys = 0\n", | |||
|  |     "            total_keys = len(state_dict)\n", | |||
|  |     "            unmatched_samples = []\n", | |||
|  |     "            \n", | |||
|  |     "            for key, value in state_dict.items():\n", | |||
|  |     "                # 移除可能的前缀\n", | |||
|  |     "                clean_key = key\n", | |||
|  |     "                \n", | |||
|  |     "                # 移除'_orig_mod.'前缀(PyTorch编译产生的)\n", | |||
|  |     "                if clean_key.startswith('_orig_mod.'):\n", | |||
|  |     "                    clean_key = clean_key.replace('_orig_mod.', '')\n", | |||
|  |     "                \n", | |||
|  |     "                # 移除'module.'前缀(分布式训练产生的)\n", | |||
|  |     "                if clean_key.startswith('module.'):\n", | |||
|  |     "                    clean_key = clean_key.replace('module.', '')\n", | |||
|  |     "                \n", | |||
|  |     "                if clean_key in model_state_dict:\n", | |||
|  |     "                    filtered_state_dict[clean_key] = value\n", | |||
|  |     "                    matched_keys += 1\n", | |||
|  |     "                else:\n", | |||
|  |     "                    # 只显示前几个不匹配的键作为示例\n", | |||
|  |     "                    if len(unmatched_samples) < 3:\n", | |||
|  |     "                        unmatched_samples.append(f\"{key} -> {clean_key}\")\n", | |||
|  |     "            \n", | |||
|  |     "            print(f\"   📊 参数匹配统计: {matched_keys}/{total_keys} 个键匹配成功\")\n", | |||
|  |     "            \n", | |||
|  |     "            if unmatched_samples:\n", | |||
|  |     "                print(f\"   ⚠️ 不匹配键示例: {', '.join(unmatched_samples)}\")\n", | |||
|  |     "            \n", | |||
|  |     "            # 加载权重\n", | |||
|  |     "            missing_keys, unexpected_keys = self.model.load_state_dict(filtered_state_dict, strict=False)\n", | |||
|  |     "            \n", | |||
|  |     "            if missing_keys:\n", | |||
|  |     "                print(f\"   ⚠️ 缺失的键 ({len(missing_keys)}): {missing_keys[:3]}{'...' if len(missing_keys) > 3 else ''}\")\n", | |||
|  |     "            \n", | |||
|  |     "            if unexpected_keys:\n", | |||
|  |     "                print(f\"   ⚠️ 意外的键 ({len(unexpected_keys)}): {unexpected_keys[:3]}{'...' if len(unexpected_keys) > 3 else ''}\")\n", | |||
|  |     "            \n", | |||
|  |     "            self.model.to(self.device)\n", | |||
|  |     "            self.model.eval()\n", | |||
|  |     "            \n", | |||
|  |     "            if matched_keys > 0:\n", | |||
|  |     "                print(f\"   ✅ 预训练权重加载完成 ({matched_keys}个参数)\")\n", | |||
|  |     "            else:\n", | |||
|  |     "                print(f\"   ❌ 没有成功匹配任何预训练权重\")\n", | |||
|  |     "                print(f\"   🔄 使用随机初始化的权重继续\")\n", | |||
|  |     "            \n", | |||
|  |     "            # 4. 显示模型信息\n", | |||
|  |     "            total_params = sum(p.numel() for p in self.model.parameters())\n", | |||
|  |     "            print(f\"   📊 模型参数统计:\")\n", | |||
|  |     "            print(f\"      总参数数量: {total_params:,}\")\n", | |||
|  |     "            \n", | |||
|  |     "            # 显示每个day的参数数量\n", | |||
|  |     "            day_params = 0\n", | |||
|  |     "            for name, param in self.model.named_parameters():\n", | |||
|  |     "                if 'day' in name:\n", | |||
|  |     "                    day_params += param.numel()\n", | |||
|  |     "            \n", | |||
|  |     "            print(f\"      Day-specific参数: {day_params:,} ({day_params/total_params*100:.1f}%)\")\n", | |||
|  |     "            print(f\"      会话数量: {len(self.model_args['dataset']['sessions'])}\")\n", | |||
|  |     "            \n", | |||
|  |     "            # 显示会话列表\n", | |||
|  |     "            print(f\"   📅 支持的会话:\")\n", | |||
|  |     "            sessions = self.model_args['dataset']['sessions']\n", | |||
|  |     "            for i, session in enumerate(sessions[:5]):\n", | |||
|  |     "                print(f\"      {i}: {session}\")\n", | |||
|  |     "            if len(sessions) > 5:\n", | |||
|  |     "                print(f\"      ... 还有 {len(sessions)-5} 个会话\")\n", | |||
|  |     "            \n", | |||
|  |     "            print(f\"\\n🎉 RNN模型加载{'成功' if matched_keys > 0 else '完成(使用随机权重)'}!\")\n", | |||
|  |     "            return True\n", | |||
|  |     "            \n", | |||
|  |     "        except Exception as e:\n", | |||
|  |     "            print(f\"❌ RNN模型加载失败: {str(e)}\")\n", | |||
|  |     "            import traceback\n", | |||
|  |     "            print(f\"详细错误信息:\")\n", | |||
|  |     "            print(traceback.format_exc())\n", | |||
|  |     "            return False\n", | |||
|  |     "    \n", | |||
|  |     "    def predict_trial(self, neural_features, day_idx=0):\n", | |||
|  |     "        \"\"\"\n", | |||
|  |     "        对单个试验进行RNN预测\n", | |||
|  |     "        \n", | |||
|  |     "        参数:\n", | |||
|  |     "            neural_features: 神经特征 [time_steps, features]\n", | |||
|  |     "            day_idx: day索引(对应不同的session)\n", | |||
|  |     "        \n", | |||
|  |     "        返回:\n", | |||
|  |     "            logits: RNN输出 [time_steps, n_classes]\n", | |||
|  |     "        \"\"\"\n", | |||
|  |     "        if self.model is None:\n", | |||
|  |     "            raise ValueError(\"模型尚未加载,请先调用load_model()\")\n", | |||
|  |     "        \n", | |||
|  |     "        try:\n", | |||
|  |     "            # 转换为tensor并添加batch维度\n", | |||
|  |     "            if isinstance(neural_features, np.ndarray):\n", | |||
|  |     "                neural_features = torch.from_numpy(neural_features).float()\n", | |||
|  |     "            \n", | |||
|  |     "            neural_features = neural_features.unsqueeze(0).to(self.device)  # [1, time_steps, features]\n", | |||
|  |     "            \n", | |||
|  |     "            # 确保day_idx在有效范围内\n", | |||
|  |     "            n_days = len(self.model_args['dataset']['sessions'])\n", | |||
|  |     "            day_idx = max(0, min(day_idx, n_days - 1))\n", | |||
|  |     "            \n", | |||
|  |     "            # 创建day索引\n", | |||
|  |     "            day_tensor = torch.tensor([day_idx], dtype=torch.long, device=self.device)\n", | |||
|  |     "            \n", | |||
|  |     "            # 模型推理\n", | |||
|  |     "            with torch.no_grad():\n", | |||
|  |     "                logits = self.model(neural_features, day_tensor)  # [1, time_steps, n_classes]\n", | |||
|  |     "            \n", | |||
|  |     "            # 移除batch维度并转换为numpy\n", | |||
|  |     "            logits = logits.squeeze(0).cpu().numpy()  # [time_steps, n_classes]\n", | |||
|  |     "            \n", | |||
|  |     "            return logits\n", | |||
|  |     "            \n", | |||
|  |     "        except Exception as e:\n", | |||
|  |     "            print(f\"❌ RNN预测失败: {str(e)}\")\n", | |||
|  |     "            return None\n", | |||
|  |     "    \n", | |||
|  |     "    def get_day_index(self, session_name):\n", | |||
|  |     "        \"\"\"\n", | |||
|  |     "        根据session名称获取对应的day索引\n", | |||
|  |     "        \n", | |||
|  |     "        参数:\n", | |||
|  |     "            session_name: session名称 (如 't15.2023.08.11')\n", | |||
|  |     "        \n", | |||
|  |     "        返回:\n", | |||
|  |     "            day_idx: day索引\n", | |||
|  |     "        \"\"\"\n", | |||
|  |     "        if self.model_args is None:\n", | |||
|  |     "            return 0\n", | |||
|  |     "        \n", | |||
|  |     "        sessions = self.model_args['dataset']['sessions']\n", | |||
|  |     "        try:\n", | |||
|  |     "            return sessions.index(session_name)\n", | |||
|  |     "        except ValueError:\n", | |||
|  |     "            print(f\"⚠️ 未找到session {session_name},使用day_idx=0\")\n", | |||
|  |     "            return 0\n", | |||
|  |     "    \n", | |||
|  |     "    def is_loaded(self):\n", | |||
|  |     "        \"\"\"检查模型是否已加载\"\"\"\n", | |||
|  |     "        return self.model is not None\n", | |||
|  |     "\n", | |||
|  |     "# 创建RNN模型加载器实例\n", | |||
|  |     "print(\"🧠 创建真实RNN模型加载器...\")\n", | |||
|  |     "rnn_loader = RealRNNModelLoader(\n", | |||
|  |     "    model_path='./data/t15_pretrained_rnn_baseline',\n", | |||
|  |     "    device='auto'\n", | |||
|  |     ")\n", | |||
|  |     "\n", | |||
|  |     "# 尝试加载模型\n", | |||
|  |     "if rnn_loader.load_model():\n", | |||
|  |     "    print(\"\\n✅ 真实RNN模型加载器准备完成!\")\n", | |||
|  |     "    print(\"🔧 现在可以使用真实的预训练RNN进行预测\")\n", | |||
|  |     "else:\n", | |||
|  |     "    print(\"\\n❌ RNN模型加载失败\")\n", | |||
|  |     "    print(\"💡 将继续使用模拟预测作为备选方案\")" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 17, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [], | |||
|  |    "source": [ | |||
|  |     "# 🏗️ 完整数据集批处理管道 - 使用真实RNN模型(优化版)\n", | |||
|  |     "import os\n", | |||
|  |     "import re\n", | |||
|  |     "import time\n", | |||
|  |     "from tqdm import tqdm\n", | |||
|  |     "import pandas as pd\n", | |||
|  |     "import numpy as np\n", | |||
|  |     "import h5py\n", | |||
|  |     "\n", | |||
|  |     "class BrainToTextDatasetPipeline:\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    批量处理所有数据集session的完整管道\n", | |||
|  |     "    集成了真实的RNN模型进行特征提取\n", | |||
|  |     "    优化版:添加延迟检查和进度条功能\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    \n", | |||
|  |     "    def __init__(self, data_dir='./data/hdf5_data_final', rnn_loader=None):\n", | |||
|  |     "        \"\"\"\n", | |||
|  |     "        初始化数据集管道(轻量级初始化)\n", | |||
|  |     "        \n", | |||
|  |     "        参数:\n", | |||
|  |     "            data_dir: 数据目录路径\n", | |||
|  |     "            rnn_loader: 真实RNN模型加载器实例\n", | |||
|  |     "        \"\"\"\n", | |||
|  |     "        print(\"🔧 初始化数据集管道...\")\n", | |||
|  |     "        self.data_dir = data_dir\n", | |||
|  |     "        self.rnn_loader = rnn_loader\n", | |||
|  |     "        self.sessions = []\n", | |||
|  |     "        self.results = {}\n", | |||
|  |     "        \n", | |||
|  |     "        # 延迟检查标志\n", | |||
|  |     "        self._rnn_checked = False\n", | |||
|  |     "        self._sessions_scanned = False\n", | |||
|  |     "        self.use_real_rnn = False\n", | |||
|  |     "        \n", | |||
|  |     "        print(\"✅ 管道初始化完成!使用延迟加载模式\")\n", | |||
|  |     "        print(\"💡 调用 check_status() 来检查RNN模型和数据状态\")\n", | |||
|  |     "    \n", | |||
|  |     "    def check_status(self):\n", | |||
|  |     "        \"\"\"检查RNN模型和数据状态(带进度条)\"\"\"\n", | |||
|  |     "        print(\"\\n🔍 正在检查系统状态...\")\n", | |||
|  |     "        \n", | |||
|  |     "        tasks = [\n", | |||
|  |     "            (\"检查数据目录\", self._check_data_directory),\n", | |||
|  |     "            (\"扫描session文件\", self._scan_sessions),\n", | |||
|  |     "            (\"检查RNN模型\", self._check_rnn_model),\n", | |||
|  |     "            (\"验证系统就绪\", self._verify_system_ready)\n", | |||
|  |     "        ]\n", | |||
|  |     "        \n", | |||
|  |     "        with tqdm(tasks, desc=\"系统检查\", unit=\"步\") as pbar:\n", | |||
|  |     "            for task_name, task_func in pbar:\n", | |||
|  |     "                pbar.set_postfix_str(f\"执行: {task_name}\")\n", | |||
|  |     "                time.sleep(0.1)  # 小延迟以显示进度\n", | |||
|  |     "                task_func()\n", | |||
|  |     "                pbar.set_postfix_str(f\"✅ {task_name}\")\n", | |||
|  |     "                time.sleep(0.2)  # 显示完成状态\n", | |||
|  |     "        \n", | |||
|  |     "        self._print_status_summary()\n", | |||
|  |     "    \n", | |||
|  |     "    def _check_data_directory(self):\n", | |||
|  |     "        \"\"\"检查数据目录\"\"\"\n", | |||
|  |     "        if not os.path.exists(self.data_dir):\n", | |||
|  |     "            raise FileNotFoundError(f\"❌ 数据目录不存在: {self.data_dir}\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 检查目录中是否有文件\n", | |||
|  |     "        items = os.listdir(self.data_dir)\n", | |||
|  |     "        if not items:\n", | |||
|  |     "            raise ValueError(f\"❌ 数据目录为空: {self.data_dir}\")\n", | |||
|  |     "    \n", | |||
|  |     "    def _scan_sessions(self):\n", | |||
|  |     "        \"\"\"扫描数据目录中的所有session(带进度条)\"\"\"\n", | |||
|  |     "        if self._sessions_scanned:\n", | |||
|  |     "            return\n", | |||
|  |     "        \n", | |||
|  |     "        if not os.path.exists(self.data_dir):\n", | |||
|  |     "            print(f\"❌ 数据目录不存在: {self.data_dir}\")\n", | |||
|  |     "            return\n", | |||
|  |     "        \n", | |||
|  |     "        print(f\"📂 正在扫描数据目录: {self.data_dir}\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 获取所有项目用于进度条\n", | |||
|  |     "        all_items = os.listdir(self.data_dir)\n", | |||
|  |     "        pattern = re.compile(r't15\\.\\d{4}\\.\\d{2}\\.\\d{2}')\n", | |||
|  |     "        \n", | |||
|  |     "        # 使用进度条扫描\n", | |||
|  |     "        valid_sessions = []\n", | |||
|  |     "        with tqdm(all_items, desc=\"扫描sessions\", unit=\"项\", leave=False) as pbar:\n", | |||
|  |     "            for item in pbar:\n", | |||
|  |     "                pbar.set_postfix_str(f\"检查: {item[:20]}...\")\n", | |||
|  |     "                \n", | |||
|  |     "                if pattern.match(item):\n", | |||
|  |     "                    session_path = os.path.join(self.data_dir, item)\n", | |||
|  |     "                    if os.path.isdir(session_path):\n", | |||
|  |     "                        # 检查是否有数据文件\n", | |||
|  |     "                        try:\n", | |||
|  |     "                            files = os.listdir(session_path)\n", | |||
|  |     "                            has_data = any(f.endswith('.h5') or f.endswith('.hdf5') for f in files)\n", | |||
|  |     "                            if has_data:\n", | |||
|  |     "                                valid_sessions.append(item)\n", | |||
|  |     "                                pbar.set_postfix_str(f\"✅ {item}\")\n", | |||
|  |     "                        except PermissionError:\n", | |||
|  |     "                            pbar.set_postfix_str(f\"⚠️ 权限错误: {item}\")\n", | |||
|  |     "                        except Exception:\n", | |||
|  |     "                            pbar.set_postfix_str(f\"❌ 错误: {item}\")\n", | |||
|  |     "                \n", | |||
|  |     "                # 小延迟以显示进度\n", | |||
|  |     "                time.sleep(0.01)\n", | |||
|  |     "        \n", | |||
|  |     "        self.sessions = sorted(valid_sessions)\n", | |||
|  |     "        self._sessions_scanned = True\n", | |||
|  |     "        \n", | |||
|  |     "        print(f\"✅ 扫描完成! 发现 {len(self.sessions)} 个有效session\")\n", | |||
|  |     "    \n", | |||
|  |     "    def _check_rnn_model(self):\n", | |||
|  |     "        \"\"\"延迟检查RNN模型状态(带进度条)\"\"\"\n", | |||
|  |     "        if self._rnn_checked:\n", | |||
|  |     "            return\n", | |||
|  |     "        \n", | |||
|  |     "        print(\"🔍 正在检查RNN模型状态...\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 模拟检查过程的进度条\n", | |||
|  |     "        check_steps = [\n", | |||
|  |     "            \"检查模型加载器\",\n", | |||
|  |     "            \"验证模型文件路径\", \n", | |||
|  |     "            \"测试文件可读性\",\n", | |||
|  |     "            \"验证模型结构\",\n", | |||
|  |     "            \"确认模型状态\"\n", | |||
|  |     "        ]\n", | |||
|  |     "        \n", | |||
|  |     "        with tqdm(check_steps, desc=\"模型检查\", unit=\"步\", leave=False) as pbar:\n", | |||
|  |     "            for step in pbar:\n", | |||
|  |     "                pbar.set_postfix_str(step)\n", | |||
|  |     "                time.sleep(0.15)  # 模拟检查时间\n", | |||
|  |     "                \n", | |||
|  |     "                if \"测试文件可读性\" in step:\n", | |||
|  |     "                    # 实际的模型检查逻辑\n", | |||
|  |     "                    if self.rnn_loader and self.rnn_loader.is_loaded():\n", | |||
|  |     "                        self.use_real_rnn = True\n", | |||
|  |     "                        pbar.set_postfix_str(\"✅ 真实模型可用\")\n", | |||
|  |     "                    else:\n", | |||
|  |     "                        self.use_real_rnn = False\n", | |||
|  |     "                        pbar.set_postfix_str(\"❌ 使用模拟模型\")\n", | |||
|  |     "        \n", | |||
|  |     "        self._rnn_checked = True\n", | |||
|  |     "        \n", | |||
|  |     "        model_type = \"真实RNN模型\" if self.use_real_rnn else \"模拟模型\"\n", | |||
|  |     "        print(f\"🤖 模型状态: {model_type}\")\n", | |||
|  |     "    \n", | |||
|  |     "    def _verify_system_ready(self):\n", | |||
|  |     "        \"\"\"验证系统就绪状态\"\"\"\n", | |||
|  |     "        if not self._sessions_scanned:\n", | |||
|  |     "            raise RuntimeError(\"Sessions未扫描完成\")\n", | |||
|  |     "        if not self._rnn_checked:\n", | |||
|  |     "            raise RuntimeError(\"RNN模型未检查完成\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 验证基本要求\n", | |||
|  |     "        if len(self.sessions) == 0:\n", | |||
|  |     "            raise ValueError(\"未找到有效的session数据\")\n", | |||
|  |     "    \n", | |||
|  |     "    def _print_status_summary(self):\n", | |||
|  |     "        \"\"\"打印状态摘要\"\"\"\n", | |||
|  |     "        print(f\"\\n📋 系统状态摘要:\")\n", | |||
|  |     "        print(\"=\"*50)\n", | |||
|  |     "        print(f\"📂 数据目录: {self.data_dir}\")\n", | |||
|  |     "        print(f\"📊 有效Sessions: {len(self.sessions)}\")\n", | |||
|  |     "        print(f\"🤖 RNN模型: {'真实模型' if self.use_real_rnn else '模拟模型'}\")\n", | |||
|  |     "        print(f\"✅ 系统状态: {'就绪' if self._sessions_scanned and self._rnn_checked else '未就绪'}\")\n", | |||
|  |     "        \n", | |||
|  |     "        if len(self.sessions) > 0:\n", | |||
|  |     "            print(f\"\\n📝 前5个session:\")\n", | |||
|  |     "            for i, session in enumerate(self.sessions[:5]):\n", | |||
|  |     "                print(f\"   {i+1:2d}. {session}\")\n", | |||
|  |     "            if len(self.sessions) > 5:\n", | |||
|  |     "                print(f\"   ... 还有 {len(self.sessions)-5} 个session\")\n", | |||
|  |     "    \n", | |||
|  |     "    def _load_session_data(self, session_name):\n", | |||
|  |     "        \"\"\"加载单个session的数据\"\"\"\n", | |||
|  |     "        session_path = os.path.join(self.data_dir, session_name)\n", | |||
|  |     "        \n", | |||
|  |     "        try:\n", | |||
|  |     "            # 查找数据文件\n", | |||
|  |     "            data_files = [f for f in os.listdir(session_path) \n", | |||
|  |     "                         if f.endswith('.h5') or f.endswith('.hdf5')]\n", | |||
|  |     "            \n", | |||
|  |     "            if not data_files:\n", | |||
|  |     "                return None, \"未找到数据文件\"\n", | |||
|  |     "            \n", | |||
|  |     "            # 使用第一个找到的数据文件\n", | |||
|  |     "            data_file = os.path.join(session_path, data_files[0])\n", | |||
|  |     "            \n", | |||
|  |     "            with h5py.File(data_file, 'r') as f:\n", | |||
|  |     "                neural_data = []\n", | |||
|  |     "                labels = []\n", | |||
|  |     "                \n", | |||
|  |     "                # 遍历所有试验\n", | |||
|  |     "                for trial_key in f.keys():\n", | |||
|  |     "                    if trial_key.startswith('trial_'):\n", | |||
|  |     "                        trial_group = f[trial_key]\n", | |||
|  |     "                        \n", | |||
|  |     "                        # 获取神经数据和标签\n", | |||
|  |     "                        neural_features = trial_group['neural_data'][:]  # [time_steps, features]\n", | |||
|  |     "                        trial_labels = trial_group['labels'][:]  # [time_steps]\n", | |||
|  |     "                        \n", | |||
|  |     "                        # 确保数据格式正确\n", | |||
|  |     "                        if len(neural_features.shape) == 2 and len(trial_labels.shape) == 1:\n", | |||
|  |     "                            # 检查时间步长是否匹配\n", | |||
|  |     "                            if neural_features.shape[0] == trial_labels.shape[0]:\n", | |||
|  |     "                                neural_data.append(neural_features)\n", | |||
|  |     "                                labels.append(trial_labels)\n", | |||
|  |     "                \n", | |||
|  |     "                if not neural_data:\n", | |||
|  |     "                    return None, \"未找到有效的试验数据\"\n", | |||
|  |     "                \n", | |||
|  |     "                print(f\"   📊 {session_name}: 加载了 {len(neural_data)} 个试验\")\n", | |||
|  |     "                return (neural_data, labels), None\n", | |||
|  |     "                \n", | |||
|  |     "        except Exception as e:\n", | |||
|  |     "            return None, f\"加载失败: {str(e)}\"\n", | |||
|  |     "    \n", | |||
|  |     "    def _extract_rnn_features(self, neural_data, session_name):\n", | |||
|  |     "        \"\"\"使用真实RNN模型提取特征\"\"\"\n", | |||
|  |     "        if not self.use_real_rnn:\n", | |||
|  |     "            return self._simulate_rnn_predictions(neural_data)\n", | |||
|  |     "        \n", | |||
|  |     "        try:\n", | |||
|  |     "            # 获取session对应的day索引\n", | |||
|  |     "            day_idx = self.rnn_loader.get_day_index(session_name)\n", | |||
|  |     "            \n", | |||
|  |     "            rnn_predictions = []\n", | |||
|  |     "            \n", | |||
|  |     "            # 为试验预测添加进度条\n", | |||
|  |     "            with tqdm(neural_data, desc=f\"RNN预测 {session_name}\", unit=\"试验\", leave=False) as pbar:\n", | |||
|  |     "                for trial_idx, neural_features in enumerate(pbar):\n", | |||
|  |     "                    pbar.set_postfix_str(f\"试验 {trial_idx+1}\")\n", | |||
|  |     "                    \n", | |||
|  |     "                    # 使用真实RNN模型进行预测\n", | |||
|  |     "                    logits = self.rnn_loader.predict_trial(neural_features, day_idx)\n", | |||
|  |     "                    \n", | |||
|  |     "                    if logits is not None:\n", | |||
|  |     "                        rnn_predictions.append(logits)\n", | |||
|  |     "                        pbar.set_postfix_str(f\"✅ 试验 {trial_idx+1}\")\n", | |||
|  |     "                    else:\n", | |||
|  |     "                        # 如果预测失败,使用模拟数据\n", | |||
|  |     "                        print(f\"   ⚠️ 试验 {trial_idx} RNN预测失败,使用模拟数据\")\n", | |||
|  |     "                        simulated = self._simulate_single_trial_prediction(neural_features)\n", | |||
|  |     "                        rnn_predictions.append(simulated)\n", | |||
|  |     "                        pbar.set_postfix_str(f\"⚠️ 试验 {trial_idx+1} (模拟)\")\n", | |||
|  |     "            \n", | |||
|  |     "            return rnn_predictions\n", | |||
|  |     "            \n", | |||
|  |     "        except Exception as e:\n", | |||
|  |     "            print(f\"   ❌ RNN特征提取失败: {str(e)}\")\n", | |||
|  |     "            print(f\"   🔄 回退到模拟预测\")\n", | |||
|  |     "            return self._simulate_rnn_predictions(neural_data)\n", | |||
|  |     "    \n", | |||
|  |     "    def _simulate_single_trial_prediction(self, neural_features):\n", | |||
|  |     "        \"\"\"为单个试验生成模拟RNN预测\"\"\"\n", | |||
|  |     "        time_steps = neural_features.shape[0]\n", | |||
|  |     "        n_phonemes = 40\n", | |||
|  |     "        \n", | |||
|  |     "        # 生成模拟的logits(更加真实的分布)\n", | |||
|  |     "        logits = np.random.randn(time_steps, n_phonemes) * 2.0\n", | |||
|  |     "        \n", | |||
|  |     "        # 添加一些时间相关的模式\n", | |||
|  |     "        for t in range(time_steps):\n", | |||
|  |     "            # 静音类在开始和结束时概率更高\n", | |||
|  |     "            if t < 5 or t > time_steps - 5:\n", | |||
|  |     "                logits[t, 0] += 2.0  # 静音类\n", | |||
|  |     "            \n", | |||
|  |     "            # 添加一些语音学合理的模式\n", | |||
|  |     "            if t % 10 < 3:  # 模拟辅音\n", | |||
|  |     "                logits[t, 1:15] += 1.0\n", | |||
|  |     "            else:  # 模拟元音\n", | |||
|  |     "                logits[t, 15:25] += 1.0\n", | |||
|  |     "        \n", | |||
|  |     "        return logits\n", | |||
|  |     "    \n", | |||
|  |     "    # def _simulate_rnn_predictions(self, neural_data):\n", | |||
|  |     "    #     \"\"\"为所有试验生成模拟RNN预测\"\"\"\n", | |||
|  |     "    #     print(\"   🎭 使用模拟RNN预测\")\n", | |||
|  |     "    #     predictions = []\n", | |||
|  |     "        \n", | |||
|  |     "    #     for neural_features in neural_data:\n", | |||
|  |     "    #         pred = self._simulate_single_trial_prediction(neural_features)\n", | |||
|  |     "    #         predictions.append(pred)\n", | |||
|  |     "        \n", | |||
|  |     "    #     return predictions\n", | |||
|  |     "    \n", | |||
|  |     "    def _compute_confidence_metrics(self, logits):\n", | |||
|  |     "        \"\"\"计算置信度指标\"\"\"\n", | |||
|  |     "        # 转换为概率\n", | |||
|  |     "        probs = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)\n", | |||
|  |     "        \n", | |||
|  |     "        # 计算熵(不确定性指标)\n", | |||
|  |     "        entropy = -np.sum(probs * np.log(probs + 1e-8), axis=-1)\n", | |||
|  |     "        \n", | |||
|  |     "        # 计算最大概率(置信度指标)\n", | |||
|  |     "        max_prob = np.max(probs, axis=-1)\n", | |||
|  |     "        \n", | |||
|  |     "        # 计算top-2差距(决策边界指标)\n", | |||
|  |     "        sorted_probs = np.sort(probs, axis=-1)\n", | |||
|  |     "        top2_margin = sorted_probs[:, -1] - sorted_probs[:, -2]\n", | |||
|  |     "        \n", | |||
|  |     "        return {\n", | |||
|  |     "            'entropy': entropy,\n", | |||
|  |     "            'max_prob': max_prob,\n", | |||
|  |     "            'top2_margin': top2_margin,\n", | |||
|  |     "            'mean_entropy': np.mean(entropy),\n", | |||
|  |     "            'mean_max_prob': np.mean(max_prob),\n", | |||
|  |     "            'mean_top2_margin': np.mean(top2_margin)\n", | |||
|  |     "        }\n", | |||
|  |     "    \n", | |||
|  |     "    def _process_single_session(self, session_name):\n", | |||
|  |     "        \"\"\"处理单个session\"\"\"\n", | |||
|  |     "        print(f\"\\n🔄 处理session: {session_name}\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 加载数据\n", | |||
|  |     "        data_result, error = self._load_session_data(session_name)\n", | |||
|  |     "        if data_result is None:\n", | |||
|  |     "            print(f\"   ❌ 数据加载失败: {error}\")\n", | |||
|  |     "            return None\n", | |||
|  |     "        \n", | |||
|  |     "        neural_data, labels = data_result\n", | |||
|  |     "        \n", | |||
|  |     "        # 提取RNN特征\n", | |||
|  |     "        print(f\"   🧠 提取RNN特征...\")\n", | |||
|  |     "        rnn_predictions = self._extract_rnn_features(neural_data, session_name)\n", | |||
|  |     "        \n", | |||
|  |     "        if not rnn_predictions:\n", | |||
|  |     "            print(f\"   ❌ RNN特征提取失败\")\n", | |||
|  |     "            return None\n", | |||
|  |     "        \n", | |||
|  |     "        # 处理所有试验数据\n", | |||
|  |     "        session_features = []\n", | |||
|  |     "        \n", | |||
|  |     "        for trial_idx, (neural_features, trial_labels, rnn_logits) in enumerate(\n", | |||
|  |     "            zip(neural_data, labels, rnn_predictions)):\n", | |||
|  |     "            \n", | |||
|  |     "            # 确保维度匹配\n", | |||
|  |     "            min_length = min(len(neural_features), len(trial_labels), len(rnn_logits))\n", | |||
|  |     "            neural_features = neural_features[:min_length]\n", | |||
|  |     "            trial_labels = trial_labels[:min_length]\n", | |||
|  |     "            rnn_logits = rnn_logits[:min_length]\n", | |||
|  |     "            \n", | |||
|  |     "            # 计算置信度指标\n", | |||
|  |     "            confidence_metrics = self._compute_confidence_metrics(rnn_logits)\n", | |||
|  |     "            \n", | |||
|  |     "            # 创建DataFrame\n", | |||
|  |     "            trial_df = pd.DataFrame()\n", | |||
|  |     "            \n", | |||
|  |     "            # 添加神经特征\n", | |||
|  |     "            for feat_idx in range(neural_features.shape[1]):\n", | |||
|  |     "                trial_df[f'neural_feat_{feat_idx:03d}'] = neural_features[:, feat_idx]\n", | |||
|  |     "            \n", | |||
|  |     "            # 添加RNN预测的40个音素概率\n", | |||
|  |     "            rnn_probs = np.exp(rnn_logits) / np.sum(np.exp(rnn_logits), axis=-1, keepdims=True)\n", | |||
|  |     "            for phoneme_idx in range(40):\n", | |||
|  |     "                trial_df[f'phoneme_{phoneme_idx:02d}'] = rnn_probs[:, phoneme_idx]\n", | |||
|  |     "            \n", | |||
|  |     "            # 添加置信度指标\n", | |||
|  |     "            trial_df['confidence_entropy'] = confidence_metrics['entropy']\n", | |||
|  |     "            trial_df['confidence_max_prob'] = confidence_metrics['max_prob']\n", | |||
|  |     "            trial_df['confidence_top2_margin'] = confidence_metrics['top2_margin']\n", | |||
|  |     "            \n", | |||
|  |     "            # 添加元数据\n", | |||
|  |     "            trial_df['session_name'] = session_name\n", | |||
|  |     "            trial_df['trial_id'] = trial_idx\n", | |||
|  |     "            trial_df['time_step'] = range(len(trial_df))\n", | |||
|  |     "            trial_df['ground_truth_label'] = trial_labels\n", | |||
|  |     "            \n", | |||
|  |     "            session_features.append(trial_df)\n", | |||
|  |     "        \n", | |||
|  |     "        # 合并所有试验\n", | |||
|  |     "        if session_features:\n", | |||
|  |     "            combined_df = pd.concat(session_features, ignore_index=True)\n", | |||
|  |     "            print(f\"   ✅ {session_name}: 处理完成,共 {len(combined_df)} 个样本\")\n", | |||
|  |     "            return combined_df\n", | |||
|  |     "        else:\n", | |||
|  |     "            print(f\"   ❌ {session_name}: 没有有效数据\")\n", | |||
|  |     "            return None\n", | |||
|  |     "    \n", | |||
|  |     "    def process_sessions(self, session_names=None, max_sessions=None):\n", | |||
|  |     "        \"\"\"批量处理sessions(带进度条)\"\"\"\n", | |||
|  |     "        # 确保已检查状态\n", | |||
|  |     "        if not self._sessions_scanned or not self._rnn_checked:\n", | |||
|  |     "            print(\"⚠️ 请先运行 check_status() 检查系统状态\")\n", | |||
|  |     "            return\n", | |||
|  |     "        \n", | |||
|  |     "        # 确定要处理的sessions\n", | |||
|  |     "        if session_names is None:\n", | |||
|  |     "            target_sessions = self.sessions[:max_sessions] if max_sessions else self.sessions\n", | |||
|  |     "        else:\n", | |||
|  |     "            target_sessions = [s for s in session_names if s in self.sessions]\n", | |||
|  |     "        \n", | |||
|  |     "        if not target_sessions:\n", | |||
|  |     "            print(\"❌ 没有找到要处理的sessions\")\n", | |||
|  |     "            return\n", | |||
|  |     "        \n", | |||
|  |     "        print(f\"\\n🚀 开始批量处理 {len(target_sessions)} 个sessions\")\n", | |||
|  |     "        print(\"=\"*60)\n", | |||
|  |     "        \n", | |||
|  |     "        successful_results = {}\n", | |||
|  |     "        \n", | |||
|  |     "        # 使用进度条处理sessions\n", | |||
|  |     "        with tqdm(target_sessions, desc=\"处理Sessions\", unit=\"session\") as pbar:\n", | |||
|  |     "            for session_name in pbar:\n", | |||
|  |     "                pbar.set_postfix_str(f\"处理: {session_name}\")\n", | |||
|  |     "                \n", | |||
|  |     "                try:\n", | |||
|  |     "                    result_df = self._process_single_session(session_name)\n", | |||
|  |     "                    if result_df is not None:\n", | |||
|  |     "                        successful_results[session_name] = result_df\n", | |||
|  |     "                        pbar.set_postfix_str(f\"✅ {session_name}\")\n", | |||
|  |     "                    else:\n", | |||
|  |     "                        pbar.set_postfix_str(f\"❌ {session_name}\")\n", | |||
|  |     "                        \n", | |||
|  |     "                except Exception as e:\n", | |||
|  |     "                    print(f\"   💥 处理 {session_name} 时出错: {str(e)}\")\n", | |||
|  |     "                    pbar.set_postfix_str(f\"\udca5 {session_name}\")\n", | |||
|  |     "                \n", | |||
|  |     "                # 小延迟以显示状态\n", | |||
|  |     "                time.sleep(0.1)\n", | |||
|  |     "        \n", | |||
|  |     "        self.results = successful_results\n", | |||
|  |     "        \n", | |||
|  |     "        print(f\"\\n📊 批量处理完成!\")\n", | |||
|  |     "        print(f\"   成功处理: {len(successful_results)} / {len(target_sessions)} sessions\")\n", | |||
|  |     "        print(f\"   总样本数: {sum(len(df) for df in successful_results.values()):,}\")\n", | |||
|  |     "        \n", | |||
|  |     "        if successful_results:\n", | |||
|  |     "            print(f\"   特征维度: {list(successful_results.values())[0].shape[1]} 列\")\n", | |||
|  |     "    \n", | |||
|  |     "    def get_combined_dataset(self):\n", | |||
|  |     "        \"\"\"获取合并的数据集\"\"\"\n", | |||
|  |     "        if not self.results:\n", | |||
|  |     "            print(\"❌ 没有处理结果,请先运行 process_sessions()\")\n", | |||
|  |     "            return None\n", | |||
|  |     "        \n", | |||
|  |     "        print(\"🔗 合并所有session数据...\")\n", | |||
|  |     "        combined_dfs = list(self.results.values())\n", | |||
|  |     "        \n", | |||
|  |     "        if combined_dfs:\n", | |||
|  |     "            combined_df = pd.concat(combined_dfs, ignore_index=True)\n", | |||
|  |     "            print(f\"✅ 合并完成: {len(combined_df):,} 个样本\")\n", | |||
|  |     "            return combined_df\n", | |||
|  |     "        else:\n", | |||
|  |     "            return None\n", | |||
|  |     "    \n", | |||
|  |     "    def save_results(self, output_dir='./outputs'):\n", | |||
|  |     "        \"\"\"保存处理结果(带进度条)\"\"\"\n", | |||
|  |     "        if not self.results:\n", | |||
|  |     "            print(\"❌ 没有处理结果,请先运行 process_sessions()\")\n", | |||
|  |     "            return\n", | |||
|  |     "        \n", | |||
|  |     "        os.makedirs(output_dir, exist_ok=True)\n", | |||
|  |     "        print(f\"💾 保存处理结果到: {output_dir}\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 使用进度条保存文件\n", | |||
|  |     "        with tqdm(self.results.items(), desc=\"保存文件\", unit=\"文件\") as pbar:\n", | |||
|  |     "            for session_name, df in pbar:\n", | |||
|  |     "                pbar.set_postfix_str(f\"保存: {session_name}\")\n", | |||
|  |     "                output_path = os.path.join(output_dir, f\"{session_name}_features.csv\")\n", | |||
|  |     "                df.to_csv(output_path, index=False) \n", | |||
|  |     "                pbar.set_postfix_str(f\"✅ {session_name}\")\n", | |||
|  |     "                time.sleep(0.05)  # 小延迟以显示进度\n", | |||
|  |     "        \n", | |||
|  |     "        # 保存合并数据集\n", | |||
|  |     "        combined_df = self.get_combined_dataset()\n", | |||
|  |     "        if combined_df is not None:\n", | |||
|  |     "            print(\"💾 保存合并数据集...\")\n", | |||
|  |     "            combined_path = os.path.join(output_dir, \"combined_all_sessions_features.csv\")\n", | |||
|  |     "            combined_df.to_csv(combined_path, index=False)\n", | |||
|  |     "            print(f\"   ✅ 合并数据集: {combined_path}\")\n", | |||
|  |     "        \n", | |||
|  |     "        print(f\"💾 保存完成!\")" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 20, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "🔧 创建改进版数据处理管道 (仅真实数据模式)...\n", | |||
|  |       "🔧 初始化改进版数据处理管道...\n", | |||
|  |       "✅ 管道初始化完成!\n", | |||
|  |       "✅ 改进版管道创建完成!\n", | |||
|  |       "⚠️  重要: 本管道仅支持真实RNN模型,不提供任何模拟功能\n", | |||
|  |       "💡 使用方法:\n", | |||
|  |       "   # 确保rnn_loader已加载真实模型\n", | |||
|  |       "   results = improved_pipeline.run_full_pipeline(\n", | |||
|  |       "       train_list, val_list, test_list,\n", | |||
|  |       "       max_files_per_split=3,\n", | |||
|  |       "       rnn_loader=rnn_loader  # 必须是已加载的真实模型\n", | |||
|  |       "   )\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 26, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "📚 数据集访问和使用示例\n", | |||
|  |       "============================================================\n", | |||
|  |       "✅ 数据集变量已创建:\n", | |||
|  |       "   train_datasets: 0 个训练DataFrame\n", | |||
|  |       "   val_datasets: 0 个验证DataFrame\n", | |||
|  |       "   test_datasets: 0 个测试DataFrame\n", | |||
|  |       "\n", | |||
|  |       "🔗 数据集合并示例:\n", | |||
|  |       "\n", | |||
|  |       "💾 数据集保存功能:\n", | |||
|  |       "使用 save_datasets_to_files(train_datasets, val_datasets, test_datasets)\n", | |||
|  |       "将保存所有单独session和合并的数据集文件\n", | |||
|  |       "\n", | |||
|  |       "🎯 工作流完成总结:\n", | |||
|  |       "============================================================\n", | |||
|  |       "✅ 成功创建了完整的数据集处理工作流\n", | |||
|  |       "✅ 生成了三个主要变量:train_datasets, val_datasets, test_datasets\n", | |||
|  |       "✅ 每个变量包含多个DataFrame,对应不同的sessions\n", | |||
|  |       "✅ 每个DataFrame包含:\n", | |||
|  |       "   • 512维神经特征 (neural_feat_000 ~ neural_feat_511)\n", | |||
|  |       "   • 40维音素预测概率 (phoneme_0 ~ phoneme_39)\n", | |||
|  |       "   • 置信度指标 (entropy, top2_margin)\n", | |||
|  |       "   • 元数据 (session, trial, 时间戳等)\n", | |||
|  |       "✅ 支持单独访问或合并使用\n", | |||
|  |       "✅ 支持保存为CSV文件\n", | |||
|  |       "\n", | |||
|  |       "🚀 现在你可以使用这些数据集进行机器学习建模了!\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "# 📚 数据集访问和使用示例\n", | |||
|  |     "print(\"📚 数据集访问和使用示例\")\n", | |||
|  |     "print(\"=\"*60)\n", | |||
|  |     "\n", | |||
|  |     "# 方便的变量赋值(按照你的要求)\n", | |||
|  |     "train_datasets = pipeline.train_datasets   # 训练集列表\n", | |||
|  |     "val_datasets = pipeline.val_datasets       # 验证集列表  \n", | |||
|  |     "test_datasets = pipeline.test_datasets     # 测试集列表\n", | |||
|  |     "\n", | |||
|  |     "print(f\"✅ 数据集变量已创建:\")\n", | |||
|  |     "print(f\"   train_datasets: {len(train_datasets)} 个训练DataFrame\")\n", | |||
|  |     "print(f\"   val_datasets: {len(val_datasets)} 个验证DataFrame\")\n", | |||
|  |     "print(f\"   test_datasets: {len(test_datasets)} 个测试DataFrame\")\n", | |||
|  |     "\n", | |||
|  |     "# 演示如何使用这些数据集\n", | |||
|  |     "if train_datasets:\n", | |||
|  |     "    print(f\"\\n🔍 数据集使用示例:\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 访问第一个训练集\n", | |||
|  |     "    first_train_session = train_datasets[0]\n", | |||
|  |     "    print(f\"\\n1. 访问第一个训练session:\")\n", | |||
|  |     "    print(f\"   形状: {first_train_session.shape}\")\n", | |||
|  |     "    print(f\"   Session: {first_train_session['session'].iloc[0]}\")\n", | |||
|  |     "    print(f\"   试验数: {first_train_session['trial_idx'].max() + 1}\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 获取神经特征\n", | |||
|  |     "    neural_cols = [col for col in first_train_session.columns if col.startswith('neural_feat_')]\n", | |||
|  |     "    neural_features = first_train_session[neural_cols]\n", | |||
|  |     "    print(f\"\\n2. 提取神经特征:\")\n", | |||
|  |     "    print(f\"   神经特征形状: {neural_features.shape}\")\n", | |||
|  |     "    print(f\"   特征维度: {len(neural_cols)}\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 获取音素预测\n", | |||
|  |     "    phoneme_cols = [col for col in first_train_session.columns if col.startswith('phoneme_')]\n", | |||
|  |     "    phoneme_predictions = first_train_session[phoneme_cols]\n", | |||
|  |     "    print(f\"\\n3. 提取音素预测:\")\n", | |||
|  |     "    print(f\"   预测概率形状: {phoneme_predictions.shape}\")\n", | |||
|  |     "    print(f\"   音素类别数: {len(phoneme_cols)}\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 获取元数据\n", | |||
|  |     "    metadata_cols = ['time_step', 'session', 'trial_idx', 'ground_truth_phoneme_name', \n", | |||
|  |     "                     'predicted_phoneme_name', 'max_probability', 'entropy', 'top2_margin']\n", | |||
|  |     "    metadata = first_train_session[metadata_cols]\n", | |||
|  |     "    print(f\"\\n4. 提取元数据:\")\n", | |||
|  |     "    print(f\"   元数据列数: {len(metadata_cols)}\")\n", | |||
|  |     "    print(f\"   包含: 时间步、session、试验信息、真实/预测标签、置信度指标\")\n", | |||
|  |     "\n", | |||
|  |     "# 合并所有数据集的函数\n", | |||
|  |     "def combine_datasets(dataset_list, split_name):\n", | |||
|  |     "    \"\"\"合并同一类型的所有数据集\"\"\"\n", | |||
|  |     "    if not dataset_list:\n", | |||
|  |     "        return None\n", | |||
|  |     "    \n", | |||
|  |     "    combined_df = pd.concat(dataset_list, ignore_index=True)\n", | |||
|  |     "    print(f\"\\n📊 {split_name}合并结果:\")\n", | |||
|  |     "    print(f\"   总数据形状: {combined_df.shape}\")\n", | |||
|  |     "    print(f\"   包含sessions: {combined_df['session'].nunique()} 个\")\n", | |||
|  |     "    print(f\"   总试验数: {combined_df['trial_idx'].nunique()}\")\n", | |||
|  |     "    print(f\"   总时间步数: {len(combined_df):,}\")\n", | |||
|  |     "    \n", | |||
|  |     "    return combined_df\n", | |||
|  |     "\n", | |||
|  |     "# 演示合并数据集\n", | |||
|  |     "print(f\"\\n🔗 数据集合并示例:\")\n", | |||
|  |     "if train_datasets:\n", | |||
|  |     "    combined_train = combine_datasets(train_datasets, \"训练集\")\n", | |||
|  |     "    \n", | |||
|  |     "if val_datasets:\n", | |||
|  |     "    combined_val = combine_datasets(val_datasets, \"验证集\")\n", | |||
|  |     "    \n", | |||
|  |     "if test_datasets:\n", | |||
|  |     "    combined_test = combine_datasets(test_datasets, \"测试集\")\n", | |||
|  |     "\n", | |||
|  |     "# 保存数据集的函数\n", | |||
|  |     "def save_datasets_to_files(train_list, val_list, test_list, output_dir='./processed_datasets'):\n", | |||
|  |     "    \"\"\"保存所有数据集到文件\"\"\"\n", | |||
|  |     "    saved_files = []\n", | |||
|  |     "    \n", | |||
|  |     "    # 保存训练集\n", | |||
|  |     "    for i, df in enumerate(train_list):\n", | |||
|  |     "        filename = os.path.join(output_dir, f'train_session_{i:02d}_{df[\"session\"].iloc[0]}.csv')\n", | |||
|  |     "        df.to_csv(filename, index=False)\n", | |||
|  |     "        saved_files.append(filename)\n", | |||
|  |     "    \n", | |||
|  |     "    # 保存验证集\n", | |||
|  |     "    for i, df in enumerate(val_list):\n", | |||
|  |     "        filename = os.path.join(output_dir, f'val_session_{i:02d}_{df[\"session\"].iloc[0]}.csv')\n", | |||
|  |     "        df.to_csv(filename, index=False)\n", | |||
|  |     "        saved_files.append(filename)\n", | |||
|  |     "    \n", | |||
|  |     "    # 保存测试集\n", | |||
|  |     "    for i, df in enumerate(test_list):\n", | |||
|  |     "        filename = os.path.join(output_dir, f'test_session_{i:02d}_{df[\"session\"].iloc[0]}.csv')\n", | |||
|  |     "        df.to_csv(filename, index=False)\n", | |||
|  |     "        saved_files.append(filename)\n", | |||
|  |     "    \n", | |||
|  |     "    # 保存合并的数据集\n", | |||
|  |     "    if train_list:\n", | |||
|  |     "        combined_train = pd.concat(train_list, ignore_index=True)\n", | |||
|  |     "        train_combined_file = os.path.join(output_dir, 'combined_train_all_sessions.csv')\n", | |||
|  |     "        combined_train.to_csv(train_combined_file, index=False)\n", | |||
|  |     "        saved_files.append(train_combined_file)\n", | |||
|  |     "    \n", | |||
|  |     "    if val_list:\n", | |||
|  |     "        combined_val = pd.concat(val_list, ignore_index=True)\n", | |||
|  |     "        val_combined_file = os.path.join(output_dir, 'combined_val_all_sessions.csv')\n", | |||
|  |     "        combined_val.to_csv(val_combined_file, index=False)\n", | |||
|  |     "        saved_files.append(val_combined_file)\n", | |||
|  |     "    \n", | |||
|  |     "    if test_list:\n", | |||
|  |     "        combined_test = pd.concat(test_list, ignore_index=True)\n", | |||
|  |     "        test_combined_file = os.path.join(output_dir, 'combined_test_all_sessions.csv')\n", | |||
|  |     "        combined_test.to_csv(test_combined_file, index=False)\n", | |||
|  |     "        saved_files.append(test_combined_file)\n", | |||
|  |     "    \n", | |||
|  |     "    return saved_files\n", | |||
|  |     "\n", | |||
|  |     "print(f\"\\n💾 数据集保存功能:\")\n", | |||
|  |     "print(f\"使用 save_datasets_to_files(train_datasets, val_datasets, test_datasets)\")\n", | |||
|  |     "print(f\"将保存所有单独session和合并的数据集文件\")\n", | |||
|  |     "\n", | |||
|  |     "print(f\"\\n🎯 工作流完成总结:\")\n", | |||
|  |     "print(\"=\"*60)\n", | |||
|  |     "print(f\"✅ 成功创建了完整的数据集处理工作流\")\n", | |||
|  |     "print(f\"✅ 生成了三个主要变量:train_datasets, val_datasets, test_datasets\")\n", | |||
|  |     "print(f\"✅ 每个变量包含多个DataFrame,对应不同的sessions\")\n", | |||
|  |     "print(f\"✅ 每个DataFrame包含:\")\n", | |||
|  |     "print(f\"   • 512维神经特征 (neural_feat_000 ~ neural_feat_511)\")\n", | |||
|  |     "print(f\"   • 40维音素预测概率 (phoneme_0 ~ phoneme_39)\")\n", | |||
|  |     "print(f\"   • 置信度指标 (entropy, top2_margin)\")\n", | |||
|  |     "print(f\"   • 元数据 (session, trial, 时间戳等)\")\n", | |||
|  |     "print(f\"✅ 支持单独访问或合并使用\")\n", | |||
|  |     "print(f\"✅ 支持保存为CSV文件\")\n", | |||
|  |     "\n", | |||
|  |     "print(f\"\\n🚀 现在你可以使用这些数据集进行机器学习建模了!\")" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "# 模型建立" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "markdown", | |||
|  |    "metadata": {}, | |||
|  |    "source": [ | |||
|  |     "## 🌲 随机森林多输出回归模型\n", | |||
|  |     "\n", | |||
|  |     "使用随机森林对40个音素概率进行回归预测,输入为512维神经特征,时间窗口为30。" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": null, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [], | |||
|  |    "source": [] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 27, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "🌲 初始化随机森林回归模型\n", | |||
|  |       "🌲 随机森林回归器初始化:\n", | |||
|  |       "   时间窗口大小: 30\n", | |||
|  |       "   树的数量: 100\n", | |||
|  |       "   最大深度: 10\n", | |||
|  |       "   并行任务: -1\n", | |||
|  |       "\n", | |||
|  |       "✅ 随机森林回归器准备完成!\n", | |||
|  |       "🔧 下一步: 准备训练数据和开始训练\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "# 🌲 随机森林多输出回归模型实现\n", | |||
|  |     "import numpy as np\n", | |||
|  |     "from sklearn.ensemble import RandomForestRegressor\n", | |||
|  |     "from sklearn.model_selection import train_test_split\n", | |||
|  |     "from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error\n", | |||
|  |     "from sklearn.preprocessing import StandardScaler\n", | |||
|  |     "import matplotlib.pyplot as plt\n", | |||
|  |     "import seaborn as sns\n", | |||
|  |     "from tqdm import tqdm\n", | |||
|  |     "import warnings\n", | |||
|  |     "warnings.filterwarnings('ignore')\n", | |||
|  |     "\n", | |||
|  |     "class TimeWindowRandomForestRegressor:\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    基于时间窗口的随机森林多输出回归器\n", | |||
|  |     "    用于预测40个音素的概率分布\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    \n", | |||
|  |     "    def __init__(self, window_size=30, n_estimators=100, max_depth=10, n_jobs=-1, random_state=42):\n", | |||
|  |     "        \"\"\"\n", | |||
|  |     "        初始化模型\n", | |||
|  |     "        \n", | |||
|  |     "        参数:\n", | |||
|  |     "            window_size: 时间窗口大小\n", | |||
|  |     "            n_estimators: 随机森林中树的数量\n", | |||
|  |     "            max_depth: 树的最大深度\n", | |||
|  |     "            n_jobs: 并行任务数\n", | |||
|  |     "            random_state: 随机种子\n", | |||
|  |     "        \"\"\"\n", | |||
|  |     "        self.window_size = window_size\n", | |||
|  |     "        self.n_estimators = n_estimators\n", | |||
|  |     "        self.max_depth = max_depth\n", | |||
|  |     "        self.n_jobs = n_jobs\n", | |||
|  |     "        self.random_state = random_state\n", | |||
|  |     "        \n", | |||
|  |     "        # 初始化模型和预处理器\n", | |||
|  |     "        self.regressor = RandomForestRegressor(\n", | |||
|  |     "            n_estimators=n_estimators,\n", | |||
|  |     "            max_depth=max_depth,\n", | |||
|  |     "            n_jobs=n_jobs,\n", | |||
|  |     "            random_state=random_state,\n", | |||
|  |     "            verbose=1\n", | |||
|  |     "        )\n", | |||
|  |     "        \n", | |||
|  |     "        self.scaler = StandardScaler()\n", | |||
|  |     "        self.is_fitted = False\n", | |||
|  |     "        \n", | |||
|  |     "        print(f\"🌲 随机森林回归器初始化:\")\n", | |||
|  |     "        print(f\"   时间窗口大小: {window_size}\")\n", | |||
|  |     "        print(f\"   树的数量: {n_estimators}\")\n", | |||
|  |     "        print(f\"   最大深度: {max_depth}\")\n", | |||
|  |     "        print(f\"   并行任务: {n_jobs}\")\n", | |||
|  |     "    \n", | |||
|  |     "    def create_time_windows(self, neural_features, phoneme_targets=None):\n", | |||
|  |     "        \"\"\"\n", | |||
|  |     "        创建时间窗口特征\n", | |||
|  |     "        \n", | |||
|  |     "        参数:\n", | |||
|  |     "            neural_features: 神经特征 [time_steps, 512]\n", | |||
|  |     "            phoneme_targets: 音素目标 [time_steps, 40] (可选)\n", | |||
|  |     "        \n", | |||
|  |     "        返回:\n", | |||
|  |     "            windowed_features: [samples, window_size * 512]\n", | |||
|  |     "            windowed_targets: [samples, 40] (如果提供targets)\n", | |||
|  |     "        \"\"\"\n", | |||
|  |     "        if len(neural_features) < self.window_size:\n", | |||
|  |     "            print(f\"⚠️ 数据长度 {len(neural_features)} 小于窗口大小 {self.window_size}\")\n", | |||
|  |     "            return None, None\n", | |||
|  |     "        \n", | |||
|  |     "        n_samples = len(neural_features) - self.window_size + 1\n", | |||
|  |     "        n_features = neural_features.shape[1]\n", | |||
|  |     "        \n", | |||
|  |     "        # 创建时间窗口特征\n", | |||
|  |     "        windowed_features = np.zeros((n_samples, self.window_size * n_features))\n", | |||
|  |     "        \n", | |||
|  |     "        for i in range(n_samples):\n", | |||
|  |     "            # 展平时间窗口内的所有特征\n", | |||
|  |     "            window_data = neural_features[i:i+self.window_size].flatten()\n", | |||
|  |     "            windowed_features[i] = window_data\n", | |||
|  |     "        \n", | |||
|  |     "        windowed_targets = None\n", | |||
|  |     "        if phoneme_targets is not None:\n", | |||
|  |     "            # 使用窗口中心点的音素概率作为目标\n", | |||
|  |     "            center_offset = self.window_size // 2\n", | |||
|  |     "            windowed_targets = phoneme_targets[center_offset:center_offset+n_samples]\n", | |||
|  |     "        \n", | |||
|  |     "        return windowed_features, windowed_targets\n", | |||
|  |     "    \n", | |||
|  |     "    def prepare_dataset_for_training(self, datasets_list, dataset_type=\"train\"):\n", | |||
|  |     "        \"\"\"\n", | |||
|  |     "        准备训练数据集\n", | |||
|  |     "        \n", | |||
|  |     "        参数:\n", | |||
|  |     "            datasets_list: DataFrame列表 (train_datasets, val_datasets, etc.)\n", | |||
|  |     "            dataset_type: 数据集类型名称\n", | |||
|  |     "        \n", | |||
|  |     "        返回:\n", | |||
|  |     "            X: 特征矩阵 [总样本数, window_size * 512]\n", | |||
|  |     "            y: 目标矩阵 [总样本数, 40]\n", | |||
|  |     "        \"\"\"\n", | |||
|  |     "        print(f\"\\n📊 准备{dataset_type}数据集:\")\n", | |||
|  |     "        print(f\"   输入数据集数量: {len(datasets_list)}\")\n", | |||
|  |     "        \n", | |||
|  |     "        all_X = []\n", | |||
|  |     "        all_y = []\n", | |||
|  |     "        \n", | |||
|  |     "        for i, df in enumerate(tqdm(datasets_list, desc=f\"处理{dataset_type}数据\")):\n", | |||
|  |     "            # 提取神经特征 (前512列)\n", | |||
|  |     "            neural_cols = [col for col in df.columns if col.startswith('neural_feat_')]\n", | |||
|  |     "            neural_features = df[neural_cols].values\n", | |||
|  |     "            \n", | |||
|  |     "            # 提取音素目标 (40列音素概率)\n", | |||
|  |     "            phoneme_cols = [col for col in df.columns if col.startswith('phoneme_')]\n", | |||
|  |     "            phoneme_targets = df[phoneme_cols].values\n", | |||
|  |     "            \n", | |||
|  |     "            # 按trial分组处理\n", | |||
|  |     "            trials = df['trial_idx'].unique()\n", | |||
|  |     "            \n", | |||
|  |     "            for trial_idx in trials:\n", | |||
|  |     "                trial_mask = df['trial_idx'] == trial_idx\n", | |||
|  |     "                trial_neural = neural_features[trial_mask]\n", | |||
|  |     "                trial_phonemes = phoneme_targets[trial_mask]\n", | |||
|  |     "                \n", | |||
|  |     "                # 创建时间窗口\n", | |||
|  |     "                windowed_X, windowed_y = self.create_time_windows(trial_neural, trial_phonemes)\n", | |||
|  |     "                \n", | |||
|  |     "                if windowed_X is not None and windowed_y is not None:\n", | |||
|  |     "                    all_X.append(windowed_X)\n", | |||
|  |     "                    all_y.append(windowed_y)\n", | |||
|  |     "        \n", | |||
|  |     "        if not all_X:\n", | |||
|  |     "            print(f\"❌ 没有有效的{dataset_type}数据\")\n", | |||
|  |     "            return None, None\n", | |||
|  |     "        \n", | |||
|  |     "        # 合并所有数据\n", | |||
|  |     "        X = np.vstack(all_X)\n", | |||
|  |     "        y = np.vstack(all_y)\n", | |||
|  |     "        \n", | |||
|  |     "        print(f\"   ✅ {dataset_type}数据准备完成:\")\n", | |||
|  |     "        print(f\"      特征矩阵形状: {X.shape}\")\n", | |||
|  |     "        print(f\"      目标矩阵形状: {y.shape}\")\n", | |||
|  |     "        print(f\"      内存使用: {X.nbytes / 1024**2:.1f} MB (X) + {y.nbytes / 1024**2:.1f} MB (y)\")\n", | |||
|  |     "        \n", | |||
|  |     "        return X, y\n", | |||
|  |     "    \n", | |||
|  |     "    def fit(self, X_train, y_train, X_val=None, y_val=None):\n", | |||
|  |     "        \"\"\"\n", | |||
|  |     "        训练模型\n", | |||
|  |     "        \n", | |||
|  |     "        参数:\n", | |||
|  |     "            X_train: 训练特征\n", | |||
|  |     "            y_train: 训练目标\n", | |||
|  |     "            X_val: 验证特征 (可选)\n", | |||
|  |     "            y_val: 验证目标 (可选)\n", | |||
|  |     "        \"\"\"\n", | |||
|  |     "        print(f\"\\n🚀 开始训练随机森林回归模型:\")\n", | |||
|  |     "        print(f\"   训练样本数: {X_train.shape[0]:,}\")\n", | |||
|  |     "        print(f\"   特征维度: {X_train.shape[1]:,}\")\n", | |||
|  |     "        print(f\"   目标维度: {y_train.shape[1]}\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 标准化特征\n", | |||
|  |     "        print(\"   🔄 标准化特征...\")\n", | |||
|  |     "        X_train_scaled = self.scaler.fit_transform(X_train)\n", | |||
|  |     "        \n", | |||
|  |     "        # 训练模型\n", | |||
|  |     "        print(\"   🌲 训练随机森林...\")\n", | |||
|  |     "        self.regressor.fit(X_train_scaled, y_train)\n", | |||
|  |     "        \n", | |||
|  |     "        self.is_fitted = True\n", | |||
|  |     "        \n", | |||
|  |     "        # 计算训练集性能\n", | |||
|  |     "        print(\"   📊 评估训练集性能...\")\n", | |||
|  |     "        train_predictions = self.regressor.predict(X_train_scaled)\n", | |||
|  |     "        train_mse = mean_squared_error(y_train, train_predictions)\n", | |||
|  |     "        train_r2 = r2_score(y_train, train_predictions)\n", | |||
|  |     "        train_mae = mean_absolute_error(y_train, train_predictions)\n", | |||
|  |     "        \n", | |||
|  |     "        print(f\"   ✅ 训练完成!\")\n", | |||
|  |     "        print(f\"      训练集 MSE: {train_mse:.6f}\")\n", | |||
|  |     "        print(f\"      训练集 R²: {train_r2:.4f}\")\n", | |||
|  |     "        print(f\"      训练集 MAE: {train_mae:.6f}\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 如果有验证集,计算验证集性能\n", | |||
|  |     "        if X_val is not None and y_val is not None:\n", | |||
|  |     "            print(\"   📊 评估验证集性能...\")\n", | |||
|  |     "            X_val_scaled = self.scaler.transform(X_val)\n", | |||
|  |     "            val_predictions = self.regressor.predict(X_val_scaled)\n", | |||
|  |     "            val_mse = mean_squared_error(y_val, val_predictions)\n", | |||
|  |     "            val_r2 = r2_score(y_val, val_predictions)\n", | |||
|  |     "            val_mae = mean_absolute_error(y_val, val_predictions)\n", | |||
|  |     "            \n", | |||
|  |     "            print(f\"      验证集 MSE: {val_mse:.6f}\")\n", | |||
|  |     "            print(f\"      验证集 R²: {val_r2:.4f}\")\n", | |||
|  |     "            print(f\"      验证集 MAE: {val_mae:.6f}\")\n", | |||
|  |     "            \n", | |||
|  |     "            return {\n", | |||
|  |     "                'train_mse': train_mse, 'train_r2': train_r2, 'train_mae': train_mae,\n", | |||
|  |     "                'val_mse': val_mse, 'val_r2': val_r2, 'val_mae': val_mae\n", | |||
|  |     "            }\n", | |||
|  |     "        \n", | |||
|  |     "        return {\n", | |||
|  |     "            'train_mse': train_mse, 'train_r2': train_r2, 'train_mae': train_mae\n", | |||
|  |     "        }\n", | |||
|  |     "    \n", | |||
|  |     "    def predict(self, X):\n", | |||
|  |     "        \"\"\"预测\"\"\"\n", | |||
|  |     "        if not self.is_fitted:\n", | |||
|  |     "            raise ValueError(\"模型尚未训练,请先调用fit()方法\")\n", | |||
|  |     "        \n", | |||
|  |     "        X_scaled = self.scaler.transform(X)\n", | |||
|  |     "        return self.regressor.predict(X_scaled)\n", | |||
|  |     "    \n", | |||
|  |     "    def get_feature_importance(self, top_k=20):\n", | |||
|  |     "        \"\"\"获取特征重要性\"\"\"\n", | |||
|  |     "        if not self.is_fitted:\n", | |||
|  |     "            raise ValueError(\"模型尚未训练,请先调用fit()方法\")\n", | |||
|  |     "        \n", | |||
|  |     "        importances = self.regressor.feature_importances_\n", | |||
|  |     "        \n", | |||
|  |     "        # 创建特征名称 (window_timestep_feature)\n", | |||
|  |     "        feature_names = []\n", | |||
|  |     "        for t in range(self.window_size):\n", | |||
|  |     "            for f in range(512):\n", | |||
|  |     "                feature_names.append(f\"t{t}_feat{f}\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 获取top-k重要特征\n", | |||
|  |     "        top_indices = np.argsort(importances)[::-1][:top_k]\n", | |||
|  |     "        top_features = [(feature_names[i], importances[i]) for i in top_indices]\n", | |||
|  |     "        \n", | |||
|  |     "        return top_features, importances\n", | |||
|  |     "\n", | |||
|  |     "# 初始化模型\n", | |||
|  |     "print(\"🌲 初始化随机森林回归模型\")\n", | |||
|  |     "rf_regressor = TimeWindowRandomForestRegressor(\n", | |||
|  |     "    window_size=30,      # 时间窗口大小\n", | |||
|  |     "    n_estimators=100,    # 树的数量\n", | |||
|  |     "    max_depth=10,        # 最大深度 (防止过拟合)\n", | |||
|  |     "    n_jobs=-1,           # 使用所有CPU核心\n", | |||
|  |     "    random_state=42\n", | |||
|  |     ")\n", | |||
|  |     "\n", | |||
|  |     "print(\"\\n✅ 随机森林回归器准备完成!\")\n", | |||
|  |     "print(\"🔧 下一步: 准备训练数据和开始训练\")" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 28, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "🚀 开始数据准备和模型训练流程\n", | |||
|  |       "============================================================\n", | |||
|  |       "❌ 没有可用的训练数据,请先运行数据处理工作流\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "# 🚀 准备数据并训练随机森林回归模型\n", | |||
|  |     "print(\"🚀 开始数据准备和模型训练流程\")\n", | |||
|  |     "print(\"=\"*60)\n", | |||
|  |     "\n", | |||
|  |     "# 检查数据可用性\n", | |||
|  |     "if not train_datasets:\n", | |||
|  |     "    print(\"❌ 没有可用的训练数据,请先运行数据处理工作流\")\n", | |||
|  |     "else:\n", | |||
|  |     "    print(f\"✅ 检测到数据:\")\n", | |||
|  |     "    print(f\"   训练数据集: {len(train_datasets)} 个sessions\")\n", | |||
|  |     "    print(f\"   验证数据集: {len(val_datasets)} 个sessions\")\n", | |||
|  |     "    print(f\"   测试数据集: {len(test_datasets)} 个sessions\")\n", | |||
|  |     "\n", | |||
|  |     "    # 1. 准备训练数据\n", | |||
|  |     "    print(f\"\\n📊 第1步: 准备训练数据\")\n", | |||
|  |     "    X_train, y_train = rf_regressor.prepare_dataset_for_training(train_datasets, \"训练集\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 2. 准备验证数据\n", | |||
|  |     "    print(f\"\\n📊 第2步: 准备验证数据\")\n", | |||
|  |     "    X_val, y_val = rf_regressor.prepare_dataset_for_training(val_datasets, \"验证集\")\n", | |||
|  |     "    \n", | |||
|  |     "    if X_train is not None and y_train is not None:\n", | |||
|  |     "        print(f\"\\n📈 数据准备完成统计:\")\n", | |||
|  |     "        print(f\"   训练集: {X_train.shape[0]:,} 样本\")\n", | |||
|  |     "        print(f\"   验证集: {X_val.shape[0]:,} 样本\" if X_val is not None else \"   验证集: 无\")\n", | |||
|  |     "        print(f\"   特征维度: {X_train.shape[1]:,} (时间窗口30 × 512特征)\")\n", | |||
|  |     "        print(f\"   目标维度: {y_train.shape[1]} (40个音素概率)\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 检查数据质量\n", | |||
|  |     "        print(f\"\\n🔍 数据质量检查:\")\n", | |||
|  |     "        print(f\"   训练特征范围: [{X_train.min():.4f}, {X_train.max():.4f}]\")\n", | |||
|  |     "        print(f\"   训练目标范围: [{y_train.min():.4f}, {y_train.max():.4f}]\")\n", | |||
|  |     "        print(f\"   训练特征均值: {X_train.mean():.4f}\")\n", | |||
|  |     "        print(f\"   训练目标均值: {y_train.mean():.4f}\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 检查是否有NaN或无穷值\n", | |||
|  |     "        nan_count_X = np.isnan(X_train).sum()\n", | |||
|  |     "        nan_count_y = np.isnan(y_train).sum()\n", | |||
|  |     "        inf_count_X = np.isinf(X_train).sum()\n", | |||
|  |     "        inf_count_y = np.isinf(y_train).sum()\n", | |||
|  |     "        \n", | |||
|  |     "        print(f\"   NaN检查: X有{nan_count_X}个, y有{nan_count_y}个\")\n", | |||
|  |     "        print(f\"   Inf检查: X有{inf_count_X}个, y有{inf_count_y}个\")\n", | |||
|  |     "        \n", | |||
|  |     "        if nan_count_X > 0 or nan_count_y > 0 or inf_count_X > 0 or inf_count_y > 0:\n", | |||
|  |     "            print(\"⚠️ 检测到异常值,将进行清理...\")\n", | |||
|  |     "            # 清理异常值\n", | |||
|  |     "            valid_mask = ~(np.isnan(X_train).any(axis=1) | np.isnan(y_train).any(axis=1) | \n", | |||
|  |     "                          np.isinf(X_train).any(axis=1) | np.isinf(y_train).any(axis=1))\n", | |||
|  |     "            X_train = X_train[valid_mask]\n", | |||
|  |     "            y_train = y_train[valid_mask]\n", | |||
|  |     "            \n", | |||
|  |     "            if X_val is not None and y_val is not None:\n", | |||
|  |     "                valid_mask_val = ~(np.isnan(X_val).any(axis=1) | np.isnan(y_val).any(axis=1) | \n", | |||
|  |     "                                  np.isinf(X_val).any(axis=1) | np.isinf(y_val).any(axis=1))\n", | |||
|  |     "                X_val = X_val[valid_mask_val]\n", | |||
|  |     "                y_val = y_val[valid_mask_val]\n", | |||
|  |     "            \n", | |||
|  |     "            print(f\"✅ 数据清理完成,剩余训练样本: {X_train.shape[0]:,}\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 3. 训练模型\n", | |||
|  |     "        print(f\"\\n🌲 第3步: 训练随机森林回归模型\")\n", | |||
|  |     "        training_results = rf_regressor.fit(X_train, y_train, X_val, y_val)\n", | |||
|  |     "        \n", | |||
|  |     "        # 4. 分析训练结果\n", | |||
|  |     "        print(f\"\\n📊 第4步: 训练结果分析\")\n", | |||
|  |     "        print(\"=\"*50)\n", | |||
|  |     "        \n", | |||
|  |     "        for metric, value in training_results.items():\n", | |||
|  |     "            metric_name = metric.replace('_', ' ').title()\n", | |||
|  |     "            print(f\"   {metric_name}: {value:.6f}\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 5. 特征重要性分析\n", | |||
|  |     "        print(f\"\\n🔍 第5步: 特征重要性分析\")\n", | |||
|  |     "        top_features, all_importances = rf_regressor.get_feature_importance(top_k=20)\n", | |||
|  |     "        \n", | |||
|  |     "        print(f\"\\n🏆 Top 20 重要特征:\")\n", | |||
|  |     "        print(f\"{'排名':>4} {'特征名称':>15} {'重要性':>10}\")\n", | |||
|  |     "        print(\"-\" * 35)\n", | |||
|  |     "        for i, (feature_name, importance) in enumerate(top_features):\n", | |||
|  |     "            print(f\"{i+1:>4} {feature_name:>15} {importance:>10.6f}\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 分析时间窗口内的重要性分布\n", | |||
|  |     "        print(f\"\\n📈 时间窗口重要性分布:\")\n", | |||
|  |     "        window_importances = np.zeros(rf_regressor.window_size)\n", | |||
|  |     "        for i in range(rf_regressor.window_size):\n", | |||
|  |     "            start_idx = i * 512\n", | |||
|  |     "            end_idx = (i + 1) * 512\n", | |||
|  |     "            window_importances[i] = all_importances[start_idx:end_idx].sum()\n", | |||
|  |     "        \n", | |||
|  |     "        max_time_step = np.argmax(window_importances)\n", | |||
|  |     "        print(f\"   最重要的时间步: t{max_time_step} (重要性: {window_importances[max_time_step]:.6f})\")\n", | |||
|  |     "        print(f\"   窗口中心位置: t{rf_regressor.window_size//2}\")\n", | |||
|  |     "        print(f\"   重要性分布: 前5个时间步的重要性\")\n", | |||
|  |     "        for i in range(min(5, len(window_importances))):\n", | |||
|  |     "            print(f\"      t{i}: {window_importances[i]:.6f}\")\n", | |||
|  |     "        \n", | |||
|  |     "        print(f\"\\n✅ 随机森林回归模型训练完成!\")\n", | |||
|  |     "        print(f\"🎯 模型可以预测40个音素的概率分布\")\n", | |||
|  |     "        print(f\"📊 基于30时间步的神经特征窗口\")\n", | |||
|  |     "        print(f\"🌲 使用{rf_regressor.n_estimators}棵决策树\")\n", | |||
|  |     "        \n", | |||
|  |     "    else:\n", | |||
|  |     "        print(\"❌ 数据准备失败,无法训练模型\")" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 17, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "⚠️ 模型尚未训练完成,请先运行训练代码\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "# 📊 模型评估和可视化分析\n", | |||
|  |     "def evaluate_phoneme_predictions(rf_model, X_test, y_test, dataset_name=\"测试集\"):\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    评估每个音素的预测性能\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    print(f\"\\n📊 {dataset_name}详细评估\")\n", | |||
|  |     "    print(\"=\"*50)\n", | |||
|  |     "    \n", | |||
|  |     "    # 获取预测结果\n", | |||
|  |     "    y_pred = rf_model.predict(X_test)\n", | |||
|  |     "    \n", | |||
|  |     "    # 计算每个音素的性能指标\n", | |||
|  |     "    phoneme_metrics = []\n", | |||
|  |     "    \n", | |||
|  |     "    for i in range(40):  # 40个音素\n", | |||
|  |     "        phoneme_name = LOGIT_TO_PHONEME[i]\n", | |||
|  |     "        \n", | |||
|  |     "        # 计算单个音素的指标\n", | |||
|  |     "        mse = mean_squared_error(y_test[:, i], y_pred[:, i])\n", | |||
|  |     "        r2 = r2_score(y_test[:, i], y_pred[:, i])\n", | |||
|  |     "        mae = mean_absolute_error(y_test[:, i], y_pred[:, i])\n", | |||
|  |     "        \n", | |||
|  |     "        # 计算相关系数\n", | |||
|  |     "        correlation = np.corrcoef(y_test[:, i], y_pred[:, i])[0, 1]\n", | |||
|  |     "        \n", | |||
|  |     "        phoneme_metrics.append({\n", | |||
|  |     "            'phoneme_id': i,\n", | |||
|  |     "            'phoneme_name': phoneme_name,\n", | |||
|  |     "            'mse': mse, \n", | |||
|  |     "            'r2': r2,\n", | |||
|  |     "            'mae': mae,\n", | |||
|  |     "            'correlation': correlation if not np.isnan(correlation) else 0.0\n", | |||
|  |     "        })\n", | |||
|  |     "    \n", | |||
|  |     "    # 转换为DataFrame便于分析\n", | |||
|  |     "    metrics_df = pd.DataFrame(phoneme_metrics)\n", | |||
|  |     "    \n", | |||
|  |     "    # 打印总体统计\n", | |||
|  |     "    print(f\"📈 总体性能指标:\")\n", | |||
|  |     "    print(f\"   平均 MSE: {metrics_df['mse'].mean():.6f}\")\n", | |||
|  |     "    print(f\"   平均 R²: {metrics_df['r2'].mean():.4f}\")\n", | |||
|  |     "    print(f\"   平均 MAE: {metrics_df['mae'].mean():.6f}\")\n", | |||
|  |     "    print(f\"   平均相关系数: {metrics_df['correlation'].mean():.4f}\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 找出最佳和最差预测的音素\n", | |||
|  |     "    best_r2_idx = metrics_df['r2'].idxmax()\n", | |||
|  |     "    worst_r2_idx = metrics_df['r2'].idxmin()\n", | |||
|  |     "    \n", | |||
|  |     "    print(f\"\\n🏆 最佳预测音素:\")\n", | |||
|  |     "    best_phoneme = metrics_df.loc[best_r2_idx]\n", | |||
|  |     "    print(f\"   {best_phoneme['phoneme_name']} (ID: {best_phoneme['phoneme_id']})\")\n", | |||
|  |     "    print(f\"   R²: {best_phoneme['r2']:.4f}, MSE: {best_phoneme['mse']:.6f}\")\n", | |||
|  |     "    \n", | |||
|  |     "    print(f\"\\n📉 最差预测音素:\")\n", | |||
|  |     "    worst_phoneme = metrics_df.loc[worst_r2_idx]\n", | |||
|  |     "    print(f\"   {worst_phoneme['phoneme_name']} (ID: {worst_phoneme['phoneme_id']})\")\n", | |||
|  |     "    print(f\"   R²: {worst_phoneme['r2']:.4f}, MSE: {worst_phoneme['mse']:.6f}\")\n", | |||
|  |     "    \n", | |||
|  |     "    return metrics_df, y_pred\n", | |||
|  |     "\n", | |||
|  |     "def visualize_prediction_results(metrics_df, y_true, y_pred, save_plots=False):\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    可视化预测结果\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    print(f\"\\n📊 创建可视化图表...\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 设置图表样式\n", | |||
|  |     "    plt.style.use('default')\n", | |||
|  |     "    fig = plt.figure(figsize=(20, 12))\n", | |||
|  |     "    \n", | |||
|  |     "    # 1. R²分数分布\n", | |||
|  |     "    plt.subplot(2, 3, 1)\n", | |||
|  |     "    plt.hist(metrics_df['r2'], bins=20, alpha=0.7, color='skyblue', edgecolor='black')\n", | |||
|  |     "    plt.axvline(metrics_df['r2'].mean(), color='red', linestyle='--', \n", | |||
|  |     "                label=f'平均值: {metrics_df[\"r2\"].mean():.4f}')\n", | |||
|  |     "    plt.xlabel('R² Score')\n", | |||
|  |     "    plt.ylabel('音素数量')\n", | |||
|  |     "    plt.title('R² Score 分布')\n", | |||
|  |     "    plt.legend()\n", | |||
|  |     "    plt.grid(True, alpha=0.3)\n", | |||
|  |     "    \n", | |||
|  |     "    # 2. MSE分布\n", | |||
|  |     "    plt.subplot(2, 3, 2)\n", | |||
|  |     "    plt.hist(metrics_df['mse'], bins=20, alpha=0.7, color='lightcoral', edgecolor='black')\n", | |||
|  |     "    plt.axvline(metrics_df['mse'].mean(), color='red', linestyle='--',\n", | |||
|  |     "                label=f'平均值: {metrics_df[\"mse\"].mean():.6f}')\n", | |||
|  |     "    plt.xlabel('Mean Squared Error')\n", | |||
|  |     "    plt.ylabel('音素数量')\n", | |||
|  |     "    plt.title('MSE 分布')\n", | |||
|  |     "    plt.legend()\n", | |||
|  |     "    plt.grid(True, alpha=0.3)\n", | |||
|  |     "    \n", | |||
|  |     "    # 3. 前10个音素的性能对比\n", | |||
|  |     "    plt.subplot(2, 3, 3)\n", | |||
|  |     "    top_10 = metrics_df.nlargest(10, 'r2')\n", | |||
|  |     "    bars = plt.bar(range(10), top_10['r2'], color='lightgreen', alpha=0.7)\n", | |||
|  |     "    plt.xlabel('音素排名')\n", | |||
|  |     "    plt.ylabel('R² Score')\n", | |||
|  |     "    plt.title('Top 10 音素预测性能')\n", | |||
|  |     "    plt.xticks(range(10), top_10['phoneme_name'], rotation=45)\n", | |||
|  |     "    plt.grid(True, alpha=0.3)\n", | |||
|  |     "    \n", | |||
|  |     "    # 添加数值标签\n", | |||
|  |     "    for i, bar in enumerate(bars):\n", | |||
|  |     "        height = bar.get_height()\n", | |||
|  |     "        plt.text(bar.get_x() + bar.get_width()/2., height + 0.001,\n", | |||
|  |     "                f'{height:.3f}', ha='center', va='bottom', fontsize=8)\n", | |||
|  |     "    \n", | |||
|  |     "    # 4. 真实值 vs 预测值散点图 (选择最佳音素)\n", | |||
|  |     "    plt.subplot(2, 3, 4)\n", | |||
|  |     "    best_phoneme_idx = metrics_df['r2'].idxmax()\n", | |||
|  |     "    phoneme_id = metrics_df.loc[best_phoneme_idx, 'phoneme_id']\n", | |||
|  |     "    phoneme_name = metrics_df.loc[best_phoneme_idx, 'phoneme_name']\n", | |||
|  |     "    \n", | |||
|  |     "    # 随机采样1000个点以避免图表过于密集\n", | |||
|  |     "    sample_size = min(1000, len(y_true))\n", | |||
|  |     "    sample_indices = np.random.choice(len(y_true), sample_size, replace=False)\n", | |||
|  |     "    \n", | |||
|  |     "    plt.scatter(y_true[sample_indices, phoneme_id], y_pred[sample_indices, phoneme_id], \n", | |||
|  |     "               alpha=0.6, s=20, color='blue')\n", | |||
|  |     "    \n", | |||
|  |     "    # 添加对角线 (完美预测线)\n", | |||
|  |     "    min_val = min(y_true[:, phoneme_id].min(), y_pred[:, phoneme_id].min())\n", | |||
|  |     "    max_val = max(y_true[:, phoneme_id].max(), y_pred[:, phoneme_id].max())\n", | |||
|  |     "    plt.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8, label='完美预测')\n", | |||
|  |     "    \n", | |||
|  |     "    plt.xlabel('真实值')\n", | |||
|  |     "    plt.ylabel('预测值')\n", | |||
|  |     "    plt.title(f'最佳音素 {phoneme_name} 的预测结果')\n", | |||
|  |     "    plt.legend()\n", | |||
|  |     "    plt.grid(True, alpha=0.3)\n", | |||
|  |     "    \n", | |||
|  |     "    # 5. 相关系数热力图 (前20个音素)\n", | |||
|  |     "    plt.subplot(2, 3, 5)\n", | |||
|  |     "    top_20_correlations = metrics_df.nlargest(20, 'correlation')\n", | |||
|  |     "    corr_data = top_20_correlations[['phoneme_name', 'correlation']].set_index('phoneme_name')\n", | |||
|  |     "    \n", | |||
|  |     "    # 创建热力图数据\n", | |||
|  |     "    heatmap_data = corr_data.values.reshape(-1, 1)\n", | |||
|  |     "    im = plt.imshow(heatmap_data.T, cmap='RdYlBu_r', aspect='auto', vmin=0, vmax=1)\n", | |||
|  |     "    \n", | |||
|  |     "    plt.colorbar(im, shrink=0.8)\n", | |||
|  |     "    plt.yticks([0], ['相关系数'])\n", | |||
|  |     "    plt.xticks(range(len(top_20_correlations)), top_20_correlations['phoneme_name'], \n", | |||
|  |     "               rotation=45, ha='right')\n", | |||
|  |     "    plt.title('Top 20 音素相关系数')\n", | |||
|  |     "    \n", | |||
|  |     "    # 6. 各音素预测误差箱线图 (前10个音素)\n", | |||
|  |     "    plt.subplot(2, 3, 6)\n", | |||
|  |     "    top_10_ids = metrics_df.nlargest(10, 'r2')['phoneme_id'].values\n", | |||
|  |     "    errors_data = []\n", | |||
|  |     "    labels = []\n", | |||
|  |     "    \n", | |||
|  |     "    for phoneme_id in top_10_ids:\n", | |||
|  |     "        errors = np.abs(y_true[:, phoneme_id] - y_pred[:, phoneme_id])\n", | |||
|  |     "        errors_data.append(errors)\n", | |||
|  |     "        labels.append(LOGIT_TO_PHONEME[phoneme_id])\n", | |||
|  |     "    \n", | |||
|  |     "    plt.boxplot(errors_data, labels=labels)\n", | |||
|  |     "    plt.xlabel('音素')\n", | |||
|  |     "    plt.ylabel('绝对误差')\n", | |||
|  |     "    plt.title('Top 10 音素预测误差分布')\n", | |||
|  |     "    plt.xticks(rotation=45)\n", | |||
|  |     "    plt.grid(True, alpha=0.3)\n", | |||
|  |     "    \n", | |||
|  |     "    plt.tight_layout()\n", | |||
|  |     "    \n", | |||
|  |     "    if save_plots:\n", | |||
|  |     "        plt.savefig('./processed_datasets/rf_regression_results.png', dpi=300, bbox_inches='tight')\n", | |||
|  |     "        print(\"📁 图表已保存至: ./processed_datasets/rf_regression_results.png\")\n", | |||
|  |     "    \n", | |||
|  |     "    plt.show()\n", | |||
|  |     "\n", | |||
|  |     "# 如果模型训练成功,进行评估\n", | |||
|  |     "if 'rf_regressor' in locals() and rf_regressor.is_fitted:\n", | |||
|  |     "    print(f\"\\n🎯 开始模型评估和可视化\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 评估验证集\n", | |||
|  |     "    if X_val is not None and y_val is not None:\n", | |||
|  |     "        val_metrics, val_predictions = evaluate_phoneme_predictions(\n", | |||
|  |     "            rf_regressor, X_val, y_val, \"验证集\"\n", | |||
|  |     "        )\n", | |||
|  |     "        \n", | |||
|  |     "        # 可视化结果\n", | |||
|  |     "        visualize_prediction_results(val_metrics, y_val, val_predictions, save_plots=True)\n", | |||
|  |     "        \n", | |||
|  |     "        # 保存详细结果\n", | |||
|  |     "        val_metrics.to_csv('./processed_datasets/phoneme_prediction_metrics.csv', index=False)\n", | |||
|  |     "        print(f\"\\n📁 详细评估结果已保存至: ./processed_datasets/phoneme_prediction_metrics.csv\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 准备测试集数据 (如果有)\n", | |||
|  |     "        if test_datasets:\n", | |||
|  |     "            print(f\"\\n🔮 准备测试集预测...\")\n", | |||
|  |     "            X_test, y_test = rf_regressor.prepare_dataset_for_training(test_datasets, \"测试集\")\n", | |||
|  |     "            \n", | |||
|  |     "            if X_test is not None:\n", | |||
|  |     "                test_metrics, test_predictions = evaluate_phoneme_predictions(\n", | |||
|  |     "                    rf_regressor, X_test, y_test, \"测试集\"\n", | |||
|  |     "                )\n", | |||
|  |     "                print(f\"\\n✅ 测试集评估完成\")\n", | |||
|  |     "            else:\n", | |||
|  |     "                print(f\"⚠️ 测试集数据准备失败\")\n", | |||
|  |     "    \n", | |||
|  |     "    print(f\"\\n🎉 随机森林回归模型完整评估完成!\")\n", | |||
|  |     "    print(f\"📊 生成了详细的性能分析和可视化图表\")\n", | |||
|  |     "    print(f\"🔧 模型已准备好用于实际预测任务\")\n", | |||
|  |     "    \n", | |||
|  |     "else:\n", | |||
|  |     "    print(\"⚠️ 模型尚未训练完成,请先运行训练代码\")" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 18, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "✅ 回归转分类分析功能已创建!\n", | |||
|  |       "🎯 主要功能:\n", | |||
|  |       "• 将40维概率回归结果转换为分类预测\n", | |||
|  |       "• 计算分类准确率和置信度分析\n", | |||
|  |       "• 提供Top-K准确率评估\n", | |||
|  |       "• 生成详细的混淆矩阵和错误分析\n", | |||
|  |       "• 创建全面的可视化图表\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "# 🎯 回归结果转分类结果分析\n", | |||
|  |     "def regression_to_classification_analysis(y_true_probs, y_pred_probs, show_detailed_metrics=True):\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    将回归预测的40个音素概率转换为分类结果并进行分析\n", | |||
|  |     "    \n", | |||
|  |     "    参数:\n", | |||
|  |     "        y_true_probs: 真实的40个音素概率 [n_samples, 40]\n", | |||
|  |     "        y_pred_probs: 预测的40个音素概率 [n_samples, 40]\n", | |||
|  |     "        show_detailed_metrics: 是否显示详细的分类指标\n", | |||
|  |     "     \n", | |||
|  |     "    返回:\n", | |||
|  |     "        classification_results: 包含分类结果的字典\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    print(\"🎯 回归结果转分类结果分析\")\n", | |||
|  |     "    print(\"=\"*60)\n", | |||
|  |     "    \n", | |||
|  |     "    # 1. 将概率转换为分类标签\n", | |||
|  |     "    y_true_classes = np.argmax(y_true_probs, axis=1)  # 真实类别\n", | |||
|  |     "    y_pred_classes = np.argmax(y_pred_probs, axis=1)  # 预测类别\n", | |||
|  |     "    \n", | |||
|  |     "    # 2. 计算分类准确率\n", | |||
|  |     "    accuracy = (y_true_classes == y_pred_classes).mean()\n", | |||
|  |     "    \n", | |||
|  |     "    print(f\"📊 分类结果概览:\")\n", | |||
|  |     "    print(f\"   总样本数: {len(y_true_classes):,}\")\n", | |||
|  |     "    print(f\"   分类准确率: {accuracy:.4f} ({accuracy*100:.2f}%)\")\n", | |||
|  |     "    print(f\"   正确预测: {(y_true_classes == y_pred_classes).sum():,}\")\n", | |||
|  |     "    print(f\"   错误预测: {(y_true_classes != y_pred_classes).sum():,}\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 3. 分析预测置信度\n", | |||
|  |     "    pred_confidences = np.max(y_pred_probs, axis=1)  # 预测的最大概率\n", | |||
|  |     "    true_confidences = np.max(y_true_probs, axis=1)  # 真实的最大概率\n", | |||
|  |     "    \n", | |||
|  |     "    print(f\"\\n🔍 预测置信度分析:\")\n", | |||
|  |     "    print(f\"   预测置信度均值: {pred_confidences.mean():.4f}\")\n", | |||
|  |     "    print(f\"   预测置信度标准差: {pred_confidences.std():.4f}\")\n", | |||
|  |     "    print(f\"   预测置信度范围: [{pred_confidences.min():.4f}, {pred_confidences.max():.4f}]\")\n", | |||
|  |     "    print(f\"   真实置信度均值: {true_confidences.mean():.4f}\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 4. 按置信度分组的准确率分析\n", | |||
|  |     "    confidence_bins = [0.0, 0.3, 0.5, 0.7, 0.9, 1.0]\n", | |||
|  |     "    print(f\"\\n📈 按预测置信度分组的准确率:\")\n", | |||
|  |     "    print(f\"{'置信度区间':>12} {'样本数':>8} {'准确率':>8} {'百分比':>8}\")\n", | |||
|  |     "    print(\"-\" * 40)\n", | |||
|  |     "    \n", | |||
|  |     "    for i in range(len(confidence_bins)-1):\n", | |||
|  |     "        low, high = confidence_bins[i], confidence_bins[i+1]\n", | |||
|  |     "        mask = (pred_confidences >= low) & (pred_confidences < high)\n", | |||
|  |     "        if i == len(confidence_bins)-2:  # 最后一个区间包含等号\n", | |||
|  |     "            mask = (pred_confidences >= low) & (pred_confidences <= high)\n", | |||
|  |     "        \n", | |||
|  |     "        if mask.sum() > 0:\n", | |||
|  |     "            bin_accuracy = (y_true_classes[mask] == y_pred_classes[mask]).mean()\n", | |||
|  |     "            count = mask.sum()\n", | |||
|  |     "            percentage = count / len(y_true_classes) * 100\n", | |||
|  |     "            print(f\"[{low:.1f}, {high:.1f}{')'if i<len(confidence_bins)-2 else ']':>1} {count:>8} {bin_accuracy:>8.4f} {percentage:>7.1f}%\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 5. 混淆矩阵分析(Top-K音素)\n", | |||
|  |     "    from collections import Counter\n", | |||
|  |     "    \n", | |||
|  |     "    # 找出最常见的音素\n", | |||
|  |     "    true_counter = Counter(y_true_classes)\n", | |||
|  |     "    pred_counter = Counter(y_pred_classes)\n", | |||
|  |     "    \n", | |||
|  |     "    most_common_true = true_counter.most_common(10)\n", | |||
|  |     "    most_common_pred = pred_counter.most_common(10)\n", | |||
|  |     "    \n", | |||
|  |     "    print(f\"\\n🏆 最常见的音素 (真实 vs 预测):\")\n", | |||
|  |     "    print(f\"{'真实音素':>12} {'次数':>6} {'预测音素':>12} {'次数':>6}\")\n", | |||
|  |     "    print(\"-\" * 42)\n", | |||
|  |     "    \n", | |||
|  |     "    for i in range(min(len(most_common_true), len(most_common_pred))):\n", | |||
|  |     "        true_id, true_count = most_common_true[i]\n", | |||
|  |     "        pred_id, pred_count = most_common_pred[i]\n", | |||
|  |     "        true_name = LOGIT_TO_PHONEME[true_id]\n", | |||
|  |     "        pred_name = LOGIT_TO_PHONEME[pred_id]\n", | |||
|  |     "        print(f\"{true_name:>12} {true_count:>6} {pred_name:>12} {pred_count:>6}\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 6. 每个音素的分类性能\n", | |||
|  |     "    if show_detailed_metrics:\n", | |||
|  |     "        from sklearn.metrics import classification_report, confusion_matrix\n", | |||
|  |     "        \n", | |||
|  |     "        print(f\"\\n📋 详细分类报告 (前20个最常见音素):\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 获取前20个最常见的音素\n", | |||
|  |     "        top_20_phonemes = [phoneme_id for phoneme_id, _ in most_common_true[:20]]\n", | |||
|  |     "        \n", | |||
|  |     "        # 创建掩码,只包含这些音素\n", | |||
|  |     "        mask_top20 = np.isin(y_true_classes, top_20_phonemes)\n", | |||
|  |     "        y_true_top20 = y_true_classes[mask_top20]\n", | |||
|  |     "        y_pred_top20 = y_pred_classes[mask_top20]\n", | |||
|  |     "        \n", | |||
|  |     "        # 生成分类报告\n", | |||
|  |     "        target_names = [LOGIT_TO_PHONEME[i] for i in top_20_phonemes]\n", | |||
|  |     "        \n", | |||
|  |     "        try:\n", | |||
|  |     "            report = classification_report(\n", | |||
|  |     "                y_true_top20, y_pred_top20, \n", | |||
|  |     "                labels=top_20_phonemes,\n", | |||
|  |     "                target_names=target_names,\n", | |||
|  |     "                output_dict=True,\n", | |||
|  |     "                zero_division=0\n", | |||
|  |     "            )\n", | |||
|  |     "            \n", | |||
|  |     "            # 打印格式化的报告\n", | |||
|  |     "            print(f\"{'音素':>8} {'精确率':>8} {'召回率':>8} {'F1分数':>8} {'支持数':>8}\")\n", | |||
|  |     "            print(\"-\" * 48)\n", | |||
|  |     "            \n", | |||
|  |     "            for phoneme_id in top_20_phonemes:\n", | |||
|  |     "                phoneme_name = LOGIT_TO_PHONEME[phoneme_id]\n", | |||
|  |     "                if phoneme_name in report:\n", | |||
|  |     "                    metrics = report[phoneme_name]\n", | |||
|  |     "                    print(f\"{phoneme_name:>8} {metrics['precision']:>8.4f} {metrics['recall']:>8.4f} \"\n", | |||
|  |     "                          f\"{metrics['f1-score']:>8.4f} {int(metrics['support']):>8}\")\n", | |||
|  |     "            \n", | |||
|  |     "            # 总体指标\n", | |||
|  |     "            macro_avg = report['macro avg']\n", | |||
|  |     "            weighted_avg = report['weighted avg']\n", | |||
|  |     "            print(\"-\" * 48)\n", | |||
|  |     "            print(f\"{'宏平均':>8} {macro_avg['precision']:>8.4f} {macro_avg['recall']:>8.4f} \"\n", | |||
|  |     "                  f\"{macro_avg['f1-score']:>8.4f}\")\n", | |||
|  |     "            print(f\"{'加权平均':>8} {weighted_avg['precision']:>8.4f} {weighted_avg['recall']:>8.4f} \"\n", | |||
|  |     "                  f\"{weighted_avg['f1-score']:>8.4f}\")\n", | |||
|  |     "            \n", | |||
|  |     "        except Exception as e:\n", | |||
|  |     "            print(f\"分类报告生成失败: {e}\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 7. Top-K准确率分析\n", | |||
|  |     "    print(f\"\\n🎯 Top-K 准确率分析:\")\n", | |||
|  |     "    for k in [1, 3, 5, 10]:\n", | |||
|  |     "        # 计算Top-K准确率\n", | |||
|  |     "        top_k_pred = np.argsort(y_pred_probs, axis=1)[:, -k:]  # 取概率最高的K个\n", | |||
|  |     "        top_k_accuracy = np.mean([y_true_classes[i] in top_k_pred[i] for i in range(len(y_true_classes))])\n", | |||
|  |     "        print(f\"   Top-{k} 准确率: {top_k_accuracy:.4f} ({top_k_accuracy*100:.2f}%)\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 8. 错误分析 - 最常见的预测错误\n", | |||
|  |     "    print(f\"\\n❌ 最常见的预测错误:\")\n", | |||
|  |     "    error_mask = y_true_classes != y_pred_classes\n", | |||
|  |     "    error_pairs = list(zip(y_true_classes[error_mask], y_pred_classes[error_mask]))\n", | |||
|  |     "    error_counter = Counter(error_pairs)\n", | |||
|  |     "    \n", | |||
|  |     "    print(f\"{'真实音素':>12} {'预测音素':>12} {'错误次数':>8}\")\n", | |||
|  |     "    print(\"-\" * 36)\n", | |||
|  |     "    for (true_id, pred_id), count in error_counter.most_common(10):\n", | |||
|  |     "        true_name = LOGIT_TO_PHONEME[true_id]\n", | |||
|  |     "        pred_name = LOGIT_TO_PHONEME[pred_id]\n", | |||
|  |     "        print(f\"{true_name:>12} {pred_name:>12} {count:>8}\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 返回结果字典\n", | |||
|  |     "    classification_results = {\n", | |||
|  |     "        'accuracy': accuracy,\n", | |||
|  |     "        'y_true_classes': y_true_classes,\n", | |||
|  |     "        'y_pred_classes': y_pred_classes,\n", | |||
|  |     "        'pred_confidences': pred_confidences,\n", | |||
|  |     "        'true_confidences': true_confidences,\n", | |||
|  |     "        'most_common_errors': error_counter.most_common(10)\n", | |||
|  |     "    }\n", | |||
|  |     "    \n", | |||
|  |     "    return classification_results\n", | |||
|  |     "\n", | |||
|  |     "def create_classification_visualizations(y_true_probs, y_pred_probs, classification_results):\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    为分类结果创建可视化图表\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    print(f\"\\n📊 创建分类结果可视化...\")\n", | |||
|  |     "    \n", | |||
|  |     "    fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n", | |||
|  |     "    fig.suptitle('随机森林回归转分类结果分析', fontsize=16, fontweight='bold')\n", | |||
|  |     "    \n", | |||
|  |     "    y_true_classes = classification_results['y_true_classes']\n", | |||
|  |     "    y_pred_classes = classification_results['y_pred_classes']\n", | |||
|  |     "    pred_confidences = classification_results['pred_confidences']\n", | |||
|  |     "    \n", | |||
|  |     "    # 1. 预测置信度分布\n", | |||
|  |     "    axes[0, 0].hist(pred_confidences, bins=50, alpha=0.7, color='skyblue', edgecolor='black')\n", | |||
|  |     "    axes[0, 0].axvline(pred_confidences.mean(), color='red', linestyle='--', \n", | |||
|  |     "                      label=f'均值: {pred_confidences.mean():.3f}')\n", | |||
|  |     "    axes[0, 0].set_xlabel('预测置信度')\n", | |||
|  |     "    axes[0, 0].set_ylabel('样本数量')\n", | |||
|  |     "    axes[0, 0].set_title('预测置信度分布')\n", | |||
|  |     "    axes[0, 0].legend()\n", | |||
|  |     "    axes[0, 0].grid(True, alpha=0.3)\n", | |||
|  |     "    \n", | |||
|  |     "    # 2. 准确率 vs 置信度\n", | |||
|  |     "    confidence_bins = np.linspace(0, 1, 21)\n", | |||
|  |     "    bin_centers = (confidence_bins[:-1] + confidence_bins[1:]) / 2\n", | |||
|  |     "    bin_accuracies = []\n", | |||
|  |     "    bin_counts = []\n", | |||
|  |     "    \n", | |||
|  |     "    for i in range(len(confidence_bins)-1):\n", | |||
|  |     "        mask = (pred_confidences >= confidence_bins[i]) & (pred_confidences < confidence_bins[i+1])\n", | |||
|  |     "        if mask.sum() > 0:\n", | |||
|  |     "            accuracy = (y_true_classes[mask] == y_pred_classes[mask]).mean()\n", | |||
|  |     "            bin_accuracies.append(accuracy)\n", | |||
|  |     "            bin_counts.append(mask.sum())\n", | |||
|  |     "        else:\n", | |||
|  |     "            bin_accuracies.append(0)\n", | |||
|  |     "            bin_counts.append(0)\n", | |||
|  |     "    \n", | |||
|  |     "    # 只显示有数据的bins\n", | |||
|  |     "    valid_bins = np.array(bin_counts) > 0\n", | |||
|  |     "    axes[0, 1].plot(bin_centers[valid_bins], np.array(bin_accuracies)[valid_bins], \n", | |||
|  |     "                   'bo-', linewidth=2, markersize=6)\n", | |||
|  |     "    axes[0, 1].set_xlabel('预测置信度')\n", | |||
|  |     "    axes[0, 1].set_ylabel('准确率')\n", | |||
|  |     "    axes[0, 1].set_title('准确率 vs 预测置信度')\n", | |||
|  |     "    axes[0, 1].grid(True, alpha=0.3)\n", | |||
|  |     "    axes[0, 1].set_ylim(0, 1)\n", | |||
|  |     "    \n", | |||
|  |     "    # 3. 最常见音素的预测准确率\n", | |||
|  |     "    from collections import Counter\n", | |||
|  |     "    true_counter = Counter(y_true_classes)\n", | |||
|  |     "    most_common_phonemes = [phoneme_id for phoneme_id, _ in true_counter.most_common(15)]\n", | |||
|  |     "    \n", | |||
|  |     "    phoneme_accuracies = []\n", | |||
|  |     "    phoneme_names = []\n", | |||
|  |     "    for phoneme_id in most_common_phonemes:\n", | |||
|  |     "        mask = y_true_classes == phoneme_id\n", | |||
|  |     "        if mask.sum() > 0:\n", | |||
|  |     "            accuracy = (y_pred_classes[mask] == phoneme_id).mean()\n", | |||
|  |     "            phoneme_accuracies.append(accuracy)\n", | |||
|  |     "            phoneme_names.append(LOGIT_TO_PHONEME[phoneme_id])\n", | |||
|  |     "    \n", | |||
|  |     "    bars = axes[0, 2].bar(range(len(phoneme_names)), phoneme_accuracies, \n", | |||
|  |     "                         color='lightgreen', alpha=0.7)\n", | |||
|  |     "    axes[0, 2].set_xlabel('音素')\n", | |||
|  |     "    axes[0, 2].set_ylabel('准确率')\n", | |||
|  |     "    axes[0, 2].set_title('Top 15 音素的分类准确率')\n", | |||
|  |     "    axes[0, 2].set_xticks(range(len(phoneme_names)))\n", | |||
|  |     "    axes[0, 2].set_xticklabels(phoneme_names, rotation=45, ha='right')\n", | |||
|  |     "    axes[0, 2].grid(True, alpha=0.3)\n", | |||
|  |     "    \n", | |||
|  |     "    # 添加数值标签\n", | |||
|  |     "    for bar, acc in zip(bars, phoneme_accuracies):\n", | |||
|  |     "        height = bar.get_height()\n", | |||
|  |     "        axes[0, 2].text(bar.get_x() + bar.get_width()/2., height + 0.01,\n", | |||
|  |     "                       f'{acc:.3f}', ha='center', va='bottom', fontsize=8)\n", | |||
|  |     "    \n", | |||
|  |     "    # 4. 混淆矩阵(前10个最常见音素)\n", | |||
|  |     "    from sklearn.metrics import confusion_matrix\n", | |||
|  |     "    top_10_phonemes = most_common_phonemes[:10]\n", | |||
|  |     "    mask_top10 = np.isin(y_true_classes, top_10_phonemes) & np.isin(y_pred_classes, top_10_phonemes)\n", | |||
|  |     "    \n", | |||
|  |     "    if mask_top10.sum() > 0:\n", | |||
|  |     "        cm = confusion_matrix(y_true_classes[mask_top10], y_pred_classes[mask_top10], \n", | |||
|  |     "                            labels=top_10_phonemes)\n", | |||
|  |     "        \n", | |||
|  |     "        im = axes[1, 0].imshow(cm, interpolation='nearest', cmap='Blues')\n", | |||
|  |     "        axes[1, 0].set_title('混淆矩阵 (Top 10 音素)')\n", | |||
|  |     "        \n", | |||
|  |     "        # 添加颜色条\n", | |||
|  |     "        cbar = plt.colorbar(im, ax=axes[1, 0], shrink=0.8)\n", | |||
|  |     "        cbar.set_label('预测次数')\n", | |||
|  |     "        \n", | |||
|  |     "        # 设置标签\n", | |||
|  |     "        tick_marks = np.arange(len(top_10_phonemes))\n", | |||
|  |     "        top_10_names = [LOGIT_TO_PHONEME[i] for i in top_10_phonemes]\n", | |||
|  |     "        axes[1, 0].set_xticks(tick_marks)\n", | |||
|  |     "        axes[1, 0].set_yticks(tick_marks)\n", | |||
|  |     "        axes[1, 0].set_xticklabels(top_10_names, rotation=45, ha='right')\n", | |||
|  |     "        axes[1, 0].set_yticklabels(top_10_names)\n", | |||
|  |     "        axes[1, 0].set_xlabel('预测音素')\n", | |||
|  |     "        axes[1, 0].set_ylabel('真实音素')\n", | |||
|  |     "    \n", | |||
|  |     "    # 5. Top-K准确率\n", | |||
|  |     "    k_values = [1, 2, 3, 4, 5, 10, 15, 20]\n", | |||
|  |     "    top_k_accuracies = []\n", | |||
|  |     "    \n", | |||
|  |     "    for k in k_values:\n", | |||
|  |     "        top_k_pred = np.argsort(y_pred_probs, axis=1)[:, -k:]\n", | |||
|  |     "        top_k_accuracy = np.mean([y_true_classes[i] in top_k_pred[i] for i in range(len(y_true_classes))])\n", | |||
|  |     "        top_k_accuracies.append(top_k_accuracy)\n", | |||
|  |     "    \n", | |||
|  |     "    axes[1, 1].plot(k_values, top_k_accuracies, 'ro-', linewidth=2, markersize=8)\n", | |||
|  |     "    axes[1, 1].set_xlabel('K 值')\n", | |||
|  |     "    axes[1, 1].set_ylabel('Top-K 准确率')\n", | |||
|  |     "    axes[1, 1].set_title('Top-K 准确率曲线')\n", | |||
|  |     "    axes[1, 1].grid(True, alpha=0.3)\n", | |||
|  |     "    axes[1, 1].set_ylim(0, 1)\n", | |||
|  |     "    \n", | |||
|  |     "    # 添加数值标签\n", | |||
|  |     "    for k, acc in zip(k_values, top_k_accuracies):\n", | |||
|  |     "        axes[1, 1].annotate(f'{acc:.3f}', (k, acc), textcoords=\"offset points\", \n", | |||
|  |     "                           xytext=(0,10), ha='center')\n", | |||
|  |     "    \n", | |||
|  |     "    # 6. 错误分析 - 最常见错误的热力图\n", | |||
|  |     "    error_pairs = classification_results['most_common_errors'][:25]  # 前25个最常见错误\n", | |||
|  |     "    if error_pairs:\n", | |||
|  |     "        # 创建错误矩阵\n", | |||
|  |     "        unique_phonemes = list(set([pair[0][0] for pair in error_pairs] + [pair[0][1] for pair in error_pairs]))\n", | |||
|  |     "        error_matrix = np.zeros((len(unique_phonemes), len(unique_phonemes)))\n", | |||
|  |     "        \n", | |||
|  |     "        phoneme_to_idx = {phoneme: i for i, phoneme in enumerate(unique_phonemes)}\n", | |||
|  |     "        \n", | |||
|  |     "        for (true_id, pred_id), count in error_pairs:\n", | |||
|  |     "            if true_id in phoneme_to_idx and pred_id in phoneme_to_idx:\n", | |||
|  |     "                true_idx = phoneme_to_idx[true_id]\n", | |||
|  |     "                pred_idx = phoneme_to_idx[pred_id]\n", | |||
|  |     "                error_matrix[true_idx, pred_idx] = count\n", | |||
|  |     "        \n", | |||
|  |     "        im = axes[1, 2].imshow(error_matrix, cmap='Reds', interpolation='nearest')\n", | |||
|  |     "        axes[1, 2].set_title('最常见错误分布')\n", | |||
|  |     "        \n", | |||
|  |     "        # 设置标签\n", | |||
|  |     "        phoneme_names = [LOGIT_TO_PHONEME[p] for p in unique_phonemes]\n", | |||
|  |     "        axes[1, 2].set_xticks(range(len(phoneme_names)))\n", | |||
|  |     "        axes[1, 2].set_yticks(range(len(phoneme_names)))\n", | |||
|  |     "        axes[1, 2].set_xticklabels(phoneme_names, rotation=45, ha='right')\n", | |||
|  |     "        axes[1, 2].set_yticklabels(phoneme_names)\n", | |||
|  |     "        axes[1, 2].set_xlabel('预测音素')\n", | |||
|  |     "        axes[1, 2].set_ylabel('真实音素')\n", | |||
|  |     "        \n", | |||
|  |     "        # 添加颜色条\n", | |||
|  |     "        cbar = plt.colorbar(im, ax=axes[1, 2], shrink=0.8)\n", | |||
|  |     "        cbar.set_label('错误次数')\n", | |||
|  |     "    \n", | |||
|  |     "    plt.tight_layout()\n", | |||
|  |     "    plt.savefig('./processed_datasets/classification_analysis.png', dpi=300, bbox_inches='tight')\n", | |||
|  |     "    print(\"📁 分类分析图表已保存至: ./processed_datasets/classification_analysis.png\")\n", | |||
|  |     "    plt.show()\n", | |||
|  |     "\n", | |||
|  |     "print(\"✅ 回归转分类分析功能已创建!\")\n", | |||
|  |     "print(\"🎯 主要功能:\")\n", | |||
|  |     "print(\"• 将40维概率回归结果转换为分类预测\")\n", | |||
|  |     "print(\"• 计算分类准确率和置信度分析\")\n", | |||
|  |     "print(\"• 提供Top-K准确率评估\")\n", | |||
|  |     "print(\"• 生成详细的混淆矩阵和错误分析\")\n", | |||
|  |     "print(\"• 创建全面的可视化图表\")" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": 19, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [ | |||
|  |     { | |||
|  |      "name": "stdout", | |||
|  |      "output_type": "stream", | |||
|  |      "text": [ | |||
|  |       "⚠️ 随机森林模型尚未训练完成\n", | |||
|  |       "💡 请先运行前面的训练代码\n" | |||
|  |      ] | |||
|  |     } | |||
|  |    ], | |||
|  |    "source": [ | |||
|  |     "# 🎯 完整的回归转分类评估流程\n", | |||
|  |     "def complete_regression_classification_evaluation(rf_model, X_test, y_test, dataset_name=\"测试集\"):\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    完整的回归模型转分类结果评估流程\n", | |||
|  |     "    \"\"\"\n", | |||
|  |     "    print(f\"\\n🎯 {dataset_name}完整评估: 回归 → 分类\")\n", | |||
|  |     "    print(\"=\"*70)\n", | |||
|  |     "    \n", | |||
|  |     "    # 1. 获取回归预测结果\n", | |||
|  |     "    print(\"📊 第1步: 获取回归预测...\")\n", | |||
|  |     "    y_pred_probs = rf_model.predict(X_test)\n", | |||
|  |     "    \n", | |||
|  |     "    # 确保概率值在合理范围内\n", | |||
|  |     "    y_pred_probs = np.clip(y_pred_probs, 0, 1)\n", | |||
|  |     "    \n", | |||
|  |     "    # 2. 回归性能评估\n", | |||
|  |     "    print(\"\\n📈 第2步: 回归性能评估...\")\n", | |||
|  |     "    mse = mean_squared_error(y_test, y_pred_probs) \n", | |||
|  |     "    mae = mean_absolute_error(y_test, y_pred_probs)\n", | |||
|  |     "    r2 = r2_score(y_test, y_pred_probs)\n", | |||
|  |     "    \n", | |||
|  |     "    print(f\"   回归 MSE: {mse:.6f}\")\n", | |||
|  |     "    print(f\"   回归 MAE: {mae:.6f}\")\n", | |||
|  |     "    print(f\"   回归 R²: {r2:.4f}\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 3. 概率归一化(softmax)\n", | |||
|  |     "    print(\"\\n🔄 第3步: 概率归一化...\")\n", | |||
|  |     "    def softmax(x, axis=-1):\n", | |||
|  |     "        exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))\n", | |||
|  |     "        return exp_x / np.sum(exp_x, axis=axis, keepdims=True)\n", | |||
|  |     "    \n", | |||
|  |     "    # 对预测结果应用softmax,使其成为真正的概率分布\n", | |||
|  |     "    y_pred_probs_normalized = softmax(y_pred_probs)\n", | |||
|  |     "    y_test_normalized = softmax(y_test)  # 也对真实标签归一化\n", | |||
|  |     "    \n", | |||
|  |     "    print(f\"   预测概率归一化前: 每行和均值 = {np.mean(np.sum(y_pred_probs, axis=1)):.4f}\")\n", | |||
|  |     "    print(f\"   预测概率归一化后: 每行和均值 = {np.mean(np.sum(y_pred_probs_normalized, axis=1)):.4f}\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 4. 分类结果分析\n", | |||
|  |     "    print(\"\\n🎯 第4步: 分类结果分析...\")\n", | |||
|  |     "    classification_results = regression_to_classification_analysis(\n", | |||
|  |     "        y_test_normalized, y_pred_probs_normalized, show_detailed_metrics=True\n", | |||
|  |     "    )\n", | |||
|  |     "    \n", | |||
|  |     "    # 5. 创建可视化\n", | |||
|  |     "    print(\"\\n📊 第5步: 创建可视化图表...\")\n", | |||
|  |     "    create_classification_visualizations(y_test_normalized, y_pred_probs_normalized, classification_results)\n", | |||
|  |     "    \n", | |||
|  |     "    # 6. 保存结果\n", | |||
|  |     "    print(\"\\n💾 第6步: 保存分析结果...\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 保存分类结果\n", | |||
|  |     "    results_df = pd.DataFrame({\n", | |||
|  |     "        'true_class': classification_results['y_true_classes'],\n", | |||
|  |     "        'pred_class': classification_results['y_pred_classes'],\n", | |||
|  |     "        'true_phoneme': [LOGIT_TO_PHONEME[i] for i in classification_results['y_true_classes']],\n", | |||
|  |     "        'pred_phoneme': [LOGIT_TO_PHONEME[i] for i in classification_results['y_pred_classes']],\n", | |||
|  |     "        'pred_confidence': classification_results['pred_confidences'],\n", | |||
|  |     "        'is_correct': classification_results['y_true_classes'] == classification_results['y_pred_classes']\n", | |||
|  |     "    })\n", | |||
|  |     "    \n", | |||
|  |     "    results_df.to_csv('./processed_datasets/classification_results.csv', index=False)\n", | |||
|  |     "    \n", | |||
|  |     "    # 保存详细的概率预测\n", | |||
|  |     "    prob_results_df = pd.DataFrame(y_pred_probs_normalized, \n", | |||
|  |     "                                  columns=[f'prob_{LOGIT_TO_PHONEME[i]}' for i in range(40)])\n", | |||
|  |     "    prob_results_df['true_class'] = classification_results['y_true_classes']\n", | |||
|  |     "    prob_results_df['pred_class'] = classification_results['y_pred_classes']\n", | |||
|  |     "    \n", | |||
|  |     "    prob_results_df.to_csv('./processed_datasets/probability_predictions.csv', index=False)\n", | |||
|  |     "    \n", | |||
|  |     "    print(\"📁 结果已保存:\")\n", | |||
|  |     "    print(\"   • ./processed_datasets/classification_results.csv (分类结果)\")\n", | |||
|  |     "    print(\"   • ./processed_datasets/probability_predictions.csv (概率预测)\")\n", | |||
|  |     "    print(\"   • ./processed_datasets/classification_analysis.png (可视化图表)\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 7. 总结报告\n", | |||
|  |     "    print(f\"\\n📋 {dataset_name}评估总结:\")\n", | |||
|  |     "    print(\"=\"*50)\n", | |||
|  |     "    print(f\"🔸 回归性能:\")\n", | |||
|  |     "    print(f\"   MSE: {mse:.6f}\")\n", | |||
|  |     "    print(f\"   R²: {r2:.4f}\")\n", | |||
|  |     "    print(f\"🔸 分类性能:\")\n", | |||
|  |     "    print(f\"   准确率: {classification_results['accuracy']:.4f} ({classification_results['accuracy']*100:.2f}%)\")\n", | |||
|  |     "    print(f\"   平均置信度: {classification_results['pred_confidences'].mean():.4f}\")\n", | |||
|  |     "    \n", | |||
|  |     "    # 计算Top-K准确率\n", | |||
|  |     "    for k in [1, 3, 5]:\n", | |||
|  |     "        top_k_pred = np.argsort(y_pred_probs_normalized, axis=1)[:, -k:]\n", | |||
|  |     "        top_k_accuracy = np.mean([classification_results['y_true_classes'][i] in top_k_pred[i] \n", | |||
|  |     "                                for i in range(len(classification_results['y_true_classes']))])\n", | |||
|  |     "        print(f\"   Top-{k} 准确率: {top_k_accuracy:.4f} ({top_k_accuracy*100:.2f}%)\")\n", | |||
|  |     "    \n", | |||
|  |     "    return {\n", | |||
|  |     "        'regression_metrics': {'mse': mse, 'mae': mae, 'r2': r2},\n", | |||
|  |     "        'classification_results': classification_results,\n", | |||
|  |     "        'normalized_predictions': y_pred_probs_normalized,\n", | |||
|  |     "        'normalized_true': y_test_normalized\n", | |||
|  |     "    }\n", | |||
|  |     "\n", | |||
|  |     "# 如果模型已训练且有验证数据,执行完整评估\n", | |||
|  |     "if 'rf_regressor' in locals() and hasattr(rf_regressor, 'is_fitted') and rf_regressor.is_fitted:\n", | |||
|  |     "    if 'X_val' in locals() and X_val is not None and 'y_val' in locals() and y_val is not None:\n", | |||
|  |     "        print(\"🚀 开始完整的回归转分类评估...\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 执行完整评估\n", | |||
|  |     "        evaluation_results = complete_regression_classification_evaluation(\n", | |||
|  |     "            rf_regressor, X_val, y_val, \"验证集\"\n", | |||
|  |     "        )\n", | |||
|  |     "        \n", | |||
|  |     "        print(f\"\\n🎉 评估完成!\")\n", | |||
|  |     "        print(f\"✅ 随机森林回归模型成功转换为分类结果\")\n", | |||
|  |     "        print(f\"📊 生成了详细的性能分析和可视化\")\n", | |||
|  |     "        print(f\"💾 所有结果已保存到文件\")\n", | |||
|  |     "        \n", | |||
|  |     "        # 如果有测试数据,也进行评估\n", | |||
|  |     "        if 'X_test' in locals() and X_test is not None and 'y_test' in locals() and y_test is not None:\n", | |||
|  |     "            print(f\"\\n🔮 开始测试集评估...\")\n", | |||
|  |     "            test_evaluation_results = complete_regression_classification_evaluation(\n", | |||
|  |     "                rf_regressor, X_test, y_test, \"测试集\"\n", | |||
|  |     "            )\n", | |||
|  |     "    else:\n", | |||
|  |     "        print(\"⚠️ 没有可用的验证数据进行评估\")\n", | |||
|  |     "else:\n", | |||
|  |     "    print(\"⚠️ 随机森林模型尚未训练完成\")\n", | |||
|  |     "    print(\"💡 请先运行前面的训练代码\")" | |||
|  |    ] | |||
|  |   }, | |||
|  |   { | |||
|  |    "cell_type": "code", | |||
|  |    "execution_count": null, | |||
|  |    "metadata": {}, | |||
|  |    "outputs": [], | |||
|  |    "source": [] | |||
|  |   } | |||
|  |  ], | |||
|  |  "metadata": { | |||
|  |   "kaggle": { | |||
|  |    "accelerator": "tpu1vmV38", | |||
|  |    "dataSources": [ | |||
|  |     { | |||
|  |      "databundleVersionId": 13056355, | |||
|  |      "sourceId": 106809, | |||
|  |      "sourceType": "competition" | |||
|  |     } | |||
|  |    ], | |||
|  |    "dockerImageVersionId": 31091, | |||
|  |    "isGpuEnabled": false, | |||
|  |    "isInternetEnabled": true, | |||
|  |    "language": "python", | |||
|  |    "sourceType": "notebook" | |||
|  |   }, | |||
|  |   "kernelspec": { | |||
|  |    "display_name": "Python 3 (ipykernel)", | |||
|  |    "language": "python", | |||
|  |    "name": "python3" | |||
|  |   } | |||
|  |  }, | |||
|  |  "nbformat": 4, | |||
|  |  "nbformat_minor": 4 | |||
|  | } |