Skip to content

Commit

Permalink
Re-add almost all integration tests into DPPL test suite proper
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Dec 7, 2024
1 parent 9494f3a commit 6eebe2f
Show file tree
Hide file tree
Showing 6 changed files with 386 additions and 93 deletions.
39 changes: 39 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,43 @@
end
end
end

@testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin
# Failing model
t = 1:0.05:8
σ = 0.3
y = @. rand(sin(t) + Normal(0, σ))
@model function state_space(y, TT, ::Type{T}=Float64) where {T}
# Priors
α ~ Normal(y[1], 0.001)
τ ~ Exponential(1)
η ~ filldist(Normal(0, 1), TT - 1)
σ ~ Exponential(1)
# create latent variable
x = Vector{T}(undef, TT)
x[1] = α
for t in 2:TT
x[t] = x[t - 1] + η[t - 1] * τ
end
# measurement model
y ~ MvNormal(x, σ^2 * I)
return x
end
model = state_space(y, length(t))

# Dummy sampling algorithm for testing. The test case can only be replicated
# with a custom sampler, it doesn't work with SampleFromPrior(). We need to
# overload assume so that model evaluation doesn't fail due to a lack
# of implementation
struct MyEmptyAlg end
DynamicPPL.getspace(::DynamicPPL.Sampler{MyEmptyAlg}) = ()
DynamicPPL.assume(rng, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi) =
DynamicPPL.assume(dist, vn, vi)

# Compiling the ReverseDiff tape used to fail here
spl = Sampler(MyEmptyAlg())
vi = VarInfo(model)
ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl))
@test LogDensityProblemsAD.ADgradient(AutoReverseDiff(; compile=true), ldf) isa Any
end
end
90 changes: 40 additions & 50 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ is_typed_varinfo(::DynamicPPL.AbstractVarInfo) = false
is_typed_varinfo(varinfo::DynamicPPL.TypedVarInfo) = true
is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true

const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()

@testset "model.jl" begin
@testset "convenience functions" begin
model = gdemo_default # defined in test/test_util.jl
model = GDEMO_DEFAULT # defined in test/test_util.jl

# sample from model and extract variables
vi = VarInfo(model)
Expand All @@ -55,53 +57,26 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
@test ljoint lp

#### logprior, logjoint, loglikelihood for MCMC chains ####
for model in DynamicPPL.TestUtils.DEMO_MODELS # length(DynamicPPL.TestUtils.DEMO_MODELS)=12
var_info = VarInfo(model)
vns = DynamicPPL.TestUtils.varnames(model)
syms = unique(DynamicPPL.getsym.(vns))

# generate a chain of sample parameter values.
@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS
N = 200
vals_OrderedDict = mapreduce(hcat, 1:N) do _
rand(OrderedDict, model)
end
vals_mat = mapreduce(hcat, 1:N) do i
[vals_OrderedDict[i][vn] for vn in vns]
end
i = 1
for col in eachcol(vals_mat)
col_flattened = []
[push!(col_flattened, x...) for x in col]
if i == 1
chain_mat = Matrix(reshape(col_flattened, 1, length(col_flattened)))
else
chain_mat = vcat(
chain_mat, reshape(col_flattened, 1, length(col_flattened))
)
end
i += 1
end
chain_mat = convert(Matrix{Float64}, chain_mat)

# devise parameter names for chain
sample_values_vec = collect(values(vals_OrderedDict[1]))
symbol_names = []
chain_sym_map = Dict()
for k in 1:length(keys(var_info))
vn_parent = keys(var_info)[k]
chain = make_chain_from_prior(model, N)
logpriors = logprior(model, chain)
loglikelihoods = loglikelihood(model, chain)
logjoints = logjoint(model, chain)

# Construct mapping of varname symbols to varname-parent symbols.
# Here, varname_leaves is used to ensure compatibility with the
# variables stored in the chain
var_info = VarInfo(model)
chain_sym_map = Dict{Symbol,Symbol}()
for vn_parent in keys(var_info)
sym = DynamicPPL.getsym(vn_parent)
vn_children = DynamicPPL.varname_leaves(vn_parent, sample_values_vec[k]) # `varname_leaves` defined in src/utils.jl
vn_children = DynamicPPL.varname_leaves(vn_parent, var_info[vn_parent])
for vn_child in vn_children
chain_sym_map[Symbol(vn_child)] = sym
symbol_names = [symbol_names; Symbol(vn_child)]
end
end
chain = Chains(chain_mat, symbol_names)

# calculate the pointwise loglikelihoods for the whole chain using the newly written functions
logpriors = logprior(model, chain)
loglikelihoods = loglikelihood(model, chain)
logjoints = logjoint(model, chain)
# compare them with true values
for i in 1:N
samples_dict = Dict()
Expand All @@ -125,8 +100,21 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
end
end

