Skip to content

Commit

Permalink
Update example notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
francesco-innocenti committed Nov 26, 2024
1 parent 75d4bb9 commit 5c5d910
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 284 deletions.
85 changes: 21 additions & 64 deletions examples/discriminative_pc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/thebuckleylab/jpc/blob/main/examples/discriminative_pc.ipynb)\n",
"\n",
"This notebook demonstrates how to train a neural network with predictive coding (PC) to discriminate or classify MNIST digits."
"This notebook demonstrates how to train a simple feedforward network with predictive coding (PC) to discriminate or classify MNIST digits."
]
},
{
Expand Down Expand Up @@ -45,11 +45,13 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"## Hyperparameters\n",
"\n",
"We define some global parameters, including network architecture, learning rate, batch size etc."
"We define some global parameters, including network architecture, learning rate, batch size, etc."
]
},
{
Expand Down Expand Up @@ -137,9 +139,7 @@
},
{
"cell_type": "markdown",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"metadata": {},
"source": [
"## Network\n",
"\n",
Expand All @@ -150,43 +150,7 @@
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Sequential(\n",
" layers=(\n",
" Linear(\n",
" weight=f32[300,784],\n",
" bias=f32[300],\n",
" in_features=784,\n",
" out_features=300,\n",
" use_bias=True\n",
" ),\n",
" Lambda(fn=<wrapped function relu>)\n",
" )\n",
"), Sequential(\n",
" layers=(\n",
" Linear(\n",
" weight=f32[300,300],\n",
" bias=f32[300],\n",
" in_features=300,\n",
" out_features=300,\n",
" use_bias=True\n",
" ),\n",
" Lambda(fn=<wrapped function relu>)\n",
" )\n",
"), Linear(\n",
" weight=f32[10,300],\n",
" bias=f32[10],\n",
" in_features=300,\n",
" out_features=10,\n",
" use_bias=True\n",
")]\n"
]
}
],
"outputs": [],
"source": [
"key = jax.random.PRNGKey(SEED)\n",
"_, *subkeys = jax.random.split(key, 4)\n",
Expand All @@ -204,15 +168,14 @@
" ],\n",
" ),\n",
" nn.Linear(300, 10, key=subkeys[2]),\n",
"]\n",
"print(network)"
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"You can also use the utility `jpc.get_fc_network` to define an MLP or fully connected network with some activation functions."
"You can also use the utility `jpc.make_mlp` to define a multi-layer perceptron (MLP) or fully connected network with some activation function (see docs [here](https://thebuckleylab.github.io/jpc/api/make_mlp/) for more details)."
]
},
{
Expand Down Expand Up @@ -274,9 +237,7 @@
"source": [
"## Train and test\n",
"\n",
"A PC network can be trained in a single line of code with `jpc.make_pc_step()`. See the documentation for more. Similarly, we can use `jpc.test_discriminative_pc()` to compute the network accuracy. Note that these functions are already \"jitted\" for performance.\n",
"\n",
"Below we simply wrap each of these functions in our training and test loops, respectively."
"A PC network can be updated in a single line of code with `jpc.make_pc_step()` (see the [docs](https://thebuckleylab.github.io/jpc/api/Training/#jpc.make_pc_step) for more details). Similarly, we can use `jpc.test_discriminative_pc()` to compute the network accuracy (docs [here](https://thebuckleylab.github.io/jpc/api/Testing/#jpc.test_discriminative_pc)). Note that these functions are already \"jitted\" for optimised performance. Below we simply wrap each of these functions in training and test loops, respectively."
]
},
{
Expand Down Expand Up @@ -352,24 +313,20 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Train iter 100, train loss=0.018149, avg test accuracy=93.790062\n",
"Train iter 200, train loss=0.012088, avg test accuracy=95.142227\n",
"Train iter 300, train loss=0.016424, avg test accuracy=95.723160\n"
"Train iter 100, train loss=0.015131, avg test accuracy=93.389420\n",
"Train iter 200, train loss=0.017173, avg test accuracy=95.102165\n",
"Train iter 300, train loss=0.013283, avg test accuracy=95.783257\n"
]
}
],
"source": [
"import warnings\n",
"with warnings.catch_warnings():\n",
" warnings.simplefilter('ignore')\n",
" \n",
" train(\n",
" model=network,\n",
" lr=LEARNING_RATE,\n",
" batch_size=BATCH_SIZE,\n",
" test_every=TEST_EVERY,\n",
" n_train_iters=N_TRAIN_ITERS\n",
" )"
"train(\n",
" model=network,\n",
" lr=LEARNING_RATE,\n",
" batch_size=BATCH_SIZE,\n",
" test_every=TEST_EVERY,\n",
" n_train_iters=N_TRAIN_ITERS\n",
")"
]
}
],
Expand All @@ -394,4 +351,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
161 changes: 18 additions & 143 deletions examples/hybrid_pc.ipynb

Large diffs are not rendered by default.

