{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 环境配置 与 utils" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://download.pytorch.org/whl/cu126\n", "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n", "Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (0.21.0+cu124)\n", "Requirement already satisfied: torchaudio in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n", "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.14.0)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.5)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.5.1)\n", "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)\n", " Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)\n", " Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)\n", " Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)\n", " Downloading https://download.pytorch.org/whl/cu126/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)\n", " Downloading https://download.pytorch.org/whl/cu126/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)\n", " Downloading https://download.pytorch.org/whl/cu126/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-curand-cu12==10.3.5.147 (from torch)\n", " Downloading https://download.pytorch.org/whl/cu126/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)\n", " Downloading https://download.pytorch.org/whl/cu126/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch)\n", " Downloading https://download.pytorch.org/whl/cu126/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n", "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch)\n", " Downloading https://download.pytorch.org/whl/cu126/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n", "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision) (1.26.4)\n", "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (11.2.1)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n", "Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.3.8)\n", "Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.2.4)\n", "Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (0.1.1)\n", "Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2025.2.0)\n", "Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2022.2.0)\n", "Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2.4.1)\n", "Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2024.2.0)\n", "Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2022.2.0)\n", "Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy->torchvision) (1.4.0)\n", "Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy->torchvision) (2024.2.0)\n", "Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy->torchvision) (2024.2.0)\n", "Downloading https://download.pytorch.org/whl/cu126/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 363.4/363.4 MB 4.5 MB/s eta 0:00:00\n", "Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.8/13.8 MB 84.2 MB/s eta 0:00:00\n", "Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.6/24.6 MB 78.5 MB/s eta 0:00:00\n", "Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 883.7/883.7 kB 41.1 MB/s eta 0:00:00\n", "Downloading https://download.pytorch.org/whl/cu126/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 2.1 MB/s eta 0:00:00\n", "Downloading https://download.pytorch.org/whl/cu126/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 211.5/211.5 MB 2.1 MB/s eta 0:00:00\n", "Downloading https://download.pytorch.org/whl/cu126/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.3/56.3 MB 30.7 MB/s eta 0:00:00\n", "Downloading https://download.pytorch.org/whl/cu126/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 127.9/127.9 MB 12.8 MB/s eta 0:00:00\n", "Downloading https://download.pytorch.org/whl/cu126/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.5/207.5 MB 7.8 MB/s eta 0:00:00\n", "Downloading https://download.pytorch.org/whl/cu126/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.1/21.1 MB 79.7 MB/s eta 0:00:00\n", "Installing collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12\n", " Attempting uninstall: nvidia-nvjitlink-cu12\n", " Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n", " Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n", " Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n", " Attempting uninstall: nvidia-curand-cu12\n", " Found existing installation: nvidia-curand-cu12 10.3.6.82\n", " Uninstalling nvidia-curand-cu12-10.3.6.82:\n", " Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n", " Attempting uninstall: nvidia-cufft-cu12\n", " Found existing installation: nvidia-cufft-cu12 11.2.3.61\n", " Uninstalling nvidia-cufft-cu12-11.2.3.61:\n", " Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n", " Attempting uninstall: nvidia-cuda-runtime-cu12\n", " Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n", " Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n", " Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n", " Attempting uninstall: nvidia-cuda-nvrtc-cu12\n", " Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n", " Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n", " Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n", " Attempting uninstall: nvidia-cuda-cupti-cu12\n", " Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n", " Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n", " Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n", " Attempting uninstall: nvidia-cublas-cu12\n", " Found existing installation: nvidia-cublas-cu12 12.5.3.2\n", " Uninstalling nvidia-cublas-cu12-12.5.3.2:\n", " Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n", " Attempting uninstall: nvidia-cusparse-cu12\n", " Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n", " Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n", " Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n", " Attempting uninstall: nvidia-cudnn-cu12\n", " Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n", " Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n", " Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n", " Attempting uninstall: nvidia-cusolver-cu12\n", " Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n", " Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n", " Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n", "Successfully installed nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127\n", "Collecting jupyter==1.1.1\n", " Downloading jupyter-1.1.1-py2.py3-none-any.whl.metadata (2.0 kB)\n", "Requirement already satisfied: numpy<2.1.0,>=1.26.0 in /usr/local/lib/python3.11/dist-packages (1.26.4)\n", "Collecting pandas==2.3.0\n", " Downloading pandas-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (91 kB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 91.2/91.2 kB 4.6 MB/s eta 0:00:00\n", "Collecting matplotlib==3.10.1\n", " Downloading matplotlib-3.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)\n", "Collecting scipy==1.15.2\n", " Downloading scipy-1.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 62.0/62.0 kB 3.0 MB/s eta 0:00:00\n", "Collecting scikit-learn==1.6.1\n", " Downloading scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (18 kB)\n", "Requirement already satisfied: tqdm==4.67.1 in /usr/local/lib/python3.11/dist-packages (4.67.1)\n", "Collecting g2p_en==2.1.0\n", " Downloading g2p_en-2.1.0-py3-none-any.whl.metadata (4.5 kB)\n", "Collecting h5py==3.13.0\n", " Downloading h5py-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.5 kB)\n", "Requirement already satisfied: omegaconf==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n", "Requirement already satisfied: editdistance==0.8.1 in /usr/local/lib/python3.11/dist-packages (0.8.1)\n", "Requirement already satisfied: huggingface-hub==0.33.1 in /usr/local/lib/python3.11/dist-packages (0.33.1)\n", "Collecting transformers==4.53.0\n", " Downloading transformers-4.53.0-py3-none-any.whl.metadata (39 kB)\n", "Requirement already satisfied: tokenizers==0.21.2 in /usr/local/lib/python3.11/dist-packages (0.21.2)\n", "Requirement already satisfied: accelerate==1.8.1 in /usr/local/lib/python3.11/dist-packages (1.8.1)\n", "Collecting bitsandbytes==0.46.0\n", " Downloading bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)\n", "Requirement already satisfied: notebook in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.5.4)\n", "Requirement already satisfied: jupyter-console in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.1.0)\n", "Requirement already satisfied: nbconvert in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.4.5)\n", "Requirement already satisfied: ipykernel in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.17.1)\n", "Requirement already satisfied: ipywidgets in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (8.1.5)\n", "Requirement already satisfied: jupyterlab in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (3.6.8)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2.9.0.post0)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n", "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n", "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.3.2)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (0.12.1)\n", "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (4.58.4)\n", "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.4.8)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (25.0)\n", "Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (11.2.1)\n", "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (3.0.9)\n", "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (1.5.1)\n", "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (3.6.0)\n", "Requirement already satisfied: nltk>=3.2.4 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (3.9.1)\n", "Requirement already satisfied: inflect>=0.3.1 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (7.5.0)\n", "Collecting distance>=0.1.3 (from g2p_en==2.1.0)\n", " Downloading Distance-0.1.3.tar.gz (180 kB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 180.3/180.3 kB 9.7 MB/s eta 0:00:00\n", " Preparing metadata (setup.py): started\n", " Preparing metadata (setup.py): finished with status 'done'\n", "Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (4.9.3)\n", "Requirement already satisfied: PyYAML>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (6.0.2)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (3.18.0)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2025.5.1)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2.32.4)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (4.14.0)\n", "Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (1.1.5)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (2024.11.6)\n", "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (0.5.3)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (7.0.0)\n", "Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (2.6.0+cu124)\n", "Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.3.8)\n", "Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.2.4)\n", "Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (0.1.1)\n", "Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2025.2.0)\n", "Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2022.2.0)\n", "Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2.4.1)\n", "Requirement already satisfied: more_itertools>=8.5.0 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (10.7.0)\n", "Requirement already satisfied: typeguard>=4.0.1 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (4.4.4)\n", "Requirement already satisfied: click in /usr/local/lib/python3.11/dist-packages (from nltk>=3.2.4->g2p_en==2.1.0) (8.2.1)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas==2.3.0) (1.17.0)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.5)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.1.6)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (9.1.0.70)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.5.8)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.2.1.3)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (10.3.5.147)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.6.1.9)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.3.1.170)\n", "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (0.6.2)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.2.0)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=2.0.0->accelerate==1.8.1) (1.3.0)\n", "Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.8.0)\n", "Requirement already satisfied: ipython>=7.23.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (7.34.0)\n", "Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (8.6.3)\n", "Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (0.1.7)\n", "Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.6.0)\n", "Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (24.0.1)\n", "Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (6.5.1)\n", "Requirement already satisfied: traitlets>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (5.7.1)\n", "Requirement already satisfied: comm>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (0.2.2)\n", "Requirement already satisfied: widgetsnbextension~=4.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (4.0.14)\n", "Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (3.0.15)\n", "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (3.0.51)\n", "Requirement already satisfied: pygments in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (2.19.2)\n", "Requirement already satisfied: jupyter-core in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (5.8.1)\n", "Requirement already satisfied: jupyterlab-server~=2.19 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.27.3)\n", "Requirement already satisfied: jupyter-server<3,>=1.16.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.12.5)\n", "Requirement already satisfied: jupyter-ydoc~=0.2.4 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.2.5)\n", "Requirement already satisfied: jupyter-server-ydoc~=0.8.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.8.0)\n", "Requirement already satisfied: nbclassic in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (1.3.1)\n", "Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (25.1.0)\n", "Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.2.0)\n", "Requirement already satisfied: nbformat in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (5.10.4)\n", "Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (1.8.3)\n", "Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.18.1)\n", "Requirement already satisfied: prometheus-client in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.22.1)\n", "Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.8.4)\n", "Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.3.0)\n", "Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.4)\n", "Requirement already satisfied: bleach in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (6.2.0)\n", "Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (1.5.1)\n", "Requirement already satisfied: testpath in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.6.0)\n", "Requirement already satisfied: defusedxml in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.7.1)\n", "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (4.13.4)\n", "Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.5.13)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (3.0.2)\n", "Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n", "Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2022.2.0)\n", "Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy<2.1.0,>=1.26.0) (1.4.0)\n", "Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy<2.1.0,>=1.26.0) (2024.2.0)\n", "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.4.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2.5.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2025.6.15)\n", "Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n", "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (75.2.0)\n", "Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.19.2)\n", "Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.4.2)\n", "Requirement already satisfied: pickleshare in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.7.5)\n", "Requirement already satisfied: backcall in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.2.0)\n", "Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.9.0)\n", "Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.11/dist-packages (from jupyter-core->jupyterlab->jupyter==1.1.1) (4.3.8)\n", "Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (4.9.0)\n", "Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.12.0)\n", "Requirement already satisfied: jupyter-server-terminals in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.5.3)\n", "Requirement already satisfied: overrides in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (7.7.0)\n", "Requirement already satisfied: websocket-client in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.8.0)\n", "Requirement already satisfied: jupyter-server-fileid<1,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.9.3)\n", "Requirement already satisfied: ypy-websocket<0.9.0,>=0.8.2 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.8.4)\n", "Requirement already satisfied: y-py<0.7.0,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-ydoc~=0.2.4->jupyterlab->jupyter==1.1.1) (0.6.2)\n", "Requirement already satisfied: babel>=2.10 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2.17.0)\n", "Requirement already satisfied: json5>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.12.0)\n", "Requirement already satisfied: jsonschema>=4.18.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (4.24.0)\n", "Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.11/dist-packages (from nbclassic->jupyterlab->jupyter==1.1.1) (0.2.4)\n", "Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.11/dist-packages (from nbformat->notebook->jupyter==1.1.1) (2.21.1)\n", "Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->jupyter-console->jupyter==1.1.1) (0.2.13)\n", "Requirement already satisfied: ptyprocess in /usr/local/lib/python3.11/dist-packages (from terminado>=0.8.3->notebook->jupyter==1.1.1) (0.7.0)\n", "Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.11/dist-packages (from argon2-cffi->notebook->jupyter==1.1.1) (21.2.0)\n", "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->nbconvert->jupyter==1.1.1) (2.7)\n", "Requirement already satisfied: webencodings in /usr/local/lib/python3.11/dist-packages (from bleach->nbconvert->jupyter==1.1.1) (0.5.1)\n", "Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio>=3.1.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.1)\n", "Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.11/dist-packages (from jedi>=0.16->ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.8.4)\n", "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (25.3.0)\n", "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2025.4.1)\n", "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.36.2)\n", "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.25.1)\n", "Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.3.0)\n", "Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.4)\n", "Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.1)\n", "Requirement already satisfied: aiofiles<23,>=22.1.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (22.1.0)\n", "Requirement already satisfied: aiosqlite<1,>=0.17.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.21.0)\n", "Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (1.17.1)\n", "Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (2.22)\n", "Requirement already satisfied: fqdn in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.5.1)\n", "Requirement already satisfied: isoduration in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (20.11.0)\n", "Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.0.0)\n", "Requirement already satisfied: uri-template in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n", "Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (24.11.1)\n", "Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.11/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n", "Requirement already satisfied: types-python-dateutil>=2.8.10 in /usr/local/lib/python3.11/dist-packages (from arrow>=0.15.0->isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (2.9.0.20250516)\n", "Downloading jupyter-1.1.1-py2.py3-none-any.whl (2.7 kB)\n", "Downloading pandas-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.4 MB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 12.4/12.4 MB 84.0 MB/s eta 0:00:00\n", "Downloading matplotlib-3.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.6 MB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.6/8.6 MB 103.4 MB/s eta 0:00:00\n", "Downloading scipy-1.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (37.6 MB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 37.6/37.6 MB 44.8 MB/s eta 0:00:00\n", "Downloading scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.5 MB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.5/13.5 MB 75.2 MB/s eta 0:00:00\n", "Downloading g2p_en-2.1.0-py3-none-any.whl (3.1 MB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.1/3.1 MB 80.6 MB/s eta 0:00:00\n", "Downloading h5py-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.5/4.5 MB 89.1 MB/s eta 0:00:00\n", "Downloading transformers-4.53.0-py3-none-any.whl (10.8 MB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 10.8/10.8 MB 105.1 MB/s eta 0:00:00\n", "Downloading bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl (67.0 MB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 67.0/67.0 MB 17.9 MB/s eta 0:00:00\n", "Building wheels for collected packages: distance\n", " Building wheel for distance (setup.py): started\n", " Building wheel for distance (setup.py): finished with status 'done'\n", " Created wheel for distance: filename=Distance-0.1.3-py3-none-any.whl size=16256 sha256=3bf3b151b4d15c4f5e442085e48cff050eec4c6d192a307e2b09a77daa84a5dc\n", " Stored in directory: /root/.cache/pip/wheels/fb/cd/9c/3ab5d666e3bcacc58900b10959edd3816cc9557c7337986322\n", "Successfully built distance\n", "Installing collected packages: distance, jupyter, scipy, transformers, scikit-learn, pandas, matplotlib, h5py, g2p_en, bitsandbytes\n", " Attempting uninstall: scipy\n", " Found existing installation: scipy 1.15.3\n", " Uninstalling scipy-1.15.3:\n", " Successfully uninstalled scipy-1.15.3\n", " Attempting uninstall: transformers\n", " Found existing installation: transformers 4.52.4\n", " Uninstalling transformers-4.52.4:\n", " Successfully uninstalled transformers-4.52.4\n", " Attempting uninstall: scikit-learn\n", " Found existing installation: scikit-learn 1.2.2\n", " Uninstalling scikit-learn-1.2.2:\n", " Successfully uninstalled scikit-learn-1.2.2\n", " Attempting uninstall: pandas\n", " Found existing installation: pandas 2.2.3\n", " Uninstalling pandas-2.2.3:\n", " Successfully uninstalled pandas-2.2.3\n", " Attempting uninstall: matplotlib\n", " Found existing installation: matplotlib 3.7.2\n", " Uninstalling matplotlib-3.7.2:\n", " Successfully uninstalled matplotlib-3.7.2\n", " Attempting uninstall: h5py\n", " Found existing installation: h5py 3.14.0\n", " Uninstalling h5py-3.14.0:\n", " Successfully uninstalled h5py-3.14.0\n", "Successfully installed bitsandbytes-0.46.0 distance-0.1.3 g2p_en-2.1.0 h5py-3.13.0 jupyter-1.1.1 matplotlib-3.10.1 pandas-2.3.0 scikit-learn-1.6.1 scipy-1.15.2 transformers-4.53.0\n", "Requirement already satisfied: PyDrive2 in /usr/local/lib/python3.11/dist-packages (1.21.3)\n", "Requirement already satisfied: google-api-python-client>=1.12.5 in /usr/local/lib/python3.11/dist-packages (from PyDrive2) (2.173.0)\n", "Requirement already satisfied: oauth2client>=4.0.0 in /usr/local/lib/python3.11/dist-packages (from PyDrive2) (4.1.3)\n", "Requirement already satisfied: PyYAML>=3.0 in /usr/local/lib/python3.11/dist-packages (from PyDrive2) (6.0.2)\n", "Collecting cryptography<44 (from PyDrive2)\n", " Downloading cryptography-43.0.3-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (5.4 kB)\n", "Collecting pyOpenSSL<=24.2.1,>=19.1.0 (from PyDrive2)\n", " Downloading pyOpenSSL-24.2.1-py3-none-any.whl.metadata (13 kB)\n", "Requirement already satisfied: cffi>=1.12 in /usr/local/lib/python3.11/dist-packages (from cryptography<44->PyDrive2) (1.17.1)\n", "Requirement already satisfied: httplib2<1.0.0,>=0.19.0 in /usr/local/lib/python3.11/dist-packages (from google-api-python-client>=1.12.5->PyDrive2) (0.22.0)\n", "Requirement already satisfied: google-auth!=2.24.0,!=2.25.0,<3.0.0,>=1.32.0 in /usr/local/lib/python3.11/dist-packages (from google-api-python-client>=1.12.5->PyDrive2) (2.40.3)\n", "Requirement already satisfied: google-auth-httplib2<1.0.0,>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from google-api-python-client>=1.12.5->PyDrive2) (0.2.0)\n", "Requirement already satisfied: google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5 in /usr/local/lib/python3.11/dist-packages (from google-api-python-client>=1.12.5->PyDrive2) (1.34.1)\n", "Requirement already satisfied: uritemplate<5,>=3.0.1 in /usr/local/lib/python3.11/dist-packages (from google-api-python-client>=1.12.5->PyDrive2) (4.2.0)\n", "Requirement already satisfied: pyasn1>=0.1.7 in /usr/local/lib/python3.11/dist-packages (from oauth2client>=4.0.0->PyDrive2) (0.6.1)\n", "Requirement already satisfied: pyasn1-modules>=0.0.5 in /usr/local/lib/python3.11/dist-packages (from oauth2client>=4.0.0->PyDrive2) (0.4.2)\n", "Requirement already satisfied: rsa>=3.1.4 in /usr/local/lib/python3.11/dist-packages (from oauth2client>=4.0.0->PyDrive2) (4.9.1)\n", "Requirement already satisfied: six>=1.6.1 in /usr/local/lib/python3.11/dist-packages (from oauth2client>=4.0.0->PyDrive2) (1.17.0)\n", "Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.12->cryptography<44->PyDrive2) (2.22)\n", "Requirement already satisfied: googleapis-common-protos<2.0dev,>=1.56.2 in /usr/local/lib/python3.11/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (1.70.0)\n", "Requirement already satisfied: protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<4.0.0dev,>=3.19.5 in /usr/local/lib/python3.11/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (3.20.3)\n", "Requirement already satisfied: requests<3.0.0dev,>=2.18.0 in /usr/local/lib/python3.11/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (2.32.4)\n", "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from google-auth!=2.24.0,!=2.25.0,<3.0.0,>=1.32.0->google-api-python-client>=1.12.5->PyDrive2) (5.5.2)\n", "Requirement already satisfied: pyparsing!=3.0.0,!=3.0.1,!=3.0.2,!=3.0.3,<4,>=2.4.2 in /usr/local/lib/python3.11/dist-packages (from httplib2<1.0.0,>=0.19.0->google-api-python-client>=1.12.5->PyDrive2) (3.0.9)\n", "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (3.4.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (2.5.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (2025.6.15)\n", "Downloading cryptography-43.0.3-cp39-abi3-manylinux_2_28_x86_64.whl (4.0 MB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.0/4.0 MB 49.5 MB/s eta 0:00:00\n", "Downloading pyOpenSSL-24.2.1-py3-none-any.whl (58 kB)\n", " ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 58.4/58.4 kB 3.4 MB/s eta 0:00:00\n", "Installing collected packages: cryptography, pyOpenSSL\n", " Attempting uninstall: cryptography\n", " Found existing installation: cryptography 44.0.3\n", " Uninstalling cryptography-44.0.3:\n", " Successfully uninstalled cryptography-44.0.3\n", " Attempting uninstall: pyOpenSSL\n", " Found existing installation: pyOpenSSL 25.1.0\n", " Uninstalling pyOpenSSL-25.1.0:\n", " Successfully uninstalled pyOpenSSL-25.1.0\n", "Successfully installed cryptography-43.0.3 pyOpenSSL-24.2.1\n", "Obtaining file:///kaggle/working/nejm-brain-to-text\n", " Preparing metadata (setup.py): started\n", " Preparing metadata (setup.py): finished with status 'done'\n", "Installing collected packages: nejm_b2txt_utils\n", " Running setup.py develop for nejm_b2txt_utils\n", "Successfully installed nejm_b2txt_utils-0.0.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Cloning into 'nejm-brain-to-text'...\n", "ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "bigframes 2.8.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.\n", "gensim 4.3.3 requires scipy<1.14.0,>=1.7.0, but you have scipy 1.15.2 which is incompatible.\n", "dask-cudf-cu12 25.2.2 requires pandas<2.2.4dev0,>=2.0, but you have pandas 2.3.0 which is incompatible.\n", "cudf-cu12 25.2.2 requires pandas<2.2.4dev0,>=2.0, but you have pandas 2.3.0 which is incompatible.\n", "datasets 3.6.0 requires fsspec[http]<=2025.3.0,>=2023.1.0, but you have fsspec 2025.5.1 which is incompatible.\n", "ydata-profiling 4.16.1 requires matplotlib<=3.10,>=3.5, but you have matplotlib 3.10.1 which is incompatible.\n", "category-encoders 2.7.0 requires scikit-learn<1.6.0,>=1.0.0, but you have scikit-learn 1.6.1 which is incompatible.\n", "cesium 0.12.4 requires numpy<3.0,>=2.0, but you have numpy 1.26.4 which is incompatible.\n", "google-colab 1.0.0 requires google-auth==2.38.0, but you have google-auth 2.40.3 which is incompatible.\n", "google-colab 1.0.0 requires notebook==6.5.7, but you have notebook 6.5.4 which is incompatible.\n", "google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.3.0 which is incompatible.\n", "google-colab 1.0.0 requires requests==2.32.3, but you have requests 2.32.4 which is incompatible.\n", "google-colab 1.0.0 requires tornado==6.4.2, but you have tornado 6.5.1 which is incompatible.\n", "dopamine-rl 4.1.2 requires gymnasium>=1.0.0, but you have gymnasium 0.29.0 which is incompatible.\n", "pandas-gbq 0.29.1 requires google-api-core<3.0.0,>=2.10.2, but you have google-api-core 1.34.1 which is incompatible.\n", "bigframes 2.8.0 requires google-cloud-bigquery[bqstorage,pandas]>=3.31.0, but you have google-cloud-bigquery 3.25.0 which is incompatible.\n", "bigframes 2.8.0 requires rich<14,>=12.4.4, but you have rich 14.0.0 which is incompatible.\n" ] } ], "source": [ "%%bash\n", "rm -rf /kaggle/working/nejm-brain-to-text/\n", "git clone https://github.com/ZH-CEN/nejm-brain-to-text.git\n", "cp /kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl /kaggle/working/nejm-brain-to-text/data/t15_copyTask.pkl\n", "\n", "ln -s /kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline /kaggle/working/nejm-brain-to-text/data\n", "ln -s /kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final /kaggle/working/nejm-brain-to-text/data\n", "ln -s /kaggle/input/rnn-pretagged-data /kaggle/working/nejm-brain-to-text/data\n", "\n", "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n", "\n", "pip install \\\n", " jupyter==1.1.1 \\\n", " \"numpy>=1.26.0,<2.1.0\" \\\n", " pandas==2.3.0 \\\n", " matplotlib==3.10.1 \\\n", " scipy==1.15.2 \\\n", " scikit-learn==1.6.1 \\\n", " tqdm==4.67.1 \\\n", " g2p_en==2.1.0 \\\n", " h5py==3.13.0 \\\n", " omegaconf==2.3.0 \\\n", " editdistance==0.8.1 \\\n", " huggingface-hub==0.33.1 \\\n", " transformers==4.53.0 \\\n", " tokenizers==0.21.2 \\\n", " accelerate==1.8.1 \\\n", " bitsandbytes==0.46.0\n", " \n", "pip install PyDrive2\n", "\n", "cd /kaggle/working/nejm-brain-to-text/\n", "pip install -e .\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working/nejm-brain-to-text\n" ] } ], "source": [ "%cd /kaggle/working/nejm-brain-to-text\n", "import numpy as np\n", "import os\n", "import pickle\n", "import matplotlib.pyplot as plt\n", "import matplotlib\n", "from g2p_en import G2p\n", "import pandas as pd\n", "import numpy as np\n", "from nejm_b2txt_utils.general_utils import *\n", "\n", "matplotlib.rcParams['pdf.fonttype'] = 42\n", "matplotlib.rcParams['ps.fonttype'] = 42\n", "matplotlib.rcParams['font.family'] = 'sans-serif'\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working/nejm-brain-to-text/model_training\n" ] } ], "source": [ "%cd model_training/\n", "from data_augmentations import gauss_smooth\n", "# single decoding step function that also returns smoothed input\n", "# smooths data and puts it through the model, returning both logits and smoothed input.\n", "def runSingleDecodingStepWithSmoothedInput(x, input_layer, model, model_args, device):\n", "\n", " # Use autocast for efficiency\n", " with torch.autocast(device_type = \"cuda\", enabled = model_args['use_amp'], dtype = torch.bfloat16):\n", "\n", " smoothed_x = gauss_smooth(\n", " inputs = x, \n", " device = device,\n", " smooth_kernel_std = model_args['dataset']['data_transforms']['smooth_kernel_std'],\n", " smooth_kernel_size = model_args['dataset']['data_transforms']['smooth_kernel_size'],\n", " padding = 'valid',\n", " )\n", "\n", " with torch.no_grad():\n", " logits, _ = model(\n", " x = smoothed_x,\n", " day_idx = torch.tensor([input_layer], device=device),\n", " states = None, # no initial states\n", " return_state = True,\n", " )\n", "\n", " # convert both logits and smoothed input from bfloat16 to float32\n", " logits = logits.float().cpu().numpy()\n", " smoothed_input = smoothed_x.float().cpu().numpy()\n", "\n", " # # original order is [BLANK, phonemes..., SIL]\n", " # # rearrange so the order is [BLANK, SIL, phonemes...]\n", " # logits = rearrange_speech_logits_pt(logits)\n", "\n", " return logits, smoothed_input\n", "\n", "\n", "import h5py\n", "def load_h5py_file(file_path, b2txt_csv_df):\n", " data = {\n", " 'neural_features': [],\n", " 'n_time_steps': [],\n", " 'seq_class_ids': [],\n", " 'seq_len': [],\n", " 'transcriptions': [],\n", " 'sentence_label': [],\n", " 'session': [],\n", " 'block_num': [],\n", " 'trial_num': [],\n", " 'corpus': [],\n", " }\n", " # Open the hdf5 file for that day\n", " with h5py.File(file_path, 'r') as f:\n", "\n", " keys = list(f.keys())\n", "\n", " # For each trial in the selected trials in that day\n", " for key in keys:\n", " g = f[key]\n", "\n", " neural_features = g['input_features'][:] # pyright: ignore[reportIndexIssue]\n", " n_time_steps = g.attrs['n_time_steps']\n", " seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None # type: ignore\n", " seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None\n", " transcription = g['transcription'][:] if 'transcription' in g else None # type: ignore\n", " sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None # pyright: ignore[reportIndexIssue]\n", " session = g.attrs['session']\n", " block_num = g.attrs['block_num']\n", " trial_num = g.attrs['trial_num']\n", "\n", " # match this trial up with the csv to get the corpus name\n", " year, month, day = session.split('.')[1:] # pyright: ignore[reportAttributeAccessIssue]\n", " date = f'{year}-{month}-{day}'\n", " row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)]\n", " corpus_name = row['Corpus'].values[0]\n", "\n", " data['neural_features'].append(neural_features)\n", " data['n_time_steps'].append(n_time_steps)\n", " data['seq_class_ids'].append(seq_class_ids)\n", " data['seq_len'].append(seq_len)\n", " data['transcriptions'].append(transcription)\n", " data['sentence_label'].append(sentence_label)\n", " data['session'].append(session)\n", " data['block_num'].append(block_num)\n", " data['trial_num'].append(trial_num)\n", " data['corpus'].append(corpus_name)\n", " return data" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "LOGIT_TO_PHONEME = [\n", " 'BLANK',\n", " 'AA', 'AE', 'AH', 'AO', 'AW',\n", " 'AY', 'B', 'CH', 'D', 'DH',\n", " 'EH', 'ER', 'EY', 'F', 'G',\n", " 'HH', 'IH', 'IY', 'JH', 'K',\n", " 'L', 'M', 'N', 'NG', 'OW',\n", " 'OY', 'P', 'R', 'S', 'SH',\n", " 'T', 'TH', 'UH', 'UW', 'V',\n", " 'W', 'Y', 'Z', 'ZH',\n", " ' | ',\n", "]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 数据分析与预处理" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 数据准备" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working/nejm-brain-to-text\n" ] } ], "source": [ "%cd /kaggle/working/nejm-brain-to-text/" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "data = load_h5py_file(file_path='/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.08.11/data_train.hdf5',\n", " b2txt_csv_df=pd.read_csv('/kaggle/working/nejm-brain-to-text/data/t15_copyTaskData_description.csv'))" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "data2 = load_h5py_file(file_path='/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.08.13/data_train.hdf5',\n", " b2txt_csv_df=pd.read_csv('/kaggle/working/nejm-brain-to-text/data/t15_copyTaskData_description.csv'))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- **任务介绍** :机器学习解决高维信号的模式识别问题" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "我们的数据集标签缺少时间戳,现在要进行的是半监督学习" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- 音素时间均等分割或者按照调研数据设定初始长度。然后筛掉异常值。提取出可用的训练集,再控制时间长短,查看样本类的长度" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'neural_features': array([[ 2.3076649 , -0.78699756, -0.64687246, ..., 0.57367045,\n", " -0.7091646 , -0.11018186],\n", " [-0.5859305 , -0.78699756, -0.64687246, ..., 0.3122117 ,\n", " 1.7943763 , -0.76884896],\n", " [-0.5859305 , -0.78699756, -0.64687246, ..., -0.21193463,\n", " -0.8481289 , -0.7648201 ],\n", " ...,\n", " [-0.5859305 , 0.22756557, 0.9262037 , ..., -0.34710956,\n", " 0.9710176 , 2.5397465 ],\n", " [-0.5859305 , 0.22756557, -0.64687246, ..., -0.83613133,\n", " -0.68723625, 0.10479005],\n", " [ 0.8608672 , -0.78699756, -0.64687246, ..., -0.7171131 ,\n", " 0.7417906 , -0.7008622 ]], dtype=float32),\n", " 'n_time_steps': 321,\n", " 'seq_class_ids': array([ 7, 28, 17, 24, 40, 17, 31, 40, 20, 21, 25, 29, 12, 40, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0], dtype=int32),\n", " 'seq_len': 14,\n", " 'transcriptions': array([ 66, 114, 105, 110, 103, 32, 105, 116, 32, 99, 108, 111, 115,\n", " 101, 114, 46, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0], dtype=int32),\n", " 'sentence_label': 'Bring it closer.',\n", " 'session': 't15.2023.08.11',\n", " 'block_num': 2,\n", " 'trial_num': 0,\n", " 'corpus': '50-Word'}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def data_patch(data, index):\n", " data_patch = {}\n", " data_patch['neural_features'] = data['neural_features'][index]\n", " data_patch['n_time_steps'] = data['n_time_steps'][index]\n", " data_patch['seq_class_ids'] = data['seq_class_ids'][index]\n", " data_patch['seq_len'] = data['seq_len'][index]\n", " data_patch['transcriptions'] = data['transcriptions'][index]\n", " data_patch['sentence_label'] = data['sentence_label'][index]\n", " data_patch['session'] = data['session'][index]\n", " data_patch['block_num'] = data['block_num'][index]\n", " data_patch['trial_num'] = data['trial_num'][index]\n", " data_patch['corpus'] = data['corpus'][index]\n", " return data_patch" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "d1 = data_patch(data, 0)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Transcriptions non-zero length: 16\n", "Seq class ids non-zero length: 14\n", "Seq len: 14\n" ] } ], "source": [ "trans_len = len([x for x in d1['transcriptions'] if x != 0])\n", "seq_len_nonzero = len([x for x in d1['seq_class_ids'] if x != 0])\n", "seq_len = d1['seq_len']\n", "print(f\"Transcriptions non-zero length: {trans_len}\")\n", "print(f\"Seq class ids non-zero length: {seq_len_nonzero}\")\n", "print(f\"Seq len: {seq_len}\")" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Number of feature sequences: 14\n", "Shape of first sequence: (22, 512)\n" ] } ], "source": [ "def create_time_windows(d1):\n", " import numpy as np\n", " n_time_steps = d1['n_time_steps']\n", " seq_len = d1['seq_len']\n", " # Create equal windows\n", " edges = np.linspace(0, n_time_steps, seq_len + 1, dtype=int)\n", " windows = [(edges[i], edges[i+1]) for i in range(seq_len)]\n", " \n", " # Extract feature sequences for each window\n", " feature_sequences = []\n", " for start, end in windows:\n", " seq = d1['neural_features'][start:end, :]\n", " feature_sequences.append(seq)\n", " \n", " return feature_sequences\n", "\n", "# Example usage\n", "feature_sequences = create_time_windows(d1)\n", "print(\"Number of feature sequences:\", len(feature_sequences))\n", "print(\"Shape of first sequence:\", feature_sequences[0].shape)\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Train: 45, Val: 41, Test: 41\n", "Train files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_train.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.08.11/data_train.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_train.hdf5']\n", "Val files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_val.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_val.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2024.03.08/data_val.hdf5']\n", "Test files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_test.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_test.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2024.03.08/data_test.hdf5']\n" ] } ], "source": [ "import os\n", "\n", "def scan_hdf5_files(base_path):\n", " train_files = []\n", " val_files = []\n", " test_files = []\n", " for root, dirs, files in os.walk(base_path):\n", " for file in files:\n", " if file.endswith('.hdf5'):\n", " abs_path = os.path.abspath(os.path.join(root, file))\n", " if 'data_train.hdf5' in file:\n", " train_files.append(abs_path)\n", " elif 'data_val.hdf5' in file:\n", " val_files.append(abs_path)\n", " elif 'data_test.hdf5' in file:\n", " test_files.append(abs_path)\n", " return train_files, val_files, test_files\n", "\n", "# Example usage\n", "FILE_PATH = 'data/hdf5_data_final'\n", "train_list, val_list, test_list = scan_hdf5_files(FILE_PATH)\n", "print(f\"Train: {len(train_list)}, Val: {len(val_list)}, Test: {len(test_list)}\")\n", "print(\"Train files (first 3):\", train_list[:3])\n", "print(\"Val files (first 3):\", val_list[:3])\n", "print(\"Test files (first 3):\", test_list[:3])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 🔗 数据集批量处理工作流" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working/nejm-brain-to-text/model_training\n" ] } ], "source": [ "%cd model_training/" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "======================================================================\n", "🚀 RNN数据批量处理工具 - 增强内存管理版本\n", "======================================================================\n", "🔧 创建RNN数据处理器(增强内存管理版)...\n", "🔧 初始化RNN数据处理器(增强内存管理版)...\n", " 模型路径: ../data/t15_pretrained_rnn_baseline\n", " 数据目录: ../data/hdf5_data_final\n", " 计算设备: cpu\n", " 批量保存间隔: 25 个试验\n", " 内存警告阈值: 80%\n", " 初始内存状态: RAM: 0.52GB (3.8%)\n", "📋 模型配置:\n", " Sessions数量: 45\n", " 神经特征维度: 512\n", " Patch size: 14\n", " Patch stride: 4\n", " 输出类别数: 41\n", "🔄 加载模型... 当前内存: RAM: 0.52GB (3.8%)\n", "✅ 模型加载成功,内存清理后: RAM: 1.18GB (8.6%)\n", "📊 CSV数据加载完成: 265 条记录\n", "✅ 初始化完成!当前内存: RAM: 1.18GB (8.6%)\n", "✅ RNN数据处理器创建成功!\n", "🧠 内存管理功能已启用:\n", " - 批量保存间隔: 25个试验\n", " - 自动内存监控和清理\n", " - GPU内存即时释放\n", " - 垃圾回收优化\n", "✅ 模型加载成功,内存清理后: RAM: 1.18GB (8.6%)\n", "📊 CSV数据加载完成: 265 条记录\n", "✅ 初始化完成!当前内存: RAM: 1.18GB (8.6%)\n", "✅ RNN数据处理器创建成功!\n", "🧠 内存管理功能已启用:\n", " - 批量保存间隔: 25个试验\n", " - 自动内存监控和清理\n", " - GPU内存即时释放\n", " - 垃圾回收优化\n" ] } ], "source": [ "# 🚀 RNN数据批量处理工具 - 完整版(增强内存管理 + 自动上传)\n", "import os\n", "import torch\n", "import numpy as np\n", "import pandas as pd\n", "from omegaconf import OmegaConf\n", "import time\n", "from tqdm import tqdm\n", "import h5py\n", "from pathlib import Path\n", "import gc # 垃圾回收\n", "import psutil # 内存监控\n", "\n", "# 导入模型相关模块\n", "import sys\n", "sys.path.append('../model_training')\n", "from rnn_model import GRUDecoder\n", "from evaluate_model_helpers import *\n", "from data_augmentations import gauss_smooth\n", "\n", "print(\"=\"*70)\n", "print(\"🚀 RNN数据批量处理工具 - 增强内存管理 + 自动上传版本\")\n", "print(\"=\"*70)\n", "\n", "class MemoryManager:\n", " \"\"\"内存管理器 - 监控和清理内存\"\"\"\n", " \n", " @staticmethod\n", " def get_memory_info():\n", " \"\"\"获取内存使用情况\"\"\"\n", " process = psutil.Process()\n", " memory_info = process.memory_info()\n", " memory_percent = process.memory_percent()\n", " \n", " # GPU内存(如果可用)\n", " gpu_memory = \"\"\n", " if torch.cuda.is_available():\n", " gpu_allocated = torch.cuda.memory_allocated() / 1024**3\n", " gpu_reserved = torch.cuda.memory_reserved() / 1024**3\n", " gpu_memory = f\" | GPU: {gpu_allocated:.2f}GB allocated, {gpu_reserved:.2f}GB reserved\"\n", " \n", " return f\"RAM: {memory_info.rss / 1024**3:.2f}GB ({memory_percent:.1f}%){gpu_memory}\"\n", " \n", " @staticmethod\n", " def clear_memory():\n", " \"\"\"清理内存\"\"\"\n", " # 清理Python垃圾回收\n", " collected = gc.collect()\n", " \n", " # 清理GPU内存\n", " if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", " torch.cuda.synchronize()\n", " \n", " return collected\n", " \n", " @staticmethod\n", " def memory_warning_check():\n", " \"\"\"检查内存使用情况并发出警告\"\"\"\n", " memory_percent = psutil.virtual_memory().percent\n", " if memory_percent > 85:\n", " print(f\" 内存使用率过高: {memory_percent:.1f}%\")\n", " return True\n", " return False\n", "\n", "class RNNDataProcessor:\n", " \"\"\"\n", " RNN数据批量处理器 - 生成RNN输入输出拼接数据\n", " 增强版本:优化内存管理,支持大数据集处理,自动上传到WebDAV\n", " \n", " 核心功能:\n", " 1. 加载预训练RNN模型\n", " 2. 处理原始神经数据(高斯平滑 + patch操作)\n", " 3. 获取RNN输出(40类置信度分数)\n", " 4. 拼接处理后的输入和输出\n", " 5. 批量保存所有session数据\n", " 6. 自动上传到WebDAV并删除本地文件\n", " 7. 内存管理和监控\n", " \"\"\"\n", " \n", " def __init__(self, model_path, data_dir, csv_path, device='auto', \n", " batch_save_interval=50, memory_threshold=80, \n", " enable_auto_upload=True, webdav_uploader=None):\n", " \"\"\"\n", " 初始化处理器\n", " \n", " 参数:\n", " model_path: 预训练RNN模型路径\n", " data_dir: 数据目录路径 \n", " csv_path: 数据描述CSV文件路径\n", " device: 计算设备 ('auto', 'cpu', 'cuda:0'等)\n", " batch_save_interval: 批量保存间隔(每N个试验保存一次)\n", " memory_threshold: 内存警告阈值(百分比)\n", " enable_auto_upload: 是否启用自动上传\n", " webdav_uploader: WebDAV上传器实例\n", " \"\"\"\n", " self.model_path = model_path\n", " self.data_dir = data_dir\n", " self.csv_path = csv_path\n", " self.batch_save_interval = batch_save_interval\n", " self.memory_threshold = memory_threshold\n", " self.enable_auto_upload = enable_auto_upload\n", " self.webdav_uploader = webdav_uploader\n", " \n", " # 设备选择\n", " if device == 'auto':\n", " self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n", " else:\n", " self.device = torch.device(device)\n", " \n", " print(f\"🔧 初始化RNN数据处理器(增强内存管理 + 自动上传版)...\")\n", " print(f\" 模型路径: {model_path}\")\n", " print(f\" 数据目录: {data_dir}\")\n", " print(f\" 计算设备: {self.device}\")\n", " print(f\" 批量保存间隔: {batch_save_interval} 个试验\")\n", " print(f\" 内存警告阈值: {memory_threshold}%\")\n", " print(f\" 自动上传: {'✅ 启用' if enable_auto_upload else '❌ 禁用'}\")\n", " \n", " # 初始内存状态\n", " print(f\" 初始内存状态: {MemoryManager.get_memory_info()}\")\n", " \n", " # 加载配置和模型\n", " self._load_config()\n", " self._load_model()\n", " self._load_csv()\n", " \n", " print(f\" 初始化完成!当前内存: {MemoryManager.get_memory_info()}\")\n", " \n", " def _load_config(self):\n", " \"\"\"加载模型配置\"\"\"\n", " config_path = os.path.join(self.model_path, 'checkpoint/args.yaml')\n", " if not os.path.exists(config_path):\n", " raise FileNotFoundError(f\"配置文件不存在: {config_path}\")\n", " \n", " self.model_args = OmegaConf.load(config_path)\n", " \n", " print(f\" 模型配置:\")\n", " print(f\" Sessions数量: {len(self.model_args['dataset']['sessions'])}\")\n", " print(f\" 神经特征维度: {self.model_args['model']['n_input_features']}\")\n", " print(f\" Patch size: {self.model_args['model']['patch_size']}\")\n", " print(f\" Patch stride: {self.model_args['model']['patch_stride']}\")\n", " print(f\" 输出类别数: {self.model_args['dataset']['n_classes']}\")\n", " \n", " def _load_model(self):\n", " \"\"\"加载预训练RNN模型\"\"\"\n", " try:\n", " print(f\" 加载模型... 当前内存: {MemoryManager.get_memory_info()}\")\n", " \n", " # 创建模型\n", " self.model = GRUDecoder(\n", " neural_dim=self.model_args['model']['n_input_features'],\n", " n_units=self.model_args['model']['n_units'], \n", " n_days=len(self.model_args['dataset']['sessions']),\n", " n_classes=self.model_args['dataset']['n_classes'],\n", " rnn_dropout=self.model_args['model']['rnn_dropout'],\n", " input_dropout=self.model_args['model']['input_network']['input_layer_dropout'],\n", " n_layers=self.model_args['model']['n_layers'],\n", " patch_size=self.model_args['model']['patch_size'],\n", " patch_stride=self.model_args['model']['patch_stride'],\n", " )\n", " \n", " # 加载权重\n", " checkpoint_path = os.path.join(self.model_path, 'checkpoint/best_checkpoint')\n", " try:\n", " checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)\n", " except TypeError:\n", " checkpoint = torch.load(checkpoint_path, map_location=self.device)\n", " \n", " # 清理键名\n", " for key in list(checkpoint['model_state_dict'].keys()):\n", " checkpoint['model_state_dict'][key.replace(\"module.\", \"\")] = checkpoint['model_state_dict'].pop(key)\n", " checkpoint['model_state_dict'][key.replace(\"_orig_mod.\", \"\")] = checkpoint['model_state_dict'].pop(key)\n", " \n", " self.model.load_state_dict(checkpoint['model_state_dict'])\n", " self.model.to(self.device)\n", " self.model.eval()\n", " \n", " # 立即清理checkpoint内存\n", " del checkpoint\n", " MemoryManager.clear_memory()\n", " \n", " print(f\" 模型加载成功,内存清理后: {MemoryManager.get_memory_info()}\")\n", " \n", " except Exception as e:\n", " print(f\" 模型加载失败: {e}\")\n", " raise\n", " \n", " def _load_csv(self):\n", " \"\"\"加载数据描述文件\"\"\"\n", " if not os.path.exists(self.csv_path):\n", " raise FileNotFoundError(f\"CSV文件不存在: {self.csv_path}\")\n", " \n", " self.csv_df = pd.read_csv(self.csv_path)\n", " print(f\"📊 CSV数据加载完成: {len(self.csv_df)} 条记录\")\n", " \n", " def _process_single_trial(self, neural_data, session_idx):\n", " \"\"\"\n", " 处理单个试验数据(优化内存使用)\n", " \n", " 参数:\n", " neural_data: 原始神经数据 [time_steps, features]\n", " session_idx: 会话索引\n", " \n", " 返回:\n", " dict: 包含拼接数据和统计信息\n", " \"\"\"\n", " try:\n", " # 添加batch维度\n", " neural_input = np.expand_dims(neural_data, axis=0)\n", " neural_tensor = torch.tensor(neural_input, device=self.device, dtype=torch.bfloat16)\n", " \n", " # 高斯平滑\n", " with torch.autocast(device_type=\"cuda\" if self.device.type == \"cuda\" else \"cpu\", \n", " enabled=self.model_args.get('use_amp', False), dtype=torch.bfloat16):\n", " \n", " smoothed_data = gauss_smooth(\n", " inputs=neural_tensor,\n", " device=self.device,\n", " smooth_kernel_std=self.model_args['dataset']['data_transforms']['smooth_kernel_std'],\n", " smooth_kernel_size=self.model_args['dataset']['data_transforms']['smooth_kernel_size'],\n", " padding='valid',\n", " )\n", " \n", " # Patch操作(复制模型内部逻辑)\n", " processed_data = smoothed_data\n", " if self.model.patch_size > 0:\n", " processed_data = processed_data.unsqueeze(1) # [batch, 1, time, features]\n", " processed_data = processed_data.permute(0, 3, 1, 2) # [batch, features, 1, time]\n", " \n", " # 滑动窗口提取\n", " patches = processed_data.unfold(3, self.model.patch_size, self.model.patch_stride)\n", " patches = patches.squeeze(2) # [batch, features, patches, patch_size]\n", " patches = patches.permute(0, 2, 3, 1) # [batch, patches, patch_size, features]\n", " \n", " # 展平最后两个维度\n", " processed_data = patches.reshape(patches.size(0), patches.size(1), -1)\n", " \n", " # RNN推理\n", " with torch.no_grad():\n", " logits, _ = self.model(\n", " x=smoothed_data,\n", " day_idx=torch.tensor([session_idx], device=self.device),\n", " states=None,\n", " return_state=True,\n", " )\n", " \n", " # 转换为numpy并立即释放GPU内存\n", " processed_features = processed_data.float().cpu().numpy()[0] # [time_steps, processed_features]\n", " confidence_scores = logits.float().cpu().numpy()[0] # [time_steps, 40]\n", " \n", " # 立即清理GPU张量\n", " del neural_tensor, smoothed_data, processed_data, logits\n", " if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", " \n", " # 拼接数据\n", " concatenated = np.concatenate([processed_features, confidence_scores], axis=1)\n", " \n", " return {\n", " 'concatenated_data': concatenated,\n", " 'processed_features': processed_features,\n", " 'confidence_scores': confidence_scores,\n", " 'original_time_steps': neural_data.shape[0],\n", " 'processed_time_steps': concatenated.shape[0],\n", " 'feature_reduction_ratio': concatenated.shape[0] / neural_data.shape[0]\n", " }\n", " \n", " except Exception as e:\n", " # 确保GPU内存清理\n", " if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", " raise e\n", " \n", " def _upload_and_cleanup_file(self, filepath):\n", " \"\"\"上传文件到WebDAV并删除本地文件\"\"\"\n", " if not self.enable_auto_upload or not self.webdav_uploader:\n", " return False\n", " \n", " try:\n", " # 上传文件\n", " filename = os.path.basename(filepath)\n", " success = self.webdav_uploader.upload_file(\n", " str(filepath), \n", " '/移动云盘/DATA/rnn_processed_data/'\n", " )\n", " \n", " if success:\n", " # 删除本地文件\n", " os.remove(filepath)\n", " print(f\" 📤 上传并删除: {filename}\")\n", " return True\n", " else:\n", " print(f\" ❌ 上传失败,保留本地文件: {filename}\")\n", " return False\n", " \n", " except Exception as e:\n", " print(f\" ⚠️ 上传过程出错: {e}\")\n", " return False\n", " \n", " def _save_batch_data(self, results, session_name, data_type, save_path, batch_idx=None):\n", " \"\"\"\n", " 保存批次数据(减少内存占用 + 自动上传)\n", " \n", " 参数:\n", " results: 结果数据\n", " session_name: 会话名称\n", " data_type: 数据类型\n", " save_path: 保存路径\n", " batch_idx: 批次索引(可选)\n", " \"\"\"\n", " if not results['concatenated_data']:\n", " return\n", " \n", " # 生成文件名\n", " if batch_idx is not None:\n", " filename = f\"{session_name}_{data_type}_rnn_processed_batch{batch_idx}.npz\"\n", " else:\n", " filename = f\"{session_name}_{data_type}_rnn_processed.npz\"\n", " \n", " filepath = save_path / filename\n", " \n", " save_data = {\n", " 'concatenated_data': np.array(results['concatenated_data'], dtype=object),\n", " 'processed_features': np.array(results['processed_features'], dtype=object),\n", " 'confidence_scores': np.array(results['confidence_scores'], dtype=object),\n", " 'trial_metadata': np.array(results['trial_metadata'], dtype=object),\n", " }\n", " \n", " # 保存文件\n", " np.savez_compressed(str(filepath), **save_data)\n", " print(f\" 💾 保存批次: {filename} ({len(results['concatenated_data'])} 个试验)\")\n", " \n", " # 自动上传并删除本地文件\n", " if self.enable_auto_upload:\n", " self._upload_and_cleanup_file(filepath)\n", " \n", " # 清理结果数据释放内存\n", " for key in results:\n", " if isinstance(results[key], list):\n", " results[key].clear()\n", " \n", " # 强制垃圾回收\n", " MemoryManager.clear_memory()\n", " \n", " def process_session(self, session_name, data_types=['train', 'val', 'test'], save_dir='./rnn_processed_data'):\n", " \"\"\"\n", " 处理单个session的数据(优化内存管理 + 自动上传)\n", " \n", " 参数:\n", " session_name: 会话名称\n", " data_types: 要处理的数据类型列表\n", " save_dir: 保存目录\n", " \n", " 返回:\n", " dict: 处理结果摘要\n", " \"\"\"\n", " print(f\"\\n 处理会话: {session_name}\")\n", " print(f\" 开始时内存: {MemoryManager.get_memory_info()}\")\n", " \n", " session_idx = self.model_args['dataset']['sessions'].index(session_name)\n", " session_results = {}\n", " \n", " # 确保保存目录存在\n", " save_path = Path(save_dir)\n", " save_path.mkdir(parents=True, exist_ok=True)\n", " \n", " for data_type in data_types:\n", " data_file = os.path.join(self.data_dir, session_name, f'data_{data_type}.hdf5')\n", " \n", " if not os.path.exists(data_file):\n", " print(f\" {data_type} 数据文件不存在,跳过\")\n", " continue\n", " \n", " print(f\" 处理 {data_type} 数据...\")\n", " \n", " try:\n", " # 加载数据\n", " data = load_h5py_file(data_file, self.csv_df)\n", " num_trials = len(data['neural_features'])\n", " \n", " if num_trials == 0:\n", " print(f\" {data_type} 数据为空\")\n", " continue\n", " \n", " # 处理所有试验(批量保存策略)\n", " results = {\n", " 'concatenated_data': [],\n", " 'processed_features': [],\n", " 'confidence_scores': [],\n", " 'trial_metadata': [],\n", " 'processing_stats': []\n", " }\n", " \n", " batch_count = 0\n", " total_processed = 0\n", " uploaded_files = 0\n", " \n", " for trial_idx in tqdm(range(num_trials), desc=f\" {data_type}\", leave=False):\n", " # 检查内存使用情况\n", " if trial_idx % 10 == 0: # 每10个trial检查一次\n", " if MemoryManager.memory_warning_check():\n", " MemoryManager.clear_memory()\n", " print(f\" 🧹 内存清理: {MemoryManager.get_memory_info()}\")\n", " \n", " neural_data = data['neural_features'][trial_idx]\n", " \n", " # 处理单个试验\n", " trial_result = self._process_single_trial(neural_data, session_idx)\n", " \n", " # 保存结果\n", " results['concatenated_data'].append(trial_result['concatenated_data'])\n", " results['processed_features'].append(trial_result['processed_features'])\n", " results['confidence_scores'].append(trial_result['confidence_scores'])\n", " \n", " # 保存元数据\n", " metadata = {\n", " 'session': session_name,\n", " 'data_type': data_type,\n", " 'trial_idx': trial_idx,\n", " 'block_num': data.get('block_num', [None])[trial_idx],\n", " 'trial_num': data.get('trial_num', [None])[trial_idx],\n", " **{k: v for k, v in trial_result.items() if k != 'concatenated_data'}\n", " }\n", " \n", " # 添加真实标签(如果可用)\n", " if data_type in ['train', 'val'] and 'sentence_label' in data:\n", " metadata.update({\n", " 'sentence_label': data['sentence_label'][trial_idx],\n", " 'seq_class_ids': data['seq_class_ids'][trial_idx],\n", " 'seq_len': data['seq_len'][trial_idx]\n", " })\n", " \n", " results['trial_metadata'].append(metadata)\n", " results['processing_stats'].append(trial_result)\n", " total_processed += 1\n", " \n", " # 批量保存策略\n", " if (trial_idx + 1) % self.batch_save_interval == 0 or trial_idx == num_trials - 1:\n", " self._save_batch_data(results, session_name, data_type, save_path, batch_count)\n", " if self.enable_auto_upload:\n", " uploaded_files += 1\n", " batch_count += 1\n", " \n", " # 强制内存清理\n", " MemoryManager.clear_memory()\n", " \n", " # 统计信息\n", " if total_processed > 0:\n", " print(f\" {data_type} 处理完成:\")\n", " print(f\" 试验数: {total_processed}\")\n", " print(f\" 保存批次数: {batch_count}\")\n", " if self.enable_auto_upload:\n", " print(f\" 上传文件数: {uploaded_files}\")\n", " print(f\" 最终内存: {MemoryManager.get_memory_info()}\")\n", " \n", " session_results[data_type] = {\n", " 'total_trials': total_processed,\n", " 'batch_count': batch_count,\n", " 'uploaded_files': uploaded_files if self.enable_auto_upload else 0,\n", " 'files': [f\"{session_name}_{data_type}_rnn_processed_batch{i}.npz\" for i in range(batch_count)]\n", " }\n", " \n", " # 清理大型数据对象\n", " del data\n", " MemoryManager.clear_memory()\n", " \n", " except Exception as e:\n", " print(f\" {data_type} 处理失败: {e}\")\n", " # 确保内存清理\n", " MemoryManager.clear_memory()\n", " continue\n", " \n", " print(f\" 会话完成时内存: {MemoryManager.get_memory_info()}\")\n", " return session_results\n", " \n", " def process_all_sessions(self, data_types=['train', 'val', 'test'], save_dir='./rnn_processed_data'):\n", " \"\"\"\n", " 批量处理所有sessions(优化内存管理 + 自动上传)\n", " \n", " 参数:\n", " data_types: 要处理的数据类型\n", " save_dir: 保存目录\n", " \n", " 返回:\n", " dict: 所有处理结果摘要\n", " \"\"\"\n", " print(f\"\\n 开始批量处理所有会话(增强内存管理 + 自动上传)...\")\n", " print(f\" 目标数据类型: {data_types}\")\n", " print(f\" 保存目录: {save_dir}\")\n", " print(f\" 批量保存间隔: {self.batch_save_interval}\")\n", " print(f\" 自动上传: {'✅ 启用' if self.enable_auto_upload else '❌ 禁用'}\")\n", " print(f\" 初始内存状态: {MemoryManager.get_memory_info()}\")\n", " \n", " save_path = Path(save_dir)\n", " save_path.mkdir(parents=True, exist_ok=True)\n", " \n", " all_results = {}\n", " sessions = self.model_args['dataset']['sessions']\n", " \n", " start_time = time.time()\n", " total_uploaded_files = 0\n", " \n", " for i, session in enumerate(sessions):\n", " print(f\"\\n 进度: {i+1}/{len(sessions)} - {session}\")\n", " \n", " try:\n", " session_results = self.process_session(session, data_types, save_dir)\n", " \n", " if session_results:\n", " all_results[session] = session_results\n", " \n", " # 统计上传文件数\n", " session_uploaded = sum(\n", " type_data.get('uploaded_files', 0) \n", " for type_data in session_results.values()\n", " )\n", " total_uploaded_files += session_uploaded\n", " \n", " print(f\" 会话 {session} 完成\")\n", " if self.enable_auto_upload:\n", " print(f\" 本会话上传文件: {session_uploaded}\")\n", " else:\n", " print(f\" 会话 {session} 无有效数据\")\n", " \n", " # 每处理几个session进行一次深度内存清理\n", " if (i + 1) % 5 == 0:\n", " print(f\" 深度内存清理...\")\n", " collected = MemoryManager.clear_memory()\n", " print(f\" 回收对象数: {collected}, 当前内存: {MemoryManager.get_memory_info()}\")\n", " \n", " except Exception as e:\n", " print(f\" 会话 {session} 处理失败: {e}\")\n", " # 确保内存清理\n", " MemoryManager.clear_memory()\n", " continue\n", " \n", " # 生成总结\n", " end_time = time.time()\n", " processing_time = end_time - start_time\n", " \n", " total_trials = sum(\n", " session_data[data_type]['total_trials']\n", " for session_data in all_results.values()\n", " for data_type in session_data.keys()\n", " )\n", " \n", " total_files = sum(\n", " session_data[data_type]['batch_count']\n", " for session_data in all_results.values()\n", " for data_type in session_data.keys()\n", " )\n", " \n", " print(f\"\\n 批量处理完成!\")\n", " print(f\"⏱ 总耗时: {processing_time/60:.2f} 分钟\")\n", " print(f\" 处理统计:\")\n", " print(f\" 成功会话: {len(all_results)}/{len(sessions)}\")\n", " print(f\" 总试验数: {total_trials}\")\n", " print(f\" 生成文件总数: {total_files}\")\n", " if self.enable_auto_upload:\n", " print(f\" 📤 上传文件总数: {total_uploaded_files}\")\n", " print(f\" 💾 本地保留文件: {total_files - total_uploaded_files}\")\n", " print(f\" 最终内存状态: {MemoryManager.get_memory_info()}\")\n", " print(f\" 数据保存在: {save_dir}\")\n", " \n", " # 保存总结信息\n", " summary = {\n", " 'processing_time': processing_time,\n", " 'total_sessions': len(all_results),\n", " 'total_trials': total_trials,\n", " 'total_files': total_files,\n", " 'uploaded_files': total_uploaded_files if self.enable_auto_upload else 0,\n", " 'auto_upload_enabled': self.enable_auto_upload,\n", " 'data_types': data_types,\n", " 'sessions': list(all_results.keys()),\n", " 'batch_save_interval': self.batch_save_interval,\n", " 'memory_management': True,\n", " 'model_config': {\n", " 'patch_size': self.model_args['model']['patch_size'],\n", " 'patch_stride': self.model_args['model']['patch_stride'],\n", " 'smooth_kernel_size': self.model_args['dataset']['data_transforms']['smooth_kernel_size'],\n", " 'smooth_kernel_std': self.model_args['dataset']['data_transforms']['smooth_kernel_std'],\n", " }\n", " }\n", " \n", " import json\n", " with open(save_path / 'processing_summary.json', 'w') as f:\n", " json.dump(summary, f, indent=2)\n", " \n", " return all_results\n", "\n", "# 检查WebDAV上传器是否可用\n", "try:\n", " # 如果之前的上传器可用,直接使用\n", " if 'uploader' in globals():\n", " webdav_uploader_instance = uploader\n", " print(\"🔗 使用现有的WebDAV上传器\")\n", " else:\n", " # 创建简单的WebDAV上传器\n", " from webdav3.client import Client\n", " \n", " class SimpleWebDAVUploader:\n", " def __init__(self):\n", " self.client = Client({\n", " 'webdav_hostname': 'http://zchens.cn:5244/dav/',\n", " 'webdav_login': 'admin',\n", " 'webdav_password': 'Zccns20050420',\n", " 'webdav_timeout': 30\n", " })\n", " \n", " def upload_file(self, local_file, remote_dir):\n", " try:\n", " filename = os.path.basename(local_file)\n", " remote_path = remote_dir.rstrip('/') + '/' + filename\n", " \n", " if not self.client.check(remote_dir):\n", " self.client.mkdir(remote_dir)\n", " \n", " self.client.upload_sync(remote_path=remote_path, local_path=local_file)\n", " return True\n", " except Exception as e:\n", " print(f\"上传失败: {e}\")\n", " return False\n", " \n", " webdav_uploader_instance = SimpleWebDAVUploader()\n", " print(\"🔗 创建新的WebDAV上传器\")\n", " \n", "except Exception as e:\n", " webdav_uploader_instance = None\n", " print(f\"⚠️ WebDAV上传器不可用: {e}\")\n", "\n", "# 创建处理器实例(增强内存管理 + 自动上传)\n", "print(\"🔧 创建RNN数据处理器(增强内存管理 + 自动上传版)...\")\n", "\n", "try:\n", " processor = RNNDataProcessor(\n", " model_path='../data/t15_pretrained_rnn_baseline',\n", " data_dir='../data/hdf5_data_final',\n", " csv_path='../data/t15_copyTaskData_description.csv',\n", " device='auto',\n", " batch_save_interval=25, # 每25个试验保存一次,减少内存积累\n", " memory_threshold=80, # 80%内存使用率时警告\n", " enable_auto_upload=True, # 启用自动上传\n", " webdav_uploader=webdav_uploader_instance # 传入WebDAV上传器\n", " )\n", " \n", " print(f\" RNN数据处理器创建成功!\")\n", " print(f\" 功能特性:\")\n", " print(f\" - 批量保存间隔: 25个试验\")\n", " print(f\" - 自动内存监控和清理\")\n", " print(f\" - GPU内存即时释放\")\n", " print(f\" - 垃圾回收优化\")\n", " print(f\" - 📤 自动上传到WebDAV\")\n", " print(f\" - 🗑️ 自动删除本地文件\")\n", " \n", "except Exception as e:\n", " print(f\" 处理器创建失败: {e}\")\n", " processor = None" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "======================================================================\n", "🎯 RNN数据批量处理 - 使用示例(内存优化版)\n", "======================================================================\n", "\n", "📋 可用的处理方法:\n", "1️⃣ 单session处理: processor.process_session('session_name')\n", "2️⃣ 批量处理所有: processor.process_all_sessions()\n", "\n", "🧠 内存管理特性:\n", " ✅ 自动批量保存 (每25个试验)\n", " ✅ 实时内存监控和清理\n", " ✅ GPU内存即时释放\n", " ✅ 垃圾回收优化\n", "\n", "📊 可用会话数量: 45\n", "📝 前5个会话: ['t15.2023.08.11', 't15.2023.08.13', 't15.2023.08.18', 't15.2023.08.20', 't15.2023.08.25']\n", "💡 问题会话 t15.2023.09.01 在位置: 6\n", "\n", "🔍 当前系统状态:\n", " RAM: 1.18GB (8.6%)\n", "\n", "🧪 快速测试: 处理会话 't15.2023.08.13' 的训练数据...\n", " 这将测试内存管理功能...\n", "\n", "🔄 处理会话: t15.2023.08.13\n", " 开始时内存: RAM: 1.18GB (8.6%)\n", " 📁 处理 train 数据...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " train: 0%| | 0/348 [00:00 36\u001b[0m single_result \u001b[38;5;241m=\u001b[39m \u001b[43mprocessor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprocess_session\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_session\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mtrain\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 38\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m single_result \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m single_result:\n\u001b[0;32m 39\u001b[0m train_info \u001b[38;5;241m=\u001b[39m single_result[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m'\u001b[39m]\n", "Cell \u001b[1;32mIn[1], line 400\u001b[0m, in \u001b[0;36mRNNDataProcessor.process_session\u001b[1;34m(self, session_name, data_types, save_dir)\u001b[0m\n\u001b[0;32m 398\u001b[0m \u001b[38;5;66;03m# 批量保存策略\u001b[39;00m\n\u001b[0;32m 399\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (trial_idx \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m) \u001b[38;5;241m%\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_save_interval \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m trial_idx \u001b[38;5;241m==\u001b[39m num_trials \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m--> 400\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_save_batch_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mresults\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msession_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msave_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_count\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 401\u001b[0m batch_count \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m 403\u001b[0m \u001b[38;5;66;03m# 强制内存清理\u001b[39;00m\n", "Cell \u001b[1;32mIn[1], line 296\u001b[0m, in \u001b[0;36mRNNDataProcessor._save_batch_data\u001b[1;34m(self, results, session_name, data_type, save_path, batch_idx)\u001b[0m\n\u001b[0;32m 287\u001b[0m filepath \u001b[38;5;241m=\u001b[39m save_path \u001b[38;5;241m/\u001b[39m filename\n\u001b[0;32m 289\u001b[0m save_data \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m 290\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mconcatenated_data\u001b[39m\u001b[38;5;124m'\u001b[39m: np\u001b[38;5;241m.\u001b[39marray(results[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mconcatenated_data\u001b[39m\u001b[38;5;124m'\u001b[39m], dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mobject\u001b[39m),\n\u001b[0;32m 291\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mprocessed_features\u001b[39m\u001b[38;5;124m'\u001b[39m: np\u001b[38;5;241m.\u001b[39marray(results[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mprocessed_features\u001b[39m\u001b[38;5;124m'\u001b[39m], dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mobject\u001b[39m),\n\u001b[0;32m 292\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mconfidence_scores\u001b[39m\u001b[38;5;124m'\u001b[39m: np\u001b[38;5;241m.\u001b[39marray(results[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mconfidence_scores\u001b[39m\u001b[38;5;124m'\u001b[39m], dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mobject\u001b[39m),\n\u001b[0;32m 293\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrial_metadata\u001b[39m\u001b[38;5;124m'\u001b[39m: np\u001b[38;5;241m.\u001b[39marray(results[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrial_metadata\u001b[39m\u001b[38;5;124m'\u001b[39m], dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mobject\u001b[39m),\n\u001b[0;32m 294\u001b[0m }\n\u001b[1;32m--> 296\u001b[0m np\u001b[38;5;241m.\u001b[39msavez_compressed(\u001b[38;5;28mstr\u001b[39m(filepath), \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39msave_data)\n\u001b[0;32m 297\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m 💾 保存批次: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfilename\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(results[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mconcatenated_data\u001b[39m\u001b[38;5;124m'\u001b[39m])\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m 个试验)\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 299\u001b[0m \u001b[38;5;66;03m# 清理结果数据释放内存\u001b[39;00m\n", "File \u001b[1;32md:\\SoftWare\\Anaconda3\\envs\\b2txt25\\lib\\site-packages\\numpy\\lib\\_npyio_impl.py:753\u001b[0m, in \u001b[0;36msavez_compressed\u001b[1;34m(file, *args, **kwds)\u001b[0m\n\u001b[0;32m 689\u001b[0m \u001b[38;5;129m@array_function_dispatch\u001b[39m(_savez_compressed_dispatcher)\n\u001b[0;32m 690\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21msavez_compressed\u001b[39m(file, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds):\n\u001b[0;32m 691\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m 692\u001b[0m \u001b[38;5;124;03m Save several arrays into a single file in compressed ``.npz`` format.\u001b[39;00m\n\u001b[0;32m 693\u001b[0m \n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 751\u001b[0m \n\u001b[0;32m 752\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[1;32m--> 753\u001b[0m \u001b[43m_savez\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwds\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", "File \u001b[1;32md:\\SoftWare\\Anaconda3\\envs\\b2txt25\\lib\\site-packages\\numpy\\lib\\_npyio_impl.py:786\u001b[0m, in \u001b[0;36m_savez\u001b[1;34m(file, args, kwds, compress, allow_pickle, pickle_kwargs)\u001b[0m\n\u001b[0;32m 784\u001b[0m \u001b[38;5;66;03m# always force zip64, gh-10776\u001b[39;00m\n\u001b[0;32m 785\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m zipf\u001b[38;5;241m.\u001b[39mopen(fname, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m, force_zip64\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m fid:\n\u001b[1;32m--> 786\u001b[0m \u001b[38;5;28;43mformat\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite_array\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfid\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 787\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_pickle\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mallow_pickle\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 788\u001b[0m \u001b[43m \u001b[49m\u001b[43mpickle_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpickle_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 790\u001b[0m zipf\u001b[38;5;241m.\u001b[39mclose()\n", "File \u001b[1;32md:\\SoftWare\\Anaconda3\\envs\\b2txt25\\lib\\site-packages\\numpy\\lib\\format.py:746\u001b[0m, in \u001b[0;36mwrite_array\u001b[1;34m(fp, array, version, allow_pickle, pickle_kwargs)\u001b[0m\n\u001b[0;32m 744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m pickle_kwargs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 745\u001b[0m pickle_kwargs \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m--> 746\u001b[0m pickle\u001b[38;5;241m.\u001b[39mdump(array, fp, protocol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m4\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mpickle_kwargs)\n\u001b[0;32m 747\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m array\u001b[38;5;241m.\u001b[39mflags\u001b[38;5;241m.\u001b[39mf_contiguous \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m array\u001b[38;5;241m.\u001b[39mflags\u001b[38;5;241m.\u001b[39mc_contiguous:\n\u001b[0;32m 748\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m isfileobj(fp):\n", "File \u001b[1;32md:\\SoftWare\\Anaconda3\\envs\\b2txt25\\lib\\zipfile.py:1142\u001b[0m, in \u001b[0;36m_ZipWriteFile.write\u001b[1;34m(self, data)\u001b[0m\n\u001b[0;32m 1140\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_crc \u001b[38;5;241m=\u001b[39m crc32(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_crc)\n\u001b[0;32m 1141\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compressor:\n\u001b[1;32m-> 1142\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_compressor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompress\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1143\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compress_size \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(data)\n\u001b[0;32m 1144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_fileobj\u001b[38;5;241m.\u001b[39mwrite(data)\n", "\u001b[1;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "# 使用示例和批量处理(增强内存管理版)\n", "\n", "print(\"=\"*70)\n", "print(\"RNN数据批量处理 - 使用示例(内存优化版)\")\n", "print(\"=\"*70)\n", "\n", "if processor is not None:\n", " \n", " # 方法1: 处理单个session (推荐用于测试)\n", " print(\"\\n可用的处理方法:\")\n", " print(\"1. 单session处理: processor.process_session('session_name')\")\n", " print(\"2. 批量处理所有: processor.process_all_sessions()\")\n", " print(\"\\n内存管理特性:\")\n", " print(\" 自动批量保存 (每25个试验)\")\n", " print(\" 实时内存监控和清理\")\n", " print(\" GPU内存即时释放\")\n", " print(\" 垃圾回收优化\")\n", " \n", " # 显示可用的sessions\n", " sessions = processor.model_args['dataset']['sessions']\n", " print(f\"\\n可用会话数量: {len(sessions)}\")\n", " print(f\"前5个会话: {sessions[:5]}\")\n", " print(f\"问题会话 t15.2023.09.01 在位置: {sessions.index('t15.2023.09.01') if 't15.2023.09.01' in sessions else '未找到'}\")\n", " \n", " # 内存状态检查\n", " print(f\"\\n当前系统状态:\")\n", " print(f\" {MemoryManager.get_memory_info()}\")\n", " \n", " # 快速测试 - 处理第一个session的部分数据\n", " test_session = sessions[1] # 't15.2023.08.11'\n", " \n", " print(f\"\\n快速测试: 处理会话 '{test_session}' 的训练数据...\")\n", " print(f\" 这将测试内存管理功能...\")\n", " \n", " # 处理单个session(仅train数据进行测试)\n", " single_result = processor.process_session(test_session, ['train'])\n", " \n", " if single_result and 'train' in single_result:\n", " train_info = single_result['train']\n", " \n", " print(f\"\\n内存管理测试完成!结果概览:\")\n", " print(f\" 处理的试验数: {train_info['total_trials']}\")\n", " print(f\" 保存的批次数: {train_info['batch_count']}\")\n", " print(f\" 生成的文件: {len(train_info['files'])}\")\n", " print(f\" 内存管理状态: {MemoryManager.get_memory_info()}\")\n", " \n", " # 加载一个批次文件来验证\n", " if train_info['files']:\n", " first_file = Path('./rnn_processed_data') / train_info['files'][0]\n", " if first_file.exists():\n", " test_data = np.load(str(first_file), allow_pickle=True)\n", " sample_data = test_data['concatenated_data'][0]\n", " \n", " print(f\"\\n数据验证 (第一批次):\")\n", " print(f\" 批次文件: {train_info['files'][0]}\")\n", " print(f\" 样本数据形状: {sample_data.shape}\")\n", " print(f\" 特征维度详情:\")\n", " print(f\" - 处理后的神经特征: {sample_data.shape[1] - 41} 维\")\n", " print(f\" - RNN置信度分数: 41 维\")\n", " print(f\" - 总拼接特征: {sample_data.shape[1]} 维\")\n", " print(f\" - 时间步数: {sample_data.shape[0]}\")\n", " \n", " # 显示一些样本元数据\n", " sample_metadata = test_data['trial_metadata'][0]\n", " print(f\" 样本元数据:\")\n", " print(f\" - 原始时间步: {sample_metadata['original_time_steps']}\")\n", " print(f\" - 处理后时间步: {sample_metadata['processed_time_steps']}\")\n", " print(f\" - 时间压缩比: {sample_metadata['feature_reduction_ratio']:.3f}\")\n", " \n", " if 'sentence_label' in sample_metadata:\n", " print(f\" - 句子标签: {sample_metadata['sentence_label']}\")\n", " \n", " # 清理测试数据\n", " del test_data, sample_data, sample_metadata\n", " MemoryManager.clear_memory()\n", " \n", " print(f\"\\n处理大数据集建议:\")\n", " print(f\" 使用增强版处理器,自动内存管理\")\n", " print(f\" 数据自动分批保存,避免内存溢出\") \n", " print(f\" 可安全处理 t15.2023.09.01 等大批次\")\n", " print(f\" 要批量处理所有数据,运行:\")\n", " print(f\" results = processor.process_all_sessions()\")\n", " \n", "else:\n", " print(\"处理器未创建成功,请检查上面的错误信息\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "======================================================================\n", "🚀 批量处理选项\n", "======================================================================\n", "📊 批量处理配置:\n", " 启用批量处理: True\n", " 保存目录: ./rnn_processed_data\n", " 数据类型: ['train', 'val', 'test']\n", " 总会话数: 45\n", "\n", "🚀 开始批量处理所有数据...\n", "⚠️ 这可能需要较长时间(预计30-60分钟)\n", "\n", "🚀 开始批量处理所有会话...\n", " 目标数据类型: ['train', 'val', 'test']\n", " 保存目录: ./rnn_processed_data\n", "\n", "📊 进度: 1/45\n", "\n", "🔄 处理会话: t15.2023.08.11\n", " 📁 处理 train 数据...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ✅ train 处理完成:\n", " 试验数: 288\n", " 时间步范围: 30-251\n", " 特征维度: 7209 (处理后特征: 7169, 置信度: 40)\n", " 平均时间压缩比: 0.240\n", " ⚠️ val 数据文件不存在,跳过\n", " ⚠️ test 数据文件不存在,跳过\n", " 💾 保存: t15.2023.08.11_train_rnn_processed.npz\n", "\n", "📊 进度: 2/45\n", "\n", "🔄 处理会话: t15.2023.08.13\n", " 📁 处理 train 数据...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ✅ train 处理完成:\n", " 试验数: 348\n", " 时间步范围: 55-352\n", " 特征维度: 7209 (处理后特征: 7169, 置信度: 40)\n", " 平均时间压缩比: 0.243\n", " 📁 处理 val 数据...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ✅ val 处理完成:\n", " 试验数: 35\n", " 时间步范围: 90-296\n", " 特征维度: 7209 (处理后特征: 7169, 置信度: 40)\n", " 平均时间压缩比: 0.243\n", " 📁 处理 test 数据...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ " ✅ test 处理完成:\n", " 试验数: 35\n", " 时间步范围: 80-238\n", " 特征维度: 7209 (处理后特征: 7169, 置信度: 40)\n", " 平均时间压缩比: 0.242\n", " 💾 保存: t15.2023.08.13_train_rnn_processed.npz\n", " 💾 保存: t15.2023.08.13_val_rnn_processed.npz\n" ] } ], "source": [ "# 🚀 批量处理所有数据 (增强内存管理 + 自动上传版本)\n", "\n", "print(\"=\"*70)\n", "print(\"批量处理选项 - 内存优化 + 自动上传版\")\n", "print(\"=\"*70)\n", "\n", "# 设置参数\n", "ENABLE_FULL_PROCESSING = True # 设为True开始批量处理 \n", "SAVE_DIR = \"./rnn_processed_data\" # 保存目录\n", "DATA_TYPES = ['train', 'val', 'test'] # 要处理的数据类型\n", "\n", "print(f\" 批量处理配置 (内存优化 + 自动上传版):\")\n", "print(f\" 启用批量处理: {ENABLE_FULL_PROCESSING}\")\n", "print(f\" 保存目录: {SAVE_DIR}\")\n", "print(f\" 数据类型: {DATA_TYPES}\")\n", "print(f\" 总会话数: {len(processor.model_args['dataset']['sessions'])}\")\n", "print(f\" 批量保存策略: 每{processor.batch_save_interval}个试验保存一次\")\n", "print(f\" 内存监控阈值: {processor.memory_threshold}%\")\n", "print(f\" 📤 自动上传: {'✅ 启用' if processor.enable_auto_upload else '❌ 禁用'}\")\n", "print(f\" 🗑️ 自动清理: {'✅ 启用' if processor.enable_auto_upload else '❌ 禁用'}\")\n", "\n", "# 显示新功能优势\n", "print(f\"\\n 🆕 新增功能优势:\")\n", "print(f\" 📤 处理完成后自动上传到WebDAV\")\n", "print(f\" 🗑️ 上传成功后自动删除本地文件\")\n", "print(f\" 💾 节省本地存储空间\")\n", "print(f\" ☁️ 数据安全备份到云端\")\n", "print(f\" 🔄 无需手动管理文件传输\")\n", "\n", "print(f\"\\n 🔧 技术特性:\")\n", "print(f\" 自动分批保存,避免内存积累\")\n", "print(f\" 实时GPU内存清理\")\n", "print(f\" 垃圾回收优化\")\n", "print(f\" 内存使用监控和警告\")\n", "print(f\" 可处理 t15.2023.09.01 等大数据集\")\n", "\n", "if ENABLE_FULL_PROCESSING and processor is not None:\n", " print(f\"\\n 🚀 开始批量处理所有数据(内存优化 + 自动上传版)...\")\n", " print(f\" 这可能需要较长时间(预计30-60分钟)\")\n", " print(f\" 内存不足问题已解决,可安全处理大数据集\")\n", " print(f\" 📤 文件将自动上传到WebDAV并删除本地副本\")\n", " print(f\" 开始时内存状态: {MemoryManager.get_memory_info()}\")\n", " \n", " # 记录处理开始时间\n", " import time\n", " start_processing_time = time.time()\n", " \n", " # 批量处理\n", " all_results = processor.process_all_sessions(\n", " data_types=DATA_TYPES,\n", " save_dir=SAVE_DIR\n", " )\n", " \n", " # 计算处理时间\n", " end_processing_time = time.time()\n", " total_processing_time = end_processing_time - start_processing_time\n", " \n", " print(f\"\\n 🎉 批量处理完成!\")\n", " print(f\" ⏱ 总处理时间: {total_processing_time/60:.2f} 分钟\")\n", " print(f\" 最终内存状态: {MemoryManager.get_memory_info()}\")\n", " \n", " # 详细统计\n", " total_files = 0\n", " total_uploaded = 0\n", " for session_name, session_data in all_results.items():\n", " for data_type, type_data in session_data.items():\n", " total_files += type_data['batch_count']\n", " total_uploaded += type_data.get('uploaded_files', 0)\n", " \n", " print(f\"\\n 📊 处理统计详情:\")\n", " print(f\" 成功处理的会话: {len(all_results)}\")\n", " print(f\" 生成文件总数: {total_files}\")\n", " print(f\" 📤 成功上传文件: {total_uploaded}\")\n", " print(f\" 💾 本地保留文件: {total_files - total_uploaded}\")\n", " print(f\" 🔄 上传成功率: {(total_uploaded/total_files*100) if total_files > 0 else 0:.1f}%\")\n", " \n", " # 存储空间统计\n", " try:\n", " import os\n", " local_size = 0\n", " if os.path.exists(SAVE_DIR):\n", " for root, dirs, files in os.walk(SAVE_DIR):\n", " for file in files:\n", " local_size += os.path.getsize(os.path.join(root, file))\n", " \n", " print(f\" 💾 剩余本地文件大小: {local_size / 1024**3:.2f} GB\")\n", " except:\n", " pass\n", " \n", " print(f\"\\n ☁️ WebDAV云端存储:\")\n", " print(f\" 远程路径: /移动云盘/DATA/rnn_processed_data/\")\n", " print(f\" 文件格式: session_name_datatype_rnn_processed_batchN.npz\")\n", " print(f\" 例如: t15.2023.08.13_train_rnn_processed_batch0.npz\")\n", " \n", " if total_uploaded > 0:\n", " print(f\"\\n ✅ 自动上传工作流成功!\")\n", " print(f\" 所有处理完的文件已自动上传到云端\")\n", " print(f\" 本地存储空间得到有效管理\")\n", " print(f\" 数据安全性和可访问性得到保障\")\n", " \n", "else:\n", " print(f\"\\n 要开始批量处理,请将 ENABLE_FULL_PROCESSING 设为 True\")\n", " print(f\" 或者手动运行: processor.process_all_sessions()\")\n", "\n", "print(f\"\\n 📋 数据使用说明(自动上传版本):\")\n", "print(f\" 🔄 处理流程:\")\n", "print(f\" 1. 处理原始神经数据 → RNN输出\")\n", "print(f\" 2. 保存到本地 (.npz 文件)\")\n", "print(f\" 3. 自动上传到WebDAV云端\")\n", "print(f\" 4. 删除本地文件,释放存储空间\")\n", "print(f\"\")\n", "print(f\" 📤 云端文件结构:\")\n", "print(f\" /移动云盘/DATA/rnn_processed_data/\")\n", "print(f\" ├── t15.2023.08.11_train_rnn_processed_batch0.npz\")\n", "print(f\" ├── t15.2023.08.11_train_rnn_processed_batch1.npz\")\n", "print(f\" ├── t15.2023.08.11_val_rnn_processed_batch0.npz\")\n", "print(f\" └── ...\")\n", "print(f\"\")\n", "print(f\" 📥 下载和使用:\")\n", "print(f\" # 从WebDAV下载文件\")\n", "print(f\" uploader.client.download_sync(\")\n", "print(f\" remote_path='/移动云盘/DATA/rnn_processed_data/filename.npz',\")\n", "print(f\" local_path='./filename.npz'\")\n", "print(f\" )\")\n", "print(f\" \")\n", "print(f\" # 加载数据\")\n", "print(f\" data = np.load('filename.npz', allow_pickle=True)\")\n", "print(f\" features = data['concatenated_data']\")\n", "print(f\" metadata = data['trial_metadata']\")\n", "print(f\"\")\n", "print(f\" 🎯 优势总结:\")\n", "print(f\" ✅ 解决了 t15.2023.09.01 内存不足问题\")\n", "print(f\" ✅ 数据自动分批,便于后续加载\")\n", "print(f\" ✅ 处理速度优化,内存使用稳定\")\n", "print(f\" ✅ 错误恢复能力强,单个批次失败不影响整体\")\n", "print(f\" 🆕 自动上传,无需手动管理文件\")\n", "print(f\" 🆕 本地存储空间自动释放\")\n", "print(f\" 🆕 云端数据安全备份\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 📁 云端文件管理工具\n", "\n", "def list_cloud_files(remote_dir='/移动云盘/DATA/rnn_processed_data/'):\n", " \"\"\"列出云端的所有处理文件\"\"\"\n", " if 'uploader' not in globals():\n", " print(\"❌ WebDAV上传器未初始化\")\n", " return []\n", " \n", " try:\n", " files = uploader.client.list(remote_dir)\n", " rnn_files = [f for f in files if f.endswith('.npz')]\n", " \n", " print(f\"☁️ 云端文件列表 ({len(rnn_files)} 个文件):\")\n", " print(f\"📍 路径: {remote_dir}\")\n", " print(\"-\" * 60)\n", " \n", " for i, file in enumerate(rnn_files, 1):\n", " print(f\"{i:3d}. {file}\")\n", " \n", " return rnn_files\n", " \n", " except Exception as e:\n", " print(f\"❌ 获取文件列表失败: {e}\")\n", " return []\n", "\n", "def download_cloud_file(filename, local_dir='./downloaded_data/'):\n", " \"\"\"从云端下载单个文件\"\"\"\n", " if 'uploader' not in globals():\n", " print(\"❌ WebDAV上传器未初始化\")\n", " return False\n", " \n", " try:\n", " # 确保本地目录存在\n", " os.makedirs(local_dir, exist_ok=True)\n", " \n", " remote_path = f'/移动云盘/DATA/rnn_processed_data/{filename}'\n", " local_path = os.path.join(local_dir, filename)\n", " \n", " uploader.client.download_sync(remote_path=remote_path, local_path=local_path)\n", " \n", " print(f\"✅ 下载成功: {filename}\")\n", " print(f\"📁 保存到: {local_path}\")\n", " \n", " # 显示文件信息\n", " if os.path.exists(local_path):\n", " file_size = os.path.getsize(local_path) / 1024**2 # MB\n", " print(f\"📊 文件大小: {file_size:.2f} MB\")\n", " \n", " return True\n", " \n", " except Exception as e:\n", " print(f\"❌ 下载失败: {e}\")\n", " return False\n", "\n", "def download_session_files(session_name, data_types=['train', 'val', 'test'], local_dir='./downloaded_data/'):\n", " \"\"\"下载指定会话的所有文件\"\"\"\n", " files = list_cloud_files()\n", " \n", " session_files = []\n", " for file in files:\n", " if file.startswith(session_name):\n", " for data_type in data_types:\n", " if f'_{data_type}_' in file:\n", " session_files.append(file)\n", " break\n", " \n", " if not session_files:\n", " print(f\"❌ 未找到会话 {session_name} 的文件\")\n", " return False\n", " \n", " print(f\"\\n📥 下载会话 {session_name} 的文件...\")\n", " success_count = 0\n", " \n", " for file in session_files:\n", " if download_cloud_file(file, local_dir):\n", " success_count += 1\n", " \n", " print(f\"\\n✅ 下载完成: {success_count}/{len(session_files)} 个文件\")\n", " return success_count == len(session_files)\n", "\n", "def check_local_storage():\n", " \"\"\"检查本地存储使用情况\"\"\"\n", " print(\"💾 本地存储检查:\")\n", " \n", " # 检查处理数据目录\n", " if os.path.exists('./rnn_processed_data'):\n", " total_size = 0\n", " file_count = 0\n", " \n", " for root, dirs, files in os.walk('./rnn_processed_data'):\n", " for file in files:\n", " if file.endswith('.npz'):\n", " filepath = os.path.join(root, file)\n", " total_size += os.path.getsize(filepath)\n", " file_count += 1\n", " \n", " print(f\" 📁 ./rnn_processed_data/\")\n", " print(f\" 📊 文件数量: {file_count}\")\n", " print(f\" 📊 总大小: {total_size / 1024**3:.2f} GB\")\n", " \n", " if file_count > 0:\n", " print(f\" 💡 建议: 这些文件已处理完成,可以删除以释放空间\")\n", " print(f\" 使用: rm -rf ./rnn_processed_data/\")\n", " else:\n", " print(f\" ✅ ./rnn_processed_data/ 目录不存在或为空\")\n", " \n", " # 检查下载目录\n", " if os.path.exists('./downloaded_data'):\n", " download_size = 0\n", " download_count = 0\n", " \n", " for root, dirs, files in os.walk('./downloaded_data'):\n", " for file in files:\n", " if file.endswith('.npz'):\n", " filepath = os.path.join(root, file)\n", " download_size += os.path.getsize(filepath)\n", " download_count += 1\n", " \n", " print(f\" 📁 ./downloaded_data/\")\n", " print(f\" 📊 下载文件数: {download_count}\")\n", " print(f\" 📊 下载大小: {download_size / 1024**3:.2f} GB\")\n", "\n", "# 使用示例\n", "print(\"📁 云端文件管理工具已加载!\")\n", "print(\"\\n🛠 可用函数:\")\n", "print(\"• list_cloud_files() # 列出所有云端文件\")\n", "print(\"• download_cloud_file('filename.npz') # 下载单个文件\")\n", "print(\"• download_session_files('t15.2023.08.13') # 下载指定会话的所有文件\")\n", "print(\"• check_local_storage() # 检查本地存储使用情况\")\n", "\n", "print(\"\\n💡 使用示例:\")\n", "print(\"# 查看云端有哪些文件\")\n", "print(\"files = list_cloud_files()\")\n", "print(\"\")\n", "print(\"# 下载特定文件\")\n", "print(\"download_cloud_file('t15.2023.08.13_train_rnn_processed_batch0.npz')\")\n", "print(\"\")\n", "print(\"# 下载整个会话的数据\")\n", "print(\"download_session_files('t15.2023.08.13')\")\n", "print(\"\")\n", "print(\"# 检查本地存储\")\n", "print(\"check_local_storage()\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 📤 WebDAV文件上传工具" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 📚 WebDAV库选择 - 现成的解决方案" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "✅ webdavclient3 可用\n", "修复后的WebDAV上传函数:\n", "1. 单个文件上传会自动添加文件名到远程路径\n", "2. 目录上传会过滤掉 .git 等不需要的文件\n", "3. 显示详细的上传和跳过信息\n", "\n", "测试单个文件上传...\n", "单个文件上传结果: {'success': True, 'local_file': 'F:\\\\BRAIN-TO-TEXT\\\\nejm-brain-to-text\\\\brain-to-text-25\\\\client_secrets.json', 'remote_path': '/移动云盘/DATA/client_secrets.json', 'library': 'webdavclient3'}\n", "\n", "测试目录上传(带过滤)...\n", "单个文件上传结果: {'success': True, 'local_file': 'F:\\\\BRAIN-TO-TEXT\\\\nejm-brain-to-text\\\\brain-to-text-25\\\\client_secrets.json', 'remote_path': '/移动云盘/DATA/client_secrets.json', 'library': 'webdavclient3'}\n", "\n", "测试目录上传(带过滤)...\n", "目录上传结果: {'success': True, 'local_dir': 'F:\\\\BRAIN-TO-TEXT\\\\nejm-brain-to-text\\\\data-kaggle', 'remote_dir': '/移动云盘/DATA/data-kaggle', 'uploaded_files': [], 'skipped_files': [], 'library': 'webdavclient3'}\n", "上传了 0 个文件\n", "跳过了 0 个文件\n", "目录上传结果: {'success': True, 'local_dir': 'F:\\\\BRAIN-TO-TEXT\\\\nejm-brain-to-text\\\\data-kaggle', 'remote_dir': '/移动云盘/DATA/data-kaggle', 'uploaded_files': [], 'skipped_files': [], 'library': 'webdavclient3'}\n", "上传了 0 个文件\n", "跳过了 0 个文件\n" ] } ], "source": [ "# 📋 WebDAV库安装指南\n", "\n", "print(\"📦 安装WebDAV客户端库:\")\n", "print(\"pip install webdavclient3\")\n", "print(\"\")\n", "print(\"💡 如果已安装,下面的简化版上传工具就可以直接使用了!\")\n", "print(\" 所有复杂的代码都已简化为易用的类和函数。\")\n", "\n", "# 检查安装状态\n", "try:\n", " from webdav3.client import Client\n", " print(\"✅ webdavclient3 已安装并可用\")\n", "except ImportError:\n", " print(\"❌ 需要安装: pip install webdavclient3\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "✅ WebDAV连接已建立: http://zchens.cn:5244/dav/\n", "\n", "📖 使用方法:\n", "# 上传单个文件\n", "uploader.upload_file(r'F:\\path\\to\\file.txt')\n", "\n", "# 上传目录(自动过滤.git等文件)\n", "uploader.upload_dir(r'F:\\path\\to\\dir', '/移动云盘/DATA/目标目录')\n", "\n", "# 上传目录(不过滤任何文件)\n", "uploader.upload_dir(r'F:\\path\\to\\dir', '/移动云盘/DATA/目标目录', exclude_git=False)\n" ] } ], "source": [ "# 📤 简化版WebDAV上传工具\n", "\n", "from webdav3.client import Client\n", "import os\n", "import fnmatch\n", "\n", "class WebDAVUploader:\n", " \"\"\"简化的WebDAV上传器\"\"\"\n", " \n", " def __init__(self, url='http://zchens.cn:5244/dav/', username='admin', password='Zccns20050420'):\n", " \"\"\"初始化WebDAV连接\"\"\"\n", " self.client = Client({\n", " 'webdav_hostname': url,\n", " 'webdav_login': username,\n", " 'webdav_password': password,\n", " 'webdav_timeout': 30\n", " })\n", " print(f\"✅ WebDAV连接已建立: {url}\")\n", " \n", " def upload_file(self, local_file, remote_dir='/移动云盘/DATA/'):\n", " \"\"\"上传单个文件\"\"\"\n", " try:\n", " filename = os.path.basename(local_file)\n", " remote_path = remote_dir.rstrip('/') + '/' + filename\n", " \n", " # 确保远程目录存在\n", " if not self.client.check(remote_dir):\n", " self.client.mkdir(remote_dir)\n", " \n", " self.client.upload_sync(remote_path=remote_path, local_path=local_file)\n", " print(f\"✅ 文件上传成功: {filename} -> {remote_path}\")\n", " return True\n", " \n", " except Exception as e:\n", " print(f\"❌ 文件上传失败: {e}\")\n", " return False\n", " \n", " def upload_dir(self, local_dir, remote_dir, exclude_git=True):\n", " \"\"\"上传目录(自动过滤不需要的文件)\"\"\"\n", " exclude_patterns = ['.git*', '__pycache__*', '*.pyc', '.vscode*'] if exclude_git else []\n", " \n", " try:\n", " uploaded = 0\n", " skipped = 0\n", " \n", " for root, dirs, files in os.walk(local_dir):\n", " # 过滤目录\n", " if exclude_git:\n", " dirs[:] = [d for d in dirs if not any(fnmatch.fnmatch(d, p) for p in exclude_patterns)]\n", " \n", " for file in files:\n", " # 检查是否跳过文件\n", " if exclude_git and any(fnmatch.fnmatch(file, p) for p in exclude_patterns):\n", " skipped += 1\n", " continue\n", " \n", " local_file = os.path.join(root, file)\n", " rel_path = os.path.relpath(local_file, local_dir)\n", " remote_file = remote_dir.rstrip('/') + '/' + rel_path.replace('\\\\', '/')\n", " \n", " # 确保远程目录存在\n", " remote_file_dir = '/'.join(remote_file.split('/')[:-1])\n", " if not self.client.check(remote_file_dir):\n", " self.client.mkdir(remote_file_dir)\n", " \n", " self.client.upload_sync(remote_path=remote_file, local_path=local_file)\n", " uploaded += 1\n", " \n", " print(f\"✅ 目录上传完成: 上传 {uploaded} 个文件,跳过 {skipped} 个文件\")\n", " return True\n", " \n", " except Exception as e:\n", " print(f\"❌ 目录上传失败: {e}\")\n", " return False\n", "\n", "# 创建上传器实例\n", "uploader = WebDAVUploader()\n", "\n", "print(\"\\n📖 使用方法:\")\n", "print(\"# 上传单个文件\")\n", "print(\"uploader.upload_file(r'F:\\\\path\\\\to\\\\file.txt')\")\n", "print(\"\")\n", "print(\"# 上传目录(自动过滤.git等文件)\") \n", "print(\"uploader.upload_dir(r'F:\\\\path\\\\to\\\\dir', '/移动云盘/DATA/目标目录')\")\n", "print(\"\")\n", "print(\"# 上传目录(不过滤任何文件)\")\n", "print(\"uploader.upload_dir(r'F:\\\\path\\\\to\\\\dir', '/移动云盘/DATA/目标目录', exclude_git=False)\")" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==================================================\n", "WebDAV上传示例\n", "==================================================\n", "\n", "1️⃣ 上传单个文件:\n", "✅ 文件上传成功: client_secrets.json -> /移动云盘/DATA/client_secrets.json\n", "\n", "2️⃣ 上传目录(过滤.git等文件):\n", "✅ 目录上传完成: 上传 0 个文件,跳过 0 个文件\n", "\n", "3️⃣ 上传RNN处理结果:\n", "✅ 文件上传成功: client_secrets.json -> /移动云盘/DATA/client_secrets.json\n", "\n", "2️⃣ 上传目录(过滤.git等文件):\n", "✅ 目录上传完成: 上传 0 个文件,跳过 0 个文件\n", "\n", "3️⃣ 上传RNN处理结果:\n", "❌ 目录上传失败: HTTPConnectionPool(host='127.0.0.1', port=7897): Read timed out. (read timeout=30)\n", "RNN处理结果上传完成!\n", "\n", "✨ 上传任务完成!现在你的文件应该在云盘中了。\n" ] } ], "source": [ "# 🚀 快速使用示例\n", "\n", "print(\"=\"*50)\n", "print(\"WebDAV上传示例\")\n", "print(\"=\"*50)\n", "\n", "# 示例1: 上传单个文件\n", "print(\"\\n1️⃣ 上传单个文件:\")\n", "uploader.upload_file(r'F:\\BRAIN-TO-TEXT\\nejm-brain-to-text\\brain-to-text-25\\client_secrets.json')\n", "\n", "# 示例2: 上传目录(自动过滤.git)\n", "print(\"\\n2️⃣ 上传目录(过滤.git等文件):\")\n", "uploader.upload_dir(\n", " r'F:\\BRAIN-TO-TEXT\\nejm-brain-to-text\\data-kaggle', \n", " '/移动云盘/DATA/data-kaggle-clean'\n", ")\n", "\n", "# 示例3: 上传处理后的数据目录\n", "print(\"\\n3️⃣ 上传RNN处理结果:\")\n", "if os.path.exists('./rnn_processed_data'):\n", " uploader.upload_dir(\n", " './rnn_processed_data',\n", " '/移动云盘/DATA/rnn_processed_data'\n", " )\n", " print(\"RNN处理结果上传完成!\")\n", "else:\n", " print(\"⚠️ 未找到RNN处理数据目录\")\n", "\n", "print(\"\\n✨ 上传任务完成!现在你的文件应该在云盘中了。\")" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "🎯 便捷函数已定义!\n", "\n", "可用的快捷函数:\n", "• quick_upload_file('文件路径') # 上传单个文件\n", "• quick_upload_project('项目目录') # 上传整个项目\n", "• quick_upload_results() # 上传所有结果文件\n", "\n", "例如:\n", "quick_upload_file('data.csv')\n", "quick_upload_project(r'F:\\BRAIN-TO-TEXT\\nejm-brain-to-text')\n", "quick_upload_results()\n" ] } ], "source": [ "# 🎯 便捷函数 - 一键上传常用文件\n", "\n", "def quick_upload_file(file_path, remote_dir='/移动云盘/DATA/'):\n", " \"\"\"快速上传单个文件\"\"\"\n", " return uploader.upload_file(file_path, remote_dir)\n", "\n", "def quick_upload_project(project_dir, remote_name=None):\n", " \"\"\"快速上传整个项目目录(自动过滤.git等)\"\"\"\n", " if remote_name is None:\n", " remote_name = os.path.basename(project_dir.rstrip('/\\\\'))\n", " \n", " remote_dir = f'/移动云盘/DATA/{remote_name}'\n", " return uploader.upload_dir(project_dir, remote_dir, exclude_git=True)\n", "\n", "def quick_upload_results():\n", " \"\"\"快速上传所有结果文件\"\"\"\n", " results = []\n", " \n", " # 上传RNN处理结果\n", " if os.path.exists('./rnn_processed_data'):\n", " print(\"📊 上传RNN处理结果...\")\n", " results.append(uploader.upload_dir('./rnn_processed_data', '/移动云盘/DATA/rnn_processed_data'))\n", " \n", " # 上传notebook文件\n", " notebook_files = [f for f in os.listdir('.') if f.endswith('.ipynb')]\n", " for nb in notebook_files:\n", " print(f\"📓 上传notebook: {nb}\")\n", " results.append(uploader.upload_file(nb, '/移动云盘/DATA/notebooks/'))\n", " \n", " success_count = sum(results)\n", " print(f\"\\n✅ 完成!成功上传 {success_count}/{len(results)} 个项目\")\n", " return all(results)\n", "\n", "# 使用示例\n", "print(\"🎯 便捷函数已定义!\")\n", "print(\"\\n可用的快捷函数:\")\n", "print(\"• quick_upload_file('文件路径') # 上传单个文件\")\n", "print(\"• quick_upload_project('项目目录') # 上传整个项目\")\n", "print(\"• quick_upload_results() # 上传所有结果文件\")\n", "print(\"\\n例如:\")\n", "print(\"quick_upload_file('data.csv')\")\n", "print(\"quick_upload_project(r'F:\\\\BRAIN-TO-TEXT\\\\nejm-brain-to-text')\")\n", "print(\"quick_upload_results()\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kaggle": { "accelerator": "tpu1vmV38", "dataSources": [ { "databundleVersionId": 13056355, "sourceId": 106809, "sourceType": "competition" } ], "dockerImageVersionId": 31091, "isGpuEnabled": false, "isInternetEnabled": true, "language": "python", "sourceType": "notebook" }, "kernelspec": { "display_name": "b2txt25", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 4 }