diff --git a/demo/Expectation Propagation.ipynb b/demo/Expectation Propagation.ipynb
index 5f189be29..c7b680364 100644
--- a/demo/Expectation Propagation.ipynb
+++ b/demo/Expectation Propagation.ipynb
@@ -125,155 +125,7 @@
"outputs": [
{
"data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
+ "image/svg+xml": "\n\n"
},
"execution_count": 5,
"metadata": {},
@@ -312,9 +164,9 @@
" for k = 2:nr_samples + 1\n",
" x[k] ~ NormalMeanPrecision(x[k - 1] + 0.1, 100)\n",
" y[k - 1] ~ Probit(x[k]) where {\n",
- " # Probit node by default uses RequireInbound pipeline with vague(NormalMeanPrecision) message as initial value for `in` edge\n",
+ " # Probit node by default uses RequireMessage pipeline with vague(NormalMeanPrecision) message as initial value for `in` edge\n",
" # To change initial value use may specify it manually, like. Changes to the initial message may improve stability in some situations\n",
- " pipeline = RequireInbound(in = NormalMeanPrecision(0, 0.01)) \n",
+ " pipeline = RequireMessage(in = NormalMeanPrecision(0, 0.01)) \n",
" }\n",
" end\n",
" \n",
@@ -379,257 +231,7 @@
"outputs": [
{
"data": {
- "image/svg+xml": [
- "\n",
- "\n"
- ]
+ "image/svg+xml": "\n\n"
},
"execution_count": 8,
"metadata": {},
diff --git a/demo/Normalizing Flow Tutorial.ipynb b/demo/Normalizing Flow Tutorial.ipynb
index e17649dba..0220cb70a 100644
--- a/demo/Normalizing Flow Tutorial.ipynb
+++ b/demo/Normalizing Flow Tutorial.ipynb
@@ -637,7 +637,7 @@
" y_lat2[k] ~ dot(y_lat1[k], [1, 1])\n",
"\n",
" # specify observations\n",
- " y[k] ~ Probit(y_lat2[k]) # default: where { pipeline = RequireInbound(in = NormalMeanPrecision(0, 1.0)) }\n",
+ " y[k] ~ Probit(y_lat2[k]) # default: where { pipeline = RequireMessage(in = NormalMeanPrecision(0, 1.0)) }\n",
"\n",
" end\n",
"\n",
diff --git a/docs/make.jl b/docs/make.jl
index 97637315c..1bd4889d0 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -26,15 +26,15 @@ makedocs(
],
"Library" => [
"Messages" => "lib/message.md",
- "Functional forms" => "lib/form.md",
- "Prod implementation" => "lib/prod.md",
"Factor nodes" => [
- "Overview" => "lib/node.md",
+ "Overview" => "lib/nodes/nodes.md",
"Flow" => "lib/nodes/flow.md"
],
- "Math utils" => "lib/math.md",
- "Helper utils" => "lib/helpers.md",
- "Exported methods" => "lib/methods.md"
+ "Functional forms" => "lib/form.md",
+ "Prod implementation" => "lib/prod.md",
+ "Math utils" => "lib/math.md",
+ "Helper utils" => "lib/helpers.md",
+ "Exported methods" => "lib/methods.md"
],
"Examples" => [
"Overview" => "examples/overview.md",
diff --git a/docs/src/examples/flow_tutorial.md b/docs/src/examples/flow_tutorial.md
index 4dd5f2f10..4cbe338bc 100644
--- a/docs/src/examples/flow_tutorial.md
+++ b/docs/src/examples/flow_tutorial.md
@@ -346,7 +346,7 @@ The corresponding probabilistic model for the binary classification task can be
y_lat2[k] ~ dot(y_lat1[k], [1, 1])
# specify observations
- y[k] ~ Probit(y_lat2[k]) # default: where { pipeline = RequireInbound(in = NormalMeanPrecision(0, 1.0)) }
+ y[k] ~ Probit(y_lat2[k]) # default: where { pipeline = RequireMessage(in = NormalMeanPrecision(0, 1.0)) }
end
diff --git a/docs/src/examples/probit.md b/docs/src/examples/probit.md
index 46c425227..1ee84f44f 100644
--- a/docs/src/examples/probit.md
+++ b/docs/src/examples/probit.md
@@ -80,9 +80,9 @@ p = plot!(p, data_x[2:end], label = "x")
for k = 2:nr_samples + 1
x[k] ~ NormalMeanPrecision(x[k - 1] + 0.1, 100)
y[k - 1] ~ Probit(x[k]) where {
- # Probit node by default uses RequireInbound pipeline with vague(NormalMeanPrecision) message as initial value for `in` edge
+ # Probit node by default uses RequireMessage pipeline with vague(NormalMeanPrecision) message as initial value for `in` edge
# To change initial value use may specify it manually, like. Changes to the initial message may improve stability in some situations
- pipeline = RequireInbound(in = NormalMeanPrecision(0, 0.01))
+ pipeline = RequireMessage(in = NormalMeanPrecision(0, 0.01))
}
end
diff --git a/docs/src/lib/node.md b/docs/src/lib/node.md
deleted file mode 100644
index 8d1fe755c..000000000
--- a/docs/src/lib/node.md
+++ /dev/null
@@ -1,68 +0,0 @@
-
-# [Nodes implementation](@id lib-node)
-
-In message passing framework one of the most important concepts is factor node.
-Factor node represents a local function in a factorised representation of a generative model.
-
-!!! note
- To quickly check the list of all available factor nodes that can be used in the model specification language call `?make_node` or `Base.doc(make_node)`.
-
-## [Node traits](@id lib-node-traits)
-
-Each factor node has to define [`as_node_functional_form`](@ref) trait function and to specify [`ValidNodeFunctionalForm`](@ref) singleton as a return object. By default [`as_node_functional_form`](@ref) returns [`UndefinedNodeFunctionalForm`](@ref). Objects that do not specify this property correctly cannot be used in model specification.
-
-!!! note
- [`@node`](@ref) macro does that automatically
-
-```@docs
-ValidNodeFunctionalForm
-UndefinedNodeFunctionalForm
-as_node_functional_form
-```
-
-## [Node types](@id lib-node-types)
-
-We distinguish different types of factor nodes to have a better control over Bethe Free Energy computation.
-Each factor node has either [`Deterministic`](@ref) or [`Stochastic`](@ref) functional form type.
-
-```@docs
-Deterministic
-Stochastic
-isdeterministic
-isstochastic
-sdtype
-```
-
-```@setup lib-node-types
-using ReactiveMP
-```
-
-For example `+` node has the [`Deterministic`](@ref) type:
-
-```@example lib-node-types
-plus_node = make_node(+)
-
-println("Is `+` node deterministic: ", isdeterministic(plus_node))
-println("Is `+` node stochastic: ", isstochastic(plus_node))
-nothing #hide
-```
-
-On the other hand `Bernoulli` node has the [`Stochastic`](@ref) type:
-
-```@example lib-node-types
-bernoulli_node = make_node(Bernoulli)
-
-println("Is `Bernoulli` node deterministic: ", isdeterministic(bernoulli_node))
-println("Is `Bernoulli` node stochastic: ", isstochastic(bernoulli_node))
-```
-
-To get an actual instance of the type object we use [`sdtype`](@ref) function:
-
-```@example lib-node-types
-println("sdtype() of `+` node is ", sdtype(plus_node))
-println("sdtype() of `Bernoulli` node is ", sdtype(bernoulli_node))
-nothing #hide
-```
-
-## [Node factorisation constraints](@id lib-node-factorisation-constraints)
-
diff --git a/docs/src/lib/nodes/nodes.md b/docs/src/lib/nodes/nodes.md
new file mode 100644
index 000000000..0e112add6
--- /dev/null
+++ b/docs/src/lib/nodes/nodes.md
@@ -0,0 +1,107 @@
+
+# [Nodes implementation](@id lib-node)
+
+In the message passing framework, one of the most important concepts is a factor node.
+A factor node represents a local function in a factorised representation of a generative model.
+
+!!! note
+ To quickly check the list of all available factor nodes that can be used in the model specification language, call `?make_node` or `Base.doc(make_node)`.
+
+## [Adding a custom node](@id lib-custom-node)
+
+`ReactiveMP.jl` exports the `@node` macro that allows for quick definition of a factor node with a __fixed__ number of edges. The interface is the following:
+
+```julia
+struct MyNewCustomNode end
+
+@node MyNewCustomNode Stochastic [ x, y, z ]
+# ^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^ ^^^^^^^^^^^
+# Node's tag/name Node's type A fixed set of edges
+# Another possible The very first edge (in this example `x`) is considered
+# value is to be the output of the node
+# `Deterministic`
+```
+
+This expression registers a new node that can be used with the inference engine. Note howeve, that the `@node` macro does not generate any message passing update rules.
+These must be defined using the `@rule` macro.
+
+## [Node types](@id lib-node-types)
+
+We distinguish different types of factor nodes in order to have better control over Bethe Free Energy computation.
+Each factor node has either the [`Deterministic`](@ref) or [`Stochastic`](@ref) functional form type.
+
+```@docs
+Deterministic
+Stochastic
+isdeterministic
+isstochastic
+sdtype
+```
+
+```@setup lib-node-types
+using ReactiveMP
+```
+
+For example the `+` node has the [`Deterministic`](@ref) type:
+
+```@example lib-node-types
+plus_node = make_node(+)
+
+println("Is `+` node deterministic: ", isdeterministic(plus_node))
+println("Is `+` node stochastic: ", isstochastic(plus_node))
+nothing #hide
+```
+
+On the other hand, the `Bernoulli` node has the [`Stochastic`](@ref) type:
+
+```@example lib-node-types
+bernoulli_node = make_node(Bernoulli)
+
+println("Is `Bernoulli` node deterministic: ", isdeterministic(bernoulli_node))
+println("Is `Bernoulli` node stochastic: ", isstochastic(bernoulli_node))
+```
+
+To get an actual instance of the type object we use [`sdtype`](@ref) function:
+
+```@example lib-node-types
+println("sdtype() of `+` node is ", sdtype(plus_node))
+println("sdtype() of `Bernoulli` node is ", sdtype(bernoulli_node))
+nothing #hide
+```
+
+## [Node functional dependencies pipeline](@id lib-node-functional-dependencies-pipeline)
+
+The generic implementation of factor nodes in ReactiveMP supports custom functional dependency pipelines. Briefly, the __functional dependencies pipeline__ defines what
+dependencies are need to compute a single message. As an example, consider the belief-propagation message update equation for a factor node $f$ with three edges: $x$, $y$ and $z$:
+
+```math
+\mu(x) = \int \mu(y) \mu(z) f(x, y, z) \mathrm{d}y \mathrm{d}z
+```
+
+Here we see that in the standard setting for the belief-propagation message out of edge $x$, we need only messages from the edges $y$ and $z$. In contrast, consider the variational message update rule equation with mean-field assumption:
+
+```math
+\mu(x) = \exp \int q(y) q(z) \log f(x, y, z) \mathrm{d}y \mathrm{d}z
+```
+
+We see that in this setting, we do not need messages $\mu(y)$ and $\mu(z)$, but only the marginals $q(y)$ and $q(z)$. The purpose of a __functional dependencies pipeline__ is to determine functional dependencies (a set of messages or marginals) that are needed to compute a single message. By default, `ReactiveMP.jl` uses so-called `DefaultFunctionalDependencies` that correctly implements belief-propagation and variational message passing schemes (including both mean-field and structured factorisations). The full list of built-in pipelines is presented below:
+
+```@docs
+DefaultFunctionalDependencies
+RequireMessageFunctionalDependencies
+RequireMarginalFunctionalDependencies
+RequireEverythingFunctionalDependencies
+```
+
+## [Node traits](@id lib-node-traits)
+
+Each factor node has to define the [`as_node_functional_form`](@ref) trait function and to specify a [`ValidNodeFunctionalForm`](@ref) singleton as a return object. By default [`as_node_functional_form`](@ref) returns [`UndefinedNodeFunctionalForm`](@ref). Objects that do not specify this property correctly cannot be used in model specification.
+
+!!! note
+ [`@node`](@ref) macro does that automatically
+
+```@docs
+ValidNodeFunctionalForm
+UndefinedNodeFunctionalForm
+as_node_functional_form
+```
diff --git a/src/node.jl b/src/node.jl
index 2ef318729..1122cf6a1 100644
--- a/src/node.jl
+++ b/src/node.jl
@@ -9,7 +9,9 @@ export iscontain, isfactorised, getinterface
export clusters, clusterindex
export connect!, activate!
export make_node
-export DefaultFunctionalDependencies, RequireInboundFunctionalDependencies, RequireEverythingFunctionalDependencies
+export DefaultFunctionalDependencies,
+ RequireMessageFunctionalDependencies,
+ RequireMarginalFunctionalDependencies, RequireEverythingFunctionalDependencies
export @node
using Rocket
@@ -21,7 +23,7 @@ import Base: getindex, setindex!, firstindex, lastindex
## Node traits
"""
- ValidNodeFunctionalForm
+ ValidNodeFunctionalForm
Trait specification for an object that can be used in model specification as a factor node.
@@ -146,7 +148,7 @@ struct FullFactorisation end
"""
collect_factorisation(nodetype, factorisation)
-This function converts given factorisation to a correct internal factorisation representation for a given node.
+This function converts given factorisation to a correct internal factorisation representation for a given node.
See also: [`MeanField`](@ref), [`FullFactorisation`](@ref)
"""
@@ -155,7 +157,7 @@ function collect_factorisation end
"""
collect_meta(nodetype, meta)
-This function converts given meta object to a correct internal meta representation for a given node.
+This function converts given meta object to a correct internal meta representation for a given node.
Fallbacks to `default_meta` in case if meta is `nothing`.
See also: [`default_meta`](@ref), [`FactorNode`](@ref)
@@ -236,7 +238,7 @@ local_constraint(interface::NodeInterface) = interface.local_constraint
"""
tag(interface)
-Returns a tag of the interface in the form of `Val{ name(interface) }`.
+Returns a tag of the interface in the form of `Val{ name(interface) }`.
The major difference between tag and name is that it is possible to dispath on interface's tag in message computation rule.
See also: [`NodeInterface`](@ref), [`name`](@ref)
@@ -306,8 +308,8 @@ get_pipeline_stages(interface::NodeInterface) = get_pipeline_stages(connectedvar
"""
IndexedNodeInterface
-`IndexedNodeInterface` object represents a repetative node-variable connection.
-Used in cases when node may connect different number of random variables with the same name, e.g. means and precisions of Gaussian Mixture node.
+`IndexedNodeInterface` object represents a repetative node-variable connection.
+Used in cases when a node may connect to a different number of random variables with the same name, e.g. means and precisions of a Gaussian Mixture node.
See also: [`name`](@ref), [`tag`](@ref), [`messageout`](@ref), [`messagein`](@ref)
"""
@@ -339,20 +341,25 @@ get_pipeline_stages(interface::IndexedNodeInterface) = get_pipelin
"""
FactorNodeLocalMarginal
-This object represents local marginals for some specific factor node.
-Local marginal can be joint in case of structured factorisation.
+This object represents local marginals for some specific factor node.
+The local marginal can be joint in case of structured factorisation.
Local to factor node marginal also can be shared with a corresponding marginal of some random variable.
See also: [`FactorNodeLocalMarginals`](@ref)
"""
mutable struct FactorNodeLocalMarginal
index :: Int
+ first :: Int
name :: Symbol
stream :: Union{Nothing, AbstractSubscribable{<:Marginal}}
- FactorNodeLocalMarginal(index::Int, name::Symbol) = new(index, name, nothing)
+ FactorNodeLocalMarginal(index::Int, first::Int, name::Symbol) = new(index, first, name, nothing)
end
+# `First` defines the index of the first element in the joint marginal
+# E.g. if the set of variables is (x, y, z, w) and joint is `z_w`, first is equal to 3
+Base.first(localmarginal::FactorNodeLocalMarginal) = localmarginal.first
+
index(localmarginal::FactorNodeLocalMarginal) = localmarginal.index
name(localmarginal::FactorNodeLocalMarginal) = localmarginal.name
@@ -362,7 +369,7 @@ setstream!(localmarginal::FactorNodeLocalMarginal, observable) = localmarginal.s
"""
FactorNodeLocalMarginals
-This object acts as an iterable and indexable proxy for local marginals for some node.
+This object acts as an iterable and indexable proxy for local marginals for some node.
"""
struct FactorNodeLocalMarginals{M}
marginals::M
@@ -371,10 +378,10 @@ end
function FactorNodeLocalMarginals(variablenames, factorisation)
marginal_names = map(fcluster -> clustername(map(i -> variablenames[i], fcluster)), factorisation)
index = 0 # its better not to use zip or enumerate here to preserve tuple-like structure
- marginals = map((mname) -> begin
+ marginals = map(marginal_names) do mname
index += 1
- FactorNodeLocalMarginal(index, mname)
- end, marginal_names)
+ return FactorNodeLocalMarginal(index, first(factorisation[index]), mname)
+ end
return FactorNodeLocalMarginals(marginals)
end
@@ -454,7 +461,7 @@ getpipeline(factornode::FactorNode) = factornode.pipeline
clustername(cluster) = mapreduce(v -> name(v), (a, b) -> Symbol(a, :_, b), cluster)
-# Cluster is reffered to a tuple of node interfaces
+# Cluster is reffered to a tuple of node interfaces
clusters(factornode::FactorNode) =
map(factor -> map(i -> @inbounds(interfaces(factornode)[i]), factor), factorisation(factornode))
@@ -523,7 +530,7 @@ collect_pipeline(T::Any, stage::AbstractPipelineStage) = Fact
collect_pipeline(T::Any, fdp::AbstractNodeFunctionalDependenciesPipeline) = FactorNodePipeline(fdp, EmptyPipelineStage())
collect_pipeline(T::Any, pipeline::FactorNodePipeline) = pipeline
-## Functional Dependencies
+## Functional Dependencies
function message_dependencies end
function marginal_dependencies end
@@ -531,50 +538,91 @@ function marginal_dependencies end
Base.:+(left::AbstractNodeFunctionalDependenciesPipeline, right::AbstractPipelineStage) = FactorNodePipeline(left, right)
Base.:+(left::FactorNodePipeline, right::AbstractPipelineStage) = FactorNodePipeline(left.functional_dependencies, left.extra_stages + right)
-### Default
+### Default
"""
DefaultFunctionalDependencies
+
+This pipeline translates directly to enforcing a variational message passing scheme. In order to compute a message out of some edge, this pipeline requires
+messages from edges within the same edge-cluster and marginals over other edge-clusters.
+
+See also: [`ReactiveMP.RequireMessageFunctionalDependencies`](@ref), [`ReactiveMP.RequireMarginalFunctionalDependencies`](@ref), [`ReactiveMP.RequireEverythingFunctionalDependencies`](@ref)
"""
struct DefaultFunctionalDependencies <: AbstractNodeFunctionalDependenciesPipeline end
-function message_dependencies(::DefaultFunctionalDependencies, nodeinterfaces, varcluster, iindex)
+function message_dependencies(
+ ::DefaultFunctionalDependencies,
+ nodeinterfaces,
+ nodelocalmarginals,
+ varcluster,
+ cindex,
+ iindex
+)
# First we remove current edge index from the list of dependencies
vdependencies = TupleTools.deleteat(varcluster, varclusterindex(varcluster, iindex))
# Second we map interface indices to the actual interfaces
return map(inds -> map(i -> @inbounds(nodeinterfaces[i]), inds), vdependencies)
end
-function marginal_dependencies(::DefaultFunctionalDependencies, nodelocalmarginals, varcluster, cindex)
+function marginal_dependencies(
+ ::DefaultFunctionalDependencies,
+ nodeinterfaces,
+ nodelocalmarginals,
+ varcluster,
+ cindex,
+ iindex
+)
return TupleTools.deleteat(nodelocalmarginals, cindex)
end
-### With inbound
+### With inbound messages
-struct RequireInboundFunctionalDependencies{I, S} <: AbstractNodeFunctionalDependenciesPipeline
- indices :: I
- start_with :: S
-end
+"""
+ RequireMessageFunctionalDependencies(indices::Tuple, start_with::Tuple)
-struct InterfacePluginStartWithMessage{M, S}
- msg :: M
- start_with :: S
-end
+The same as `DefaultFunctionalDependencies`, but in order to compute a message out of some edge also requires the inbound message on the this edge.
-name(p::InterfacePluginStartWithMessage) = name(p.msg)
-messagein(p::InterfacePluginStartWithMessage) = messagein(p.start_with, p)
+# Arguments
-messagein(::Nothing, p::InterfacePluginStartWithMessage) = messagein(p.msg)
+- `indices`::Tuple, tuple of integers, which indicates what edges should require inbound messages
+- `start_with::Tuple`, tuple of `nothing` or `<:Distribution`, which specifies the initial inbound messages for edges in `indices`
-function messagein(something, p::InterfacePluginStartWithMessage)
- output = messagein(p.msg)
- if isnothing(getrecent(output))
- setmessage!(output, something)
- end
- return output
+Note: `start_with` uses `setmessage!` mechanism, hence, it can be visible by other listeners on the same edge. Explicit call to `setmessage!` overwrites whatever has been passed in `start_with`.
+
+`@model` macro accepts a simplified construction of this pipeline:
+
+```julia
+@model function some_model()
+ # ...
+ y ~ NormalMeanVariance(x, τ) where {
+ pipeline = RequireMessage(x = vague(NormalMeanPrecision), τ)
+ # ^^^ ^^^
+ # request 'inbound' for 'x' we may do the same for 'τ',
+ # and initialise with `vague(...)` but here we skip initialisation
+ }
+ # ...
end
+```
-function message_dependencies(dependencies::RequireInboundFunctionalDependencies, nodeinterfaces, varcluster, iindex)
+Deprecation warning: `RequireInboundFunctionalDependencies` has been deprecated in favor of `RequireMessageFunctionalDependencies`.
+
+See also: [`ReactiveMP.DefaultFunctionalDependencies`](@ref), [`ReactiveMP.RequireMarginalFunctionalDependencies`](@ref), [`ReactiveMP.RequireEverythingFunctionalDependencies`](@ref)
+"""
+struct RequireMessageFunctionalDependencies{I, S} <: AbstractNodeFunctionalDependenciesPipeline
+ indices :: I
+ start_with :: S
+end
+
+Base.@deprecate_binding RequireInboundFunctionalDependencies RequireMessageFunctionalDependencies
+
+function message_dependencies(
+ dependencies::RequireMessageFunctionalDependencies,
+ nodeinterfaces,
+ nodelocalmarginals,
+ varcluster,
+ cindex,
+ iindex
+)
# First we find dependency index in `indices`, we use it later to find `start_with` distribution
depindex = findfirst((i) -> i === iindex, dependencies.indices)
@@ -582,52 +630,194 @@ function message_dependencies(dependencies::RequireInboundFunctionalDependencies
# If we have `depindex` in our `indices` we include it in our list of functional dependencies. It effectively forces rule to require inbound message
if depindex !== nothing
# `mapindex` is a lambda function here
- mapindex = let nodeinterfaces = nodeinterfaces, depindex = depindex
- (i) -> begin
- interface = @inbounds nodeinterfaces[i]
- # InterfacePluginStartWithMessage is a proxy structure for `name` and `messagein` method for an interface
- # It returns the same name but modifies `messagein` to return an observable with `start_with` operator
- return if i === iindex
- InterfacePluginStartWithMessage(interface, dependencies.start_with[depindex])
- else
- interface
- end
- end
+ output = messagein(nodeinterfaces[iindex])
+ start_with = dependencies.start_with[depindex]
+ # Initialise now, if message has not been initialised before and `start_with` element is not empty
+ if isnothing(getrecent(output)) && !isnothing(start_with)
+ setmessage!(output, start_with)
end
- return map(inds -> map(mapindex, inds), varcluster)
+ return map(inds -> map(i -> @inbounds(nodeinterfaces[i]), inds), varcluster)
else
- return message_dependencies(DefaultFunctionalDependencies(), nodeinterfaces, varcluster, iindex)
+ return message_dependencies(
+ DefaultFunctionalDependencies(),
+ nodeinterfaces,
+ nodelocalmarginals,
+ varcluster,
+ cindex,
+ iindex
+ )
end
end
-function marginal_dependencies(::RequireInboundFunctionalDependencies, nodelocalmarginals, varcluster, cindex)
- return marginal_dependencies(DefaultFunctionalDependencies(), nodelocalmarginals, varcluster, cindex)
+function marginal_dependencies(
+ ::RequireMessageFunctionalDependencies,
+ nodeinterfaces,
+ nodelocalmarginals,
+ varcluster,
+ cindex,
+ iindex
+)
+ return marginal_dependencies(
+ DefaultFunctionalDependencies(),
+ nodeinterfaces,
+ nodelocalmarginals,
+ varcluster,
+ cindex,
+ iindex
+ )
+end
+
+### With marginals
+
+"""
+ RequireMarginalFunctionalDependencies(indices::Tuple, start_with::Tuple)
+
+Similar to `DefaultFunctionalDependencies`, but in order to compute a message out of some edge also requires the posterior marginal on that edge.
+
+# Arguments
+
+- `indices`::Tuple, tuple of integers, which indicates what edges should require their own marginals
+- `start_with::Tuple`, tuple of `nothing` or `<:Distribution`, which specifies the initial marginal for edges in `indices`
+
+Note: `start_with` uses the `setmarginal!` mechanism, hence it can be visible to other listeners on the same edge. Explicit calls to `setmarginal!` overwrites whatever has been passed in `start_with`.
+
+`@model` macro accepts a simplified construction of this pipeline:
+
+```julia
+@model function some_model()
+ # ...
+ y ~ NormalMeanVariance(x, τ) where {
+ pipeline = RequireMarginal(x = vague(NormalMeanPrecision), τ)
+ # ^^^ ^^^
+ # request 'marginal' for 'x' we may do the same for 'τ',
+ # and initialise with `vague(...)` but here we skip initialisation
+ }
+ # ...
+end
+```
+
+Note: The simplified construction in `@model` macro syntax is only available in `GraphPPL.jl` of version `>2.2.0`.
+
+See also: [`ReactiveMP.DefaultFunctionalDependencies`](@ref), [`ReactiveMP.RequireMessageFunctionalDependencies`](@ref), [`ReactiveMP.RequireEverythingFunctionalDependencies`](@ref)
+"""
+struct RequireMarginalFunctionalDependencies{I, S} <: AbstractNodeFunctionalDependenciesPipeline
+ indices :: I
+ start_with :: S
+end
+
+function message_dependencies(
+ ::RequireMarginalFunctionalDependencies,
+ nodeinterfaces,
+ nodelocalmarginals,
+ varcluster,
+ cindex,
+ iindex
+)
+ return message_dependencies(
+ DefaultFunctionalDependencies(),
+ nodeinterfaces,
+ nodelocalmarginals,
+ varcluster,
+ cindex,
+ iindex
+ )
+end
+
+function marginal_dependencies(
+ dependencies::RequireMarginalFunctionalDependencies,
+ nodeinterfaces,
+ nodelocalmarginals,
+ varcluster,
+ cindex,
+ iindex
+)
+ # First we find dependency index in `indices`, we use it later to find `start_with` distribution
+ depindex = findfirst((i) -> i === iindex, dependencies.indices)
+
+ if depindex !== nothing
+ # We create an auxiliary local marginal with non-standard index here and inject it to other standard dependencies
+ extra_localmarginal = FactorNodeLocalMarginal(-1, iindex, name(nodeinterfaces[iindex]))
+ vmarginal = getmarginal(connectedvar(nodeinterfaces[iindex]), IncludeAll())
+ start_with = dependencies.start_with[depindex]
+ # Initialise now, if marginal has not been initialised before and `start_with` element is not empty
+ if isnothing(getrecent(vmarginal)) && !isnothing(start_with)
+ setmarginal!(vmarginal, start_with)
+ end
+ setstream!(extra_localmarginal, vmarginal)
+ default = marginal_dependencies(
+ DefaultFunctionalDependencies(),
+ nodeinterfaces,
+ nodelocalmarginals,
+ varcluster,
+ cindex,
+ iindex
+ )
+ # Find insertion position (probably might be implemented more efficiently)
+ insertafter = sum(first(el) < iindex ? 1 : 0 for el in default; init = 0)
+ return TupleTools.insertafter(default, insertafter, (extra_localmarginal,))
+ else
+ return marginal_dependencies(
+ DefaultFunctionalDependencies(),
+ nodeinterfaces,
+ nodelocalmarginals,
+ varcluster,
+ cindex,
+ iindex
+ )
+ end
end
### Everything
+"""
+ RequireEverythingFunctionalDependencies
+
+This pipeline specifies that in order to compute a message of some edge update rules request everything that is available locally.
+This includes all inbound messages (including on the same edge) and marginals over all local edge-clusters (this may or may not include marginals on single edges, depends on the local factorisation constraint).
+
+See also: [`DefaultFunctionalDependencies`](@ref), [`RequireMessageFunctionalDependencies`](@ref), [`RequireMarginalFunctionalDependencies`](@ref)
+"""
struct RequireEverythingFunctionalDependencies <: AbstractNodeFunctionalDependenciesPipeline end
-function ReactiveMP.message_dependencies(::RequireEverythingFunctionalDependencies, nodeinterfaces, varcluster, iindex)
- # Return all node interfaces including the edge we are trying to compuate a message on
+function ReactiveMP.message_dependencies(
+ ::RequireEverythingFunctionalDependencies,
+ nodeinterfaces,
+ nodelocalmarginals,
+ varcluster,
+ cindex,
+ iindex
+)
+ # Return all node interfaces including the edge we are trying to compute a message on
return nodeinterfaces
end
function ReactiveMP.marginal_dependencies(
::RequireEverythingFunctionalDependencies,
+ nodeinterfaces,
nodelocalmarginals,
varcluster,
- cindex
+ cindex,
+ iindex
)
# Returns only local marginals based on local q factorisation, it does not return all possible combinations of all joint posterior marginals
return nodelocalmarginals
end
-###
+###
default_functional_dependencies_pipeline(_) = DefaultFunctionalDependencies()
-### Generic
+### Generic `functional_dependencies` for `AbstractFactorNode`
+
+function functional_dependencies(factornode::AbstractFactorNode, iname::Symbol)
+ return functional_dependencies(get_pipeline_dependencies(getpipeline(factornode)), factornode, iname)
+end
+
+function functional_dependencies(factornode::AbstractFactorNode, iindex::Int)
+ return functional_dependencies(get_pipeline_dependencies(getpipeline(factornode)), factornode, iindex)
+end
+
+### `FactorNode` implementation of `functional_dependencies`
function functional_dependencies(dependencies, factornode::FactorNode, iname::Symbol)
return functional_dependencies(dependencies, factornode, interfaceindex(factornode, iname))
@@ -642,8 +832,8 @@ function functional_dependencies(dependencies, factornode::FactorNode, iindex::I
varcluster = @inbounds nodeclusters[cindex]
- messages = message_dependencies(dependencies, nodeinterfaces, varcluster, iindex)
- marginals = marginal_dependencies(dependencies, nodelocalmarginals, varcluster, cindex)
+ messages = message_dependencies(dependencies, nodeinterfaces, nodelocalmarginals, varcluster, cindex, iindex)
+ marginals = marginal_dependencies(dependencies, nodeinterfaces, nodelocalmarginals, varcluster, cindex, iindex)
return tuple(messages...), tuple(marginals...)
end
@@ -670,18 +860,15 @@ function get_marginals_observable(factornode, marginals)
end
function activate!(model, factornode::AbstractFactorNode)
- fform = functionalform(factornode)
- meta = metadata(factornode)
- node_pipeline = getpipeline(factornode)
-
- node_pipeline_dependencies = get_pipeline_dependencies(node_pipeline)
+ fform = functionalform(factornode)
+ meta = metadata(factornode)
+ node_pipeline = getpipeline(factornode)
node_pipeline_extra_stages = get_pipeline_stages(node_pipeline)
for (iindex, interface) in enumerate(interfaces(factornode))
cvariable = connectedvar(interface)
if cvariable !== nothing && (israndom(cvariable) || isdata(cvariable))
- message_dependencies, marginal_dependencies =
- functional_dependencies(node_pipeline_dependencies, factornode, iindex)
+ message_dependencies, marginal_dependencies = functional_dependencies(factornode, iindex)
msgs_names, msgs_observable = get_messages_observable(factornode, message_dependencies)
marginal_names, marginals_observable = get_marginals_observable(factornode, marginal_dependencies)
@@ -841,7 +1028,7 @@ import .MacroHelpers
# Examples
```julia
-struct MyNormalDistribution
+struct MyNormalDistribution
mean :: Float64
var :: Float64
end
@@ -849,7 +1036,7 @@ end
@node MyNormalDistribution Stochastic [ out, mean, var ]
```
-```julia
+```julia
@node typeof(+) Deterministic [ out, in1, in2 ]
```
@@ -898,7 +1085,7 @@ macro node(fformtype, sdtype, interfaces_list)
$non_unique_error_sym =
(fformtype, names) ->
"""
- Non-unique variables used for the creation of the `$(fformtype)` node, which is disallowed.
+ Non-unique variables used for the creation of the `$(fformtype)` node, which is disallowed.
Check creation of the `$(fformtype)` with the `[ $(join(names, ", ")) ]` arguments.
"""
)
@@ -939,9 +1126,9 @@ macro node(fformtype, sdtype, interfaces_list)
end
end
- # By default every argument passed to a factorisation option of the node is transformed by
+ # By default every argument passed to a factorisation option of the node is transformed by
# `collect_factorisation` function to have a tuple like structure.
- # The default recipe is simple: for stochastic nodes we convert `FullFactorisation` and `MeanField` objects
+ # The default recipe is simple: for stochastic nodes we convert `FullFactorisation` and `MeanField` objects
# to their tuple of indices equivalents. For deterministic nodes any factorisation is replaced by a FullFactorisation equivalent
factorisation_collectors = if sdtype === :Stochastic
quote
diff --git a/src/nodes/probit.jl b/src/nodes/probit.jl
index 381d314e8..1a5008a61 100644
--- a/src/nodes/probit.jl
+++ b/src/nodes/probit.jl
@@ -17,7 +17,7 @@ ProbitMeta(; p = 32) = ProbitMeta(p)
default_meta(::Type{Probit}) = ProbitMeta(32)
default_functional_dependencies_pipeline(::Type{<:Probit}) =
- RequireInboundFunctionalDependencies((2,), (vague(NormalMeanPrecision),))
+ RequireMessageFunctionalDependencies((2,), (vague(NormalMeanPrecision),))
default_interface_local_constraint(::Type{<:Probit}, edge::Val{:in}) = MomentMatching()
default_interface_local_constraint(::Type{<:Probit}, edge::Val{:out}) = Marginalisation()
diff --git a/test/models/test_probit.jl b/test/models/test_probit.jl
index e224c6b95..bf2c0aff0 100644
--- a/test/models/test_probit.jl
+++ b/test/models/test_probit.jl
@@ -23,7 +23,7 @@ using StatsFuns: normcdf
for k in 2:nr_samples+1
x[k] ~ NormalMeanPrecision(x[k-1] + 0.1, 100)
y[k-1] ~ Probit(x[k]) where {
- pipeline = RequireInbound(in = NormalMeanPrecision(0, 1.0))
+ pipeline = RequireMessage(in = NormalMeanPrecision(0, 1.0))
}
end
diff --git a/test/test_node.jl b/test/test_node.jl
index 32968357c..ea4c68736 100644
--- a/test/test_node.jl
+++ b/test/test_node.jl
@@ -2,6 +2,7 @@ module ReactiveMPNodeTest
using Test
using ReactiveMP
+using Rocket
using Distributions
@testset "FactorNode" begin
@@ -22,6 +23,691 @@ using Distributions
@test_throws MethodError sdtype(0)
end
+ @testset "Functional dependencies pipelines" begin
+ struct DummyStochasticNode end
+
+ @node DummyStochasticNode Stochastic [x, y, z]
+
+ function make_dummy_model(factorisation, pipeline)
+ m = ReactiveMP.FactorGraphModel()
+ x = randomvar(m, :x)
+ y = randomvar(m, :y)
+ z = randomvar(m, :z)
+ make_node(m, FactorNodeCreationOptions(nothing, nothing, nothing), Uninformative, x)
+ make_node(m, FactorNodeCreationOptions(nothing, nothing, nothing), Uninformative, y)
+ make_node(m, FactorNodeCreationOptions(nothing, nothing, nothing), Uninformative, z)
+ node =
+ make_node(m, FactorNodeCreationOptions(factorisation, nothing, pipeline), DummyStochasticNode, x, y, z)
+ activate!(m)
+ return m, x, y, z, node
+ end
+
+ @testset "Default functional dependencies" begin
+ @testset "Default functional dependencies: FullFactorisation" begin
+ # We test `FullFactorisation` case here
+ m, x, y, z, node = make_dummy_model(FullFactorisation(), DefaultFunctionalDependencies())
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) ===
+ DefaultFunctionalDependencies()
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 2 && name(x_msgdeps[1]) === :y && name(x_msgdeps[2]) === :z
+ @test length(x_mgdeps) === 0
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 2 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :z
+ @test length(y_mgdeps) === 0
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 2 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :y
+ @test length(z_mgdeps) === 0
+ end
+
+ @testset "Default functional dependencies: MeanField" begin
+ # We test `MeanField` case here
+ m, x, y, z, node = make_dummy_model(MeanField(), DefaultFunctionalDependencies())
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) ===
+ DefaultFunctionalDependencies()
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 0
+ @test length(x_mgdeps) === 2 && name(x_mgdeps[1]) === :y && name(x_mgdeps[2]) === :z
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 0
+ @test length(y_mgdeps) === 2 && name(y_mgdeps[1]) === :x && name(y_mgdeps[2]) === :z
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 0
+ @test length(z_mgdeps) === 2 && name(z_mgdeps[1]) === :x && name(z_mgdeps[2]) === :y
+ end
+
+ @testset "Default functional dependencies: Structured factorisation" begin
+ # We test `(x, y), (z)` factorisation case here
+ m, x, y, z, node = make_dummy_model(((1, 2), (3,)), DefaultFunctionalDependencies())
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) ===
+ DefaultFunctionalDependencies()
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 1 && name(x_msgdeps[1]) === :y
+ @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :z
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 1 && name(y_msgdeps[1]) === :x
+ @test length(y_mgdeps) === 1 && name(y_mgdeps[1]) === :z
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 0
+ @test length(z_mgdeps) === 1 && name(z_mgdeps[1]) === :x_y
+
+ ## --- ##
+
+ # We test `(x, ), (y, z)` factorisation case here
+ m, x, y, z, node = make_dummy_model(((1,), (2, 3)), DefaultFunctionalDependencies())
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) ===
+ DefaultFunctionalDependencies()
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 0
+ @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :y_z
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 1 && name(y_msgdeps[1]) === :z
+ @test length(y_mgdeps) === 1 && name(y_mgdeps[1]) === :x
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 1 && name(z_msgdeps[1]) === :y
+ @test length(z_mgdeps) === 1 && name(z_mgdeps[1]) === :x
+
+ ## --- ##
+
+ # We test `(x, z), (y, )` factorisation case here
+ m, x, y, z, node = make_dummy_model(((1, 3), (2,)), DefaultFunctionalDependencies())
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) ===
+ DefaultFunctionalDependencies()
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 1 && name(x_msgdeps[1]) === :z
+ @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :y
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 0
+ @test length(y_mgdeps) === 1 && name(y_mgdeps[1]) === :x_z
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 1 && name(z_msgdeps[1]) === :x
+ @test length(z_mgdeps) === 1 && name(z_mgdeps[1]) === :y
+ end
+ end
+
+ @testset "Require inbound message functional dependencies" begin
+ @testset "Require inbound message functional dependencies: FullFactorisation" begin
+ # Require inbound message on `x`
+ pipeline = RequireMessageFunctionalDependencies((1,), (NormalMeanVariance(0.123, 0.123),))
+
+ # We test `FullFactorisation` case here
+ m, x, y, z, node = make_dummy_model(FullFactorisation(), pipeline)
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 3 && name(x_msgdeps[1]) === :x && name(x_msgdeps[2]) === :y &&
+ name(x_msgdeps[3]) === :z
+ @test length(x_mgdeps) === 0
+ @test mean_var(Rocket.getrecent(ReactiveMP.messagein(x_msgdeps[1]))) == (0.123, 0.123)
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 2 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :z
+ @test length(y_mgdeps) === 0
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 2 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :y
+ @test length(z_mgdeps) === 0
+
+ ## -- ##
+
+ # Require inbound message on `y` and `z`
+ pipeline = RequireMessageFunctionalDependencies((2, 3), (NormalMeanVariance(0.123, 0.123), nothing))
+
+ # We test `FullFactorisation` case here
+ m, x, y, z, node = make_dummy_model(FullFactorisation(), pipeline)
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 2 && name(x_msgdeps[1]) === :y && name(x_msgdeps[2]) === :z
+ @test length(x_mgdeps) === 0
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 3 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :y &&
+ name(y_msgdeps[3]) === :z
+ @test length(y_mgdeps) === 0
+ @test mean_var(Rocket.getrecent(ReactiveMP.messagein(y_msgdeps[2]))) == (0.123, 0.123)
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 3 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :y &&
+ name(z_msgdeps[3]) === :z
+ @test length(z_mgdeps) === 0
+ @test isnothing(Rocket.getrecent(ReactiveMP.messagein(z_msgdeps[3])))
+ end
+
+ @testset "Require inbound message functional dependencies: MeanField" begin
+ # Require inbound message on `x`
+ pipeline = RequireMessageFunctionalDependencies((1,), (NormalMeanVariance(0.123, 0.123),))
+
+ # We test `MeanField` case here
+ m, x, y, z, node = make_dummy_model(MeanField(), pipeline)
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 1 && name(x_msgdeps[1]) === :x
+ @test length(x_mgdeps) === 2 && name(x_mgdeps[1]) === :y && name(x_mgdeps[2]) === :z
+ @test mean_var(Rocket.getrecent(ReactiveMP.messagein(x_msgdeps[1]))) == (0.123, 0.123)
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 0
+ @test length(y_mgdeps) === 2 && name(y_mgdeps[1]) === :x && name(y_mgdeps[2]) === :z
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 0
+ @test length(z_mgdeps) === 2 && name(z_mgdeps[1]) === :x && name(z_mgdeps[2]) === :y
+
+ ## -- ##
+
+ # Require inbound message on `y` and `z`
+ pipeline = RequireMessageFunctionalDependencies((2, 3), (NormalMeanVariance(0.123, 0.123), nothing))
+
+ # We test `MeanField` case here
+ m, x, y, z, node = make_dummy_model(MeanField(), pipeline)
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 0
+ @test length(x_mgdeps) === 2 && name(x_mgdeps[1]) === :y && name(x_mgdeps[2]) === :z
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 1 && name(y_msgdeps[1]) === :y
+ @test length(y_mgdeps) === 2 && name(y_mgdeps[1]) === :x && name(y_mgdeps[2]) === :z
+ @test mean_var(Rocket.getrecent(ReactiveMP.messagein(y_msgdeps[1]))) == (0.123, 0.123)
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 1 && name(z_msgdeps[1]) === :z
+ @test length(z_mgdeps) === 2 && name(z_mgdeps[1]) === :x && name(z_mgdeps[2]) === :y
+ @test isnothing(Rocket.getrecent(ReactiveMP.messagein(z_msgdeps[1])))
+ end
+
+ @testset "Require inbound message dependencies: Structured factorisation" begin
+ # Require inbound message on `y` and `z`
+ pipeline = RequireMessageFunctionalDependencies((2, 3), (NormalMeanVariance(0.123, 0.123), nothing))
+
+ # We test `(x, y), (z)` factorisation case here
+ m, x, y, z, node = make_dummy_model(((1, 2), (3,)), pipeline)
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 1 && name(x_msgdeps[1]) === :y
+ @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :z
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 2 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :y
+ @test length(y_mgdeps) === 1 && name(y_mgdeps[1]) === :z
+ @test mean_var(Rocket.getrecent(ReactiveMP.messagein(y_msgdeps[2]))) == (0.123, 0.123)
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 1 && name(z_msgdeps[1]) === :z
+ @test length(z_mgdeps) === 1 && name(z_mgdeps[1]) === :x_y
+ @test isnothing(Rocket.getrecent(ReactiveMP.messagein(z_msgdeps[1])))
+
+ ## --- ##
+
+ # Require inbound message on `y` and `z`
+ pipeline = RequireMessageFunctionalDependencies((2, 3), (NormalMeanVariance(0.123, 0.123), nothing))
+
+ # We test `(x, ), (y, z)` factorisation case here
+ m, x, y, z, node = make_dummy_model(((1,), (2, 3)), pipeline)
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 0
+ @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :y_z
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 2 && name(y_msgdeps[1]) === :y && name(y_msgdeps[2]) === :z
+ @test length(y_mgdeps) === 1 && name(y_mgdeps[1]) === :x
+ @test mean_var(Rocket.getrecent(ReactiveMP.messagein(y_msgdeps[1]))) == (0.123, 0.123)
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 2 && name(z_msgdeps[1]) === :y && name(z_msgdeps[2]) === :z
+ @test length(z_mgdeps) === 1 && name(z_mgdeps[1]) === :x
+ @test isnothing(Rocket.getrecent(ReactiveMP.messagein(z_msgdeps[2])))
+
+ ## --- ##
+
+ # Require inbound message on `y` and `z`
+ pipeline = RequireMessageFunctionalDependencies((2, 3), (NormalMeanVariance(0.123, 0.123), nothing))
+
+ # We test `(x, z), (y, )` factorisation case here
+ m, x, y, z, node = make_dummy_model(((1, 3), (2,)), pipeline)
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 1 && name(x_msgdeps[1]) === :z
+ @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :y
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 1 && name(y_msgdeps[1]) === :y
+ @test length(y_mgdeps) === 1 && name(y_mgdeps[1]) === :x_z
+ @test mean_var(Rocket.getrecent(ReactiveMP.messagein(y_msgdeps[1]))) == (0.123, 0.123)
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 2 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :z
+ @test length(z_mgdeps) === 1 && name(z_mgdeps[1]) === :y
+ @test isnothing(Rocket.getrecent(ReactiveMP.messagein(z_msgdeps[2])))
+ end
+ end
+
+ @testset "Require marginal functional dependencies" begin
+ @testset "Require marginal functional dependencies: FullFactorisation" begin
+ # Require marginal on `x`
+ pipeline = RequireMarginalFunctionalDependencies((1,), (NormalMeanVariance(0.123, 0.123),))
+
+ # We test `FullFactorisation` case here
+ m, x, y, z, node = make_dummy_model(FullFactorisation(), pipeline)
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 2 && name(x_msgdeps[1]) === :y && name(x_msgdeps[2]) === :z
+ @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :x
+ @test mean_var(Rocket.getrecent(ReactiveMP.getmarginal(x, IncludeAll()))) == (0.123, 0.123)
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 2 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :z
+ @test length(y_mgdeps) === 0
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 2 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :y
+ @test length(z_mgdeps) === 0
+
+ ## -- ##
+
+ # Require marginals on `y` and `z`
+ pipeline = RequireMarginalFunctionalDependencies((2, 3), (NormalMeanVariance(0.123, 0.123), nothing))
+
+ # We test `FullFactorisation` case here
+ m, x, y, z, node = make_dummy_model(FullFactorisation(), pipeline)
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 2 && name(x_msgdeps[1]) === :y && name(x_msgdeps[2]) === :z
+ @test length(x_mgdeps) === 0
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 2 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :z
+ @test length(y_mgdeps) === 1 && name(y_mgdeps[1]) === :y
+ @test mean_var(Rocket.getrecent(ReactiveMP.getmarginal(y, IncludeAll()))) == (0.123, 0.123)
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 2 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :y
+ @test length(z_mgdeps) === 1 && name(z_mgdeps[1]) === :z
+ @test isnothing(Rocket.getrecent(ReactiveMP.getmarginal(z, IncludeAll())))
+ end
+
+ @testset "Require marginal functional dependencies: MeanField" begin
+ # Require marginal on `x`
+ pipeline = RequireMarginalFunctionalDependencies((1,), (NormalMeanVariance(0.123, 0.123),))
+
+ # We test `MeanField` case here
+ m, x, y, z, node = make_dummy_model(MeanField(), pipeline)
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 0
+ @test length(x_mgdeps) === 3 && name(x_mgdeps[1]) === :x && name(x_mgdeps[2]) === :y &&
+ name(x_mgdeps[3]) === :z
+ @test mean_var(Rocket.getrecent(ReactiveMP.getmarginal(x, IncludeAll()))) == (0.123, 0.123)
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 0
+ @test length(y_mgdeps) === 2 && name(y_mgdeps[1]) === :x && name(y_mgdeps[2]) === :z
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 0
+ @test length(z_mgdeps) === 2 && name(z_mgdeps[1]) === :x && name(z_mgdeps[2]) === :y
+
+ ## -- ##
+
+ # Require marginals on `y` and `z`
+ pipeline = RequireMarginalFunctionalDependencies((2, 3), (NormalMeanVariance(0.123, 0.123), nothing))
+
+ # We test `MeanField` case here
+ m, x, y, z, node = make_dummy_model(MeanField(), pipeline)
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 0
+ @test length(x_mgdeps) === 2 && name(x_mgdeps[1]) === :y && name(x_mgdeps[2]) === :z
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 0
+ @test length(y_mgdeps) === 3 && name(y_mgdeps[1]) === :x && name(y_mgdeps[2]) === :y &&
+ name(y_mgdeps[3]) === :z
+ @test mean_var(Rocket.getrecent(ReactiveMP.getmarginal(y, IncludeAll()))) == (0.123, 0.123)
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 0
+ @test length(z_mgdeps) === 3 && name(z_mgdeps[1]) === :x && name(z_mgdeps[2]) === :y &&
+ name(z_mgdeps[3]) === :z
+ @test isnothing(Rocket.getrecent(ReactiveMP.getmarginal(z, IncludeAll())))
+ end
+
+ @testset "Require marginal functional dependencies: Structured factorisation" begin
+ # Require marginal on `y` and `z`
+ pipeline = RequireMarginalFunctionalDependencies((2, 3), (NormalMeanVariance(0.123, 0.123), nothing))
+
+ # We test `(x, y), (z)` factorisation case here
+ m, x, y, z, node = make_dummy_model(((1, 2), (3,)), pipeline)
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 1 && name(x_msgdeps[1]) === :y
+ @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :z
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 1 && name(y_msgdeps[1]) === :x
+ @test length(y_mgdeps) === 2 && name(y_mgdeps[1]) === :y && name(y_mgdeps[2]) === :z
+ @test mean_var(Rocket.getrecent(ReactiveMP.getmarginal(y, IncludeAll()))) == (0.123, 0.123)
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 0
+ @test length(z_mgdeps) === 2 && name(z_mgdeps[1]) === :x_y && name(z_mgdeps[2]) === :z
+ @test isnothing(Rocket.getrecent(ReactiveMP.getmarginal(z, IncludeAll())))
+
+ ## --- ##
+
+ # Require marginals on `y` and `z`
+ pipeline = RequireMarginalFunctionalDependencies((2, 3), (NormalMeanVariance(0.123, 0.123), nothing))
+
+ # We test `(x, ), (y, z)` factorisation case here
+ m, x, y, z, node = make_dummy_model(((1,), (2, 3)), pipeline)
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 0
+ @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :y_z
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 1 && name(y_msgdeps[1]) === :z
+ @test length(y_mgdeps) === 2 && name(y_mgdeps[1]) === :x && name(y_mgdeps[2]) === :y
+ @test mean_var(Rocket.getrecent(ReactiveMP.getmarginal(y, IncludeAll()))) == (0.123, 0.123)
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 1 && name(z_msgdeps[1]) === :y
+ @test length(z_mgdeps) === 2 && name(z_mgdeps[1]) === :x && name(z_mgdeps[2]) === :z
+ @test isnothing(Rocket.getrecent(ReactiveMP.getmarginal(z, IncludeAll())))
+
+ ## --- ##
+
+ # Require marginals on `y` and `z`
+ pipeline = RequireMarginalFunctionalDependencies((2, 3), (NormalMeanVariance(0.123, 0.123), nothing))
+
+ # We test `(x, z), (y, )` factorisation case here
+ m, x, y, z, node = make_dummy_model(((1, 3), (2,)), pipeline)
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 1 && name(x_msgdeps[1]) === :z
+ @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :y
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 0
+ @test length(y_mgdeps) === 2 && name(y_mgdeps[1]) === :x_z && name(y_mgdeps[2]) === :y
+ @test mean_var(Rocket.getrecent(ReactiveMP.getmarginal(y, IncludeAll()))) == (0.123, 0.123)
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 1 && name(z_msgdeps[1]) === :x
+ @test length(z_mgdeps) === 2 && name(z_mgdeps[1]) === :y && name(z_mgdeps[2]) === :z
+ @test isnothing(Rocket.getrecent(ReactiveMP.getmarginal(z, IncludeAll())))
+ end
+ end
+
+ @testset "Require everything functional dependencies" begin
+ @testset "Require everything functional dependencies: FullFactorisation" begin
+ pipeline = RequireEverythingFunctionalDependencies()
+
+ # We test `FullFactorisation` case here
+ m, x, y, z, node = make_dummy_model(FullFactorisation(), pipeline)
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 3 && name(x_msgdeps[1]) === :x && name(x_msgdeps[2]) === :y &&
+ name(x_msgdeps[3]) === :z
+ @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :x_y_z
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 3 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :y &&
+ name(y_msgdeps[3]) === :z
+ @test length(y_mgdeps) === 1 && name(y_mgdeps[1]) === :x_y_z
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 3 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :y &&
+ name(z_msgdeps[3]) === :z
+ @test length(z_mgdeps) === 1 && name(z_mgdeps[1]) === :x_y_z
+ end
+
+ @testset "Require everything functional dependencies: MeanField" begin
+ pipeline = RequireEverythingFunctionalDependencies()
+
+ # We test `MeanField` case here
+ m, x, y, z, node = make_dummy_model(MeanField(), pipeline)
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 3 && name(x_mgdeps[1]) === :x && name(x_mgdeps[2]) === :y &&
+ name(x_mgdeps[3]) === :z
+ @test length(x_mgdeps) === 3 && name(x_mgdeps[1]) === :x && name(x_mgdeps[2]) === :y &&
+ name(x_mgdeps[3]) === :z
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 3 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :y &&
+ name(y_msgdeps[3]) === :z
+ @test length(y_mgdeps) === 3 && name(y_mgdeps[1]) === :x && name(y_mgdeps[2]) === :y &&
+ name(y_mgdeps[3]) === :z
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 3 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :y &&
+ name(z_msgdeps[3]) === :z
+ @test length(z_mgdeps) === 3 && name(z_mgdeps[1]) === :x && name(z_mgdeps[2]) === :y &&
+ name(z_mgdeps[3]) === :z
+ end
+
+ @testset "Require everything dependencies: Structured factorisation" begin
+ pipeline = RequireEverythingFunctionalDependencies()
+
+ # We test `(x, y), (z)` factorisation case here
+ m, x, y, z, node = make_dummy_model(((1, 2), (3,)), pipeline)
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 3 && name(x_msgdeps[1]) === :x && name(x_msgdeps[2]) === :y &&
+ name(x_msgdeps[3]) === :z
+ @test length(x_mgdeps) === 2 && name(x_mgdeps[1]) === :x_y && name(x_mgdeps[2]) === :z
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 3 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :y &&
+ name(y_msgdeps[3]) === :z
+ @test length(y_mgdeps) === 2 && name(y_mgdeps[1]) === :x_y && name(y_mgdeps[2]) === :z
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 3 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :y &&
+ name(z_msgdeps[3]) === :z
+ @test length(z_mgdeps) === 2 && name(z_mgdeps[1]) === :x_y && name(z_mgdeps[2]) === :z
+
+ ## --- ##
+
+ pipeline = RequireEverythingFunctionalDependencies()
+
+ # We test `(x, ), (y, z)` factorisation case here
+ m, x, y, z, node = make_dummy_model(((1,), (2, 3)), pipeline)
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 3 && name(x_msgdeps[1]) === :x && name(x_msgdeps[2]) === :y &&
+ name(x_msgdeps[3]) === :z
+ @test length(x_mgdeps) === 2 && name(x_mgdeps[1]) === :x && name(x_mgdeps[2]) === :y_z
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 3 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :y &&
+ name(y_msgdeps[3]) === :z
+ @test length(y_mgdeps) === 2 && name(y_mgdeps[1]) === :x && name(y_mgdeps[2]) === :y_z
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 3 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :y &&
+ name(z_msgdeps[3]) === :z
+ @test length(z_mgdeps) === 2 && name(z_mgdeps[1]) === :x && name(z_mgdeps[2]) === :y_z
+
+ ## --- ##
+
+ pipeline = RequireEverythingFunctionalDependencies()
+
+ # We test `(x, z), (y, )` factorisation case here
+ m, x, y, z, node = make_dummy_model(((1, 3), (2,)), pipeline)
+
+ # Test that pipeline dependencies have been set properly
+ @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline
+
+ x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x)
+
+ @test length(x_msgdeps) === 3 && name(x_msgdeps[1]) === :x && name(x_msgdeps[2]) === :y &&
+ name(x_msgdeps[3]) === :z
+ @test length(x_mgdeps) === 2 && name(x_mgdeps[1]) === :x_z && name(x_mgdeps[2]) === :y
+
+ y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y)
+
+ @test length(y_msgdeps) === 3 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :y &&
+ name(y_msgdeps[3]) === :z
+ @test length(y_mgdeps) === 2 && name(y_mgdeps[1]) === :x_z && name(y_mgdeps[2]) === :y
+
+ z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z)
+
+ @test length(z_msgdeps) === 3 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :y &&
+ name(z_msgdeps[3]) === :z
+ @test length(z_mgdeps) === 2 && name(z_mgdeps[1]) === :x_z && name(z_mgdeps[2]) === :y
+ end
+ end
+ end
+
@testset "@node macro" begin
# Testing Stochastic node specification