Skip to content

Commit

Permalink
Reduce iterations in ESS tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru committed Dec 6, 2024
1 parent c0a4ee9 commit a8b5cdd
Showing 1 changed file with 45 additions and 34 deletions.
79 changes: 45 additions & 34 deletions test/mcmc/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ using Distributions: Normal, sample
using DynamicPPL: DynamicPPL
using DynamicPPL: Sampler
using Random: Random
using StableRNGs: StableRNG
using Test: @test, @testset
using Turing

@testset "ESS" begin
@info "Starting ESS tests"

@model function demo(x)
m ~ Normal()
return x ~ Normal(m, 0.5)
Expand All @@ -25,7 +28,7 @@ using Turing

@testset "ESS constructor" begin
Random.seed!(0)
N = 500
N = 10

s1 = ESS()
s2 = ESS(:m)
Expand All @@ -43,41 +46,49 @@ using Turing
end

@testset "ESS inference" begin
Random.seed!(1)
chain = sample(demo_default, ESS(), 5_000)
check_numerical(chain, [:m], [0.8]; atol=0.1)

Random.seed!(1)
chain = sample(demodot_default, ESS(), 5_000)
check_numerical(chain, ["m[1]", "m[2]"], [0.0, 0.8]; atol=0.1)

Random.seed!(100)
alg = Gibbs(CSMC(15, :s), ESS(:m))
chain = sample(gdemo(1.5, 2.0), alg, 10_000)
check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1)

# MoGtest
Random.seed!(125)
alg = Gibbs(CSMC(15, :z1, :z2, :z3, :z4), ESS(:mu1), ESS(:mu2))
chain = sample(MoGtest_default, alg, 6000)
check_MoGtest_default(chain; atol=0.1)

# Different "equivalent" models.
# NOTE: Because `ESS` only supports "single" variables with
# Gaussian priors, we restrict ourselves to this subspace by conditioning
# on the non-Gaussian variables in `DEMO_MODELS`.
models_conditioned = map(DynamicPPL.TestUtils.DEMO_MODELS) do model
# Condition on the non-Gaussian random variables.
model | (s=DynamicPPL.TestUtils.posterior_mean(model).s,)
@info "Starting ESS inference tests"
rng = StableRNG(23)

@testset "demo_default" begin
chain = sample(rng, demo_default, ESS(), 500)
check_numerical(chain, [:m], [0.8]; atol=0.1)
end

@testset "demodot_default" begin
chain = sample(rng, demodot_default, ESS(), 500)
check_numerical(chain, ["m[1]", "m[2]"], [0.0, 0.8]; atol=0.1)
end

@testset "gdemo with CSMC + ESS" begin
alg = Gibbs(CSMC(15, :s), ESS(:m))
chain = sample(rng, gdemo(1.5, 2.0), alg, 2000)
check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]; atol=0.1)
end

@testset "MoGtest_default with CSMC + ESS" begin
alg = Gibbs(CSMC(15, :z1, :z2, :z3, :z4), ESS(:mu1), ESS(:mu2))
chain = sample(rng, MoGtest_default, alg, 2000)
check_MoGtest_default(chain; atol=0.1)
end

DynamicPPL.TestUtils.test_sampler(
models_conditioned,
DynamicPPL.Sampler(ESS()),
10_000;
# Filter out the varnames we've conditioned on.
varnames_filter=vn -> DynamicPPL.getsym(vn) != :s,
)
@testset "TestModels" begin
# Different "equivalent" models.
# NOTE: Because `ESS` only supports "single" variables with
# Gaussian priors, we restrict ourselves to this subspace by conditioning
# on the non-Gaussian variables in `DEMO_MODELS`.
models_conditioned = map(DynamicPPL.TestUtils.DEMO_MODELS) do model
# Condition on the non-Gaussian random variables.
model | (s=DynamicPPL.TestUtils.posterior_mean(model).s,)
end

DynamicPPL.TestUtils.test_sampler(
models_conditioned,
DynamicPPL.Sampler(ESS()),
2000;
# Filter out the varnames we've conditioned on.
varnames_filter=vn -> DynamicPPL.getsym(vn) != :s,
)
end
end
end

Expand Down

0 comments on commit a8b5cdd

Please sign in to comment.