Train a Hebbian Predictive Coding Network on short sequences of moving objects. Both invariant representations and a generative model are learned.
This repository provides the code for the experiments in our paper:
Matthias Brucklacher*, Sander M. Bohte, Jorge F. Mejias, Cyriel M. A. Pennartz - [Local minimization of prediction errors drives learning of invariant object representations in a generative network model of visual perception] (https://www.biorxiv.org/content/10.1101/2022.07.18.500392v3)
*Corresponding author. Email: [email protected]
The model consists of three hierarchical areas attempting to minimize reconstruction errors in inference and learning. While this is similar to other Predictive Coding networks (Rao & Ballard (1999), Dora et al. (2021)), the inputs here are sequences and not static images. Importantly, the network does not actively predict forward in time, activity in representation neurons simply carries over.
Without externally provided labels, the network learns invariant representations that can be read out linearly. The increasing invariance across the network hierarchy is also mirrored in a slower timescale of activity updates in higher network areas.
-
Setup the conda environment
pcenv
by running:conda env create -f environment.yml
-
With the activated environment, install the local package 'scripts' to allow absolute imports of modules. From directory 'PCInvariance', run:
pip install -e .
- The small number of required packages makes it fast to create the conda environment from hand.
- The necessary packages are listed in the file environment.yml
- Activate the environment:
source activate pcenv
- Run subsequent commands from directory 'scripts'.
-
Fig. 4a:
python train.py --data mnist_extended.npy --labels labels_mnist_extended.npy --trafos 0 0 0 --resultfolder fig4a --epochs 0 python post_run_analysis/plot_rdm.py --simulation_id fig4a
-
Fig. 4b:
python train.py --data mnist_extended.npy --labels labels_mnist_extended.npy --trafos 1 1 1 --resultfolder fig4b --epochs 20 python post_run_analysis/plot_rdm.py --simulation_id fig4b
-
Fig. 4c
python train.py --data mnist_extended.npy --labels labels_mnist_extended.npy --trafos 0 0 0 --resultfolder fig4c --epochs 20 python post_run_analysis/plot_rdm.py --simulation_id fig4c
-
Fig. 4d
python train.py --data mnist_extended.npy --labels labels_mnist_extended.npy --trafos 0 0 0 --noise_on 1 --resultfolder fig4d --epochs 20 python post_run_analysis/plot_rdm.py --simulation_id fig4d
-
Fig. 4e
python train.py --data mnist_extended.npy --labels labels_mnist_extended.npy --trafos 2 2 2 --resultfolder fig4e --epochs 20 python post_run_analysis/plot_rdm.py --simulation_id fig4e
-
Fig. 4f
python train.py --data smallnorb_extended.npy --labels smallnorb_labels.npy --trafos 0 0 0 --resultfolder fig4f --epochs 20 python post_run_analysis/plot_rdm.py --simulation_id fig4f
- First, train networks:
python train.py --data smallnorb_extended.npy --labels smallnorb_labels.npy --trafos 0 0 0 --resultfolder fig5a --epochs 20 --n_runs 4 python train.py --data mnist_extended_fast.npy --labels labels_mnist_extended.npy --trafos 0 0 0 --resultfolder fig5b_trafo-0_static-0_noise-0 --epochs 20 --n_runs 4 python train.py --data mnist_extended_fast.npy --labels labels_mnist_extended.npy --trafos 1 1 1 --resultfolder fig5b_trafo-1_static-0_noise-0 --epochs 20 --n_runs 4 python train.py --data mnist_extended_fast.npy --labels labels_mnist_extended.npy --trafos 2 2 2 --resultfolder fig5b_trafo-2_static-0_noise-0 --epochs 20 --n_runs 4 python train.py --data mnist_extended_fast.npy --labels labels_mnist_extended.npy --trafos 0 0 0 --resultfolder fig5b_trafo-0_static-1_noise-0 --do_train_static 1 --epochs 20 --n_runs 4 python train.py --data mnist_extended_fast.npy --labels labels_mnist_extended.npy --trafos 1 1 1 --resultfolder fig5b_trafo-1_static-1_noise-0 --do_train_static 1 --epochs 20 --n_runs 4 python train.py --data mnist_extended_fast.npy --labels labels_mnist_extended.npy --trafos 2 2 2 --resultfolder fig5b_trafo-2_static-1_noise-0 --do_train_static 1 --epochs 20 --n_runs 4
- Then plot:
python post_run_analysis/plot_fig5ab.py
- Train networks with standard architecture and with an increased number of neurons. To generate all necessary training runs, replace n_ipc, the number of instances per class in the training data, below with [1, 5, 10, 15, 20].
python train.py --data mnist_extended_fast.npy --labels labels_mnist_extended.npy --trafos 0 0 0 --resultfolder fig5cd_arch-[2000-500-30]_nipc-<n_ipc> --epochs 20 --n_runs 4 --use_validation_data 1 --n_instances_per_class_train <n_ipc>
python train.py --data mnist_extended_fast.npy --labels labels_mnist_extended.npy --trafos 0 0 0 --resultfolder fig5cd_arch-[4000-2000-90]_nipc-<n_ipc> --epochs 20 --n_runs 4 --use_validation_data 1 --n_instances_per_class_train <n_ipc>
- Then, compute baselines and plot.
python post_run_analysis/plot_fig5cd.py
Figure 6. The network develops a hierarchy of timescales comparable to experimental data from rodent visual cortex.
-
Compute autocorrelations
python post_run_analysis/autocorrelation_compute.py
-
Run statistical analysis and plot autocorrelation decay
python post_run_analysis/autocorrelation_analyze.py
- First, train networks:
python train.py --data mnist_extended_fast.npy --labels labels_mnist_extended.npy --trafos 0 0 0 --resultfolder fig7b_translation --epochs 20 --n_runs 4 python train.py --data mnist_extended_fast.npy --labels labels_mnist_extended.npy --trafos 1 1 1 --resultfolder fig7b_rotation --epochs 20 --n_runs 4 python train.py --data mnist_extended_fast.npy --labels labels_mnist_extended.npy --trafos 2 2 2 --resultfolder fig7b_scaling --epochs 20 --n_runs 4
- Then
python post_run_analysis/plot_fig7b.py --simulation_id_1 fig7b_translation --simulation_id_2 fig7b_rotation --simulation_id_3 fig7b_scaling