From 696bb3d734646053157482f1fcd8cdbc66591a23 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 29 Nov 2024 14:06:58 +0000 Subject: [PATCH] Implement make_chain_from_prior([rng,] model, n_iters) --- test/test_util.jl | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/test_util.jl b/test/test_util.jl index f1325b729..c71d7b486 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -110,3 +110,32 @@ function modify_value_representation(nt::NamedTuple) end return modified_nt end + +""" + make_chain_from_prior([rng,] model, n_iters) + +Construct an MCMCChains.Chains object by sampling from the prior of `model` for +`n_iters` iterations. +""" +function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::Int) + # Sample from the prior + varinfos = [VarInfo(rng, model) for _ in 1:n_iters] + # Convert each varinfo into an OrderedDict of vns => params. + # We have to use varname_and_value_leaves so that each parameter is a scalar + dicts = map(varinfos) do t + vals = DynamicPPL.values_as(t, OrderedDict) + iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + tuples = mapreduce(collect, vcat, iters) + OrderedDict(tuples) + end + # Extract all varnames found in any dictionary. Doing it this way guards + # against the possibility of having different varnames in different + # dictionaries, e.g. for models that have dynamic variables / array sizes + all_varnames = collect(union(map(keys, dicts)...)) + vals = [get(dict, vn, missing) for dict in dicts, vn in all_varnames] + # Construct and return the Chains object + return Chains(vals, all_varnames) +end +function make_chain_from_prior(model::Model, n_iters::Int) + return make_chain_from_prior(Random.default_rng(), model, n_iters) +end