Skip to content

Commit

Permalink
Refactor graph creation and node function handling (#163)
Browse files Browse the repository at this point in the history
This PR introduces several enhancements and refactors the graph creation
process and node function handling in JuliaBUGS. The main changes
include:

* Sharing node function expressions across nodes:

Previously, each node in the graph had its own unique node function
expression. This PR modifies the behavior so that nodes originating from
the same statement in the model definition share the same node function
expression.

  Example:
  ```julia
  @bugs begin
      for i in 1:2
          x[i] ~ dnorm(0, 1)
          y[i] ~ dnorm(x[i], i)
      end
  end
  ````
In the previous version, the nodes for x[1], x[2], y[1], and y[2] would
have separate node function expressions. Now, they will share the
expressions dnorm(0, 1) and dnorm(x[i], i) based on the corresponding
statements.

The function `replace_constants_in_expr` used to plugin all the scalar
values into the node function expr, now the function is removed. The new
node function is a function takes all the variable on the RHS as
arguments (including loop variables).
  E.g., 
  ```julia
  function (;x::AbstractArray{Float64}, i::Int)
      return dnorm(x[1], i)
  end
  ```
The binding of loop variables to values are stored at nodes, and used
when evaluating node function.

This change reduces memory usage and paves the way for potentially
evaluating node functions once and using compiled functions during model
evaluation.

* Simplifying the graph creation process:
  The graph building algorithm has been overhauled for clarity.

  Example:
  ```julia
  @bugs begin
      x[1:2] ~ dmnorm(...)
      x[3] ~ dnorm(0, 1)
      y ~ dnorm(sum(x[2:3]), 1)
  end
  ```
In the previous version, temporary nodes were created for variables used
on the RHS that were not explicitly defined in the model, such as `x[1],
x[2], x[2:3]`. These temporary nodes needed to be removed later, adding
complexity to the graph construction process.
  
  The new approach follows a two-stage process:
  
- In the first stage, nodes are created for all variables explicitly
defined in the model, also a matrix containing node id is created for
each variable. In this example, `x[1:2]` has id `1`, `x[3]` has id `2`,
`y` has id `3`. And id tracker looks like `x_ids = [1, 1, 2]`
  
- In the second stage, edges are inserted between the nodes based on the
dependencies specified in the model statements. The node id is looked up
and edges are created accordingly.
  
This eliminates the need for creating and removing temporary nodes,
resulting in a cleaner and more efficient graph construction process.

* Renaming `_eval` to `bugs_eval`

* The Var struct and associated functions have been removed. Variables
are now represented using Tuple{Symbol,
Vararg{Union{Int,UnitRange{Int}}}}.
  ```julia
  # Previous representation using Var
  x = Var(:x)
  y = Var(:y, (1, 2))
  
  # New representation using tuples
  x = (:x,)
  y = (:y, 1, 2)
  ```

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
sunxd3 and github-actions[bot] authored Mar 16, 2024
1 parent e57b892 commit dddc106
Show file tree
Hide file tree
Showing 22 changed files with 467 additions and 877 deletions.
15 changes: 6 additions & 9 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "JuliaBUGS"
uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
version = "0.4.1"
version = "0.5.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -20,11 +20,9 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
Expand All @@ -51,18 +49,17 @@ Bijectors = "0.13"
Distributions = "0.23.8, 0.24, 0.25"
Documenter = "0.27, 1"
DynamicPPL = "0.22, 0.23, 0.24"
Graphs = "1.4.1"
Graphs = "1"
JuliaSyntax = "0.4"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.6"
LogDensityProblemsAD = "1"
LogExpFunctions = "0.3"
MacroTools = "0.5.6"
MetaGraphsNext = "0.5, 0.6"
MacroTools = "0.5"
MetaGraphsNext = "0.6, 0.7"
PDMats = "0.10, 0.11"
Setfield = "0.7.1, 0.8, 1"
SpecialFunctions = "2"
StaticArrays = "1.9"
UnPack = "1"
Statistics = "1.9"
julia = "1.9"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Documenter
using JuliaBUGS
using JuliaBUGS: @bugs, compile, BUGSModel, BUGSGraph, ConcreteNodeInfo
using JuliaBUGS: @bugs, compile, BUGSModel, BUGSGraph
using MetaGraphsNext
using JuliaBUGS.BUGSPrimitives
using DynamicPPL: SimpleVarInfo
Expand Down
1 change: 0 additions & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,5 @@
@bugs
compile
BUGSModel
ConcreteNodeInfo
BUGSGraph
```
2 changes: 1 addition & 1 deletion ext/JuliaBUGSAdvancedHMCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using JuliaBUGS:
find_generated_vars,
LogDensityContext,
evaluate!!,
_eval
bugs_eval
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.DynamicPPL
using JuliaBUGS.LogDensityProblems
Expand Down
1 change: 0 additions & 1 deletion ext/JuliaBUGSAdvancedMHExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ using JuliaBUGS.LogDensityProblems
using JuliaBUGS.LogDensityProblemsAD
using JuliaBUGS.Random
using JuliaBUGS.Bijectors
using JuliaBUGS.UnPack
using MCMCChains: Chains
import JuliaBUGS: gibbs_internal

Expand Down
23 changes: 9 additions & 14 deletions src/JuliaBUGS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ using LogDensityProblems, LogDensityProblemsAD
using MacroTools
using MetaGraphsNext
using Random
using Setfield
using StaticArrays
using UnPack

using DynamicPPL: DynamicPPL, SimpleVarInfo

Expand All @@ -33,9 +31,8 @@ include("parser/Parser.jl")
using .Parser

include("utils.jl")
include("variable_types.jl")
include("compiler_pass.jl")
include("graphs.jl")
include("compiler_pass.jl")
include("model.jl")
include("logdensityproblems.jl")
include("gibbs.jl")
Expand Down Expand Up @@ -127,11 +124,13 @@ function finish_checking_repeated_assignments(
end
end

function compute_node_functions(model_def, eval_env)
pass = NodeFunctions(eval_env)
function create_graph(model_def, eval_env)
pass = AddVertices(model_def, eval_env)
analyze_block(pass, model_def)
pass = AddEdges(pass.env, pass.g, pass.vertex_id_tracker)
analyze_block(pass, model_def)
vars, node_args, node_functions, dependencies = post_process(pass)
return vars, node_args, node_functions, dependencies

return pass.g
end

function semantic_analysis(model_def, data)
Expand Down Expand Up @@ -164,12 +163,8 @@ function compile(model_def::Expr, data, inits; is_transformed=true)
data, inits = check_input(data), check_input(inits)
eval_env = semantic_analysis(model_def, data)
model_def = concretize_colon_indexing(model_def, eval_env)
vars, node_args, node_functions, dependencies = compute_node_functions(
model_def, eval_env
)
g = create_BUGSGraph(vars, node_args, node_functions, dependencies)
sorted_nodes = map(Base.Fix1(label_for, g), topological_sort(g))
return BUGSModel(g, sorted_nodes, eval_env, inits; is_transformed=is_transformed)
g = create_graph(model_def, eval_env)
return BUGSModel(g, eval_env, inits; is_transformed=is_transformed)
end

"""
Expand Down
Loading

2 comments on commit dddc106

@sunxd3
Copy link
Member Author

@sunxd3 sunxd3 commented on dddc106 Mar 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/103005

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.0 -m "<description of version>" dddc106f09e78a265bd412d7bcd4cd101473a9f6
git push origin v0.5.0

Please sign in to comment.