Skip to content

Commit

Permalink
Implement make_chain_from_prior([rng,] model, n_iters)
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Dec 4, 2024
1 parent 2649a30 commit 2cc15b1
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 0 deletions.
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand Down
33 changes: 33 additions & 0 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,36 @@ function modify_value_representation(nt::NamedTuple)
end
return modified_nt
end

"""
make_chain_from_prior([rng,] model, n_iters)
Construct an MCMCChains.Chains object by sampling from the prior of `model` for
`n_iters` iterations.
"""
function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::Int)
# Sample from the prior
varinfos = [VarInfo(rng, model) for _ in 1:n_iters]
# Extract all varnames found in any dictionary. Doing it this way guards
# against the possibility of having different varnames in different
# dictionaries, e.g. for models that have dynamic variables / array sizes
varnames = OrderedSet{VarName}()
# Convert each varinfo into an OrderedDict of vns => params.
# We have to use varname_and_value_leaves so that each parameter is a scalar
dicts = map(varinfos) do t
vals = DynamicPPL.values_as(t, OrderedDict)
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
tuples = mapreduce(collect, vcat, iters)
push!(varnames, map(first, tuples)...)
OrderedDict(tuples)
end
# Convert back to list
varnames = collect(varnames)
# Construct matrix of values
vals = [get(dict, vn, missing) for dict in dicts, vn in varnames]
# Construct and return the Chains object
return Chains(vals, varnames)
end
function make_chain_from_prior(model::Model, n_iters::Int)
return make_chain_from_prior(Random.default_rng(), model, n_iters)
end

0 comments on commit 2cc15b1

Please sign in to comment.