Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve graph library #219

Draft
wants to merge 26 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "ba9fb4c0-828e-4473-b6a1-cd2560fee5bf"
version = "0.6.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down Expand Up @@ -41,7 +42,7 @@ JuliaBUGSAdvancedMHExt = ["AdvancedMH", "MCMCChains"]
JuliaBUGSDynamicPPLExt = ["DynamicPPL"]
JuliaBUGSGraphMakieExt = ["GraphMakie", "GLMakie"]
JuliaBUGSGraphPlotExt = ["GraphPlot"]
JuliaBUGSMCMCChainsExt = ["DynamicPPL", "MCMCChains"]
JuliaBUGSMCMCChainsExt = ["MCMCChains"]

[compat]
ADTypes = "1.6"
Expand Down
57 changes: 24 additions & 33 deletions ext/JuliaBUGSAdvancedHMCExt.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,22 @@
module JuliaBUGSAdvancedHMCExt

using AbstractMCMC
using AdvancedHMC
using AdvancedHMC: Transition, stat
using JuliaBUGS
using AbstractMCMC: AbstractMCMC
using AdvancedHMC: AdvancedHMC
using MCMCChains: MCMCChains
using JuliaBUGS:
AbstractBUGSModel, BUGSModel, Gibbs, find_generated_vars, LogDensityContext, evaluate!!
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.BangBang
using JuliaBUGS.LogDensityProblems
using JuliaBUGS.LogDensityProblemsAD
using JuliaBUGS.Bijectors
using JuliaBUGS.Random
using MCMCChains: Chains
import JuliaBUGS: gibbs_internal
JuliaBUGS, Accessors, ADTypes, LogDensityProblems, LogDensityProblemsAD, Random