67 changes: 44 additions & 23 deletions examples/linear_net_theoretical_energy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,17 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Theoretical energy of deep linear networks\n",
"# Theoretical PC energy of deep linear networks\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/thebuckleylab/jpc/blob/main/examples/linear_net_theoretical_energy.ipynb)"
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/thebuckleylab/jpc/blob/main/examples/linear_net_theoretical_energy.ipynb)\n",
"\n",
"This notebook demonstrates how to compute the theoretical PC energy at the inference equilibrium $\\mathcal{F}^*$ when $\\mathcal{F}|_{\\nabla_{\\mathbf{z}} \\mathcal{F} = \\mathbf{0}}$ for a deep linear network with input and output $(\\mathbf{x}_i, \\mathbf{y}_i)$ (see [Innocenti et al., 2024](https://arxiv.org/abs/2408.11979))\n",
"\n",
"\\begin{equation}\n",
" \\mathcal{F}^* = \\frac{1}{2N} \\sum_{i=1}^N (\\mathbf{y}_i - W_{L:1}\\mathbf{x}_i)^T S^{-1}(\\mathbf{y}_i - W_{L:1}\\mathbf{x}_i)\n",
"\\end{equation}\n",
"\n",
"where $S = I_{d_y} + \\sum_{\\ell=2}^L (W_{L:\\ell})(W_{L:\\ell})^T$ and $W_{k:\\ell} = W_k \\dots W_\\ell$ for $\\ell, k \\in 1,\\dots, L$.\n"
]
},
{
Expand Down Expand Up @@ -52,11 +60,13 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"## Hyperparameters\n",
"\n",
"We define some global parameters, including network architecture, learning rate, batch size etc."
"We define some global parameters, including network architecture, learning rate, batch size, etc."
]
},
{
Expand All @@ -75,7 +85,9 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"## Dataset\n",
"\n",
Expand All @@ -88,6 +100,9 @@
"metadata": {},
"outputs": [],
"source": [
"#@title data utils\n",
"\n",
"\n",
"def get_mnist_loaders(batch_size):\n",
" train_data = MNIST(train=True, normalise=True)\n",
" test_data = MNIST(train=False, normalise=True)\n",
Expand Down Expand Up @@ -191,9 +206,13 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"## Linear network"
"## Linear network\n",
"\n",
"We'll use a linear network with 10 hidden layers as an example."
]
},
{
Expand All @@ -214,13 +233,13 @@
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"## Train and test\n",
"\n",
"A PC network can be trained in a single line of code with `jpc.make_pc_step()`. See the documentation for more. Similarly, we can use `jpc.test_discriminative_pc()` to compute the network accuracy. Note that these functions are already \"jitted\" for performance.\n",
"\n",
"Below we simply wrap each of these functions in our training and test loops, respectively."
"To compute the theoretical energy, we can use `jpc.linear_equilib_energy()` (see the the [docs](https://thebuckleylab.github.io/jpc/api/Analytical%20tools/#jpc.linear_equilib_energy) for more details) which as clear from the equation above just takes the model and the data."
]
},
{
Expand Down Expand Up @@ -302,7 +321,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run"
"## Run\n",
"\n",
"Below we plot the theoretical energy against the numerical one."
]
},
{
Expand All @@ -314,16 +335,16 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Train iter 10, train loss=0.067622, avg test accuracy=59.535255\n",
"Train iter 20, train loss=0.054068, avg test accuracy=75.230370\n",
"Train iter 30, train loss=0.063356, avg test accuracy=77.453926\n",
"Train iter 40, train loss=0.051848, avg test accuracy=80.048080\n",
"Train iter 50, train loss=0.061488, avg test accuracy=82.061295\n",
"Train iter 60, train loss=0.044830, avg test accuracy=80.789261\n",
"Train iter 70, train loss=0.045716, avg test accuracy=84.174683\n",
"Train iter 80, train loss=0.053921, avg test accuracy=82.041267\n",
"Train iter 90, train loss=0.040125, avg test accuracy=83.072914\n",
"Train iter 100, train loss=0.050980, avg test accuracy=83.974358\n"
"Train iter 10, train loss=0.070245, avg test accuracy=64.853767\n",
"Train iter 20, train loss=0.075006, avg test accuracy=69.961937\n",
"Train iter 30, train loss=0.055347, avg test accuracy=70.292465\n",
"Train iter 40, train loss=0.057690, avg test accuracy=78.275238\n",
"Train iter 50, train loss=0.052301, avg test accuracy=79.607368\n",
"Train iter 60, train loss=0.051747, avg test accuracy=80.909454\n",
"Train iter 70, train loss=0.053040, avg test accuracy=80.238380\n",
"Train iter 80, train loss=0.047872, avg test accuracy=81.029648\n",
"Train iter 90, train loss=0.051192, avg test accuracy=82.662262\n",
"Train iter 100, train loss=0.054825, avg test accuracy=83.533653\n"
]
},
{
Expand Down Expand Up @@ -377,4 +398,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}
Loading

0 comments on commit 5c5d910

Please sign in to comment.