Skip to content

Commit

Permalink
Merge branch 'master' into mhauru/fix-test-seed
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm authored Oct 30, 2024
2 parents 591be1a + f6fdc91 commit e6623d5
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 37 deletions.
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,4 @@ StatsBase = "0.33, 0.34"
StatsFuns = "0.9.5, 1"
TimerOutputs = "0.5"
Zygote = "0.5.4, 0.6"
julia = "1.3"
julia = "1.10"
27 changes: 8 additions & 19 deletions test/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,14 @@ using Turing
@testset "rng" begin
model = gdemo_default

# multithreaded sampling with PG causes segfaults on Julia 1.5.4
# https://github.com/TuringLang/Turing.jl/issues/1571
samplers = @static if VERSION <= v"1.5.3" || VERSION >= v"1.6.0"
(
HMC(0.1, 7; adtype=adbackend),
PG(10),
IS(),
MH(),
Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)),
Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)),
)
else
(
HMC(0.1, 7; adtype=adbackend),
IS(),
MH(),
Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)),
)
end
samplers = (
HMC(0.1, 7; adtype=adbackend),
PG(10),
IS(),
MH(),
Gibbs(PG(3, :s), HMC(0.4, 8, :m; adtype=adbackend)),
Gibbs(HMC(0.1, 5, :s; adtype=adbackend), ESS(:m)),
)
for sampler in samplers
Random.seed!(5)
chain1 = sample(model, sampler, MCMCThreads(), 1000, 4)
Expand Down
25 changes: 8 additions & 17 deletions test/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,23 +137,14 @@ end
)

@testset "inference" begin
if adtype isa AutoReverseDiff &&
model.f === DynamicPPL.TestUtils.demo_assume_index_observe &&
VERSION < v"1.8"
# Ref: https://github.com/TuringLang/DynamicPPL.jl/issues/612
@test_throws UndefRefError sample(
model, sampler_ext, 5_000; sample_kwargs...
)
else
DynamicPPL.TestUtils.test_sampler(
[model],
sampler_ext,
5_000;
rtol=0.2,
sampler_name="AdvancedHMC",
sample_kwargs...,
)
end
DynamicPPL.TestUtils.test_sampler(
[model],
sampler_ext,
5_000;
rtol=0.2,
sampler_name="AdvancedHMC",
sample_kwargs...,
)
end
end
end
Expand Down

0 comments on commit e6623d5

Please sign in to comment.