diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 0000000..e69de29 diff --git a/404.html b/404.html new file mode 100644 index 0000000..bf4d66f --- /dev/null +++ b/404.html @@ -0,0 +1,493 @@ + + + +
+ + + + + + + + + + + + + + + + + + +jpc.pc_energy_fn(network: PyTree[Callable], activities: PyTree[ArrayLike], output: ArrayLike, input: Optional[ArrayLike] = None) -> Scalar
+
+¤Computes the free energy for a feedforward neural network of the form
+given parameters \(θ\), free activities \(\mathbf{z}\), output +\(\mathbf{z}_L = \mathbf{y}\) and optionally input \(\mathbf{z}_0 = \mathbf{x}\). +The activity of each layer \(\mathbf{z}_\ell\) is some function of the previous +layer, e.g. \(f_\ell(W_\ell \mathbf{z}_{\ell-1} + \mathbf{b}_\ell)\) +for a fully connected layer.
+Note
+The input and output correspond to the prior and observation of +the generative model, respectively.
+Main arguments:
+network
: List of callable network layers.activities
: List of activities for each layer free to vary.output
: Observation or target of the generative model.input
: Optional prior of the generative model.Returns:
+The total energy normalised by batch size.
+ +jpc.hpc_energy_fn(amortiser: PyTree[Callable], generator: PyTree[Callable], activities: PyTree[ArrayLike], output: ArrayLike, input: ArrayLike) -> Scalar
+
+¤Computes the free energy for a 'hybrid' predictive coding network.
+@article{tscshantz2023hybrid,
+ title={Hybrid predictive coding: Inferring, fast and slow},
+ author={Tscshantz, Alexander and Millidge, Beren and Seth, Anil K and Buckley, Christopher L},
+ journal={PLoS Computational Biology},
+ volume={19},
+ number={8},
+ pages={e1011280},
+ year={2023},
+ publisher={Public Library of Science San Francisco, CA USA}
+}
+
Note
+Input is required so currently this only supports supervised training.
+Main arguments:
+amortiser
: List of callable layers for network amortising the inference
+ of the generative model.generator
: List of callable layers for the generative model.activities
: List of activities for each layer free to vary.output
: Observation of the generative model (or input of the amortiser).input
: Prior of the generative model (or output of the amortiser).Returns:
+The total energy normalised by batch size.
+ +jpc.compute_pc_param_grads(network: PyTree[Callable], activities: PyTree[ArrayLike], output: ArrayLike, input: Optional[ArrayLike] = None) -> PyTree[Array]
+
+¤Computes the gradient of the energy with respect to network parameters \(\partial \mathcal{F} / \partial θ\).
+Main arguments:
+network
: List of callable network layers.activities
: List of activities for each layer free to vary.output
: Observation or target of the generative model.input
: Optional prior of the generative model.Returns:
+List of parameter gradients for each network layer.
+ +jpc.compute_gen_param_grads(amortiser: PyTree[Callable], generator: PyTree[Callable], activities: PyTree[ArrayLike], output: ArrayLike, input: ArrayLike) -> PyTree[Array]
+
+¤Computes the gradient of the energy w.r.t the parameters of a generative model \(\partial \mathcal{F} / \partial θ\).
+Note
+This has the same functionality as compute_pc_param_grads
but can be
+used together with compute_amort_param_grads
for a more user-friendly
+API when training hybrid predictive coding networks.
Main arguments:
+amortiser
: List of callable layers for the network amortising the
+ inference of the generative model.generator
: List of callable layers for the generative model.activities
: List of activities for each layer free to vary.output
: Observation or target of the generative model.input
: Prior of the generative model.Returns:
+List of parameter gradients for each layer of the generative network.
+ +jpc.compute_amort_param_grads(amortiser: PyTree[Callable], generator: PyTree[Callable], activities: PyTree[ArrayLike], output: ArrayLike, input: ArrayLike) -> PyTree[Array]
+
+¤Computes the gradient of the energy w.r.t the parameters of an amortised model \(\partial \mathcal{F} / \partial \phi\).
+Main arguments:
+amortiser
: List of callable layers for a network amortising the
+ inference of the generative model.generator
: List of callable layers for the generative network.activities
: List of activities for each layer free to vary.output
: Observation or target of the generative model.input
: Optional prior of the generative model.Returns:
+List of parameter gradients for each layer of the amortiser.
+ +jpc.solve_pc_activities(generator: PyTree[Callable], activities: PyTree[ArrayLike], output: ArrayLike, input: Optional[ArrayLike] = None, amortiser: Optional[PyTree[Callable]] = None, solver: AbstractSolver = Dopri5(), n_iters: int = 300, stepsize_controller: AbstractStepSizeController = PIDController(rtol=1e-05, atol=1e-05), dt: Union[float, int] = None, record_iters: bool = False) -> PyTree[Array]
+
+¤Solves the activity (inference) dynamics of a predictive coding network.
+This is a wrapper around diffrax.diffeqsolve
to integrate the gradient
+ODE system _neg_activity_grad
defining the PC activity dynamics
where \(\mathcal{F}\) is the free energy, \(\mathbf{z}\) are the activities, +with \(\mathbf{z}_L\) clamped to some target and \(\mathbf{z}_0\) optionally +equal to some prior.
+Main arguments:
+generator
: List of callable layers for the generative model.activities
: List of activities for each layer free to vary.output
: Observation or target of the generative model.input
: Optional prior of the generative model.Other arguments:
+amortiser
: Optional list of callable layers for a network amortising
+ the inference of the generative model.solver
: diffrax (ODE) solver to be used. Default is Dopri5.n_iters
: Number of integration steps (300 as default).stepsize_controller
: diffrax controllers for step size integration.
+ Defaults to PIDController
.dt
: Integration step size. Defaults to None, since step size is
+ automatically determined by the default PIDController
.record_iters
: If True
, returns all integration steps. False
by
+ default.Returns:
+List with solution of the activity dynamics for each layer.
+ +jpc.init_activities_with_ffwd(network: PyTree[Callable], input: ArrayLike) -> PyTree[Array]
+
+¤Initialises layers' activity with a feedforward pass.
+Main arguments:
+network
: List of callable network layers.input
: for the network.Returns:
+List with feedforward values of each layer.
+ +jpc.init_activities_from_gaussian(key: PRNGKeyArray, network: PyTree[Callable], mode: str, batch_size: int, sigma: Scalar = 0.05) -> PyTree[Array]
+
+¤Initialises network activities from a zero-mean Gaussian \(\sim \mathcal{N}(0, \sigma^2)\).
+Main arguments:
+key
: jax.random.PRNGKey
for sampling.network
: List of callable network layers.mode
: If 'supervised', all hidden layers are initialised. If
+ 'unsupervised' the input layer is also initialised.batch_size
: Dimension of data batch.sigma
: Standard deviation for Gaussian to sample activities from.Returns:
+List of randomly initialised activities for each layer.
+ +