Skip to content

Commit

Permalink
add a tutorial on multi armed bandit tasks (#137)
Browse files Browse the repository at this point in the history
* update readme

* add an example of multi-armed bandit task (first part)

* add support for missing observations
  • Loading branch information
LegrandNico authored Nov 22, 2023
1 parent 3f5a501 commit e914f84
Show file tree
Hide file tree
Showing 14 changed files with 2,119 additions and 169 deletions.
30 changes: 15 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

# PyHGF: A Graph Neural Network Library for Predictive Coding

PyHGF is a Python library that implements the generalized, nodalized and multilevel Hierarchical Gaussian Filters for predictive coding written on top of [JAX](https://jax.readthedocs.io/en/latest/jax.html). The library can create and manipulate graph neural networks that perform belief update through the diffusion of precision-weighted prediction errors under new observations. The core functions are derivable, JIT-able, and are designed to interface smoothly with other libraries in the JAX ecosystem for neural networks, reinforcement leanring, Bayesian inference or optimization.
PyHGF is a Python library written on top of [JAX](https://jax.readthedocs.io/en/latest/jax.html) to create and manipulate graph neural networks that can perform belief updates through the diffusion of predictions and precision-weighted prediction errors. These networks can serve as biologically plausible computational models of cognitive functions for computational psychiatry and reinforcement learning or as a generalisation of Bayesian filtering to arbitrarily sized graphical structures for signal processing. In their most standard form, these models are a generalisation and nodalisation of the Hierarchical Gaussian Filters (HGF) for predictive coding. The library is made modular and designed to facilitate the manipulation of probabilistic networks, so the user can focus on model design. The core functions are derivable, JIT-able, and designed to interface smoothly with other libraries in the JAX ecosystem for neural networks, reinforcement learning, Bayesian inference or optimization.

* 📖 [API Documentation](https://ilabcode.github.io/pyhgf/api.html)
* ✏️ [Tutorials, examples and exercises](https://ilabcode.github.io/pyhgf/tutorials.html)
* 📖 [API Documentation](https://ilabcode.github.io/pyhgf/)
* ✏️ [Tutorials and examples](https://ilabcode.github.io/pyhgf/learn.html)

## Getting started

Expand All @@ -21,24 +21,24 @@ The current version under development can be installed from the master branch of

`pip install “git+https://github.com/ilabcode/pyhgf.git”`

### How does it works?
### How does it work?

The nodalized Hierarchical Gaussian Filter consists of a network of probabilistic nodes hierarchically structured where each node can inherit its value and volatility sufficient statistics from other parents node. The presentation of a new observation at the lower level of the hierarchy (i.e. the input node) triggers a recursive update of the nodes' belief through the bottom-up propagation of precision-weighted prediction error.
The nodalized Hierarchical Gaussian Filter consists of a network of probabilistic nodes hierarchically structured where each node can inherit its value and volatility sufficient statistics from other parent nodes. The presentation of a new observation at the lower level of the hierarchy (i.e. the input node) triggers a recursive update of the nodes' belief through the bottom-up propagation of precision-weighted prediction error.

More generally, pyhgf operates on graph neural networks that can be defined and updated through the following variables:

* The nodes attributes (dictionary) that store each node's parameters (value, precision, learning rates, volatility coupling, ...).
* The nodes' attributes (dictionary) that store each node's parameters (value, precision, learning rates, volatility coupling, ...).
* The edges (tuple) that lists, for each node, the indexes of the value and volatility parents.
* A set of update functions that operate on any of the 3 other variables, starting from a target node.
* An update sequence (tuple) that define the order in which the update functions are called, and the target node.
* An update sequence (tuple) that defines the order in which the update functions are called, and the target node.

![png](https://raw.githubusercontent.com/ilabcode/pyhgf/master/docs/source/images/graph_networks.svg)

Value parent and volatility parent are nodes themself. Any node can be a value and/or volatility parent for other nodes and have multiple value and/or volatility parents. A filtering structure consists of nodes embedding other nodes hierarchically. Nodes are parametrized by their sufficient statistic and parents. The transformations between nodes can be linear, non-linear, or any function (thus a *generalization* of the HGF).

The resulting probabilistic network operates as a filter toward new observation. If a decision function (taking the whole model as a parameter) is also defined, behaviors can be triggered accordingly. By comparing those behaviors with actual outcomes, a surprise function can be optimized over the range of parameters of interest.
The resulting probabilistic network operates as a filter toward new observation. If a decision function (taking the whole model as a parameter) is also defined, behaviours can be triggered accordingly. By comparing those behaviours with actual outcomes, a surprise function can be optimized over the range of parameters of interest.

You can find a deeper introduction on how to create and manipulate networks under the following link:
You can find a deeper introduction to how to create and manipulate networks under the following link:

* 🎓 [How to create and manipulate networks of probabilistic nodes](https://ilabcode.github.io/pyhgf/notebooks/0-Creating_networks.html#creating-and-manipulating-networks-of-probabilistic-nodes)

Expand Down Expand Up @@ -86,12 +86,12 @@ print(f"Model's surprise = {surprise}")
hgf.plot_trajectories();
```

`Creating a binary Hierarchical Gaussian Filter with 2 levels.`
`... Create the update sequence from the network structure.`
`... Create the belief propagation function.`
`... Cache the belief propagation function.`
`Adding 320 new observations.`
`Model's surprise = 203.6395263671875`
`Creating a binary Hierarchical Gaussian Filter with 2 levels.`
`... Create the update sequence from the network structure.`
`... Create the belief propagation function.`
`... Cache the belief propagation function.`
`Adding 320 new observations.`
`Model's surprise = 203.6395263671875`

![png](https://raw.githubusercontent.com/ilabcode/pyhgf/master/docs/source/images/trajectories.png)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


.. contents:: Table of Contents
:depth: 2
:depth: 5

API
+++
Expand Down
16 changes: 8 additions & 8 deletions docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

[![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit&logoColor=white)](https://github.com/pre-commit/pre-commit) [![license](https://img.shields.io/badge/License-GPL%20v3-blue.svg)](https://github.com/ilabcode/pyhgf/blob/master/LICENSE) [![codecov](https://codecov.io/gh/ilabcode/pyhgf/branch/master/graph/badge.svg)](https://codecov.io/gh/ilabcode/pyhgf) [![black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![mypy](http://www.mypy-lang.org/static/mypy_badge.svg)](http://mypy-lang.org/) [![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) [![pip](https://badge.fury.io/py/pyhgf.svg)](https://badge.fury.io/py/pyhgf)

# The multilevel, generalized and nodalized Hierarchical Gaussian Filter for predictive coding
# PyHGF: A Graph Neural Network Library for Predictive Coding

PyHGF is a Python library that implements the generalized, nodalized and multilevel Hierarchical Gaussian Filters for predictive coding written on top of [JAX](https://jax.readthedocs.io/en/latest/jax.html). The library can create and manipulate graph neural networks that perform belief update through the diffusion of precision-weighted prediction errors under new observations. The core functions are derivable, JIT-able, and are designed to interface smoothly with other libraries in the JAX ecosystem for neural networks, reinforcement leanring, Bayesian inference or optimization.
PyHGF is a Python library written on top of [JAX](https://jax.readthedocs.io/en/latest/jax.html) to create and manipulate graph neural networks that can perform belief updates through the diffusion of predictions and precision-weighted prediction errors. These networks can serve as biologically plausible computational models of cognitive functions for computational psychiatry and reinforcement learning or as a generalisation of Bayesian filtering to arbitrarily sized graphical structures for signal processing. In their most standard form, these models are a generalisation and nodalisation of the Hierarchical Gaussian Filters (HGF) for predictive coding. The library is made modular and designed to facilitate the manipulation of probabilistic networks, so the user can focus on model design. The core functions are derivable, JIT-able, and designed to interface smoothly with other libraries in the JAX ecosystem for neural networks, reinforcement learning, Bayesian inference or optimization.

* 📖 [API Documentation](https://ilabcode.github.io/pyhgf/)
* ✏️ [Tutorials and examples](https://ilabcode.github.io/pyhgf/tutorials.html)
* ✏️ [Tutorials and examples](https://ilabcode.github.io/pyhgf/learn.html)

## Getting started

Expand All @@ -25,16 +25,16 @@ The current version under development can be installed from the master branch of
pip install “git+https://github.com/ilabcode/pyhgf.git”
```

### How does it works?
### How does it work?

The nodalized Hierarchical Gaussian Filter consists of a network of probabilistic nodes hierarchically structured where each node can inherit its value and volatility sufficient statistics from other parents node. The presentation of a new observation at the lower level of the hierarchy (i.e. the input node) triggers a recursive update of the nodes' belief through the bottom-up propagation of precision-weighted prediction error.
The nodalized Hierarchical Gaussian Filter consists of a network of probabilistic nodes hierarchically structured where each node can inherit its value and volatility sufficient statistics from other parent nodes. The presentation of a new observation at the lower level of the hierarchy (i.e. the input node) triggers a recursive update of the nodes' belief through the bottom-up propagation of precision-weighted prediction error.

More generally, pyhgf operates on graph neural networks that can be defined and updated through the following variables:

* The node parameters (dictionary) that store each node's parameters (value, precision, learning rates, volatility coupling, ...).
* The node structure (tuple) that list, for each node, the indexes of the value and volatility parents.
* The node structure (tuple) that lists, for each node, the indexes of the value and volatility parents.
* A set of update functions that operate on any of the 3 other variables, starting from a target node.
* An update sequence (tuple) that define the order in which the update functions are called, and the target node.
* An update sequence (tuple) that defines the order in which the update functions are called, and the target node.

![png](./images/graph_networks.svg)

Expand All @@ -52,7 +52,7 @@ The pyhgf package includes pre-implemented standard HGF models that can be used

### Model fitting

Here we demonstrate how to fit a two-level binary Hierarchical Gaussian filter. The input time series are binary outcome from {cite:p}`Iglesias2021`.
Here we demonstrate how to fit a two-level binary Hierarchical Gaussian filter. The input time series are binary outcomes from {cite:p}`Iglesias2021`.

```python
from pyhgf.model import HGF
Expand Down
6 changes: 6 additions & 0 deletions docs/source/learn.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,12 @@ Recovering parameters from the generative model and using the sampling functiona
:link-type: ref
:img-top: ./images/input_mean_precision.png

:::

:::{grid-item-card} Multi-armed bandit task with independent reward and punishments
:link: example_3
:link-type: ref

:::
::::

Expand Down
6 changes: 4 additions & 2 deletions docs/source/notebooks/1.2-Categorical_HGF.ipynb

Large diffs are not rendered by default.

Loading

0 comments on commit e914f84

Please sign in to comment.