Skip to content

Commit

Permalink
Improve performance of logdensity computation (#228)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
sunxd3 and github-actions[bot] authored Oct 25, 2024
1 parent 53e9de6 commit ce29cc1
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 46 deletions.
10 changes: 3 additions & 7 deletions src/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,9 @@ struct NodeInfo{F}
loop_vars::NamedTuple
end

"""
BUGSGraph
The `BUGSGraph` object represents the graph structure for a BUGS model. It is a type alias for
`MetaGraphsNext.MetaGraph`.
"""
const BUGSGraph = MetaGraph
const BUGSGraph = MetaGraph{
Int,Graphs.SimpleDiGraph{Int},<:VarName,<:NodeInfo,Nothing,Nothing,<:Any,Float64
}

"""
find_generated_vars(g::BUGSGraph)
Expand Down
78 changes: 40 additions & 38 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,18 @@ Return a vector of `VarName` containing the names of all the variables in the mo
"""
variables(m::BUGSModel) = collect(labels(m.g))

function prepare_arg_values(
args::Tuple{Vararg{Symbol}}, evaluation_env::NamedTuple, loop_vars::NamedTuple{lvars}
) where {lvars}
return NamedTuple{args}(Tuple(
map(args) do arg
if arg in lvars
loop_vars[arg]
else
AbstractPPL.get(evaluation_env, @varname($arg))
end
end,
))
@generated function prepare_arg_values(
::Val{args}, evaluation_env::NamedTuple, loop_vars::NamedTuple{lvars}
) where {args,lvars}
fields = []
for arg in args
if arg in lvars
push!(fields, :(loop_vars[$(QuoteNode(arg))]))
else
push!(fields, :(evaluation_env[$(QuoteNode(arg))]))
end
end
return :(NamedTuple{$(args)}(($(fields...),)))
end

function BUGSModel(
Expand All @@ -99,7 +99,7 @@ function BUGSModel(

for vn in sorted_nodes
(; is_stochastic, is_observed, node_function, node_args, loop_vars) = g[vn]
args = prepare_arg_values(node_args, evaluation_env, loop_vars)
args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars)
if !is_stochastic
value = Base.invokelatest(node_function; args...)
evaluation_env = BangBang.setindex!!(evaluation_env, value, vn)
Expand Down Expand Up @@ -179,7 +179,7 @@ function initialize!(model::BUGSModel, initial_params::NamedTuple)
check_input(initial_params)
for vn in model.sorted_nodes
(; is_stochastic, is_observed, node_function, node_args, loop_vars) = model.g[vn]
args = prepare_arg_values(node_args, model.evaluation_env, loop_vars)
args = prepare_arg_values(Val(node_args), model.evaluation_env, loop_vars)
if !is_stochastic
value = Base.invokelatest(node_function; args...)
BangBang.@set!! model.evaluation_env = setindex!!(
Expand Down Expand Up @@ -243,7 +243,7 @@ function getparams(model::BUGSModel)
end
else
(; node_function, node_args, loop_vars) = model.g[v]
args = prepare_arg_values(node_args, model.evaluation_env, loop_vars)
args = prepare_arg_values(Val(node_args), model.evaluation_env, loop_vars)
dist = node_function(; args...)
transformed_value = Bijectors.transform(
Bijectors.bijector(dist), AbstractPPL.get(model.evaluation_env, v)
Expand All @@ -267,7 +267,7 @@ function getparams_as_ordereddict(model::BUGSModel)
d[v] = AbstractPPL.get(model.evaluation_env, v)
else
(; node_function, node_args, loop_vars) = model.g[v]
args = prepare_arg_values(node_args, model.evaluation_env, loop_vars)
args = prepare_arg_values(Val(node_args), model.evaluation_env, loop_vars)
dist = node_function(; args...)
d[v] = Bijectors.transform(
Bijectors.bijector(dist), AbstractPPL.get(model.evaluation_env, v)
Expand Down Expand Up @@ -321,7 +321,20 @@ function AbstractPPL.condition(
)
end

return BUGSModel(model, new_parameters, sorted_blanket_with_vars, evaluation_env)
g = copy(model.g)
for vn in sorted_blanket_with_vars
if vn in new_parameters
continue
end
ni = g[vn]
if ni.is_stochastic && !ni.is_observed
ni = @set ni.is_observed = true
g[vn] = ni
end
end

new_model = BUGSModel(model, new_parameters, sorted_blanket_with_vars, evaluation_env)
return BangBang.setproperty!!(new_model, :g, g)
end

function AbstractPPL.decondition(model::BUGSModel, var_group::Vector{<:VarName})
Expand Down Expand Up @@ -387,7 +400,7 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ctx::SamplingContext)
logp = 0.0
for vn in sorted_nodes
(; is_stochastic, node_function, node_args, loop_vars) = g[vn]
args = prepare_arg_values(node_args, evaluation_env, loop_vars)
args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars)
if !is_stochastic
value = node_function(; args...)
evaluation_env = setindex!!(evaluation_env, value, vn)
Expand All @@ -410,7 +423,7 @@ function AbstractPPL.evaluate!!(model::BUGSModel, ::DefaultContext)
logp = 0.0
for vn in sorted_nodes
(; is_stochastic, node_function, node_args, loop_vars) = g[vn]
args = prepare_arg_values(node_args, evaluation_env, loop_vars)
args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars)
if !is_stochastic
value = node_function(; args...)
evaluation_env = setindex!!(evaluation_env, value, vn)
Expand All @@ -436,51 +449,40 @@ end
function AbstractPPL.evaluate!!(
model::BUGSModel, ::LogDensityContext, flattened_values::AbstractVector
)
param_lengths = if model.transformed
model.transformed_param_length
else
model.untransformed_param_length
end

if length(flattened_values) != param_lengths
error(
"The length of `flattened_values` does not match the length of the parameters in the model",
)
end

var_lengths = if model.transformed
model.transformed_var_lengths
else
model.untransformed_var_lengths
end

sorted_nodes = model.sorted_nodes
g = model.g
evaluation_env = deepcopy(model.evaluation_env)
current_idx = 1
logp = 0.0
for vn in sorted_nodes
(; is_stochastic, node_function, node_args, loop_vars) = g[vn]
args = prepare_arg_values(node_args, evaluation_env, loop_vars)
for vn in model.sorted_nodes
(; is_stochastic, is_observed, node_function, node_args, loop_vars) = g[vn]
args = prepare_arg_values(Val(node_args), evaluation_env, loop_vars)
if !is_stochastic
value = node_function(; args...)
evaluation_env = BangBang.setindex!!(evaluation_env, value, vn)
else
dist = node_function(; args...)
if vn in model.parameters
if !is_observed
l = var_lengths[vn]
if model.transformed
b = Bijectors.bijector(dist)
b_inv = Bijectors.inverse(b)
reconstructed_value = reconstruct(
b_inv, dist, flattened_values[current_idx:(current_idx + l - 1)]
b_inv,
dist,
view(flattened_values, current_idx:(current_idx + l - 1)),
)
value, logjac = Bijectors.with_logabsdet_jacobian(
b_inv, reconstructed_value
)
else
value = reconstruct(
dist, flattened_values[current_idx:(current_idx + l - 1)]
dist, view(flattened_values, current_idx:(current_idx + l - 1))
)
logjac = 0.0
end
Expand Down
3 changes: 2 additions & 1 deletion test/graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ decond_model = AbstractPPL.decondition(cond_model, [a, l])
c_value = 4.0
mb_logp = begin
logp = 0
logp += logpdf(dnorm(1.0, c_value), 1.0) # a
f = 2.0 - 1.0
logp += logpdf(dnorm(f, c_value), 1.0) # a
logp += logpdf(dnorm(0.0, 1.0), 2.0) # b
logp += logpdf(dnorm(0.0, 1.0), -2.0) # l
logp += logpdf(dnorm(-2.0, 1.0), c_value) # c
Expand Down

0 comments on commit ce29cc1

Please sign in to comment.