Skip to content

Commit

Permalink
Merge branch 'main' into sharp-bits
Browse files Browse the repository at this point in the history
  • Loading branch information
bvdmitri authored Dec 23, 2024
2 parents 45957fa + 74753ca commit 8d82cba
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 4 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ jobs:
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v2
with:
version: '1.10' # 1.11 is not supported by Cairo/GraphViz
- uses: actions/cache@v3
id: examples
with:
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "RxInfer"
uuid = "86711068-29c9-4ff7-b620-ae75d7495b3d"
authors = ["Bagaev Dmitry <[email protected]> and contributors"]
version = "3.8.0"
version = "3.8.2"

[deps]
BayesBase = "b4ee3484-f114-42fe-b91c-797d54a0c67e"
Expand Down Expand Up @@ -36,7 +36,7 @@ DomainSets = "0.5.2, 0.6, 0.7"
ExponentialFamily = "1.5"
ExponentialFamilyProjection = "1.1"
FastCholesky = "1.3.0"
GraphPPL = "~4.4.0"
GraphPPL = "~4.5.0"
LinearAlgebra = "1.9"
MacroTools = "0.5.6"
Optim = "1.0.0"
Expand Down
4 changes: 2 additions & 2 deletions codemeta.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
"downloadUrl": "https://github.com/reactivebayes/RxInfer.jl/releases",
"issueTracker": "https://github.com/reactivebayes/RxInfer.jl/issues",
"name": "RxInfer.jl",
"version": "3.8.0",
"version": "3.8.2",
"description": "Julia package for automated, scalable and efficient Bayesian inference on factor graphs with reactive message passing. ",
"applicationCategory": "Statistics",
"developmentStatus": "active",
"readme": "https://reactivebayes.github.io/RxInfer.jl/stable/",
"softwareVersion": "3.8.0",
"softwareVersion": "3.8.2",
"keywords": [
"Bayesian inference",
"message passing",
Expand Down
84 changes: 84 additions & 0 deletions docs/src/manuals/customization/custom-node.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,3 +309,87 @@ nothing # hide
```

Congratulations! You have successfully implemented your own custom node in `RxInfer`. We went through the definition of a node to the implementation of the update rules and marginal posterior calculations. Finally we tested our custom node in a model and checked if we implemented everything correctly.

# [Custom node experimental functionality](@id custom-node-experimental)

!!! warning "Experimental features"
The functionality described below is experimental and subject to change in future releases. Use it with caution in production code.

## [Rules that require a reference to a node object](@id inference-ruleswithnode)

In some advanced scenarios, you might need access to the node object itself within a message passing rule. This can be useful when:
- You need to inspect the current state of other variables in the model
- You want to implement complex message passing schemes that depend on the global model state
- You're experimenting with custom inference algorithms that require access to the factor graph structure

Here's how to implement a rule with node access. First we define a custom node and a simple model that uses this node:

```@example custom-node-node-in-a-rule
using RxInfer
struct MyExperimentalNode end
@node MyExperimentalNode Stochastic [ out, θ ]
@model function my_experimental_model(y)
θ ~ Normal(mean = 0.0, variance = 1.0)
y ~ MyExperimentalNode(θ)
end
```

Second, we enable instruction to the inference backend to pass node reference to the rule.

```@example custom-node-node-in-a-rule
# Enable node reference passing for this node type
ReactiveMP.call_rule_is_node_required(::Type{<:MyExperimentalNode}) = ReactiveMP.CallRuleNodeRequired()
```

!!! note "Performance Impact"
Enabling node reference passing can negatively impact performance as it requires additional bookkeeping during inference.

!!! danger "Global State"
Setting `call_rule_is_node_required` for existing nodes (like `NormalMeanVariance`) affects all models globally and will affect code that depends on your package. Only safe to use this for your custom nodes.

The `call_rule_is_node_required` function is used to instruct the inference backend to pass the node object to the rule. After this is set, we can use the `getnode()` function to access the node object within the rule.

```@example custom-node-node-in-a-rule
@rule MyExperimentalNode(:θ, Marginalisation) (q_out::Any, ) = begin
node = getnode()
# Access interface index
ii = ReactiveMP.interfaceindex(node, :θ)
# Get interface object
θi = ReactiveMP.getinterfaces(node)[ii]
# Get variable object
θv = ReactiveMP.getvariable(θi)
# By default, `germarginal` ignores marginals set in the @initialization block
# `IncludeAll` overrides this behavior and includes all marginals
qθ = Rocket.getrecent(ReactiveMP.getmarginal(θv, IncludeAll()))
# This is a simple rule that returns a NormalMeanVariance distribution
# It could be replaced with any other rule that returns a distribution
return NormalMeanVariance(mean(qθ) + mean(q_out), var(qθ))
end
```

### Running inference with the custom node and rule

Here's a full example showing how to use this functionality:

```@example custom-node-node-in-a-rule
initialization = @initialization begin
q(θ) = NormalMeanVariance(3.14, 2.71)
end
result = infer(
model = my_experimental_model(),
data = (y = 1.0, ),
initialization = initialization
)
nothing #hide
```

As we can see, the print statement in the rule is executed, which means that the node reference passing is working as expected. This feature opens up possibilities for advanced inference scenarios, but should be used judiciously. Consider whether your use case truly requires access to the node object, as simpler solutions using standard message passing rules are often sufficient and more maintainable.



0 comments on commit 8d82cba

Please sign in to comment.