Skip to content

Commit

Permalink
Re-add test_setval! from test/turing/model.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Nov 29, 2024
1 parent 2fbf57b commit b52aa73
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 34 deletions.
34 changes: 0 additions & 34 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,40 +43,6 @@ function test_model_ad(model, logp_manual)
@test back(1)[1] grad
end

"""
test_setval!(model, chain; sample_idx = 1, chain_idx = 1)
Test `setval!` on `model` and `chain`.
Worth noting that this only supports models containing symbols of the forms
`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc.
"""
function test_setval!(model, chain; sample_idx=1, chain_idx=1)
var_info = VarInfo(model)
spl = SampleFromPrior()
θ_old = var_info[spl]
DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx)
θ_new = var_info[spl]
@test θ_old != θ_new
vals = DynamicPPL.values_as(var_info, OrderedDict)
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
for (n, v) in mapreduce(collect, vcat, iters)
n = string(n)
if Symbol(n) keys(chain)
# Assume it's a group
chain_val = vec(
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
)
v_true = vec(v)
else
chain_val = chain[sample_idx, n, chain_idx]
v_true = v
end

@test v_true == chain_val
end
end

"""
short_varinfo_name(vi::AbstractVarInfo)
Expand Down
48 changes: 48 additions & 0 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
test_base!!(SimpleVarInfo(Dict()))
test_base!!(SimpleVarInfo(DynamicPPL.VarNamedVector()))
end

@testset "flags" begin
# Test flag setting:
# is_flagged, set_flag!, unset_flag!
Expand Down Expand Up @@ -187,6 +188,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
setgid!(vi, gid2, vn)
@test meta.x.gids[meta.x.idcs[vn]] == Set([gid1, gid2])
end

@testset "setval! & setval_and_resample!" begin
@model function testmodel(x)
n = length(x)
Expand Down Expand Up @@ -339,6 +341,52 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,)
@test vals_prev == vi.metadata.x.vals
end

@testset "setval! on chain" begin
# Define a helper function
"""
test_setval!(model, chain; sample_idx = 1, chain_idx = 1)
Test `setval!` on `model` and `chain`.
Worth noting that this only supports models containing symbols of the forms
`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc.
"""
function test_setval!(model, chain; sample_idx=1, chain_idx=1)
var_info = VarInfo(model)
spl = SampleFromPrior()
θ_old = var_info[spl]
DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx)
θ_new = var_info[spl]
@test θ_old != θ_new
vals = DynamicPPL.values_as(var_info, OrderedDict)
iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals))
for (n, v) in mapreduce(collect, vcat, iters)
n = string(n)
if Symbol(n) keys(chain)
# Assume it's a group
chain_val = vec(
MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx]
)
v_true = vec(v)
else
chain_val = chain[sample_idx, n, chain_idx]
v_true = v
end

@test v_true == chain_val
end
end

@testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS
chain = make_chain_from_prior(model, 10)
# A simple way of checking that the computation is determinstic: run twice and compare.
res1 = generated_quantities(model, MCMCChains.get_sections(chain, :parameters))
res2 = generated_quantities(model, MCMCChains.get_sections(chain, :parameters))
@test all(res1 .== res2)
test_setval!(model, MCMCChains.get_sections(chain, :parameters))
end
end

@testset "istrans" begin
@model demo_constrained() = x ~ truncated(Normal(), 0, Inf)
model = demo_constrained()
Expand Down

0 comments on commit b52aa73

Please sign in to comment.