2670 lines
156 KiB
Plaintext
2670 lines
156 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 环境配置 与 utils"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Looking in indexes: https://download.pytorch.org/whl/cu126\n",
|
||
"Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n",
|
||
"Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (0.21.0+cu124)\n",
|
||
"Requirement already satisfied: torchaudio in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n",
|
||
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n",
|
||
"Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.14.0)\n",
|
||
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.5)\n",
|
||
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n",
|
||
"Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.5.1)\n",
|
||
"Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)\n",
|
||
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
|
||
"Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)\n",
|
||
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
|
||
"Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)\n",
|
||
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
|
||
"Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)\n",
|
||
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
|
||
"Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)\n",
|
||
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
|
||
"Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)\n",
|
||
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
|
||
"Collecting nvidia-curand-cu12==10.3.5.147 (from torch)\n",
|
||
" Downloading https://download.pytorch.org/whl/cu126/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
|
||
"Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)\n",
|
||
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
|
||
"Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch)\n",
|
||
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
|
||
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n",
|
||
"Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n",
|
||
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
|
||
"Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch)\n",
|
||
" Downloading https://download.pytorch.org/whl/cu126/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
|
||
"Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n",
|
||
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n",
|
||
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n",
|
||
"Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision) (1.26.4)\n",
|
||
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (11.2.1)\n",
|
||
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n",
|
||
"Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.3.8)\n",
|
||
"Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.2.4)\n",
|
||
"Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (0.1.1)\n",
|
||
"Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2025.2.0)\n",
|
||
"Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2022.2.0)\n",
|
||
"Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2.4.1)\n",
|
||
"Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2024.2.0)\n",
|
||
"Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2022.2.0)\n",
|
||
"Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy->torchvision) (1.4.0)\n",
|
||
"Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy->torchvision) (2024.2.0)\n",
|
||
"Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy->torchvision) (2024.2.0)\n",
|
||
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 363.4/363.4 MB 4.5 MB/s eta 0:00:00\n",
|
||
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.8/13.8 MB 84.2 MB/s eta 0:00:00\n",
|
||
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.6/24.6 MB 78.5 MB/s eta 0:00:00\n",
|
||
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 883.7/883.7 kB 41.1 MB/s eta 0:00:00\n",
|
||
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 2.1 MB/s eta 0:00:00\n",
|
||
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 211.5/211.5 MB 2.1 MB/s eta 0:00:00\n",
|
||
"Downloading https://download.pytorch.org/whl/cu126/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.3/56.3 MB 30.7 MB/s eta 0:00:00\n",
|
||
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 127.9/127.9 MB 12.8 MB/s eta 0:00:00\n",
|
||
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.5/207.5 MB 7.8 MB/s eta 0:00:00\n",
|
||
"Downloading https://download.pytorch.org/whl/cu126/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.1/21.1 MB 79.7 MB/s eta 0:00:00\n",
|
||
"Installing collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12\n",
|
||
" Attempting uninstall: nvidia-nvjitlink-cu12\n",
|
||
" Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n",
|
||
" Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n",
|
||
" Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n",
|
||
" Attempting uninstall: nvidia-curand-cu12\n",
|
||
" Found existing installation: nvidia-curand-cu12 10.3.6.82\n",
|
||
" Uninstalling nvidia-curand-cu12-10.3.6.82:\n",
|
||
" Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n",
|
||
" Attempting uninstall: nvidia-cufft-cu12\n",
|
||
" Found existing installation: nvidia-cufft-cu12 11.2.3.61\n",
|
||
" Uninstalling nvidia-cufft-cu12-11.2.3.61:\n",
|
||
" Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n",
|
||
" Attempting uninstall: nvidia-cuda-runtime-cu12\n",
|
||
" Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n",
|
||
" Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n",
|
||
" Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n",
|
||
" Attempting uninstall: nvidia-cuda-nvrtc-cu12\n",
|
||
" Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n",
|
||
" Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n",
|
||
" Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n",
|
||
" Attempting uninstall: nvidia-cuda-cupti-cu12\n",
|
||
" Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n",
|
||
" Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n",
|
||
" Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n",
|
||
" Attempting uninstall: nvidia-cublas-cu12\n",
|
||
" Found existing installation: nvidia-cublas-cu12 12.5.3.2\n",
|
||
" Uninstalling nvidia-cublas-cu12-12.5.3.2:\n",
|
||
" Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n",
|
||
" Attempting uninstall: nvidia-cusparse-cu12\n",
|
||
" Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n",
|
||
" Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n",
|
||
" Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n",
|
||
" Attempting uninstall: nvidia-cudnn-cu12\n",
|
||
" Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n",
|
||
" Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n",
|
||
" Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n",
|
||
" Attempting uninstall: nvidia-cusolver-cu12\n",
|
||
" Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n",
|
||
" Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n",
|
||
" Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n",
|
||
"Successfully installed nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127\n",
|
||
"Collecting jupyter==1.1.1\n",
|
||
" Downloading jupyter-1.1.1-py2.py3-none-any.whl.metadata (2.0 kB)\n",
|
||
"Requirement already satisfied: numpy<2.1.0,>=1.26.0 in /usr/local/lib/python3.11/dist-packages (1.26.4)\n",
|
||
"Collecting pandas==2.3.0\n",
|
||
" Downloading pandas-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (91 kB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 91.2/91.2 kB 4.6 MB/s eta 0:00:00\n",
|
||
"Collecting matplotlib==3.10.1\n",
|
||
" Downloading matplotlib-3.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)\n",
|
||
"Collecting scipy==1.15.2\n",
|
||
" Downloading scipy-1.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 62.0/62.0 kB 3.0 MB/s eta 0:00:00\n",
|
||
"Collecting scikit-learn==1.6.1\n",
|
||
" Downloading scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (18 kB)\n",
|
||
"Requirement already satisfied: tqdm==4.67.1 in /usr/local/lib/python3.11/dist-packages (4.67.1)\n",
|
||
"Collecting g2p_en==2.1.0\n",
|
||
" Downloading g2p_en-2.1.0-py3-none-any.whl.metadata (4.5 kB)\n",
|
||
"Collecting h5py==3.13.0\n",
|
||
" Downloading h5py-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.5 kB)\n",
|
||
"Requirement already satisfied: omegaconf==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n",
|
||
"Requirement already satisfied: editdistance==0.8.1 in /usr/local/lib/python3.11/dist-packages (0.8.1)\n",
|
||
"Requirement already satisfied: huggingface-hub==0.33.1 in /usr/local/lib/python3.11/dist-packages (0.33.1)\n",
|
||
"Collecting transformers==4.53.0\n",
|
||
" Downloading transformers-4.53.0-py3-none-any.whl.metadata (39 kB)\n",
|
||
"Requirement already satisfied: tokenizers==0.21.2 in /usr/local/lib/python3.11/dist-packages (0.21.2)\n",
|
||
"Requirement already satisfied: accelerate==1.8.1 in /usr/local/lib/python3.11/dist-packages (1.8.1)\n",
|
||
"Collecting bitsandbytes==0.46.0\n",
|
||
" Downloading bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)\n",
|
||
"Requirement already satisfied: notebook in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.5.4)\n",
|
||
"Requirement already satisfied: jupyter-console in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.1.0)\n",
|
||
"Requirement already satisfied: nbconvert in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.4.5)\n",
|
||
"Requirement already satisfied: ipykernel in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.17.1)\n",
|
||
"Requirement already satisfied: ipywidgets in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (8.1.5)\n",
|
||
"Requirement already satisfied: jupyterlab in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (3.6.8)\n",
|
||
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2.9.0.post0)\n",
|
||
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n",
|
||
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n",
|
||
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.3.2)\n",
|
||
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (0.12.1)\n",
|
||
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (4.58.4)\n",
|
||
"Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.4.8)\n",
|
||
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (25.0)\n",
|
||
"Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (11.2.1)\n",
|
||
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (3.0.9)\n",
|
||
"Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (1.5.1)\n",
|
||
"Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (3.6.0)\n",
|
||
"Requirement already satisfied: nltk>=3.2.4 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (3.9.1)\n",
|
||
"Requirement already satisfied: inflect>=0.3.1 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (7.5.0)\n",
|
||
"Collecting distance>=0.1.3 (from g2p_en==2.1.0)\n",
|
||
" Downloading Distance-0.1.3.tar.gz (180 kB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 180.3/180.3 kB 9.7 MB/s eta 0:00:00\n",
|
||
" Preparing metadata (setup.py): started\n",
|
||
" Preparing metadata (setup.py): finished with status 'done'\n",
|
||
"Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (4.9.3)\n",
|
||
"Requirement already satisfied: PyYAML>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (6.0.2)\n",
|
||
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (3.18.0)\n",
|
||
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2025.5.1)\n",
|
||
"Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2.32.4)\n",
|
||
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (4.14.0)\n",
|
||
"Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (1.1.5)\n",
|
||
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (2024.11.6)\n",
|
||
"Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (0.5.3)\n",
|
||
"Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (7.0.0)\n",
|
||
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (2.6.0+cu124)\n",
|
||
"Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.3.8)\n",
|
||
"Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.2.4)\n",
|
||
"Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (0.1.1)\n",
|
||
"Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2025.2.0)\n",
|
||
"Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2022.2.0)\n",
|
||
"Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2.4.1)\n",
|
||
"Requirement already satisfied: more_itertools>=8.5.0 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (10.7.0)\n",
|
||
"Requirement already satisfied: typeguard>=4.0.1 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (4.4.4)\n",
|
||
"Requirement already satisfied: click in /usr/local/lib/python3.11/dist-packages (from nltk>=3.2.4->g2p_en==2.1.0) (8.2.1)\n",
|
||
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas==2.3.0) (1.17.0)\n",
|
||
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.5)\n",
|
||
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.1.6)\n",
|
||
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
||
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
||
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
||
"Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (9.1.0.70)\n",
|
||
"Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.5.8)\n",
|
||
"Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.2.1.3)\n",
|
||
"Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (10.3.5.147)\n",
|
||
"Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.6.1.9)\n",
|
||
"Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.3.1.170)\n",
|
||
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (0.6.2)\n",
|
||
"Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (2.21.5)\n",
|
||
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
||
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
||
"Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.2.0)\n",
|
||
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (1.13.1)\n",
|
||
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=2.0.0->accelerate==1.8.1) (1.3.0)\n",
|
||
"Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.8.0)\n",
|
||
"Requirement already satisfied: ipython>=7.23.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (7.34.0)\n",
|
||
"Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (8.6.3)\n",
|
||
"Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (0.1.7)\n",
|
||
"Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.6.0)\n",
|
||
"Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (24.0.1)\n",
|
||
"Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (6.5.1)\n",
|
||
"Requirement already satisfied: traitlets>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (5.7.1)\n",
|
||
"Requirement already satisfied: comm>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (0.2.2)\n",
|
||
"Requirement already satisfied: widgetsnbextension~=4.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (4.0.14)\n",
|
||
"Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (3.0.15)\n",
|
||
"Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (3.0.51)\n",
|
||
"Requirement already satisfied: pygments in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (2.19.2)\n",
|
||
"Requirement already satisfied: jupyter-core in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (5.8.1)\n",
|
||
"Requirement already satisfied: jupyterlab-server~=2.19 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.27.3)\n",
|
||
"Requirement already satisfied: jupyter-server<3,>=1.16.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.12.5)\n",
|
||
"Requirement already satisfied: jupyter-ydoc~=0.2.4 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.2.5)\n",
|
||
"Requirement already satisfied: jupyter-server-ydoc~=0.8.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.8.0)\n",
|
||
"Requirement already satisfied: nbclassic in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (1.3.1)\n",
|
||
"Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (25.1.0)\n",
|
||
"Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.2.0)\n",
|
||
"Requirement already satisfied: nbformat in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (5.10.4)\n",
|
||
"Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (1.8.3)\n",
|
||
"Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.18.1)\n",
|
||
"Requirement already satisfied: prometheus-client in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.22.1)\n",
|
||
"Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.8.4)\n",
|
||
"Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.3.0)\n",
|
||
"Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.4)\n",
|
||
"Requirement already satisfied: bleach in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (6.2.0)\n",
|
||
"Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (1.5.1)\n",
|
||
"Requirement already satisfied: testpath in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.6.0)\n",
|
||
"Requirement already satisfied: defusedxml in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.7.1)\n",
|
||
"Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (4.13.4)\n",
|
||
"Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.5.13)\n",
|
||
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (3.0.2)\n",
|
||
"Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
|
||
"Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2022.2.0)\n",
|
||
"Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy<2.1.0,>=1.26.0) (1.4.0)\n",
|
||
"Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
|
||
"Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.4.2)\n",
|
||
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.10)\n",
|
||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2.5.0)\n",
|
||
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2025.6.15)\n",
|
||
"Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
|
||
"Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (75.2.0)\n",
|
||
"Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.19.2)\n",
|
||
"Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.4.2)\n",
|
||
"Requirement already satisfied: pickleshare in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.7.5)\n",
|
||
"Requirement already satisfied: backcall in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.2.0)\n",
|
||
"Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.9.0)\n",
|
||
"Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.11/dist-packages (from jupyter-core->jupyterlab->jupyter==1.1.1) (4.3.8)\n",
|
||
"Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (4.9.0)\n",
|
||
"Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.12.0)\n",
|
||
"Requirement already satisfied: jupyter-server-terminals in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.5.3)\n",
|
||
"Requirement already satisfied: overrides in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (7.7.0)\n",
|
||
"Requirement already satisfied: websocket-client in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.8.0)\n",
|
||
"Requirement already satisfied: jupyter-server-fileid<1,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.9.3)\n",
|
||
"Requirement already satisfied: ypy-websocket<0.9.0,>=0.8.2 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.8.4)\n",
|
||
"Requirement already satisfied: y-py<0.7.0,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-ydoc~=0.2.4->jupyterlab->jupyter==1.1.1) (0.6.2)\n",
|
||
"Requirement already satisfied: babel>=2.10 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2.17.0)\n",
|
||
"Requirement already satisfied: json5>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.12.0)\n",
|
||
"Requirement already satisfied: jsonschema>=4.18.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (4.24.0)\n",
|
||
"Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.11/dist-packages (from nbclassic->jupyterlab->jupyter==1.1.1) (0.2.4)\n",
|
||
"Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.11/dist-packages (from nbformat->notebook->jupyter==1.1.1) (2.21.1)\n",
|
||
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->jupyter-console->jupyter==1.1.1) (0.2.13)\n",
|
||
"Requirement already satisfied: ptyprocess in /usr/local/lib/python3.11/dist-packages (from terminado>=0.8.3->notebook->jupyter==1.1.1) (0.7.0)\n",
|
||
"Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.11/dist-packages (from argon2-cffi->notebook->jupyter==1.1.1) (21.2.0)\n",
|
||
"Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->nbconvert->jupyter==1.1.1) (2.7)\n",
|
||
"Requirement already satisfied: webencodings in /usr/local/lib/python3.11/dist-packages (from bleach->nbconvert->jupyter==1.1.1) (0.5.1)\n",
|
||
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio>=3.1.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.1)\n",
|
||
"Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.11/dist-packages (from jedi>=0.16->ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.8.4)\n",
|
||
"Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (25.3.0)\n",
|
||
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2025.4.1)\n",
|
||
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.36.2)\n",
|
||
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.25.1)\n",
|
||
"Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.3.0)\n",
|
||
"Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.4)\n",
|
||
"Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.1)\n",
|
||
"Requirement already satisfied: aiofiles<23,>=22.1.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (22.1.0)\n",
|
||
"Requirement already satisfied: aiosqlite<1,>=0.17.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.21.0)\n",
|
||
"Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (1.17.1)\n",
|
||
"Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (2.22)\n",
|
||
"Requirement already satisfied: fqdn in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.5.1)\n",
|
||
"Requirement already satisfied: isoduration in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (20.11.0)\n",
|
||
"Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.0.0)\n",
|
||
"Requirement already satisfied: uri-template in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n",
|
||
"Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (24.11.1)\n",
|
||
"Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.11/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n",
|
||
"Requirement already satisfied: types-python-dateutil>=2.8.10 in /usr/local/lib/python3.11/dist-packages (from arrow>=0.15.0->isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (2.9.0.20250516)\n",
|
||
"Downloading jupyter-1.1.1-py2.py3-none-any.whl (2.7 kB)\n",
|
||
"Downloading pandas-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.4 MB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 12.4/12.4 MB 84.0 MB/s eta 0:00:00\n",
|
||
"Downloading matplotlib-3.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.6 MB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.6/8.6 MB 103.4 MB/s eta 0:00:00\n",
|
||
"Downloading scipy-1.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (37.6 MB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 37.6/37.6 MB 44.8 MB/s eta 0:00:00\n",
|
||
"Downloading scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.5 MB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.5/13.5 MB 75.2 MB/s eta 0:00:00\n",
|
||
"Downloading g2p_en-2.1.0-py3-none-any.whl (3.1 MB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.1/3.1 MB 80.6 MB/s eta 0:00:00\n",
|
||
"Downloading h5py-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.5/4.5 MB 89.1 MB/s eta 0:00:00\n",
|
||
"Downloading transformers-4.53.0-py3-none-any.whl (10.8 MB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 10.8/10.8 MB 105.1 MB/s eta 0:00:00\n",
|
||
"Downloading bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl (67.0 MB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 67.0/67.0 MB 17.9 MB/s eta 0:00:00\n",
|
||
"Building wheels for collected packages: distance\n",
|
||
" Building wheel for distance (setup.py): started\n",
|
||
" Building wheel for distance (setup.py): finished with status 'done'\n",
|
||
" Created wheel for distance: filename=Distance-0.1.3-py3-none-any.whl size=16256 sha256=3bf3b151b4d15c4f5e442085e48cff050eec4c6d192a307e2b09a77daa84a5dc\n",
|
||
" Stored in directory: /root/.cache/pip/wheels/fb/cd/9c/3ab5d666e3bcacc58900b10959edd3816cc9557c7337986322\n",
|
||
"Successfully built distance\n",
|
||
"Installing collected packages: distance, jupyter, scipy, transformers, scikit-learn, pandas, matplotlib, h5py, g2p_en, bitsandbytes\n",
|
||
" Attempting uninstall: scipy\n",
|
||
" Found existing installation: scipy 1.15.3\n",
|
||
" Uninstalling scipy-1.15.3:\n",
|
||
" Successfully uninstalled scipy-1.15.3\n",
|
||
" Attempting uninstall: transformers\n",
|
||
" Found existing installation: transformers 4.52.4\n",
|
||
" Uninstalling transformers-4.52.4:\n",
|
||
" Successfully uninstalled transformers-4.52.4\n",
|
||
" Attempting uninstall: scikit-learn\n",
|
||
" Found existing installation: scikit-learn 1.2.2\n",
|
||
" Uninstalling scikit-learn-1.2.2:\n",
|
||
" Successfully uninstalled scikit-learn-1.2.2\n",
|
||
" Attempting uninstall: pandas\n",
|
||
" Found existing installation: pandas 2.2.3\n",
|
||
" Uninstalling pandas-2.2.3:\n",
|
||
" Successfully uninstalled pandas-2.2.3\n",
|
||
" Attempting uninstall: matplotlib\n",
|
||
" Found existing installation: matplotlib 3.7.2\n",
|
||
" Uninstalling matplotlib-3.7.2:\n",
|
||
" Successfully uninstalled matplotlib-3.7.2\n",
|
||
" Attempting uninstall: h5py\n",
|
||
" Found existing installation: h5py 3.14.0\n",
|
||
" Uninstalling h5py-3.14.0:\n",
|
||
" Successfully uninstalled h5py-3.14.0\n",
|
||
"Successfully installed bitsandbytes-0.46.0 distance-0.1.3 g2p_en-2.1.0 h5py-3.13.0 jupyter-1.1.1 matplotlib-3.10.1 pandas-2.3.0 scikit-learn-1.6.1 scipy-1.15.2 transformers-4.53.0\n",
|
||
"Requirement already satisfied: PyDrive2 in /usr/local/lib/python3.11/dist-packages (1.21.3)\n",
|
||
"Requirement already satisfied: google-api-python-client>=1.12.5 in /usr/local/lib/python3.11/dist-packages (from PyDrive2) (2.173.0)\n",
|
||
"Requirement already satisfied: oauth2client>=4.0.0 in /usr/local/lib/python3.11/dist-packages (from PyDrive2) (4.1.3)\n",
|
||
"Requirement already satisfied: PyYAML>=3.0 in /usr/local/lib/python3.11/dist-packages (from PyDrive2) (6.0.2)\n",
|
||
"Collecting cryptography<44 (from PyDrive2)\n",
|
||
" Downloading cryptography-43.0.3-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (5.4 kB)\n",
|
||
"Collecting pyOpenSSL<=24.2.1,>=19.1.0 (from PyDrive2)\n",
|
||
" Downloading pyOpenSSL-24.2.1-py3-none-any.whl.metadata (13 kB)\n",
|
||
"Requirement already satisfied: cffi>=1.12 in /usr/local/lib/python3.11/dist-packages (from cryptography<44->PyDrive2) (1.17.1)\n",
|
||
"Requirement already satisfied: httplib2<1.0.0,>=0.19.0 in /usr/local/lib/python3.11/dist-packages (from google-api-python-client>=1.12.5->PyDrive2) (0.22.0)\n",
|
||
"Requirement already satisfied: google-auth!=2.24.0,!=2.25.0,<3.0.0,>=1.32.0 in /usr/local/lib/python3.11/dist-packages (from google-api-python-client>=1.12.5->PyDrive2) (2.40.3)\n",
|
||
"Requirement already satisfied: google-auth-httplib2<1.0.0,>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from google-api-python-client>=1.12.5->PyDrive2) (0.2.0)\n",
|
||
"Requirement already satisfied: google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5 in /usr/local/lib/python3.11/dist-packages (from google-api-python-client>=1.12.5->PyDrive2) (1.34.1)\n",
|
||
"Requirement already satisfied: uritemplate<5,>=3.0.1 in /usr/local/lib/python3.11/dist-packages (from google-api-python-client>=1.12.5->PyDrive2) (4.2.0)\n",
|
||
"Requirement already satisfied: pyasn1>=0.1.7 in /usr/local/lib/python3.11/dist-packages (from oauth2client>=4.0.0->PyDrive2) (0.6.1)\n",
|
||
"Requirement already satisfied: pyasn1-modules>=0.0.5 in /usr/local/lib/python3.11/dist-packages (from oauth2client>=4.0.0->PyDrive2) (0.4.2)\n",
|
||
"Requirement already satisfied: rsa>=3.1.4 in /usr/local/lib/python3.11/dist-packages (from oauth2client>=4.0.0->PyDrive2) (4.9.1)\n",
|
||
"Requirement already satisfied: six>=1.6.1 in /usr/local/lib/python3.11/dist-packages (from oauth2client>=4.0.0->PyDrive2) (1.17.0)\n",
|
||
"Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.12->cryptography<44->PyDrive2) (2.22)\n",
|
||
"Requirement already satisfied: googleapis-common-protos<2.0dev,>=1.56.2 in /usr/local/lib/python3.11/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (1.70.0)\n",
|
||
"Requirement already satisfied: protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<4.0.0dev,>=3.19.5 in /usr/local/lib/python3.11/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (3.20.3)\n",
|
||
"Requirement already satisfied: requests<3.0.0dev,>=2.18.0 in /usr/local/lib/python3.11/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (2.32.4)\n",
|
||
"Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from google-auth!=2.24.0,!=2.25.0,<3.0.0,>=1.32.0->google-api-python-client>=1.12.5->PyDrive2) (5.5.2)\n",
|
||
"Requirement already satisfied: pyparsing!=3.0.0,!=3.0.1,!=3.0.2,!=3.0.3,<4,>=2.4.2 in /usr/local/lib/python3.11/dist-packages (from httplib2<1.0.0,>=0.19.0->google-api-python-client>=1.12.5->PyDrive2) (3.0.9)\n",
|
||
"Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (3.4.2)\n",
|
||
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (3.10)\n",
|
||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (2.5.0)\n",
|
||
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (2025.6.15)\n",
|
||
"Downloading cryptography-43.0.3-cp39-abi3-manylinux_2_28_x86_64.whl (4.0 MB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.0/4.0 MB 49.5 MB/s eta 0:00:00\n",
|
||
"Downloading pyOpenSSL-24.2.1-py3-none-any.whl (58 kB)\n",
|
||
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 58.4/58.4 kB 3.4 MB/s eta 0:00:00\n",
|
||
"Installing collected packages: cryptography, pyOpenSSL\n",
|
||
" Attempting uninstall: cryptography\n",
|
||
" Found existing installation: cryptography 44.0.3\n",
|
||
" Uninstalling cryptography-44.0.3:\n",
|
||
" Successfully uninstalled cryptography-44.0.3\n",
|
||
" Attempting uninstall: pyOpenSSL\n",
|
||
" Found existing installation: pyOpenSSL 25.1.0\n",
|
||
" Uninstalling pyOpenSSL-25.1.0:\n",
|
||
" Successfully uninstalled pyOpenSSL-25.1.0\n",
|
||
"Successfully installed cryptography-43.0.3 pyOpenSSL-24.2.1\n",
|
||
"Obtaining file:///kaggle/working/nejm-brain-to-text\n",
|
||
" Preparing metadata (setup.py): started\n",
|
||
" Preparing metadata (setup.py): finished with status 'done'\n",
|
||
"Installing collected packages: nejm_b2txt_utils\n",
|
||
" Running setup.py develop for nejm_b2txt_utils\n",
|
||
"Successfully installed nejm_b2txt_utils-0.0.0\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Cloning into 'nejm-brain-to-text'...\n",
|
||
"ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
|
||
"bigframes 2.8.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.\n",
|
||
"gensim 4.3.3 requires scipy<1.14.0,>=1.7.0, but you have scipy 1.15.2 which is incompatible.\n",
|
||
"dask-cudf-cu12 25.2.2 requires pandas<2.2.4dev0,>=2.0, but you have pandas 2.3.0 which is incompatible.\n",
|
||
"cudf-cu12 25.2.2 requires pandas<2.2.4dev0,>=2.0, but you have pandas 2.3.0 which is incompatible.\n",
|
||
"datasets 3.6.0 requires fsspec[http]<=2025.3.0,>=2023.1.0, but you have fsspec 2025.5.1 which is incompatible.\n",
|
||
"ydata-profiling 4.16.1 requires matplotlib<=3.10,>=3.5, but you have matplotlib 3.10.1 which is incompatible.\n",
|
||
"category-encoders 2.7.0 requires scikit-learn<1.6.0,>=1.0.0, but you have scikit-learn 1.6.1 which is incompatible.\n",
|
||
"cesium 0.12.4 requires numpy<3.0,>=2.0, but you have numpy 1.26.4 which is incompatible.\n",
|
||
"google-colab 1.0.0 requires google-auth==2.38.0, but you have google-auth 2.40.3 which is incompatible.\n",
|
||
"google-colab 1.0.0 requires notebook==6.5.7, but you have notebook 6.5.4 which is incompatible.\n",
|
||
"google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.3.0 which is incompatible.\n",
|
||
"google-colab 1.0.0 requires requests==2.32.3, but you have requests 2.32.4 which is incompatible.\n",
|
||
"google-colab 1.0.0 requires tornado==6.4.2, but you have tornado 6.5.1 which is incompatible.\n",
|
||
"dopamine-rl 4.1.2 requires gymnasium>=1.0.0, but you have gymnasium 0.29.0 which is incompatible.\n",
|
||
"pandas-gbq 0.29.1 requires google-api-core<3.0.0,>=2.10.2, but you have google-api-core 1.34.1 which is incompatible.\n",
|
||
"bigframes 2.8.0 requires google-cloud-bigquery[bqstorage,pandas]>=3.31.0, but you have google-cloud-bigquery 3.25.0 which is incompatible.\n",
|
||
"bigframes 2.8.0 requires rich<14,>=12.4.4, but you have rich 14.0.0 which is incompatible.\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"%%bash\n",
|
||
"rm -rf /kaggle/working/nejm-brain-to-text/\n",
|
||
"git clone https://github.com/ZH-CEN/nejm-brain-to-text.git\n",
|
||
"cp /kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl /kaggle/working/nejm-brain-to-text/data/t15_copyTask.pkl\n",
|
||
"\n",
|
||
"ln -s /kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline /kaggle/working/nejm-brain-to-text/data\n",
|
||
"ln -s /kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final /kaggle/working/nejm-brain-to-text/data\n",
|
||
"ln -s /kaggle/input/rnn-pretagged-data /kaggle/working/nejm-brain-to-text/data\n",
|
||
"\n",
|
||
"pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n",
|
||
"\n",
|
||
"pip install \\\n",
|
||
" jupyter==1.1.1 \\\n",
|
||
" \"numpy>=1.26.0,<2.1.0\" \\\n",
|
||
" pandas==2.3.0 \\\n",
|
||
" matplotlib==3.10.1 \\\n",
|
||
" scipy==1.15.2 \\\n",
|
||
" scikit-learn==1.6.1 \\\n",
|
||
" tqdm==4.67.1 \\\n",
|
||
" g2p_en==2.1.0 \\\n",
|
||
" h5py==3.13.0 \\\n",
|
||
" omegaconf==2.3.0 \\\n",
|
||
" editdistance==0.8.1 \\\n",
|
||
" huggingface-hub==0.33.1 \\\n",
|
||
" transformers==4.53.0 \\\n",
|
||
" tokenizers==0.21.2 \\\n",
|
||
" accelerate==1.8.1 \\\n",
|
||
" bitsandbytes==0.46.0\n",
|
||
" \n",
|
||
"pip install PyDrive2\n",
|
||
"\n",
|
||
"cd /kaggle/working/nejm-brain-to-text/\n",
|
||
"pip install -e .\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/kaggle/working/nejm-brain-to-text\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"%cd /kaggle/working/nejm-brain-to-text\n",
|
||
"import numpy as np\n",
|
||
"import os\n",
|
||
"import pickle\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import matplotlib\n",
|
||
"from g2p_en import G2p\n",
|
||
"import pandas as pd\n",
|
||
"import numpy as np\n",
|
||
"from nejm_b2txt_utils.general_utils import *\n",
|
||
"\n",
|
||
"matplotlib.rcParams['pdf.fonttype'] = 42\n",
|
||
"matplotlib.rcParams['ps.fonttype'] = 42\n",
|
||
"matplotlib.rcParams['font.family'] = 'sans-serif'\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/kaggle/working/nejm-brain-to-text/model_training\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"%cd model_training/\n",
|
||
"from data_augmentations import gauss_smooth\n",
|
||
"# single decoding step function that also returns smoothed input\n",
|
||
"# smooths data and puts it through the model, returning both logits and smoothed input.\n",
|
||
"def runSingleDecodingStepWithSmoothedInput(x, input_layer, model, model_args, device):\n",
|
||
"\n",
|
||
" # Use autocast for efficiency\n",
|
||
" with torch.autocast(device_type = \"cuda\", enabled = model_args['use_amp'], dtype = torch.bfloat16):\n",
|
||
"\n",
|
||
" smoothed_x = gauss_smooth(\n",
|
||
" inputs = x, \n",
|
||
" device = device,\n",
|
||
" smooth_kernel_std = model_args['dataset']['data_transforms']['smooth_kernel_std'],\n",
|
||
" smooth_kernel_size = model_args['dataset']['data_transforms']['smooth_kernel_size'],\n",
|
||
" padding = 'valid',\n",
|
||
" )\n",
|
||
"\n",
|
||
" with torch.no_grad():\n",
|
||
" logits, _ = model(\n",
|
||
" x = smoothed_x,\n",
|
||
" day_idx = torch.tensor([input_layer], device=device),\n",
|
||
" states = None, # no initial states\n",
|
||
" return_state = True,\n",
|
||
" )\n",
|
||
"\n",
|
||
" # convert both logits and smoothed input from bfloat16 to float32\n",
|
||
" logits = logits.float().cpu().numpy()\n",
|
||
" smoothed_input = smoothed_x.float().cpu().numpy()\n",
|
||
"\n",
|
||
" # # original order is [BLANK, phonemes..., SIL]\n",
|
||
" # # rearrange so the order is [BLANK, SIL, phonemes...]\n",
|
||
" # logits = rearrange_speech_logits_pt(logits)\n",
|
||
"\n",
|
||
" return logits, smoothed_input\n",
|
||
"\n",
|
||
"\n",
|
||
"import h5py\n",
|
||
"def load_h5py_file(file_path, b2txt_csv_df):\n",
|
||
" data = {\n",
|
||
" 'neural_features': [],\n",
|
||
" 'n_time_steps': [],\n",
|
||
" 'seq_class_ids': [],\n",
|
||
" 'seq_len': [],\n",
|
||
" 'transcriptions': [],\n",
|
||
" 'sentence_label': [],\n",
|
||
" 'session': [],\n",
|
||
" 'block_num': [],\n",
|
||
" 'trial_num': [],\n",
|
||
" 'corpus': [],\n",
|
||
" }\n",
|
||
" # Open the hdf5 file for that day\n",
|
||
" with h5py.File(file_path, 'r') as f:\n",
|
||
"\n",
|
||
" keys = list(f.keys())\n",
|
||
"\n",
|
||
" # For each trial in the selected trials in that day\n",
|
||
" for key in keys:\n",
|
||
" g = f[key]\n",
|
||
"\n",
|
||
" neural_features = g['input_features'][:] # pyright: ignore[reportIndexIssue]\n",
|
||
" n_time_steps = g.attrs['n_time_steps']\n",
|
||
" seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None # type: ignore\n",
|
||
" seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None\n",
|
||
" transcription = g['transcription'][:] if 'transcription' in g else None # type: ignore\n",
|
||
" sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None # pyright: ignore[reportIndexIssue]\n",
|
||
" session = g.attrs['session']\n",
|
||
" block_num = g.attrs['block_num']\n",
|
||
" trial_num = g.attrs['trial_num']\n",
|
||
"\n",
|
||
" # match this trial up with the csv to get the corpus name\n",
|
||
" year, month, day = session.split('.')[1:] # pyright: ignore[reportAttributeAccessIssue]\n",
|
||
" date = f'{year}-{month}-{day}'\n",
|
||
" row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)]\n",
|
||
" corpus_name = row['Corpus'].values[0]\n",
|
||
"\n",
|
||
" data['neural_features'].append(neural_features)\n",
|
||
" data['n_time_steps'].append(n_time_steps)\n",
|
||
" data['seq_class_ids'].append(seq_class_ids)\n",
|
||
" data['seq_len'].append(seq_len)\n",
|
||
" data['transcriptions'].append(transcription)\n",
|
||
" data['sentence_label'].append(sentence_label)\n",
|
||
" data['session'].append(session)\n",
|
||
" data['block_num'].append(block_num)\n",
|
||
" data['trial_num'].append(trial_num)\n",
|
||
" data['corpus'].append(corpus_name)\n",
|
||
" return data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"LOGIT_TO_PHONEME = [\n",
|
||
" 'BLANK',\n",
|
||
" 'AA', 'AE', 'AH', 'AO', 'AW',\n",
|
||
" 'AY', 'B', 'CH', 'D', 'DH',\n",
|
||
" 'EH', 'ER', 'EY', 'F', 'G',\n",
|
||
" 'HH', 'IH', 'IY', 'JH', 'K',\n",
|
||
" 'L', 'M', 'N', 'NG', 'OW',\n",
|
||
" 'OY', 'P', 'R', 'S', 'SH',\n",
|
||
" 'T', 'TH', 'UH', 'UW', 'V',\n",
|
||
" 'W', 'Y', 'Z', 'ZH',\n",
|
||
" ' | ',\n",
|
||
"]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 数据分析与预处理"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 数据准备"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/kaggle/working/nejm-brain-to-text\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"%cd /kaggle/working/nejm-brain-to-text/"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import pandas as pd\n",
|
||
"data = load_h5py_file(file_path='/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.08.11/data_train.hdf5',\n",
|
||
" b2txt_csv_df=pd.read_csv('/kaggle/working/nejm-brain-to-text/data/t15_copyTaskData_description.csv'))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 57,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"data2 = load_h5py_file(file_path='/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.08.13/data_train.hdf5',\n",
|
||
" b2txt_csv_df=pd.read_csv('/kaggle/working/nejm-brain-to-text/data/t15_copyTaskData_description.csv'))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"- **任务介绍** :机器学习解决高维信号的模式识别问题"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"我们的数据集标签缺少时间戳,现在要进行的是半监督学习"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"- 音素时间均等分割或者按照调研数据设定初始长度。然后筛掉异常值。提取出可用的训练集,再控制时间长短,查看样本类的长度"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"{'neural_features': array([[ 2.3076649 , -0.78699756, -0.64687246, ..., 0.57367045,\n",
|
||
" -0.7091646 , -0.11018186],\n",
|
||
" [-0.5859305 , -0.78699756, -0.64687246, ..., 0.3122117 ,\n",
|
||
" 1.7943763 , -0.76884896],\n",
|
||
" [-0.5859305 , -0.78699756, -0.64687246, ..., -0.21193463,\n",
|
||
" -0.8481289 , -0.7648201 ],\n",
|
||
" ...,\n",
|
||
" [-0.5859305 , 0.22756557, 0.9262037 , ..., -0.34710956,\n",
|
||
" 0.9710176 , 2.5397465 ],\n",
|
||
" [-0.5859305 , 0.22756557, -0.64687246, ..., -0.83613133,\n",
|
||
" -0.68723625, 0.10479005],\n",
|
||
" [ 0.8608672 , -0.78699756, -0.64687246, ..., -0.7171131 ,\n",
|
||
" 0.7417906 , -0.7008622 ]], dtype=float32),\n",
|
||
" 'n_time_steps': 321,\n",
|
||
" 'seq_class_ids': array([ 7, 28, 17, 24, 40, 17, 31, 40, 20, 21, 25, 29, 12, 40, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0], dtype=int32),\n",
|
||
" 'seq_len': 14,\n",
|
||
" 'transcriptions': array([ 66, 114, 105, 110, 103, 32, 105, 116, 32, 99, 108, 111, 115,\n",
|
||
" 101, 114, 46, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
|
||
" 0, 0, 0, 0, 0, 0], dtype=int32),\n",
|
||
" 'sentence_label': 'Bring it closer.',\n",
|
||
" 'session': 't15.2023.08.11',\n",
|
||
" 'block_num': 2,\n",
|
||
" 'trial_num': 0,\n",
|
||
" 'corpus': '50-Word'}"
|
||
]
|
||
},
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"def data_patch(data, index):\n",
|
||
" data_patch = {}\n",
|
||
" data_patch['neural_features'] = data['neural_features'][index]\n",
|
||
" data_patch['n_time_steps'] = data['n_time_steps'][index]\n",
|
||
" data_patch['seq_class_ids'] = data['seq_class_ids'][index]\n",
|
||
" data_patch['seq_len'] = data['seq_len'][index]\n",
|
||
" data_patch['transcriptions'] = data['transcriptions'][index]\n",
|
||
" data_patch['sentence_label'] = data['sentence_label'][index]\n",
|
||
" data_patch['session'] = data['session'][index]\n",
|
||
" data_patch['block_num'] = data['block_num'][index]\n",
|
||
" data_patch['trial_num'] = data['trial_num'][index]\n",
|
||
" data_patch['corpus'] = data['corpus'][index]\n",
|
||
" return data_patch"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"d1 = data_patch(data, 0)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Transcriptions non-zero length: 16\n",
|
||
"Seq class ids non-zero length: 14\n",
|
||
"Seq len: 14\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"trans_len = len([x for x in d1['transcriptions'] if x != 0])\n",
|
||
"seq_len_nonzero = len([x for x in d1['seq_class_ids'] if x != 0])\n",
|
||
"seq_len = d1['seq_len']\n",
|
||
"print(f\"Transcriptions non-zero length: {trans_len}\")\n",
|
||
"print(f\"Seq class ids non-zero length: {seq_len_nonzero}\")\n",
|
||
"print(f\"Seq len: {seq_len}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Number of feature sequences: 14\n",
|
||
"Shape of first sequence: (22, 512)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"def create_time_windows(d1):\n",
|
||
" import numpy as np\n",
|
||
" n_time_steps = d1['n_time_steps']\n",
|
||
" seq_len = d1['seq_len']\n",
|
||
" # Create equal windows\n",
|
||
" edges = np.linspace(0, n_time_steps, seq_len + 1, dtype=int)\n",
|
||
" windows = [(edges[i], edges[i+1]) for i in range(seq_len)]\n",
|
||
" \n",
|
||
" # Extract feature sequences for each window\n",
|
||
" feature_sequences = []\n",
|
||
" for start, end in windows:\n",
|
||
" seq = d1['neural_features'][start:end, :]\n",
|
||
" feature_sequences.append(seq)\n",
|
||
" \n",
|
||
" return feature_sequences\n",
|
||
"\n",
|
||
"# Example usage\n",
|
||
"feature_sequences = create_time_windows(d1)\n",
|
||
"print(\"Number of feature sequences:\", len(feature_sequences))\n",
|
||
"print(\"Shape of first sequence:\", feature_sequences[0].shape)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Train: 45, Val: 41, Test: 41\n",
|
||
"Train files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_train.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.08.11/data_train.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_train.hdf5']\n",
|
||
"Val files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_val.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_val.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2024.03.08/data_val.hdf5']\n",
|
||
"Test files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_test.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_test.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2024.03.08/data_test.hdf5']\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import os\n",
|
||
"\n",
|
||
"def scan_hdf5_files(base_path):\n",
|
||
" train_files = []\n",
|
||
" val_files = []\n",
|
||
" test_files = []\n",
|
||
" for root, dirs, files in os.walk(base_path):\n",
|
||
" for file in files:\n",
|
||
" if file.endswith('.hdf5'):\n",
|
||
" abs_path = os.path.abspath(os.path.join(root, file))\n",
|
||
" if 'data_train.hdf5' in file:\n",
|
||
" train_files.append(abs_path)\n",
|
||
" elif 'data_val.hdf5' in file:\n",
|
||
" val_files.append(abs_path)\n",
|
||
" elif 'data_test.hdf5' in file:\n",
|
||
" test_files.append(abs_path)\n",
|
||
" return train_files, val_files, test_files\n",
|
||
"\n",
|
||
"# Example usage\n",
|
||
"FILE_PATH = 'data/hdf5_data_final'\n",
|
||
"train_list, val_list, test_list = scan_hdf5_files(FILE_PATH)\n",
|
||
"print(f\"Train: {len(train_list)}, Val: {len(val_list)}, Test: {len(test_list)}\")\n",
|
||
"print(\"Train files (first 3):\", train_list[:3])\n",
|
||
"print(\"Val files (first 3):\", val_list[:3])\n",
|
||
"print(\"Test files (first 3):\", test_list[:3])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 🔗 数据集批量处理工作流"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 25,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/kaggle/working/nejm-brain-to-text/model_training\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"%cd model_training/"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"======================================================================\n",
|
||
"🚀 RNN数据批量处理工具 - 增强内存管理版本\n",
|
||
"======================================================================\n",
|
||
"🔧 创建RNN数据处理器(增强内存管理版)...\n",
|
||
"🔧 初始化RNN数据处理器(增强内存管理版)...\n",
|
||
" 模型路径: ../data/t15_pretrained_rnn_baseline\n",
|
||
" 数据目录: ../data/hdf5_data_final\n",
|
||
" 计算设备: cpu\n",
|
||
" 批量保存间隔: 25 个试验\n",
|
||
" 内存警告阈值: 80%\n",
|
||
" 初始内存状态: RAM: 0.52GB (3.8%)\n",
|
||
"📋 模型配置:\n",
|
||
" Sessions数量: 45\n",
|
||
" 神经特征维度: 512\n",
|
||
" Patch size: 14\n",
|
||
" Patch stride: 4\n",
|
||
" 输出类别数: 41\n",
|
||
"🔄 加载模型... 当前内存: RAM: 0.52GB (3.8%)\n",
|
||
"✅ 模型加载成功,内存清理后: RAM: 1.18GB (8.6%)\n",
|
||
"📊 CSV数据加载完成: 265 条记录\n",
|
||
"✅ 初始化完成!当前内存: RAM: 1.18GB (8.6%)\n",
|
||
"✅ RNN数据处理器创建成功!\n",
|
||
"🧠 内存管理功能已启用:\n",
|
||
" - 批量保存间隔: 25个试验\n",
|
||
" - 自动内存监控和清理\n",
|
||
" - GPU内存即时释放\n",
|
||
" - 垃圾回收优化\n",
|
||
"✅ 模型加载成功,内存清理后: RAM: 1.18GB (8.6%)\n",
|
||
"📊 CSV数据加载完成: 265 条记录\n",
|
||
"✅ 初始化完成!当前内存: RAM: 1.18GB (8.6%)\n",
|
||
"✅ RNN数据处理器创建成功!\n",
|
||
"🧠 内存管理功能已启用:\n",
|
||
" - 批量保存间隔: 25个试验\n",
|
||
" - 自动内存监控和清理\n",
|
||
" - GPU内存即时释放\n",
|
||
" - 垃圾回收优化\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 🚀 RNN数据批量处理工具 - 完整版(增强内存管理 + 自动上传)\n",
|
||
"import os\n",
|
||
"import torch\n",
|
||
"import numpy as np\n",
|
||
"import pandas as pd\n",
|
||
"from omegaconf import OmegaConf\n",
|
||
"import time\n",
|
||
"from tqdm import tqdm\n",
|
||
"import h5py\n",
|
||
"from pathlib import Path\n",
|
||
"import gc # 垃圾回收\n",
|
||
"import psutil # 内存监控\n",
|
||
"\n",
|
||
"# 导入模型相关模块\n",
|
||
"import sys\n",
|
||
"sys.path.append('../model_training')\n",
|
||
"from rnn_model import GRUDecoder\n",
|
||
"from evaluate_model_helpers import *\n",
|
||
"from data_augmentations import gauss_smooth\n",
|
||
"\n",
|
||
"print(\"=\"*70)\n",
|
||
"print(\"🚀 RNN数据批量处理工具 - 增强内存管理 + 自动上传版本\")\n",
|
||
"print(\"=\"*70)\n",
|
||
"\n",
|
||
"class MemoryManager:\n",
|
||
" \"\"\"内存管理器 - 监控和清理内存\"\"\"\n",
|
||
" \n",
|
||
" @staticmethod\n",
|
||
" def get_memory_info():\n",
|
||
" \"\"\"获取内存使用情况\"\"\"\n",
|
||
" process = psutil.Process()\n",
|
||
" memory_info = process.memory_info()\n",
|
||
" memory_percent = process.memory_percent()\n",
|
||
" \n",
|
||
" # GPU内存(如果可用)\n",
|
||
" gpu_memory = \"\"\n",
|
||
" if torch.cuda.is_available():\n",
|
||
" gpu_allocated = torch.cuda.memory_allocated() / 1024**3\n",
|
||
" gpu_reserved = torch.cuda.memory_reserved() / 1024**3\n",
|
||
" gpu_memory = f\" | GPU: {gpu_allocated:.2f}GB allocated, {gpu_reserved:.2f}GB reserved\"\n",
|
||
" \n",
|
||
" return f\"RAM: {memory_info.rss / 1024**3:.2f}GB ({memory_percent:.1f}%){gpu_memory}\"\n",
|
||
" \n",
|
||
" @staticmethod\n",
|
||
" def clear_memory():\n",
|
||
" \"\"\"清理内存\"\"\"\n",
|
||
" # 清理Python垃圾回收\n",
|
||
" collected = gc.collect()\n",
|
||
" \n",
|
||
" # 清理GPU内存\n",
|
||
" if torch.cuda.is_available():\n",
|
||
" torch.cuda.empty_cache()\n",
|
||
" torch.cuda.synchronize()\n",
|
||
" \n",
|
||
" return collected\n",
|
||
" \n",
|
||
" @staticmethod\n",
|
||
" def memory_warning_check():\n",
|
||
" \"\"\"检查内存使用情况并发出警告\"\"\"\n",
|
||
" memory_percent = psutil.virtual_memory().percent\n",
|
||
" if memory_percent > 85:\n",
|
||
" print(f\" 内存使用率过高: {memory_percent:.1f}%\")\n",
|
||
" return True\n",
|
||
" return False\n",
|
||
"\n",
|
||
"class RNNDataProcessor:\n",
|
||
" \"\"\"\n",
|
||
" RNN数据批量处理器 - 生成RNN输入输出拼接数据\n",
|
||
" 增强版本:优化内存管理,支持大数据集处理,自动上传到WebDAV\n",
|
||
" \n",
|
||
" 核心功能:\n",
|
||
" 1. 加载预训练RNN模型\n",
|
||
" 2. 处理原始神经数据(高斯平滑 + patch操作)\n",
|
||
" 3. 获取RNN输出(40类置信度分数)\n",
|
||
" 4. 拼接处理后的输入和输出\n",
|
||
" 5. 批量保存所有session数据\n",
|
||
" 6. 自动上传到WebDAV并删除本地文件\n",
|
||
" 7. 内存管理和监控\n",
|
||
" \"\"\"\n",
|
||
" \n",
|
||
" def __init__(self, model_path, data_dir, csv_path, device='auto', \n",
|
||
" batch_save_interval=50, memory_threshold=80, \n",
|
||
" enable_auto_upload=True, webdav_uploader=None):\n",
|
||
" \"\"\"\n",
|
||
" 初始化处理器\n",
|
||
" \n",
|
||
" 参数:\n",
|
||
" model_path: 预训练RNN模型路径\n",
|
||
" data_dir: 数据目录路径 \n",
|
||
" csv_path: 数据描述CSV文件路径\n",
|
||
" device: 计算设备 ('auto', 'cpu', 'cuda:0'等)\n",
|
||
" batch_save_interval: 批量保存间隔(每N个试验保存一次)\n",
|
||
" memory_threshold: 内存警告阈值(百分比)\n",
|
||
" enable_auto_upload: 是否启用自动上传\n",
|
||
" webdav_uploader: WebDAV上传器实例\n",
|
||
" \"\"\"\n",
|
||
" self.model_path = model_path\n",
|
||
" self.data_dir = data_dir\n",
|
||
" self.csv_path = csv_path\n",
|
||
" self.batch_save_interval = batch_save_interval\n",
|
||
" self.memory_threshold = memory_threshold\n",
|
||
" self.enable_auto_upload = enable_auto_upload\n",
|
||
" self.webdav_uploader = webdav_uploader\n",
|
||
" \n",
|
||
" # 设备选择\n",
|
||
" if device == 'auto':\n",
|
||
" self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
|
||
" else:\n",
|
||
" self.device = torch.device(device)\n",
|
||
" \n",
|
||
" print(f\"🔧 初始化RNN数据处理器(增强内存管理 + 自动上传版)...\")\n",
|
||
" print(f\" 模型路径: {model_path}\")\n",
|
||
" print(f\" 数据目录: {data_dir}\")\n",
|
||
" print(f\" 计算设备: {self.device}\")\n",
|
||
" print(f\" 批量保存间隔: {batch_save_interval} 个试验\")\n",
|
||
" print(f\" 内存警告阈值: {memory_threshold}%\")\n",
|
||
" print(f\" 自动上传: {'✅ 启用' if enable_auto_upload else '❌ 禁用'}\")\n",
|
||
" \n",
|
||
" # 初始内存状态\n",
|
||
" print(f\" 初始内存状态: {MemoryManager.get_memory_info()}\")\n",
|
||
" \n",
|
||
" # 加载配置和模型\n",
|
||
" self._load_config()\n",
|
||
" self._load_model()\n",
|
||
" self._load_csv()\n",
|
||
" \n",
|
||
" print(f\" 初始化完成!当前内存: {MemoryManager.get_memory_info()}\")\n",
|
||
" \n",
|
||
" def _load_config(self):\n",
|
||
" \"\"\"加载模型配置\"\"\"\n",
|
||
" config_path = os.path.join(self.model_path, 'checkpoint/args.yaml')\n",
|
||
" if not os.path.exists(config_path):\n",
|
||
" raise FileNotFoundError(f\"配置文件不存在: {config_path}\")\n",
|
||
" \n",
|
||
" self.model_args = OmegaConf.load(config_path)\n",
|
||
" \n",
|
||
" print(f\" 模型配置:\")\n",
|
||
" print(f\" Sessions数量: {len(self.model_args['dataset']['sessions'])}\")\n",
|
||
" print(f\" 神经特征维度: {self.model_args['model']['n_input_features']}\")\n",
|
||
" print(f\" Patch size: {self.model_args['model']['patch_size']}\")\n",
|
||
" print(f\" Patch stride: {self.model_args['model']['patch_stride']}\")\n",
|
||
" print(f\" 输出类别数: {self.model_args['dataset']['n_classes']}\")\n",
|
||
" \n",
|
||
" def _load_model(self):\n",
|
||
" \"\"\"加载预训练RNN模型\"\"\"\n",
|
||
" try:\n",
|
||
" print(f\" 加载模型... 当前内存: {MemoryManager.get_memory_info()}\")\n",
|
||
" \n",
|
||
" # 创建模型\n",
|
||
" self.model = GRUDecoder(\n",
|
||
" neural_dim=self.model_args['model']['n_input_features'],\n",
|
||
" n_units=self.model_args['model']['n_units'], \n",
|
||
" n_days=len(self.model_args['dataset']['sessions']),\n",
|
||
" n_classes=self.model_args['dataset']['n_classes'],\n",
|
||
" rnn_dropout=self.model_args['model']['rnn_dropout'],\n",
|
||
" input_dropout=self.model_args['model']['input_network']['input_layer_dropout'],\n",
|
||
" n_layers=self.model_args['model']['n_layers'],\n",
|
||
" patch_size=self.model_args['model']['patch_size'],\n",
|
||
" patch_stride=self.model_args['model']['patch_stride'],\n",
|
||
" )\n",
|
||
" \n",
|
||
" # 加载权重\n",
|
||
" checkpoint_path = os.path.join(self.model_path, 'checkpoint/best_checkpoint')\n",
|
||
" try:\n",
|
||
" checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)\n",
|
||
" except TypeError:\n",
|
||
" checkpoint = torch.load(checkpoint_path, map_location=self.device)\n",
|
||
" \n",
|
||
" # 清理键名\n",
|
||
" for key in list(checkpoint['model_state_dict'].keys()):\n",
|
||
" checkpoint['model_state_dict'][key.replace(\"module.\", \"\")] = checkpoint['model_state_dict'].pop(key)\n",
|
||
" checkpoint['model_state_dict'][key.replace(\"_orig_mod.\", \"\")] = checkpoint['model_state_dict'].pop(key)\n",
|
||
" \n",
|
||
" self.model.load_state_dict(checkpoint['model_state_dict'])\n",
|
||
" self.model.to(self.device)\n",
|
||
" self.model.eval()\n",
|
||
" \n",
|
||
" # 立即清理checkpoint内存\n",
|
||
" del checkpoint\n",
|
||
" MemoryManager.clear_memory()\n",
|
||
" \n",
|
||
" print(f\" 模型加载成功,内存清理后: {MemoryManager.get_memory_info()}\")\n",
|
||
" \n",
|
||
" except Exception as e:\n",
|
||
" print(f\" 模型加载失败: {e}\")\n",
|
||
" raise\n",
|
||
" \n",
|
||
" def _load_csv(self):\n",
|
||
" \"\"\"加载数据描述文件\"\"\"\n",
|
||
" if not os.path.exists(self.csv_path):\n",
|
||
" raise FileNotFoundError(f\"CSV文件不存在: {self.csv_path}\")\n",
|
||
" \n",
|
||
" self.csv_df = pd.read_csv(self.csv_path)\n",
|
||
" print(f\"📊 CSV数据加载完成: {len(self.csv_df)} 条记录\")\n",
|
||
" \n",
|
||
" def _process_single_trial(self, neural_data, session_idx):\n",
|
||
" \"\"\"\n",
|
||
" 处理单个试验数据(优化内存使用)\n",
|
||
" \n",
|
||
" 参数:\n",
|
||
" neural_data: 原始神经数据 [time_steps, features]\n",
|
||
" session_idx: 会话索引\n",
|
||
" \n",
|
||
" 返回:\n",
|
||
" dict: 包含拼接数据和统计信息\n",
|
||
" \"\"\"\n",
|
||
" try:\n",
|
||
" # 添加batch维度\n",
|
||
" neural_input = np.expand_dims(neural_data, axis=0)\n",
|
||
" neural_tensor = torch.tensor(neural_input, device=self.device, dtype=torch.bfloat16)\n",
|
||
" \n",
|
||
" # 高斯平滑\n",
|
||
" with torch.autocast(device_type=\"cuda\" if self.device.type == \"cuda\" else \"cpu\", \n",
|
||
" enabled=self.model_args.get('use_amp', False), dtype=torch.bfloat16):\n",
|
||
" \n",
|
||
" smoothed_data = gauss_smooth(\n",
|
||
" inputs=neural_tensor,\n",
|
||
" device=self.device,\n",
|
||
" smooth_kernel_std=self.model_args['dataset']['data_transforms']['smooth_kernel_std'],\n",
|
||
" smooth_kernel_size=self.model_args['dataset']['data_transforms']['smooth_kernel_size'],\n",
|
||
" padding='valid',\n",
|
||
" )\n",
|
||
" \n",
|
||
" # Patch操作(复制模型内部逻辑)\n",
|
||
" processed_data = smoothed_data\n",
|
||
" if self.model.patch_size > 0:\n",
|
||
" processed_data = processed_data.unsqueeze(1) # [batch, 1, time, features]\n",
|
||
" processed_data = processed_data.permute(0, 3, 1, 2) # [batch, features, 1, time]\n",
|
||
" \n",
|
||
" # 滑动窗口提取\n",
|
||
" patches = processed_data.unfold(3, self.model.patch_size, self.model.patch_stride)\n",
|
||
" patches = patches.squeeze(2) # [batch, features, patches, patch_size]\n",
|
||
" patches = patches.permute(0, 2, 3, 1) # [batch, patches, patch_size, features]\n",
|
||
" \n",
|
||
" # 展平最后两个维度\n",
|
||
" processed_data = patches.reshape(patches.size(0), patches.size(1), -1)\n",
|
||
" \n",
|
||
" # RNN推理\n",
|
||
" with torch.no_grad():\n",
|
||
" logits, _ = self.model(\n",
|
||
" x=smoothed_data,\n",
|
||
" day_idx=torch.tensor([session_idx], device=self.device),\n",
|
||
" states=None,\n",
|
||
" return_state=True,\n",
|
||
" )\n",
|
||
" \n",
|
||
" # 转换为numpy并立即释放GPU内存\n",
|
||
" processed_features = processed_data.float().cpu().numpy()[0] # [time_steps, processed_features]\n",
|
||
" confidence_scores = logits.float().cpu().numpy()[0] # [time_steps, 40]\n",
|
||
" \n",
|
||
" # 立即清理GPU张量\n",
|
||
" del neural_tensor, smoothed_data, processed_data, logits\n",
|
||
" if torch.cuda.is_available():\n",
|
||
" torch.cuda.empty_cache()\n",
|
||
" \n",
|
||
" # 拼接数据\n",
|
||
" concatenated = np.concatenate([processed_features, confidence_scores], axis=1)\n",
|
||
" \n",
|
||
" return {\n",
|
||
" 'concatenated_data': concatenated,\n",
|
||
" 'processed_features': processed_features,\n",
|
||
" 'confidence_scores': confidence_scores,\n",
|
||
" 'original_time_steps': neural_data.shape[0],\n",
|
||
" 'processed_time_steps': concatenated.shape[0],\n",
|
||
" 'feature_reduction_ratio': concatenated.shape[0] / neural_data.shape[0]\n",
|
||
" }\n",
|
||
" \n",
|
||
" except Exception as e:\n",
|
||
" # 确保GPU内存清理\n",
|
||
" if torch.cuda.is_available():\n",
|
||
" torch.cuda.empty_cache()\n",
|
||
" raise e\n",
|
||
" \n",
|
||
" def _upload_and_cleanup_file(self, filepath):\n",
|
||
" \"\"\"上传文件到WebDAV并删除本地文件\"\"\"\n",
|
||
" if not self.enable_auto_upload or not self.webdav_uploader:\n",
|
||
" return False\n",
|
||
" \n",
|
||
" try:\n",
|
||
" # 上传文件\n",
|
||
" filename = os.path.basename(filepath)\n",
|
||
" success = self.webdav_uploader.upload_file(\n",
|
||
" str(filepath), \n",
|
||
" '/移动云盘/DATA/rnn_processed_data/'\n",
|
||
" )\n",
|
||
" \n",
|
||
" if success:\n",
|
||
" # 删除本地文件\n",
|
||
" os.remove(filepath)\n",
|
||
" print(f\" 📤 上传并删除: {filename}\")\n",
|
||
" return True\n",
|
||
" else:\n",
|
||
" print(f\" ❌ 上传失败,保留本地文件: {filename}\")\n",
|
||
" return False\n",
|
||
" \n",
|
||
" except Exception as e:\n",
|
||
" print(f\" ⚠️ 上传过程出错: {e}\")\n",
|
||
" return False\n",
|
||
" \n",
|
||
" def _save_batch_data(self, results, session_name, data_type, save_path, batch_idx=None):\n",
|
||
" \"\"\"\n",
|
||
" 保存批次数据(减少内存占用 + 自动上传)\n",
|
||
" \n",
|
||
" 参数:\n",
|
||
" results: 结果数据\n",
|
||
" session_name: 会话名称\n",
|
||
" data_type: 数据类型\n",
|
||
" save_path: 保存路径\n",
|
||
" batch_idx: 批次索引(可选)\n",
|
||
" \"\"\"\n",
|
||
" if not results['concatenated_data']:\n",
|
||
" return\n",
|
||
" \n",
|
||
" # 生成文件名\n",
|
||
" if batch_idx is not None:\n",
|
||
" filename = f\"{session_name}_{data_type}_rnn_processed_batch{batch_idx}.npz\"\n",
|
||
" else:\n",
|
||
" filename = f\"{session_name}_{data_type}_rnn_processed.npz\"\n",
|
||
" \n",
|
||
" filepath = save_path / filename\n",
|
||
" \n",
|
||
" save_data = {\n",
|
||
" 'concatenated_data': np.array(results['concatenated_data'], dtype=object),\n",
|
||
" 'processed_features': np.array(results['processed_features'], dtype=object),\n",
|
||
" 'confidence_scores': np.array(results['confidence_scores'], dtype=object),\n",
|
||
" 'trial_metadata': np.array(results['trial_metadata'], dtype=object),\n",
|
||
" }\n",
|
||
" \n",
|
||
" # 保存文件\n",
|
||
" np.savez_compressed(str(filepath), **save_data)\n",
|
||
" print(f\" 💾 保存批次: {filename} ({len(results['concatenated_data'])} 个试验)\")\n",
|
||
" \n",
|
||
" # 自动上传并删除本地文件\n",
|
||
" if self.enable_auto_upload:\n",
|
||
" self._upload_and_cleanup_file(filepath)\n",
|
||
" \n",
|
||
" # 清理结果数据释放内存\n",
|
||
" for key in results:\n",
|
||
" if isinstance(results[key], list):\n",
|
||
" results[key].clear()\n",
|
||
" \n",
|
||
" # 强制垃圾回收\n",
|
||
" MemoryManager.clear_memory()\n",
|
||
" \n",
|
||
" def process_session(self, session_name, data_types=['train', 'val', 'test'], save_dir='./rnn_processed_data'):\n",
|
||
" \"\"\"\n",
|
||
" 处理单个session的数据(优化内存管理 + 自动上传)\n",
|
||
" \n",
|
||
" 参数:\n",
|
||
" session_name: 会话名称\n",
|
||
" data_types: 要处理的数据类型列表\n",
|
||
" save_dir: 保存目录\n",
|
||
" \n",
|
||
" 返回:\n",
|
||
" dict: 处理结果摘要\n",
|
||
" \"\"\"\n",
|
||
" print(f\"\\n 处理会话: {session_name}\")\n",
|
||
" print(f\" 开始时内存: {MemoryManager.get_memory_info()}\")\n",
|
||
" \n",
|
||
" session_idx = self.model_args['dataset']['sessions'].index(session_name)\n",
|
||
" session_results = {}\n",
|
||
" \n",
|
||
" # 确保保存目录存在\n",
|
||
" save_path = Path(save_dir)\n",
|
||
" save_path.mkdir(parents=True, exist_ok=True)\n",
|
||
" \n",
|
||
" for data_type in data_types:\n",
|
||
" data_file = os.path.join(self.data_dir, session_name, f'data_{data_type}.hdf5')\n",
|
||
" \n",
|
||
" if not os.path.exists(data_file):\n",
|
||
" print(f\" {data_type} 数据文件不存在,跳过\")\n",
|
||
" continue\n",
|
||
" \n",
|
||
" print(f\" 处理 {data_type} 数据...\")\n",
|
||
" \n",
|
||
" try:\n",
|
||
" # 加载数据\n",
|
||
" data = load_h5py_file(data_file, self.csv_df)\n",
|
||
" num_trials = len(data['neural_features'])\n",
|
||
" \n",
|
||
" if num_trials == 0:\n",
|
||
" print(f\" {data_type} 数据为空\")\n",
|
||
" continue\n",
|
||
" \n",
|
||
" # 处理所有试验(批量保存策略)\n",
|
||
" results = {\n",
|
||
" 'concatenated_data': [],\n",
|
||
" 'processed_features': [],\n",
|
||
" 'confidence_scores': [],\n",
|
||
" 'trial_metadata': [],\n",
|
||
" 'processing_stats': []\n",
|
||
" }\n",
|
||
" \n",
|
||
" batch_count = 0\n",
|
||
" total_processed = 0\n",
|
||
" uploaded_files = 0\n",
|
||
" \n",
|
||
" for trial_idx in tqdm(range(num_trials), desc=f\" {data_type}\", leave=False):\n",
|
||
" # 检查内存使用情况\n",
|
||
" if trial_idx % 10 == 0: # 每10个trial检查一次\n",
|
||
" if MemoryManager.memory_warning_check():\n",
|
||
" MemoryManager.clear_memory()\n",
|
||
" print(f\" 🧹 内存清理: {MemoryManager.get_memory_info()}\")\n",
|
||
" \n",
|
||
" neural_data = data['neural_features'][trial_idx]\n",
|
||
" \n",
|
||
" # 处理单个试验\n",
|
||
" trial_result = self._process_single_trial(neural_data, session_idx)\n",
|
||
" \n",
|
||
" # 保存结果\n",
|
||
" results['concatenated_data'].append(trial_result['concatenated_data'])\n",
|
||
" results['processed_features'].append(trial_result['processed_features'])\n",
|
||
" results['confidence_scores'].append(trial_result['confidence_scores'])\n",
|
||
" \n",
|
||
" # 保存元数据\n",
|
||
" metadata = {\n",
|
||
" 'session': session_name,\n",
|
||
" 'data_type': data_type,\n",
|
||
" 'trial_idx': trial_idx,\n",
|
||
" 'block_num': data.get('block_num', [None])[trial_idx],\n",
|
||
" 'trial_num': data.get('trial_num', [None])[trial_idx],\n",
|
||
" **{k: v for k, v in trial_result.items() if k != 'concatenated_data'}\n",
|
||
" }\n",
|
||
" \n",
|
||
" # 添加真实标签(如果可用)\n",
|
||
" if data_type in ['train', 'val'] and 'sentence_label' in data:\n",
|
||
" metadata.update({\n",
|
||
" 'sentence_label': data['sentence_label'][trial_idx],\n",
|
||
" 'seq_class_ids': data['seq_class_ids'][trial_idx],\n",
|
||
" 'seq_len': data['seq_len'][trial_idx]\n",
|
||
" })\n",
|
||
" \n",
|
||
" results['trial_metadata'].append(metadata)\n",
|
||
" results['processing_stats'].append(trial_result)\n",
|
||
" total_processed += 1\n",
|
||
" \n",
|
||
" # 批量保存策略\n",
|
||
" if (trial_idx + 1) % self.batch_save_interval == 0 or trial_idx == num_trials - 1:\n",
|
||
" self._save_batch_data(results, session_name, data_type, save_path, batch_count)\n",
|
||
" if self.enable_auto_upload:\n",
|
||
" uploaded_files += 1\n",
|
||
" batch_count += 1\n",
|
||
" \n",
|
||
" # 强制内存清理\n",
|
||
" MemoryManager.clear_memory()\n",
|
||
" \n",
|
||
" # 统计信息\n",
|
||
" if total_processed > 0:\n",
|
||
" print(f\" {data_type} 处理完成:\")\n",
|
||
" print(f\" 试验数: {total_processed}\")\n",
|
||
" print(f\" 保存批次数: {batch_count}\")\n",
|
||
" if self.enable_auto_upload:\n",
|
||
" print(f\" 上传文件数: {uploaded_files}\")\n",
|
||
" print(f\" 最终内存: {MemoryManager.get_memory_info()}\")\n",
|
||
" \n",
|
||
" session_results[data_type] = {\n",
|
||
" 'total_trials': total_processed,\n",
|
||
" 'batch_count': batch_count,\n",
|
||
" 'uploaded_files': uploaded_files if self.enable_auto_upload else 0,\n",
|
||
" 'files': [f\"{session_name}_{data_type}_rnn_processed_batch{i}.npz\" for i in range(batch_count)]\n",
|
||
" }\n",
|
||
" \n",
|
||
" # 清理大型数据对象\n",
|
||
" del data\n",
|
||
" MemoryManager.clear_memory()\n",
|
||
" \n",
|
||
" except Exception as e:\n",
|
||
" print(f\" {data_type} 处理失败: {e}\")\n",
|
||
" # 确保内存清理\n",
|
||
" MemoryManager.clear_memory()\n",
|
||
" continue\n",
|
||
" \n",
|
||
" print(f\" 会话完成时内存: {MemoryManager.get_memory_info()}\")\n",
|
||
" return session_results\n",
|
||
" \n",
|
||
" def process_all_sessions(self, data_types=['train', 'val', 'test'], save_dir='./rnn_processed_data'):\n",
|
||
" \"\"\"\n",
|
||
" 批量处理所有sessions(优化内存管理 + 自动上传)\n",
|
||
" \n",
|
||
" 参数:\n",
|
||
" data_types: 要处理的数据类型\n",
|
||
" save_dir: 保存目录\n",
|
||
" \n",
|
||
" 返回:\n",
|
||
" dict: 所有处理结果摘要\n",
|
||
" \"\"\"\n",
|
||
" print(f\"\\n 开始批量处理所有会话(增强内存管理 + 自动上传)...\")\n",
|
||
" print(f\" 目标数据类型: {data_types}\")\n",
|
||
" print(f\" 保存目录: {save_dir}\")\n",
|
||
" print(f\" 批量保存间隔: {self.batch_save_interval}\")\n",
|
||
" print(f\" 自动上传: {'✅ 启用' if self.enable_auto_upload else '❌ 禁用'}\")\n",
|
||
" print(f\" 初始内存状态: {MemoryManager.get_memory_info()}\")\n",
|
||
" \n",
|
||
" save_path = Path(save_dir)\n",
|
||
" save_path.mkdir(parents=True, exist_ok=True)\n",
|
||
" \n",
|
||
" all_results = {}\n",
|
||
" sessions = self.model_args['dataset']['sessions']\n",
|
||
" \n",
|
||
" start_time = time.time()\n",
|
||
" total_uploaded_files = 0\n",
|
||
" \n",
|
||
" for i, session in enumerate(sessions):\n",
|
||
" print(f\"\\n 进度: {i+1}/{len(sessions)} - {session}\")\n",
|
||
" \n",
|
||
" try:\n",
|
||
" session_results = self.process_session(session, data_types, save_dir)\n",
|
||
" \n",
|
||
" if session_results:\n",
|
||
" all_results[session] = session_results\n",
|
||
" \n",
|
||
" # 统计上传文件数\n",
|
||
" session_uploaded = sum(\n",
|
||
" type_data.get('uploaded_files', 0) \n",
|
||
" for type_data in session_results.values()\n",
|
||
" )\n",
|
||
" total_uploaded_files += session_uploaded\n",
|
||
" \n",
|
||
" print(f\" 会话 {session} 完成\")\n",
|
||
" if self.enable_auto_upload:\n",
|
||
" print(f\" 本会话上传文件: {session_uploaded}\")\n",
|
||
" else:\n",
|
||
" print(f\" 会话 {session} 无有效数据\")\n",
|
||
" \n",
|
||
" # 每处理几个session进行一次深度内存清理\n",
|
||
" if (i + 1) % 5 == 0:\n",
|
||
" print(f\" 深度内存清理...\")\n",
|
||
" collected = MemoryManager.clear_memory()\n",
|
||
" print(f\" 回收对象数: {collected}, 当前内存: {MemoryManager.get_memory_info()}\")\n",
|
||
" \n",
|
||
" except Exception as e:\n",
|
||
" print(f\" 会话 {session} 处理失败: {e}\")\n",
|
||
" # 确保内存清理\n",
|
||
" MemoryManager.clear_memory()\n",
|
||
" continue\n",
|
||
" \n",
|
||
" # 生成总结\n",
|
||
" end_time = time.time()\n",
|
||
" processing_time = end_time - start_time\n",
|
||
" \n",
|
||
" total_trials = sum(\n",
|
||
" session_data[data_type]['total_trials']\n",
|
||
" for session_data in all_results.values()\n",
|
||
" for data_type in session_data.keys()\n",
|
||
" )\n",
|
||
" \n",
|
||
" total_files = sum(\n",
|
||
" session_data[data_type]['batch_count']\n",
|
||
" for session_data in all_results.values()\n",
|
||
" for data_type in session_data.keys()\n",
|
||
" )\n",
|
||
" \n",
|
||
" print(f\"\\n 批量处理完成!\")\n",
|
||
" print(f\"⏱ 总耗时: {processing_time/60:.2f} 分钟\")\n",
|
||
" print(f\" 处理统计:\")\n",
|
||
" print(f\" 成功会话: {len(all_results)}/{len(sessions)}\")\n",
|
||
" print(f\" 总试验数: {total_trials}\")\n",
|
||
" print(f\" 生成文件总数: {total_files}\")\n",
|
||
" if self.enable_auto_upload:\n",
|
||
" print(f\" 📤 上传文件总数: {total_uploaded_files}\")\n",
|
||
" print(f\" 💾 本地保留文件: {total_files - total_uploaded_files}\")\n",
|
||
" print(f\" 最终内存状态: {MemoryManager.get_memory_info()}\")\n",
|
||
" print(f\" 数据保存在: {save_dir}\")\n",
|
||
" \n",
|
||
" # 保存总结信息\n",
|
||
" summary = {\n",
|
||
" 'processing_time': processing_time,\n",
|
||
" 'total_sessions': len(all_results),\n",
|
||
" 'total_trials': total_trials,\n",
|
||
" 'total_files': total_files,\n",
|
||
" 'uploaded_files': total_uploaded_files if self.enable_auto_upload else 0,\n",
|
||
" 'auto_upload_enabled': self.enable_auto_upload,\n",
|
||
" 'data_types': data_types,\n",
|
||
" 'sessions': list(all_results.keys()),\n",
|
||
" 'batch_save_interval': self.batch_save_interval,\n",
|
||
" 'memory_management': True,\n",
|
||
" 'model_config': {\n",
|
||
" 'patch_size': self.model_args['model']['patch_size'],\n",
|
||
" 'patch_stride': self.model_args['model']['patch_stride'],\n",
|
||
" 'smooth_kernel_size': self.model_args['dataset']['data_transforms']['smooth_kernel_size'],\n",
|
||
" 'smooth_kernel_std': self.model_args['dataset']['data_transforms']['smooth_kernel_std'],\n",
|
||
" }\n",
|
||
" }\n",
|
||
" \n",
|
||
" import json\n",
|
||
" with open(save_path / 'processing_summary.json', 'w') as f:\n",
|
||
" json.dump(summary, f, indent=2)\n",
|
||
" \n",
|
||
" return all_results\n",
|
||
"\n",
|
||
"# 检查WebDAV上传器是否可用\n",
|
||
"try:\n",
|
||
" # 如果之前的上传器可用,直接使用\n",
|
||
" if 'uploader' in globals():\n",
|
||
" webdav_uploader_instance = uploader\n",
|
||
" print(\"🔗 使用现有的WebDAV上传器\")\n",
|
||
" else:\n",
|
||
" # 创建简单的WebDAV上传器\n",
|
||
" from webdav3.client import Client\n",
|
||
" \n",
|
||
" class SimpleWebDAVUploader:\n",
|
||
" def __init__(self):\n",
|
||
" self.client = Client({\n",
|
||
" 'webdav_hostname': 'http://zchens.cn:5244/dav/',\n",
|
||
" 'webdav_login': 'admin',\n",
|
||
" 'webdav_password': 'Zccns20050420',\n",
|
||
" 'webdav_timeout': 30\n",
|
||
" })\n",
|
||
" \n",
|
||
" def upload_file(self, local_file, remote_dir):\n",
|
||
" try:\n",
|
||
" filename = os.path.basename(local_file)\n",
|
||
" remote_path = remote_dir.rstrip('/') + '/' + filename\n",
|
||
" \n",
|
||
" if not self.client.check(remote_dir):\n",
|
||
" self.client.mkdir(remote_dir)\n",
|
||
" \n",
|
||
" self.client.upload_sync(remote_path=remote_path, local_path=local_file)\n",
|
||
" return True\n",
|
||
" except Exception as e:\n",
|
||
" print(f\"上传失败: {e}\")\n",
|
||
" return False\n",
|
||
" \n",
|
||
" webdav_uploader_instance = SimpleWebDAVUploader()\n",
|
||
" print(\"🔗 创建新的WebDAV上传器\")\n",
|
||
" \n",
|
||
"except Exception as e:\n",
|
||
" webdav_uploader_instance = None\n",
|
||
" print(f\"⚠️ WebDAV上传器不可用: {e}\")\n",
|
||
"\n",
|
||
"# 创建处理器实例(增强内存管理 + 自动上传)\n",
|
||
"print(\"🔧 创建RNN数据处理器(增强内存管理 + 自动上传版)...\")\n",
|
||
"\n",
|
||
"try:\n",
|
||
" processor = RNNDataProcessor(\n",
|
||
" model_path='../data/t15_pretrained_rnn_baseline',\n",
|
||
" data_dir='../data/hdf5_data_final',\n",
|
||
" csv_path='../data/t15_copyTaskData_description.csv',\n",
|
||
" device='auto',\n",
|
||
" batch_save_interval=25, # 每25个试验保存一次,减少内存积累\n",
|
||
" memory_threshold=80, # 80%内存使用率时警告\n",
|
||
" enable_auto_upload=True, # 启用自动上传\n",
|
||
" webdav_uploader=webdav_uploader_instance # 传入WebDAV上传器\n",
|
||
" )\n",
|
||
" \n",
|
||
" print(f\" RNN数据处理器创建成功!\")\n",
|
||
" print(f\" 功能特性:\")\n",
|
||
" print(f\" - 批量保存间隔: 25个试验\")\n",
|
||
" print(f\" - 自动内存监控和清理\")\n",
|
||
" print(f\" - GPU内存即时释放\")\n",
|
||
" print(f\" - 垃圾回收优化\")\n",
|
||
" print(f\" - 📤 自动上传到WebDAV\")\n",
|
||
" print(f\" - 🗑️ 自动删除本地文件\")\n",
|
||
" \n",
|
||
"except Exception as e:\n",
|
||
" print(f\" 处理器创建失败: {e}\")\n",
|
||
" processor = None"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"======================================================================\n",
|
||
"🎯 RNN数据批量处理 - 使用示例(内存优化版)\n",
|
||
"======================================================================\n",
|
||
"\n",
|
||
"📋 可用的处理方法:\n",
|
||
"1️⃣ 单session处理: processor.process_session('session_name')\n",
|
||
"2️⃣ 批量处理所有: processor.process_all_sessions()\n",
|
||
"\n",
|
||
"🧠 内存管理特性:\n",
|
||
" ✅ 自动批量保存 (每25个试验)\n",
|
||
" ✅ 实时内存监控和清理\n",
|
||
" ✅ GPU内存即时释放\n",
|
||
" ✅ 垃圾回收优化\n",
|
||
"\n",
|
||
"📊 可用会话数量: 45\n",
|
||
"📝 前5个会话: ['t15.2023.08.11', 't15.2023.08.13', 't15.2023.08.18', 't15.2023.08.20', 't15.2023.08.25']\n",
|
||
"💡 问题会话 t15.2023.09.01 在位置: 6\n",
|
||
"\n",
|
||
"🔍 当前系统状态:\n",
|
||
" RAM: 1.18GB (8.6%)\n",
|
||
"\n",
|
||
"🧪 快速测试: 处理会话 't15.2023.08.13' 的训练数据...\n",
|
||
" 这将测试内存管理功能...\n",
|
||
"\n",
|
||
"🔄 处理会话: t15.2023.08.13\n",
|
||
" 开始时内存: RAM: 1.18GB (8.6%)\n",
|
||
" 📁 处理 train 数据...\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" train: 0%| | 0/348 [00:00<?, ?it/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"⚠️ 内存使用率过高: 96.0%\n",
|
||
" 🧹 内存清理: RAM: 1.39GB (10.2%)\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" train: 3%|▎ | 10/348 [00:06<03:27, 1.63it/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"⚠️ 内存使用率过高: 93.7%\n",
|
||
" 🧹 内存清理: RAM: 1.56GB (11.4%)\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" train: 6%|▌ | 20/348 [00:13<02:59, 1.82it/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"⚠️ 内存使用率过高: 93.8%\n",
|
||
" 🧹 内存清理: RAM: 1.65GB (12.0%)\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" train: 7%|▋ | 24/348 [00:15<03:22, 1.60it/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" 💾 保存批次: t15.2023.08.13_train_rnn_processed_batch0.npz (25 个试验)\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" train: 9%|▊ | 30/348 [00:41<09:06, 1.72s/it]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"⚠️ 内存使用率过高: 93.3%\n",
|
||
" 🧹 内存清理: RAM: 1.55GB (11.3%)\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" train: 11%|█▏ | 40/348 [00:47<03:49, 1.34it/s]"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"⚠️ 内存使用率过高: 93.7%\n",
|
||
" 🧹 内存清理: RAM: 1.63GB (11.9%)\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" "
|
||
]
|
||
},
|
||
{
|
||
"ename": "KeyboardInterrupt",
|
||
"evalue": "",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
||
"Cell \u001b[1;32mIn[2], line 36\u001b[0m\n\u001b[0;32m 33\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m 这将测试内存管理功能...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 35\u001b[0m \u001b[38;5;66;03m# 处理单个session(仅train数据进行测试)\u001b[39;00m\n\u001b[1;32m---> 36\u001b[0m single_result \u001b[38;5;241m=\u001b[39m \u001b[43mprocessor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprocess_session\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_session\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mtrain\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 38\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m single_result \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m single_result:\n\u001b[0;32m 39\u001b[0m train_info \u001b[38;5;241m=\u001b[39m single_result[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m'\u001b[39m]\n",
|
||
"Cell \u001b[1;32mIn[1], line 400\u001b[0m, in \u001b[0;36mRNNDataProcessor.process_session\u001b[1;34m(self, session_name, data_types, save_dir)\u001b[0m\n\u001b[0;32m 398\u001b[0m \u001b[38;5;66;03m# 批量保存策略\u001b[39;00m\n\u001b[0;32m 399\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (trial_idx \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m) \u001b[38;5;241m%\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_save_interval \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m trial_idx \u001b[38;5;241m==\u001b[39m num_trials \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m--> 400\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_save_batch_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mresults\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msession_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msave_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_count\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 401\u001b[0m batch_count \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m 403\u001b[0m \u001b[38;5;66;03m# 强制内存清理\u001b[39;00m\n",
|
||
"Cell \u001b[1;32mIn[1], line 296\u001b[0m, in \u001b[0;36mRNNDataProcessor._save_batch_data\u001b[1;34m(self, results, session_name, data_type, save_path, batch_idx)\u001b[0m\n\u001b[0;32m 287\u001b[0m filepath \u001b[38;5;241m=\u001b[39m save_path \u001b[38;5;241m/\u001b[39m filename\n\u001b[0;32m 289\u001b[0m save_data \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m 290\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mconcatenated_data\u001b[39m\u001b[38;5;124m'\u001b[39m: np\u001b[38;5;241m.\u001b[39marray(results[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mconcatenated_data\u001b[39m\u001b[38;5;124m'\u001b[39m], dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mobject\u001b[39m),\n\u001b[0;32m 291\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mprocessed_features\u001b[39m\u001b[38;5;124m'\u001b[39m: np\u001b[38;5;241m.\u001b[39marray(results[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mprocessed_features\u001b[39m\u001b[38;5;124m'\u001b[39m], dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mobject\u001b[39m),\n\u001b[0;32m 292\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mconfidence_scores\u001b[39m\u001b[38;5;124m'\u001b[39m: np\u001b[38;5;241m.\u001b[39marray(results[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mconfidence_scores\u001b[39m\u001b[38;5;124m'\u001b[39m], dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mobject\u001b[39m),\n\u001b[0;32m 293\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrial_metadata\u001b[39m\u001b[38;5;124m'\u001b[39m: np\u001b[38;5;241m.\u001b[39marray(results[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrial_metadata\u001b[39m\u001b[38;5;124m'\u001b[39m], dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mobject\u001b[39m),\n\u001b[0;32m 294\u001b[0m }\n\u001b[1;32m--> 296\u001b[0m np\u001b[38;5;241m.\u001b[39msavez_compressed(\u001b[38;5;28mstr\u001b[39m(filepath), \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39msave_data)\n\u001b[0;32m 297\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m 💾 保存批次: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfilename\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(results[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mconcatenated_data\u001b[39m\u001b[38;5;124m'\u001b[39m])\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m 个试验)\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 299\u001b[0m \u001b[38;5;66;03m# 清理结果数据释放内存\u001b[39;00m\n",
|
||
"File \u001b[1;32md:\\SoftWare\\Anaconda3\\envs\\b2txt25\\lib\\site-packages\\numpy\\lib\\_npyio_impl.py:753\u001b[0m, in \u001b[0;36msavez_compressed\u001b[1;34m(file, *args, **kwds)\u001b[0m\n\u001b[0;32m 689\u001b[0m \u001b[38;5;129m@array_function_dispatch\u001b[39m(_savez_compressed_dispatcher)\n\u001b[0;32m 690\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21msavez_compressed\u001b[39m(file, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds):\n\u001b[0;32m 691\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m 692\u001b[0m \u001b[38;5;124;03m Save several arrays into a single file in compressed ``.npz`` format.\u001b[39;00m\n\u001b[0;32m 693\u001b[0m \n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 751\u001b[0m \n\u001b[0;32m 752\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[1;32m--> 753\u001b[0m \u001b[43m_savez\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwds\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
|
||
"File \u001b[1;32md:\\SoftWare\\Anaconda3\\envs\\b2txt25\\lib\\site-packages\\numpy\\lib\\_npyio_impl.py:786\u001b[0m, in \u001b[0;36m_savez\u001b[1;34m(file, args, kwds, compress, allow_pickle, pickle_kwargs)\u001b[0m\n\u001b[0;32m 784\u001b[0m \u001b[38;5;66;03m# always force zip64, gh-10776\u001b[39;00m\n\u001b[0;32m 785\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m zipf\u001b[38;5;241m.\u001b[39mopen(fname, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m, force_zip64\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m fid:\n\u001b[1;32m--> 786\u001b[0m \u001b[38;5;28;43mformat\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite_array\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfid\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 787\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_pickle\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mallow_pickle\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 788\u001b[0m \u001b[43m \u001b[49m\u001b[43mpickle_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpickle_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 790\u001b[0m zipf\u001b[38;5;241m.\u001b[39mclose()\n",
|
||
"File \u001b[1;32md:\\SoftWare\\Anaconda3\\envs\\b2txt25\\lib\\site-packages\\numpy\\lib\\format.py:746\u001b[0m, in \u001b[0;36mwrite_array\u001b[1;34m(fp, array, version, allow_pickle, pickle_kwargs)\u001b[0m\n\u001b[0;32m 744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m pickle_kwargs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 745\u001b[0m pickle_kwargs \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m--> 746\u001b[0m pickle\u001b[38;5;241m.\u001b[39mdump(array, fp, protocol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m4\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mpickle_kwargs)\n\u001b[0;32m 747\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m array\u001b[38;5;241m.\u001b[39mflags\u001b[38;5;241m.\u001b[39mf_contiguous \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m array\u001b[38;5;241m.\u001b[39mflags\u001b[38;5;241m.\u001b[39mc_contiguous:\n\u001b[0;32m 748\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m isfileobj(fp):\n",
|
||
"File \u001b[1;32md:\\SoftWare\\Anaconda3\\envs\\b2txt25\\lib\\zipfile.py:1142\u001b[0m, in \u001b[0;36m_ZipWriteFile.write\u001b[1;34m(self, data)\u001b[0m\n\u001b[0;32m 1140\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_crc \u001b[38;5;241m=\u001b[39m crc32(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_crc)\n\u001b[0;32m 1141\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compressor:\n\u001b[1;32m-> 1142\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_compressor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompress\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1143\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compress_size \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(data)\n\u001b[0;32m 1144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_fileobj\u001b[38;5;241m.\u001b[39mwrite(data)\n",
|
||
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 使用示例和批量处理(增强内存管理版)\n",
|
||
"\n",
|
||
"print(\"=\"*70)\n",
|
||
"print(\"RNN数据批量处理 - 使用示例(内存优化版)\")\n",
|
||
"print(\"=\"*70)\n",
|
||
"\n",
|
||
"if processor is not None:\n",
|
||
" \n",
|
||
" # 方法1: 处理单个session (推荐用于测试)\n",
|
||
" print(\"\\n可用的处理方法:\")\n",
|
||
" print(\"1. 单session处理: processor.process_session('session_name')\")\n",
|
||
" print(\"2. 批量处理所有: processor.process_all_sessions()\")\n",
|
||
" print(\"\\n内存管理特性:\")\n",
|
||
" print(\" 自动批量保存 (每25个试验)\")\n",
|
||
" print(\" 实时内存监控和清理\")\n",
|
||
" print(\" GPU内存即时释放\")\n",
|
||
" print(\" 垃圾回收优化\")\n",
|
||
" \n",
|
||
" # 显示可用的sessions\n",
|
||
" sessions = processor.model_args['dataset']['sessions']\n",
|
||
" print(f\"\\n可用会话数量: {len(sessions)}\")\n",
|
||
" print(f\"前5个会话: {sessions[:5]}\")\n",
|
||
" print(f\"问题会话 t15.2023.09.01 在位置: {sessions.index('t15.2023.09.01') if 't15.2023.09.01' in sessions else '未找到'}\")\n",
|
||
" \n",
|
||
" # 内存状态检查\n",
|
||
" print(f\"\\n当前系统状态:\")\n",
|
||
" print(f\" {MemoryManager.get_memory_info()}\")\n",
|
||
" \n",
|
||
" # 快速测试 - 处理第一个session的部分数据\n",
|
||
" test_session = sessions[1] # 't15.2023.08.11'\n",
|
||
" \n",
|
||
" print(f\"\\n快速测试: 处理会话 '{test_session}' 的训练数据...\")\n",
|
||
" print(f\" 这将测试内存管理功能...\")\n",
|
||
" \n",
|
||
" # 处理单个session(仅train数据进行测试)\n",
|
||
" single_result = processor.process_session(test_session, ['train'])\n",
|
||
" \n",
|
||
" if single_result and 'train' in single_result:\n",
|
||
" train_info = single_result['train']\n",
|
||
" \n",
|
||
" print(f\"\\n内存管理测试完成!结果概览:\")\n",
|
||
" print(f\" 处理的试验数: {train_info['total_trials']}\")\n",
|
||
" print(f\" 保存的批次数: {train_info['batch_count']}\")\n",
|
||
" print(f\" 生成的文件: {len(train_info['files'])}\")\n",
|
||
" print(f\" 内存管理状态: {MemoryManager.get_memory_info()}\")\n",
|
||
" \n",
|
||
" # 加载一个批次文件来验证\n",
|
||
" if train_info['files']:\n",
|
||
" first_file = Path('./rnn_processed_data') / train_info['files'][0]\n",
|
||
" if first_file.exists():\n",
|
||
" test_data = np.load(str(first_file), allow_pickle=True)\n",
|
||
" sample_data = test_data['concatenated_data'][0]\n",
|
||
" \n",
|
||
" print(f\"\\n数据验证 (第一批次):\")\n",
|
||
" print(f\" 批次文件: {train_info['files'][0]}\")\n",
|
||
" print(f\" 样本数据形状: {sample_data.shape}\")\n",
|
||
" print(f\" 特征维度详情:\")\n",
|
||
" print(f\" - 处理后的神经特征: {sample_data.shape[1] - 41} 维\")\n",
|
||
" print(f\" - RNN置信度分数: 41 维\")\n",
|
||
" print(f\" - 总拼接特征: {sample_data.shape[1]} 维\")\n",
|
||
" print(f\" - 时间步数: {sample_data.shape[0]}\")\n",
|
||
" \n",
|
||
" # 显示一些样本元数据\n",
|
||
" sample_metadata = test_data['trial_metadata'][0]\n",
|
||
" print(f\" 样本元数据:\")\n",
|
||
" print(f\" - 原始时间步: {sample_metadata['original_time_steps']}\")\n",
|
||
" print(f\" - 处理后时间步: {sample_metadata['processed_time_steps']}\")\n",
|
||
" print(f\" - 时间压缩比: {sample_metadata['feature_reduction_ratio']:.3f}\")\n",
|
||
" \n",
|
||
" if 'sentence_label' in sample_metadata:\n",
|
||
" print(f\" - 句子标签: {sample_metadata['sentence_label']}\")\n",
|
||
" \n",
|
||
" # 清理测试数据\n",
|
||
" del test_data, sample_data, sample_metadata\n",
|
||
" MemoryManager.clear_memory()\n",
|
||
" \n",
|
||
" print(f\"\\n处理大数据集建议:\")\n",
|
||
" print(f\" 使用增强版处理器,自动内存管理\")\n",
|
||
" print(f\" 数据自动分批保存,避免内存溢出\") \n",
|
||
" print(f\" 可安全处理 t15.2023.09.01 等大批次\")\n",
|
||
" print(f\" 要批量处理所有数据,运行:\")\n",
|
||
" print(f\" results = processor.process_all_sessions()\")\n",
|
||
" \n",
|
||
"else:\n",
|
||
" print(\"处理器未创建成功,请检查上面的错误信息\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"======================================================================\n",
|
||
"🚀 批量处理选项\n",
|
||
"======================================================================\n",
|
||
"📊 批量处理配置:\n",
|
||
" 启用批量处理: True\n",
|
||
" 保存目录: ./rnn_processed_data\n",
|
||
" 数据类型: ['train', 'val', 'test']\n",
|
||
" 总会话数: 45\n",
|
||
"\n",
|
||
"🚀 开始批量处理所有数据...\n",
|
||
"⚠️ 这可能需要较长时间(预计30-60分钟)\n",
|
||
"\n",
|
||
"🚀 开始批量处理所有会话...\n",
|
||
" 目标数据类型: ['train', 'val', 'test']\n",
|
||
" 保存目录: ./rnn_processed_data\n",
|
||
"\n",
|
||
"📊 进度: 1/45\n",
|
||
"\n",
|
||
"🔄 处理会话: t15.2023.08.11\n",
|
||
" 📁 处理 train 数据...\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" ✅ train 处理完成:\n",
|
||
" 试验数: 288\n",
|
||
" 时间步范围: 30-251\n",
|
||
" 特征维度: 7209 (处理后特征: 7169, 置信度: 40)\n",
|
||
" 平均时间压缩比: 0.240\n",
|
||
" ⚠️ val 数据文件不存在,跳过\n",
|
||
" ⚠️ test 数据文件不存在,跳过\n",
|
||
" 💾 保存: t15.2023.08.11_train_rnn_processed.npz\n",
|
||
"\n",
|
||
"📊 进度: 2/45\n",
|
||
"\n",
|
||
"🔄 处理会话: t15.2023.08.13\n",
|
||
" 📁 处理 train 数据...\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" ✅ train 处理完成:\n",
|
||
" 试验数: 348\n",
|
||
" 时间步范围: 55-352\n",
|
||
" 特征维度: 7209 (处理后特征: 7169, 置信度: 40)\n",
|
||
" 平均时间压缩比: 0.243\n",
|
||
" 📁 处理 val 数据...\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" ✅ val 处理完成:\n",
|
||
" 试验数: 35\n",
|
||
" 时间步范围: 90-296\n",
|
||
" 特征维度: 7209 (处理后特征: 7169, 置信度: 40)\n",
|
||
" 平均时间压缩比: 0.243\n",
|
||
" 📁 处理 test 数据...\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" \r"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" ✅ test 处理完成:\n",
|
||
" 试验数: 35\n",
|
||
" 时间步范围: 80-238\n",
|
||
" 特征维度: 7209 (处理后特征: 7169, 置信度: 40)\n",
|
||
" 平均时间压缩比: 0.242\n",
|
||
" 💾 保存: t15.2023.08.13_train_rnn_processed.npz\n",
|
||
" 💾 保存: t15.2023.08.13_val_rnn_processed.npz\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 🚀 批量处理所有数据 (增强内存管理 + 自动上传版本)\n",
|
||
"\n",
|
||
"print(\"=\"*70)\n",
|
||
"print(\"批量处理选项 - 内存优化 + 自动上传版\")\n",
|
||
"print(\"=\"*70)\n",
|
||
"\n",
|
||
"# 设置参数\n",
|
||
"ENABLE_FULL_PROCESSING = True # 设为True开始批量处理 \n",
|
||
"SAVE_DIR = \"./rnn_processed_data\" # 保存目录\n",
|
||
"DATA_TYPES = ['train', 'val', 'test'] # 要处理的数据类型\n",
|
||
"\n",
|
||
"print(f\" 批量处理配置 (内存优化 + 自动上传版):\")\n",
|
||
"print(f\" 启用批量处理: {ENABLE_FULL_PROCESSING}\")\n",
|
||
"print(f\" 保存目录: {SAVE_DIR}\")\n",
|
||
"print(f\" 数据类型: {DATA_TYPES}\")\n",
|
||
"print(f\" 总会话数: {len(processor.model_args['dataset']['sessions'])}\")\n",
|
||
"print(f\" 批量保存策略: 每{processor.batch_save_interval}个试验保存一次\")\n",
|
||
"print(f\" 内存监控阈值: {processor.memory_threshold}%\")\n",
|
||
"print(f\" 📤 自动上传: {'✅ 启用' if processor.enable_auto_upload else '❌ 禁用'}\")\n",
|
||
"print(f\" 🗑️ 自动清理: {'✅ 启用' if processor.enable_auto_upload else '❌ 禁用'}\")\n",
|
||
"\n",
|
||
"# 显示新功能优势\n",
|
||
"print(f\"\\n 🆕 新增功能优势:\")\n",
|
||
"print(f\" 📤 处理完成后自动上传到WebDAV\")\n",
|
||
"print(f\" 🗑️ 上传成功后自动删除本地文件\")\n",
|
||
"print(f\" 💾 节省本地存储空间\")\n",
|
||
"print(f\" ☁️ 数据安全备份到云端\")\n",
|
||
"print(f\" 🔄 无需手动管理文件传输\")\n",
|
||
"\n",
|
||
"print(f\"\\n 🔧 技术特性:\")\n",
|
||
"print(f\" 自动分批保存,避免内存积累\")\n",
|
||
"print(f\" 实时GPU内存清理\")\n",
|
||
"print(f\" 垃圾回收优化\")\n",
|
||
"print(f\" 内存使用监控和警告\")\n",
|
||
"print(f\" 可处理 t15.2023.09.01 等大数据集\")\n",
|
||
"\n",
|
||
"if ENABLE_FULL_PROCESSING and processor is not None:\n",
|
||
" print(f\"\\n 🚀 开始批量处理所有数据(内存优化 + 自动上传版)...\")\n",
|
||
" print(f\" 这可能需要较长时间(预计30-60分钟)\")\n",
|
||
" print(f\" 内存不足问题已解决,可安全处理大数据集\")\n",
|
||
" print(f\" 📤 文件将自动上传到WebDAV并删除本地副本\")\n",
|
||
" print(f\" 开始时内存状态: {MemoryManager.get_memory_info()}\")\n",
|
||
" \n",
|
||
" # 记录处理开始时间\n",
|
||
" import time\n",
|
||
" start_processing_time = time.time()\n",
|
||
" \n",
|
||
" # 批量处理\n",
|
||
" all_results = processor.process_all_sessions(\n",
|
||
" data_types=DATA_TYPES,\n",
|
||
" save_dir=SAVE_DIR\n",
|
||
" )\n",
|
||
" \n",
|
||
" # 计算处理时间\n",
|
||
" end_processing_time = time.time()\n",
|
||
" total_processing_time = end_processing_time - start_processing_time\n",
|
||
" \n",
|
||
" print(f\"\\n 🎉 批量处理完成!\")\n",
|
||
" print(f\" ⏱ 总处理时间: {total_processing_time/60:.2f} 分钟\")\n",
|
||
" print(f\" 最终内存状态: {MemoryManager.get_memory_info()}\")\n",
|
||
" \n",
|
||
" # 详细统计\n",
|
||
" total_files = 0\n",
|
||
" total_uploaded = 0\n",
|
||
" for session_name, session_data in all_results.items():\n",
|
||
" for data_type, type_data in session_data.items():\n",
|
||
" total_files += type_data['batch_count']\n",
|
||
" total_uploaded += type_data.get('uploaded_files', 0)\n",
|
||
" \n",
|
||
" print(f\"\\n 📊 处理统计详情:\")\n",
|
||
" print(f\" 成功处理的会话: {len(all_results)}\")\n",
|
||
" print(f\" 生成文件总数: {total_files}\")\n",
|
||
" print(f\" 📤 成功上传文件: {total_uploaded}\")\n",
|
||
" print(f\" 💾 本地保留文件: {total_files - total_uploaded}\")\n",
|
||
" print(f\" 🔄 上传成功率: {(total_uploaded/total_files*100) if total_files > 0 else 0:.1f}%\")\n",
|
||
" \n",
|
||
" # 存储空间统计\n",
|
||
" try:\n",
|
||
" import os\n",
|
||
" local_size = 0\n",
|
||
" if os.path.exists(SAVE_DIR):\n",
|
||
" for root, dirs, files in os.walk(SAVE_DIR):\n",
|
||
" for file in files:\n",
|
||
" local_size += os.path.getsize(os.path.join(root, file))\n",
|
||
" \n",
|
||
" print(f\" 💾 剩余本地文件大小: {local_size / 1024**3:.2f} GB\")\n",
|
||
" except:\n",
|
||
" pass\n",
|
||
" \n",
|
||
" print(f\"\\n ☁️ WebDAV云端存储:\")\n",
|
||
" print(f\" 远程路径: /移动云盘/DATA/rnn_processed_data/\")\n",
|
||
" print(f\" 文件格式: session_name_datatype_rnn_processed_batchN.npz\")\n",
|
||
" print(f\" 例如: t15.2023.08.13_train_rnn_processed_batch0.npz\")\n",
|
||
" \n",
|
||
" if total_uploaded > 0:\n",
|
||
" print(f\"\\n ✅ 自动上传工作流成功!\")\n",
|
||
" print(f\" 所有处理完的文件已自动上传到云端\")\n",
|
||
" print(f\" 本地存储空间得到有效管理\")\n",
|
||
" print(f\" 数据安全性和可访问性得到保障\")\n",
|
||
" \n",
|
||
"else:\n",
|
||
" print(f\"\\n 要开始批量处理,请将 ENABLE_FULL_PROCESSING 设为 True\")\n",
|
||
" print(f\" 或者手动运行: processor.process_all_sessions()\")\n",
|
||
"\n",
|
||
"print(f\"\\n 📋 数据使用说明(自动上传版本):\")\n",
|
||
"print(f\" 🔄 处理流程:\")\n",
|
||
"print(f\" 1. 处理原始神经数据 → RNN输出\")\n",
|
||
"print(f\" 2. 保存到本地 (.npz 文件)\")\n",
|
||
"print(f\" 3. 自动上传到WebDAV云端\")\n",
|
||
"print(f\" 4. 删除本地文件,释放存储空间\")\n",
|
||
"print(f\"\")\n",
|
||
"print(f\" 📤 云端文件结构:\")\n",
|
||
"print(f\" /移动云盘/DATA/rnn_processed_data/\")\n",
|
||
"print(f\" ├── t15.2023.08.11_train_rnn_processed_batch0.npz\")\n",
|
||
"print(f\" ├── t15.2023.08.11_train_rnn_processed_batch1.npz\")\n",
|
||
"print(f\" ├── t15.2023.08.11_val_rnn_processed_batch0.npz\")\n",
|
||
"print(f\" └── ...\")\n",
|
||
"print(f\"\")\n",
|
||
"print(f\" 📥 下载和使用:\")\n",
|
||
"print(f\" # 从WebDAV下载文件\")\n",
|
||
"print(f\" uploader.client.download_sync(\")\n",
|
||
"print(f\" remote_path='/移动云盘/DATA/rnn_processed_data/filename.npz',\")\n",
|
||
"print(f\" local_path='./filename.npz'\")\n",
|
||
"print(f\" )\")\n",
|
||
"print(f\" \")\n",
|
||
"print(f\" # 加载数据\")\n",
|
||
"print(f\" data = np.load('filename.npz', allow_pickle=True)\")\n",
|
||
"print(f\" features = data['concatenated_data']\")\n",
|
||
"print(f\" metadata = data['trial_metadata']\")\n",
|
||
"print(f\"\")\n",
|
||
"print(f\" 🎯 优势总结:\")\n",
|
||
"print(f\" ✅ 解决了 t15.2023.09.01 内存不足问题\")\n",
|
||
"print(f\" ✅ 数据自动分批,便于后续加载\")\n",
|
||
"print(f\" ✅ 处理速度优化,内存使用稳定\")\n",
|
||
"print(f\" ✅ 错误恢复能力强,单个批次失败不影响整体\")\n",
|
||
"print(f\" 🆕 自动上传,无需手动管理文件\")\n",
|
||
"print(f\" 🆕 本地存储空间自动释放\")\n",
|
||
"print(f\" 🆕 云端数据安全备份\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 📁 云端文件管理工具\n",
|
||
"\n",
|
||
"def list_cloud_files(remote_dir='/移动云盘/DATA/rnn_processed_data/'):\n",
|
||
" \"\"\"列出云端的所有处理文件\"\"\"\n",
|
||
" if 'uploader' not in globals():\n",
|
||
" print(\"❌ WebDAV上传器未初始化\")\n",
|
||
" return []\n",
|
||
" \n",
|
||
" try:\n",
|
||
" files = uploader.client.list(remote_dir)\n",
|
||
" rnn_files = [f for f in files if f.endswith('.npz')]\n",
|
||
" \n",
|
||
" print(f\"☁️ 云端文件列表 ({len(rnn_files)} 个文件):\")\n",
|
||
" print(f\"📍 路径: {remote_dir}\")\n",
|
||
" print(\"-\" * 60)\n",
|
||
" \n",
|
||
" for i, file in enumerate(rnn_files, 1):\n",
|
||
" print(f\"{i:3d}. {file}\")\n",
|
||
" \n",
|
||
" return rnn_files\n",
|
||
" \n",
|
||
" except Exception as e:\n",
|
||
" print(f\"❌ 获取文件列表失败: {e}\")\n",
|
||
" return []\n",
|
||
"\n",
|
||
"def download_cloud_file(filename, local_dir='./downloaded_data/'):\n",
|
||
" \"\"\"从云端下载单个文件\"\"\"\n",
|
||
" if 'uploader' not in globals():\n",
|
||
" print(\"❌ WebDAV上传器未初始化\")\n",
|
||
" return False\n",
|
||
" \n",
|
||
" try:\n",
|
||
" # 确保本地目录存在\n",
|
||
" os.makedirs(local_dir, exist_ok=True)\n",
|
||
" \n",
|
||
" remote_path = f'/移动云盘/DATA/rnn_processed_data/{filename}'\n",
|
||
" local_path = os.path.join(local_dir, filename)\n",
|
||
" \n",
|
||
" uploader.client.download_sync(remote_path=remote_path, local_path=local_path)\n",
|
||
" \n",
|
||
" print(f\"✅ 下载成功: {filename}\")\n",
|
||
" print(f\"📁 保存到: {local_path}\")\n",
|
||
" \n",
|
||
" # 显示文件信息\n",
|
||
" if os.path.exists(local_path):\n",
|
||
" file_size = os.path.getsize(local_path) / 1024**2 # MB\n",
|
||
" print(f\"📊 文件大小: {file_size:.2f} MB\")\n",
|
||
" \n",
|
||
" return True\n",
|
||
" \n",
|
||
" except Exception as e:\n",
|
||
" print(f\"❌ 下载失败: {e}\")\n",
|
||
" return False\n",
|
||
"\n",
|
||
"def download_session_files(session_name, data_types=['train', 'val', 'test'], local_dir='./downloaded_data/'):\n",
|
||
" \"\"\"下载指定会话的所有文件\"\"\"\n",
|
||
" files = list_cloud_files()\n",
|
||
" \n",
|
||
" session_files = []\n",
|
||
" for file in files:\n",
|
||
" if file.startswith(session_name):\n",
|
||
" for data_type in data_types:\n",
|
||
" if f'_{data_type}_' in file:\n",
|
||
" session_files.append(file)\n",
|
||
" break\n",
|
||
" \n",
|
||
" if not session_files:\n",
|
||
" print(f\"❌ 未找到会话 {session_name} 的文件\")\n",
|
||
" return False\n",
|
||
" \n",
|
||
" print(f\"\\n📥 下载会话 {session_name} 的文件...\")\n",
|
||
" success_count = 0\n",
|
||
" \n",
|
||
" for file in session_files:\n",
|
||
" if download_cloud_file(file, local_dir):\n",
|
||
" success_count += 1\n",
|
||
" \n",
|
||
" print(f\"\\n✅ 下载完成: {success_count}/{len(session_files)} 个文件\")\n",
|
||
" return success_count == len(session_files)\n",
|
||
"\n",
|
||
"def check_local_storage():\n",
|
||
" \"\"\"检查本地存储使用情况\"\"\"\n",
|
||
" print(\"💾 本地存储检查:\")\n",
|
||
" \n",
|
||
" # 检查处理数据目录\n",
|
||
" if os.path.exists('./rnn_processed_data'):\n",
|
||
" total_size = 0\n",
|
||
" file_count = 0\n",
|
||
" \n",
|
||
" for root, dirs, files in os.walk('./rnn_processed_data'):\n",
|
||
" for file in files:\n",
|
||
" if file.endswith('.npz'):\n",
|
||
" filepath = os.path.join(root, file)\n",
|
||
" total_size += os.path.getsize(filepath)\n",
|
||
" file_count += 1\n",
|
||
" \n",
|
||
" print(f\" 📁 ./rnn_processed_data/\")\n",
|
||
" print(f\" 📊 文件数量: {file_count}\")\n",
|
||
" print(f\" 📊 总大小: {total_size / 1024**3:.2f} GB\")\n",
|
||
" \n",
|
||
" if file_count > 0:\n",
|
||
" print(f\" 💡 建议: 这些文件已处理完成,可以删除以释放空间\")\n",
|
||
" print(f\" 使用: rm -rf ./rnn_processed_data/\")\n",
|
||
" else:\n",
|
||
" print(f\" ✅ ./rnn_processed_data/ 目录不存在或为空\")\n",
|
||
" \n",
|
||
" # 检查下载目录\n",
|
||
" if os.path.exists('./downloaded_data'):\n",
|
||
" download_size = 0\n",
|
||
" download_count = 0\n",
|
||
" \n",
|
||
" for root, dirs, files in os.walk('./downloaded_data'):\n",
|
||
" for file in files:\n",
|
||
" if file.endswith('.npz'):\n",
|
||
" filepath = os.path.join(root, file)\n",
|
||
" download_size += os.path.getsize(filepath)\n",
|
||
" download_count += 1\n",
|
||
" \n",
|
||
" print(f\" 📁 ./downloaded_data/\")\n",
|
||
" print(f\" 📊 下载文件数: {download_count}\")\n",
|
||
" print(f\" 📊 下载大小: {download_size / 1024**3:.2f} GB\")\n",
|
||
"\n",
|
||
"# 使用示例\n",
|
||
"print(\"📁 云端文件管理工具已加载!\")\n",
|
||
"print(\"\\n🛠 可用函数:\")\n",
|
||
"print(\"• list_cloud_files() # 列出所有云端文件\")\n",
|
||
"print(\"• download_cloud_file('filename.npz') # 下载单个文件\")\n",
|
||
"print(\"• download_session_files('t15.2023.08.13') # 下载指定会话的所有文件\")\n",
|
||
"print(\"• check_local_storage() # 检查本地存储使用情况\")\n",
|
||
"\n",
|
||
"print(\"\\n💡 使用示例:\")\n",
|
||
"print(\"# 查看云端有哪些文件\")\n",
|
||
"print(\"files = list_cloud_files()\")\n",
|
||
"print(\"\")\n",
|
||
"print(\"# 下载特定文件\")\n",
|
||
"print(\"download_cloud_file('t15.2023.08.13_train_rnn_processed_batch0.npz')\")\n",
|
||
"print(\"\")\n",
|
||
"print(\"# 下载整个会话的数据\")\n",
|
||
"print(\"download_session_files('t15.2023.08.13')\")\n",
|
||
"print(\"\")\n",
|
||
"print(\"# 检查本地存储\")\n",
|
||
"print(\"check_local_storage()\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 📤 WebDAV文件上传工具"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 📚 WebDAV库选择 - 现成的解决方案"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"✅ webdavclient3 可用\n",
|
||
"修复后的WebDAV上传函数:\n",
|
||
"1. 单个文件上传会自动添加文件名到远程路径\n",
|
||
"2. 目录上传会过滤掉 .git 等不需要的文件\n",
|
||
"3. 显示详细的上传和跳过信息\n",
|
||
"\n",
|
||
"测试单个文件上传...\n",
|
||
"单个文件上传结果: {'success': True, 'local_file': 'F:\\\\BRAIN-TO-TEXT\\\\nejm-brain-to-text\\\\brain-to-text-25\\\\client_secrets.json', 'remote_path': '/移动云盘/DATA/client_secrets.json', 'library': 'webdavclient3'}\n",
|
||
"\n",
|
||
"测试目录上传(带过滤)...\n",
|
||
"单个文件上传结果: {'success': True, 'local_file': 'F:\\\\BRAIN-TO-TEXT\\\\nejm-brain-to-text\\\\brain-to-text-25\\\\client_secrets.json', 'remote_path': '/移动云盘/DATA/client_secrets.json', 'library': 'webdavclient3'}\n",
|
||
"\n",
|
||
"测试目录上传(带过滤)...\n",
|
||
"目录上传结果: {'success': True, 'local_dir': 'F:\\\\BRAIN-TO-TEXT\\\\nejm-brain-to-text\\\\data-kaggle', 'remote_dir': '/移动云盘/DATA/data-kaggle', 'uploaded_files': [], 'skipped_files': [], 'library': 'webdavclient3'}\n",
|
||
"上传了 0 个文件\n",
|
||
"跳过了 0 个文件\n",
|
||
"目录上传结果: {'success': True, 'local_dir': 'F:\\\\BRAIN-TO-TEXT\\\\nejm-brain-to-text\\\\data-kaggle', 'remote_dir': '/移动云盘/DATA/data-kaggle', 'uploaded_files': [], 'skipped_files': [], 'library': 'webdavclient3'}\n",
|
||
"上传了 0 个文件\n",
|
||
"跳过了 0 个文件\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 📋 WebDAV库安装指南\n",
|
||
"\n",
|
||
"print(\"📦 安装WebDAV客户端库:\")\n",
|
||
"print(\"pip install webdavclient3\")\n",
|
||
"print(\"\")\n",
|
||
"print(\"💡 如果已安装,下面的简化版上传工具就可以直接使用了!\")\n",
|
||
"print(\" 所有复杂的代码都已简化为易用的类和函数。\")\n",
|
||
"\n",
|
||
"# 检查安装状态\n",
|
||
"try:\n",
|
||
" from webdav3.client import Client\n",
|
||
" print(\"✅ webdavclient3 已安装并可用\")\n",
|
||
"except ImportError:\n",
|
||
" print(\"❌ 需要安装: pip install webdavclient3\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"✅ WebDAV连接已建立: http://zchens.cn:5244/dav/\n",
|
||
"\n",
|
||
"📖 使用方法:\n",
|
||
"# 上传单个文件\n",
|
||
"uploader.upload_file(r'F:\\path\\to\\file.txt')\n",
|
||
"\n",
|
||
"# 上传目录(自动过滤.git等文件)\n",
|
||
"uploader.upload_dir(r'F:\\path\\to\\dir', '/移动云盘/DATA/目标目录')\n",
|
||
"\n",
|
||
"# 上传目录(不过滤任何文件)\n",
|
||
"uploader.upload_dir(r'F:\\path\\to\\dir', '/移动云盘/DATA/目标目录', exclude_git=False)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 📤 简化版WebDAV上传工具\n",
|
||
"\n",
|
||
"from webdav3.client import Client\n",
|
||
"import os\n",
|
||
"import fnmatch\n",
|
||
"\n",
|
||
"class WebDAVUploader:\n",
|
||
" \"\"\"简化的WebDAV上传器\"\"\"\n",
|
||
" \n",
|
||
" def __init__(self, url='http://zchens.cn:5244/dav/', username='admin', password='Zccns20050420'):\n",
|
||
" \"\"\"初始化WebDAV连接\"\"\"\n",
|
||
" self.client = Client({\n",
|
||
" 'webdav_hostname': url,\n",
|
||
" 'webdav_login': username,\n",
|
||
" 'webdav_password': password,\n",
|
||
" 'webdav_timeout': 30\n",
|
||
" })\n",
|
||
" print(f\"✅ WebDAV连接已建立: {url}\")\n",
|
||
" \n",
|
||
" def upload_file(self, local_file, remote_dir='/移动云盘/DATA/'):\n",
|
||
" \"\"\"上传单个文件\"\"\"\n",
|
||
" try:\n",
|
||
" filename = os.path.basename(local_file)\n",
|
||
" remote_path = remote_dir.rstrip('/') + '/' + filename\n",
|
||
" \n",
|
||
" # 确保远程目录存在\n",
|
||
" if not self.client.check(remote_dir):\n",
|
||
" self.client.mkdir(remote_dir)\n",
|
||
" \n",
|
||
" self.client.upload_sync(remote_path=remote_path, local_path=local_file)\n",
|
||
" print(f\"✅ 文件上传成功: {filename} -> {remote_path}\")\n",
|
||
" return True\n",
|
||
" \n",
|
||
" except Exception as e:\n",
|
||
" print(f\"❌ 文件上传失败: {e}\")\n",
|
||
" return False\n",
|
||
" \n",
|
||
" def upload_dir(self, local_dir, remote_dir, exclude_git=True):\n",
|
||
" \"\"\"上传目录(自动过滤不需要的文件)\"\"\"\n",
|
||
" exclude_patterns = ['.git*', '__pycache__*', '*.pyc', '.vscode*'] if exclude_git else []\n",
|
||
" \n",
|
||
" try:\n",
|
||
" uploaded = 0\n",
|
||
" skipped = 0\n",
|
||
" \n",
|
||
" for root, dirs, files in os.walk(local_dir):\n",
|
||
" # 过滤目录\n",
|
||
" if exclude_git:\n",
|
||
" dirs[:] = [d for d in dirs if not any(fnmatch.fnmatch(d, p) for p in exclude_patterns)]\n",
|
||
" \n",
|
||
" for file in files:\n",
|
||
" # 检查是否跳过文件\n",
|
||
" if exclude_git and any(fnmatch.fnmatch(file, p) for p in exclude_patterns):\n",
|
||
" skipped += 1\n",
|
||
" continue\n",
|
||
" \n",
|
||
" local_file = os.path.join(root, file)\n",
|
||
" rel_path = os.path.relpath(local_file, local_dir)\n",
|
||
" remote_file = remote_dir.rstrip('/') + '/' + rel_path.replace('\\\\', '/')\n",
|
||
" \n",
|
||
" # 确保远程目录存在\n",
|
||
" remote_file_dir = '/'.join(remote_file.split('/')[:-1])\n",
|
||
" if not self.client.check(remote_file_dir):\n",
|
||
" self.client.mkdir(remote_file_dir)\n",
|
||
" \n",
|
||
" self.client.upload_sync(remote_path=remote_file, local_path=local_file)\n",
|
||
" uploaded += 1\n",
|
||
" \n",
|
||
" print(f\"✅ 目录上传完成: 上传 {uploaded} 个文件,跳过 {skipped} 个文件\")\n",
|
||
" return True\n",
|
||
" \n",
|
||
" except Exception as e:\n",
|
||
" print(f\"❌ 目录上传失败: {e}\")\n",
|
||
" return False\n",
|
||
"\n",
|
||
"# 创建上传器实例\n",
|
||
"uploader = WebDAVUploader()\n",
|
||
"\n",
|
||
"print(\"\\n📖 使用方法:\")\n",
|
||
"print(\"# 上传单个文件\")\n",
|
||
"print(\"uploader.upload_file(r'F:\\\\path\\\\to\\\\file.txt')\")\n",
|
||
"print(\"\")\n",
|
||
"print(\"# 上传目录(自动过滤.git等文件)\") \n",
|
||
"print(\"uploader.upload_dir(r'F:\\\\path\\\\to\\\\dir', '/移动云盘/DATA/目标目录')\")\n",
|
||
"print(\"\")\n",
|
||
"print(\"# 上传目录(不过滤任何文件)\")\n",
|
||
"print(\"uploader.upload_dir(r'F:\\\\path\\\\to\\\\dir', '/移动云盘/DATA/目标目录', exclude_git=False)\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"==================================================\n",
|
||
"WebDAV上传示例\n",
|
||
"==================================================\n",
|
||
"\n",
|
||
"1️⃣ 上传单个文件:\n",
|
||
"✅ 文件上传成功: client_secrets.json -> /移动云盘/DATA/client_secrets.json\n",
|
||
"\n",
|
||
"2️⃣ 上传目录(过滤.git等文件):\n",
|
||
"✅ 目录上传完成: 上传 0 个文件,跳过 0 个文件\n",
|
||
"\n",
|
||
"3️⃣ 上传RNN处理结果:\n",
|
||
"✅ 文件上传成功: client_secrets.json -> /移动云盘/DATA/client_secrets.json\n",
|
||
"\n",
|
||
"2️⃣ 上传目录(过滤.git等文件):\n",
|
||
"✅ 目录上传完成: 上传 0 个文件,跳过 0 个文件\n",
|
||
"\n",
|
||
"3️⃣ 上传RNN处理结果:\n",
|
||
"❌ 目录上传失败: HTTPConnectionPool(host='127.0.0.1', port=7897): Read timed out. (read timeout=30)\n",
|
||
"RNN处理结果上传完成!\n",
|
||
"\n",
|
||
"✨ 上传任务完成!现在你的文件应该在云盘中了。\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 🚀 快速使用示例\n",
|
||
"\n",
|
||
"print(\"=\"*50)\n",
|
||
"print(\"WebDAV上传示例\")\n",
|
||
"print(\"=\"*50)\n",
|
||
"\n",
|
||
"# 示例1: 上传单个文件\n",
|
||
"print(\"\\n1️⃣ 上传单个文件:\")\n",
|
||
"uploader.upload_file(r'F:\\BRAIN-TO-TEXT\\nejm-brain-to-text\\brain-to-text-25\\client_secrets.json')\n",
|
||
"\n",
|
||
"# 示例2: 上传目录(自动过滤.git)\n",
|
||
"print(\"\\n2️⃣ 上传目录(过滤.git等文件):\")\n",
|
||
"uploader.upload_dir(\n",
|
||
" r'F:\\BRAIN-TO-TEXT\\nejm-brain-to-text\\data-kaggle', \n",
|
||
" '/移动云盘/DATA/data-kaggle-clean'\n",
|
||
")\n",
|
||
"\n",
|
||
"# 示例3: 上传处理后的数据目录\n",
|
||
"print(\"\\n3️⃣ 上传RNN处理结果:\")\n",
|
||
"if os.path.exists('./rnn_processed_data'):\n",
|
||
" uploader.upload_dir(\n",
|
||
" './rnn_processed_data',\n",
|
||
" '/移动云盘/DATA/rnn_processed_data'\n",
|
||
" )\n",
|
||
" print(\"RNN处理结果上传完成!\")\n",
|
||
"else:\n",
|
||
" print(\"⚠️ 未找到RNN处理数据目录\")\n",
|
||
"\n",
|
||
"print(\"\\n✨ 上传任务完成!现在你的文件应该在云盘中了。\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"🎯 便捷函数已定义!\n",
|
||
"\n",
|
||
"可用的快捷函数:\n",
|
||
"• quick_upload_file('文件路径') # 上传单个文件\n",
|
||
"• quick_upload_project('项目目录') # 上传整个项目\n",
|
||
"• quick_upload_results() # 上传所有结果文件\n",
|
||
"\n",
|
||
"例如:\n",
|
||
"quick_upload_file('data.csv')\n",
|
||
"quick_upload_project(r'F:\\BRAIN-TO-TEXT\\nejm-brain-to-text')\n",
|
||
"quick_upload_results()\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 🎯 便捷函数 - 一键上传常用文件\n",
|
||
"\n",
|
||
"def quick_upload_file(file_path, remote_dir='/移动云盘/DATA/'):\n",
|
||
" \"\"\"快速上传单个文件\"\"\"\n",
|
||
" return uploader.upload_file(file_path, remote_dir)\n",
|
||
"\n",
|
||
"def quick_upload_project(project_dir, remote_name=None):\n",
|
||
" \"\"\"快速上传整个项目目录(自动过滤.git等)\"\"\"\n",
|
||
" if remote_name is None:\n",
|
||
" remote_name = os.path.basename(project_dir.rstrip('/\\\\'))\n",
|
||
" \n",
|
||
" remote_dir = f'/移动云盘/DATA/{remote_name}'\n",
|
||
" return uploader.upload_dir(project_dir, remote_dir, exclude_git=True)\n",
|
||
"\n",
|
||
"def quick_upload_results():\n",
|
||
" \"\"\"快速上传所有结果文件\"\"\"\n",
|
||
" results = []\n",
|
||
" \n",
|
||
" # 上传RNN处理结果\n",
|
||
" if os.path.exists('./rnn_processed_data'):\n",
|
||
" print(\"📊 上传RNN处理结果...\")\n",
|
||
" results.append(uploader.upload_dir('./rnn_processed_data', '/移动云盘/DATA/rnn_processed_data'))\n",
|
||
" \n",
|
||
" # 上传notebook文件\n",
|
||
" notebook_files = [f for f in os.listdir('.') if f.endswith('.ipynb')]\n",
|
||
" for nb in notebook_files:\n",
|
||
" print(f\"📓 上传notebook: {nb}\")\n",
|
||
" results.append(uploader.upload_file(nb, '/移动云盘/DATA/notebooks/'))\n",
|
||
" \n",
|
||
" success_count = sum(results)\n",
|
||
" print(f\"\\n✅ 完成!成功上传 {success_count}/{len(results)} 个项目\")\n",
|
||
" return all(results)\n",
|
||
"\n",
|
||
"# 使用示例\n",
|
||
"print(\"🎯 便捷函数已定义!\")\n",
|
||
"print(\"\\n可用的快捷函数:\")\n",
|
||
"print(\"• quick_upload_file('文件路径') # 上传单个文件\")\n",
|
||
"print(\"• quick_upload_project('项目目录') # 上传整个项目\")\n",
|
||
"print(\"• quick_upload_results() # 上传所有结果文件\")\n",
|
||
"print(\"\\n例如:\")\n",
|
||
"print(\"quick_upload_file('data.csv')\")\n",
|
||
"print(\"quick_upload_project(r'F:\\\\BRAIN-TO-TEXT\\\\nejm-brain-to-text')\")\n",
|
||
"print(\"quick_upload_results()\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kaggle": {
|
||
"accelerator": "tpu1vmV38",
|
||
"dataSources": [
|
||
{
|
||
"databundleVersionId": 13056355,
|
||
"sourceId": 106809,
|
||
"sourceType": "competition"
|
||
}
|
||
],
|
||
"dockerImageVersionId": 31091,
|
||
"isGpuEnabled": false,
|
||
"isInternetEnabled": true,
|
||
"language": "python",
|
||
"sourceType": "notebook"
|
||
},
|
||
"kernelspec": {
|
||
"display_name": "b2txt25",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.10.18"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 4
|
||
}
|