diff --git a/demo/Expectation Propagation.ipynb b/demo/Expectation Propagation.ipynb index 5f189be29..c7b680364 100644 --- a/demo/Expectation Propagation.ipynb +++ b/demo/Expectation Propagation.ipynb @@ -125,155 +125,7 @@ "outputs": [ { "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ] + "image/svg+xml": "\n\n\n \n \n \n\n\n\n \n \n \n\n\n\n \n \n \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n" }, "execution_count": 5, "metadata": {}, @@ -312,9 +164,9 @@ " for k = 2:nr_samples + 1\n", " x[k] ~ NormalMeanPrecision(x[k - 1] + 0.1, 100)\n", " y[k - 1] ~ Probit(x[k]) where {\n", - " # Probit node by default uses RequireInbound pipeline with vague(NormalMeanPrecision) message as initial value for `in` edge\n", + " # Probit node by default uses RequireMessage pipeline with vague(NormalMeanPrecision) message as initial value for `in` edge\n", " # To change initial value use may specify it manually, like. Changes to the initial message may improve stability in some situations\n", - " pipeline = RequireInbound(in = NormalMeanPrecision(0, 0.01)) \n", + " pipeline = RequireMessage(in = NormalMeanPrecision(0, 0.01)) \n", " }\n", " end\n", " \n", @@ -379,257 +231,7 @@ "outputs": [ { "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - " \n", - " \n", - " \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ] + "image/svg+xml": "\n\n\n \n \n \n\n\n\n \n \n \n\n\n\n \n \n \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n \n \n \n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n" }, "execution_count": 8, "metadata": {}, diff --git a/demo/Normalizing Flow Tutorial.ipynb b/demo/Normalizing Flow Tutorial.ipynb index e17649dba..0220cb70a 100644 --- a/demo/Normalizing Flow Tutorial.ipynb +++ b/demo/Normalizing Flow Tutorial.ipynb @@ -637,7 +637,7 @@ " y_lat2[k] ~ dot(y_lat1[k], [1, 1])\n", "\n", " # specify observations\n", - " y[k] ~ Probit(y_lat2[k]) # default: where { pipeline = RequireInbound(in = NormalMeanPrecision(0, 1.0)) }\n", + " y[k] ~ Probit(y_lat2[k]) # default: where { pipeline = RequireMessage(in = NormalMeanPrecision(0, 1.0)) }\n", "\n", " end\n", "\n", diff --git a/docs/make.jl b/docs/make.jl index 97637315c..1bd4889d0 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -26,15 +26,15 @@ makedocs( ], "Library" => [ "Messages" => "lib/message.md", - "Functional forms" => "lib/form.md", - "Prod implementation" => "lib/prod.md", "Factor nodes" => [ - "Overview" => "lib/node.md", + "Overview" => "lib/nodes/nodes.md", "Flow" => "lib/nodes/flow.md" ], - "Math utils" => "lib/math.md", - "Helper utils" => "lib/helpers.md", - "Exported methods" => "lib/methods.md" + "Functional forms" => "lib/form.md", + "Prod implementation" => "lib/prod.md", + "Math utils" => "lib/math.md", + "Helper utils" => "lib/helpers.md", + "Exported methods" => "lib/methods.md" ], "Examples" => [ "Overview" => "examples/overview.md", diff --git a/docs/src/examples/flow_tutorial.md b/docs/src/examples/flow_tutorial.md index 4dd5f2f10..4cbe338bc 100644 --- a/docs/src/examples/flow_tutorial.md +++ b/docs/src/examples/flow_tutorial.md @@ -346,7 +346,7 @@ The corresponding probabilistic model for the binary classification task can be y_lat2[k] ~ dot(y_lat1[k], [1, 1]) # specify observations - y[k] ~ Probit(y_lat2[k]) # default: where { pipeline = RequireInbound(in = NormalMeanPrecision(0, 1.0)) } + y[k] ~ Probit(y_lat2[k]) # default: where { pipeline = RequireMessage(in = NormalMeanPrecision(0, 1.0)) } end diff --git a/docs/src/examples/probit.md b/docs/src/examples/probit.md index 46c425227..1ee84f44f 100644 --- a/docs/src/examples/probit.md +++ b/docs/src/examples/probit.md @@ -80,9 +80,9 @@ p = plot!(p, data_x[2:end], label = "x") for k = 2:nr_samples + 1 x[k] ~ NormalMeanPrecision(x[k - 1] + 0.1, 100) y[k - 1] ~ Probit(x[k]) where { - # Probit node by default uses RequireInbound pipeline with vague(NormalMeanPrecision) message as initial value for `in` edge + # Probit node by default uses RequireMessage pipeline with vague(NormalMeanPrecision) message as initial value for `in` edge # To change initial value use may specify it manually, like. Changes to the initial message may improve stability in some situations - pipeline = RequireInbound(in = NormalMeanPrecision(0, 0.01)) + pipeline = RequireMessage(in = NormalMeanPrecision(0, 0.01)) } end diff --git a/docs/src/lib/node.md b/docs/src/lib/node.md deleted file mode 100644 index 8d1fe755c..000000000 --- a/docs/src/lib/node.md +++ /dev/null @@ -1,68 +0,0 @@ - -# [Nodes implementation](@id lib-node) - -In message passing framework one of the most important concepts is factor node. -Factor node represents a local function in a factorised representation of a generative model. - -!!! note - To quickly check the list of all available factor nodes that can be used in the model specification language call `?make_node` or `Base.doc(make_node)`. - -## [Node traits](@id lib-node-traits) - -Each factor node has to define [`as_node_functional_form`](@ref) trait function and to specify [`ValidNodeFunctionalForm`](@ref) singleton as a return object. By default [`as_node_functional_form`](@ref) returns [`UndefinedNodeFunctionalForm`](@ref). Objects that do not specify this property correctly cannot be used in model specification. - -!!! note - [`@node`](@ref) macro does that automatically - -```@docs -ValidNodeFunctionalForm -UndefinedNodeFunctionalForm -as_node_functional_form -``` - -## [Node types](@id lib-node-types) - -We distinguish different types of factor nodes to have a better control over Bethe Free Energy computation. -Each factor node has either [`Deterministic`](@ref) or [`Stochastic`](@ref) functional form type. - -```@docs -Deterministic -Stochastic -isdeterministic -isstochastic -sdtype -``` - -```@setup lib-node-types -using ReactiveMP -``` - -For example `+` node has the [`Deterministic`](@ref) type: - -```@example lib-node-types -plus_node = make_node(+) - -println("Is `+` node deterministic: ", isdeterministic(plus_node)) -println("Is `+` node stochastic: ", isstochastic(plus_node)) -nothing #hide -``` - -On the other hand `Bernoulli` node has the [`Stochastic`](@ref) type: - -```@example lib-node-types -bernoulli_node = make_node(Bernoulli) - -println("Is `Bernoulli` node deterministic: ", isdeterministic(bernoulli_node)) -println("Is `Bernoulli` node stochastic: ", isstochastic(bernoulli_node)) -``` - -To get an actual instance of the type object we use [`sdtype`](@ref) function: - -```@example lib-node-types -println("sdtype() of `+` node is ", sdtype(plus_node)) -println("sdtype() of `Bernoulli` node is ", sdtype(bernoulli_node)) -nothing #hide -``` - -## [Node factorisation constraints](@id lib-node-factorisation-constraints) - diff --git a/docs/src/lib/nodes/nodes.md b/docs/src/lib/nodes/nodes.md new file mode 100644 index 000000000..0e112add6 --- /dev/null +++ b/docs/src/lib/nodes/nodes.md @@ -0,0 +1,107 @@ + +# [Nodes implementation](@id lib-node) + +In the message passing framework, one of the most important concepts is a factor node. +A factor node represents a local function in a factorised representation of a generative model. + +!!! note + To quickly check the list of all available factor nodes that can be used in the model specification language, call `?make_node` or `Base.doc(make_node)`. + +## [Adding a custom node](@id lib-custom-node) + +`ReactiveMP.jl` exports the `@node` macro that allows for quick definition of a factor node with a __fixed__ number of edges. The interface is the following: + +```julia +struct MyNewCustomNode end + +@node MyNewCustomNode Stochastic [ x, y, z ] +# ^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^ ^^^^^^^^^^^ +# Node's tag/name Node's type A fixed set of edges +# Another possible The very first edge (in this example `x`) is considered +# value is to be the output of the node +# `Deterministic` +``` + +This expression registers a new node that can be used with the inference engine. Note howeve, that the `@node` macro does not generate any message passing update rules. +These must be defined using the `@rule` macro. + +## [Node types](@id lib-node-types) + +We distinguish different types of factor nodes in order to have better control over Bethe Free Energy computation. +Each factor node has either the [`Deterministic`](@ref) or [`Stochastic`](@ref) functional form type. + +```@docs +Deterministic +Stochastic +isdeterministic +isstochastic +sdtype +``` + +```@setup lib-node-types +using ReactiveMP +``` + +For example the `+` node has the [`Deterministic`](@ref) type: + +```@example lib-node-types +plus_node = make_node(+) + +println("Is `+` node deterministic: ", isdeterministic(plus_node)) +println("Is `+` node stochastic: ", isstochastic(plus_node)) +nothing #hide +``` + +On the other hand, the `Bernoulli` node has the [`Stochastic`](@ref) type: + +```@example lib-node-types +bernoulli_node = make_node(Bernoulli) + +println("Is `Bernoulli` node deterministic: ", isdeterministic(bernoulli_node)) +println("Is `Bernoulli` node stochastic: ", isstochastic(bernoulli_node)) +``` + +To get an actual instance of the type object we use [`sdtype`](@ref) function: + +```@example lib-node-types +println("sdtype() of `+` node is ", sdtype(plus_node)) +println("sdtype() of `Bernoulli` node is ", sdtype(bernoulli_node)) +nothing #hide +``` + +## [Node functional dependencies pipeline](@id lib-node-functional-dependencies-pipeline) + +The generic implementation of factor nodes in ReactiveMP supports custom functional dependency pipelines. Briefly, the __functional dependencies pipeline__ defines what +dependencies are need to compute a single message. As an example, consider the belief-propagation message update equation for a factor node $f$ with three edges: $x$, $y$ and $z$: + +```math +\mu(x) = \int \mu(y) \mu(z) f(x, y, z) \mathrm{d}y \mathrm{d}z +``` + +Here we see that in the standard setting for the belief-propagation message out of edge $x$, we need only messages from the edges $y$ and $z$. In contrast, consider the variational message update rule equation with mean-field assumption: + +```math +\mu(x) = \exp \int q(y) q(z) \log f(x, y, z) \mathrm{d}y \mathrm{d}z +``` + +We see that in this setting, we do not need messages $\mu(y)$ and $\mu(z)$, but only the marginals $q(y)$ and $q(z)$. The purpose of a __functional dependencies pipeline__ is to determine functional dependencies (a set of messages or marginals) that are needed to compute a single message. By default, `ReactiveMP.jl` uses so-called `DefaultFunctionalDependencies` that correctly implements belief-propagation and variational message passing schemes (including both mean-field and structured factorisations). The full list of built-in pipelines is presented below: + +```@docs +DefaultFunctionalDependencies +RequireMessageFunctionalDependencies +RequireMarginalFunctionalDependencies +RequireEverythingFunctionalDependencies +``` + +## [Node traits](@id lib-node-traits) + +Each factor node has to define the [`as_node_functional_form`](@ref) trait function and to specify a [`ValidNodeFunctionalForm`](@ref) singleton as a return object. By default [`as_node_functional_form`](@ref) returns [`UndefinedNodeFunctionalForm`](@ref). Objects that do not specify this property correctly cannot be used in model specification. + +!!! note + [`@node`](@ref) macro does that automatically + +```@docs +ValidNodeFunctionalForm +UndefinedNodeFunctionalForm +as_node_functional_form +``` diff --git a/src/node.jl b/src/node.jl index 2ef318729..1122cf6a1 100644 --- a/src/node.jl +++ b/src/node.jl @@ -9,7 +9,9 @@ export iscontain, isfactorised, getinterface export clusters, clusterindex export connect!, activate! export make_node -export DefaultFunctionalDependencies, RequireInboundFunctionalDependencies, RequireEverythingFunctionalDependencies +export DefaultFunctionalDependencies, + RequireMessageFunctionalDependencies, + RequireMarginalFunctionalDependencies, RequireEverythingFunctionalDependencies export @node using Rocket @@ -21,7 +23,7 @@ import Base: getindex, setindex!, firstindex, lastindex ## Node traits """ - ValidNodeFunctionalForm + ValidNodeFunctionalForm Trait specification for an object that can be used in model specification as a factor node. @@ -146,7 +148,7 @@ struct FullFactorisation end """ collect_factorisation(nodetype, factorisation) -This function converts given factorisation to a correct internal factorisation representation for a given node. +This function converts given factorisation to a correct internal factorisation representation for a given node. See also: [`MeanField`](@ref), [`FullFactorisation`](@ref) """ @@ -155,7 +157,7 @@ function collect_factorisation end """ collect_meta(nodetype, meta) -This function converts given meta object to a correct internal meta representation for a given node. +This function converts given meta object to a correct internal meta representation for a given node. Fallbacks to `default_meta` in case if meta is `nothing`. See also: [`default_meta`](@ref), [`FactorNode`](@ref) @@ -236,7 +238,7 @@ local_constraint(interface::NodeInterface) = interface.local_constraint """ tag(interface) -Returns a tag of the interface in the form of `Val{ name(interface) }`. +Returns a tag of the interface in the form of `Val{ name(interface) }`. The major difference between tag and name is that it is possible to dispath on interface's tag in message computation rule. See also: [`NodeInterface`](@ref), [`name`](@ref) @@ -306,8 +308,8 @@ get_pipeline_stages(interface::NodeInterface) = get_pipeline_stages(connectedvar """ IndexedNodeInterface -`IndexedNodeInterface` object represents a repetative node-variable connection. -Used in cases when node may connect different number of random variables with the same name, e.g. means and precisions of Gaussian Mixture node. +`IndexedNodeInterface` object represents a repetative node-variable connection. +Used in cases when a node may connect to a different number of random variables with the same name, e.g. means and precisions of a Gaussian Mixture node. See also: [`name`](@ref), [`tag`](@ref), [`messageout`](@ref), [`messagein`](@ref) """ @@ -339,20 +341,25 @@ get_pipeline_stages(interface::IndexedNodeInterface) = get_pipelin """ FactorNodeLocalMarginal -This object represents local marginals for some specific factor node. -Local marginal can be joint in case of structured factorisation. +This object represents local marginals for some specific factor node. +The local marginal can be joint in case of structured factorisation. Local to factor node marginal also can be shared with a corresponding marginal of some random variable. See also: [`FactorNodeLocalMarginals`](@ref) """ mutable struct FactorNodeLocalMarginal index :: Int + first :: Int name :: Symbol stream :: Union{Nothing, AbstractSubscribable{<:Marginal}} - FactorNodeLocalMarginal(index::Int, name::Symbol) = new(index, name, nothing) + FactorNodeLocalMarginal(index::Int, first::Int, name::Symbol) = new(index, first, name, nothing) end +# `First` defines the index of the first element in the joint marginal +# E.g. if the set of variables is (x, y, z, w) and joint is `z_w`, first is equal to 3 +Base.first(localmarginal::FactorNodeLocalMarginal) = localmarginal.first + index(localmarginal::FactorNodeLocalMarginal) = localmarginal.index name(localmarginal::FactorNodeLocalMarginal) = localmarginal.name @@ -362,7 +369,7 @@ setstream!(localmarginal::FactorNodeLocalMarginal, observable) = localmarginal.s """ FactorNodeLocalMarginals -This object acts as an iterable and indexable proxy for local marginals for some node. +This object acts as an iterable and indexable proxy for local marginals for some node. """ struct FactorNodeLocalMarginals{M} marginals::M @@ -371,10 +378,10 @@ end function FactorNodeLocalMarginals(variablenames, factorisation) marginal_names = map(fcluster -> clustername(map(i -> variablenames[i], fcluster)), factorisation) index = 0 # its better not to use zip or enumerate here to preserve tuple-like structure - marginals = map((mname) -> begin + marginals = map(marginal_names) do mname index += 1 - FactorNodeLocalMarginal(index, mname) - end, marginal_names) + return FactorNodeLocalMarginal(index, first(factorisation[index]), mname) + end return FactorNodeLocalMarginals(marginals) end @@ -454,7 +461,7 @@ getpipeline(factornode::FactorNode) = factornode.pipeline clustername(cluster) = mapreduce(v -> name(v), (a, b) -> Symbol(a, :_, b), cluster) -# Cluster is reffered to a tuple of node interfaces +# Cluster is reffered to a tuple of node interfaces clusters(factornode::FactorNode) = map(factor -> map(i -> @inbounds(interfaces(factornode)[i]), factor), factorisation(factornode)) @@ -523,7 +530,7 @@ collect_pipeline(T::Any, stage::AbstractPipelineStage) = Fact collect_pipeline(T::Any, fdp::AbstractNodeFunctionalDependenciesPipeline) = FactorNodePipeline(fdp, EmptyPipelineStage()) collect_pipeline(T::Any, pipeline::FactorNodePipeline) = pipeline -## Functional Dependencies +## Functional Dependencies function message_dependencies end function marginal_dependencies end @@ -531,50 +538,91 @@ function marginal_dependencies end Base.:+(left::AbstractNodeFunctionalDependenciesPipeline, right::AbstractPipelineStage) = FactorNodePipeline(left, right) Base.:+(left::FactorNodePipeline, right::AbstractPipelineStage) = FactorNodePipeline(left.functional_dependencies, left.extra_stages + right) -### Default +### Default """ DefaultFunctionalDependencies + +This pipeline translates directly to enforcing a variational message passing scheme. In order to compute a message out of some edge, this pipeline requires +messages from edges within the same edge-cluster and marginals over other edge-clusters. + +See also: [`ReactiveMP.RequireMessageFunctionalDependencies`](@ref), [`ReactiveMP.RequireMarginalFunctionalDependencies`](@ref), [`ReactiveMP.RequireEverythingFunctionalDependencies`](@ref) """ struct DefaultFunctionalDependencies <: AbstractNodeFunctionalDependenciesPipeline end -function message_dependencies(::DefaultFunctionalDependencies, nodeinterfaces, varcluster, iindex) +function message_dependencies( + ::DefaultFunctionalDependencies, + nodeinterfaces, + nodelocalmarginals, + varcluster, + cindex, + iindex +) # First we remove current edge index from the list of dependencies vdependencies = TupleTools.deleteat(varcluster, varclusterindex(varcluster, iindex)) # Second we map interface indices to the actual interfaces return map(inds -> map(i -> @inbounds(nodeinterfaces[i]), inds), vdependencies) end -function marginal_dependencies(::DefaultFunctionalDependencies, nodelocalmarginals, varcluster, cindex) +function marginal_dependencies( + ::DefaultFunctionalDependencies, + nodeinterfaces, + nodelocalmarginals, + varcluster, + cindex, + iindex +) return TupleTools.deleteat(nodelocalmarginals, cindex) end -### With inbound +### With inbound messages -struct RequireInboundFunctionalDependencies{I, S} <: AbstractNodeFunctionalDependenciesPipeline - indices :: I - start_with :: S -end +""" + RequireMessageFunctionalDependencies(indices::Tuple, start_with::Tuple) -struct InterfacePluginStartWithMessage{M, S} - msg :: M - start_with :: S -end +The same as `DefaultFunctionalDependencies`, but in order to compute a message out of some edge also requires the inbound message on the this edge. -name(p::InterfacePluginStartWithMessage) = name(p.msg) -messagein(p::InterfacePluginStartWithMessage) = messagein(p.start_with, p) +# Arguments -messagein(::Nothing, p::InterfacePluginStartWithMessage) = messagein(p.msg) +- `indices`::Tuple, tuple of integers, which indicates what edges should require inbound messages +- `start_with::Tuple`, tuple of `nothing` or `<:Distribution`, which specifies the initial inbound messages for edges in `indices` -function messagein(something, p::InterfacePluginStartWithMessage) - output = messagein(p.msg) - if isnothing(getrecent(output)) - setmessage!(output, something) - end - return output +Note: `start_with` uses `setmessage!` mechanism, hence, it can be visible by other listeners on the same edge. Explicit call to `setmessage!` overwrites whatever has been passed in `start_with`. + +`@model` macro accepts a simplified construction of this pipeline: + +```julia +@model function some_model() + # ... + y ~ NormalMeanVariance(x, τ) where { + pipeline = RequireMessage(x = vague(NormalMeanPrecision), τ) + # ^^^ ^^^ + # request 'inbound' for 'x' we may do the same for 'τ', + # and initialise with `vague(...)` but here we skip initialisation + } + # ... end +``` -function message_dependencies(dependencies::RequireInboundFunctionalDependencies, nodeinterfaces, varcluster, iindex) +Deprecation warning: `RequireInboundFunctionalDependencies` has been deprecated in favor of `RequireMessageFunctionalDependencies`. + +See also: [`ReactiveMP.DefaultFunctionalDependencies`](@ref), [`ReactiveMP.RequireMarginalFunctionalDependencies`](@ref), [`ReactiveMP.RequireEverythingFunctionalDependencies`](@ref) +""" +struct RequireMessageFunctionalDependencies{I, S} <: AbstractNodeFunctionalDependenciesPipeline + indices :: I + start_with :: S +end + +Base.@deprecate_binding RequireInboundFunctionalDependencies RequireMessageFunctionalDependencies + +function message_dependencies( + dependencies::RequireMessageFunctionalDependencies, + nodeinterfaces, + nodelocalmarginals, + varcluster, + cindex, + iindex +) # First we find dependency index in `indices`, we use it later to find `start_with` distribution depindex = findfirst((i) -> i === iindex, dependencies.indices) @@ -582,52 +630,194 @@ function message_dependencies(dependencies::RequireInboundFunctionalDependencies # If we have `depindex` in our `indices` we include it in our list of functional dependencies. It effectively forces rule to require inbound message if depindex !== nothing # `mapindex` is a lambda function here - mapindex = let nodeinterfaces = nodeinterfaces, depindex = depindex - (i) -> begin - interface = @inbounds nodeinterfaces[i] - # InterfacePluginStartWithMessage is a proxy structure for `name` and `messagein` method for an interface - # It returns the same name but modifies `messagein` to return an observable with `start_with` operator - return if i === iindex - InterfacePluginStartWithMessage(interface, dependencies.start_with[depindex]) - else - interface - end - end + output = messagein(nodeinterfaces[iindex]) + start_with = dependencies.start_with[depindex] + # Initialise now, if message has not been initialised before and `start_with` element is not empty + if isnothing(getrecent(output)) && !isnothing(start_with) + setmessage!(output, start_with) end - return map(inds -> map(mapindex, inds), varcluster) + return map(inds -> map(i -> @inbounds(nodeinterfaces[i]), inds), varcluster) else - return message_dependencies(DefaultFunctionalDependencies(), nodeinterfaces, varcluster, iindex) + return message_dependencies( + DefaultFunctionalDependencies(), + nodeinterfaces, + nodelocalmarginals, + varcluster, + cindex, + iindex + ) end end -function marginal_dependencies(::RequireInboundFunctionalDependencies, nodelocalmarginals, varcluster, cindex) - return marginal_dependencies(DefaultFunctionalDependencies(), nodelocalmarginals, varcluster, cindex) +function marginal_dependencies( + ::RequireMessageFunctionalDependencies, + nodeinterfaces, + nodelocalmarginals, + varcluster, + cindex, + iindex +) + return marginal_dependencies( + DefaultFunctionalDependencies(), + nodeinterfaces, + nodelocalmarginals, + varcluster, + cindex, + iindex + ) +end + +### With marginals + +""" + RequireMarginalFunctionalDependencies(indices::Tuple, start_with::Tuple) + +Similar to `DefaultFunctionalDependencies`, but in order to compute a message out of some edge also requires the posterior marginal on that edge. + +# Arguments + +- `indices`::Tuple, tuple of integers, which indicates what edges should require their own marginals +- `start_with::Tuple`, tuple of `nothing` or `<:Distribution`, which specifies the initial marginal for edges in `indices` + +Note: `start_with` uses the `setmarginal!` mechanism, hence it can be visible to other listeners on the same edge. Explicit calls to `setmarginal!` overwrites whatever has been passed in `start_with`. + +`@model` macro accepts a simplified construction of this pipeline: + +```julia +@model function some_model() + # ... + y ~ NormalMeanVariance(x, τ) where { + pipeline = RequireMarginal(x = vague(NormalMeanPrecision), τ) + # ^^^ ^^^ + # request 'marginal' for 'x' we may do the same for 'τ', + # and initialise with `vague(...)` but here we skip initialisation + } + # ... +end +``` + +Note: The simplified construction in `@model` macro syntax is only available in `GraphPPL.jl` of version `>2.2.0`. + +See also: [`ReactiveMP.DefaultFunctionalDependencies`](@ref), [`ReactiveMP.RequireMessageFunctionalDependencies`](@ref), [`ReactiveMP.RequireEverythingFunctionalDependencies`](@ref) +""" +struct RequireMarginalFunctionalDependencies{I, S} <: AbstractNodeFunctionalDependenciesPipeline + indices :: I + start_with :: S +end + +function message_dependencies( + ::RequireMarginalFunctionalDependencies, + nodeinterfaces, + nodelocalmarginals, + varcluster, + cindex, + iindex +) + return message_dependencies( + DefaultFunctionalDependencies(), + nodeinterfaces, + nodelocalmarginals, + varcluster, + cindex, + iindex + ) +end + +function marginal_dependencies( + dependencies::RequireMarginalFunctionalDependencies, + nodeinterfaces, + nodelocalmarginals, + varcluster, + cindex, + iindex +) + # First we find dependency index in `indices`, we use it later to find `start_with` distribution + depindex = findfirst((i) -> i === iindex, dependencies.indices) + + if depindex !== nothing + # We create an auxiliary local marginal with non-standard index here and inject it to other standard dependencies + extra_localmarginal = FactorNodeLocalMarginal(-1, iindex, name(nodeinterfaces[iindex])) + vmarginal = getmarginal(connectedvar(nodeinterfaces[iindex]), IncludeAll()) + start_with = dependencies.start_with[depindex] + # Initialise now, if marginal has not been initialised before and `start_with` element is not empty + if isnothing(getrecent(vmarginal)) && !isnothing(start_with) + setmarginal!(vmarginal, start_with) + end + setstream!(extra_localmarginal, vmarginal) + default = marginal_dependencies( + DefaultFunctionalDependencies(), + nodeinterfaces, + nodelocalmarginals, + varcluster, + cindex, + iindex + ) + # Find insertion position (probably might be implemented more efficiently) + insertafter = sum(first(el) < iindex ? 1 : 0 for el in default; init = 0) + return TupleTools.insertafter(default, insertafter, (extra_localmarginal,)) + else + return marginal_dependencies( + DefaultFunctionalDependencies(), + nodeinterfaces, + nodelocalmarginals, + varcluster, + cindex, + iindex + ) + end end ### Everything +""" + RequireEverythingFunctionalDependencies + +This pipeline specifies that in order to compute a message of some edge update rules request everything that is available locally. +This includes all inbound messages (including on the same edge) and marginals over all local edge-clusters (this may or may not include marginals on single edges, depends on the local factorisation constraint). + +See also: [`DefaultFunctionalDependencies`](@ref), [`RequireMessageFunctionalDependencies`](@ref), [`RequireMarginalFunctionalDependencies`](@ref) +""" struct RequireEverythingFunctionalDependencies <: AbstractNodeFunctionalDependenciesPipeline end -function ReactiveMP.message_dependencies(::RequireEverythingFunctionalDependencies, nodeinterfaces, varcluster, iindex) - # Return all node interfaces including the edge we are trying to compuate a message on +function ReactiveMP.message_dependencies( + ::RequireEverythingFunctionalDependencies, + nodeinterfaces, + nodelocalmarginals, + varcluster, + cindex, + iindex +) + # Return all node interfaces including the edge we are trying to compute a message on return nodeinterfaces end function ReactiveMP.marginal_dependencies( ::RequireEverythingFunctionalDependencies, + nodeinterfaces, nodelocalmarginals, varcluster, - cindex + cindex, + iindex ) # Returns only local marginals based on local q factorisation, it does not return all possible combinations of all joint posterior marginals return nodelocalmarginals end -### +### default_functional_dependencies_pipeline(_) = DefaultFunctionalDependencies() -### Generic +### Generic `functional_dependencies` for `AbstractFactorNode` + +function functional_dependencies(factornode::AbstractFactorNode, iname::Symbol) + return functional_dependencies(get_pipeline_dependencies(getpipeline(factornode)), factornode, iname) +end + +function functional_dependencies(factornode::AbstractFactorNode, iindex::Int) + return functional_dependencies(get_pipeline_dependencies(getpipeline(factornode)), factornode, iindex) +end + +### `FactorNode` implementation of `functional_dependencies` function functional_dependencies(dependencies, factornode::FactorNode, iname::Symbol) return functional_dependencies(dependencies, factornode, interfaceindex(factornode, iname)) @@ -642,8 +832,8 @@ function functional_dependencies(dependencies, factornode::FactorNode, iindex::I varcluster = @inbounds nodeclusters[cindex] - messages = message_dependencies(dependencies, nodeinterfaces, varcluster, iindex) - marginals = marginal_dependencies(dependencies, nodelocalmarginals, varcluster, cindex) + messages = message_dependencies(dependencies, nodeinterfaces, nodelocalmarginals, varcluster, cindex, iindex) + marginals = marginal_dependencies(dependencies, nodeinterfaces, nodelocalmarginals, varcluster, cindex, iindex) return tuple(messages...), tuple(marginals...) end @@ -670,18 +860,15 @@ function get_marginals_observable(factornode, marginals) end function activate!(model, factornode::AbstractFactorNode) - fform = functionalform(factornode) - meta = metadata(factornode) - node_pipeline = getpipeline(factornode) - - node_pipeline_dependencies = get_pipeline_dependencies(node_pipeline) + fform = functionalform(factornode) + meta = metadata(factornode) + node_pipeline = getpipeline(factornode) node_pipeline_extra_stages = get_pipeline_stages(node_pipeline) for (iindex, interface) in enumerate(interfaces(factornode)) cvariable = connectedvar(interface) if cvariable !== nothing && (israndom(cvariable) || isdata(cvariable)) - message_dependencies, marginal_dependencies = - functional_dependencies(node_pipeline_dependencies, factornode, iindex) + message_dependencies, marginal_dependencies = functional_dependencies(factornode, iindex) msgs_names, msgs_observable = get_messages_observable(factornode, message_dependencies) marginal_names, marginals_observable = get_marginals_observable(factornode, marginal_dependencies) @@ -841,7 +1028,7 @@ import .MacroHelpers # Examples ```julia -struct MyNormalDistribution +struct MyNormalDistribution mean :: Float64 var :: Float64 end @@ -849,7 +1036,7 @@ end @node MyNormalDistribution Stochastic [ out, mean, var ] ``` -```julia +```julia @node typeof(+) Deterministic [ out, in1, in2 ] ``` @@ -898,7 +1085,7 @@ macro node(fformtype, sdtype, interfaces_list) $non_unique_error_sym = (fformtype, names) -> """ - Non-unique variables used for the creation of the `$(fformtype)` node, which is disallowed. + Non-unique variables used for the creation of the `$(fformtype)` node, which is disallowed. Check creation of the `$(fformtype)` with the `[ $(join(names, ", ")) ]` arguments. """ ) @@ -939,9 +1126,9 @@ macro node(fformtype, sdtype, interfaces_list) end end - # By default every argument passed to a factorisation option of the node is transformed by + # By default every argument passed to a factorisation option of the node is transformed by # `collect_factorisation` function to have a tuple like structure. - # The default recipe is simple: for stochastic nodes we convert `FullFactorisation` and `MeanField` objects + # The default recipe is simple: for stochastic nodes we convert `FullFactorisation` and `MeanField` objects # to their tuple of indices equivalents. For deterministic nodes any factorisation is replaced by a FullFactorisation equivalent factorisation_collectors = if sdtype === :Stochastic quote diff --git a/src/nodes/probit.jl b/src/nodes/probit.jl index 381d314e8..1a5008a61 100644 --- a/src/nodes/probit.jl +++ b/src/nodes/probit.jl @@ -17,7 +17,7 @@ ProbitMeta(; p = 32) = ProbitMeta(p) default_meta(::Type{Probit}) = ProbitMeta(32) default_functional_dependencies_pipeline(::Type{<:Probit}) = - RequireInboundFunctionalDependencies((2,), (vague(NormalMeanPrecision),)) + RequireMessageFunctionalDependencies((2,), (vague(NormalMeanPrecision),)) default_interface_local_constraint(::Type{<:Probit}, edge::Val{:in}) = MomentMatching() default_interface_local_constraint(::Type{<:Probit}, edge::Val{:out}) = Marginalisation() diff --git a/test/models/test_probit.jl b/test/models/test_probit.jl index e224c6b95..bf2c0aff0 100644 --- a/test/models/test_probit.jl +++ b/test/models/test_probit.jl @@ -23,7 +23,7 @@ using StatsFuns: normcdf for k in 2:nr_samples+1 x[k] ~ NormalMeanPrecision(x[k-1] + 0.1, 100) y[k-1] ~ Probit(x[k]) where { - pipeline = RequireInbound(in = NormalMeanPrecision(0, 1.0)) + pipeline = RequireMessage(in = NormalMeanPrecision(0, 1.0)) } end diff --git a/test/test_node.jl b/test/test_node.jl index 32968357c..ea4c68736 100644 --- a/test/test_node.jl +++ b/test/test_node.jl @@ -2,6 +2,7 @@ module ReactiveMPNodeTest using Test using ReactiveMP +using Rocket using Distributions @testset "FactorNode" begin @@ -22,6 +23,691 @@ using Distributions @test_throws MethodError sdtype(0) end + @testset "Functional dependencies pipelines" begin + struct DummyStochasticNode end + + @node DummyStochasticNode Stochastic [x, y, z] + + function make_dummy_model(factorisation, pipeline) + m = ReactiveMP.FactorGraphModel() + x = randomvar(m, :x) + y = randomvar(m, :y) + z = randomvar(m, :z) + make_node(m, FactorNodeCreationOptions(nothing, nothing, nothing), Uninformative, x) + make_node(m, FactorNodeCreationOptions(nothing, nothing, nothing), Uninformative, y) + make_node(m, FactorNodeCreationOptions(nothing, nothing, nothing), Uninformative, z) + node = + make_node(m, FactorNodeCreationOptions(factorisation, nothing, pipeline), DummyStochasticNode, x, y, z) + activate!(m) + return m, x, y, z, node + end + + @testset "Default functional dependencies" begin + @testset "Default functional dependencies: FullFactorisation" begin + # We test `FullFactorisation` case here + m, x, y, z, node = make_dummy_model(FullFactorisation(), DefaultFunctionalDependencies()) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === + DefaultFunctionalDependencies() + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 2 && name(x_msgdeps[1]) === :y && name(x_msgdeps[2]) === :z + @test length(x_mgdeps) === 0 + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 2 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :z + @test length(y_mgdeps) === 0 + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 2 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :y + @test length(z_mgdeps) === 0 + end + + @testset "Default functional dependencies: MeanField" begin + # We test `MeanField` case here + m, x, y, z, node = make_dummy_model(MeanField(), DefaultFunctionalDependencies()) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === + DefaultFunctionalDependencies() + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 0 + @test length(x_mgdeps) === 2 && name(x_mgdeps[1]) === :y && name(x_mgdeps[2]) === :z + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 0 + @test length(y_mgdeps) === 2 && name(y_mgdeps[1]) === :x && name(y_mgdeps[2]) === :z + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 0 + @test length(z_mgdeps) === 2 && name(z_mgdeps[1]) === :x && name(z_mgdeps[2]) === :y + end + + @testset "Default functional dependencies: Structured factorisation" begin + # We test `(x, y), (z)` factorisation case here + m, x, y, z, node = make_dummy_model(((1, 2), (3,)), DefaultFunctionalDependencies()) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === + DefaultFunctionalDependencies() + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 1 && name(x_msgdeps[1]) === :y + @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :z + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 1 && name(y_msgdeps[1]) === :x + @test length(y_mgdeps) === 1 && name(y_mgdeps[1]) === :z + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 0 + @test length(z_mgdeps) === 1 && name(z_mgdeps[1]) === :x_y + + ## --- ## + + # We test `(x, ), (y, z)` factorisation case here + m, x, y, z, node = make_dummy_model(((1,), (2, 3)), DefaultFunctionalDependencies()) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === + DefaultFunctionalDependencies() + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 0 + @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :y_z + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 1 && name(y_msgdeps[1]) === :z + @test length(y_mgdeps) === 1 && name(y_mgdeps[1]) === :x + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 1 && name(z_msgdeps[1]) === :y + @test length(z_mgdeps) === 1 && name(z_mgdeps[1]) === :x + + ## --- ## + + # We test `(x, z), (y, )` factorisation case here + m, x, y, z, node = make_dummy_model(((1, 3), (2,)), DefaultFunctionalDependencies()) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === + DefaultFunctionalDependencies() + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 1 && name(x_msgdeps[1]) === :z + @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :y + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 0 + @test length(y_mgdeps) === 1 && name(y_mgdeps[1]) === :x_z + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 1 && name(z_msgdeps[1]) === :x + @test length(z_mgdeps) === 1 && name(z_mgdeps[1]) === :y + end + end + + @testset "Require inbound message functional dependencies" begin + @testset "Require inbound message functional dependencies: FullFactorisation" begin + # Require inbound message on `x` + pipeline = RequireMessageFunctionalDependencies((1,), (NormalMeanVariance(0.123, 0.123),)) + + # We test `FullFactorisation` case here + m, x, y, z, node = make_dummy_model(FullFactorisation(), pipeline) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 3 && name(x_msgdeps[1]) === :x && name(x_msgdeps[2]) === :y && + name(x_msgdeps[3]) === :z + @test length(x_mgdeps) === 0 + @test mean_var(Rocket.getrecent(ReactiveMP.messagein(x_msgdeps[1]))) == (0.123, 0.123) + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 2 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :z + @test length(y_mgdeps) === 0 + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 2 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :y + @test length(z_mgdeps) === 0 + + ## -- ## + + # Require inbound message on `y` and `z` + pipeline = RequireMessageFunctionalDependencies((2, 3), (NormalMeanVariance(0.123, 0.123), nothing)) + + # We test `FullFactorisation` case here + m, x, y, z, node = make_dummy_model(FullFactorisation(), pipeline) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 2 && name(x_msgdeps[1]) === :y && name(x_msgdeps[2]) === :z + @test length(x_mgdeps) === 0 + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 3 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :y && + name(y_msgdeps[3]) === :z + @test length(y_mgdeps) === 0 + @test mean_var(Rocket.getrecent(ReactiveMP.messagein(y_msgdeps[2]))) == (0.123, 0.123) + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 3 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :y && + name(z_msgdeps[3]) === :z + @test length(z_mgdeps) === 0 + @test isnothing(Rocket.getrecent(ReactiveMP.messagein(z_msgdeps[3]))) + end + + @testset "Require inbound message functional dependencies: MeanField" begin + # Require inbound message on `x` + pipeline = RequireMessageFunctionalDependencies((1,), (NormalMeanVariance(0.123, 0.123),)) + + # We test `MeanField` case here + m, x, y, z, node = make_dummy_model(MeanField(), pipeline) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 1 && name(x_msgdeps[1]) === :x + @test length(x_mgdeps) === 2 && name(x_mgdeps[1]) === :y && name(x_mgdeps[2]) === :z + @test mean_var(Rocket.getrecent(ReactiveMP.messagein(x_msgdeps[1]))) == (0.123, 0.123) + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 0 + @test length(y_mgdeps) === 2 && name(y_mgdeps[1]) === :x && name(y_mgdeps[2]) === :z + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 0 + @test length(z_mgdeps) === 2 && name(z_mgdeps[1]) === :x && name(z_mgdeps[2]) === :y + + ## -- ## + + # Require inbound message on `y` and `z` + pipeline = RequireMessageFunctionalDependencies((2, 3), (NormalMeanVariance(0.123, 0.123), nothing)) + + # We test `MeanField` case here + m, x, y, z, node = make_dummy_model(MeanField(), pipeline) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 0 + @test length(x_mgdeps) === 2 && name(x_mgdeps[1]) === :y && name(x_mgdeps[2]) === :z + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 1 && name(y_msgdeps[1]) === :y + @test length(y_mgdeps) === 2 && name(y_mgdeps[1]) === :x && name(y_mgdeps[2]) === :z + @test mean_var(Rocket.getrecent(ReactiveMP.messagein(y_msgdeps[1]))) == (0.123, 0.123) + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 1 && name(z_msgdeps[1]) === :z + @test length(z_mgdeps) === 2 && name(z_mgdeps[1]) === :x && name(z_mgdeps[2]) === :y + @test isnothing(Rocket.getrecent(ReactiveMP.messagein(z_msgdeps[1]))) + end + + @testset "Require inbound message dependencies: Structured factorisation" begin + # Require inbound message on `y` and `z` + pipeline = RequireMessageFunctionalDependencies((2, 3), (NormalMeanVariance(0.123, 0.123), nothing)) + + # We test `(x, y), (z)` factorisation case here + m, x, y, z, node = make_dummy_model(((1, 2), (3,)), pipeline) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 1 && name(x_msgdeps[1]) === :y + @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :z + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 2 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :y + @test length(y_mgdeps) === 1 && name(y_mgdeps[1]) === :z + @test mean_var(Rocket.getrecent(ReactiveMP.messagein(y_msgdeps[2]))) == (0.123, 0.123) + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 1 && name(z_msgdeps[1]) === :z + @test length(z_mgdeps) === 1 && name(z_mgdeps[1]) === :x_y + @test isnothing(Rocket.getrecent(ReactiveMP.messagein(z_msgdeps[1]))) + + ## --- ## + + # Require inbound message on `y` and `z` + pipeline = RequireMessageFunctionalDependencies((2, 3), (NormalMeanVariance(0.123, 0.123), nothing)) + + # We test `(x, ), (y, z)` factorisation case here + m, x, y, z, node = make_dummy_model(((1,), (2, 3)), pipeline) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 0 + @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :y_z + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 2 && name(y_msgdeps[1]) === :y && name(y_msgdeps[2]) === :z + @test length(y_mgdeps) === 1 && name(y_mgdeps[1]) === :x + @test mean_var(Rocket.getrecent(ReactiveMP.messagein(y_msgdeps[1]))) == (0.123, 0.123) + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 2 && name(z_msgdeps[1]) === :y && name(z_msgdeps[2]) === :z + @test length(z_mgdeps) === 1 && name(z_mgdeps[1]) === :x + @test isnothing(Rocket.getrecent(ReactiveMP.messagein(z_msgdeps[2]))) + + ## --- ## + + # Require inbound message on `y` and `z` + pipeline = RequireMessageFunctionalDependencies((2, 3), (NormalMeanVariance(0.123, 0.123), nothing)) + + # We test `(x, z), (y, )` factorisation case here + m, x, y, z, node = make_dummy_model(((1, 3), (2,)), pipeline) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 1 && name(x_msgdeps[1]) === :z + @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :y + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 1 && name(y_msgdeps[1]) === :y + @test length(y_mgdeps) === 1 && name(y_mgdeps[1]) === :x_z + @test mean_var(Rocket.getrecent(ReactiveMP.messagein(y_msgdeps[1]))) == (0.123, 0.123) + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 2 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :z + @test length(z_mgdeps) === 1 && name(z_mgdeps[1]) === :y + @test isnothing(Rocket.getrecent(ReactiveMP.messagein(z_msgdeps[2]))) + end + end + + @testset "Require marginal functional dependencies" begin + @testset "Require marginal functional dependencies: FullFactorisation" begin + # Require marginal on `x` + pipeline = RequireMarginalFunctionalDependencies((1,), (NormalMeanVariance(0.123, 0.123),)) + + # We test `FullFactorisation` case here + m, x, y, z, node = make_dummy_model(FullFactorisation(), pipeline) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 2 && name(x_msgdeps[1]) === :y && name(x_msgdeps[2]) === :z + @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :x + @test mean_var(Rocket.getrecent(ReactiveMP.getmarginal(x, IncludeAll()))) == (0.123, 0.123) + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 2 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :z + @test length(y_mgdeps) === 0 + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 2 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :y + @test length(z_mgdeps) === 0 + + ## -- ## + + # Require marginals on `y` and `z` + pipeline = RequireMarginalFunctionalDependencies((2, 3), (NormalMeanVariance(0.123, 0.123), nothing)) + + # We test `FullFactorisation` case here + m, x, y, z, node = make_dummy_model(FullFactorisation(), pipeline) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 2 && name(x_msgdeps[1]) === :y && name(x_msgdeps[2]) === :z + @test length(x_mgdeps) === 0 + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 2 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :z + @test length(y_mgdeps) === 1 && name(y_mgdeps[1]) === :y + @test mean_var(Rocket.getrecent(ReactiveMP.getmarginal(y, IncludeAll()))) == (0.123, 0.123) + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 2 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :y + @test length(z_mgdeps) === 1 && name(z_mgdeps[1]) === :z + @test isnothing(Rocket.getrecent(ReactiveMP.getmarginal(z, IncludeAll()))) + end + + @testset "Require marginal functional dependencies: MeanField" begin + # Require marginal on `x` + pipeline = RequireMarginalFunctionalDependencies((1,), (NormalMeanVariance(0.123, 0.123),)) + + # We test `MeanField` case here + m, x, y, z, node = make_dummy_model(MeanField(), pipeline) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 0 + @test length(x_mgdeps) === 3 && name(x_mgdeps[1]) === :x && name(x_mgdeps[2]) === :y && + name(x_mgdeps[3]) === :z + @test mean_var(Rocket.getrecent(ReactiveMP.getmarginal(x, IncludeAll()))) == (0.123, 0.123) + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 0 + @test length(y_mgdeps) === 2 && name(y_mgdeps[1]) === :x && name(y_mgdeps[2]) === :z + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 0 + @test length(z_mgdeps) === 2 && name(z_mgdeps[1]) === :x && name(z_mgdeps[2]) === :y + + ## -- ## + + # Require marginals on `y` and `z` + pipeline = RequireMarginalFunctionalDependencies((2, 3), (NormalMeanVariance(0.123, 0.123), nothing)) + + # We test `MeanField` case here + m, x, y, z, node = make_dummy_model(MeanField(), pipeline) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 0 + @test length(x_mgdeps) === 2 && name(x_mgdeps[1]) === :y && name(x_mgdeps[2]) === :z + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 0 + @test length(y_mgdeps) === 3 && name(y_mgdeps[1]) === :x && name(y_mgdeps[2]) === :y && + name(y_mgdeps[3]) === :z + @test mean_var(Rocket.getrecent(ReactiveMP.getmarginal(y, IncludeAll()))) == (0.123, 0.123) + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 0 + @test length(z_mgdeps) === 3 && name(z_mgdeps[1]) === :x && name(z_mgdeps[2]) === :y && + name(z_mgdeps[3]) === :z + @test isnothing(Rocket.getrecent(ReactiveMP.getmarginal(z, IncludeAll()))) + end + + @testset "Require marginal functional dependencies: Structured factorisation" begin + # Require marginal on `y` and `z` + pipeline = RequireMarginalFunctionalDependencies((2, 3), (NormalMeanVariance(0.123, 0.123), nothing)) + + # We test `(x, y), (z)` factorisation case here + m, x, y, z, node = make_dummy_model(((1, 2), (3,)), pipeline) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 1 && name(x_msgdeps[1]) === :y + @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :z + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 1 && name(y_msgdeps[1]) === :x + @test length(y_mgdeps) === 2 && name(y_mgdeps[1]) === :y && name(y_mgdeps[2]) === :z + @test mean_var(Rocket.getrecent(ReactiveMP.getmarginal(y, IncludeAll()))) == (0.123, 0.123) + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 0 + @test length(z_mgdeps) === 2 && name(z_mgdeps[1]) === :x_y && name(z_mgdeps[2]) === :z + @test isnothing(Rocket.getrecent(ReactiveMP.getmarginal(z, IncludeAll()))) + + ## --- ## + + # Require marginals on `y` and `z` + pipeline = RequireMarginalFunctionalDependencies((2, 3), (NormalMeanVariance(0.123, 0.123), nothing)) + + # We test `(x, ), (y, z)` factorisation case here + m, x, y, z, node = make_dummy_model(((1,), (2, 3)), pipeline) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 0 + @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :y_z + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 1 && name(y_msgdeps[1]) === :z + @test length(y_mgdeps) === 2 && name(y_mgdeps[1]) === :x && name(y_mgdeps[2]) === :y + @test mean_var(Rocket.getrecent(ReactiveMP.getmarginal(y, IncludeAll()))) == (0.123, 0.123) + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 1 && name(z_msgdeps[1]) === :y + @test length(z_mgdeps) === 2 && name(z_mgdeps[1]) === :x && name(z_mgdeps[2]) === :z + @test isnothing(Rocket.getrecent(ReactiveMP.getmarginal(z, IncludeAll()))) + + ## --- ## + + # Require marginals on `y` and `z` + pipeline = RequireMarginalFunctionalDependencies((2, 3), (NormalMeanVariance(0.123, 0.123), nothing)) + + # We test `(x, z), (y, )` factorisation case here + m, x, y, z, node = make_dummy_model(((1, 3), (2,)), pipeline) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 1 && name(x_msgdeps[1]) === :z + @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :y + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 0 + @test length(y_mgdeps) === 2 && name(y_mgdeps[1]) === :x_z && name(y_mgdeps[2]) === :y + @test mean_var(Rocket.getrecent(ReactiveMP.getmarginal(y, IncludeAll()))) == (0.123, 0.123) + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 1 && name(z_msgdeps[1]) === :x + @test length(z_mgdeps) === 2 && name(z_mgdeps[1]) === :y && name(z_mgdeps[2]) === :z + @test isnothing(Rocket.getrecent(ReactiveMP.getmarginal(z, IncludeAll()))) + end + end + + @testset "Require everything functional dependencies" begin + @testset "Require everything functional dependencies: FullFactorisation" begin + pipeline = RequireEverythingFunctionalDependencies() + + # We test `FullFactorisation` case here + m, x, y, z, node = make_dummy_model(FullFactorisation(), pipeline) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 3 && name(x_msgdeps[1]) === :x && name(x_msgdeps[2]) === :y && + name(x_msgdeps[3]) === :z + @test length(x_mgdeps) === 1 && name(x_mgdeps[1]) === :x_y_z + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 3 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :y && + name(y_msgdeps[3]) === :z + @test length(y_mgdeps) === 1 && name(y_mgdeps[1]) === :x_y_z + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 3 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :y && + name(z_msgdeps[3]) === :z + @test length(z_mgdeps) === 1 && name(z_mgdeps[1]) === :x_y_z + end + + @testset "Require everything functional dependencies: MeanField" begin + pipeline = RequireEverythingFunctionalDependencies() + + # We test `MeanField` case here + m, x, y, z, node = make_dummy_model(MeanField(), pipeline) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 3 && name(x_mgdeps[1]) === :x && name(x_mgdeps[2]) === :y && + name(x_mgdeps[3]) === :z + @test length(x_mgdeps) === 3 && name(x_mgdeps[1]) === :x && name(x_mgdeps[2]) === :y && + name(x_mgdeps[3]) === :z + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 3 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :y && + name(y_msgdeps[3]) === :z + @test length(y_mgdeps) === 3 && name(y_mgdeps[1]) === :x && name(y_mgdeps[2]) === :y && + name(y_mgdeps[3]) === :z + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 3 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :y && + name(z_msgdeps[3]) === :z + @test length(z_mgdeps) === 3 && name(z_mgdeps[1]) === :x && name(z_mgdeps[2]) === :y && + name(z_mgdeps[3]) === :z + end + + @testset "Require everything dependencies: Structured factorisation" begin + pipeline = RequireEverythingFunctionalDependencies() + + # We test `(x, y), (z)` factorisation case here + m, x, y, z, node = make_dummy_model(((1, 2), (3,)), pipeline) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 3 && name(x_msgdeps[1]) === :x && name(x_msgdeps[2]) === :y && + name(x_msgdeps[3]) === :z + @test length(x_mgdeps) === 2 && name(x_mgdeps[1]) === :x_y && name(x_mgdeps[2]) === :z + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 3 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :y && + name(y_msgdeps[3]) === :z + @test length(y_mgdeps) === 2 && name(y_mgdeps[1]) === :x_y && name(y_mgdeps[2]) === :z + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 3 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :y && + name(z_msgdeps[3]) === :z + @test length(z_mgdeps) === 2 && name(z_mgdeps[1]) === :x_y && name(z_mgdeps[2]) === :z + + ## --- ## + + pipeline = RequireEverythingFunctionalDependencies() + + # We test `(x, ), (y, z)` factorisation case here + m, x, y, z, node = make_dummy_model(((1,), (2, 3)), pipeline) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 3 && name(x_msgdeps[1]) === :x && name(x_msgdeps[2]) === :y && + name(x_msgdeps[3]) === :z + @test length(x_mgdeps) === 2 && name(x_mgdeps[1]) === :x && name(x_mgdeps[2]) === :y_z + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 3 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :y && + name(y_msgdeps[3]) === :z + @test length(y_mgdeps) === 2 && name(y_mgdeps[1]) === :x && name(y_mgdeps[2]) === :y_z + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 3 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :y && + name(z_msgdeps[3]) === :z + @test length(z_mgdeps) === 2 && name(z_mgdeps[1]) === :x && name(z_mgdeps[2]) === :y_z + + ## --- ## + + pipeline = RequireEverythingFunctionalDependencies() + + # We test `(x, z), (y, )` factorisation case here + m, x, y, z, node = make_dummy_model(((1, 3), (2,)), pipeline) + + # Test that pipeline dependencies have been set properly + @test ReactiveMP.get_pipeline_dependencies(ReactiveMP.getpipeline(node)) === pipeline + + x_msgdeps, x_mgdeps = ReactiveMP.functional_dependencies(node, :x) + + @test length(x_msgdeps) === 3 && name(x_msgdeps[1]) === :x && name(x_msgdeps[2]) === :y && + name(x_msgdeps[3]) === :z + @test length(x_mgdeps) === 2 && name(x_mgdeps[1]) === :x_z && name(x_mgdeps[2]) === :y + + y_msgdeps, y_mgdeps = ReactiveMP.functional_dependencies(node, :y) + + @test length(y_msgdeps) === 3 && name(y_msgdeps[1]) === :x && name(y_msgdeps[2]) === :y && + name(y_msgdeps[3]) === :z + @test length(y_mgdeps) === 2 && name(y_mgdeps[1]) === :x_z && name(y_mgdeps[2]) === :y + + z_msgdeps, z_mgdeps = ReactiveMP.functional_dependencies(node, :z) + + @test length(z_msgdeps) === 3 && name(z_msgdeps[1]) === :x && name(z_msgdeps[2]) === :y && + name(z_msgdeps[3]) === :z + @test length(z_mgdeps) === 2 && name(z_mgdeps[1]) === :x_z && name(z_mgdeps[2]) === :y + end + end + end + @testset "@node macro" begin # Testing Stochastic node specification