function AbstractMCMC.bundle_samples(
ts::Vector{<:Transition},
logdensitymodel::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper},
sampler::AdvancedHMC.AbstractHMCSampler,
ts::Vector{<:AdvancedHMC.Transition},
logdensitymodel,
sampler,
state,
chain_type::Type{Chains};
chain_type::Type{MCMCChains.Chains};
discard_initial=0,
thinning=1,
kwargs...,
)
params = [t.z.θ for t in ts]
stats_names = collect(keys(merge((; lp=ts[1].z.ℓπ.value), AdvancedHMC.stat(ts[1]))))
stats_values = [
vcat([ts[i].z.ℓπ.value..., collect(values(AdvancedHMC.stat(ts[i])))...]) for
Expand All @@ -33,7 +25,7 @@ function AbstractMCMC.bundle_samples(

return JuliaBUGS.gen_chains(
logdensitymodel,
[t.z.θ for t in ts],
params,
stats_names,
stats_values;
discard_initial=discard_initial,
Expand All @@ -43,24 +35,23 @@ function AbstractMCMC.bundle_samples(
end

function JuliaBUGS.gibbs_internal(
rng::Random.AbstractRNG, cond_model::BUGSModel, sampler::HMC
rng::Random.AbstractRNG,
sub_model::JuliaBUGS.BUGSModel,
sampler::AdvancedHMC.HMC,
state::AdvancedHMC.HMCState,
adtype::ADTypes.AbstractADType,
)
logdensitymodel = AbstractMCMC.LogDensityModel(
LogDensityProblemsAD.ADgradient(:ReverseDiff, cond_model)
)
t, s = AbstractMCMC.step(
rng,
logdensitymodel,
sampler;
n_adapts=0,
initial_params=JuliaBUGS.getparams(cond_model),
# update the log density in the state
hamiltonian = AdvancedHMC.Hamiltonian(state.metric, sub_model)
state = Accessors.@set state.transition.z = AdvancedHMC.phasepoint(
hamiltonian, state.transition.z.θ, state.transition.z.r
)
updated_model = initialize!(cond_model, t.z.θ)
return JuliaBUGS.getparams(
BangBang.setproperty!!(
updated_model.base_model, :evaluation_env, updated_model.evaluation_env
),

logdensitymodel = AbstractMCMC.LogDensityModel(
LogDensityProblemsAD.ADgradient(adtype, sub_model)
)
_, s = AbstractMCMC.step(rng, logdensitymodel, sampler, state; n_adapts=0)
return initialize!(sub_model, s.transition.z.θ).evaluation_env, s
end

end
62 changes: 26 additions & 36 deletions ext/JuliaBUGSAdvancedMHExt.jl
Original file line number Diff line number Diff line change
@@ -1,60 +1,50 @@
module JuliaBUGSAdvancedMHExt

using AbstractMCMC
using AdvancedMH
using JuliaBUGS
using JuliaBUGS: BUGSModel, find_generated_vars, LogDensityContext, evaluate!!
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.LogDensityProblems
using JuliaBUGS.LogDensityProblemsAD
using JuliaBUGS.Random
using JuliaBUGS.Bijectors
using MCMCChains: Chains
import JuliaBUGS: gibbs_internal
using AbstractMCMC: AbstractMCMC
using AdvancedMH: AdvancedMH
using MCMCChains: MCMCChains
using JuliaBUGS: JuliaBUGS
using JuliaBUGS: Accessors, ADTypes, LogDensityProblems, LogDensityProblemsAD, Random

function AbstractMCMC.bundle_samples(
ts::Vector{<:AdvancedMH.AbstractTransition},
logdensitymodel::Union{
AbstractMCMC.LogDensityModel{<:JuliaBUGS.BUGSModel},
AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper},
},
sampler::AdvancedMH.MHSampler,
ts::Vector{<:AdvancedMH.Transition},
logdensitymodel,
sampler,
state,
chain_type::Type{Chains};
chain_type::Type{MCMCChains.Chains};
discard_initial=0,
thinning=1,
kwargs...,
)
params = [t.params for t in ts]
stats_names = [:lp]
stats_values = [t.lp for t in ts]

return JuliaBUGS.gen_chains(
logdensitymodel,
[t.params for t in ts],
[:lp],
[t.lp for t in ts];
params,
stats_names,
stats_values;
discard_initial=discard_initial,
thinning=thinning,
kwargs...,
)
end

function JuliaBUGS.gibbs_internal(
rng::Random.AbstractRNG, cond_model::BUGSModel, sampler::AdvancedMH.MHSampler
rng::Random.AbstractRNG,
sub_model::JuliaBUGS.BUGSModel,
sampler::AdvancedMH.MHSampler,
state::AdvancedMH.Transition,
adtype::ADTypes.AbstractADType,
)
state = Accessors.@set state.lp = LogDensityProblems.logdensity(sub_model, state.params)

logdensitymodel = AbstractMCMC.LogDensityModel(
LogDensityProblemsAD.ADgradient(:ReverseDiff, cond_model)
)
t, s = AbstractMCMC.step(
rng,
logdensitymodel,
sampler;
n_adapts=0,
initial_params=JuliaBUGS.getparams(cond_model),
)
updated_model = initialize!(cond_model, t.params)
return JuliaBUGS.getparams(
BangBang.setproperty!!(
updated_model.base_model, :evaluation_env, updated_model.evaluation_env
),
LogDensityProblemsAD.ADgradient(adtype, sub_model)
)
_, s = AbstractMCMC.step(rng, logdensitymodel, sampler, state)
return JuliaBUGS.initialize!(sub_model, s.params).evaluation_env, s
end

end
80 changes: 53 additions & 27 deletions ext/JuliaBUGSMCMCChainsExt.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,37 @@
module JuliaBUGSMCMCChainsExt

using JuliaBUGS
using JuliaBUGS: AbstractBUGSModel, find_generated_vars, LogDensityContext, evaluate!!
using JuliaBUGS.AbstractPPL
using JuliaBUGS.BUGSPrimitives
using JuliaBUGS.LogDensityProblems
using JuliaBUGS.LogDensityProblemsAD
using DynamicPPL
using AbstractMCMC
using AbstractMCMC: AbstractMCMC
using MCMCChains: Chains
using JuliaBUGS:
JuliaBUGS, AbstractPPL, BUGSPrimitives, LogDensityProblems, LogDensityProblemsAD

function JuliaBUGS.gen_chains(
model::AbstractMCMC.LogDensityModel{<:JuliaBUGS.BUGSModel},
samples,
stats_names,
stats_values;
function AbstractMCMC.bundle_samples(
ts,
logdensitymodel::AbstractMCMC.LogDensityModel{<:JuliaBUGS.BUGSModel},
sampler::JuliaBUGS.Gibbs,
state,
::Type{Chains};
discard_initial=0,
thinning=1,
kwargs...,
)
return JuliaBUGS.gen_chains(
model.logdensity,
samples,
stats_names,
stats_values;
discard_initial=discard_initial,
thinning=thinning,
kwargs...,
logdensitymodel, ts, [], []; discard_initial=discard_initial, kwargs...
)
end

function get_bugsmodel(model::AbstractMCMC.LogDensityModel{<:JuliaBUGS.BUGSModel})
return model.logdensity
end

function get_bugsmodel(
model::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper}
)
ad_wrapper = model.logdensity
return Base.parent(ad_wrapper)::JuliaBUGS.BUGSModel
end

function JuliaBUGS.gen_chains(
model::AbstractMCMC.LogDensityModel{<:LogDensityProblemsAD.ADGradientWrapper},
model::AbstractMCMC.LogDensityModel,
samples,
stats_names,
stats_values;
Expand All @@ -40,7 +40,7 @@ function JuliaBUGS.gen_chains(
kwargs...,
)
return JuliaBUGS.gen_chains(
model.logdensity.ℓ,
get_bugsmodel(model),
samples,
stats_names,
stats_values;
Expand All @@ -62,13 +62,15 @@ function JuliaBUGS.gen_chains(
param_vars = model.parameters
g = model.g

generated_vars = find_generated_vars(g)
generated_vars = JuliaBUGS.find_generated_quantities_variables(g)
generated_vars = [v for v in model.sorted_nodes if v in generated_vars] # keep the order

param_vals = []
generated_quantities = []
for i in axes(samples)[1]
evaluation_env = first(evaluate!!(model, LogDensityContext(), samples[i]))
evaluation_env = first(
JuliaBUGS.evaluate!!(model, JuliaBUGS.LogDensityContext(), samples[i])
)
push!(
param_vals,
[AbstractPPL.get(evaluation_env, param_var) for param_var in param_vars],
Expand All @@ -84,13 +86,13 @@ function JuliaBUGS.gen_chains(

param_name_leaves = collect(
Iterators.flatten([
collect(DynamicPPL.varname_leaves(vn, param_vals[1][i])) for
collect(varname_leaves(vn, param_vals[1][i])) for
(i, vn) in enumerate(param_vars)
],),
)
generated_varname_leaves = collect(
Iterators.flatten([
collect(DynamicPPL.varname_leaves(vn, generated_quantities[1][i])) for
collect(varname_leaves(vn, generated_quantities[1][i])) for
(i, vn) in enumerate(generated_vars)
],),
)
Expand Down Expand Up @@ -129,4 +131,28 @@ function JuliaBUGS.gen_chains(
)
end

# utils: copied from DynamicPPL

varname_leaves(vn::JuliaBUGS.VarName, ::Real) = [vn]
function varname_leaves(vn::JuliaBUGS.VarName, val::AbstractArray{<:Union{Real,Missing}})
return (
JuliaBUGS.VarName(vn, Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)) for
I in CartesianIndices(val)
)
end
function varname_leaves(vn::JuliaBUGS.VarName, val::AbstractArray)
return Iterators.flatten(
varname_leaves(
JuliaBUGS.VarName(vn, Accessors.IndexLens(Tuple(I)) ∘ getoptic(vn)), val[I]
) for I in CartesianIndices(val)
)
end
function varname_leaves(vn::JuliaBUGS.VarName, val::NamedTuple)
iter = Iterators.map(keys(val)) do sym
optic = Accessors.PropertyLens{sym}()
varname_leaves(JuliaBUGS.VarName(vn, optic ∘ getoptic(vn)), optic(val))
end
return Iterators.flatten(iter)
end

end
1 change: 1 addition & 0 deletions src/JuliaBUGS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module JuliaBUGS
using AbstractMCMC
using AbstractPPL
using Accessors
using ADTypes
using BangBang
using Bijectors: Bijectors
using Distributions
Expand Down
Loading
Loading