Skip to content

Commit

Permalink
assignment_1
Browse files Browse the repository at this point in the history
  • Loading branch information
kgogina committed Aug 3, 2024
1 parent 0542d62 commit 2bc0809
Show file tree
Hide file tree
Showing 2 changed files with 833 additions and 60 deletions.
157 changes: 130 additions & 27 deletions 01_materials/labs/lab_2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -36,9 +36,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQ0AAAEnCAYAAACzJRZYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAYzUlEQVR4nO3dfVBU1/0G8GflZRFlQdBViASpYaKIyJsmgikYI5ai0XS0NNWIWtOQoJFQG0uSGS0dwTZtahINEbUYhyomqUYbIwZawHQMFVGMFYNafEHRoKgs0hYDnN8fGfaXjaJ8cd9Yn8/Mnck9nsv5XiKP91723KNRSikQEXVTH1sXQES9C0ODiEQYGkQkwtAgIhGGBhGJMDSISIShQUQiDA0iEmFoEJEIQ8MCNm3aBI1GgzNnzti6FIs5c+YMNBoNNm3aJD62tLQUGo0GH374odnq6fyapaWlPTq+uLgYkydPhp+fH7RaLfR6PR5//HF88sknZqvRUTA0LCAxMRGff/45fH19bV0KdVNjYyNGjRqFP/7xj/j000+xbt06uLi4IDExEfn5+bYuz64427oARzRo0CAMGjTI1mWQQFJSEpKSkkzapk6disDAQOTm5mLOnDk2qsz+8ErDAm53exIXF4eQkBB8/vnniI6ORt++fTFs2DDk5eUBAHbv3o2IiAi4u7tj9OjRKCwsNPmap06dwvz58xEUFAR3d3c88MADmDZtGo4ePXrL+MeOHUN8fDzc3d0xaNAgpKamYvfu3be9fC8uLsakSZOg0+ng7u6OmJgY/O1vf+vReUtqBID//e9/SE9Px5AhQ9C3b1/Exsbi8OHDt/Q7ePAgnnzySXh7e8PNzQ3h4eF4//33e1SjhIuLC7y8vODszH9bv42hYUWXLl3C/PnzsXDhQuzcuROjR4/GggULkJmZiYyMDLz88sv4y1/+gv79+2PGjBmor683HltfXw8fHx+sWrUKhYWFWLt2LZydnfHII4+gpqbG2O/ixYuIjY1FTU0NcnJysHnzZjQ3N2PRokW31JOfn4/4+HjodDq89957eP/99+Ht7Y0pU6b0KDi6W2OnV155BbW1tdiwYQM2bNiA+vp6xMXFoba21tinpKQEMTExuH79Ot59913s3LkTYWFhSEpKuuvzlM7nLvPmzev2OXR0dKCtrQ319fVYvnw5Tpw4gV/84hfdPv6+oMjs8vLyFAB1+vRpY1tsbKwCoA4ePGhsa2xsVE5OTqpv377qwoULxvaqqioFQL311ltdjtHW1qZu3rypgoKC1EsvvWRs/+Uvf6k0Go06duyYSf8pU6YoAKqkpEQppVRLS4vy9vZW06ZNM+nX3t6uxowZo8aNG3fHczx9+rQCoPLy8sQ1lpSUKAAqIiJCdXR0GNvPnDmjXFxc1MKFC41tI0aMUOHh4errr782+dpTp05Vvr6+qr293eRrdp5f59dzcnJSCxYsuOO5fFvn9wmA0ul0avv27d0+9n7BKw0r8vX1RWRkpHHf29sber0eYWFh8PPzM7aPHDkSAHD27FljW1tbG7KyshAcHAxXV1c4OzvD1dUVJ0+exPHjx439ysrKEBISguDgYJOxn376aZP9/fv34+rVq0hOTkZbW5tx6+jowA9+8ANUVFSgpaVFdH7drbHTT3/6U2g0GuN+QEAAoqOjUVJSAuCb250vv/wSs2fPNn79zu2HP/whLl68eNsrmG9/vba2NmzcuLHb5/D222/jwIED2LlzJ6ZMmYKkpCRs3bq128ffD3izZkXe3t63tLm6ut7S7urqCuCbe/5O6enpWLt2LZYtW4bY2FgMGDAAffr0wcKFC/Hf//7X2K+xsRGBgYG3jDN48GCT/a+++goAMHPmzC7rvXr1Kvr169eNM5PV2GnIkCG3bTty5IhJjUuXLsXSpUtvO+aVK1e6XV93BAUFGf/7ySefREJCAlJTU5GUlIQ+ffhvLMDQ6DXy8/Mxd+5cZGVlmbRfuXIFXl5exn0fHx/jD9u3Xbp0yWR/4MCBAL75l/XRRx+97ZjfDRpz1dhVTZ1tPj4+JjVmZGTgRz/60W3HfPjhh0U1So0bNw6FhYW4fPmy+PvhqBgavYRGo4FWqzVp2717Ny5cuICHHnrI2BYbG4vf//73qK6uNrlFKSgoMDk2JiYGXl5eqK6uvu1DUkvW2Gnr1q1IT0833qKcPXsW+/fvx9y5cwF8EwhBQUE4cuTILUFkDUoplJWVwcvLyxhkxNDoNaZOnYpNmzZhxIgRCA0NRWVlJV5//XUMHTrUpF9aWhr+9Kc/ISEhAZmZmRg8eDC2bNmCL7/8EgCMl9j9+/fH22+/jeTkZFy9ehUzZ86EXq/H5cuXceTIEVy+fBk5OTkWqbFTQ0MDnnrqKTz77LNoamrC8uXL4ebmhoyMDGOfdevWISEhAVOmTMG8efPwwAMP4OrVqzh+/DgOHTqEDz74oMt6zp49i+HDhyM5OfmuzzWmT5+OMWPGICwsDD4+Pqivr8emTZtQVlZm/C0QfYPfiV7izTffhIuLC7Kzs3Hjxg1ERERg+/bteO2110z6+fn5oaysDGlpaUhJSYG7uzueeuopZGZmIjk52eQ2Yc6cOXjwwQfxu9/9Ds899xyam5uND2Ylv6aU1tgpKysLFRUVmD9/PgwGA8aNG4eCggIMHz7c2GfixIk4cOAAVq5cibS0NFy7dg0+Pj4IDg7Gj3/84zvWo5RCe3s72tvb71p7TEwMPvzwQ6xZswYGgwFeXl6IiorCxx9/jMTERNk3wsFplOLbyO8HP//5z7F161Y0NjYaH7QS9QSvNBxQZmYm/Pz88L3vfQ83btzAxx9/jA0bNuC1115jYNA9Y2g4IBcXF7z++us4f/482traEBQUhDfeeANLliyxdWnkAHh7QkQi/LQKEYkwNIhIhKFBRCJWfxDa0dGB+vp6eHh4mExWIiLbUkqhubkZfn5+d5xnY/XQqK+vh7+/v7WHJaJuqqur6/JTvIANQsPDwwPAN4XpdDprD29Vy5Yts/qY7777rlXHCwkJsep4APDCCy9YfczO6fmOzGAwwN/f3/gz2hWrh0bnLYlOp3P40Pju5C1H5OTkZPUx+/bta/UxHf3v6rfd7bEBH4QSkQhDg4hEGBpEJMLQICIRhgYRiTA0iEiEoUFEIj0KjXfeeQeBgYFwc3NDZGQkPvvsM3PXRUR2Shwa27ZtQ1paGl599VUcPnwYjz32GBISEnDu3DlL1EdEdkYcGm+88QZ+9rOfYeHChRg5ciRWr14Nf39/8Zuriah3EoXGzZs3UVlZifj4eJP2+Ph47N+/36yFEZF9Es09uXLlCtrb229ZaWrw4MG3XS0LAFpbW9Ha2mrcNxgMPSiTiOxFjx6EfndCi1Kqy0ku2dnZ8PT0NG6cFk/Uu4lCY+DAgXBycrrlqqKhoaHLdS4zMjLQ1NRk3Orq6npeLRHZnCg0XF1dERkZiaKiIpP2oqIiREdH3/YYrVZrnAZ/P0yHJ3J04vdppKen45lnnkFUVBTGjx+P3NxcnDt3DikpKZaoj4jsjDg0kpKS0NjYiMzMTFy8eBEhISH45JNPEBAQYIn6iMjO9OjNXS+88IJNXrlGRLbHuSdEJMLQICIRhgYRiTA0iEiEoUFEIgwNIhJhaBCRCEODiEQ0SillzQENBgM8PT3R1NTk8PNQNm3aZPUxBwwYYNXxZsyYYdXxbMXKPyY20d2fTV5pEJEIQ4OIRBgaRCTC0CAiEYYGEYkwNIhIhKFBRCIMDSISYWgQkYg4NPbt24dp06bBz88PGo0GH330kQXKIiJ7JQ6NlpYWjBkzBmvWrLFEPURk58QvFk5ISEBCQoIlaiGiXoDPNIhIpEdLGEhwAWgix2LxKw0uAE3kWCweGlwAmsixWPz2RKvVQqvVWnoYIrIScWjcuHEDp06dMu6fPn0aVVVV8Pb2xoMPPmjW4ojI/ohD4+DBg5g4caJxPz09HQCQnJxsk9fbEZF1iUMjLi7uvnhfIhHdHj+nQUQiDA0iEmFoEJEIQ4OIRBgaRCTC0CAiEYYGEYkwNIhIxOJzT+5n8+bNs/qYK1assOp4np6eVh0PAN577z2rj0n/j1caRCTC0CAiEYYGEYkwNIhIhKFBRCIMDSISYWgQkQhDg4hEGBpEJCIKjezsbIwdOxYeHh7Q6/WYMWMGampqLFUbEdkhUWiUlZUhNTUV5eXlKCoqQltbG+Lj49HS0mKp+ojIzojmnhQWFprs5+XlQa/Xo7KyEt///vfNWhgR2ad7mrDW1NQEAPD29u6yD9dyJXIsPX4QqpRCeno6JkyYgJCQkC77cS1XIsfS49BYtGgRvvjiC2zduvWO/biWK5Fj6dHtyeLFi7Fr1y7s27cPQ4cOvWNfruVK5FhEoaGUwuLFi7Fjxw6UlpYiMDDQUnURkZ0ShUZqaiq2bNmCnTt3wsPDA5cuXQLwzdub+vbta5ECici+iJ5p5OTkoKmpCXFxcfD19TVu27Zts1R9RGRnxLcnRHR/49wTIhJhaBCRCEODiEQYGkQkwtAgIhGGBhGJMDSISIShQUQiXADawYSHh1t1PC8vL6uOBwABAQFWH5P+H680iEiEoUFEIgwNIhJhaBCRCEODiEQYGkQkwtAgIhGGBhGJMDSISET8jtDQ0FDodDrodDqMHz8ee/bssVRtRGSHRKExdOhQrFq1CgcPHsTBgwfx+OOPY/r06Th27Jil6iMiOyOaezJt2jST/ZUrVyInJwfl5eUYNWqUWQsjIvvU4wlr7e3t+OCDD9DS0oLx48d32Y8LQBM5FvGD0KNHj6J///7QarVISUnBjh07EBwc3GV/LgBN5FjEofHwww+jqqoK5eXleP7555GcnIzq6uou+3MBaCLHIr49cXV1xUMPPQQAiIqKQkVFBd58802sW7futv25ADSRY7nnz2kopUyeWRCRYxNdabzyyitISEiAv78/mpubUVBQgNLSUhQWFlqqPiKyM6LQ+Oqrr/DMM8/g4sWL8PT0RGhoKAoLCzF58mRL1UdEdkYUGhs3brRUHUTUS3DuCRGJMDSISIShQUQiDA0iEmFoEJEIQ4OIRBgaRCTCtVwdzPTp0606XklJiVXHA4C4uDirj1lVVWX1MYcNG2b1MbuDVxpEJMLQICIRhgYRiTA0iEiEoUFEIgwNIhJhaBCRCEODiEQYGkQkwtAgIpF7Co3s7GxoNBqkpaWZqRwisnc9Do2Kigrk5uYiNDTUnPUQkZ3rUWjcuHEDs2fPxvr16zFgwABz10REdqxHoZGamorExEQ88cQTd+3b2toKg8FgshFR7yWeGl9QUIBDhw6hoqKiW/2zs7Px61//WlwYEdkn0ZVGXV0dlixZgvz8fLi5uXXrGC4ATeRYRFcalZWVaGhoQGRkpLGtvb0d+/btw5o1a9Da2gonJyeTY7gANJFjEYXGpEmTcPToUZO2+fPnY8SIEVi2bNktgUFEjkcUGh4eHggJCTFp69evH3x8fG5pJyLHxE+EEpHIPb9YuLS01AxlEFFvwSsNIhJhaBCRCEODiEQYGkQkwtAgIhGGBhGJMDSISESjlFLWHNBgMMDT0xNNTU3Q6XTWHJocxIwZM6w+5vXr160+prU/A9Xdn01eaRCRCEODiEQYGkQkwtAgIhGGBhGJMDSISIShQUQiDA0iEmFoEJEIQ4OIREShsWLFCmg0GpNtyJAhlqqNiOyQ+B2ho0aNQnFxsXGfyxYQ3V/EoeHs7MyrC6L7mPiZxsmTJ+Hn54fAwED85Cc/QW1t7R37cwFoIsciCo1HHnkEmzdvxt69e7F+/XpcunQJ0dHRaGxs7PKY7OxseHp6Gjd/f/97LpqIbOee3qfR0tKC4cOH4+WXX0Z6evpt+7S2tqK1tdW4bzAY4O/vz/dpUI/xfRqW0d33adzTYkn9+vXD6NGjcfLkyS77cAFoIsdyT5/TaG1txfHjx+Hr62uueojIzolCY+nSpSgrK8Pp06fxz3/+EzNnzoTBYEBycrKl6iMiOyO6PTl//jyefvppXLlyBYMGDcKjjz6K8vJyBAQEWKo+IrIzotAoKCiwVB1E1Etw7gkRiTA0iEiEoUFEIgwNIhJhaBCRCEODiEQYGkQkck9zT+jOrD3hyBZjVlVVWXU8wDbf17CwMKuPaa94pUFEIgwNIhJhaBCRCEODiEQYGkQkwtAgIhGGBhGJMDSISIShQUQiDA0iEhGHxoULFzBnzhz4+PjA3d0dYWFhqKystERtRGSHRHNPrl27hpiYGEycOBF79uyBXq/Hv//9b3h5eVmoPCKyN6LQ+O1vfwt/f3/k5eUZ24YNG2bumojIjoluT3bt2oWoqCjMmjULer0e4eHhWL9+/R2P4QLQRI5FFBq1tbXIyclBUFAQ9u7di5SUFLz44ovYvHlzl8dwAWgixyIKjY6ODkRERCArKwvh4eF47rnn8OyzzyInJ6fLYzIyMtDU1GTc6urq7rloIrIdUWj4+voiODjYpG3kyJE4d+5cl8dotVrodDqTjYh6L1FoxMTEoKamxqTtxIkTXJaR6D4iCo2XXnoJ5eXlyMrKwqlTp7Blyxbk5uYiNTXVUvURkZ0RhcbYsWOxY8cObN26FSEhIfjNb36D1atXY/bs2Zaqj4jsjPjFwlOnTsXUqVMtUQsR9QKce0JEIgwNIhJhaBCRCEODiEQYGkQkwtAgIhGGBhGJcAFoC1q9erXVx7T2gsy2eJ9KWlqa1cdcsWKF1ce0V7zSICIRhgYRiTA0iEiEoUFEIgwNIhJhaBCRCEODiEQYGkQkwtAgIhFRaAwbNgwajeaWje8IJbp/iD5GXlFRgfb2duP+v/71L0yePBmzZs0ye2FEZJ9EoTFo0CCT/VWrVmH48OGIjY01a1FEZL96/Ezj5s2byM/Px4IFC6DRaMxZExHZsR7Pcv3oo49w/fp1zJs37479Wltb0draatznAtBEvVuPrzQ2btyIhIQE+Pn53bEfF4Amciw9Co2zZ8+iuLgYCxcuvGtfLgBN5Fh6dHuSl5cHvV6PxMTEu/bVarXQarU9GYaI7JD4SqOjowN5eXlITk6GszNf/EV0vxGHRnFxMc6dO4cFCxZYoh4isnPiS4X4+HgopSxRCxH1Apx7QkQiDA0iEmFoEJEIQ4OIRBgaRCTC0CAiEYYGEYlY/SOdnZ/xuB9mu3799ddWH7Ojo8Oq47W1tVl1PAAms6at5X74+9p5jnf7HJZGWfmTWufPn+dMVyI7VldXh6FDh3b551YPjY6ODtTX18PDw0P08h6DwQB/f3/U1dVBp9NZsELb4nk6jt52jkopNDc3w8/PD336dP3kwuq3J3369Lljit2NTqfrFf8D7hXP03H0pnP09PS8ax8+CCUiEYYGEYn0mtDQarVYvny5w7/Qh+fpOBz1HK3+IJSIerdec6VBRPaBoUFEIgwNIhJhaBCRSK8IjXfeeQeBgYFwc3NDZGQkPvvsM1uXZFbZ2dkYO3YsPDw8oNfrMWPGDNTU1Ni6LIvLzs6GRqNBWlqarUsxuwsXLmDOnDnw8fGBu7s7wsLCUFlZaeuyzMLuQ2Pbtm1IS0vDq6++isOHD+Oxxx5DQkICzp07Z+vSzKasrAypqakoLy9HUVER2traEB8fj5aWFluXZjEVFRXIzc1FaGiorUsxu2vXriEmJgYuLi7Ys2cPqqur8Yc//AFeXl62Ls08lJ0bN26cSklJMWkbMWKE+tWvfmWjiiyvoaFBAVBlZWW2LsUimpubVVBQkCoqKlKxsbFqyZIlti7JrJYtW6YmTJhg6zIsxq6vNG7evInKykrEx8ebtMfHx2P//v02qsrympqaAADe3t42rsQyUlNTkZiYiCeeeMLWpVjErl27EBUVhVmzZkGv1yM8PBzr16+3dVlmY9ehceXKFbS3t2Pw4MEm7YMHD8alS5dsVJVlKaWQnp6OCRMmICQkxNblmF1BQQEOHTqE7OxsW5diMbW1tcjJyUFQUBD27t2LlJQUvPjii9i8ebOtSzOLXrGu4nen0CulRNPqe5NFixbhiy++wD/+8Q9bl2J2dXV1WLJkCT799FO4ubnZuhyL6ejoQFRUFLKysgAA4eHhOHbsGHJycjB37lwbV3fv7PpKY+DAgXBycrrlqqKhoeGWqw9HsHjxYuzatQslJSX39PoAe1VZWYmGhgZERkbC2dkZzs7OKCsrw1tvvQVnZ2e0t7fbukSz8PX1RXBwsEnbyJEjHebhvV2HhqurKyIjI1FUVGTSXlRUhOjoaBtVZX5KKSxatAjbt2/H3//+dwQGBtq6JIuYNGkSjh49iqqqKuMWFRWF2bNno6qqCk5OTrYu0SxiYmJu+ZX5iRMnEBAQYKOKzMzGD2LvqqCgQLm4uKiNGzeq6upqlZaWpvr166fOnDlj69LM5vnnn1eenp6qtLRUXbx40bj95z//sXVpFueIvz05cOCAcnZ2VitXrlQnT55Uf/7zn5W7u7vKz8+3dWlmYfehoZRSa9euVQEBAcrV1VVFREQ43K8iAdx2y8vLs3VpFueIoaGUUn/9619VSEiI0mq1asSIESo3N9fWJZkNp8YTkYhdP9MgIvvD0CAiEYYGEYkwNIhIhKFBRCIMDSISYWgQkQhDg4hEGBpEJMLQICIRhgYRiTA0iEjk/wAAXRINA99d9gAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 300x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"sample_index = 45\n",
"plt.figure(figsize=(3, 3))\n",
Expand All @@ -58,7 +69,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -91,7 +102,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -101,18 +112,43 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"array([0., 0., 0., 1., 0., 0., 0., 0., 0., 0.])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"one_hot(n_classes=10, y=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],\n",
" [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],\n",
" [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.]])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"one_hot(n_classes=10, y=[0, 4, 9, 1])"
]
Expand Down Expand Up @@ -143,7 +179,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {
"collapsed": false
},
Expand All @@ -164,9 +200,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[9.99662391e-01 3.35349373e-04 2.25956630e-06]\n"
]
}
],
"source": [
"print(softmax([10, 2, -3]))"
]
Expand All @@ -181,9 +225,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[9.99662391e-01 3.35349373e-04 2.25956630e-06]\n",
" [2.47262316e-03 9.97527377e-01 1.38536042e-11]]\n"
]
}
],
"source": [
"X = np.array([[10, 2, -3],\n",
" [-1, 5, -20]])\n",
Expand All @@ -199,18 +252,36 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 11,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0\n"
]
}
],
"source": [
"print(np.sum(softmax([10, 2, -3])))"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 12,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"softmax of 2 vectors:\n",
"[[9.99662391e-01 3.35349373e-04 2.25956630e-06]\n",
" [2.47262316e-03 9.97527377e-01 1.38536042e-11]]\n"
]
}
],
"source": [
"print(\"softmax of 2 vectors:\")\n",
"X = np.array([[10, 2, -3],\n",
Expand All @@ -227,9 +298,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1. 1.]\n"
]
}
],
"source": [
"print(np.sum(softmax(X), axis=1))"
]
Expand All @@ -251,9 +330,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.01005033585350145\n"
]
}
],
"source": [
"def nll(Y_true, Y_pred):\n",
" Y_true = np.asarray(Y_true)\n",
Expand All @@ -279,9 +366,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4.605170185988091\n"
]
}
],
"source": [
"print(nll([1, 0, 0], [0.01, 0.01, .98]))"
]
Expand All @@ -295,9 +390,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 17,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.010050335853503449\n"
]
}
],
"source": [
"# Check that the average NLL of the following 3 almost perfect\n",
"# predictions is close to 0\n",
Expand Down Expand Up @@ -809,7 +912,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
"version": "3.9.15"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 2bc0809

Please sign in to comment.