Files
b2txt25/analyses/figure_2.ipynb

296 lines
872 KiB
Plaintext
Raw Normal View History

2024-08-14 12:00:20 -07:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# imports and initialization"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import os\n",
"import pickle\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib\n",
"from g2p_en import G2p\n",
"\n",
"from nejm_b2txt_utils.general_utils import *\n",
"\n",
"matplotlib.rcParams['pdf.fonttype'] = 42\n",
"matplotlib.rcParams['ps.fonttype'] = 42\n",
"matplotlib.rcParams['font.family'] = 'sans-serif'\n",
"\n",
"g2p = G2p()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# load Copy Task evaluation data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# load pickled data\n",
"with open('../data/t15_copyTask.pkl', 'rb') as f:\n",
" dat = pickle.load(f)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# plot decoded phoneme logits from an example trial"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cue words: they can pretty much get everybody there\n",
"cue phonemes: DH EY | K AE N | P R IH T IY | M AH CH | G EH T | EH V R IY B AA D IY | DH EH R | \n",
"decoded phonemes (raw): DH EY | K AE N | B R IH T IY | M AH N CH | G EH T | EH V R IY B AA D IY | DH EH R | \n",
"decoded words: they can pretty much get everybody there\n",
"decoded phonemes: DH EY | K AE N | P R IH T IY | M AH CH | G EH T | EH V R IY B AA D IY | DH EH R | \n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABRIAAAHDCAYAAABYjOYrAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3xT9frA8U+aNGm6d0vpZJWlgGwHW1D0unCjiNctXtd14b0/r+NecVzXdeMeuHAvRFHAyQZFRlld0L1H0qZNzu+P04SWtnQlTZM879eLF22a8W2anJzznGdoFEVREEIIIYQQQgghhBBCiKPwc/cChBBCCCGEEEIIIYQQfZ8EEoUQQgghhBBCCCGEEB2SQKIQQgghhBBCCCGEEKJDEkgUQgghhBBCCCGEEEJ0SAKJQgghhBBCCCGEEEKIDkkgUQghhBBCCCGEEEII0SEJJAohhBBCCCGEEEIIITokgUQhhBBCCCGEEEIIIUSHJJAohBBCCCGEEEIIIYTokAQShRBCuM3ChQtJTU11y2OvWbMGjUbDmjVr3PL4XfH666+j0WjIyspy91KEk6SmpnL66ae7exnCx02bNo2RI0e6exk+yb5d37Rpk8sfa9q0aUybNs3lj7Nw4UKCg4Nd/jhCCCHcSwKJQgjhpewHKfZ/AQEBDBkyhBtuuIHCwkKnP57JZOLee+/1iMBcb/GW52TVqlXMmDGDsLAwQkJCGDt2LO+//36L69TU1HDzzTeTmJiIwWBg2LBhPP/8825ace/RaDS8/vrrgBocvPfeex0/27lzJ/fee68EgPuwd955hyeffLLV5Xl5edx7771s27at19fky+R57/s88XPtaNtpIYQQXadz9wKEEEK41v33309aWhp1dXX8/PPPPP/883z99df8+eefBAYGOu1xTCYT9913H0CnMx9eeuklbDab09bQ13TnOelrXnvtNa644gpOPvlkHnzwQbRaLRkZGeTm5jquY7VamTNnDps2bWLRokUMHjyYlStXcv3111NeXs7dd9/txt/AfXbu3Ml9993HtGnT3JZ5K47unXfe4c8//+Tmm29ucXleXh733XcfqampjB492i1r80XyvPd93vC5JoQQomckkCiEEF7u1FNPZdy4cQBceeWVREVF8fjjj/PZZ59x0UUXuXVt/v7+bn18cXRZWVksWrSIv/3tbzz11FPtXu/jjz/m119/5ZVXXuGvf/0rANdddx3nnnsuDzzwAFdeeSWxsbG9tWzhw2prawkKCnL3MoQPamxsxGazodfr3b0Ur6YoCnV1dRiNRncvRQghfJaUNgshhI+ZMWMGAJmZmYB68PPAAw8wcOBADAYDqamp3H333dTX17e43aZNm5gzZw7R0dEYjUbS0tIcQaOsrCxiYmIAuO+++xzl1B2VDx3ZIzErKwuNRsN///tfnn32WQYMGEBgYCCzZ88mNzcXRVF44IEHSExMxGg0cuaZZ1JWVtbiPu2957799ltGjx5NQEAAw4cP5+OPP+7wufnpp58477zzSE5OxmAwkJSUxC233ILZbG617uDgYA4dOsRZZ51FcHAwMTEx3HbbbVit1h49Jzt27GDGjBkYjUYSExP597//3WbW5meffcZpp51GQkICBoOBgQMH8sADDzgeH+Bf//oX/v7+FBcXt7r91VdfTXh4OHV1de2u5YUXXsBqtXL//fcDavmyoihtPm8AF154YYvLL7zwQurq6vjss8+O+jsfzdtvv82ECRMIDAwkIiKCKVOm8O233zp+3t5zmpqaysKFC1tcVlFRwc0330xSUhIGg4FBgwbx8MMPt3p+8/Pz2b17Nw0NDd1e9+uvv855550HwPTp0x1//yPLAX/++WcmTJhAQEAAAwYM4M0332x1Xx2tW1EUUlNTOfPMM1vdtq6ujrCwMK655poO19zRc92Z1xwc7ru3c+dOpk+fTmBgIP379+eRRx7pcA2g/k1vuOEGli1bRnp6OgEBAYwdO5Yff/yxxfXuvfdeNBoNO3fu5OKLLyYiIoITTzyxxe8zduxYjEYjkZGRXHjhhS0yaadNm8ZXX31Fdna24++TmprKmjVrGD9+PACXX36542evv/56j99T9m1HTk4Op59+OsHBwfTv359nn30WgO3btzNjxgyCgoJISUnhnXfeafN3PlJ7fVRXrFjB1KlTCQkJITQ0lPHjx7e6T6Dbfyuz2cyNN95IdHQ0ISEhnHHGGRw6dKjN9+WhQ4f461//SlxcHAaDgREjRvDqq686fn605/1oOrrfwsJCdDqdI4uuuYyMDDQaDc8884zjss5sJ5p/Vj355JOOz88NGzYQFBTETTfd1OqxDh48iFarZcmSJS0uN5lMXHPNNURFRREaGsqCBQsoLy9vdfvnnnuOESNGYDAYSEhIYNGiRVRUVLS63tKlSxk4cCBGo5EJEyY4ts92NTU1XV5j89+7M59rR/tstLPZbDz55JOMGDGCgIAA4uLiuOaaa1r97vbP9JUrVzJu3DiMRiMvvvgi0PltuhBCCOeSQKIQQviY/fv3AxAVFQWoWYr33HMPxx13HE888QRTp05lyZIlLYJCRUVFzJ49m6ysLO666y6efvpp5s+fz7p16wCIiYlx9MM7++yzeeutt3jrrbc455xzurXGZcuW8dxzz/G3v/2Nv//976xdu5bzzz+ff/7zn3zzzTfceeedXH311XzxxRfcdtttrW6/d+9eLrjgAk499VSWLFmCTqfjvPPO47vvvjvq4y5fvhyTycR1113H008/zZw5c3j66adZsGBBq+vay3mjoqL473//y9SpU3nsscdYunRpt5+TgoICpk+fzrZt27jrrru4+eabefPNN9vMBnz99dcJDg7m1ltv5amnnmLs2LHcc8893HXXXY7rXHrppTQ2NrbqZ2ixWPjwww+ZN28eAQEB7a5n1apVDB06lK+//prExERCQkKIiori//7v/1ocqNXX16PValtl4thL5zdv3tzuYxzNfffdx6WXXoq/vz/3338/9913H0lJSfzwww9dvi+TycTUqVN5++23WbBgAf/73/844YQTWLx4MbfeemuL6y5evJhhw4Zx6NChbq0bYMqUKdx4440A3H333Y6//7BhwxzX2bdvH+eeey4nn3wyjz32GBERESxcuJAdO3Z0ad0ajYZLLrmEFStWtAqsf/HFF1RVVXHJJZccdb2dea4785qzKy8v55RTTmHUqFE89thjDB06lDvvvJMVK1Z06vlbu3YtN998M5dccgn3338/paWlnHLKKfz555+trnveeedhMpl48MEHueqqqwD4z3/+w4IFCxg8eDCPP/44N998M99//z1TpkxxBF/+8Y9/MHr0aKKjox1/nyeffJJhw4Y5gudXX32142dTpkzp8XsK1G3HqaeeSlJSEo888gipqanccMMNvP7665xyyimMGzeOhx9+mJCQEBYsWOA46dNVr7/+OqeddhplZWUsXryYhx56iNGjR/PNN9+0uF5P/lYLFy7k6aefZu7cuTz88MMYjUZOO+20VtcrLCxk0qRJrFq1ihtuuIGnnnqKQYMGccUVVzh6VB7teW9PZ+43Li6OqVOn8sEHH7S6/fvvv49Wq3UE/buynQC19cPTTz/N1VdfzWOPPUZycjJnn30277//fqvA2bvvvouiKMyfP7/F5TfccAO7du3i3nvvZcGCBSxbtoyzzjqrxUmbe++9l0WLFpGQkMBjjz3GvHnzePHFF5k9e3aLEx6vvPIK11xzDfHx8TzyyCOccMIJnHHGGS0C6MHBwV1eo11nPtc6+my0u+aaa7j99ts54YQTeOqpp7j88stZtmwZc+bMaXUSJyMjg4suuoiTTz6Zp556itGjR3f5byWEEMKJFCGEEF7ptddeUwBl1apVSnFxsZKbm6u89957SlRUlGI0GpWDBw8q27ZtUwDlyiuvbHHb2267TQGUH374QVEURfnkk08UQNm4cWO7j1dcXKwAyr/+9a9Or/Gyyy5TUlJSHN9nZmYqgBITE6NUVFQ4Ll+8eLECKKNGjVIaGhocl1900UWKXq9X6urqHJelpKQogPLRRx8
"text/plain": [
"<Figure size 1600x500 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABJsAAAHDCAYAAACDEhsgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOx9eZxcVZX/t6q7unpfspCQ0ElIIgnLQDCyCiQgEDZlycgWDWiU4AAKcYlxZpAGnYAjAiKgOC0owo+IIouyKAqMC4wIomggBE0gkJC99+7q6q73+6NNpd85306dfl2d7iTn+/nkA3X6vvvuu+fcc++7757viQVBEMDhcDgcDofD4XA4HA6Hw+HIA+JD3QCHw+FwOBwOh8PhcDgcDsfuA99scjgcDofD4XA4HA6Hw+Fw5A2+2eRwOBwOh8PhcDgcDofD4cgbfLPJ4XA4HA6Hw+FwOBwOh8ORN/hmk8PhcDgcDofD4XA4HA6HI2/wzSaHw+FwOBwOh8PhcDgcDkfe4JtNDofD4XA4HA6Hw+FwOByOvME3mxwOh8PhcDgcDofD4XA4HHmDbzY5HA6Hw+FwOBwOh8PhcDjyBt9scjgcuxwuvvhiTJo0aUju/cwzzyAWi+GZZ54Zkvv3B3fffTdisRhWr1491E1x5AmTJk3CGWecMdTNcOzhmD17Ng466KChbsYeiW1+/Y9//OOg32v27NmYPXv2oN/n4osvRnl5+aDfx+FwOBw7F77Z5HA4Qti2kN32r7i4GPvttx8uv/xyrF+/Pu/3a2trwzXXXLNLbN7sLOwuffLUU0/hhBNOQFVVFSoqKjBz5kwsW7YsVKalpQVXXnkl9tlnHySTSey///644447hqjFOw+xWAx33303gJ4NpGuuuSb7t+XLl+Oaa67xTcJhjPvuuw8333yzkq9duxbXXHMNXn755Z3epj0Z3u/DH7vivLYjP+1wOByO3Cgc6gY4HI7hiWuvvRb77rsvOjo68Nvf/hZ33HEHHnvsMfz1r39FaWlp3u7T1taGuro6ADB/Qf3ud7+LTCaTtzYMN0Tpk+GGu+66CwsWLMBJJ52E//qv/0JBQQFWrFiBNWvWZMt0d3djzpw5+OMf/4jLLrsM73nPe/Dkk0/i3/7t37B161Z86UtfGsInGDosX74cdXV1mD179pCd4HPsGPfddx/++te/4sorrwzJ165di7q6OkyaNAkzZswYkrbtifB+H/7YHeY1h8PhcPQPvtnkcDgoTj31VLzvfe8DAHziE5/AyJEj8Y1vfAMPP/wwLrjggiFtWyKRGNL7O3aM1atX47LLLsMVV1yBW265pc9yDz74IH7/+9+jvr4eH//4xwEAn/rUp/Cv//qvuO666/CJT3wCe+21185qtmMPRmtrK8rKyoa6GY49EF1dXchkMigqKhrqpuzWCIIAHR0dKCkpGeqmOBwOxx4DD6NzOBwmnHDCCQCAVatWAehZIF933XWYMmUKkskkJk2ahC996UtIpVKh6/74xz9izpw5GDVqFEpKSrDvvvtmNxZWr16N0aNHAwDq6uqyoXu5jqpLzqbVq1cjFovh61//Om677TZMnjwZpaWlOPnkk7FmzRoEQYDrrrsO++yzD0pKSnDmmWdiy5YtoTq3ceH84he/wIwZM1BcXIwDDjgADz74YM6++c1vfoMPf/jDmDBhApLJJGpra3HVVVehvb1dtbu8vBzvvPMOzjrrLJSXl2P06NH43Oc+h+7u7gH1yd/+9jeccMIJKCkpwT777IOvfOUr9PTXww8/jNNPPx3jxo1DMpnElClTcN1112XvDwBf/vKXkUgksHHjRnX9JZdcgurqanR0dPTZlm9/+9vo7u7GtddeC6AnVC4IAtpvAHD++eeH5Oeffz46Ojrw8MMP7/CZd4Qf/vCHOPzww1FaWoqamhocd9xx+MUvfpH9e199OmnSJFx88cUhWUNDA6688krU1tYimUxi6tSpuOGGG1T/rlu3Dq+99hrS6XTkdt9999348Ic/DAA4/vjjs/qXoSe//e1vcfjhh6O4uBiTJ0/GD37wA1VXrnYHQYBJkybhzDPPVNd2dHSgqqoKCxcuzNnmXH1tsTlgOw/Q8uXLcfzxx6O0tBTjx4/H1772tZxtAHp0evnll+Pee+/FtGnTUFxcjJkzZ+J///d/Q+WuueYaxGIxLF++HBdeeCFqampwzDHHhJ5n5syZKCkpwYgRI3D++eeHTuTNnj0bP//5z/Hmm29m9TNp0iQ888wzOOywwwAAH/vYx7J/u/vuuwc8prb5jrfeegtnnHEGysvLMX78eNx2220AgFdeeQUnnHACysrKMHHiRNx33330mSX64nV7/PHHMWvWLFRUVKCyshKHHXaYqhNAZF21t7fj05/+NEaNGoWKigp86EMfwjvvvEPH5TvvvIOPf/zjGDNmDJLJJA488EB873vfy/59R/2+I+Sqd/369SgsLMyexumNFStWIBaL4Vvf+lZWZvETveeqm2++OTt//uEPf0BZWRk+85nPqHu9/fbbKCgowNKlS0PytrY2LFy4ECNHjkRlZSXmz5+PrVu3qutvv/12HHjggUgmkxg3bhwuu+wyNDQ0qHJ33nknpkyZgpKSEhx++OFZ/7wNLS0t/W5j7+e2zGs7mhu3IZPJ4Oabb8aBBx6I4uJijBkzBgsXLlTPvm1Of/LJJ/G+970PJSUl+M53vgPA7tMdDofDMTD4ZpPD4TDh73//OwBg5MiRAHpOO1199dV473vfi5tuugmzZs3C0qVLQxsHGzZswMknn4zVq1fji1/8Im699VbMmzcPzz//PABg9OjRWX6es88+G/fccw/uuecenHPOOZHaeO+99+L222/HFVdcgc9+9rN49tlnce655+I//uM/8MQTT2Dx4sW45JJL8Oijj+Jzn/ucun7lypU477zzcOqpp2Lp0qUoLCzEhz/8Yfzyl7/c4X0feOABtLW14VOf+hRuvfVWzJkzB7feeivmz5+vym4LHRs5ciS+/vWvY9asWbjxxhtx5513Ru6Td999F8cffzxefvllfPGLX8SVV16JH/zgB/RU0d13343y8nIsWrQIt9xyC2bOnImrr74aX/ziF7NlPvrRj6Krq0vxK3V2duLHP/4x5s6di+Li4j7b89RTT2H69Ol47LHHsM8++6CiogIjR47Ef/7nf4YW86lUCgUFBeqL/rYwzRdffLHPe+wIdXV1+OhHP4pEIoFrr70WdXV1qK2txa9//et+19XW1oZZs2bhhz/8IebPn49vfvObeP/7348lS5Zg0aJFobJLlizB/vvvj3feeSdSuwHguOOOw6c//WkAwJe+9KWs/vfff/9smTfeeAP/+q//ipNOOgk33ngjampqcPHFF+Nvf/tbv9odi8XwkY98BI8//rjafH300UfR1NSEj3zkIztsr6WvLTa3DVu3bsUpp5yCQw45BDfeeCOmT5+OxYsX4/HHHzf137PPPosrr7wSH/nIR3Dttddi8+bNOOWUU/DXv/5Vlf3whz+MtrY2/Nd//Rc++clPAgC++tWvYv78+XjPe96Db3zjG7jyyivxq1/9Cscdd1z2Bf3f//3fMWPGDIwaNSqrn5tvvhn7779/doP1kksuyf7tuOOOG/CYAnp8x6mnnora2lp87Wtfw6RJk3D55Zfj7rvvximnnIL3ve99uOGGG1BRUYH58+dnPwz0F3fffTdOP/10bNmyBUuWLMH111+PGTNm4IknngiVG4iuLr74Ytx666047bTTcMMNN6CkpASnn366Krd+/XoceeSReOqpp3D55ZfjlltuwdSpU7FgwYIsZ9aO+r0vWOodM2YMZs2ahR/96Efq+mXLlqGgoCC7MdwfPwH0hBnfeuutuOSSS3DjjTdiwoQJOPvss7Fs2TK1ufL//t//QxAEmDdvXkh++eWX49VXX8U111yD+fPn495778VZZ50V2ti/5pprcNlll2HcuHG48cYbMXfuXHznO9/BySefHNoUr6+vx8KFCzF27Fh87Wtfw/vf/3586EMfCm2ylpeX97uN22CZ13LNjduwcOFCfP7zn8f73/9+3HLLLfj
"text/plain": [
"<Figure size 1600x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABJQAAAHDCAYAAAByGUCtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAACUCklEQVR4nOzdeXxTdfb/8XeStikQCpVVMFAoyOJCsQKCKAXRKqKijqKggKKgIyrWBeooUFCL4oIjbvgtogg/OyoibjCK1NERRwSZUUFEbIVhE9nSgqQ0ub8/mEbSm0JLl5s0ryePPPSe3Nyc5Obe5J5+FpthGIYAAAAAAACACrJbnQAAAAAAAAAiCwUlAAAAAAAAVAoFJQAAAAAAAFQKBSUAAAAAAABUCgUlAAAAAAAAVAoFJQAAAAAAAFQKBSUAAAAAAABUCgUlAAAAAAAAVAoFJQAAAAAAAFQKBSUAEWfUqFFKSkqy5Lnz8vJks9mUl5dnyfNXxty5c2Wz2VRQUGB1KqgmSUlJGjx4sNVpIMqlpaXp1FNPtTqNqFR6Xv/6669r/LnS0tKUlpZW488zatQouVyuGn8eAED1o6AEIEjpj9XSW3x8vE4++WSNGzdOO3bsqPbnO3DggKZMmRIRBZraUlfek48//lgDBgxQo0aN1LBhQ6Wmpio3NzdonaKiIo0fP14nnXSSnE6nunTpoueff96ijGuPzWbT3LlzJR0uEk2ZMiVw39q1azVlyhQKgWFswYIFmjlzpim+detWTZkyRWvWrKn1nKIZ73v4i8TvtaOdpwEAh8VYnQCA8DR16lS1a9dOBw8e1Oeff67nn39eH3zwgb777jvVr1+/2p7nwIEDysrKkqQK/yX0pZdekt/vr7Ycws3xvCfh5uWXX9bo0aN1/vnn65FHHpHD4dD69eu1efPmwDo+n0/p6en6+uuvddttt6ljx45aunSp/vznP2vPnj26//77LXwF1lm7dq2ysrKUlpZmWUs8HN2CBQv03Xffafz48UHxrVu3KisrS0lJSUpJSbEkt2jE+x7+6sL3GgDAjIISgJAuuuginXnmmZKkm266SU2aNNGTTz6pd955R9dee62lucXGxlr6/Di6goIC3Xbbbbr99tv19NNPl7vewoUL9cUXXygnJ0c33nijJOnWW2/Vn/70J02bNk033XSTmjdvXltpI4rt379fDRo0sDoNRKGSkhL5/X7FxcVZnUqdZhiGDh48qHr16lmdCgDUKXR5A1AhAwYMkCTl5+dLOvwjeNq0aUpOTpbT6VRSUpLuv/9+eb3eoMd9/fXXSk9PV9OmTVWvXj21a9cuUDwoKChQs2bNJElZWVmBbnbHalZedgylgoIC2Ww2Pf7443r22WfVvn171a9fXxdccIE2b94swzA0bdo0nXTSSapXr54uu+wy7d69O2ibpWPT/P3vf1dKSori4+PVtWtXLVy48JjvzWeffaarrrpKbdq0kdPplNvt1l133aXff//dlLfL5dKWLVs0ZMgQuVwuNWvWTPfcc498Pl+V3pPvv/9eAwYMUL169XTSSSfpoYceCtmK65133tHFF1+sVq1ayel0Kjk5WdOmTQs8vyRNnjxZsbGx2rlzp+nxY8aMUePGjXXw4MFyc3nhhRfk8/k0depUSYe7tRmGEfJ9k6RrrrkmKH7NNdfo4MGDeuedd476mo/mtddeU8+ePVW/fn0lJibq3HPP1d///vfA/eW9p0lJSRo1alRQbO/evRo/frzcbrecTqc6dOigRx991PT+btu2TT/88IMOHTp03HnPnTtXV111lSSpf//+gf1ftpvI559/rp49eyo+Pl7t27fXq6++atrWsfI2DENJSUm67LLLTI89ePCgGjVqpLFjxx4z52O91xX5zEl/jMuzdu1a9e/fX/Xr11fr1q312GOPHTMH6fA+HTdunObPn69OnTopPj5eqamp+sc//hG03pQpU2Sz2bR27VoNGzZMiYmJ6tu3b9DrSU1NVb169XTCCSfommuuCWpZl5aWpvfff1+//PJLYP8kJSUpLy9PPXr0kCTdcMMNgfvmzp1b5WOq9NyxadMmDR48WC6XS61bt9azzz4rSfr22281YMAANWjQQG3bttWCBQtCvuayyhtn7cMPP1S/fv3UsGFDJSQkqEePHqZtSjruffX777/rjjvuUNOmTdWwYUNdeuml2rJlS8jjcsuWLbrxxhvVokULOZ1OnXLKKZozZ07g/qO970dzrO3u2LFDMTExgVY1R1q/fr1sNptmzZoViFXkPHHkd9XMmTMD359fffWVGjRooDvvvNP0XP/973/lcDiUnZ0dFD9w4IDGjh2rJk2aKCEhQSNGjNCePXtMj3/uued0yimnyOl0qlWrVrrtttu0d+9e03qzZ89WcnKy6tWrp549ewbOz6WKiooqneORr7si32tH+24s5ff7NXPmTJ1yyimKj49XixYtNHbsWNNrL/1OX7p0qc4880zVq1dPL774oqSKn9MBAMdGQQlAhWzcuFGS1KRJE0mHWy1NmjRJZ5xxhp566in169dP2dnZQcWBX3/9VRdccIEKCgo0ceJEPfPMMxo+fLi+/PJLSVKzZs0C4+VcfvnlmjdvnubNm6crrrjiuHKcP3++nnvuOd1+++26++679emnn+rqq6/WAw88oCVLlmjChAkaM2aM3n33Xd1zzz2mx2/YsEFDhw7VRRddpOzsbMXExOiqq67SRx99dNTnfeONN3TgwAHdeuuteuaZZ5Senq5nnnlGI0aMMK1b2s2rSZMmevzxx9WvXz898cQTmj179nG/J9u3b1f//v21Zs0aTZw4UePHj9err74asnXQ3Llz5XK5lJGRoaefflqpqamaNGmSJk6cGFjn+uuvV0lJiWm8o+LiYr355pu68sorFR8fX24+H3/8sTp37qwPPvhAJ510kho2bKgmTZrowQcfDPrB7vV65XA4TH+ZL+1SuWrVqnKf42iysrJ0/fXXKzY2VlOnTlVWVpbcbrc++eSTSm/rwIED6tevn1577TWNGDFCf/3rX3X22WcrMzNTGRkZQetmZmaqS5cu2rJly3HlLUnnnnuu7rjjDknS/fffH9j/Xbp0Cazz008/6U9/+pPOP/98PfHEE0pMTNSoUaP0/fffVypvm82m6667Th9++KGpwPruu+/K4/HouuuuO2q+FXmvK/KZK7Vnzx5deOGF6tatm5544gl17txZEyZM0Icfflih9+/TTz/V+PHjdd1112nq1KnatWuXLrzwQn333Xemda+66iodOHBAjzzyiG6++WZJ0sMPP6wRI0aoY8eOevLJJzV+/HgtW7ZM5557buAi/C9/+YtSUlLUtGnTwP6ZOXOmunTpEiiijhkzJnDfueeeW+VjSjp87rjooovkdrv12GOPKSkpSePGjdPcuXN14YUX6swzz9Sjjz6qhg0basSIEYHif2XNnTtXF198sXbv3q3MzExNnz5dKSkpWrJkSdB6VdlXo0aN0jPPPKNBgwbp0UcfVb169XTxxReb1tuxY4fOOussffzxxxo3bpyefvppdejQQaNHjw6MYXW09708FdluixYt1K9fP/3tb38zPT43N1cOhyNQ/K3MeUI63CX4mWee0ZgxY/TEE0+oTZs2uvzyy5Wbm2sqoPy///f/ZBiGhg8fHhQfN26c1q1bpylTpmjEiBGaP3++hgwZElS8nzJlim677Ta1atVKTzzxhK688kq9+OKLuuCCC4IK3zk5ORo7dqxatmypxx57TGeffbYuvfTSoEKqy+WqdI6lKvK9dqzvxlJjx47Vvffeq7PPPltPP/20brjhBs2fP1/p6emmYv769et17bXX6vzzz9fTTz+tlJSUSu8rAMAxGABwhJdfftmQZHz88cfGzp07jc2bNxuvv/660aRJE6NevXrGf//7X2PNmjWGJOOmm24Keuw999xjSDI++eQTwzAM4+233zYkGStXriz3+Xbu3GlIMiZPnlzhHEeOHGm0bds2sJyfn29IMpo1a2bs3bs3EM/MzDQkGd26dTMOHToUiF977bVGXFyccfDgwUCsbdu2hiTjrbfeCsT
"text/plain": [
"<Figure size 1600x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"trl = 500\n",
"\n",
"print(f\"cue words: {dat['cue_sentence'][trl]}\")\n",
"print(f\"cue phonemes: {' '.join(dat['cue_sentence_phonemes'][trl]).replace('SIL',' | ')}\")\n",
"print(f\"decoded phonemes (raw): {' '.join(dat['decoded_phonemes_raw'][trl]).replace('SIL',' | ')}\")\n",
"print(f\"decoded words: {dat['decoded_sentence'][trl]}\")\n",
"print(f\"decoded phonemes: {' '.join(dat['decoded_sentence_phonemes'][trl]).replace('SIL',' | ')}\")\n",
"\n",
"\n",
"# plot logits in line plot format\n",
"plt.figure(figsize=(16, 5))\n",
"plt.plot(dat['decoded_logits'][trl], '.-', markersize=3, linewidth=1)\n",
"plt.xlim([0,dat['decoded_logits'][trl].shape[0]])\n",
"plt.title(f'Post-implant day {dat[\"post_implant_day\"][trl]}, cue: \"{dat[\"cue_sentence\"][trl]}\"')\n",
"plt.show()\n",
"\n",
"\n",
"# plot logits in image format\n",
"plt.figure(figsize=(16, 5))\n",
"plt.imshow(dat['decoded_logits'][trl].T, aspect='auto', cmap='Blues', interpolation='none')\n",
"plt.colorbar()\n",
"plt.yticks(np.arange(len(LOGIT_PHONE_DEF)), LOGIT_PHONE_DEF, fontsize=8)\n",
"plt.grid(axis='y', alpha=0.5)\n",
"plt.title(f'Post-implant day {dat[\"post_implant_day\"][trl]}, cue: \"{dat[\"cue_sentence\"][trl]}\"')\n",
"plt.show()\n",
"\n",
"\n",
"# plot logits in image format with softmax\n",
"logits_softmax = np.exp(dat['decoded_logits'][trl]) / np.sum(np.exp(dat['decoded_logits'][trl]), axis=1)[:, np.newaxis]\n",
"\n",
"plt.figure(figsize=(16, 5))\n",
"plt.imshow(logits_softmax.T, aspect='auto', cmap='Blues', interpolation='none')\n",
"plt.colorbar()\n",
"plt.yticks(np.arange(len(LOGIT_PHONE_DEF)), LOGIT_PHONE_DEF, fontsize=8)\n",
"plt.grid(axis='y', alpha=0.5)\n",
"plt.title(f'Post-implant day {dat[\"post_implant_day\"][trl]}, cue: \"{dat[\"cue_sentence\"][trl]}\"')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# calculate and plot phoneme error rate and word error rate for each session"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Day 25, vocab size 50, rPER: 1.67%, WER: 0.44%\n",
"Day 27, vocab size 50, rPER: 2.11%, WER: 0.00%\n",
"Day 27, vocab size 125000, rPER: 15.36%, WER: 9.84%\n",
"Day 34, vocab size 125000, rPER: 12.89%, WER: 13.61%\n",
"Day 46, vocab size 125000, rPER: 9.50%, WER: 4.61%\n",
"Day 48, vocab size 125000, rPER: 11.70%, WER: 5.40%\n",
"Day 69, vocab size 125000, rPER: 12.45%, WER: 6.15%\n",
"Day 74, vocab size 125000, rPER: 10.17%, WER: 3.89%\n",
"Day 76, vocab size 125000, rPER: 10.36%, WER: 4.60%\n",
"Day 81, vocab size 125000, rPER: 10.58%, WER: 4.06%\n",
"Day 83, vocab size 125000, rPER: 10.13%, WER: 4.37%\n",
"Day 88, vocab size 125000, rPER: 6.93%, WER: 2.54%\n",
"Day 90, vocab size 125000, rPER: 7.27%, WER: 3.25%\n",
"Day 95, vocab size 125000, rPER: 7.76%, WER: 0.99%\n",
"Day 223, vocab size 125000, rPER: 9.33%, WER: 3.10%\n",
"Day 244, vocab size 125000, rPER: 7.33%, WER: 1.82%\n"
]
}
],
"source": [
"unique_days = np.sort(np.unique(dat['post_implant_day']))\n",
"\n",
"rper_by_day = {} # raw phoneme error rate\n",
"wer_by_day = {} # word error rate\n",
"\n",
"for d, day in enumerate(unique_days):\n",
" for v, vocab_size in enumerate(np.array((50,125000))):\n",
"\n",
" ind = [i for i in np.where(dat[\"post_implant_day\"]==day)[0] if i in np.where(dat[\"vocab_size\"]==vocab_size)[0]]\n",
"\n",
" if len(ind)==0:\n",
" continue\n",
"\n",
" rper_day = calculate_aggregate_error_rate(\n",
" r = [dat['cue_sentence_phonemes'][i] for i in ind],\n",
" h = [dat['decoded_phonemes_raw'][i] for i in ind],\n",
" )\n",
"\n",
" wer_day = calculate_aggregate_error_rate(\n",
" r = [dat['cue_sentence'][i].split() for i in ind],\n",
" h = [dat['decoded_sentence'][i].split() for i in ind],\n",
" )\n",
" \n",
" print(f'Day {day}, vocab size {vocab_size}, rPER: {rper_day[0]*100:.2f}%, WER: {wer_day[0]*100:.2f}%')\n",
"\n",
" rper_by_day[(day, vocab_size)] = rper_day\n",
" wer_by_day[(day, vocab_size)] = wer_day\n"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABKYAAAKyCAYAAADvidZRAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/SrBM8AAAACXBIWXMAAA9hAAAPYQGoP6dpAAC+mklEQVR4nOzdeVxU9f7H8fcBFQUBNZUlULRNs1wyRXJfcuuWSqZZN5es7q30avza7KZpdaPbrmnrLbVb2opmZnaN3DNLjWuLuYWJCu4yAYoK398fcxlFFodhYGDm9Xw85uGc7/meM5/5wowzH77fz7GMMUYAAAAAAABAJfPzdAAAAAAAAADwTSSmAAAAAAAA4BEkpgAAAAAAAOARJKYAAAAAAADgESSmAAAAAAAA4BEkpgAAAAAAAOARJKYAAAAAAADgESSmAAAAAAAA4BEkpgAAAAAAAOARJKYAAAAAAADgEVUuMZWYmKgOHTooODhYjRs31uDBg7V169ZCfU6cOKF7771XF1xwgerWrasbb7xR+/fvL/W8xhhNmTJFERERqlOnjvr06aPt27dX5FMBAAAAAABAKapcYmrlypW699579e2332rZsmU6deqU+vbtq+zsbEef++67T5999pk++ugjrVy5Uvv27VN8fHyp533mmWc0Y8YMvfbaa1q/fr2CgoLUr18/nThxoqKfEgAAAAAAAIphGWOMp4MozcGDB9W4cWOtXLlS3bp1U2Zmpho1aqR58+Zp6NChkqRff/1VLVu21Lp169SpU6ci5zDGKDIyUv/3f/+n+++/X5KUmZmpsLAwzZkzRzfffHOlPicAAAAAAABUwRlT58rMzJQkNWjQQJK0ceNGnTp1Sn369HH0adGihZo0aaJ169YVe47U1FRlZGQUOiY0NFSxsbElHpObmyubzea4ZWZm6uDBg6rieTwAAAAAAIBqo0onpvLz8zVx4kR17txZV1xxhSQpIyNDtWrVUr169Qr1DQsLU0ZGRrHnKWgPCwtz+pjExESFhoY6bvXq1VPjxo31xx9/lPNZAQAAAAAAQJJqeDqA0tx777366aeftGbNmkp/7EmTJikhIcGxbbPZFB0drfz8fOXn51d6PAAAAAAAANWFn59zc6GqbGJq3LhxWrx4sVatWqWoqChHe3h4uE6ePKljx44VmjW1f/9+hYeHF3uugvb9+/crIiKi0DFt27Yt9piAgAAFBAQUaffz83N6cAEAAAAAAFCyKpdhMcZo3LhxWrBggb7++ms1a9as0P727durZs2aSk5OdrRt3bpVu3fvVlxcXLHnbNasmcLDwwsdY7PZtH79+hKPAQAAAAAAQMWqcompe++9V++++67mzZun4OBgZWRkKCMjQ8ePH5dkL1o+duxYJSQkaPny5dq4caPGjBmjuLi4Qlfka9GihRYsWCBJsixLEydO1JNPPqlFixbpxx9/1MiRIxUZGanBgwd74mkCAAAAAAD4vCq3lO/VV1+VJPXo0aNQ++zZszV69GhJ0osvvig/Pz/deOONys3NVb9+/fTKK68U6r9161bHFf0k6cEHH1R2drbuuusuHTt2TF26dNHSpUtVu3btCn0+AAAAAAAAKJ5ljDHlPUlmZqYCAwNVs2ZNd8RUJdlsNoWGhiozM1MhISGeDgcAAAAAAKDaK/OMqaysLH300UdKTk7W2rVrtW/fPp0+fVqSFBwcrCuvvFI9evTQoEGDdPXVV7s9YAAAAAAAAHgHp2dMpaWl6cknn9T8+fOVlZUlSapfv77CwsLUoEEDHT9+XEeOHNGePXuUl5cny7LUtm1bJSQk6NZbb63QJ1EZmDEFAAAAAADgXk4lph5++GHNmDFDeXl5GjBggIYNG6a4uLgiV8yTpJycHG3cuFH/+c9/NG/ePKWmpuqqq67Sm2++qXbt2lXIk6gMJKYAAAAAAADcy6nEVL169fS3v/1NEydOVIMGDcr0AF999ZUef/xx9enTR1OmTHE5UE8jMQUAAAAAAOBeTiWmjh49qvr165frgdxxDk8iMQUAAAAAAOBefs50ckdCqTonpQAAAAAAAOB+TiWmAAAAAAAAAHdzS2Lq448/VufOnXXBBReoYcOG6tq1qxYuXOiOUwMAAAAAAMBLlTsx9cwzz2j48OEKCgrSHXfcoT//+c86fPiwbrzxRk2fPt0dMQKVwhgpO9t+O3/lNQAAAAAAUF5OFT8vTXh4uCZMmKBJkyY52k6fPq1evXopNTVVaWlp5Q6yKqD4uffLzpbq1rXfz8qSgoI8Gw8AAAAAAN7O6RlTAwYM0O7du4u0HzlyRB06dCjUVqNGDbVt21aHDx8uf4QAAAAAAADwSk4npv744w+1atVK06dP19mTrLp06aIHH3xQ69ev14kTJ5SZman58+drzpw56tKlS4UEDQAAAAAAgOqvTEv5XnnlFT3yyCNq0aKF/vWvf+mKK67Qtm3bNHjwYG3dutXRzxijyy+/XIsWLVLz5s0rJPDKxlI+78dSPgAAAAAAKleZa0zt3btX99xzj5YuXaoHHnhAU6ZMUc2aNfXll19q27ZtkqSWLVuqT58+siyrQoL2BBJT3o/EFAAAAAAAlcvl4ucffvihJkyYoJCQEL355pvq1q2bu2OrUkhMeT8SUwAAAAAAVC6na0yda9iwYfrll1/UuXNn9ezZU3/5y19ks9ncGRsAAAAAAAC8mEuJqUOHDkmS6tevr7ffflvLli3T119/rZYtW2rBggVuDRAAAAAAAADeyenEVG5uriZOnKi6desqLCxMdevW1X333aeTJ0+qV69e+vHHH3Xrrbdq+PDhio+PV0ZGRkXGDQAAAAAAgGrO6cTU3//+d82YMUM333yzZs2apREjRmjGjBl69NFHJUm1a9fWM888o/Xr12vXrl1q2bKl3njjjQoLHAAAAAAAANWb08XPmzRpos6dO2v+/PmOthEjRmjt2rXavXt3ob75+fl67rnn9PjjjysrK8u9EXsIxc+9H8XPAQAAAACoXE7PmMrOzlbjxo0LtTVq1EjZ2dlFT+rnpwcffFCbN28uf4QAAAAAAADwSjWc7ditWzfNnj1bcXFxuuqqq/TDDz9o7ty56t27d4nHNG/e3C1BAgAAAAAAwPs4vZRv7969uu6667R582ZZliVjjFq3bq3FixcrKiqqouP0OJbyeT+W8gEAAAAAULmcnjF14YUXatOmTfruu++Ulpam6OhodezYUX5+Tq8GBAAAAAAAABycnjHl65gx5f2YMQUAAAAAQOWqctOdVq1apeuvv16RkZGyLEsLFy4stN+yrGJvzz77bInnnDp1apH+LVq0qOBnAgAAAAAAgNI4lZjq37+/vv/+e5ceIDs7W08//bRmzZrldP82bdqU2D89Pb3Q7e2335ZlWbrxxhtLPW+rVq0KHbdmzZoyPxcAAAAAAAC4j1M1pg4ePKhOnTqpW7duGjlypOLj4xUaGlrqMd9++63effddvf/++zp+/Ljmzp3rVEADBgzQgAEDStwfHh5eaPvTTz9Vz549z3sFwBo1ahQ5FgAAAAAAAJ7jVGJq48aNmjt3rqZNm6axY8fqzjvv1GWXXab27dsrLCxM9erV04kTJ3TkyBFt3bpVGzZs0B9//CF/f3/dfPPNevLJJ9WkSRO3B79//359/vnnTiW9tm/frsjISNWuXVtxcXFKTEwsNabc3Fzl5uY6tm02myQpPz9f+fn55Q8eVY79x+r3v/v54scMAAAAAIBrnL1YXpmKnxtjtGTJEs2ePVsrVqzQkSNHin3g1q1ba8iQIbrjjjsUERHhfNTnBmdZWrBggQYPHlzs/meeeUZPP/209u3bp9q1a5d4ni+++EJZWVm67LLLlJ6ermnTpmnv3r366aefFBwcXOwxU6dO1bRp04q0b968ucRjUL3l5Fhq1aqpJOnnn39XYCDXBQAAAAAAwBUxMTFO9SvXVfm2bNmiPXv26PDhw6pTp44aNWqkVq1anXeZn7POl5hq0aKFrr32Wr388stlOu+xY8fUtGlTvfDCCxo7dmyxfYqbMRUdHa2jR49yVT4vlZ0thYTYM7o2Wz5X5QMAAAA
"text/plain": [
"<Figure size 1200x700 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(12, 7))\n",
"ax1 = plt.subplot(211)\n",
"ax2 = plt.subplot(212)\n",
"\n",
"for d, day in enumerate(unique_days):\n",
" ind = np.array(dat[\"post_implant_day\"])==day\n",
" for v, vocab_size in enumerate(np.unique(np.array(dat[\"vocab_size\"])[ind])):\n",
"\n",
" if (day, vocab_size) in rper_by_day:\n",
" rper_day = rper_by_day[(day, vocab_size)]\n",
" ax1.plot(d, 100*rper_day[0], 'o', color='r' if vocab_size==50 else 'b', label=f'{vocab_size} words')\n",
" ax1.plot([d, d], [100*rper_day[1], 100*rper_day[2]], color='r' if vocab_size==50 else 'b')\n",
"\n",
" if (day, vocab_size) in wer_by_day:\n",
" wer_day = wer_by_day[(day, vocab_size)]\n",
" ax2.plot(d, 100*wer_day[0], 'o', color='r' if vocab_size==50 else 'b', label=f'{vocab_size} words')\n",
" ax2.plot([d, d], [100*wer_day[1], 100*wer_day[2]], color='r' if vocab_size==50 else 'b')\n",
"\n",
"# axis labels\n",
"ax1.set_ylabel('Raw phoneme error rate (%)', fontsize=14)\n",
"ax2.set_xlabel('Days post-implant', fontsize=14)\n",
"ax2.set_ylabel('Word error rate (%)', fontsize=14)\n",
"\n",
"# tick labels\n",
"ax1.set_xticks(np.arange(len(unique_days)))\n",
"ax1.set_xticklabels([f'{d}' for d in unique_days])\n",
"ax2.set_xticks(np.arange(len(unique_days)))\n",
"ax2.set_xticklabels([f'{d}' for d in unique_days])\n",
"\n",
"# y limits\n",
"ax1.set_ylim([0, 20])\n",
"ax2.set_ylim([0, 20])\n",
"\n",
"# remove splines\n",
"ax1.spines['top'].set_visible(False)\n",
"ax1.spines['right'].set_visible(False)\n",
"ax2.spines['top'].set_visible(False)\n",
"ax2.spines['right'].set_visible(False)\n",
"\n",
"# grid\n",
"ax1.grid(axis='y', alpha=0.4)\n",
"ax2.grid(axis='y', alpha=0.4)\n",
"\n",
"plt.tight_layout()\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 2
}