From dddc106f09e78a265bd412d7bcd4cd101473a9f6 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Sat, 16 Mar 2024 14:40:06 +0000 Subject: [PATCH] Refactor graph creation and node function handling (#163) This PR introduces several enhancements and refactors the graph creation process and node function handling in JuliaBUGS. The main changes include: * Sharing node function expressions across nodes: Previously, each node in the graph had its own unique node function expression. This PR modifies the behavior so that nodes originating from the same statement in the model definition share the same node function expression. Example: ```julia @bugs begin for i in 1:2 x[i] ~ dnorm(0, 1) y[i] ~ dnorm(x[i], i) end end ```` In the previous version, the nodes for x[1], x[2], y[1], and y[2] would have separate node function expressions. Now, they will share the expressions dnorm(0, 1) and dnorm(x[i], i) based on the corresponding statements. The function `replace_constants_in_expr` used to plugin all the scalar values into the node function expr, now the function is removed. The new node function is a function takes all the variable on the RHS as arguments (including loop variables). E.g., ```julia function (;x::AbstractArray{Float64}, i::Int) return dnorm(x[1], i) end ``` The binding of loop variables to values are stored at nodes, and used when evaluating node function. This change reduces memory usage and paves the way for potentially evaluating node functions once and using compiled functions during model evaluation. * Simplifying the graph creation process: The graph building algorithm has been overhauled for clarity. Example: ```julia @bugs begin x[1:2] ~ dmnorm(...) x[3] ~ dnorm(0, 1) y ~ dnorm(sum(x[2:3]), 1) end ``` In the previous version, temporary nodes were created for variables used on the RHS that were not explicitly defined in the model, such as `x[1], x[2], x[2:3]`. These temporary nodes needed to be removed later, adding complexity to the graph construction process. The new approach follows a two-stage process: - In the first stage, nodes are created for all variables explicitly defined in the model, also a matrix containing node id is created for each variable. In this example, `x[1:2]` has id `1`, `x[3]` has id `2`, `y` has id `3`. And id tracker looks like `x_ids = [1, 1, 2]` - In the second stage, edges are inserted between the nodes based on the dependencies specified in the model statements. The node id is looked up and edges are created accordingly. This eliminates the need for creating and removing temporary nodes, resulting in a cleaner and more efficient graph construction process. * Renaming `_eval` to `bugs_eval` * The Var struct and associated functions have been removed. Variables are now represented using Tuple{Symbol, Vararg{Union{Int,UnitRange{Int}}}}. ```julia # Previous representation using Var x = Var(:x) y = Var(:y, (1, 2)) # New representation using tuples x = (:x,) y = (:y, 1, 2) ``` --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Project.toml | 15 +- docs/make.jl | 2 +- docs/src/api.md | 1 - ext/JuliaBUGSAdvancedHMCExt.jl | 2 +- ext/JuliaBUGSAdvancedMHExt.jl | 1 - src/JuliaBUGS.jl | 23 +- src/compiler_pass.jl | 390 +++++++++++++++------------------ src/graphs.jl | 195 +---------------- src/model.jl | 371 ++++++++++++++++--------------- src/utils.jl | 53 ++--- src/variable_types.jl | 131 ----------- test/gibbs.jl | 12 +- test/graphs.jl | 2 +- test/logp_tests/binomial.jl | 6 +- test/logp_tests/blockers.jl | 14 +- test/logp_tests/bones.jl | 8 +- test/logp_tests/dogs.jl | 2 +- test/logp_tests/gamma.jl | 6 +- test/logp_tests/rats.jl | 53 ++--- test/profile.jl | 13 +- test/run_logp_tests.jl | 26 --- test/runtests.jl | 18 +- 22 files changed, 467 insertions(+), 877 deletions(-) delete mode 100644 src/variable_types.jl delete mode 100644 test/run_logp_tests.jl diff --git a/Project.toml b/Project.toml index 0dcd235fd..d87f9fdc5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "JuliaBUGS" uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf" -version = "0.4.1" +version = "0.5.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -20,11 +20,9 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [weakdeps] AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" @@ -51,18 +49,17 @@ Bijectors = "0.13" Distributions = "0.23.8, 0.24, 0.25" Documenter = "0.27, 1" DynamicPPL = "0.22, 0.23, 0.24" -Graphs = "1.4.1" +Graphs = "1" JuliaSyntax = "0.4" LogDensityProblems = "2" -LogDensityProblemsAD = "1.6" +LogDensityProblemsAD = "1" LogExpFunctions = "0.3" -MacroTools = "0.5.6" -MetaGraphsNext = "0.5, 0.6" +MacroTools = "0.5" +MetaGraphsNext = "0.6, 0.7" PDMats = "0.10, 0.11" -Setfield = "0.7.1, 0.8, 1" SpecialFunctions = "2" StaticArrays = "1.9" -UnPack = "1" +Statistics = "1.9" julia = "1.9" [extras] diff --git a/docs/make.jl b/docs/make.jl index 0692c8562..1b1b2ab1c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,6 +1,6 @@ using Documenter using JuliaBUGS -using JuliaBUGS: @bugs, compile, BUGSModel, BUGSGraph, ConcreteNodeInfo +using JuliaBUGS: @bugs, compile, BUGSModel, BUGSGraph using MetaGraphsNext using JuliaBUGS.BUGSPrimitives using DynamicPPL: SimpleVarInfo diff --git a/docs/src/api.md b/docs/src/api.md index 3798cd067..b0516d325 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -4,6 +4,5 @@ @bugs compile BUGSModel -ConcreteNodeInfo BUGSGraph ``` \ No newline at end of file diff --git a/ext/JuliaBUGSAdvancedHMCExt.jl b/ext/JuliaBUGSAdvancedHMCExt.jl index dad9f4851..0d67bc85f 100644 --- a/ext/JuliaBUGSAdvancedHMCExt.jl +++ b/ext/JuliaBUGSAdvancedHMCExt.jl @@ -11,7 +11,7 @@ using JuliaBUGS: find_generated_vars, LogDensityContext, evaluate!!, - _eval + bugs_eval using JuliaBUGS.BUGSPrimitives using JuliaBUGS.DynamicPPL using JuliaBUGS.LogDensityProblems diff --git a/ext/JuliaBUGSAdvancedMHExt.jl b/ext/JuliaBUGSAdvancedMHExt.jl index b7d014cbb..307a716dc 100644 --- a/ext/JuliaBUGSAdvancedMHExt.jl +++ b/ext/JuliaBUGSAdvancedMHExt.jl @@ -10,7 +10,6 @@ using JuliaBUGS.LogDensityProblems using JuliaBUGS.LogDensityProblemsAD using JuliaBUGS.Random using JuliaBUGS.Bijectors -using JuliaBUGS.UnPack using MCMCChains: Chains import JuliaBUGS: gibbs_internal diff --git a/src/JuliaBUGS.jl b/src/JuliaBUGS.jl index 0125ee56c..83ec02f77 100644 --- a/src/JuliaBUGS.jl +++ b/src/JuliaBUGS.jl @@ -10,9 +10,7 @@ using LogDensityProblems, LogDensityProblemsAD using MacroTools using MetaGraphsNext using Random -using Setfield using StaticArrays -using UnPack using DynamicPPL: DynamicPPL, SimpleVarInfo @@ -33,9 +31,8 @@ include("parser/Parser.jl") using .Parser include("utils.jl") -include("variable_types.jl") -include("compiler_pass.jl") include("graphs.jl") +include("compiler_pass.jl") include("model.jl") include("logdensityproblems.jl") include("gibbs.jl") @@ -127,11 +124,13 @@ function finish_checking_repeated_assignments( end end -function compute_node_functions(model_def, eval_env) - pass = NodeFunctions(eval_env) +function create_graph(model_def, eval_env) + pass = AddVertices(model_def, eval_env) + analyze_block(pass, model_def) + pass = AddEdges(pass.env, pass.g, pass.vertex_id_tracker) analyze_block(pass, model_def) - vars, node_args, node_functions, dependencies = post_process(pass) - return vars, node_args, node_functions, dependencies + + return pass.g end function semantic_analysis(model_def, data) @@ -164,12 +163,8 @@ function compile(model_def::Expr, data, inits; is_transformed=true) data, inits = check_input(data), check_input(inits) eval_env = semantic_analysis(model_def, data) model_def = concretize_colon_indexing(model_def, eval_env) - vars, node_args, node_functions, dependencies = compute_node_functions( - model_def, eval_env - ) - g = create_BUGSGraph(vars, node_args, node_functions, dependencies) - sorted_nodes = map(Base.Fix1(label_for, g), topological_sort(g)) - return BUGSModel(g, sorted_nodes, eval_env, inits; is_transformed=is_transformed) + g = create_graph(model_def, eval_env) + return BUGSModel(g, eval_env, inits; is_transformed=is_transformed) end """ diff --git a/src/compiler_pass.jl b/src/compiler_pass.jl index 7c7e74f0c..067927841 100644 --- a/src/compiler_pass.jl +++ b/src/compiler_pass.jl @@ -469,22 +469,6 @@ function analyze_statement(pass::DataTransformation, expr::Expr, loop_vars::Name end end -""" - NodeFunctions - -A pass that analyze node functions of variables and their dependencies. -""" -struct NodeFunctions <: CompilerPass - env::NamedTuple - vars::Dict - node_args::Dict - node_functions::Dict - dependencies::Dict -end -function NodeFunctions(eval_env) - return NodeFunctions(eval_env, Dict(), Dict(), Dict(), Dict()) -end - """ evaluate_and_track_dependencies(var, env) @@ -504,50 +488,48 @@ Array elements and array variables are represented by tuples in the returned val # Examples ```jldoctest julia> evaluate_and_track_dependencies(:(x[a]), (x=[missing, missing], a = missing)) -(missing, (:a, (:x, 1:2)), (:x, :a)) +(missing, (:a, (:x, 1:2))) julia> evaluate_and_track_dependencies(:(x[a]), (x=[missing, missing], a = 1)) -(missing, ((:x, 1),), (:x, :a)) +(missing, ((:x, 1),)) julia> evaluate_and_track_dependencies(:(x[y[1]+1]+a+1), (x=[missing, missing], y = [missing, missing], a = missing)) -(missing, ((:y, 1), (:x, 1:2), :a), (:x, :y, :a)) +(missing, ((:y, 1), (:x, 1:2), :a)) julia> evaluate_and_track_dependencies(:(x[a, b]), (x = [1 2 3; 4 5 6], a = missing, b = missing)) -(missing, (:a, :b, (:x, 1:2, 1:3)), (:x, :a, :b)) +(missing, (:a, :b, (:x, 1:2, 1:3))) julia> evaluate_and_track_dependencies(:(getindex(x[1:2, 1:3], a, b)), (x = [1 2 3; 4 5 6], a = missing, b = missing)) -(missing, (:a, :b), (:x, :a, :b)) +(missing, (:a, :b)) julia> evaluate_and_track_dependencies(:(getindex(x[1:2, 1:3], 1, 1)), (x = [1 2 3; 4 5 6], a = missing, b = missing)) -(1, (), (:x,)) +(1, ()) julia> evaluate_and_track_dependencies(:(getindex(x[1:2, 1:3], a, b)), (x = [1 2 missing; 4 5 6], a = missing, b = missing)) -(missing, ((:x, 1:2, 1:3), :a, :b), (:x, :a, :b)) +(missing, ((:x, 1:2, 1:3), :a, :b)) ``` """ -evaluate_and_track_dependencies(var::Union{Int,Float64}, env) = var, (), () -evaluate_and_track_dependencies(var::UnitRange, env) = var, (), () +evaluate_and_track_dependencies(var::Union{Int,Float64}, env) = var, () +evaluate_and_track_dependencies(var::UnitRange, env) = var, () function evaluate_and_track_dependencies(var::Symbol, env) if var ∈ (:nothing, :missing, :(:)) - return var, (), () + return var, () end if env[var] === missing - return var, (var,), (var,) + return var, (var,) else - return env[var], (), (var,) + return env[var], () end end function evaluate_and_track_dependencies(var::Expr, env) - dependencies, node_func_args = [], [] + dependencies = [] if Meta.isexpr(var, :ref) v, indices... = var.args - push!(node_func_args, v) for i in eachindex(indices) ret = evaluate_and_track_dependencies(indices[i], env) index = ret[1] indices[i] = index isa Float64 ? Int(index) : index dependencies = union!(dependencies, ret[2]) - node_func_args = union!(node_func_args, ret[3]) end value = nothing @@ -556,9 +538,8 @@ function evaluate_and_track_dependencies(var::Expr, env) end value = env[v][indices...] if is_resolved(value) - return value, Tuple(dependencies), Tuple(node_func_args) + return value, Tuple(dependencies) else - # TODO: what if value is partially missing? push!(dependencies, (v, indices...)) end else @@ -573,7 +554,7 @@ function evaluate_and_track_dependencies(var::Expr, env) ), ) end - return missing, Tuple(dependencies), Tuple(node_func_args) + return missing, Tuple(dependencies) elseif Meta.isexpr(var, :call) f, args... = var.args value = nothing @@ -591,7 +572,6 @@ function evaluate_and_track_dependencies(var::Expr, env) for i in eachindex(indices) ret = evaluate_and_track_dependencies(indices[i], env) union!(dependencies, ret[2]) - union!(node_func_args, ret[3]) indices[i] = ret[1] end if any(!is_resolved, indices) @@ -608,8 +588,7 @@ function evaluate_and_track_dependencies(var::Expr, env) ret = evaluate_and_track_dependencies(arg2, env) union!(dependencies, ret[2]) - union!(node_func_args, ret[3]) - return missing, Tuple(dependencies), Tuple(node_func_args) + return missing, Tuple(dependencies) elseif f === :deviance @warn( "`deviance` function is not supported in JuliaBUGS, `deviance` will be treated as a general function." @@ -619,17 +598,14 @@ function evaluate_and_track_dependencies(var::Expr, env) ret = evaluate_and_track_dependencies(args[i], env) args[i] = ret[1] union!(dependencies, ret[2]) - union!(node_func_args, ret[3]) end value = nothing if all(is_resolved, args) && f ∈ BUGSPrimitives.BUGS_FUNCTIONS ∪ (:+, :-, :*, :/, :^, :(:), :getindex) - return getfield(JuliaBUGS, f)(args...), - Tuple(dependencies), - Tuple(node_func_args) + return getfield(JuliaBUGS, f)(args...), Tuple(dependencies) else - return missing, Tuple(dependencies), Tuple(node_func_args) + return missing, Tuple(dependencies) end end else @@ -637,210 +613,188 @@ function evaluate_and_track_dependencies(var::Expr, env) end end -""" - replace_constants_in_expr(x, env) - -Replace the constants in the expression `x` with their actual values from the environment `env` if the values are concrete. +function build_node_functions( + expr::Expr, + eval_env::NamedTuple, + f_dict::Dict{Expr,Tuple{Tuple{Vararg{Symbol}},Expr,Any}}, + loop_vars::Tuple{Vararg{Symbol}}, +) + for statement in expr.args + if is_deterministic(statement) || is_stochastic(statement) + rhs = if is_deterministic(statement) + statement.args[2] + else + statement.args[3] + end + args, node_func_expr = make_function_expr(rhs, eval_env) + # node_func = eval(node_func_expr) + node_func = nothing + f_dict[statement] = (args, node_func_expr, node_func) + elseif Meta.isexpr(statement, :for) + loop_var, _, _, body = decompose_for_expr(statement) + build_node_functions(body, eval_env, f_dict, (loop_var, loop_vars...)) + else + error("Unknown statement type: $statement") + end + end + return f_dict +end -# Examples -```jldoctest -julia> env = Dict(:a => 1, :b => 2, :c => 3); +function make_function_expr(expr, env::NamedTuple{vars}) where {vars} + args = Tuple(keys(extract_variable_names_and_numdims(expr, ()))) + arg_exprs = Expr[] + for v in args + if v ∈ vars + value = env[v] + if value isa Int + push!(arg_exprs, Expr(:(::), v, :Int)) + elseif value isa Float64 + push!(arg_exprs, Expr(:(::), v, :Float64)) + elseif value isa Missing + push!(arg_exprs, Expr(:(::), v, :(Union{Int,Float64}))) + elseif value isa AbstractArray + T = nonmissingtype(eltype(value)) + if T === Union{} + T = Float64 + end + push!(arg_exprs, Expr(:(::), v, :(AbstractArray{$T}))) + else + error("Unexpected argument type: $(typeof(value))") + end + else # loop variable + push!(arg_exprs, Expr(:(::), v, :Int)) + end + end -julia> replace_constants_in_expr(:(a * b + c), env) -:(1 * 2 + 3) + return args, MacroTools.@q function (; $(arg_exprs...)) + return $(expr) + end +end -julia> replace_constants_in_expr(:(a + b * sin(c)), env) # won't try to evaluate function calls -:(1 + 2 * sin(3)) +mutable struct AddVertices <: CompilerPass + const env::NamedTuple + const g::MetaGraph + vertex_id_tracker::NamedTuple + const f_dict::Dict{Expr,Tuple{Tuple{Vararg{Symbol}},Expr,Any}} +end -julia> replace_constants_in_expr(:(x[a]), Dict(:x => [10, 20, 30], :a => 2)) # indexing into arrays are done if possible -20 +function AddVertices(model_def::Expr, eval_env::NamedTuple) + g = MetaGraph(DiGraph(); label_type=VarName, vertex_data_type=NodeInfo) + vertex_id_tracker = Dict{Symbol,Any}() + for (k, v) in pairs(eval_env) + if v isa AbstractArray + vertex_id_tracker[k] = zeros(Int, size(v)) + else + vertex_id_tracker[k] = 0 + end + end -julia> replace_constants_in_expr(:(x[a] + b), Dict(:x => [10, 20, 30], :a => 2, :b => 5)) -:(20 + 5) + f_dict = build_node_functions( + model_def, eval_env, Dict{Expr,Tuple{Tuple{Vararg{Symbol}},Expr,Any}}(), () + ) -julia> replace_constants_in_expr(:(x[1] + y[1]), Dict(:x => [10, 20, 30], :y => [40, 50, 60])) -:(10 + 40) -``` -""" -function replace_constants_in_expr(x, env) - result = _replace_constants_in_expr(x, env) - while result != x - x = result - result = _replace_constants_in_expr(x, env) - end - return x + return AddVertices(eval_env, g, NamedTuple(vertex_id_tracker), f_dict) end -_replace_constants_in_expr(x::Number, env) = x -function _replace_constants_in_expr(x::Symbol, env) - if haskey(env, x) && env[x] isa Number - return env[x] +function analyze_statement(pass::AddVertices, expr::Expr, loop_vars::NamedTuple) + lhs_expr = is_deterministic(expr) ? expr.args[1] : expr.args[2] + env = merge(pass.env, loop_vars) + lhs = simplify_lhs(env, lhs_expr) + is_stochastic = false + is_observed = false + lhs_value = if lhs isa Symbol + env[lhs] + else + var, indices... = lhs + env[var][indices...] end - return x -end -function _replace_constants_in_expr(x::Expr, env) - if Meta.isexpr(x, :ref) - v, indices... = x.args - if haskey(env, v) && all(x -> x isa Union{Int,Float64}, indices) - val = env[v][map(Int, indices)...] - return ismissing(val) ? x : val - else - for i in eachindex(indices) - indices[i] = _replace_constants_in_expr(indices[i], env) - end - return Expr(:ref, v, indices...) + if Meta.isexpr(expr, :(=)) + if is_resolved(lhs_value) + return nothing end - elseif Meta.isexpr(x, :call) - if x.args[1] === :cumulative || x.args[1] === :density - if length(x.args) != 3 - error( - "`cumulative` and `density` are special functions in BUGS and takes two arguments, got $(length(x.args) - 1)", - ) - end - f, arg1, arg2 = x.args - if arg1 isa Symbol - return Expr(:call, f, arg1, _replace_constants_in_expr(arg2, env)) - elseif Meta.isexpr(arg1, :ref) - v, indices... = arg1.args - for i in eachindex(indices) - indices[i] = _replace_constants_in_expr(indices[i], env) - end - return Expr( - :call, - f, - Expr(:ref, v, indices...), - _replace_constants_in_expr(arg2, env), - ) - else - error( - "First argument to `cumulative` and `density` must be variable, got $(x.args[2])", - ) - end - elseif x.args[1] === :deviance - @warn( - "`deviance` function is not supported in JuliaBUGS, `deviance` will be treated as a general function." - ) - else - x = deepcopy(x) # because we are mutating the args - for i in 2:length(x.args) - x.args[i] = _replace_constants_in_expr(x.args[i], env) - end - return x + else + is_stochastic = true + if is_resolved(lhs_value) + is_observed = true end + end + + args, node_function_expr, node_function = pass.f_dict[expr] + + vn = if lhs isa Symbol + AbstractPPL.VarName{lhs}(AbstractPPL.IdentityLens()) + else + v, indices... = lhs + AbstractPPL.VarName{v}(AbstractPPL.IndexLens(indices)) + end + add_vertex!( + pass.g, + vn, + NodeInfo( + is_stochastic, is_observed, node_function_expr, node_function, args, loop_vars + ), + ) + if lhs isa Symbol + pass.vertex_id_tracker = BangBang.setproperty!!( + pass.vertex_id_tracker, lhs, code_for(pass.g, vn) + ) else - error("Unexpected expression type: $x") + v, indices... = lhs + if any(indices) do i + i isa UnitRange + end + pass.vertex_id_tracker[v][indices...] .= code_for(pass.g, vn) + else + pass.vertex_id_tracker[v][indices...] = code_for(pass.g, vn) + end end end -function create_array_var(n, env) - return Var(n, Tuple([1:i for i in size(env[n])])) +struct AddEdges <: CompilerPass + env::NamedTuple + g::MetaGraph + vertex_id_tracker::NamedTuple end -function analyze_statement(pass::NodeFunctions, expr::Expr, loop_vars::NamedTuple) +function analyze_statement(pass::AddEdges, expr::Expr, loop_vars::NamedTuple) + lhs_expr, rhs_expr = is_deterministic(expr) ? expr.args[1:2] : expr.args[2:3] env = merge(pass.env, loop_vars) - - if is_deterministic(expr) - lhs_expr, rhs_expr = expr.args[1:2] - var_type = Logical + lhs = simplify_lhs(env, lhs_expr) + lhs_value = if lhs isa Symbol + env[lhs] else - lhs_expr, rhs_expr = expr.args[2:3] - var_type = Stochastic + var, indices... = lhs + env[var][indices...] + end + if Meta.isexpr(expr, :(=)) && is_resolved(lhs_value) + return nothing end - simplified_lhs = simplify_lhs(env, lhs_expr) - lhs_var = if simplified_lhs isa Symbol - Var(simplified_lhs) + _, dependencies = evaluate_and_track_dependencies(rhs_expr, env) + + lhs_vn = if lhs isa Symbol + @varname($lhs) else - v, indices... = simplified_lhs - Var(v, Tuple(indices)) + v, indices... = lhs + AbstractPPL.VarName{v}(AbstractPPL.IndexLens(indices)) end - var_type == Logical && - evaluate(lhs_expr, env) isa Union{Number,Array{<:Number}} && - return nothing - pass.vars[lhs_var] = var_type - rhs = evaluate(rhs_expr, env) - - if rhs isa Symbol - @assert lhs_var isa Union{Scalar,ArrayElement} - node_function = MacroTools.@q ($(rhs)) -> $(rhs) - node_args = [Var(rhs)] - dependencies = [Var(rhs)] - elseif Meta.isexpr(rhs, :ref) && - all(x -> x isa Union{Number,UnitRange}, rhs.args[2:end]) - @assert var_type == Logical # if rhs is a variable, then the expression must be logical - rhs_var = Var(rhs.args[1], Tuple(rhs.args[2:end])) - rhs_array_var = create_array_var(rhs_var.name, env) - size(rhs_var) == size(lhs_var) || - error("Size mismatch between lhs and rhs at expression $expr") - if lhs_var isa ArrayElement - node_function = MacroTools.@q ($(rhs_var.name)::Array) -> - $(rhs_var.name)[$(rhs_var.indices...)] - node_args = [rhs_array_var] - dependencies = [rhs_var] + for var in dependencies + vertex_code = if var isa Symbol + pass.vertex_id_tracker[var] else - # rhs is not evaluated into a concrete value, then at least some elements of the rhs array are not data - non_data_vars = filter(x -> x isa Var, evaluate(rhs_var, env)) - # for now: evaluate(rhs_var, env) will produce scalarized `Var`s, so dependencies - # may contain `Auxiliary Nodes`, this should be okay, but maybe we should keep things uniform - # by keep `dependencies` only variables in the model, not auxiliary nodes - node_function = MacroTools.@q ($(rhs_var.name)::Array) -> - $(rhs_var.name)[$(rhs_var.indices...)] - node_args = [rhs_array_var] - dependencies = non_data_vars + v, indices... = var + pass.vertex_id_tracker[v][indices...] end - else - rhs_expr = replace_constants_in_expr(rhs_expr, env) - evaled_rhs, dependencies, node_args = evaluate_and_track_dependencies(rhs_expr, env) - - # TODO: since we are not evaluating the node function expressions anymore, we don't have to store the expression like anonymous functions - # rhs can be evaluated into a concrete value here, because including transformed variables in the data - # is effectively constant propagation - if is_resolved(evaled_rhs) - node_function = Expr(:(->), Expr(:tuple), Expr(:block, evaled_rhs)) - node_args = [] - # we can also directly save the evaled variable to `env` and later convert to var_store - # issue is that we need to do this in steps, const propagation need to a separate pass - # otherwise the variable in previous expressions will not be evaluated to the concrete value - else - node_args = collect(Any, node_args) - for i in eachindex(node_args) - if env[node_args[i]] isa AbstractArray - node_args[i] = create_array_var(node_args[i], env) - else - node_args[i] = Var(node_args[i]) - end - end - - dependencies = collect(Any, dependencies) - for i in eachindex(dependencies) - if dependencies[i] isa Symbol - dependencies[i] = Var(dependencies[i]) - else - v, indices... = dependencies[i] - dependencies[i] = Var(v, Tuple(indices)) - end - end - args = similar(node_args, Any) - for (i, arg) in enumerate(node_args) - if arg isa ArrayVar - args[i] = Expr(:(::), arg.name, :Array) - elseif arg isa Scalar - args[i] = arg.name - else - error("Unexpected argument type: $arg") - end + vertex_code = filter( + !iszero, vertex_code isa AbstractArray ? vertex_code : [vertex_code] + ) + vertex_labels = map(x -> label_for(pass.g, x), vertex_code) + for r in vertex_labels + if r != lhs_vn + add_edge!(pass.g, r, lhs_vn) end - node_function = Expr(:(->), Expr(:tuple, args...), rhs_expr) end end - - pass.node_args[lhs_var] = node_args - pass.node_functions[lhs_var] = node_function - return pass.dependencies[lhs_var] = dependencies -end - -function post_process(pass::NodeFunctions) - return pass.vars, pass.node_args, pass.node_functions, pass.dependencies end diff --git a/src/graphs.jl b/src/graphs.jl index ed2fd6ad3..8064642b4 100644 --- a/src/graphs.jl +++ b/src/graphs.jl @@ -1,182 +1,19 @@ -abstract type NodeInfo end - -""" - AuxiliaryNodeInfo - -Indicate the node is created by the compiler and not in the original BUGS model. These nodes -are only used to determine dependencies. - -E.g., x[1:2] ~ dmnorm(...); y = x[1] + 1 -In this case, x[1] is an auxiliary node because it doesn't appear on the LHS of any expression. -But we must still introduce it to determine the dependency between `y` and `x[1:2]`. - -In the current implementation, `AuxiliaryNodeInfo` is only used when constructing the graph, -and will all be removed right before returning the graph. -""" -struct AuxiliaryNodeInfo <: NodeInfo end - -""" - ConcreteNodeInfo - -Defines the information stored in each node of the BUGS graph, encapsulating the essential characteristics -and functions associated with a node within the BUGS model's dependency graph. - -# Fields - -- `node_type::VariableTypes`: Specifies whether the node is a stochastic or logical variable. -- `node_function_expr::Expr`: The node function expression. -- `node_args::Vector{VarName}`: A vector containing the names of the variables that are - arguments to the node function. - -""" -struct ConcreteNodeInfo <: NodeInfo - node_type::VariableTypes +struct NodeInfo{F} + is_stochastic::Bool + is_observed::Bool node_function_expr::Expr - node_args::Vector -end - -function Base.show(io::IO, n::ConcreteNodeInfo) - if n isa ConcreteNodeInfo - print( - io, - "ConcreteNodeInfo(\n", - "\tNode Type: ", - n.node_type, - "\n", - "\tNode Function Expression: ", - n.node_function_expr, - "\n", - "\tNode Arguments: ", - n.node_args, - "\n", - ")", - ) - end -end - -function ConcreteNodeInfo(var::Var, vars, node_functions, node_args) - return ConcreteNodeInfo( - vars[var], - node_functions[var], - map(v -> AbstractPPL.VarName{v.name}(AbstractPPL.IdentityLens()), node_args[var]), - ) -end - -function NodeInfo(var::Var, vars, node_functions, node_args) - if var in keys(vars) - return ConcreteNodeInfo(var, vars, node_functions, node_args) - else - return AuxiliaryNodeInfo() - end + node_function::F + node_args::Tuple{Vararg{Symbol}} + loop_vars::NamedTuple end """ BUGSGraph The `BUGSGraph` object represents the graph structure for a BUGS model. It is a type alias for -[`MetaGraphsNext.MetaGraph`](https://juliagraphs.org/MetaGraphsNext.jl/dev/api/#MetaGraphsNext.MetaGraph) -with node type specified to [`ConcreteNodeInfo`](@ref). +`MetaGraphsNext.MetaGraph`. """ -const BUGSGraph = MetaGraph{ - Int64,SimpleDiGraph{Int64},VarName,NodeInfo,Nothing,Nothing,Nothing,Float64 -} - -function create_BUGSGraph(vars, node_args, node_functions, dependencies) - g = MetaGraph( - SimpleDiGraph{Int64}(); - weight_function=nothing, - label_type=VarName, - vertex_data_type=NodeInfo, - ) - for l in keys(vars) # l for LHS variable - l_vn = to_varname(l) - check_and_add_vertex!(g, l_vn, NodeInfo(l, vars, node_functions, node_args)) - # The use of AuxiliaryNodeInfo is also to save computation, becasue otherwise, - # every time we introduce a new node, we need to check `subsumes` or by all the existing nodes. - scalarize_then_add_edge!(g, l; lhs_or_rhs=:lhs) - for r in dependencies[l] - r_vn = to_varname(r) - check_and_add_vertex!(g, r_vn, NodeInfo(r, vars, node_functions, node_args)) - add_edge!(g, r_vn, l_vn) - scalarize_then_add_edge!(g, r; lhs_or_rhs=:rhs) - end - end - check_undeclared_variables(g, vars) - remove_auxiliary_nodes!(g) - return g -end - -""" - check_undeclared_variables - -Check for undeclared variables within the model definition -""" -function check_undeclared_variables(g::BUGSGraph, vars) - undeclared_vars = VarName[] - for v in labels(g) - if g[v] isa AuxiliaryNodeInfo - children = outneighbor_labels(g, v) - parents = inneighbor_labels(g, v) - if isempty(parents) || isempty(children) - if !any( - AbstractPPL.subsumes(u, v) || AbstractPPL.subsumes(v, u) for # corner case x[1:1] and x[1], e.g. Leuk - u in to_varname.(keys(vars)) - ) - push!(undeclared_vars, v) - end - end - end - end - if !isempty(undeclared_vars) - error("Undeclared variables: $(string.(Symbol.(undeclared_vars)))") - end -end - -function remove_auxiliary_nodes!(g::BUGSGraph) - for v in collect(labels(g)) - if g[v] isa AuxiliaryNodeInfo - # fix dependencies - children = outneighbor_labels(g, v) - parents = inneighbor_labels(g, v) - for c in children - for p in parents - @assert !any(x -> x isa AuxiliaryNodeInfo, (g[c], g[p])) "Auxiliary nodes should not have neighbors that are also auxiliary nodes, but at least one of $(g[c]) and $(g[p]) are." - add_edge!(g, p, c) - end - end - delete!(g, v) - end - end -end - -function check_and_add_vertex!(g::BUGSGraph, v::VarName, data::NodeInfo) - if haskey(g, v) - data isa AuxiliaryNodeInfo && return nothing - if g[v] isa AuxiliaryNodeInfo - set_data!(g, v, data) - end - else - add_vertex!(g, v, data) - end -end - -function scalarize_then_add_edge!(g::BUGSGraph, v::Var; lhs_or_rhs=:lhs) - scalarized_v = vcat(scalarize(v)...) - length(scalarized_v) == 1 && return nothing - v = to_varname(v) - for v_elem in map(to_varname, scalarized_v) - add_vertex!(g, v_elem, AuxiliaryNodeInfo()) # may fail, in that case, the existing node may be concrete, so we don't need to add it - if lhs_or_rhs == :lhs # if an edge exist between v and scalaized elements, don't add again - !Graphs.has_edge(g, code_for(g, v_elem), code_for(g, v)) && - add_edge!(g, v, v_elem) - elseif lhs_or_rhs == :rhs - !Graphs.has_edge(g, code_for(g, v), code_for(g, v_elem)) && - add_edge!(g, v_elem, v) - else - error("Unknown argument $lhs_or_rhs") - end - end -end +const BUGSGraph = MetaGraph """ find_generated_vars(g::BUGSGraph) @@ -194,7 +31,7 @@ function find_generated_vars(g) generated_vars = VarName[] for n in graph_roots - if g[n].node_type == Logical + if !g[n].is_stochastic push!(generated_vars, n) # graph roots that are Logical nodes are generated variables find_generated_vars_recursive_helper(g, n, generated_vars) end @@ -274,18 +111,10 @@ function stochastic_neighbors( stochastic_neighbors_vec = VarName[] logical_en_route = VarName[] # logical variables for u in f(g, v) - if g[u] isa ConcreteNodeInfo - if g[u].node_type == Stochastic - push!(stochastic_neighbors_vec, u) - else - push!(logical_en_route, u) - ns = stochastic_neighbors(g, u, f) - for n in ns - push!(stochastic_neighbors_vec, n) - end - end + if g[u].is_stochastic + push!(stochastic_neighbors_vec, u) else - # auxiliary nodes are not counted as logical nodes + push!(logical_en_route, u) ns = stochastic_neighbors(g, u, f) for n in ns push!(stochastic_neighbors_vec, n) diff --git a/src/model.jl b/src/model.jl index b6f0bc52a..84e71d3fc 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,8 +1,11 @@ -# AbstractBUGSModel can't be a subtype of AbstractProbabilisticProgram (<: AbstractMCMC.AbstractModel) +# AbstractBUGSModel subtype `AbstractPPL.AbstractProbabilisticProgram` (which subtypes `AbstractMCMC.AbstractModel`) # because it will then dispatched to https://github.com/TuringLang/AbstractMCMC.jl/blob/d7c549fe41a80c1f164423c7ac458425535f624b/src/sample.jl#L81 # instead of https://github.com/TuringLang/AbstractMCMC.jl/blob/d7c549fe41a80c1f164423c7ac458425535f624b/src/logdensityproblems.jl#L90 abstract type AbstractBUGSModel end +# TODO: currently the evaluated `node_function` is not used, two issues need to fixed before getting rid of `bugs_eval`: +# `cumulative`/`density` requires special treatment in `bugs_eval`; implicit casting of Float64 in indices + """ BUGSModel @@ -33,7 +36,7 @@ struct BUGSModel <: AbstractBUGSModel untransformed_param_length::Int transformed_param_length::Int untransformed_var_lengths::Dict{VarName,Int} - transformed_var_lengths::Dict{VarName,Int} # TODO: store this as a delta from `untransformed_var_lengths`? + transformed_var_lengths::Dict{VarName,Int} varinfo::SimpleVarInfo distributions::Dict{VarName,Distribution} @@ -67,91 +70,97 @@ Return the names of the generated variables in the model. """ generated_variables(m::BUGSModel) = find_generated_vars(m.g) -struct UninitializedVariableError <: Exception - msg::String +function prepare_arg_values( + args::Tuple{Vararg{Symbol}}, vi::SimpleVarInfo, loop_vars::NamedTuple{lvars} +) where {lvars} + return NamedTuple{args}(Tuple( + map(args) do arg + if arg in lvars + loop_vars[arg] + else + vi[@varname($arg)] + end + end, + )) end function BUGSModel( - g::BUGSGraph, - sorted_nodes::Vector{<:VarName}, - eval_env::NamedTuple, - inits; - is_transformed::Bool=true, + g::BUGSGraph, eval_env::NamedTuple, inits::NamedTuple; is_transformed::Bool=true ) - vs = initialize_var_store(eval_env) - vi = SimpleVarInfo(vs, 0.0) - dist_store = Dict{VarName,Distribution}() + sorted_nodes = map(topological_sort(g)) do node + label_for(g, node) + end + vi = SimpleVarInfo( + NamedTuple{keys(eval_env)}( + map( + k -> begin + v = eval_env[k] + if v === missing + 0.0 + elseif v isa AbstractArray + if eltype(v) === Missing + zeros(size(v)...) + elseif Missing <: eltype(v) + coalesce.(v, zero(nonmissingtype(eltype(v)))) + else + v + end + else + v + end + end, + keys(eval_env), + ), + ), + 0.0, + ) + + distributions = Dict{VarName,Distribution}() parameters = VarName[] - untransformed_param_length = 0 - transformed_param_length = 0 - untransformed_var_lengths = Dict{VarName,Int}() - transformed_var_lengths = Dict{VarName,Int}() + untransformed_param_length, transformed_param_length = 0, 0 + untransformed_var_lengths, transformed_var_lengths = Dict{VarName,Int}(), + Dict{VarName,Int}() + for vn in sorted_nodes - @assert !(g[vn] isa AuxiliaryNodeInfo) "Auxiliary nodes should not be in the graph, but $(g[vn]) is." - - ni = g[vn] - @unpack node_type, node_function_expr, node_args = ni - args = Dict(getsym(arg) => vi[arg] for arg in node_args) # TODO: get rid of this - expr = node_function_expr.args[2] - if node_type == JuliaBUGS.Logical - value = try - _eval(expr, args, dist_store) - catch e - rethrow( - # UninitializedVariableError( - # "Encounter error when evaluating the RHS of $vn. Try to initialize variables $(join(collect(keys(args)), ", ")) directly first if not yet.", - # ), - e, - ) - end + (; is_stochastic, is_observed, node_function_expr, node_function, node_args, loop_vars) = g[vn] + args = prepare_arg_values(node_args, vi, loop_vars) + expr = node_function_expr.args[2].args[1].args[1] + if !is_stochastic + value = bugs_eval(expr, args, distributions) + # value = Base.invokelatest(node_function; arg_values...) vi = setindex!!(vi, value, vn) else - dist = try - _eval(expr, args, dist_store) - catch _ - rethrow( - UninitializedVariableError( - "Encounter support error when evaluating the distribution of $vn. Try to initialize variables $(join(collect(keys(args)), ", ")) first if not yet.", - ), - ) + dist = bugs_eval(expr, args, distributions) + # dist = Base.invokelatest(node_function; arg_values...) + distributions[vn] = dist + + if is_observed + continue end - dist_store[vn] = dist - value = AbstractPPL.get(eval_env, vn) - if !is_resolved(value) # not observed - push!(parameters, vn) - this_param_length = length(dist) - untransformed_param_length += this_param_length - - @assert length(dist) == _length(vn) begin - "The dimensionality of distribution $dist: $(length(dist)) does not match length of variable $vn: $(_length(vn)), " * - "please note that if the distribution is a multivariate distribution, " * - "the left hand side variable should use explicit indexing, e.g. x[1:2] ~ dmnorm(...)." - end - if bijector(dist) == identity - this_param_transformed_length = this_param_length - else - this_param_transformed_length = length(Bijectors.transformed(dist)) - end - untransformed_var_lengths[vn] = this_param_length - transformed_var_lengths[vn] = this_param_transformed_length - transformed_param_length += this_param_transformed_length - value = try - AbstractPPL.get(inits, vn) - catch _ - missing - end - if !is_resolved(value) # not initialized - vi = setindex!!(vi, rand(dist), vn) - else - vi = setindex!!(vi, value, vn) - end + + push!(parameters, vn) + untransformed_var_lengths[vn] = length(dist) + # not all distributions are defined for `Bijectors.transformed` + transformed_var_lengths[vn] = if bijector(dist) == identity + untransformed_var_lengths[vn] else - vi = setindex!!(vi, value, vn) + length(Bijectors.transformed(dist)) end + untransformed_param_length += untransformed_var_lengths[vn] + transformed_param_length += transformed_var_lengths[vn] + + initialization = try + AbstractPPL.get(inits, vn) + catch _ + missing + end + # TODO: this will cause partially initialized value to be redrawn + if !is_resolved(initialization) + initialization = rand(dist) + end + vi = setindex!!(vi, initialization, vn) end end - @assert (isempty(parameters) ? 0 : sum(_length(x) for x in parameters)) == - untransformed_param_length "$(isempty(parameters) ? 0 : sum(_length(x) for x in parameters)) $untransformed_param_length" return BUGSModel( is_transformed, untransformed_param_length, @@ -159,7 +168,7 @@ function BUGSModel( untransformed_var_lengths, transformed_var_lengths, vi, - dist_store, + distributions, parameters, sorted_nodes, g, @@ -167,49 +176,34 @@ function BUGSModel( ) end -function initialize_var_store(eval_env::NamedTuple) - var_store = Dict{Symbol,Any}() - for k in keys(eval_env) - v = eval_env[k] - if v === missing - var_store[k] = 0.0 - elseif v isa AbstractArray && Missing <: eltype(v) - var_store[k] = map(x -> x === missing ? 0.0 : x, v) - else - var_store[k] = v - end - end - return NamedTuple(var_store) -end - """ - get_params_varinfo(m::BUGSModel[, vi::SimpleVarInfo]) + get_params_varinfo(model::BUGSModel[, vi::SimpleVarInfo]) Returns a `SimpleVarInfo` object containing only the parameter values of the model. -If `vi` is provided, it will be used; otherwise, `m.varinfo` will be used. +If `vi` is provided, it will be used; otherwise, `model.varinfo` will be used. """ -function get_params_varinfo(m::BUGSModel) - return get_params_varinfo(m, m.varinfo) +function get_params_varinfo(model::BUGSModel) + return get_params_varinfo(model, model.varinfo) end -function get_params_varinfo(m::BUGSModel, vi::SimpleVarInfo) - if !m.transformed +function get_params_varinfo(model::BUGSModel, vi::SimpleVarInfo) + if !model.transformed d = Dict{VarName,Any}() - for param in m.parameters + for param in model.parameters d[param] = vi[param] end return SimpleVarInfo(d, vi.logp, DynamicPPL.NoTransformation()) else d = Dict{VarName,Any}() - g = m.g - for vn in m.sorted_nodes - ni = g[vn] - @unpack node_type, node_function_expr, node_args = ni - args = Dict(getsym(arg) => vi[arg] for arg in node_args) - expr = node_function_expr.args[2] - if vn in m.parameters - dist = _eval(expr, args, m.distributions) - linked_val = DynamicPPL.link(dist, vi[vn]) - d[vn] = linked_val + g = model.g + for v in model.sorted_nodes + (; is_stochastic, node_function_expr, node_function, node_args, loop_vars) = g[v] + if v in model.parameters + args = prepare_arg_values(node_args, vi, loop_vars) + expr = node_function_expr.args[2].args[1].args[1] + dist = bugs_eval(expr, args, model.distributions) + # dist = node_function(; args...) + linked_val = DynamicPPL.link(dist, vi[v]) + d[v] = linked_val end end return SimpleVarInfo(d, vi.logp, DynamicPPL.DynamicTransformation()) @@ -217,55 +211,56 @@ function get_params_varinfo(m::BUGSModel, vi::SimpleVarInfo) end """ - getparams(m::BUGSModel[, vi::SimpleVarInfo]; transformed::Bool=false) + getparams(model::BUGSModel[, vi::SimpleVarInfo]; transformed::Bool=false) Extract the parameter values from the model as a flattened vector, ordered topologically. If `transformed` is set to true, the parameters are provided in the transformed space. """ -function getparams(m::BUGSModel; transformed::Bool=false) - return getparams(m, m.varinfo; transformed=transformed) +function getparams(model::BUGSModel; transformed::Bool=false) + return getparams(model, model.varinfo; transformed=transformed) end -function getparams(m::BUGSModel, vi::SimpleVarInfo; transformed::Bool=false) - if !transformed - param_vals = Vector{Float64}(undef, m.untransformed_param_length) - pos = 1 - for p in m.parameters - val = vi[p] - len = m.untransformed_var_lengths[p] - if isa(val, Real) +function getparams(model::BUGSModel, vi::SimpleVarInfo; transformed::Bool=false) + param_vals = Vector{Float64}( + undef, + transformed ? model.transformed_param_length : model.untransformed_param_length, + ) + pos = 1 + for v in model.parameters + if !transformed + val = vi[v] + len = model.untransformed_var_lengths[v] + if val isa AbstractArray + param_vals[pos:(pos + len - 1)] .= vec(val) + else param_vals[pos] = val - pos += 1 + end + else + (; node_function_expr, node_args, loop_vars) = model.g[v] + args = prepare_arg_values(node_args, vi, loop_vars) + expr = node_function_expr.args[2].args[1].args[1] + dist = bugs_eval(expr, args, model.distributions) + # dist = node_function(; args...) + linked_val = Bijectors.link(dist, vi[v]) + len = model.transformed_var_lengths[v] + if linked_val isa AbstractArray + param_vals[pos:(pos + len - 1)] .= vec(linked_val) else - param_vals[pos:(pos + len - 1)] .= vec(val) - pos += len + param_vals[pos] = linked_val end end - return param_vals - else - transformed_param_vals = Vector{Float64}(undef, m.transformed_param_length) - pos = 1 - for v in m.parameters - ni = m.g[v] - args = (; (getsym(arg) => vi[arg] for arg in ni.node_args)...) - dist = _eval(ni.node_function_expr.args[2], args, m.distributions) - - link_vals = Bijectors.link(dist, vi[v]) - len = m.transformed_var_lengths[v] - transformed_param_vals[pos:(pos + len - 1)] .= link_vals - pos += len - end - return transformed_param_vals + pos += len end + return param_vals end """ - setparams!!(m::BUGSModel, flattened_values::AbstractVector; transformed::Bool=false) + setparams!!(model::BUGSModel, flattened_values::AbstractVector; transformed::Bool=false) Update the parameter values of a `BUGSModel` with new values provided in a flattened vector. Only the parameter values are updated, the values of logical variables are kept unchanged. -This function adopt the bangbang convention, i.e. it modifies the model in place when possible. +This function adopts the `BangBang` convention, i.e. it modifies the model in place when possible. # Arguments - `m::BUGSModel`: The model to update. @@ -276,23 +271,24 @@ This function adopt the bangbang convention, i.e. it modifies the model in place `SimpleVarInfo`: The updated `varinfo` with the new parameter values set. """ function setparams!!( - m::BUGSModel, flattened_values::AbstractVector; transformed::Bool=false + model::BUGSModel, flattened_values::AbstractVector; transformed::Bool=false ) pos = 1 - vi = m.varinfo - for v in m.parameters - ni = m.g[v] - args = (; (getsym(arg) => vi[arg] for arg in ni.node_args)...) - dist = _eval(ni.node_function_expr.args[2], args, m.distributions) + vi = model.varinfo + for v in model.parameters + (; node_function_expr, node_args, loop_vars) = model.g[v] + args = prepare_arg_values(node_args, vi, loop_vars) + expr = node_function_expr.args[2].args[1].args[1] + dist = bugs_eval(expr, args, model.distributions) len = if transformed - m.transformed_var_lengths[v] + model.transformed_var_lengths[v] else - m.untransformed_var_lengths[v] + model.untransformed_var_lengths[v] end if transformed - link_vals = flattened_values[pos:(pos + len - 1)] - sample_val = DynamicPPL.invlink_and_reconstruct(dist, link_vals) + linked_vals = flattened_values[pos:(pos + len - 1)] + sample_val = DynamicPPL.invlink_and_reconstruct(dist, linked_vals) else sample_val = flattened_values[pos:(pos + len - 1)] end @@ -302,14 +298,13 @@ function setparams!!( return vi end -# TODO: For now, a varinfo contains all model parameters is returned; alternatively, can return the generated quantities function (model::BUGSModel)() - vi, logp = evaluate!!(model, SamplingContext()) + vi, _ = evaluate!!(model, SamplingContext()) return get_params_varinfo(model, vi) end function settrans(model::BUGSModel, bool::Bool=!(model.transformed)) - return @set model.transformed = bool + return BangBang.setproperty!!(model, :transformed, bool) end function AbstractPPL.condition( @@ -383,7 +378,7 @@ end function check_var_group(var_group::Vector{<:VarName}, model::BUGSModel) non_vars = filter(var -> var ∉ labels(model.g), var_group) - logical_vars = filter(var -> model.g[var].node_type == Logical, var_group) + logical_vars = filter(var -> !model.g[var].is_stochastic, var_group) isempty(non_vars) || error("Variables $(non_vars) are not in the model") return isempty(logical_vars) || error( "Variables $(logical_vars) are not stochastic variables, conditioning on them is not supported", @@ -426,19 +421,18 @@ function AbstractPPL.evaluate!!(model::BUGSModel, rng::Random.AbstractRNG) return evaluate!!(model, SamplingContext(rng)) end function AbstractPPL.evaluate!!(model::BUGSModel, ctx::SamplingContext) - @unpack varinfo, g, sorted_nodes = model + (; varinfo, g, sorted_nodes) = model vi = deepcopy(varinfo) logp = 0.0 for vn in sorted_nodes - ni = g[vn] - @unpack node_type, node_function_expr, node_args = ni - args = Dict(getsym(arg) => vi[arg] for arg in node_args) - expr = node_function_expr.args[2] - if node_type == JuliaBUGS.Logical - value = _eval(expr, args, model.distributions) + (; is_stochastic, node_function_expr, node_args, loop_vars) = g[vn] + args = prepare_arg_values(node_args, vi, loop_vars) + expr = node_function_expr.args[2].args[1].args[1] + if !is_stochastic + value = bugs_eval(expr, args, model.distributions) vi = setindex!!(vi, value, vn) else - dist = _eval(expr, args, model.distributions) + dist = bugs_eval(expr, args, model.distributions) value = rand(ctx.rng, dist) # just sample from the prior logp += logpdf(dist, value) vi = setindex!!(vi, value, vn) @@ -451,20 +445,18 @@ function AbstractPPL.evaluate!!(model::BUGSModel) return AbstractPPL.evaluate!!(model, DefaultContext()) end function AbstractPPL.evaluate!!(model::BUGSModel, ::DefaultContext) - sorted_nodes = model.sorted_nodes - g = model.g - vi = deepcopy(model.varinfo) + (; sorted_nodes, g, varinfo) = model + vi = deepcopy(varinfo) logp = 0.0 for vn in sorted_nodes - ni = g[vn] - @unpack node_type, node_function_expr, node_args = ni - args = Dict(getsym(arg) => vi[arg] for arg in node_args) - expr = node_function_expr.args[2] - if node_type == JuliaBUGS.Logical # be conservative -- always propagate values of logical nodes - value = _eval(expr, args, model.distributions) + (; is_stochastic, node_function_expr, node_args, loop_vars) = g[vn] + args = prepare_arg_values(node_args, vi, loop_vars) + expr = node_function_expr.args[2].args[1].args[1] + if !is_stochastic + value = bugs_eval(expr, args, model.distributions) vi = setindex!!(vi, value, vn) else - dist = _eval(expr, args, model.distributions) + dist = bugs_eval(expr, args, model.distributions) value = vi[vn] if model.transformed # although the values stored in `vi` are in their original space, @@ -485,31 +477,38 @@ end function AbstractPPL.evaluate!!( model::BUGSModel, ::LogDensityContext, flattened_values::AbstractVector ) - @assert length(flattened_values) == ( - if model.transformed - model.transformed_param_length - else - model.untransformed_param_length - end - ) + param_lengths = if model.transformed + model.transformed_param_length + else + model.untransformed_param_length + end + + if length(flattened_values) != param_lengths + error( + "The length of `flattened_values` does not match the length of the parameters in the model", + ) + end + + var_lengths = if model.transformed + model.transformed_var_lengths + else + model.untransformed_var_lengths + end - var_lengths = - model.transformed ? model.transformed_var_lengths : model.untransformed_var_lengths sorted_nodes = model.sorted_nodes g = model.g vi = deepcopy(model.varinfo) current_idx = 1 logp = 0.0 for vn in sorted_nodes - ni = g[vn] - @unpack node_type, node_function_expr, node_args = ni - args = (; map(arg -> getsym(arg) => vi[arg], node_args)...) - expr = node_function_expr.args[2] - if node_type == JuliaBUGS.Logical - value = _eval(expr, args, model.distributions) + (; is_stochastic, node_function_expr, node_args, loop_vars) = g[vn] + args = prepare_arg_values(node_args, vi, loop_vars) + expr = node_function_expr.args[2].args[1].args[1] + if !is_stochastic + value = bugs_eval(expr, args, model.distributions) vi = setindex!!(vi, value, vn) else - dist = _eval(expr, args, model.distributions) + dist = bugs_eval(expr, args, model.distributions) if vn in model.parameters l = var_lengths[vn] if model.transformed diff --git a/src/utils.jl b/src/utils.jl index 86592438b..909c7b5be 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -495,16 +495,17 @@ function simple_arithmetic_eval(data::NamedTuple, expr::Expr) end """ - _eval(expr, env, dist_store) + bugs_eval(expr, env, dist_store) -`_eval` mimics `Base.eval`, but uses precompiled functions. This is possible because the expressions we want to -evaluate only have two kinds of expressions: function calls and indexing. +`bugs_eval` mimics `Base.eval`'s behavior: it traverse the Expr, for function call, it will use `getfield(JuliaBUGS, f)` to get the function. +`bugs_eval` assumes that the Expr only has two kinds of expressions: function calls and indexing. `env` is a data structure mapping symbols in `expr` to values, values can be arrays or scalars. +`dist_store` is a data structure storing runtime `Distribution` objects associated with random variables. """ -function _eval(expr::Number, env, dist_store) +function bugs_eval(expr::Number, env, dist_store) return expr end -function _eval(expr::Symbol, env, dist_store) +function bugs_eval(expr::Symbol, env, dist_store) if expr == :nothing return nothing elseif expr == :(:) @@ -513,10 +514,10 @@ function _eval(expr::Symbol, env, dist_store) return env[expr] end end -function _eval(expr::AbstractRange, env, dist_store) +function bugs_eval(expr::AbstractRange, env, dist_store) return expr end -function _eval(expr::Expr, env, dist_store) +function bugs_eval(expr::Expr, env, dist_store) if Meta.isexpr(expr, :call) f = expr.args[1] if f === :cumulative || f === :density @@ -529,7 +530,7 @@ function _eval(expr::Expr, env, dist_store) dist = if Meta.isexpr(rv1, :ref) var, indices... = rv1.args for i in eachindex(indices) - indices[i] = _eval(indices[i], env, dist_store) + indices[i] = bugs_eval(indices[i], env, dist_store) end vn = AbstractPPL.VarName{var}( AbstractPPL.Setfield.IndexLens(Tuple(indices)) @@ -543,30 +544,38 @@ function _eval(expr::Expr, env, dist_store) "the first argument of density function should be a variable, but got $(rv1).", ) end - rv2 = _eval(rv2, env, dist_store) + rv2 = bugs_eval(rv2, env, dist_store) if f === :cumulative return cdf(dist, rv2) else return pdf(dist, rv2) end else - args = [_eval(arg, env, dist_store) for arg in expr.args[2:end]] + args = [bugs_eval(arg, env, dist_store) for arg in expr.args[2:end]] if f isa Expr # `JuliaBUGS.some_function` like f = f.args[2].value end return getfield(JuliaBUGS, f)(args...) # assume all functions used are available under `JuliaBUGS` end elseif Meta.isexpr(expr, :ref) - array = _eval(expr.args[1], env, dist_store) - indices = [_eval(arg, env, dist_store) for arg in expr.args[2:end]] + array = bugs_eval(expr.args[1], env, dist_store) + indices = [bugs_eval(arg, env, dist_store) for arg in expr.args[2:end]] + # TODO: should just ban implicit type casting + indices = map(indices) do index + if index isa Float64 + Int(index) + else + index + end + end return array[indices...] elseif Meta.isexpr(expr, :block) - return _eval(expr.args[end], env, dist_store) + return bugs_eval(expr.args[end], env, dist_store) else error("Unknown expression type: $expr") end end -function _eval(expr, env, dist_store) +function bugs_eval(expr, env, dist_store) return error("Unknown expression type: $expr of type $(typeof(expr))") end @@ -585,22 +594,6 @@ function evaluate(vn::VarName, env) return ismissing(ret) ? nothing : ret end -""" - _length(vn::VarName) - -Return the length of a possible variable identified by `vn`. -Only valid if `vn` is: - - a scalar - - an array indexing whose indices are concrete(no `start`, `end`, `:`) - -! Should not be used outside of the usage demonstrated in this file. - -""" -function _length(vn::VarName) - getlens(vn) isa Setfield.IdentityLens && return 1 - return prod([length(index_range) for index_range in getlens(vn).indices]) -end - # Resolves: setindex!!([1 2; 3 4], [2 3; 4 5], 1:2, 1:2) # returns 2×2 Matrix{Any} # Alternatively, can overload BangBang.possible( # ::typeof(BangBang._setindex!), ::C, ::T, ::Vararg diff --git a/src/variable_types.jl b/src/variable_types.jl deleted file mode 100644 index 8e49c2fbc..000000000 --- a/src/variable_types.jl +++ /dev/null @@ -1,131 +0,0 @@ -""" - Var - -A lightweight type for representing variables in a model. -""" -abstract type Var end - -struct Scalar <: Var - name::Symbol - indices::Tuple{} -end - -struct ArrayElement{N} <: Var - name::Symbol - indices::NTuple{N,Int} -end - -struct ArrayVar{N} <: Var - name::Symbol - indices::NTuple{N,Union{Int,UnitRange,Colon}} -end - -Var(name::Symbol) = Scalar(name, ()) -function Var(name::Symbol, indices) - indices = map(indices) do i - if i isa AbstractFloat - isinteger(i) && return Int(i) - error("Indices must be integers.") - end - return i - end - all(x -> x isa Integer, indices) && return ArrayElement(name, indices) - return ArrayVar(name, indices) -end - -Base.size(::Scalar) = () -Base.size(::ArrayElement) = () -function Base.size(v::ArrayVar) - if any(x -> x isa Colon, v.indices) - error("Can't get size of an array with colon indices.") - end - return Tuple(map(length, v.indices)) -end - -Base.Symbol(v::Scalar) = v.name -function Base.Symbol(v::Var) - return Symbol(v.name, "[", join(v.indices, ", "), "]") -end - -toexpr(r::Number) = r -toexpr(r::UnitRange) = Expr(:call, :(:), r.start, r.stop) -toexpr(v::Scalar) = v.name -toexpr(v::Var) = Expr(:ref, v.name, toexpr.(v.indices)...) - -function hash(v::Var, h::UInt) - return hash(v.name, hash(v.indices, h)) -end - -function Base.:(==)(v1::Var, v2::Var) - typeof(v1) != typeof(v2) && return false - return v1.name == v2.name && v1.indices == v2.indices -end - -Base.show(io::IO, v::Scalar) = print(io, v.name) -function Base.show(io::IO, v::Var) - return print(io, v.name, "[", join(v.indices, ", "), "]") -end - -function to_varname(v::Scalar) - lens = AbstractPPL.IdentityLens() - return AbstractPPL.VarName{v.name}(lens) -end -function to_varname(v::Var) - lens = AbstractPPL.IndexLens(v.indices) - return AbstractPPL.VarName{v.name}(lens) -end - -""" - scalarize(v::Var) - -Return an array of `Var`s that are scalarized from `v`. If `v` is a scalar, return an array of length 1 containing `v`. -All indices of `v` must be integer or UnitRange. - -# Examples -```jldoctest -julia> scalarize(Var(:x, (1, 2:3))) -2-element Vector{Var}: - x[1, 2] - x[1, 3] -``` -""" -scalarize(v::Scalar) = [v] -scalarize(v::ArrayElement) = [v] -function scalarize(v::Var) - collected_indices = collect(Iterators.product(v.indices...)) - scalarized_vars = Array{Var}(undef, size(collected_indices)...) - for i in eachindex(collected_indices) - scalarized_vars[i] = Var(v.name, collected_indices[i]) - end - return scalarized_vars -end - -""" - evaluate(v::Var, env) - -Evaluate `v` in the environment `env`. If `v` is a scalar, return the value of `v` in `env`. If `v` is an array, -return an array of the same size as `v` with the values of `v` in `env` and `Var`s for the missing values. If `v` -represent a multi-dimensional array, the return value is always scalarized, even when no array elements are data. - -# Examples -```jldoctest; setup=:(using JuliaBUGS: evaluate) -julia> evaluate(Var(:x, (1:2, )), Dict(:x => [1, missing])) -2-element Vector{Any}: - 1 - x[2] -``` -""" -function evaluate(v::Var, env) - if !haskey(env, v.name) - return v isa Scalar ? v : scalarize(v) - end - if v isa Scalar - return env[v.name] - end - if v isa ArrayElement - value = env[v.name][v.indices...] - return ismissing(value) ? v : value - end - value = map(x -> evaluate(x, env), scalarize(v)) - return reshape(value, size(v)) -end diff --git a/test/gibbs.jl b/test/gibbs.jl index 3ab333ed8..973199429 100644 --- a/test/gibbs.jl +++ b/test/gibbs.jl @@ -27,10 +27,14 @@ model = compile(model_def, data, (;)) # use NamedTuple for SimpleVarinfo - model = @set model.varinfo = begin - vi = model.varinfo - SimpleVarInfo(DynamicPPL.values_as(vi, NamedTuple), vi.logp, vi.transformation) - end + model = JuliaBUGS.BangBang.setproperty!!( + model, + :varinfo, + begin + vi = model.varinfo + SimpleVarInfo(DynamicPPL.values_as(vi, NamedTuple), vi.logp, vi.transformation) + end, + ) # single step p_s, st_init = AbstractMCMC.step( diff --git a/test/graphs.jl b/test/graphs.jl index 397194c0b..6f72f89b5 100644 --- a/test/graphs.jl +++ b/test/graphs.jl @@ -75,7 +75,7 @@ end logp += logpdf(dnorm(2.0, 1.0), 4.0) # d, where g = 2.0 logp += logpdf(dnorm(4.0, 4.0), 5.0) # e, where h = 4.0 logp -end ≈ evaluate!!(model, LogDensityContext(), [4.0, 2.0, -2.0, 3.0, 1.0, 5.0, 4.0])[2] atol = +end ≈ evaluate!!(model, LogDensityContext(), [-2.0, 4.0, 3.0, 2.0, 1.0, 4.0, 5.0])[2] atol = 1e-8 # AuxiliaryNodeInfo diff --git a/test/logp_tests/binomial.jl b/test/logp_tests/binomial.jl index fc635e8df..4dea1ba41 100644 --- a/test/logp_tests/binomial.jl +++ b/test/logp_tests/binomial.jl @@ -14,7 +14,11 @@ dppl_model = dppl_gamma_model() bugs_logp = JuliaBUGS.evaluate!!(JuliaBUGS.settrans(bugs_model, false), DefaultContext())[2] params_vi = JuliaBUGS.get_params_varinfo(bugs_model, vi) # test if JuliaBUGS and DynamicPPL agree on parameters in the model -@test params_in_dppl_model(dppl_model) == keys(params_vi) +@test keys( + DynamicPPL.evaluate!!( + dppl_model, SimpleVarInfo(Dict{VarName,Any}()), DynamicPPL.SamplingContext() + )[2], +) == keys(params_vi) p = DynamicPPL.LogDensityFunction(dppl_model) t_p = DynamicPPL.LogDensityFunction( diff --git a/test/logp_tests/blockers.jl b/test/logp_tests/blockers.jl index 3d824fa46..fc9696ea6 100644 --- a/test/logp_tests/blockers.jl +++ b/test/logp_tests/blockers.jl @@ -1,6 +1,6 @@ -bugs_model_def = JuliaBUGS.BUGSExamples.VOLUME_I[:blockers].model_def -data = JuliaBUGS.BUGSExamples.VOLUME_I[:blockers].data -inits = JuliaBUGS.BUGSExamples.VOLUME_I[:blockers].inits[1] +bugs_model_def = JuliaBUGS.BUGSExamples.blockers.model_def +data = JuliaBUGS.BUGSExamples.blockers.data +inits = JuliaBUGS.BUGSExamples.blockers.inits[1] bugs_model = compile(bugs_model_def, data, inits) vi = bugs_model.varinfo @@ -31,13 +31,17 @@ vi = bugs_model.varinfo return sigma end -@unpack rt, nt, rc, nc, Num = data +(; rt, nt, rc, nc, Num) = data dppl_model = blockers(rc, rt, nc, nt, Num) bugs_logp = JuliaBUGS.evaluate!!(JuliaBUGS.settrans(bugs_model, false), DefaultContext())[2] params_vi = JuliaBUGS.get_params_varinfo(bugs_model, vi) # test if JuliaBUGS and DynamicPPL agree on parameters in the model -@test params_in_dppl_model(dppl_model) == keys(params_vi) +@test keys( + DynamicPPL.evaluate!!( + dppl_model, SimpleVarInfo(Dict{VarName,Any}()), DynamicPPL.SamplingContext() + )[2], +) == keys(params_vi) dppl_logp = DynamicPPL.evaluate!!( diff --git a/test/logp_tests/bones.jl b/test/logp_tests/bones.jl index 02b3518b3..2629979b5 100644 --- a/test/logp_tests/bones.jl +++ b/test/logp_tests/bones.jl @@ -33,13 +33,17 @@ vi = bugs_model.varinfo end end -@unpack grade, nChild, nInd, ncat, gamma, delta = data +(; grade, nChild, nInd, ncat, gamma, delta) = data dppl_model = bones(grade, nChild, nInd, ncat, gamma, delta) bugs_logp = JuliaBUGS.evaluate!!(JuliaBUGS.settrans(bugs_model, false), DefaultContext())[2] params_vi = JuliaBUGS.get_params_varinfo(bugs_model, vi) # test if JuliaBUGS and DynamicPPL agree on parameters in the model -@test params_in_dppl_model(dppl_model) == keys(params_vi) +@test keys( + DynamicPPL.evaluate!!( + dppl_model, SimpleVarInfo(Dict{VarName,Any}()), DynamicPPL.SamplingContext() + )[2], +) == keys(params_vi) dppl_logp = DynamicPPL.evaluate!!( diff --git a/test/logp_tests/dogs.jl b/test/logp_tests/dogs.jl index 7d7c6b096..17f8079cc 100644 --- a/test/logp_tests/dogs.jl +++ b/test/logp_tests/dogs.jl @@ -36,7 +36,7 @@ vi = bugs_model.varinfo return A, B end -@unpack Dogs, Trials, Y = data +(; Dogs, Trials, Y) = data dppl_model = dogs(Dogs, Trials, Y, 1 .- Y) bugs_logp = JuliaBUGS.evaluate!!(JuliaBUGS.settrans(bugs_model, false), DefaultContext())[2] diff --git a/test/logp_tests/gamma.jl b/test/logp_tests/gamma.jl index b78efa9f9..28e993d08 100644 --- a/test/logp_tests/gamma.jl +++ b/test/logp_tests/gamma.jl @@ -15,7 +15,11 @@ dppl_model = dppl_gamma_model() bugs_logp = JuliaBUGS.evaluate!!(JuliaBUGS.settrans(bugs_model, false), DefaultContext())[2] params_vi = JuliaBUGS.get_params_varinfo(bugs_model, vi) # test if JuliaBUGS and DynamicPPL agree on parameters in the model -@test params_in_dppl_model(dppl_model) == keys(params_vi) +@test keys( + DynamicPPL.evaluate!!( + dppl_model, SimpleVarInfo(Dict{VarName,Any}()), DynamicPPL.SamplingContext() + )[2], +) == keys(params_vi) p = DynamicPPL.LogDensityFunction(dppl_model) t_p = DynamicPPL.LogDensityFunction( diff --git a/test/logp_tests/rats.jl b/test/logp_tests/rats.jl index 7970037a6..f76bac8af 100644 --- a/test/logp_tests/rats.jl +++ b/test/logp_tests/rats.jl @@ -1,64 +1,49 @@ -# prepare data -data = load_dictionary(:rats, :data, true) -inits = load_dictionary(:rats, :init, true) +model_def = JuliaBUGS.BUGSExamples.rats.model_def +data = JuliaBUGS.BUGSExamples.rats.data +inits = JuliaBUGS.BUGSExamples.rats.inits[1] -@unpack N, T, x, xbar, Y = data - -# prepare models -model_def = @bugs begin - for i in 1:N - for j in 1:T - Y[i, j] ~ dnorm(mu[i, j], tau_c) - mu[i, j] = alpha[i] + beta[i] * (x[j] - xbar) - end - alpha[i] ~ dnorm(alpha_c, alpha_tau) - beta[i] ~ dnorm(beta_c, beta_tau) - end - tau_c ~ dgamma(0.001, 0.001) - sigma = 1 / sqrt(tau_c) - alpha_c ~ dnorm(0.0, 1.0E-6) - alpha_tau ~ dgamma(0.001, 0.001) - beta_c ~ dnorm(0.0, 1.0E-6) - beta_tau ~ dgamma(0.001, 0.001) - alpha0 = alpha_c - xbar * beta_c -end bugs_model = compile(model_def, data, inits); vi = bugs_model.varinfo @model function rats(Y, x, xbar, N, T) - tau_c ~ dgamma(0.001, 0.001) - sigma = 1 / sqrt(tau_c) + var"tau.c" ~ dgamma(0.001, 0.001) + sigma = 1 / sqrt(var"tau.c") - alpha_c ~ dnorm(0.0, 1.0E-6) - alpha_tau ~ dgamma(0.001, 0.001) + var"alpha.c" ~ dnorm(0.0, 1.0E-6) + var"alpha.tau" ~ dgamma(0.001, 0.001) - beta_c ~ dnorm(0.0, 1.0E-6) - beta_tau ~ dgamma(0.001, 0.001) + var"beta.c" ~ dnorm(0.0, 1.0E-6) + var"beta.tau" ~ dgamma(0.001, 0.001) - alpha0 = alpha_c - xbar * beta_c + alpha0 = var"alpha.c" - xbar * var"beta.c" alpha = Vector{Real}(undef, N) beta = Vector{Real}(undef, N) for i in 1:N - alpha[i] ~ dnorm(alpha_c, alpha_tau) - beta[i] ~ dnorm(beta_c, beta_tau) + alpha[i] ~ dnorm(var"alpha.c", var"alpha.tau") + beta[i] ~ dnorm(var"beta.c", var"beta.tau") for j in 1:T mu = alpha[i] + beta[i] * (x[j] - xbar) - Y[i, j] ~ dnorm(mu, tau_c) + Y[i, j] ~ dnorm(mu, var"tau.c") end end return sigma, alpha0 end +(; N, T, x, xbar, Y) = data dppl_model = rats(Y, x, xbar, N, T) bugs_model = JuliaBUGS.settrans(bugs_model, false) bugs_logp = JuliaBUGS.evaluate!!(bugs_model, DefaultContext())[2] params_vi = JuliaBUGS.get_params_varinfo(bugs_model, vi) # test if JuliaBUGS and DynamicPPL agree on parameters in the model -@test params_in_dppl_model(dppl_model) == keys(params_vi) +@test keys( + DynamicPPL.evaluate!!( + dppl_model, SimpleVarInfo(Dict{VarName,Any}()), DynamicPPL.SamplingContext() + )[2], +) == keys(params_vi) dppl_logp = DynamicPPL.evaluate!!( diff --git a/test/profile.jl b/test/profile.jl index 03909c8b7..bf99dde3a 100644 --- a/test/profile.jl +++ b/test/profile.jl @@ -28,17 +28,8 @@ for name in keys(BUGSExamples.VOLUME_I) $non_data_scalars, $non_data_array_sizes, $model_def, $data ) - _suite["NodeFunction"] = @benchmarkable JuliaBUGS.compute_node_functions( - $model_def, $eval_env - ) - model_def = JuliaBUGS.concretize_colon_indexing(model_def, eval_env) - vars, node_args, node_functions, dependencies = JuliaBUGS.compute_node_functions( - model_def, eval_env - ) - _suite["GraphCreation"] = @benchmarkable JuliaBUGS.create_BUGSGraph( - $vars, $node_args, $node_functions, $dependencies - ) + _suite["GraphCreation"] = @benchmarkable JuliaBUGS.create_graph($model_def, $eval_env) tune!(_suite) suite[string(name)] = _suite @@ -50,7 +41,7 @@ function create_result_dict(results) result_dict = Dict{String,Dict{String,Dict{String,String}}}() for (name, example_suite) in results _d = Dict{String,Dict{String,String}}() - for k in ("CollectVariables", "DataTransformation", "NodeFunction", "GraphCreation") + for k in ("CollectVariables", "DataTransformation", "GraphCreation") __d = Dict{String,String}() med = median(example_suite[k]) min = minimum(example_suite[k]) diff --git a/test/run_logp_tests.jl b/test/run_logp_tests.jl deleted file mode 100644 index ccabe64d2..000000000 --- a/test/run_logp_tests.jl +++ /dev/null @@ -1,26 +0,0 @@ -function load_dictionary(example_name, data_or_init, replace_period=true) - example = JuliaBUGS.BUGSExamples.VOLUME_I[example_name] - if data_or_init == :data - _d = example.data - elseif data_or_init == :init - _d = example.inits[1] - else - error("data_or_init must be either :data or :init") - end - d = Dict{Symbol,Any}() - for _k in keys(_d) - if replace_period - k = Symbol(replace(String(_k), "." => "_")) - end - d[k] = _d[_k] - end - return d -end - -function params_in_dppl_model(dppl_model) - return keys( - DynamicPPL.evaluate!!( - dppl_model, SimpleVarInfo(Dict{VarName,Any}()), DynamicPPL.SamplingContext() - )[2], - ) -end diff --git a/test/runtests.jl b/test/runtests.jl index 401a0237b..d9b2680dd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,28 +6,20 @@ using Bijectors using Distributions using Documenter using DynamicPPL -using DynamicPPL: getlogp, settrans!! +using DynamicPPL: getlogp, settrans!!, SimpleVarInfo using Graphs, MetaGraphsNext using JuliaBUGS using JuliaBUGS: BUGSGraph, - CollectVariables, - ConcreteNodeInfo, - DataTransformation, DefaultContext, evaluate!!, get_params_varinfo, - Logical, LogDensityContext, MHFromPrior, - NodeFunctions, - SimpleVarInfo, - Stochastic, stochastic_inneighbors, stochastic_neighbors, stochastic_outneighbors, markov_blanket, - Var, Gibbs using JuliaBUGS.BUGSPrimitives using JuliaBUGS.BUGSPrimitives: mean @@ -37,9 +29,7 @@ using MacroTools using MCMCChains using Random using ReverseDiff -using Setfield using Test -using UnPack AbstractMCMC.setprogress!(false) @@ -54,11 +44,8 @@ else JuliaBUGS, BUGSExamples, @bugs, - Var, - create_array_var, - replace_constants_in_expr, evaluate_and_track_dependencies, - scalarize, + evaluate, concretize_colon_indexing, extract_variable_names_and_numdims, extract_variables_in_bounds_and_lhs_indices, @@ -115,7 +102,6 @@ else include("passes.jl") @testset "Log Probability Test" begin - include("run_logp_tests.jl") @testset "Single stochastic variable test" begin @testset "test for $s" for s in [:binomial, :gamma, :lkj, :dwish, :ddirich] include("logp_tests/$s.jl")