diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 63019aa44..bff53bf64 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -73,7 +73,6 @@ jobs: - uses: julia-actions/julia-runtest@v1 env: - GROUP: All JULIA_NUM_THREADS: ${{ matrix.runner.num_threads }} - uses: julia-actions/julia-processcoverage@v1 diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index e3dfc3030..36bf4939c 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -14,4 +14,4 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} - run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "docs", "test", "test/turing"])' + run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "docs", "test"])' diff --git a/.github/workflows/JuliaPre.yml b/.github/workflows/JuliaPre.yml index 98b0b0ffa..a9118a4bf 100644 --- a/.github/workflows/JuliaPre.yml +++ b/.github/workflows/JuliaPre.yml @@ -25,5 +25,3 @@ jobs: - uses: julia-actions/cache@v2 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - env: - GROUP: DynamicPPL diff --git a/Project.toml b/Project.toml index 61b3b7e81..51bac6df5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.31.5" +version = "0.32.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/api.md b/docs/src/api.md index 5a226e73b..d5c6bd690 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -279,12 +279,9 @@ VarInfo TypedVarInfo ``` -One main characteristic of [`VarInfo`](@ref) is that samples are stored in a linearized form. - -```@docs -link! -invlink! -``` +One main characteristic of [`VarInfo`](@ref) is that samples are transformed to unconstrained Euclidean space and stored in a linearized form, as described in the [transformation page](internals/transformations.md). +The [Transformations section below](#Transformations) describes the methods used for this. +In the specific case of `VarInfo`, it keeps track of whether samples have been transformed by setting flags on them, using the following functions. ```@docs set_flag! diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index 4aa2aaa42..92a69d9ad 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -323,28 +323,30 @@ function varnames(model::Model{typeof(demo_assume_dot_observe)}) return [@varname(s), @varname(m)] end -@model function demo_assume_observe_literal() - # `assume` and literal `observe` +@model function demo_assume_multivariate_observe_literal() + # multivariate `assume` and literal `observe` s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) m ~ MvNormal(zeros(2), Diagonal(s)) [1.5, 2.0] ~ MvNormal(m, Diagonal(s)) return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) +function logprior_true(model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m) s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) m_dist = MvNormal(zeros(2), Diagonal(s)) return logpdf(s_dist, s) + logpdf(m_dist, m) end -function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) +function loglikelihood_true( + model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m +) return logpdf(MvNormal(m, Diagonal(s)), [1.5, 2.0]) end function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_assume_observe_literal)}, s, m + model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function varnames(model::Model{typeof(demo_assume_observe_literal)}) +function varnames(model::Model{typeof(demo_assume_multivariate_observe_literal)}) return [@varname(s), @varname(m)] end @@ -377,7 +379,31 @@ function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -@model function demo_assume_literal_dot_observe() +@model function demo_assume_observe_literal() + # univariate `assume` and literal `observe` + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + 1.5 ~ Normal(m, sqrt(s)) + 2.0 ~ Normal(m, sqrt(s)) + + return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) +end +function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) + return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) +end +function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) + return logpdf(Normal(m, sqrt(s)), 1.5) + logpdf(Normal(m, sqrt(s)), 2.0) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_observe_literal)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end +function varnames(model::Model{typeof(demo_assume_observe_literal)}) + return [@varname(s), @varname(m)] +end + +@model function demo_assume_dot_observe_literal() # `assume` and literal `dot_observe` s ~ InverseGamma(2, 3) m ~ Normal(0, sqrt(s)) @@ -385,18 +411,18 @@ end return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m) +function logprior_true(model::Model{typeof(demo_assume_dot_observe_literal)}, s, m) return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) end -function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m) +function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe_literal)}, s, m) return loglikelihood(Normal(m, sqrt(s)), [1.5, 2.0]) end function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_assume_literal_dot_observe)}, s, m + model::Model{typeof(demo_assume_dot_observe_literal)}, s, m ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function varnames(model::Model{typeof(demo_assume_literal_dot_observe)}) +function varnames(model::Model{typeof(demo_assume_dot_observe_literal)}) return [@varname(s), @varname(m)] end @@ -574,8 +600,9 @@ const DemoModels = Union{ Model{typeof(demo_assume_multivariate_observe)}, Model{typeof(demo_dot_assume_observe_index)}, Model{typeof(demo_assume_dot_observe)}, - Model{typeof(demo_assume_literal_dot_observe)}, + Model{typeof(demo_assume_dot_observe_literal)}, Model{typeof(demo_assume_observe_literal)}, + Model{typeof(demo_assume_multivariate_observe_literal)}, Model{typeof(demo_dot_assume_observe_index_literal)}, Model{typeof(demo_assume_submodel_observe_index_literal)}, Model{typeof(demo_dot_assume_observe_submodel)}, @@ -585,7 +612,9 @@ const DemoModels = Union{ } const UnivariateAssumeDemoModels = Union{ - Model{typeof(demo_assume_dot_observe)},Model{typeof(demo_assume_literal_dot_observe)} + Model{typeof(demo_assume_dot_observe)}, + Model{typeof(demo_assume_dot_observe_literal)}, + Model{typeof(demo_assume_observe_literal)}, } function posterior_mean(model::UnivariateAssumeDemoModels) return (s=49 / 24, m=7 / 6) @@ -609,7 +638,7 @@ const MultivariateAssumeDemoModels = Union{ Model{typeof(demo_assume_index_observe)}, Model{typeof(demo_assume_multivariate_observe)}, Model{typeof(demo_dot_assume_observe_index)}, - Model{typeof(demo_assume_observe_literal)}, + Model{typeof(demo_assume_multivariate_observe_literal)}, Model{typeof(demo_dot_assume_observe_index_literal)}, Model{typeof(demo_assume_submodel_observe_index_literal)}, Model{typeof(demo_dot_assume_observe_submodel)}, @@ -759,9 +788,10 @@ const DEMO_MODELS = ( demo_assume_multivariate_observe(), demo_dot_assume_observe_index(), demo_assume_dot_observe(), - demo_assume_observe_literal(), + demo_assume_multivariate_observe_literal(), demo_dot_assume_observe_index_literal(), - demo_assume_literal_dot_observe(), + demo_assume_dot_observe_literal(), + demo_assume_observe_literal(), demo_assume_submodel_observe_index_literal(), demo_dot_assume_observe_submodel(), demo_dot_assume_dot_observe_matrix(), diff --git a/src/threadsafe.jl b/src/threadsafe.jl index f7ce569b3..cedb0efad 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -79,8 +79,6 @@ setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo) haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) -link!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = link!(vi.varinfo, spl) -invlink!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = invlink!(vi.varinfo, spl) islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) function link!!( diff --git a/src/varinfo.jl b/src/varinfo.jl index bf2dd08c8..3ebb505e0 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1221,27 +1221,6 @@ function link!!( return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, spl, model) end -""" - link!(vi::VarInfo, spl::Sampler) - -Transform the values of the random variables sampled by `spl` in `vi` from the support -of their distributions to the Euclidean space and set their corresponding `"trans"` -flag values to `true`. -""" -function link!(vi::VarInfo, spl::AbstractSampler) - Base.depwarn( - "`link!(varinfo, sampler)` is deprecated, use `link!!(varinfo, sampler, model)` instead.", - :link!, - ) - return _link!(vi, spl) -end -function link!(vi::VarInfo, spl::AbstractSampler, spaceval::Val) - Base.depwarn( - "`link!(varinfo, sampler, spaceval)` is deprecated, use `link!!(varinfo, sampler, model)` instead.", - :link!, - ) - return _link!(vi, spl, spaceval) -end function _link!(vi::UntypedVarInfo, spl::AbstractSampler) # TODO: Change to a lazy iterator over `vns` vns = _getvns(vi, spl) @@ -1319,29 +1298,6 @@ function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, mode return maybe_invlink_before_eval!!(t, vi, context, model) end -""" - invlink!(vi::VarInfo, spl::AbstractSampler) - -Transform the values of the random variables sampled by `spl` in `vi` from the -Euclidean space back to the support of their distributions and sets their corresponding -`"trans"` flag values to `false`. -""" -function invlink!(vi::VarInfo, spl::AbstractSampler) - Base.depwarn( - "`invlink!(varinfo, sampler)` is deprecated, use `invlink!!(varinfo, sampler, model)` instead.", - :invlink!, - ) - return _invlink!(vi, spl) -end - -function invlink!(vi::VarInfo, spl::AbstractSampler, spaceval::Val) - Base.depwarn( - "`invlink!(varinfo, sampler, spaceval)` is deprecated, use `invlink!!(varinfo, sampler, model)` instead.", - :invlink!, - ) - return _invlink!(vi, spl, spaceval) -end - function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) vns = _getvns(vi, spl) if istrans(vi, vns[1]) diff --git a/test/Project.toml b/test/Project.toml index 48b6a7ad1..ef1d36b1d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -20,6 +20,7 @@ LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" diff --git a/test/ad.jl b/test/ad.jl index 768a55ad3..17981cf2a 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -34,4 +34,43 @@ end end end + + @testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin + # Failing model + t = 1:0.05:8 + σ = 0.3 + y = @. rand(sin(t) + Normal(0, σ)) + @model function state_space(y, TT, ::Type{T}=Float64) where {T} + # Priors + α ~ Normal(y[1], 0.001) + τ ~ Exponential(1) + η ~ filldist(Normal(0, 1), TT - 1) + σ ~ Exponential(1) + # create latent variable + x = Vector{T}(undef, TT) + x[1] = α + for t in 2:TT + x[t] = x[t - 1] + η[t - 1] * τ + end + # measurement model + y ~ MvNormal(x, σ^2 * I) + return x + end + model = state_space(y, length(t)) + + # Dummy sampling algorithm for testing. The test case can only be replicated + # with a custom sampler, it doesn't work with SampleFromPrior(). We need to + # overload assume so that model evaluation doesn't fail due to a lack + # of implementation + struct MyEmptyAlg end + DynamicPPL.getspace(::DynamicPPL.Sampler{MyEmptyAlg}) = () + DynamicPPL.assume(rng, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi) = + DynamicPPL.assume(dist, vn, vi) + + # Compiling the ReverseDiff tape used to fail here + spl = Sampler(MyEmptyAlg()) + vi = VarInfo(model) + ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl)) + @test LogDensityProblemsAD.ADgradient(AutoReverseDiff(; compile=true), ldf) isa Any + end end diff --git a/test/model.jl b/test/model.jl index d163f55f0..a19cb29d2 100644 --- a/test/model.jl +++ b/test/model.jl @@ -29,9 +29,11 @@ is_typed_varinfo(::DynamicPPL.AbstractVarInfo) = false is_typed_varinfo(varinfo::DynamicPPL.TypedVarInfo) = true is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true +const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() + @testset "model.jl" begin @testset "convenience functions" begin - model = gdemo_default # defined in test/test_util.jl + model = GDEMO_DEFAULT # sample from model and extract variables vi = VarInfo(model) @@ -55,53 +57,26 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true @test ljoint ≈ lp #### logprior, logjoint, loglikelihood for MCMC chains #### - for model in DynamicPPL.TestUtils.DEMO_MODELS # length(DynamicPPL.TestUtils.DEMO_MODELS)=12 - var_info = VarInfo(model) - vns = DynamicPPL.TestUtils.varnames(model) - syms = unique(DynamicPPL.getsym.(vns)) - - # generate a chain of sample parameter values. + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS N = 200 - vals_OrderedDict = mapreduce(hcat, 1:N) do _ - rand(OrderedDict, model) - end - vals_mat = mapreduce(hcat, 1:N) do i - [vals_OrderedDict[i][vn] for vn in vns] - end - i = 1 - for col in eachcol(vals_mat) - col_flattened = [] - [push!(col_flattened, x...) for x in col] - if i == 1 - chain_mat = Matrix(reshape(col_flattened, 1, length(col_flattened))) - else - chain_mat = vcat( - chain_mat, reshape(col_flattened, 1, length(col_flattened)) - ) - end - i += 1 - end - chain_mat = convert(Matrix{Float64}, chain_mat) - - # devise parameter names for chain - sample_values_vec = collect(values(vals_OrderedDict[1])) - symbol_names = [] - chain_sym_map = Dict() - for k in 1:length(keys(var_info)) - vn_parent = keys(var_info)[k] + chain = make_chain_from_prior(model, N) + logpriors = logprior(model, chain) + loglikelihoods = loglikelihood(model, chain) + logjoints = logjoint(model, chain) + + # Construct mapping of varname symbols to varname-parent symbols. + # Here, varname_leaves is used to ensure compatibility with the + # variables stored in the chain + var_info = VarInfo(model) + chain_sym_map = Dict{Symbol,Symbol}() + for vn_parent in keys(var_info) sym = DynamicPPL.getsym(vn_parent) - vn_children = DynamicPPL.varname_leaves(vn_parent, sample_values_vec[k]) # `varname_leaves` defined in src/utils.jl + vn_children = DynamicPPL.varname_leaves(vn_parent, var_info[vn_parent]) for vn_child in vn_children chain_sym_map[Symbol(vn_child)] = sym - symbol_names = [symbol_names; Symbol(vn_child)] end end - chain = Chains(chain_mat, symbol_names) - # calculate the pointwise loglikelihoods for the whole chain using the newly written functions - logpriors = logprior(model, chain) - loglikelihoods = loglikelihood(model, chain) - logjoints = logjoint(model, chain) # compare them with true values for i in 1:N samples_dict = Dict() @@ -125,8 +100,21 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end end + @testset "DynamicPPL#684: threadsafe evaluation with multiple types" begin + @model function multiple_types(x) + ns ~ filldist(Normal(0, 2.0), 3) + m ~ Uniform(0, 1) + return x ~ Normal(m, 1) + end + model = multiple_types(1) + chain = make_chain_from_prior(model, 10) + loglikelihood(model, chain) + logprior(model, chain) + logjoint(model, chain) + end + @testset "rng" begin - model = gdemo_default + model = GDEMO_DEFAULT for sampler in (SampleFromPrior(), SampleFromUniform()) for i in 1:10 @@ -144,13 +132,15 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end @testset "defaults without VarInfo, Sampler, and Context" begin - model = gdemo_default + model = GDEMO_DEFAULT Random.seed!(100) - s, m = model() + retval = model() Random.seed!(100) - @test model(Random.default_rng()) == (s, m) + retval2 = model(Random.default_rng()) + @test retval2.s == retval.s + @test retval2.m == retval.m end @testset "nameof" begin @@ -184,7 +174,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end @testset "Internal methods" begin - model = gdemo_default + model = GDEMO_DEFAULT # sample from model and extract variables vi = VarInfo(model) @@ -224,7 +214,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end @testset "rand" begin - model = gdemo_default + model = GDEMO_DEFAULT Random.seed!(1776) s, m = model() @@ -309,7 +299,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end end - @testset "generated_quantities on `LKJCholesky`" begin + @testset "returned() on `LKJCholesky`" begin n = 10 d = 2 model = DynamicPPL.TestUtils.demo_lkjchol(d) @@ -333,7 +323,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true ) # Test! - results = generated_quantities(model, chain) + results = returned(model, chain) for (x_true, result) in zip(xs, results) @test x_true.UL == result.x.UL end @@ -352,7 +342,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true info=(varname_to_symbol=vns_to_syms_with_extra,), ) # Test! - results = generated_quantities(model, chain_with_extra) + results = returned(model, chain_with_extra) for (x_true, result) in zip(xs, results) @test x_true.UL == result.x.UL end diff --git a/test/model_utils.jl b/test/model_utils.jl new file mode 100644 index 000000000..720ae55aa --- /dev/null +++ b/test/model_utils.jl @@ -0,0 +1,20 @@ +@testset "model_utils.jl" begin + @testset "value_iterator_from_chain" begin + @testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS + # Check that the values generated by value_iterator_from_chain + # match the values in the original chain + chain = make_chain_from_prior(model, 10) + for (i, d) in enumerate(value_iterator_from_chain(model, chain)) + for vn in keys(d) + val = DynamicPPL.getvalue(d, vn) + # Because value_iterator_from_chain groups varnames with + # the same parent symbol, we have to ungroup them here + for vn_leaf in DynamicPPL.varname_leaves(vn, val) + val_leaf = DynamicPPL.getvalue(d, vn_leaf) + @test val_leaf == chain[i, Symbol(vn_leaf), 1] + end + end + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index aea02a337..38e6f87ca 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,115 +31,77 @@ using LinearAlgebra # Diagonal using JET: JET using Combinatorics: combinations +using OrderedCollections: OrderedSet using DynamicPPL: getargs_dottilde, getargs_tilde, Selector -const DIRECTORY_DynamicPPL = dirname(dirname(pathof(DynamicPPL))) -const DIRECTORY_Turing_tests = joinpath(DIRECTORY_DynamicPPL, "test", "turing") -const GROUP = get(ENV, "GROUP", "All") - Random.seed!(100) include("test_util.jl") @testset "DynamicPPL.jl" begin - if GROUP == "All" || GROUP == "DynamicPPL" - @testset "interface" begin - include("utils.jl") - include("compiler.jl") - include("varnamedvector.jl") - include("varinfo.jl") - include("simple_varinfo.jl") - include("model.jl") - include("sampler.jl") - include("independence.jl") - include("distribution_wrappers.jl") - include("contexts.jl") - include("context_implementations.jl") - include("logdensityfunction.jl") - include("linking.jl") - - include("threadsafe.jl") - - include("serialization.jl") - - include("pointwise_logdensities.jl") - - include("lkj.jl") - - include("debug_utils.jl") - end - - @testset "compat" begin - include(joinpath("compat", "ad.jl")) - end - - @testset "extensions" begin - include("ext/DynamicPPLMCMCChainsExt.jl") - include("ext/DynamicPPLJETExt.jl") - end - - @testset "ad" begin - include("ext/DynamicPPLForwardDiffExt.jl") - include("ext/DynamicPPLMooncakeExt.jl") - include("ad.jl") - end + @testset "interface" begin + include("utils.jl") + include("compiler.jl") + include("varnamedvector.jl") + include("varinfo.jl") + include("simple_varinfo.jl") + include("model.jl") + include("sampler.jl") + include("independence.jl") + include("distribution_wrappers.jl") + include("contexts.jl") + include("context_implementations.jl") + include("logdensityfunction.jl") + include("linking.jl") + include("threadsafe.jl") + include("serialization.jl") + include("pointwise_logdensities.jl") + include("lkj.jl") + include("debug_utils.jl") + end - @testset "prob and logprob macro" begin - @test_throws ErrorException prob"..." - @test_throws ErrorException logprob"..." - end + @testset "compat" begin + include(joinpath("compat", "ad.jl")) + end - @testset "doctests" begin - DocMeta.setdocmeta!( - DynamicPPL, - :DocTestSetup, - :(using DynamicPPL, Distributions); - recursive=true, - ) - doctestfilters = [ - # Older versions will show "0 element Array" instead of "Type[]". - r"(Any\[\]|0-element Array{.+,[0-9]+})", - # Older versions will show "Array{...,1}" instead of "Vector{...}". - r"(Array{.+,\s?1}|Vector{.+})", - # Older versions will show "Array{...,2}" instead of "Matrix{...}". - r"(Array{.+,\s?2}|Matrix{.+})", - # Errors from macros sometimes result in `LoadError: LoadError:` - # rather than `LoadError:`, depending on Julia version. - r"ERROR: (LoadError:\s)+", - # Older versions do not have `;;]` but instead just `]` at end of the line - # => need to treat `;;]` and `]` as the same, i.e. ignore them if at the end of a line - r"(;;){0,1}\]$"m, - # Ignore the source of a warning in the doctest output, since this is dependent on host. - # This is a line that starts with "└ @ " and ends with the line number. - r"└ @ .+:[0-9]+", - ] - doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters) - end + @testset "extensions" begin + include("ext/DynamicPPLMCMCChainsExt.jl") + include("ext/DynamicPPLJETExt.jl") end - if GROUP == "All" || GROUP == "Downstream" - @testset "turing" begin - try - # activate separate test environment - Pkg.activate(DIRECTORY_Turing_tests) - Pkg.develop(PackageSpec(; path=DIRECTORY_DynamicPPL)) - Pkg.instantiate() + @testset "ad" begin + include("ext/DynamicPPLForwardDiffExt.jl") + include("ext/DynamicPPLMooncakeExt.jl") + include("ad.jl") + end - # make sure that the new environment is considered `using` and `import` statements - # (not added automatically on Julia 1.3, see e.g. PR #209) - if !(joinpath(DIRECTORY_Turing_tests, "Project.toml") in Base.load_path()) - pushfirst!(LOAD_PATH, DIRECTORY_Turing_tests) - end + @testset "prob and logprob macro" begin + @test_throws ErrorException prob"..." + @test_throws ErrorException logprob"..." + end - include(joinpath("turing", "runtests.jl")) - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - # If we can't resolve that means this is incompatible by SemVer and this is fine - # It means we marked this as a breaking change, so we don't need to worry about - # Mistakenly introducing a breaking change, as we have intentionally made one - @info "Not compatible with this release. No problem." exception = err - end - end + @testset "doctests" begin + DocMeta.setdocmeta!( + DynamicPPL, :DocTestSetup, :(using DynamicPPL, Distributions); recursive=true + ) + doctestfilters = [ + # Older versions will show "0 element Array" instead of "Type[]". + r"(Any\[\]|0-element Array{.+,[0-9]+})", + # Older versions will show "Array{...,1}" instead of "Vector{...}". + r"(Array{.+,\s?1}|Vector{.+})", + # Older versions will show "Array{...,2}" instead of "Matrix{...}". + r"(Array{.+,\s?2}|Matrix{.+})", + # Errors from macros sometimes result in `LoadError: LoadError:` + # rather than `LoadError:`, depending on Julia version. + r"ERROR: (LoadError:\s)+", + # Older versions do not have `;;]` but instead just `]` at end of the line + # => need to treat `;;]` and `]` as the same, i.e. ignore them if at the end of a line + r"(;;){0,1}\]$"m, + # Ignore the source of a warning in the doctest output, since this is dependent on host. + # This is a line that starts with "└ @ " and ends with the line number. + r"└ @ .+:[0-9]+", + ] + doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters) end end diff --git a/test/test_util.jl b/test/test_util.jl index f1325b729..27a68456c 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -43,40 +43,6 @@ function test_model_ad(model, logp_manual) @test back(1)[1] ≈ grad end -""" - test_setval!(model, chain; sample_idx = 1, chain_idx = 1) - -Test `setval!` on `model` and `chain`. - -Worth noting that this only supports models containing symbols of the forms -`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. -""" -function test_setval!(model, chain; sample_idx=1, chain_idx=1) - var_info = VarInfo(model) - spl = SampleFromPrior() - θ_old = var_info[spl] - DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) - θ_new = var_info[spl] - @test θ_old != θ_new - vals = DynamicPPL.values_as(var_info, OrderedDict) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) - for (n, v) in mapreduce(collect, vcat, iters) - n = string(n) - if Symbol(n) ∉ keys(chain) - # Assume it's a group - chain_val = vec( - MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] - ) - v_true = vec(v) - else - chain_val = chain[sample_idx, n, chain_idx] - v_true = v - end - - @test v_true == chain_val - end -end - """ short_varinfo_name(vi::AbstractVarInfo) @@ -110,3 +76,36 @@ function modify_value_representation(nt::NamedTuple) end return modified_nt end + +""" + make_chain_from_prior([rng,] model, n_iters) + +Construct an MCMCChains.Chains object by sampling from the prior of `model` for +`n_iters` iterations. +""" +function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::Int) + # Sample from the prior + varinfos = [VarInfo(rng, model) for _ in 1:n_iters] + # Extract all varnames found in any dictionary. Doing it this way guards + # against the possibility of having different varnames in different + # dictionaries, e.g. for models that have dynamic variables / array sizes + varnames = OrderedSet{VarName}() + # Convert each varinfo into an OrderedDict of vns => params. + # We have to use varname_and_value_leaves so that each parameter is a scalar + dicts = map(varinfos) do t + vals = DynamicPPL.values_as(t, OrderedDict) + iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + tuples = mapreduce(collect, vcat, iters) + push!(varnames, map(first, tuples)...) + OrderedDict(tuples) + end + # Convert back to list + varnames = collect(varnames) + # Construct matrix of values + vals = [get(dict, vn, missing) for dict in dicts, vn in varnames] + # Construct and return the Chains object + return Chains(vals, varnames) +end +function make_chain_from_prior(model::Model, n_iters::Int) + return make_chain_from_prior(Random.default_rng(), model, n_iters) +end diff --git a/test/turing/Project.toml b/test/turing/Project.toml deleted file mode 100644 index 62e8cc8eb..000000000 --- a/test/turing/Project.toml +++ /dev/null @@ -1,19 +0,0 @@ -[deps] -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" -HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" - -[compat] -Distributions = "0.25" -DynamicPPL = "0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30, 0.31" -HypothesisTests = "0.11" -MCMCChains = "6" -ReverseDiff = "1.15" -Turing = "0.33, 0.34, 0.35" -julia = "1.7" diff --git a/test/turing/compiler.jl b/test/turing/compiler.jl deleted file mode 100644 index 5c46ab777..000000000 --- a/test/turing/compiler.jl +++ /dev/null @@ -1,348 +0,0 @@ -@testset "compiler.jl" begin - @testset "assume" begin - @model function test_assume() - x ~ Bernoulli(1) - y ~ Bernoulli(x / 2) - return x, y - end - - smc = SMC() - pg = PG(10) - - res1 = sample(test_assume(), smc, 1000) - res2 = sample(test_assume(), pg, 1000) - - check_numerical(res1, [:y], [0.5]; atol=0.1) - check_numerical(res2, [:y], [0.5]; atol=0.1) - - # Check that all xs are 1. - @test all(isone, res1[:x]) - @test all(isone, res2[:x]) - end - @testset "beta binomial" begin - prior = Beta(2, 2) - obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1] - exact = Beta(prior.α + sum(obs), prior.β + length(obs) - sum(obs)) - meanp = exact.α / (exact.α + exact.β) - - @model function testbb(obs) - p ~ Beta(2, 2) - x ~ Bernoulli(p) - for i in 1:length(obs) - obs[i] ~ Bernoulli(p) - end - return p, x - end - - smc = SMC() - pg = PG(10) - gibbs = Gibbs(HMC(0.2, 3, :p), PG(10, :x)) - - chn_s = sample(testbb(obs), smc, 1000) - chn_p = sample(testbb(obs), pg, 2000) - chn_g = sample(testbb(obs), gibbs, 1500) - - check_numerical(chn_s, [:p], [meanp]; atol=0.05) - check_numerical(chn_p, [:x], [meanp]; atol=0.1) - check_numerical(chn_g, [:x], [meanp]; atol=0.1) - end - @testset "forbid global" begin - xs = [1.5 2.0] - # xx = 1 - - @model function fggibbstest(xs) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - # xx ~ Normal(m, sqrt(s)) # this is illegal - - for i in 1:length(xs) - xs[i] ~ Normal(m, sqrt(s)) - # for xx in xs - # xx ~ Normal(m, sqrt(s)) - end - return s, m - end - - gibbs = Gibbs(PG(10, :s), HMC(0.4, 8, :m)) - chain = sample(fggibbstest(xs), gibbs, 2) - end - @testset "new grammar" begin - x = Float64[1 2] - - @model function gauss(x) - priors = Array{Float64}(undef, 2) - priors[1] ~ InverseGamma(2, 3) # s - priors[2] ~ Normal(0, sqrt(priors[1])) # m - for i in 1:length(x) - x[i] ~ Normal(priors[2], sqrt(priors[1])) - end - return priors - end - - chain = sample(gauss(x), PG(10), 10) - chain = sample(gauss(x), SMC(), 10) - - @model function gauss2(::Type{TV}=Vector{Float64}; x) where {TV} - priors = TV(undef, 2) - priors[1] ~ InverseGamma(2, 3) # s - priors[2] ~ Normal(0, sqrt(priors[1])) # m - for i in 1:length(x) - x[i] ~ Normal(priors[2], sqrt(priors[1])) - end - return priors - end - - @test_throws ErrorException chain = sample(gauss2(; x=x), PG(10), 10) - @test_throws ErrorException chain = sample(gauss2(; x=x), SMC(), 10) - - @test_throws ErrorException chain = sample( - gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), PG(10), 10 - ) - @test_throws ErrorException chain = sample( - gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), SMC(), 10 - ) - end - @testset "new interface" begin - obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1] - - @model function newinterface(obs) - p ~ Beta(2, 2) - for i in 1:length(obs) - obs[i] ~ Bernoulli(p) - end - return p - end - - chain = sample( - newinterface(obs), - HMC(0.75, 3, :p, :x; adtype=AutoForwardDiff(; chunksize=2)), - 100, - ) - end - @testset "no return" begin - @model function noreturn(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - for i in 1:length(x) - x[i] ~ Normal(m, sqrt(s)) - end - end - - chain = sample(noreturn([1.5 2.0]), HMC(0.15, 6), 1000) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]) - end - @testset "observe" begin - @model function test() - z ~ Normal(0, 1) - x ~ Bernoulli(1) - 1 ~ Bernoulli(x / 2) - 0 ~ Bernoulli(x / 2) - return x - end - - is = IS() - smc = SMC() - pg = PG(10) - - res_is = sample(test(), is, 10000) - res_smc = sample(test(), smc, 1000) - res_pg = sample(test(), pg, 100) - - @test all(isone, res_is[:x]) - @test res_is.logevidence ≈ 2 * log(0.5) - - @test all(isone, res_smc[:x]) - @test res_smc.logevidence ≈ 2 * log(0.5) - - @test all(isone, res_pg[:x]) - end - @testset "sample" begin - alg = Gibbs(HMC(0.2, 3, :m), PG(10, :s)) - chn = sample(gdemo_default, alg, 1000) - end - @testset "vectorization @." begin - @model function vdemo1(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - @. x ~ Normal(m, sqrt(s)) - return s, m - end - - alg = HMC(0.01, 5) - x = randn(100) - res = sample(vdemo1(x), alg, 250) - - @model function vdemo1b(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - @. x ~ Normal(m, $(sqrt(s))) - return s, m - end - - res = sample(vdemo1b(x), alg, 250) - - @model function vdemo2(x) - μ ~ MvNormal(zeros(size(x, 1)), I) - @. x ~ $(MvNormal(μ, I)) - end - - D = 2 - alg = HMC(0.01, 5) - res = sample(vdemo2(randn(D, 100)), alg, 250) - - # Vector assumptions - N = 10 - alg = HMC(0.2, 4; adtype=AutoForwardDiff(; chunksize=N)) - - @model function vdemo3() - x = Vector{Real}(undef, N) - for i in 1:N - x[i] ~ Normal(0, sqrt(4)) - end - end - - t_loop = @elapsed res = sample(vdemo3(), alg, 1000) - - # Test for vectorize UnivariateDistribution - @model function vdemo4() - x = Vector{Real}(undef, N) - @. x ~ Normal(0, 2) - end - - t_vec = @elapsed res = sample(vdemo4(), alg, 1000) - - @model vdemo5() = x ~ MvNormal(zeros(N), 4 * I) - - t_mv = @elapsed res = sample(vdemo5(), alg, 1000) - - println("Time for") - println(" Loop : ", t_loop) - println(" Vec : ", t_vec) - println(" Mv : ", t_mv) - - # Transformed test - @model function vdemo6() - x = Vector{Real}(undef, N) - @. x ~ InverseGamma(2, 3) - end - - sample(vdemo6(), alg, 1000) - - N = 3 - @model function vdemo7() - x = Array{Real}(undef, N, N) - @. x ~ [InverseGamma(2, 3) for i in 1:N] - end - - sample(vdemo7(), alg, 1000) - end - @testset "vectorization .~" begin - @model function vdemo1(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - x .~ Normal(m, sqrt(s)) - return s, m - end - - alg = HMC(0.01, 5) - x = randn(100) - res = sample(vdemo1(x), alg, 250) - - @model function vdemo2(x) - μ ~ MvNormal(zeros(size(x, 1)), I) - return x .~ MvNormal(μ, I) - end - - D = 2 - alg = HMC(0.01, 5) - res = sample(vdemo2(randn(D, 100)), alg, 250) - - # Vector assumptions - N = 10 - alg = HMC(0.2, 4; adtype=AutoForwardDiff(; chunksize=N)) - - @model function vdemo3() - x = Vector{Real}(undef, N) - for i in 1:N - x[i] ~ Normal(0, sqrt(4)) - end - end - - t_loop = @elapsed res = sample(vdemo3(), alg, 1000) - - # Test for vectorize UnivariateDistribution - @model function vdemo4() - x = Vector{Real}(undef, N) - return x .~ Normal(0, 2) - end - - t_vec = @elapsed res = sample(vdemo4(), alg, 1000) - - @model vdemo5() = x ~ MvNormal(zeros(N), 4 * I) - - t_mv = @elapsed res = sample(vdemo5(), alg, 1000) - - println("Time for") - println(" Loop : ", t_loop) - println(" Vec : ", t_vec) - println(" Mv : ", t_mv) - - # Transformed test - @model function vdemo6() - x = Vector{Real}(undef, N) - return x .~ InverseGamma(2, 3) - end - - sample(vdemo6(), alg, 1000) - - @model function vdemo7() - x = Array{Real}(undef, N, N) - return x .~ [InverseGamma(2, 3) for i in 1:N] - end - - sample(vdemo7(), alg, 1000) - end - @testset "Type parameters" begin - N = 10 - alg = HMC(0.01, 5; adtype=AutoForwardDiff(; chunksize=N)) - x = randn(1000) - @model function vdemo1(::Type{T}=Float64) where {T} - x = Vector{T}(undef, N) - for i in 1:N - x[i] ~ Normal(0, sqrt(4)) - end - end - - t_loop = @elapsed res = sample(vdemo1(), alg, 250) - t_loop = @elapsed res = sample(vdemo1(DynamicPPL.TypeWrap{Float64}()), alg, 250) - - vdemo1kw(; T) = vdemo1(T) - t_loop = @elapsed res = sample( - vdemo1kw(; T=DynamicPPL.TypeWrap{Float64}()), alg, 250 - ) - - @model function vdemo2(::Type{T}=Float64) where {T<:Real} - x = Vector{T}(undef, N) - @. x ~ Normal(0, 2) - end - - t_vec = @elapsed res = sample(vdemo2(), alg, 250) - t_vec = @elapsed res = sample(vdemo2(DynamicPPL.TypeWrap{Float64}()), alg, 250) - - vdemo2kw(; T) = vdemo2(T) - t_vec = @elapsed res = sample( - vdemo2kw(; T=DynamicPPL.TypeWrap{Float64}()), alg, 250 - ) - - @model function vdemo3(::Type{TV}=Vector{Float64}) where {TV<:AbstractVector} - x = TV(undef, N) - @. x ~ InverseGamma(2, 3) - end - - sample(vdemo3(), alg, 250) - sample(vdemo3(DynamicPPL.TypeWrap{Vector{Float64}}()), alg, 250) - - vdemo3kw(; T) = vdemo3(T) - sample(vdemo3kw(; T=DynamicPPL.TypeWrap{Vector{Float64}}()), alg, 250) - end -end diff --git a/test/turing/loglikelihoods.jl b/test/turing/loglikelihoods.jl deleted file mode 100644 index 5f1c41572..000000000 --- a/test/turing/loglikelihoods.jl +++ /dev/null @@ -1,41 +0,0 @@ -@testset "loglikelihoods.jl" begin - @model function demo(xs, y) - s ~ InverseGamma(2, 3) - m ~ Normal(0, √s) - for i in eachindex(xs) - xs[i] ~ Normal(m, √s) - end - - return y ~ Normal(m, √s) - end - - xs = randn(3) - y = randn() - model = demo(xs, y) - chain = sample(model, MH(), MCMCThreads(), 100, 2) - var_to_likelihoods = pointwise_loglikelihoods( - model, MCMCChains.get_sections(chain, :parameters) - ) - @test haskey(var_to_likelihoods, "xs[1]") - @test haskey(var_to_likelihoods, "xs[2]") - @test haskey(var_to_likelihoods, "xs[3]") - @test haskey(var_to_likelihoods, "y") - - for chain_idx in MCMCChains.chains(chain) - for (i, (s, m)) in enumerate(zip(chain[:, :s, chain_idx], chain[:, :m, chain_idx])) - @test logpdf(Normal(m, √s), xs[1]) == var_to_likelihoods["xs[1]"][i, chain_idx] - @test logpdf(Normal(m, √s), xs[2]) == var_to_likelihoods["xs[2]"][i, chain_idx] - @test logpdf(Normal(m, √s), xs[3]) == var_to_likelihoods["xs[3]"][i, chain_idx] - @test logpdf(Normal(m, √s), y) == var_to_likelihoods["y"][i, chain_idx] - end - end - - var_info = VarInfo(model) - results = pointwise_loglikelihoods(model, var_info) - var_to_likelihoods = Dict(string(vn) => ℓ for (vn, ℓ) in results) - s, m = var_info[SampleFromPrior()] - @test [logpdf(Normal(m, √s), xs[1])] == var_to_likelihoods["xs[1]"] - @test [logpdf(Normal(m, √s), xs[2])] == var_to_likelihoods["xs[2]"] - @test [logpdf(Normal(m, √s), xs[3])] == var_to_likelihoods["xs[3]"] - @test [logpdf(Normal(m, √s), y)] == var_to_likelihoods["y"] -end diff --git a/test/turing/model.jl b/test/turing/model.jl deleted file mode 100644 index 599fba21b..000000000 --- a/test/turing/model.jl +++ /dev/null @@ -1,27 +0,0 @@ -@testset "model.jl" begin - @testset "setval! & generated_quantities" begin - @testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS - chain = sample(model, Prior(), 10) - # A simple way of checking that the computation is determinstic: run twice and compare. - res1 = generated_quantities(model, MCMCChains.get_sections(chain, :parameters)) - res2 = generated_quantities(model, MCMCChains.get_sections(chain, :parameters)) - @test all(res1 .== res2) - test_setval!(model, MCMCChains.get_sections(chain, :parameters)) - end - end - - @testset "value_iterator_from_chain" begin - @testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS - chain = sample(model, Prior(), 10; progress=false) - for (i, d) in enumerate(value_iterator_from_chain(model, chain)) - for vn in keys(d) - val = DynamicPPL.getvalue(d, vn) - for vn_leaf in DynamicPPL.varname_leaves(vn, val) - val_leaf = DynamicPPL.getvalue(d, vn_leaf) - @test val_leaf == chain[i, Symbol(vn_leaf), 1] - end - end - end - end - end -end diff --git a/test/turing/runtests.jl b/test/turing/runtests.jl deleted file mode 100644 index b3ff72819..000000000 --- a/test/turing/runtests.jl +++ /dev/null @@ -1,24 +0,0 @@ -using DynamicPPL -using Turing -using LinearAlgebra -using ReverseDiff - -using Random -using Test - -setprogress!(false) - -Random.seed!(100) - -# load test utilities -include(joinpath(pathof(DynamicPPL), "..", "..", "test", "test_util.jl")) -include(joinpath(pathof(Turing), "..", "..", "test", "test_utils", "numerical_tests.jl")) - -using .NumericalTests: check_numerical - -@testset "Turing" begin - include("compiler.jl") - include("loglikelihoods.jl") - include("model.jl") - include("varinfo.jl") -end diff --git a/test/turing/varinfo.jl b/test/turing/varinfo.jl deleted file mode 100644 index e5b8eb79f..000000000 --- a/test/turing/varinfo.jl +++ /dev/null @@ -1,360 +0,0 @@ -@testset "varinfo.jl" begin - # Declare empty model to make the Sampler constructor work. - @model empty_model() = begin - x = 1 - end - - function randr( - vi::VarInfo, vn::VarName, dist::Distribution, spl::Sampler, count::Bool=false - ) - if !haskey(vi, vn) - r = rand(dist) - push!!(vi, vn, r, dist, spl) - r - elseif is_flagged(vi, vn, "del") - unset_flag!(vi, vn, "del") - r = rand(dist) - vi[vn] = DynamicPPL.tovec(r) - setorder!(vi, vn, get_num_produce(vi)) - r - else - count && checkindex(vn, vi, spl) - DynamicPPL.updategid!(vi, vn, spl) - vi[vn] - end - end - - @testset "link!" begin - # Test linking spl and vi: - # link!, invlink!, istrans - @model gdemo(x, y) = begin - s ~ InverseGamma(2, 3) - m ~ Uniform(0, 2) - x ~ Normal(m, sqrt(s)) - y ~ Normal(m, sqrt(s)) - end - model = gdemo(1.0, 2.0) - - vi = VarInfo() - meta = vi.metadata - model(vi, SampleFromUniform()) - @test all(x -> !istrans(vi, x), meta.vns) - - alg = HMC(0.1, 5) - spl = Sampler(alg, model) - v = copy(meta.vals) - link!(vi, spl) - @test all(x -> istrans(vi, x), meta.vns) - invlink!(vi, spl) - @test all(x -> !istrans(vi, x), meta.vns) - @test meta.vals == v - - vi = TypedVarInfo(vi) - meta = vi.metadata - alg = HMC(0.1, 5) - spl = Sampler(alg, model) - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) - v_s = copy(meta.s.vals) - v_m = copy(meta.m.vals) - link!(vi, spl) - @test all(x -> istrans(vi, x), meta.s.vns) - @test all(x -> istrans(vi, x), meta.m.vns) - invlink!(vi, spl) - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) - @test meta.s.vals == v_s - @test meta.m.vals == v_m - - # Transforming only a subset of the variables - link!(vi, spl, Val((:m,))) - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> istrans(vi, x), meta.m.vns) - invlink!(vi, spl, Val((:m,))) - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) - @test meta.s.vals == v_s - @test meta.m.vals == v_m - end - @testset "orders" begin - csym = gensym() # unique per model - vn_z1 = @varname z[1] - vn_z2 = @varname z[2] - vn_z3 = @varname z[3] - vn_z4 = @varname z[4] - vn_a1 = @varname a[1] - vn_a2 = @varname a[2] - vn_b = @varname b - - vi = VarInfo() - dists = [Categorical([0.7, 0.3]), Normal()] - - spl1 = Sampler(PG(5), empty_model()) - spl2 = Sampler(PG(5), empty_model()) - - # First iteration, variables are added to vi - # variables samples in order: z1,a1,z2,a2,z3 - increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) - increment_num_produce!(vi) - randr(vi, vn_b, dists[2], spl2) - randr(vi, vn_z2, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) - increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) - @test vi.metadata.orders == [1, 1, 2, 2, 2, 3] - @test get_num_produce(vi) == 3 - - reset_num_produce!(vi) - set_retained_vns_del_by_spl!(vi, spl1) - @test is_flagged(vi, vn_z1, "del") - @test is_flagged(vi, vn_a1, "del") - @test is_flagged(vi, vn_z2, "del") - @test is_flagged(vi, vn_a2, "del") - @test is_flagged(vi, vn_z3, "del") - - increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) - increment_num_produce!(vi) - randr(vi, vn_z2, dists[1], spl1) - increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) - @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] - @test get_num_produce(vi) == 3 - - vi = empty!!(TypedVarInfo(vi)) - # First iteration, variables are added to vi - # variables samples in order: z1,a1,z2,a2,z3 - increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) - increment_num_produce!(vi) - randr(vi, vn_b, dists[2], spl2) - randr(vi, vn_z2, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) - increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) - @test vi.metadata.z.orders == [1, 2, 3] - @test vi.metadata.a.orders == [1, 2] - @test vi.metadata.b.orders == [2] - @test get_num_produce(vi) == 3 - - reset_num_produce!(vi) - set_retained_vns_del_by_spl!(vi, spl1) - @test is_flagged(vi, vn_z1, "del") - @test is_flagged(vi, vn_a1, "del") - @test is_flagged(vi, vn_z2, "del") - @test is_flagged(vi, vn_a2, "del") - @test is_flagged(vi, vn_z3, "del") - - increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) - increment_num_produce!(vi) - randr(vi, vn_z2, dists[1], spl1) - increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) - @test vi.metadata.z.orders == [1, 2, 3] - @test vi.metadata.a.orders == [1, 3] - @test vi.metadata.b.orders == [2] - @test get_num_produce(vi) == 3 - end - @testset "replay" begin - # Generate synthesised data - xs = rand(Normal(0.5, 1), 100) - - # Define model - @model function priorsinarray(xs, ::Type{T}=Float64) where {T} - begin - priors = Vector{T}(undef, 2) - priors[1] ~ InverseGamma(2, 3) - priors[2] ~ Normal(0, sqrt(priors[1])) - for i in 1:length(xs) - xs[i] ~ Normal(priors[2], sqrt(priors[1])) - end - priors - end - end - - # Sampling - chain = sample(priorsinarray(xs), HMC(0.01, 10), 10) - end - @testset "varname" begin - @model function mat_name_test() - p = Array{Any}(undef, 2, 2) - for i in 1:2, j in 1:2 - p[i, j] ~ Normal(0, 1) - end - return p - end - chain = sample(mat_name_test(), HMC(0.2, 4), 1000) - check_numerical(chain, ["p[1, 1]"], [0]; atol=0.25) - - @model function marr_name_test() - p = Array{Array{Any}}(undef, 2) - p[1] = Array{Any}(undef, 2) - p[2] = Array{Any}(undef, 2) - for i in 1:2, j in 1:2 - p[i][j] ~ Normal(0, 1) - end - return p - end - - chain = sample(marr_name_test(), HMC(0.2, 4), 1000) - check_numerical(chain, ["p[1][1]"], [0]; atol=0.25) - end - @testset "varinfo" begin - dists = [Normal(0, 1), MvNormal(zeros(2), I), Wishart(7, [1 0.5; 0.5 1])] - function test_varinfo!(vi) - @test getlogp(vi) === 0.0 - vi = setlogp!!(vi, 1) - @test getlogp(vi) === 1.0 - vi = acclogp!!(vi, 1) - @test getlogp(vi) === 2.0 - vi = resetlogp!!(vi) - @test getlogp(vi) === 0.0 - - spl2 = Sampler(PG(5, :w, :u), empty_model()) - vn_w = @varname w - randr(vi, vn_w, dists[1], spl2, true) - - vn_x = @varname x - vn_y = @varname y - vn_z = @varname z - vns = [vn_x, vn_y, vn_z] - - spl1 = Sampler(PG(5, :x, :y, :z), empty_model()) - for i in 1:3 - r = randr(vi, vns[i], dists[i], spl1, false) - val = vi[vns[i]] - @test sum(val - r) <= 1e-9 - end - - idcs = DynamicPPL._getidcs(vi, spl1) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 3 - else - @test length(idcs) == 3 - end - @test length(vi[spl1]) == 7 - - idcs = DynamicPPL._getidcs(vi, spl2) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 1 - else - @test length(idcs) == 1 - end - @test length(vi[spl2]) == 1 - - vn_u = @varname u - randr(vi, vn_u, dists[1], spl2, true) - - idcs = DynamicPPL._getidcs(vi, spl2) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 2 - else - @test length(idcs) == 2 - end - @test length(vi[spl2]) == 2 - end - vi = VarInfo() - test_varinfo!(vi) - test_varinfo!(empty!!(TypedVarInfo(vi))) - - @model igtest() = begin - x ~ InverseGamma(2, 3) - y ~ InverseGamma(2, 3) - z ~ InverseGamma(2, 3) - w ~ InverseGamma(2, 3) - u ~ InverseGamma(2, 3) - end - - # Test the update of group IDs - g_demo_f = igtest() - - # This test section no longer seems as applicable, considering the - # user will never end up using an UntypedVarInfo. The `VarInfo` - # Varible is also not passed around in the same way as it used to be. - - # TODO: Has to be fixed - - #= g = Sampler(Gibbs(PG(10, :x, :y, :z), HMC(0.4, 8, :w, :u)), g_demo_f) - vi = VarInfo() - g_demo_f(vi, SampleFromPrior()) - _, state = @inferred AbstractMCMC.step(Random.default_rng(), g_demo_f, g) - pg, hmc = state.states - @test pg isa TypedVarInfo - @test hmc isa Turing.Inference.HMCState - vi1 = state.vi - @test mapreduce(x -> x.gids, vcat, vi1.metadata) == - [Set([pg.selector]), Set([pg.selector]), Set([pg.selector]), Set{Selector}(), Set{Selector}()] - - @inferred g_demo_f(vi1, hmc) - @test mapreduce(x -> x.gids, vcat, vi1.metadata) == - [Set([pg.selector]), Set([pg.selector]), Set([pg.selector]), Set([hmc.selector]), Set([hmc.selector])] - - g = Sampler(Gibbs(PG(10, :x, :y, :z), HMC(0.4, 8, :w, :u)), g_demo_f) - pg, hmc = g.state.samplers - vi = empty!!(TypedVarInfo(vi)) - @inferred g_demo_f(vi, SampleFromPrior()) - pg.state.vi = vi - step!(Random.default_rng(), g_demo_f, pg, 1) - vi = pg.state.vi - @inferred g_demo_f(vi, hmc) - @test vi.metadata.x.gids[1] == Set([pg.selector]) - @test vi.metadata.y.gids[1] == Set([pg.selector]) - @test vi.metadata.z.gids[1] == Set([pg.selector]) - @test vi.metadata.w.gids[1] == Set([hmc.selector]) - @test vi.metadata.u.gids[1] == Set([hmc.selector]) =# - end - - @testset "Turing#2151: eltype(vi, spl)" begin - # build data - t = 1:0.05:8 - σ = 0.3 - y = @. rand(sin(t) + Normal(0, σ)) - - @model function state_space(y, TT, ::Type{T}=Float64) where {T} - # Priors - α ~ Normal(y[1], 0.001) - τ ~ Exponential(1) - η ~ filldist(Normal(0, 1), TT - 1) - σ ~ Exponential(1) - - # create latent variable - x = Vector{T}(undef, TT) - x[1] = α - for t in 2:TT - x[t] = x[t - 1] + η[t - 1] * τ - end - - # measurement model - y ~ MvNormal(x, σ^2 * I) - - return x - end - - n = 10 - model = state_space(y, length(t)) - @test size(sample(model, NUTS(; adtype=AutoReverseDiff(true)), n), 1) == n - end - - if Threads.nthreads() > 1 - @testset "DynamicPPL#684: OrderedDict with multiple types when multithreaded" begin - @model function f(x) - ns ~ filldist(Normal(0, 2.0), 3) - m ~ Uniform(0, 1) - return x ~ Normal(m, 1) - end - model = f(1) - chain = sample(model, NUTS(), MCMCThreads(), 10, 2) - loglikelihood(model, chain) - logprior(model, chain) - logjoint(model, chain) - end - end -end diff --git a/test/varinfo.jl b/test/varinfo.jl index ae4319904..9a55cffb9 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,3 +1,8 @@ +# Dummy algorithm for testing +# Invoke with: DynamicPPL.Sampler(MyAlg{(:x, :y)}(), ...) +struct MyAlg{space} end +DynamicPPL.getspace(::DynamicPPL.Sampler{MyAlg{space}}) where {space} = space + function check_varinfo_keys(varinfo, vns) if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, @@ -14,9 +19,29 @@ function check_varinfo_keys(varinfo, vns) end end -# A simple "algorithm" which only has `s` variables in its space. -struct MySAlg end -DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) +function randr( + vi::DynamicPPL.VarInfo, + vn::VarName, + dist::Distribution, + spl::DynamicPPL.Sampler, + count::Bool=false, +) + if !haskey(vi, vn) + r = rand(dist) + push!!(vi, vn, r, dist, spl) + r + elseif DynamicPPL.is_flagged(vi, vn, "del") + DynamicPPL.unset_flag!(vi, vn, "del") + r = rand(dist) + vi[vn] = DynamicPPL.tovec(r) + DynamicPPL.setorder!(vi, vn, DynamicPPL.get_num_produce(vi)) + r + else + count && checkindex(vn, vi, spl) + DynamicPPL.updategid!(vi, vn, spl) + vi[vn] + end +end @testset "varinfo.jl" begin @testset "TypedVarInfo with Metadata" begin @@ -130,6 +155,26 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) test_base!!(SimpleVarInfo(Dict())) test_base!!(SimpleVarInfo(DynamicPPL.VarNamedVector())) end + + @testset "get/set/acc/resetlogp" begin + function test_varinfo_logp!(vi) + @test DynamicPPL.getlogp(vi) === 0.0 + vi = DynamicPPL.setlogp!!(vi, 1.0) + @test DynamicPPL.getlogp(vi) === 1.0 + vi = DynamicPPL.acclogp!!(vi, 1.0) + @test DynamicPPL.getlogp(vi) === 2.0 + vi = DynamicPPL.resetlogp!!(vi) + @test DynamicPPL.getlogp(vi) === 0.0 + end + + vi = VarInfo() + test_varinfo_logp!(vi) + test_varinfo_logp!(TypedVarInfo(vi)) + test_varinfo_logp!(SimpleVarInfo()) + test_varinfo_logp!(SimpleVarInfo(Dict())) + test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) + end + @testset "flags" begin # Test flag setting: # is_flagged, set_flag!, unset_flag! @@ -187,6 +232,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) setgid!(vi, gid2, vn) @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1, gid2]) end + @testset "setval! & setval_and_resample!" begin @model function testmodel(x) n = length(x) @@ -339,6 +385,103 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test vals_prev == vi.metadata.x.vals end + @testset "setval! on chain" begin + # Define a helper function + """ + test_setval!(model, chain; sample_idx = 1, chain_idx = 1) + + Test `setval!` on `model` and `chain`. + + Worth noting that this only supports models containing symbols of the forms + `m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. + """ + function test_setval!(model, chain; sample_idx=1, chain_idx=1) + var_info = VarInfo(model) + spl = SampleFromPrior() + θ_old = var_info[spl] + DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) + θ_new = var_info[spl] + @test θ_old != θ_new + vals = DynamicPPL.values_as(var_info, OrderedDict) + iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + for (n, v) in mapreduce(collect, vcat, iters) + n = string(n) + if Symbol(n) ∉ keys(chain) + # Assume it's a group + chain_val = vec( + MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] + ) + v_true = vec(v) + else + chain_val = chain[sample_idx, n, chain_idx] + v_true = v + end + + @test v_true == chain_val + end + end + + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + chain = make_chain_from_prior(model, 10) + # A simple way of checking that the computation is determinstic: run twice and compare. + res1 = returned(model, MCMCChains.get_sections(chain, :parameters)) + res2 = returned(model, MCMCChains.get_sections(chain, :parameters)) + @test all(res1 .== res2) + test_setval!(model, MCMCChains.get_sections(chain, :parameters)) + end + end + + @testset "link!! and invlink!!" begin + @model gdemo(x, y) = begin + s ~ InverseGamma(2, 3) + m ~ Uniform(0, 2) + x ~ Normal(m, sqrt(s)) + y ~ Normal(m, sqrt(s)) + end + model = gdemo(1.0, 2.0) + + # Check that instantiating the model does not perform linking + vi = VarInfo() + meta = vi.metadata + model(vi, SampleFromUniform()) + @test all(x -> !istrans(vi, x), meta.vns) + + # Check that linking and invlinking set the `trans` flag accordingly + v = copy(meta.vals) + link!!(vi, model) + @test all(x -> istrans(vi, x), meta.vns) + invlink!!(vi, model) + @test all(x -> !istrans(vi, x), meta.vns) + @test meta.vals ≈ v atol = 1e-10 + + # Check that linking and invlinking preserves the values + vi = TypedVarInfo(vi) + meta = vi.metadata + @test all(x -> !istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + v_s = copy(meta.s.vals) + v_m = copy(meta.m.vals) + link!!(vi, model) + @test all(x -> istrans(vi, x), meta.s.vns) + @test all(x -> istrans(vi, x), meta.m.vns) + invlink!!(vi, model) + @test all(x -> !istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + @test meta.s.vals ≈ v_s atol = 1e-10 + @test meta.m.vals ≈ v_m atol = 1e-10 + + # Transform only one variable (`s`) but not the others (`m`) + spl = DynamicPPL.Sampler(MyAlg{(:s,)}(), model) + link!!(vi, spl, model) + @test all(x -> istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + invlink!!(vi, spl, model) + @test all(x -> !istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + @test meta.s.vals ≈ v_s atol = 1e-10 + @test meta.m.vals ≈ v_m atol = 1e-10 + end + @testset "istrans" begin @model demo_constrained() = x ~ truncated(Normal(), 0, Inf) model = demo_constrained() @@ -737,7 +880,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) DynamicPPL.Metadata(), ) selector = DynamicPPL.Selector() - spl = Sampler(MySAlg(), model, selector) + spl = Sampler(MyAlg{(:s,)}(), model, selector) vns = DynamicPPL.TestUtils.varnames(model) vns_s = filter(vn -> DynamicPPL.getsym(vn) === :s, vns) @@ -855,4 +998,147 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) end end end + + @testset "orders" begin + @model empty_model() = x = 1 + + csym = gensym() # unique per model + vn_z1 = @varname z[1] + vn_z2 = @varname z[2] + vn_z3 = @varname z[3] + vn_z4 = @varname z[4] + vn_a1 = @varname a[1] + vn_a2 = @varname a[2] + vn_b = @varname b + + vi = DynamicPPL.VarInfo() + dists = [Categorical([0.7, 0.3]), Normal()] + + spl1 = DynamicPPL.Sampler(MyAlg{()}(), empty_model()) + spl2 = DynamicPPL.Sampler(MyAlg{()}(), empty_model()) + + # First iteration, variables are added to vi + # variables samples in order: z1,a1,z2,a2,z3 + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z1, dists[1], spl1) + randr(vi, vn_a1, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_b, dists[2], spl2) + randr(vi, vn_z2, dists[1], spl1) + randr(vi, vn_a2, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z3, dists[1], spl1) + @test vi.metadata.orders == [1, 1, 2, 2, 2, 3] + @test DynamicPPL.get_num_produce(vi) == 3 + + DynamicPPL.reset_num_produce!(vi) + DynamicPPL.set_retained_vns_del_by_spl!(vi, spl1) + @test DynamicPPL.is_flagged(vi, vn_z1, "del") + @test DynamicPPL.is_flagged(vi, vn_a1, "del") + @test DynamicPPL.is_flagged(vi, vn_z2, "del") + @test DynamicPPL.is_flagged(vi, vn_a2, "del") + @test DynamicPPL.is_flagged(vi, vn_z3, "del") + + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z1, dists[1], spl1) + randr(vi, vn_a1, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z2, dists[1], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z3, dists[1], spl1) + randr(vi, vn_a2, dists[2], spl1) + @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] + @test DynamicPPL.get_num_produce(vi) == 3 + + vi = empty!!(DynamicPPL.TypedVarInfo(vi)) + # First iteration, variables are added to vi + # variables samples in order: z1,a1,z2,a2,z3 + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z1, dists[1], spl1) + randr(vi, vn_a1, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_b, dists[2], spl2) + randr(vi, vn_z2, dists[1], spl1) + randr(vi, vn_a2, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z3, dists[1], spl1) + @test vi.metadata.z.orders == [1, 2, 3] + @test vi.metadata.a.orders == [1, 2] + @test vi.metadata.b.orders == [2] + @test DynamicPPL.get_num_produce(vi) == 3 + + DynamicPPL.reset_num_produce!(vi) + DynamicPPL.set_retained_vns_del_by_spl!(vi, spl1) + @test DynamicPPL.is_flagged(vi, vn_z1, "del") + @test DynamicPPL.is_flagged(vi, vn_a1, "del") + @test DynamicPPL.is_flagged(vi, vn_z2, "del") + @test DynamicPPL.is_flagged(vi, vn_a2, "del") + @test DynamicPPL.is_flagged(vi, vn_z3, "del") + + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z1, dists[1], spl1) + randr(vi, vn_a1, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z2, dists[1], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z3, dists[1], spl1) + randr(vi, vn_a2, dists[2], spl1) + @test vi.metadata.z.orders == [1, 2, 3] + @test vi.metadata.a.orders == [1, 3] + @test vi.metadata.b.orders == [2] + @test DynamicPPL.get_num_produce(vi) == 3 + end + + @testset "varinfo ranges" begin + @model empty_model() = x = 1 + dists = [Normal(0, 1), MvNormal(zeros(2), I), Wishart(7, [1 0.5; 0.5 1])] + + function test_varinfo!(vi) + spl2 = DynamicPPL.Sampler(MyAlg{(:w, :u)}(), empty_model()) + vn_w = @varname w + randr(vi, vn_w, dists[1], spl2, true) + + vn_x = @varname x + vn_y = @varname y + vn_z = @varname z + vns = [vn_x, vn_y, vn_z] + + spl1 = DynamicPPL.Sampler(MyAlg{(:x, :y, :z)}(), empty_model()) + for i in 1:3 + r = randr(vi, vns[i], dists[i], spl1, false) + val = vi[vns[i]] + @test sum(val - r) <= 1e-9 + end + + idcs = DynamicPPL._getidcs(vi, spl1) + if idcs isa NamedTuple + @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 3 + else + @test length(idcs) == 3 + end + @test length(vi[spl1]) == 7 + + idcs = DynamicPPL._getidcs(vi, spl2) + if idcs isa NamedTuple + @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 1 + else + @test length(idcs) == 1 + end + @test length(vi[spl2]) == 1 + + vn_u = @varname u + randr(vi, vn_u, dists[1], spl2, true) + + idcs = DynamicPPL._getidcs(vi, spl2) + if idcs isa NamedTuple + @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 2 + else + @test length(idcs) == 2 + end + @test length(vi[spl2]) == 2 + end + vi = DynamicPPL.VarInfo() + test_varinfo!(vi) + test_varinfo!(empty!!(DynamicPPL.TypedVarInfo(vi))) + end end