diff --git a/src/JuliaBUGS.jl b/src/JuliaBUGS.jl index d56cfb599..6b2214f07 100644 --- a/src/JuliaBUGS.jl +++ b/src/JuliaBUGS.jl @@ -227,20 +227,20 @@ function compile(model_def::Expr, data, inits; is_transformed=true) check_input(data) check_input(inits) - scalars, array_sizes = program!(CollectVariables(), model_def, data) - has_new_val, transformed_variables = program!( + scalars, array_sizes = analyze_program(CollectVariables(), model_def, data) + has_new_val, transformed_variables = analyze_program( ConstantPropagation(scalars, array_sizes), model_def, data ) while has_new_val - has_new_val, transformed_variables = program!( + has_new_val, transformed_variables = analyze_program( ConstantPropagation(false, transformed_variables), model_def, data ) end - array_bitmap, transformed_variables = program!( + array_bitmap, transformed_variables = analyze_program( PostChecking(data, transformed_variables), model_def, data ) merged_data = merge_with_coalescence(deepcopy(data), transformed_variables) - vars, array_sizes, array_bitmap, node_args, node_functions, dependencies = program!( + vars, array_sizes, array_bitmap, node_args, node_functions, dependencies = analyze_program( NodeFunctions(array_sizes, array_bitmap), model_def, merged_data ) g = create_BUGSGraph(vars, node_args, node_functions, dependencies) diff --git a/src/compiler_pass.jl b/src/compiler_pass.jl index 02cfb1c52..59963f96d 100644 --- a/src/compiler_pass.jl +++ b/src/compiler_pass.jl @@ -1,87 +1,47 @@ -""" - CompilerPass - -Abstract supertype for all compiler passes. Concrete subtypes should store data needed and artifacts. -""" abstract type CompilerPass end -""" - program!(pass::CompilerPass, expr::Expr, env, vargs...) - -The entry point for a compiler pass, which traverses the AST and performs specific actions like assignment and for-loop processing. -This function should be implemented for every concrete subtype of CompilerPass. - -Arguments: -- pass: Instance of a concrete CompilerPass subtype. -- expr: An Expr object representing the AST to be traversed. -- env: A Dict object representing the environment. -""" -function program!(pass::CompilerPass, expr::Expr, env, vargs...) +function analyze_program(pass::CompilerPass, expr::Expr, env, vargs...) for ex in expr.args if Meta.isexpr(ex, :(=)) - assignment!(pass, ex, env, vargs...) - elseif MacroTools.@capture(ex, lhs_ ~ rhs_) - assignment!(pass, ex, env, vargs...) + analyze_assignment(pass, ex, env, vargs...) + elseif Meta.isexpr(ex, :call) && ex.args[1] == :(~) + analyze_assignment(pass, ex, env, vargs...) elseif Meta.isexpr(ex, :for) - for_loop!(pass, ex, env, vargs...) + analyze_for_loop(pass, ex, env, vargs...) else - error() + error("Unsupported expression in top level: $ex") end end return post_process(pass, expr, env, vargs...) end -""" - for_loop!(pass::CompilerPass, expr, env, vargs...) - -Processes a for-loop from a traversed AST. -""" -function for_loop!(pass::CompilerPass, expr, env, vargs...) - loop_var = expr.args[1].args[1] - lb, ub = expr.args[1].args[2].args[2:end] - body = expr.args[2] - - loop_var = Symbol(loop_var) +function analyze_for_loop(pass::CompilerPass, expr, env, vargs...) + loop_var, lb, ub, body = decompose_for_expr(expr) lb = Int(evaluate(lb, env)) ub = Int(evaluate(ub, env)) + for i in lb:ub for ex in body.args if Meta.isexpr(ex, :(=)) - assignment!(pass, ex, merge(env, NamedTuple{(loop_var,)}((i,))), vargs...) - elseif ex.head == :call && ex.args[1] == :(~) - assignment!(pass, ex, merge(env, NamedTuple{(loop_var,)}((i,))), vargs...) + analyze_assignment( + pass, ex, merge(env, NamedTuple{(loop_var,)}((i,))), vargs... + ) + elseif Meta.isexpr(ex, :call) && ex.args[1] == :(~) + analyze_assignment( + pass, ex, merge(env, NamedTuple{(loop_var,)}((i,))), vargs... + ) elseif Meta.isexpr(ex, :for) - for_loop!(pass, ex, merge(env, NamedTuple{(loop_var,)}((i,))), vargs...) + analyze_for_loop( + pass, ex, merge(env, NamedTuple{(loop_var,)}((i,))), vargs... + ) else - error() + error("Unsupported expression in for loop body: $ex") end end end end -""" - assignment!(pass::CompilerPass, expr::Expr, env, vargs...) - -Performs an assignment operation on a traversed AST. Should be implemented for every concrete subtype of CompilerPass. - -Arguments: -- pass: Instance of a concrete CompilerPass subtype. -- expr: An Expr object representing the assignment operation. -- env: A Dict object representing the environment. -""" -function assignment!(::CompilerPass, expr::Expr, env, vargs...) end - -""" - post_process(pass::CompilerPass, expr, env, vargs...) - -Performs any post-processing necessary after traversing the AST. Should be implemented for every concrete subtype of CompilerPass. - -Arguments: -- pass: Instance of a concrete CompilerPass subtype. -- expr: An Expr object representing the traversed AST. -- env: A Dict object representing the environment. -""" -function post_process(pass::CompilerPass, expr, env, vargs...) end +function analyze_assignment end @enum VariableTypes begin Logical @@ -367,7 +327,7 @@ is_resolved(::Array{Missing}) = false is_resolved(::Union{Symbol,Expr}) = false is_resolved(::Any) = false -function assignment!(pass::CollectVariables, expr::Expr, env) +function analyze_assignment(pass::CollectVariables, expr::Expr, env) if Meta.isexpr(expr, :(=)) lhs_expr = expr.args[1] else # Expr(:call, :(~), ...) @@ -456,7 +416,7 @@ function has_value(transformed_variables, v::Var) end end -function assignment!(pass::ConstantPropagation, expr::Expr, env) +function analyze_assignment(pass::ConstantPropagation, expr::Expr, env) if Meta.isexpr(expr, :(=)) && !should_skip_eval(expr.args[2]) lhs = find_variables_on_lhs(expr.args[1], env) @@ -520,7 +480,7 @@ function PostChecking(data, transformed_variables::Dict) ) end -function assignment!(pass::PostChecking, expr::Expr, env) +function analyze_assignment(pass::PostChecking, expr::Expr, env) @inline set_value!(d::Dict, value, v::Scalar) = d[v.name] = value @inline set_value!(d::Dict, value, v::Var) = d[v.name][v.indices...] = value @inline get_value(d::Dict, v::Scalar) = d[v.name] @@ -846,7 +806,7 @@ try_cast_to_int(x::Integer) = x try_cast_to_int(x::Real) = Int(x) # will error if !isinteger(x) try_cast_to_int(x) = x # catch other types, e.g. UnitRange, Colon -function assignment!(pass::NodeFunctions, expr::Expr, env) +function analyze_assignment(pass::NodeFunctions, expr::Expr, env) @capture(expr, lhs_expr_ ~ rhs_expr_) || @capture(expr, lhs_expr_ = rhs_expr_) var_type = Meta.isexpr(expr, :(=)) ? Logical : Stochastic diff --git a/src/utils.jl b/src/utils.jl index 5b4623f84..e1d520cb2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,17 @@ +""" + decompose_for_expr(expr::Expr) + +Decompose a for-loop expression into its components. The function returns four items: the +loop variable, the lower bound, the upper bound, and the body of the loop. +""" +@inline function decompose_for_expr(expr::Expr) + loop_var::Symbol = expr.args[1].args[1] + lb::Union{Int,Float64,Symbol,Expr} = expr.args[1].args[2].args[2] + ub::Union{Int,Float64,Symbol,Expr} = expr.args[1].args[2].args[3] + body::Expr = expr.args[2] + return loop_var, lb, ub, body +end + """ _eval(expr, env) diff --git a/test/methadone/methadone.jl b/test/methadone/methadone.jl index be753f3f5..9cfb52348 100644 --- a/test/methadone/methadone.jl +++ b/test/methadone/methadone.jl @@ -87,10 +87,10 @@ end ## @time model = compile(model_def, data, inits); -@time vars, array_sizes, transformed_variables, array_bitmap = JuliaBUGS.program!( +@time vars, array_sizes, transformed_variables, array_bitmap = JuliaBUGS.analyze_program( JuliaBUGS.CollectVariables(), model_def, data ); -vars, array_sizes, array_bitmap, link_functions, node_args, node_functions, dependencies = JuliaBUGS.program!( +vars, array_sizes, array_bitmap, link_functions, node_args, node_functions, dependencies = JuliaBUGS.analyze_program( JuliaBUGS.NodeFunctions(vars, array_sizes, array_bitmap), model_def, data ); diff --git a/test/passes.jl b/test/passes.jl index 5c898b3f6..98e36ee0e 100644 --- a/test/passes.jl +++ b/test/passes.jl @@ -6,20 +6,20 @@ end data = (b=1, e=[1, 2]) - scalars, array_sizes = program!(CollectVariables(), model_def, data) - has_new_val, transformed_variables = program!( + scalars, array_sizes = analyze_program(CollectVariables(), model_def, data) + has_new_val, transformed_variables = analyze_program( ConstantPropagation(scalars, array_sizes), model_def, data ) @test has_new_val == true @test transformed_variables[:a] == 2 - has_new_val, transformed_variables = program!( + has_new_val, transformed_variables = analyze_program( ConstantPropagation(false, transformed_variables), model_def, data ) @test has_new_val == true @test transformed_variables[:c] == 6 - has_new_val, transformed_variables = program!( + has_new_val, transformed_variables = analyze_program( ConstantPropagation(false, transformed_variables), model_def, data ) @test has_new_val == false @@ -30,26 +30,26 @@ end data = JuliaBUGS.BUGSExamples.VOLUME_I[m].data inits = JuliaBUGS.BUGSExamples.VOLUME_I[m].inits[1] - scalars, array_sizes = program!(CollectVariables(), model_def, data) + scalars, array_sizes = analyze_program(CollectVariables(), model_def, data) - has_new_val, transformed_variables = program!( + has_new_val, transformed_variables = analyze_program( ConstantPropagation(scalars, array_sizes), model_def, data ) @test has_new_val == true @test all(!ismissing, transformed_variables[:Y]) - has_new_val, transformed_variables = program!( + has_new_val, transformed_variables = analyze_program( ConstantPropagation(false, transformed_variables), model_def, data ) @test has_new_val == true @test all(!ismissing, transformed_variables[:dN]) - array_bitmap, transformed_variables = program!( + array_bitmap, transformed_variables = analyze_program( PostChecking(data, transformed_variables), model_def, data ) - vars, array_sizes, array_bitmap, node_args, node_functions, dependencies = program!( + vars, array_sizes, array_bitmap, node_args, node_functions, dependencies = analyze_program( NodeFunctions(array_sizes, array_bitmap), model_def, merge_with_coalescence(data, transformed_variables), diff --git a/test/runtests.jl b/test/runtests.jl index e2e065840..969982c25 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,7 +23,7 @@ using JuliaBUGS: MHFromPrior, NodeFunctions, PostChecking, - program!, + analyze_program, SimpleVarInfo, Stochastic, stochastic_inneighbors, diff --git a/test/utils.jl b/test/utils.jl index f14c1b9f5..13c1b565c 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,3 +1,24 @@ +@testset "Decompose for loop" begin + ex = MacroTools.@q for i in 1:3 + x[i] = i + for j in 1:3 + y[i, j] = i + j + end + end + + loop_var, lb, ub, body = JuliaBUGS.decompose_for_expr(ex) + + @test loop_var == :i + @test lb == 1 + @test ub == 3 + @test body == MacroTools.@q begin + x[i] = i + for j in 1:3 + y[i, j] = i + j + end + end +end + # Tests for `getparams`, using `Rats` model @testset "`getparams` with Rats" begin m = :rats