Skip to content

Commit

Permalink
Merge pull request #160 from biaslab/dev-issue-151
Browse files Browse the repository at this point in the history
feat: new functional dependencies pipeline `RequireMarginal`
  • Loading branch information
bvdmitri authored Jul 6, 2022
2 parents 4f85423 + a60c944 commit 58433aa
Show file tree
Hide file tree
Showing 11 changed files with 1,068 additions and 554 deletions.
406 changes: 4 additions & 402 deletions demo/Expectation Propagation.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion demo/Normalizing Flow Tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@
" y_lat2[k] ~ dot(y_lat1[k], [1, 1])\n",
"\n",
" # specify observations\n",
" y[k] ~ Probit(y_lat2[k]) # default: where { pipeline = RequireInbound(in = NormalMeanPrecision(0, 1.0)) }\n",
" y[k] ~ Probit(y_lat2[k]) # default: where { pipeline = RequireMessage(in = NormalMeanPrecision(0, 1.0)) }\n",
"\n",
" end\n",
"\n",
Expand Down
12 changes: 6 additions & 6 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ makedocs(
],
"Library" => [
"Messages" => "lib/message.md",
"Functional forms" => "lib/form.md",
"Prod implementation" => "lib/prod.md",
"Factor nodes" => [
"Overview" => "lib/node.md",
"Overview" => "lib/nodes/nodes.md",
"Flow" => "lib/nodes/flow.md"
],
"Math utils" => "lib/math.md",
"Helper utils" => "lib/helpers.md",
"Exported methods" => "lib/methods.md"
"Functional forms" => "lib/form.md",
"Prod implementation" => "lib/prod.md",
"Math utils" => "lib/math.md",
"Helper utils" => "lib/helpers.md",
"Exported methods" => "lib/methods.md"
],
"Examples" => [
"Overview" => "examples/overview.md",
Expand Down
2 changes: 1 addition & 1 deletion docs/src/examples/flow_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ The corresponding probabilistic model for the binary classification task can be
y_lat2[k] ~ dot(y_lat1[k], [1, 1])
# specify observations
y[k] ~ Probit(y_lat2[k]) # default: where { pipeline = RequireInbound(in = NormalMeanPrecision(0, 1.0)) }
y[k] ~ Probit(y_lat2[k]) # default: where { pipeline = RequireMessage(in = NormalMeanPrecision(0, 1.0)) }
end
Expand Down
4 changes: 2 additions & 2 deletions docs/src/examples/probit.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ p = plot!(p, data_x[2:end], label = "x")
for k = 2:nr_samples + 1
x[k] ~ NormalMeanPrecision(x[k - 1] + 0.1, 100)
y[k - 1] ~ Probit(x[k]) where {
# Probit node by default uses RequireInbound pipeline with vague(NormalMeanPrecision) message as initial value for `in` edge
# Probit node by default uses RequireMessage pipeline with vague(NormalMeanPrecision) message as initial value for `in` edge
# To change initial value use may specify it manually, like. Changes to the initial message may improve stability in some situations
pipeline = RequireInbound(in = NormalMeanPrecision(0, 0.01))
pipeline = RequireMessage(in = NormalMeanPrecision(0, 0.01))
}
end
Expand Down
68 changes: 0 additions & 68 deletions docs/src/lib/node.md

This file was deleted.

107 changes: 107 additions & 0 deletions docs/src/lib/nodes/nodes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@

# [Nodes implementation](@id lib-node)

In the message passing framework, one of the most important concepts is a factor node.
A factor node represents a local function in a factorised representation of a generative model.

!!! note
To quickly check the list of all available factor nodes that can be used in the model specification language, call `?make_node` or `Base.doc(make_node)`.

## [Adding a custom node](@id lib-custom-node)

`ReactiveMP.jl` exports the `@node` macro that allows for quick definition of a factor node with a __fixed__ number of edges. The interface is the following:

```julia
struct MyNewCustomNode end

@node MyNewCustomNode Stochastic [ x, y, z ]
# ^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^ ^^^^^^^^^^^
# Node's tag/name Node's type A fixed set of edges
# Another possible The very first edge (in this example `x`) is considered
# value is to be the output of the node
# `Deterministic`
```

This expression registers a new node that can be used with the inference engine. Note howeve, that the `@node` macro does not generate any message passing update rules.
These must be defined using the `@rule` macro.

## [Node types](@id lib-node-types)

We distinguish different types of factor nodes in order to have better control over Bethe Free Energy computation.
Each factor node has either the [`Deterministic`](@ref) or [`Stochastic`](@ref) functional form type.

```@docs
Deterministic
Stochastic
isdeterministic
isstochastic
sdtype
```

```@setup lib-node-types
using ReactiveMP
```

For example the `+` node has the [`Deterministic`](@ref) type:

```@example lib-node-types
plus_node = make_node(+)
println("Is `+` node deterministic: ", isdeterministic(plus_node))
println("Is `+` node stochastic: ", isstochastic(plus_node))
nothing #hide
```

On the other hand, the `Bernoulli` node has the [`Stochastic`](@ref) type:

```@example lib-node-types
bernoulli_node = make_node(Bernoulli)
println("Is `Bernoulli` node deterministic: ", isdeterministic(bernoulli_node))
println("Is `Bernoulli` node stochastic: ", isstochastic(bernoulli_node))
```

To get an actual instance of the type object we use [`sdtype`](@ref) function:

```@example lib-node-types
println("sdtype() of `+` node is ", sdtype(plus_node))
println("sdtype() of `Bernoulli` node is ", sdtype(bernoulli_node))
nothing #hide
```

## [Node functional dependencies pipeline](@id lib-node-functional-dependencies-pipeline)

The generic implementation of factor nodes in ReactiveMP supports custom functional dependency pipelines. Briefly, the __functional dependencies pipeline__ defines what
dependencies are need to compute a single message. As an example, consider the belief-propagation message update equation for a factor node $f$ with three edges: $x$, $y$ and $z$:

```math
\mu(x) = \int \mu(y) \mu(z) f(x, y, z) \mathrm{d}y \mathrm{d}z
```

Here we see that in the standard setting for the belief-propagation message out of edge $x$, we need only messages from the edges $y$ and $z$. In contrast, consider the variational message update rule equation with mean-field assumption:

```math
\mu(x) = \exp \int q(y) q(z) \log f(x, y, z) \mathrm{d}y \mathrm{d}z
```

We see that in this setting, we do not need messages $\mu(y)$ and $\mu(z)$, but only the marginals $q(y)$ and $q(z)$. The purpose of a __functional dependencies pipeline__ is to determine functional dependencies (a set of messages or marginals) that are needed to compute a single message. By default, `ReactiveMP.jl` uses so-called `DefaultFunctionalDependencies` that correctly implements belief-propagation and variational message passing schemes (including both mean-field and structured factorisations). The full list of built-in pipelines is presented below:

```@docs
DefaultFunctionalDependencies
RequireMessageFunctionalDependencies
RequireMarginalFunctionalDependencies
RequireEverythingFunctionalDependencies
```

## [Node traits](@id lib-node-traits)

Each factor node has to define the [`as_node_functional_form`](@ref) trait function and to specify a [`ValidNodeFunctionalForm`](@ref) singleton as a return object. By default [`as_node_functional_form`](@ref) returns [`UndefinedNodeFunctionalForm`](@ref). Objects that do not specify this property correctly cannot be used in model specification.

!!! note
[`@node`](@ref) macro does that automatically

```@docs
ValidNodeFunctionalForm
UndefinedNodeFunctionalForm
as_node_functional_form
```
Loading

0 comments on commit 58433aa

Please sign in to comment.