diff --git a/test/ad.jl b/test/ad.jl index 6046cfda4..f2efa685f 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -25,4 +25,43 @@ end end end + + @testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin + # Failing model + t = 1:0.05:8 + σ = 0.3 + y = @. rand(sin(t) + Normal(0, σ)) + @model function state_space(y, TT, ::Type{T}=Float64) where {T} + # Priors + α ~ Normal(y[1], 0.001) + τ ~ Exponential(1) + η ~ filldist(Normal(0, 1), TT - 1) + σ ~ Exponential(1) + # create latent variable + x = Vector{T}(undef, TT) + x[1] = α + for t in 2:TT + x[t] = x[t - 1] + η[t - 1] * τ + end + # measurement model + y ~ MvNormal(x, σ^2 * I) + return x + end + model = state_space(y, length(t)) + + # Dummy sampling algorithm for testing. The test case can only be replicated + # with a custom sampler, it doesn't work with SampleFromPrior(). We need to + # overload assume so that model evaluation doesn't fail due to a lack + # of implementation + struct MyEmptyAlg end + DynamicPPL.getspace(::DynamicPPL.Sampler{MyEmptyAlg}) = () + DynamicPPL.assume(rng, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi) = + DynamicPPL.assume(dist, vn, vi) + + # Compiling the ReverseDiff tape used to fail here + spl = Sampler(MyEmptyAlg()) + vi = VarInfo(model) + ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl)) + @test LogDensityProblemsAD.ADgradient(AutoReverseDiff(; compile=true), ldf) isa Any + end end diff --git a/test/model.jl b/test/model.jl index d163f55f0..c72e91066 100644 --- a/test/model.jl +++ b/test/model.jl @@ -55,50 +55,8 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true @test ljoint ≈ lp #### logprior, logjoint, loglikelihood for MCMC chains #### - for model in DynamicPPL.TestUtils.DEMO_MODELS # length(DynamicPPL.TestUtils.DEMO_MODELS)=12 - var_info = VarInfo(model) - vns = DynamicPPL.TestUtils.varnames(model) - syms = unique(DynamicPPL.getsym.(vns)) - - # generate a chain of sample parameter values. - N = 200 - vals_OrderedDict = mapreduce(hcat, 1:N) do _ - rand(OrderedDict, model) - end - vals_mat = mapreduce(hcat, 1:N) do i - [vals_OrderedDict[i][vn] for vn in vns] - end - i = 1 - for col in eachcol(vals_mat) - col_flattened = [] - [push!(col_flattened, x...) for x in col] - if i == 1 - chain_mat = Matrix(reshape(col_flattened, 1, length(col_flattened))) - else - chain_mat = vcat( - chain_mat, reshape(col_flattened, 1, length(col_flattened)) - ) - end - i += 1 - end - chain_mat = convert(Matrix{Float64}, chain_mat) - - # devise parameter names for chain - sample_values_vec = collect(values(vals_OrderedDict[1])) - symbol_names = [] - chain_sym_map = Dict() - for k in 1:length(keys(var_info)) - vn_parent = keys(var_info)[k] - sym = DynamicPPL.getsym(vn_parent) - vn_children = DynamicPPL.varname_leaves(vn_parent, sample_values_vec[k]) # `varname_leaves` defined in src/utils.jl - for vn_child in vn_children - chain_sym_map[Symbol(vn_child)] = sym - symbol_names = [symbol_names; Symbol(vn_child)] - end - end - chain = Chains(chain_mat, symbol_names) - - # calculate the pointwise loglikelihoods for the whole chain using the newly written functions + for model in DynamicPPL.TestUtils.DEMO_MODELS + chain = make_chain_from_prior(model, 200) logpriors = logprior(model, chain) loglikelihoods = loglikelihood(model, chain) logjoints = logjoint(model, chain) @@ -125,6 +83,19 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end end + @testset "DynamicPPL#684: threadsafe evaluation with multiple types" begin + @model function multiple_types(x) + ns ~ filldist(Normal(0, 2.0), 3) + m ~ Uniform(0, 1) + return x ~ Normal(m, 1) + end + model = multiple_types(1) + chain = make_chain_from_prior(model, 10) + loglikelihood(model, chain) + logprior(model, chain) + logjoint(model, chain) + end + @testset "rng" begin model = gdemo_default diff --git a/test/model_utils.jl b/test/model_utils.jl new file mode 100644 index 000000000..4d3a08412 --- /dev/null +++ b/test/model_utils.jl @@ -0,0 +1,16 @@ +@testset "model_utils.jl" begin + @testset "value_iterator_from_chain" begin + @testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS + chain = make_chain_from_prior(model, 10) + for (i, d) in enumerate(value_iterator_from_chain(model, chain)) + for vn in keys(d) + val = DynamicPPL.getvalue(d, vn) + for vn_leaf in DynamicPPL.varname_leaves(vn, val) + val_leaf = DynamicPPL.getvalue(d, vn_leaf) + @test val_leaf == chain[i, Symbol(vn_leaf), 1] + end + end + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index bd9d52858..e1405f1ee 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,6 +26,7 @@ using Distributions using LinearAlgebra # Diagonal using Combinatorics: combinations +using OrderedCollections: OrderedSet using DynamicPPL: getargs_dottilde, getargs_tilde, Selector @@ -48,15 +49,10 @@ include("test_util.jl") include("context_implementations.jl") include("logdensityfunction.jl") include("linking.jl") - include("threadsafe.jl") - include("serialization.jl") - include("pointwise_logdensities.jl") - include("lkj.jl") - include("debug_utils.jl") end diff --git a/test/test_util.jl b/test/test_util.jl index 0611c594f..27a68456c 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -43,40 +43,6 @@ function test_model_ad(model, logp_manual) @test back(1)[1] ≈ grad end -""" - test_setval!(model, chain; sample_idx = 1, chain_idx = 1) - -Test `setval!` on `model` and `chain`. - -Worth noting that this only supports models containing symbols of the forms -`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. -""" -function test_setval!(model, chain; sample_idx=1, chain_idx=1) - var_info = VarInfo(model) - spl = SampleFromPrior() - θ_old = var_info[spl] - DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) - θ_new = var_info[spl] - @test θ_old != θ_new - vals = DynamicPPL.values_as(var_info, OrderedDict) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) - for (n, v) in mapreduce(collect, vcat, iters) - n = string(n) - if Symbol(n) ∉ keys(chain) - # Assume it's a group - chain_val = vec( - MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] - ) - v_true = vec(v) - else - chain_val = chain[sample_idx, n, chain_idx] - v_true = v - end - - @test v_true == chain_val - end -end - """ short_varinfo_name(vi::AbstractVarInfo) diff --git a/test/varinfo.jl b/test/varinfo.jl index c45fb47e0..12a7651d5 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,3 +1,8 @@ +# Dummy algorithm for testing +# Invoke with: DynamicPPL.Sampler(MyAlg{(:x, :y)}(), ...) +struct MyAlg{space} end +DynamicPPL.getspace(::DynamicPPL.Sampler{MyAlg{space}}) where {space} = space + function check_varinfo_keys(varinfo, vns) if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, @@ -14,9 +19,29 @@ function check_varinfo_keys(varinfo, vns) end end -# A simple "algorithm" which only has `s` variables in its space. -struct MySAlg end -DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) +function randr( + vi::DynamicPPL.VarInfo, + vn::VarName, + dist::Distribution, + spl::DynamicPPL.Sampler, + count::Bool=false, +) + if !haskey(vi, vn) + r = rand(dist) + push!!(vi, vn, r, dist, spl) + r + elseif DynamicPPL.is_flagged(vi, vn, "del") + DynamicPPL.unset_flag!(vi, vn, "del") + r = rand(dist) + vi[vn] = DynamicPPL.tovec(r) + DynamicPPL.setorder!(vi, vn, DynamicPPL.get_num_produce(vi)) + r + else + count && checkindex(vn, vi, spl) + DynamicPPL.updategid!(vi, vn, spl) + vi[vn] + end +end @testset "varinfo.jl" begin @testset "TypedVarInfo with Metadata" begin @@ -130,6 +155,26 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) test_base!!(SimpleVarInfo(Dict())) test_base!!(SimpleVarInfo(DynamicPPL.VarNamedVector())) end + + @testset "get/set/acc/resetlogp" begin + function test_varinfo_logp!(vi) + @test DynamicPPL.getlogp(vi) === 0 + vi = DynamicPPL.setlogp!!(vi, 1) + @test DynamicPPL.getlogp(vi) === 1 + vi = DynamicPPL.acclogp!!(vi, 1) + @test DynamicPPL.getlogp(vi) === 2 + vi = DynamicPPL.resetlogp!!(vi) + @test DynamicPPL.getlogp(vi) === 0 + end + + vi = VarInfo() + test_varinfo_logp!(vi) + test_varinfo_logp!(TypedVarInfo(vi)) + test_varinfo_logp!(SimpleVarInfo()) + test_varinfo_logp!(SimpleVarInfo(Dict())) + test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) + end + @testset "flags" begin # Test flag setting: # is_flagged, set_flag!, unset_flag! @@ -187,6 +232,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) setgid!(vi, gid2, vn) @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1, gid2]) end + @testset "setval! & setval_and_resample!" begin @model function testmodel(x) n = length(x) @@ -339,6 +385,103 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test vals_prev == vi.metadata.x.vals end + @testset "setval! on chain" begin + # Define a helper function + """ + test_setval!(model, chain; sample_idx = 1, chain_idx = 1) + + Test `setval!` on `model` and `chain`. + + Worth noting that this only supports models containing symbols of the forms + `m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. + """ + function test_setval!(model, chain; sample_idx=1, chain_idx=1) + var_info = VarInfo(model) + spl = SampleFromPrior() + θ_old = var_info[spl] + DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) + θ_new = var_info[spl] + @test θ_old != θ_new + vals = DynamicPPL.values_as(var_info, OrderedDict) + iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + for (n, v) in mapreduce(collect, vcat, iters) + n = string(n) + if Symbol(n) ∉ keys(chain) + # Assume it's a group + chain_val = vec( + MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] + ) + v_true = vec(v) + else + chain_val = chain[sample_idx, n, chain_idx] + v_true = v + end + + @test v_true == chain_val + end + end + + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + chain = make_chain_from_prior(model, 10) + # A simple way of checking that the computation is determinstic: run twice and compare. + res1 = generated_quantities(model, MCMCChains.get_sections(chain, :parameters)) + res2 = generated_quantities(model, MCMCChains.get_sections(chain, :parameters)) + @test all(res1 .== res2) + test_setval!(model, MCMCChains.get_sections(chain, :parameters)) + end + end + + @testset "link!! and invlink!!" begin + @model gdemo(x, y) = begin + s ~ InverseGamma(2, 3) + m ~ Uniform(0, 2) + x ~ Normal(m, sqrt(s)) + y ~ Normal(m, sqrt(s)) + end + model = gdemo(1.0, 2.0) + + # Check that instantiating the model does not perform linking + vi = VarInfo() + meta = vi.metadata + model(vi, SampleFromUniform()) + @test all(x -> !istrans(vi, x), meta.vns) + + # Check that linking and invlinking set the `trans` flag accordingly + v = copy(meta.vals) + link!!(vi, model) + @test all(x -> istrans(vi, x), meta.vns) + invlink!!(vi, model) + @test all(x -> !istrans(vi, x), meta.vns) + @test meta.vals ≈ v atol = 1e-10 + + # Check that linking and invlinking preserves the values + vi = TypedVarInfo(vi) + meta = vi.metadata + @test all(x -> !istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + v_s = copy(meta.s.vals) + v_m = copy(meta.m.vals) + link!!(vi, model) + @test all(x -> istrans(vi, x), meta.s.vns) + @test all(x -> istrans(vi, x), meta.m.vns) + invlink!!(vi, model) + @test all(x -> !istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + @test meta.s.vals ≈ v_s atol = 1e-10 + @test meta.m.vals ≈ v_m atol = 1e-10 + + # Transform only one variable (`s`) but not the others (`m`) + spl = DynamicPPL.Sampler(MyAlg{(:s,)}(), model) + link!!(vi, spl, model) + @test all(x -> istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + invlink!!(vi, spl, model) + @test all(x -> !istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + @test meta.s.vals ≈ v_s atol = 1e-10 + @test meta.m.vals ≈ v_m atol = 1e-10 + end + @testset "istrans" begin @model demo_constrained() = x ~ truncated(Normal(), 0, Inf) model = demo_constrained() @@ -737,7 +880,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) DynamicPPL.Metadata(), ) selector = DynamicPPL.Selector() - spl = Sampler(MySAlg(), model, selector) + spl = Sampler(MyAlg{(:s,)}(), model, selector) vns = DynamicPPL.TestUtils.varnames(model) vns_s = filter(vn -> DynamicPPL.getsym(vn) === :s, vns) @@ -813,4 +956,145 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test DynamicPPL.istrans(varinfo2, vn) end end + + @testset "orders" begin + @model empty_model() = x = 1 + + csym = gensym() # unique per model + vn_z1 = @varname z[1] + vn_z2 = @varname z[2] + vn_z3 = @varname z[3] + vn_z4 = @varname z[4] + vn_a1 = @varname a[1] + vn_a2 = @varname a[2] + vn_b = @varname b + + vi = DynamicPPL.VarInfo() + dists = [Categorical([0.7, 0.3]), Normal()] + + spl1 = DynamicPPL.Sampler(MyAlg{()}(), empty_model()) + spl2 = DynamicPPL.Sampler(MyAlg{()}(), empty_model()) + + # First iteration, variables are added to vi + # variables samples in order: z1,a1,z2,a2,z3 + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z1, dists[1], spl1) + randr(vi, vn_a1, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_b, dists[2], spl2) + randr(vi, vn_z2, dists[1], spl1) + randr(vi, vn_a2, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z3, dists[1], spl1) + @test vi.metadata.orders == [1, 1, 2, 2, 2, 3] + @test DynamicPPL.get_num_produce(vi) == 3 + + DynamicPPL.reset_num_produce!(vi) + DynamicPPL.set_retained_vns_del_by_spl!(vi, spl1) + @test DynamicPPL.is_flagged(vi, vn_z1, "del") + @test DynamicPPL.is_flagged(vi, vn_a1, "del") + @test DynamicPPL.is_flagged(vi, vn_z2, "del") + @test DynamicPPL.is_flagged(vi, vn_a2, "del") + @test DynamicPPL.is_flagged(vi, vn_z3, "del") + + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z1, dists[1], spl1) + randr(vi, vn_a1, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z2, dists[1], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z3, dists[1], spl1) + randr(vi, vn_a2, dists[2], spl1) + @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] + @test DynamicPPL.get_num_produce(vi) == 3 + + vi = empty!!(DynamicPPL.TypedVarInfo(vi)) + # First iteration, variables are added to vi + # variables samples in order: z1,a1,z2,a2,z3 + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z1, dists[1], spl1) + randr(vi, vn_a1, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_b, dists[2], spl2) + randr(vi, vn_z2, dists[1], spl1) + randr(vi, vn_a2, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z3, dists[1], spl1) + @test vi.metadata.z.orders == [1, 2, 3] + @test vi.metadata.a.orders == [1, 2] + @test vi.metadata.b.orders == [2] + @test DynamicPPL.get_num_produce(vi) == 3 + + DynamicPPL.reset_num_produce!(vi) + DynamicPPL.set_retained_vns_del_by_spl!(vi, spl1) + @test DynamicPPL.is_flagged(vi, vn_z1, "del") + @test DynamicPPL.is_flagged(vi, vn_a1, "del") + @test DynamicPPL.is_flagged(vi, vn_z2, "del") + @test DynamicPPL.is_flagged(vi, vn_a2, "del") + @test DynamicPPL.is_flagged(vi, vn_z3, "del") + + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z1, dists[1], spl1) + randr(vi, vn_a1, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z2, dists[1], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z3, dists[1], spl1) + randr(vi, vn_a2, dists[2], spl1) + @test vi.metadata.z.orders == [1, 2, 3] + @test vi.metadata.a.orders == [1, 3] + @test vi.metadata.b.orders == [2] + @test DynamicPPL.get_num_produce(vi) == 3 + end + + @testset "varinfo" begin + dists = [Normal(0, 1), MvNormal(zeros(2), I), Wishart(7, [1 0.5; 0.5 1])] + function test_varinfo!(vi) + spl2 = DynamicPPL.Sampler(MyAlg{(:w, :u)}(), empty_model()) + vn_w = @varname w + randr(vi, vn_w, dists[1], spl2, true) + + vn_x = @varname x + vn_y = @varname y + vn_z = @varname z + vns = [vn_x, vn_y, vn_z] + + spl1 = DynamicPPL.Sampler(MyAlg{(:x, :y, :z)}(), empty_model()) + for i in 1:3 + r = randr(vi, vns[i], dists[i], spl1, false) + val = vi[vns[i]] + @test sum(val - r) <= 1e-9 + end + + idcs = DynamicPPL._getidcs(vi, spl1) + if idcs isa NamedTuple + @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 3 + else + @test length(idcs) == 3 + end + @test length(vi[spl1]) == 7 + + idcs = DynamicPPL._getidcs(vi, spl2) + if idcs isa NamedTuple + @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 1 + else + @test length(idcs) == 1 + end + @test length(vi[spl2]) == 1 + + vn_u = @varname u + randr(vi, vn_u, dists[1], spl2, true) + + idcs = DynamicPPL._getidcs(vi, spl2) + if idcs isa NamedTuple + @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 2 + else + @test length(idcs) == 2 + end + @test length(vi[spl2]) == 2 + end + vi = DynamicPPL.VarInfo() + test_varinfo!(vi) + test_varinfo!(empty!!(DynamicPPL.TypedVarInfo(vi))) + end end