From 6eebe2fbbc6a6b45d8132ecbe2349223831b6822 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 29 Nov 2024 14:41:36 +0000 Subject: [PATCH] Re-add almost all integration tests into DPPL test suite proper --- test/ad.jl | 39 ++++++ test/model.jl | 90 ++++++-------- test/model_utils.jl | 16 +++ test/runtests.jl | 6 +- test/test_util.jl | 34 ----- test/varinfo.jl | 294 +++++++++++++++++++++++++++++++++++++++++++- 6 files changed, 386 insertions(+), 93 deletions(-) create mode 100644 test/model_utils.jl diff --git a/test/ad.jl b/test/ad.jl index 768a55ad3..17981cf2a 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -34,4 +34,43 @@ end end end + + @testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin + # Failing model + t = 1:0.05:8 + σ = 0.3 + y = @. rand(sin(t) + Normal(0, σ)) + @model function state_space(y, TT, ::Type{T}=Float64) where {T} + # Priors + α ~ Normal(y[1], 0.001) + τ ~ Exponential(1) + η ~ filldist(Normal(0, 1), TT - 1) + σ ~ Exponential(1) + # create latent variable + x = Vector{T}(undef, TT) + x[1] = α + for t in 2:TT + x[t] = x[t - 1] + η[t - 1] * τ + end + # measurement model + y ~ MvNormal(x, σ^2 * I) + return x + end + model = state_space(y, length(t)) + + # Dummy sampling algorithm for testing. The test case can only be replicated + # with a custom sampler, it doesn't work with SampleFromPrior(). We need to + # overload assume so that model evaluation doesn't fail due to a lack + # of implementation + struct MyEmptyAlg end + DynamicPPL.getspace(::DynamicPPL.Sampler{MyEmptyAlg}) = () + DynamicPPL.assume(rng, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi) = + DynamicPPL.assume(dist, vn, vi) + + # Compiling the ReverseDiff tape used to fail here + spl = Sampler(MyEmptyAlg()) + vi = VarInfo(model) + ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl)) + @test LogDensityProblemsAD.ADgradient(AutoReverseDiff(; compile=true), ldf) isa Any + end end diff --git a/test/model.jl b/test/model.jl index d163f55f0..6e8cbb47d 100644 --- a/test/model.jl +++ b/test/model.jl @@ -29,9 +29,11 @@ is_typed_varinfo(::DynamicPPL.AbstractVarInfo) = false is_typed_varinfo(varinfo::DynamicPPL.TypedVarInfo) = true is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true +const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() + @testset "model.jl" begin @testset "convenience functions" begin - model = gdemo_default # defined in test/test_util.jl + model = GDEMO_DEFAULT # defined in test/test_util.jl # sample from model and extract variables vi = VarInfo(model) @@ -55,53 +57,26 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true @test ljoint ≈ lp #### logprior, logjoint, loglikelihood for MCMC chains #### - for model in DynamicPPL.TestUtils.DEMO_MODELS # length(DynamicPPL.TestUtils.DEMO_MODELS)=12 - var_info = VarInfo(model) - vns = DynamicPPL.TestUtils.varnames(model) - syms = unique(DynamicPPL.getsym.(vns)) - - # generate a chain of sample parameter values. + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS N = 200 - vals_OrderedDict = mapreduce(hcat, 1:N) do _ - rand(OrderedDict, model) - end - vals_mat = mapreduce(hcat, 1:N) do i - [vals_OrderedDict[i][vn] for vn in vns] - end - i = 1 - for col in eachcol(vals_mat) - col_flattened = [] - [push!(col_flattened, x...) for x in col] - if i == 1 - chain_mat = Matrix(reshape(col_flattened, 1, length(col_flattened))) - else - chain_mat = vcat( - chain_mat, reshape(col_flattened, 1, length(col_flattened)) - ) - end - i += 1 - end - chain_mat = convert(Matrix{Float64}, chain_mat) - - # devise parameter names for chain - sample_values_vec = collect(values(vals_OrderedDict[1])) - symbol_names = [] - chain_sym_map = Dict() - for k in 1:length(keys(var_info)) - vn_parent = keys(var_info)[k] + chain = make_chain_from_prior(model, N) + logpriors = logprior(model, chain) + loglikelihoods = loglikelihood(model, chain) + logjoints = logjoint(model, chain) + + # Construct mapping of varname symbols to varname-parent symbols. + # Here, varname_leaves is used to ensure compatibility with the + # variables stored in the chain + var_info = VarInfo(model) + chain_sym_map = Dict{Symbol,Symbol}() + for vn_parent in keys(var_info) sym = DynamicPPL.getsym(vn_parent) - vn_children = DynamicPPL.varname_leaves(vn_parent, sample_values_vec[k]) # `varname_leaves` defined in src/utils.jl + vn_children = DynamicPPL.varname_leaves(vn_parent, var_info[vn_parent]) for vn_child in vn_children chain_sym_map[Symbol(vn_child)] = sym - symbol_names = [symbol_names; Symbol(vn_child)] end end - chain = Chains(chain_mat, symbol_names) - # calculate the pointwise loglikelihoods for the whole chain using the newly written functions - logpriors = logprior(model, chain) - loglikelihoods = loglikelihood(model, chain) - logjoints = logjoint(model, chain) # compare them with true values for i in 1:N samples_dict = Dict() @@ -125,8 +100,21 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end end + @testset "DynamicPPL#684: threadsafe evaluation with multiple types" begin + @model function multiple_types(x) + ns ~ filldist(Normal(0, 2.0), 3) + m ~ Uniform(0, 1) + return x ~ Normal(m, 1) + end + model = multiple_types(1) + chain = make_chain_from_prior(model, 10) + loglikelihood(model, chain) + logprior(model, chain) + logjoint(model, chain) + end + @testset "rng" begin - model = gdemo_default + model = GDEMO_DEFAULT for sampler in (SampleFromPrior(), SampleFromUniform()) for i in 1:10 @@ -144,13 +132,15 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end @testset "defaults without VarInfo, Sampler, and Context" begin - model = gdemo_default + model = GDEMO_DEFAULT Random.seed!(100) - s, m = model() + retval = model() Random.seed!(100) - @test model(Random.default_rng()) == (s, m) + retval2 = model(Random.default_rng()) + @test retval2.s == retval.s + @test retval2.m == retval.m end @testset "nameof" begin @@ -184,7 +174,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end @testset "Internal methods" begin - model = gdemo_default + model = GDEMO_DEFAULT # sample from model and extract variables vi = VarInfo(model) @@ -224,7 +214,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end @testset "rand" begin - model = gdemo_default + model = GDEMO_DEFAULT Random.seed!(1776) s, m = model() @@ -309,7 +299,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end end - @testset "generated_quantities on `LKJCholesky`" begin + @testset "returned() on `LKJCholesky`" begin n = 10 d = 2 model = DynamicPPL.TestUtils.demo_lkjchol(d) @@ -333,7 +323,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true ) # Test! - results = generated_quantities(model, chain) + results = returned(model, chain) for (x_true, result) in zip(xs, results) @test x_true.UL == result.x.UL end @@ -352,7 +342,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true info=(varname_to_symbol=vns_to_syms_with_extra,), ) # Test! - results = generated_quantities(model, chain_with_extra) + results = returned(model, chain_with_extra) for (x_true, result) in zip(xs, results) @test x_true.UL == result.x.UL end diff --git a/test/model_utils.jl b/test/model_utils.jl new file mode 100644 index 000000000..4d3a08412 --- /dev/null +++ b/test/model_utils.jl @@ -0,0 +1,16 @@ +@testset "model_utils.jl" begin + @testset "value_iterator_from_chain" begin + @testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS + chain = make_chain_from_prior(model, 10) + for (i, d) in enumerate(value_iterator_from_chain(model, chain)) + for vn in keys(d) + val = DynamicPPL.getvalue(d, vn) + for vn_leaf in DynamicPPL.varname_leaves(vn, val) + val_leaf = DynamicPPL.getvalue(d, vn_leaf) + @test val_leaf == chain[i, Symbol(vn_leaf), 1] + end + end + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index d2f092a77..0393161f6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,6 +28,7 @@ using Distributions using LinearAlgebra # Diagonal using Combinatorics: combinations +using OrderedCollections: OrderedSet using DynamicPPL: getargs_dottilde, getargs_tilde, Selector @@ -50,15 +51,10 @@ include("test_util.jl") include("context_implementations.jl") include("logdensityfunction.jl") include("linking.jl") - include("threadsafe.jl") - include("serialization.jl") - include("pointwise_logdensities.jl") - include("lkj.jl") - include("debug_utils.jl") end diff --git a/test/test_util.jl b/test/test_util.jl index 0611c594f..27a68456c 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -43,40 +43,6 @@ function test_model_ad(model, logp_manual) @test back(1)[1] ≈ grad end -""" - test_setval!(model, chain; sample_idx = 1, chain_idx = 1) - -Test `setval!` on `model` and `chain`. - -Worth noting that this only supports models containing symbols of the forms -`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. -""" -function test_setval!(model, chain; sample_idx=1, chain_idx=1) - var_info = VarInfo(model) - spl = SampleFromPrior() - θ_old = var_info[spl] - DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) - θ_new = var_info[spl] - @test θ_old != θ_new - vals = DynamicPPL.values_as(var_info, OrderedDict) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) - for (n, v) in mapreduce(collect, vcat, iters) - n = string(n) - if Symbol(n) ∉ keys(chain) - # Assume it's a group - chain_val = vec( - MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] - ) - v_true = vec(v) - else - chain_val = chain[sample_idx, n, chain_idx] - v_true = v - end - - @test v_true == chain_val - end -end - """ short_varinfo_name(vi::AbstractVarInfo) diff --git a/test/varinfo.jl b/test/varinfo.jl index ae4319904..9a55cffb9 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,3 +1,8 @@ +# Dummy algorithm for testing +# Invoke with: DynamicPPL.Sampler(MyAlg{(:x, :y)}(), ...) +struct MyAlg{space} end +DynamicPPL.getspace(::DynamicPPL.Sampler{MyAlg{space}}) where {space} = space + function check_varinfo_keys(varinfo, vns) if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, @@ -14,9 +19,29 @@ function check_varinfo_keys(varinfo, vns) end end -# A simple "algorithm" which only has `s` variables in its space. -struct MySAlg end -DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) +function randr( + vi::DynamicPPL.VarInfo, + vn::VarName, + dist::Distribution, + spl::DynamicPPL.Sampler, + count::Bool=false, +) + if !haskey(vi, vn) + r = rand(dist) + push!!(vi, vn, r, dist, spl) + r + elseif DynamicPPL.is_flagged(vi, vn, "del") + DynamicPPL.unset_flag!(vi, vn, "del") + r = rand(dist) + vi[vn] = DynamicPPL.tovec(r) + DynamicPPL.setorder!(vi, vn, DynamicPPL.get_num_produce(vi)) + r + else + count && checkindex(vn, vi, spl) + DynamicPPL.updategid!(vi, vn, spl) + vi[vn] + end +end @testset "varinfo.jl" begin @testset "TypedVarInfo with Metadata" begin @@ -130,6 +155,26 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) test_base!!(SimpleVarInfo(Dict())) test_base!!(SimpleVarInfo(DynamicPPL.VarNamedVector())) end + + @testset "get/set/acc/resetlogp" begin + function test_varinfo_logp!(vi) + @test DynamicPPL.getlogp(vi) === 0.0 + vi = DynamicPPL.setlogp!!(vi, 1.0) + @test DynamicPPL.getlogp(vi) === 1.0 + vi = DynamicPPL.acclogp!!(vi, 1.0) + @test DynamicPPL.getlogp(vi) === 2.0 + vi = DynamicPPL.resetlogp!!(vi) + @test DynamicPPL.getlogp(vi) === 0.0 + end + + vi = VarInfo() + test_varinfo_logp!(vi) + test_varinfo_logp!(TypedVarInfo(vi)) + test_varinfo_logp!(SimpleVarInfo()) + test_varinfo_logp!(SimpleVarInfo(Dict())) + test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) + end + @testset "flags" begin # Test flag setting: # is_flagged, set_flag!, unset_flag! @@ -187,6 +232,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) setgid!(vi, gid2, vn) @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1, gid2]) end + @testset "setval! & setval_and_resample!" begin @model function testmodel(x) n = length(x) @@ -339,6 +385,103 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test vals_prev == vi.metadata.x.vals end + @testset "setval! on chain" begin + # Define a helper function + """ + test_setval!(model, chain; sample_idx = 1, chain_idx = 1) + + Test `setval!` on `model` and `chain`. + + Worth noting that this only supports models containing symbols of the forms + `m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. + """ + function test_setval!(model, chain; sample_idx=1, chain_idx=1) + var_info = VarInfo(model) + spl = SampleFromPrior() + θ_old = var_info[spl] + DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) + θ_new = var_info[spl] + @test θ_old != θ_new + vals = DynamicPPL.values_as(var_info, OrderedDict) + iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + for (n, v) in mapreduce(collect, vcat, iters) + n = string(n) + if Symbol(n) ∉ keys(chain) + # Assume it's a group + chain_val = vec( + MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] + ) + v_true = vec(v) + else + chain_val = chain[sample_idx, n, chain_idx] + v_true = v + end + + @test v_true == chain_val + end + end + + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + chain = make_chain_from_prior(model, 10) + # A simple way of checking that the computation is determinstic: run twice and compare. + res1 = returned(model, MCMCChains.get_sections(chain, :parameters)) + res2 = returned(model, MCMCChains.get_sections(chain, :parameters)) + @test all(res1 .== res2) + test_setval!(model, MCMCChains.get_sections(chain, :parameters)) + end + end + + @testset "link!! and invlink!!" begin + @model gdemo(x, y) = begin + s ~ InverseGamma(2, 3) + m ~ Uniform(0, 2) + x ~ Normal(m, sqrt(s)) + y ~ Normal(m, sqrt(s)) + end + model = gdemo(1.0, 2.0) + + # Check that instantiating the model does not perform linking + vi = VarInfo() + meta = vi.metadata + model(vi, SampleFromUniform()) + @test all(x -> !istrans(vi, x), meta.vns) + + # Check that linking and invlinking set the `trans` flag accordingly + v = copy(meta.vals) + link!!(vi, model) + @test all(x -> istrans(vi, x), meta.vns) + invlink!!(vi, model) + @test all(x -> !istrans(vi, x), meta.vns) + @test meta.vals ≈ v atol = 1e-10 + + # Check that linking and invlinking preserves the values + vi = TypedVarInfo(vi) + meta = vi.metadata + @test all(x -> !istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + v_s = copy(meta.s.vals) + v_m = copy(meta.m.vals) + link!!(vi, model) + @test all(x -> istrans(vi, x), meta.s.vns) + @test all(x -> istrans(vi, x), meta.m.vns) + invlink!!(vi, model) + @test all(x -> !istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + @test meta.s.vals ≈ v_s atol = 1e-10 + @test meta.m.vals ≈ v_m atol = 1e-10 + + # Transform only one variable (`s`) but not the others (`m`) + spl = DynamicPPL.Sampler(MyAlg{(:s,)}(), model) + link!!(vi, spl, model) + @test all(x -> istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + invlink!!(vi, spl, model) + @test all(x -> !istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + @test meta.s.vals ≈ v_s atol = 1e-10 + @test meta.m.vals ≈ v_m atol = 1e-10 + end + @testset "istrans" begin @model demo_constrained() = x ~ truncated(Normal(), 0, Inf) model = demo_constrained() @@ -737,7 +880,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) DynamicPPL.Metadata(), ) selector = DynamicPPL.Selector() - spl = Sampler(MySAlg(), model, selector) + spl = Sampler(MyAlg{(:s,)}(), model, selector) vns = DynamicPPL.TestUtils.varnames(model) vns_s = filter(vn -> DynamicPPL.getsym(vn) === :s, vns) @@ -855,4 +998,147 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) end end end + + @testset "orders" begin + @model empty_model() = x = 1 + + csym = gensym() # unique per model + vn_z1 = @varname z[1] + vn_z2 = @varname z[2] + vn_z3 = @varname z[3] + vn_z4 = @varname z[4] + vn_a1 = @varname a[1] + vn_a2 = @varname a[2] + vn_b = @varname b + + vi = DynamicPPL.VarInfo() + dists = [Categorical([0.7, 0.3]), Normal()] + + spl1 = DynamicPPL.Sampler(MyAlg{()}(), empty_model()) + spl2 = DynamicPPL.Sampler(MyAlg{()}(), empty_model()) + + # First iteration, variables are added to vi + # variables samples in order: z1,a1,z2,a2,z3 + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z1, dists[1], spl1) + randr(vi, vn_a1, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_b, dists[2], spl2) + randr(vi, vn_z2, dists[1], spl1) + randr(vi, vn_a2, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z3, dists[1], spl1) + @test vi.metadata.orders == [1, 1, 2, 2, 2, 3] + @test DynamicPPL.get_num_produce(vi) == 3 + + DynamicPPL.reset_num_produce!(vi) + DynamicPPL.set_retained_vns_del_by_spl!(vi, spl1) + @test DynamicPPL.is_flagged(vi, vn_z1, "del") + @test DynamicPPL.is_flagged(vi, vn_a1, "del") + @test DynamicPPL.is_flagged(vi, vn_z2, "del") + @test DynamicPPL.is_flagged(vi, vn_a2, "del") + @test DynamicPPL.is_flagged(vi, vn_z3, "del") + + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z1, dists[1], spl1) + randr(vi, vn_a1, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z2, dists[1], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z3, dists[1], spl1) + randr(vi, vn_a2, dists[2], spl1) + @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] + @test DynamicPPL.get_num_produce(vi) == 3 + + vi = empty!!(DynamicPPL.TypedVarInfo(vi)) + # First iteration, variables are added to vi + # variables samples in order: z1,a1,z2,a2,z3 + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z1, dists[1], spl1) + randr(vi, vn_a1, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_b, dists[2], spl2) + randr(vi, vn_z2, dists[1], spl1) + randr(vi, vn_a2, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z3, dists[1], spl1) + @test vi.metadata.z.orders == [1, 2, 3] + @test vi.metadata.a.orders == [1, 2] + @test vi.metadata.b.orders == [2] + @test DynamicPPL.get_num_produce(vi) == 3 + + DynamicPPL.reset_num_produce!(vi) + DynamicPPL.set_retained_vns_del_by_spl!(vi, spl1) + @test DynamicPPL.is_flagged(vi, vn_z1, "del") + @test DynamicPPL.is_flagged(vi, vn_a1, "del") + @test DynamicPPL.is_flagged(vi, vn_z2, "del") + @test DynamicPPL.is_flagged(vi, vn_a2, "del") + @test DynamicPPL.is_flagged(vi, vn_z3, "del") + + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z1, dists[1], spl1) + randr(vi, vn_a1, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z2, dists[1], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z3, dists[1], spl1) + randr(vi, vn_a2, dists[2], spl1) + @test vi.metadata.z.orders == [1, 2, 3] + @test vi.metadata.a.orders == [1, 3] + @test vi.metadata.b.orders == [2] + @test DynamicPPL.get_num_produce(vi) == 3 + end + + @testset "varinfo ranges" begin + @model empty_model() = x = 1 + dists = [Normal(0, 1), MvNormal(zeros(2), I), Wishart(7, [1 0.5; 0.5 1])] + + function test_varinfo!(vi) + spl2 = DynamicPPL.Sampler(MyAlg{(:w, :u)}(), empty_model()) + vn_w = @varname w + randr(vi, vn_w, dists[1], spl2, true) + + vn_x = @varname x + vn_y = @varname y + vn_z = @varname z + vns = [vn_x, vn_y, vn_z] + + spl1 = DynamicPPL.Sampler(MyAlg{(:x, :y, :z)}(), empty_model()) + for i in 1:3 + r = randr(vi, vns[i], dists[i], spl1, false) + val = vi[vns[i]] + @test sum(val - r) <= 1e-9 + end + + idcs = DynamicPPL._getidcs(vi, spl1) + if idcs isa NamedTuple + @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 3 + else + @test length(idcs) == 3 + end + @test length(vi[spl1]) == 7 + + idcs = DynamicPPL._getidcs(vi, spl2) + if idcs isa NamedTuple + @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 1 + else + @test length(idcs) == 1 + end + @test length(vi[spl2]) == 1 + + vn_u = @varname u + randr(vi, vn_u, dists[1], spl2, true) + + idcs = DynamicPPL._getidcs(vi, spl2) + if idcs isa NamedTuple + @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 2 + else + @test length(idcs) == 2 + end + @test length(vi[spl2]) == 2 + end + vi = DynamicPPL.VarInfo() + test_varinfo!(vi) + test_varinfo!(empty!!(DynamicPPL.TypedVarInfo(vi))) + end end