@testset "DynamicPPL#684: threadsafe evaluation with multiple types" begin
@model function multiple_types(x)
ns ~ filldist(Normal(0, 2.0), 3)
m ~ Uniform(0, 1)
return x ~ Normal(m, 1)
end
model = multiple_types(1)
chain = make_chain_from_prior(model, 10)
loglikelihood(model, chain)
logprior(model, chain)
logjoint(model, chain)
end

@testset "rng" begin
model = gdemo_default
model = GDEMO_DEFAULT

for sampler in (SampleFromPrior(), SampleFromUniform())
for i in 1:10
Expand All @@ -144,13 +132,15 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
end

@testset "defaults without VarInfo, Sampler, and Context" begin
model = gdemo_default
model = GDEMO_DEFAULT

Random.seed!(100)
s, m = model()
retval = model()

Random.seed!(100)
@test model(Random.default_rng()) == (s, m)
retval2 = model(Random.default_rng())
@test retval2.s == retval.s
@test retval2.m == retval.m
end

@testset "nameof" begin
Expand Down Expand Up @@ -184,7 +174,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
end

@testset "Internal methods" begin
model = gdemo_default
model = GDEMO_DEFAULT

# sample from model and extract variables
vi = VarInfo(model)
Expand Down Expand Up @@ -224,7 +214,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
end

@testset "rand" begin
model = gdemo_default
model = GDEMO_DEFAULT

Random.seed!(1776)
s, m = model()
Expand Down Expand Up @@ -309,7 +299,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
end
end

@testset "generated_quantities on `LKJCholesky`" begin
@testset "returned() on `LKJCholesky`" begin
n = 10
d = 2
model = DynamicPPL.TestUtils.demo_lkjchol(d)
Expand All @@ -333,7 +323,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
)

# Test!
results = generated_quantities(model, chain)
results = returned(model, chain)
for (x_true, result) in zip(xs, results)
@test x_true.UL == result.x.UL
end
Expand All @@ -352,7 +342,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true
info=(varname_to_symbol=vns_to_syms_with_extra,),
)
# Test!
results = generated_quantities(model, chain_with_extra)
results = returned(model, chain_with_extra)
for (x_true, result) in zip(xs, results)
@test x_true.UL == result.x.UL
end
Expand Down
16 changes: 16 additions & 0 deletions test/model_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
@testset "model_utils.jl" begin
@testset "value_iterator_from_chain" begin
@testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS
chain = make_chain_from_prior(model, 10)
for (i, d) in enumerate(value_iterator_from_chain(model, chain))
for vn in keys(d)
val = DynamicPPL.getvalue(d, vn)
for vn_leaf in DynamicPPL.varname_leaves(vn, val)
val_leaf = DynamicPPL.getvalue(d, vn_leaf)
@test val_leaf == chain[i, Symbol(vn_leaf), 1]
end
end
end
end
end
end
6 changes: 1 addition & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ using Distributions
using LinearAlgebra # Diagonal

using Combinatorics: combinations
using OrderedCollections: OrderedSet

using DynamicPPL: getargs_dottilde, getargs_tilde, Selector

Expand All @@ -50,15 +51,10 @@ include("test_util.jl")
include("context_implementations.jl")
include("logdensityfunction.jl")
include("linking.jl")

include("threadsafe.jl")

include("serialization.jl")

include("pointwise_logdensities.jl")

include("lkj.jl")

include("debug_utils.jl")
end

Expand Down
34 changes: 0 additions & 34 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,40 +43,6 @@ function test_model_ad(model, logp_manual)
@test back(1)[1] grad
end

"""
test_setval!(model, chain; sample_idx = 1, chain_idx = 1)
Test `setval!` on `model` and `chain`.
Worth noting that this only supports models containing symbols of the forms
`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc.
"""
function test_setval!(model, chain; sample_idx=1, chain_idx=1)
var_info = VarInfo(model)
spl = SampleFromPrior()
θ_old = var_info[spl]
DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx)
θ_new = var_info[spl]
@test θ_old != θ_new
vals = DynamicPPL.values_as(var_info, OrderedDict)
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
for (n, v) in mapreduce(collect, vcat, iters)
n = string(n)
if Symbol(n) keys(chain)
# Assume it's a group
chain_val = vec(
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
)
v_true = vec(v)
else
chain_val = chain[sample_idx, n, chain_idx]
v_true = v
end

@test v_true == chain_val
end
end

"""
short_varinfo_name(vi::AbstractVarInfo)
Expand Down
Loading

0 comments on commit 6eebe2f

Please sign in to comment.