Skip to content

Commit

Permalink
Re-add almost all integration tests into DPPL test suite proper
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Dec 4, 2024
1 parent ae14716 commit 675b40f
Show file tree
Hide file tree
Showing 6 changed files with 359 additions and 87 deletions.
39 changes: 39 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,43 @@
end
end
end

@testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin
# Failing model
t = 1:0.05:8
σ = 0.3
y = @. rand(sin(t) + Normal(0, σ))
@model function state_space(y, TT, ::Type{T}=Float64) where {T}
# Priors
α ~ Normal(y[1], 0.001)
τ ~ Exponential(1)
η ~ filldist(Normal(0, 1), TT - 1)
σ ~ Exponential(1)
# create latent variable
x = Vector{T}(undef, TT)
x[1] = α
for t in 2:TT
x[t] = x[t - 1] + η[t - 1] * τ
end
# measurement model
y ~ MvNormal(x, σ^2 * I)
return x
end
model = state_space(y, length(t))

# Dummy sampling algorithm for testing. The test case can only be replicated
# with a custom sampler, it doesn't work with SampleFromPrior(). We need to
# overload assume so that model evaluation doesn't fail due to a lack
# of implementation
struct MyEmptyAlg end
DynamicPPL.getspace(::DynamicPPL.Sampler{MyEmptyAlg}) = ()
DynamicPPL.assume(rng, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi) =
DynamicPPL.assume(dist, vn, vi)

# Compiling the ReverseDiff tape used to fail here
spl = Sampler(MyEmptyAlg())
vi = VarInfo(model)
ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl))
@test LogDensityProblemsAD.ADgradient(AutoReverseDiff(; compile=true), ldf) isa Any
end
end
59 changes: 15 additions & 44 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,50 +55,8 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
@test ljoint lp

#### logprior, logjoint, loglikelihood for MCMC chains ####
for model in DynamicPPL.TestUtils.DEMO_MODELS # length(DynamicPPL.TestUtils.DEMO_MODELS)=12
var_info = VarInfo(model)
vns = DynamicPPL.TestUtils.varnames(model)
syms = unique(DynamicPPL.getsym.(vns))

# generate a chain of sample parameter values.
N = 200
vals_OrderedDict = mapreduce(hcat, 1:N) do _
rand(OrderedDict, model)
end
vals_mat = mapreduce(hcat, 1:N) do i
[vals_OrderedDict[i][vn] for vn in vns]
end
i = 1
for col in eachcol(vals_mat)
col_flattened = []
[push!(col_flattened, x...) for x in col]
if i == 1
chain_mat = Matrix(reshape(col_flattened, 1, length(col_flattened)))
else
chain_mat = vcat(
chain_mat, reshape(col_flattened, 1, length(col_flattened))
)
end
i += 1
end
chain_mat = convert(Matrix{Float64}, chain_mat)

# devise parameter names for chain
sample_values_vec = collect(values(vals_OrderedDict[1]))
symbol_names = []
chain_sym_map = Dict()
for k in 1:length(keys(var_info))
vn_parent = keys(var_info)[k]
sym = DynamicPPL.getsym(vn_parent)
vn_children = DynamicPPL.varname_leaves(vn_parent, sample_values_vec[k]) # `varname_leaves` defined in src/utils.jl
for vn_child in vn_children
chain_sym_map[Symbol(vn_child)] = sym
symbol_names = [symbol_names; Symbol(vn_child)]
end
end
chain = Chains(chain_mat, symbol_names)

# calculate the pointwise loglikelihoods for the whole chain using the newly written functions
for model in DynamicPPL.TestUtils.DEMO_MODELS
chain = make_chain_from_prior(model, 200)
logpriors = logprior(model, chain)
loglikelihoods = loglikelihood(model, chain)
logjoints = logjoint(model, chain)
Expand All @@ -125,6 +83,19 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
end
end

@testset "DynamicPPL#684: threadsafe evaluation with multiple types" begin
@model function multiple_types(x)
ns ~ filldist(Normal(0, 2.0), 3)
m ~ Uniform(0, 1)
return x ~ Normal(m, 1)
end
model = multiple_types(1)
chain = make_chain_from_prior(model, 10)
loglikelihood(model, chain)
logprior(model, chain)
logjoint(model, chain)
end

@testset "rng" begin
model = gdemo_default

Expand Down
16 changes: 16 additions & 0 deletions test/model_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
@testset "model_utils.jl" begin
@testset "value_iterator_from_chain" begin
@testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS
chain = make_chain_from_prior(model, 10)
for (i, d) in enumerate(value_iterator_from_chain(model, chain))
for vn in keys(d)
val = DynamicPPL.getvalue(d, vn)
for vn_leaf in DynamicPPL.varname_leaves(vn, val)
val_leaf = DynamicPPL.getvalue(d, vn_leaf)
@test val_leaf == chain[i, Symbol(vn_leaf), 1]
end
end
end
end
end
end
6 changes: 1 addition & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ using Distributions
using LinearAlgebra # Diagonal

using Combinatorics: combinations
using OrderedCollections: OrderedSet

using DynamicPPL: getargs_dottilde, getargs_tilde, Selector

Expand All @@ -48,15 +49,10 @@ include("test_util.jl")
include("context_implementations.jl")
include("logdensityfunction.jl")
include("linking.jl")

include("threadsafe.jl")

include("serialization.jl")

include("pointwise_logdensities.jl")

include("lkj.jl")

include("debug_utils.jl")
end

Expand Down
34 changes: 0 additions & 34 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,40 +43,6 @@ function test_model_ad(model, logp_manual)
@test back(1)[1] grad
end

"""
test_setval!(model, chain; sample_idx = 1, chain_idx = 1)
Test `setval!` on `model` and `chain`.
Worth noting that this only supports models containing symbols of the forms
`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc.
"""
function test_setval!(model, chain; sample_idx=1, chain_idx=1)
var_info = VarInfo(model)
spl = SampleFromPrior()
θ_old = var_info[spl]
DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx)
θ_new = var_info[spl]
@test θ_old != θ_new
vals = DynamicPPL.values_as(var_info, OrderedDict)
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
for (n, v) in mapreduce(collect, vcat, iters)
n = string(n)
if Symbol(n) keys(chain)
# Assume it's a group
chain_val = vec(
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
)
v_true = vec(v)
else
chain_val = chain[sample_idx, n, chain_idx]
v_true = v
end

@test v_true == chain_val
end
end

"""
short_varinfo_name(vi::AbstractVarInfo)
Expand Down
Loading

0 comments on commit 675b40f

Please sign in to comment.