From 05f91cf88a641c73ec9555fb780be5c2453156e1 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 21 Feb 2022 22:33:52 +0100 Subject: [PATCH 1/6] Fix reproducibility of ensemble sampling --- src/sample.jl | 53 +++++++++++++++++++++++++++------------------------ 1 file changed, 28 insertions(+), 25 deletions(-) diff --git a/src/sample.jl b/src/sample.jl index e783760f..e1ff1be7 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -305,8 +305,8 @@ function mcmcsample( models = [deepcopy(model) for _ in interval] samplers = [deepcopy(sampler) for _ in interval] - # Create a seed for each chunk using the provided random number generator. - seeds = rand(rng, UInt, nchunks) + # Create a seed for each chain using the provided random number generator. + seeds = rand(rng, UInt, nchains) # Set up a chains vector. chains = Vector{Any}(undef, nchains) @@ -339,25 +339,22 @@ function mcmcsample( Distributed.@async begin try - Distributed.@sync for (i, _rng, seed, _model, _sampler) in zip(1:nchunks, rngs, seeds, models, samplers) - Threads.@spawn begin + Distributed.@sync for (i, _rng, _model, _sampler) in zip(1:nchunks, rngs, models, samplers) + chainidxs = if i == nchunks + ((i - 1) * chunksize + 1):nchains + else + ((i - 1) * chunksize + 1):(i * chunksize) + end + Threads.@spawn for chainidx in chainidxs # Seed the chunk-specific random number generator with the pre-made seed. - Random.seed!(_rng, seed) - - chainidxs = if i == nchunks - ((i - 1) * chunksize + 1):nchains - else - ((i - 1) * chunksize + 1):(i * chunksize) - end - - for chainidx in chainidxs - # Sample a chain and save it to the vector. - chains[chainidx] = StatsBase.sample(_rng, _model, _sampler, N; - progress = false, kwargs...) - - # Update the progress bar. - progress && put!(channel, true) - end + Random.seed!(_rng, seeds[chainidx]) + + # Sample a chain and save it to the vector. + chains[chainidx] = StatsBase.sample(_rng, _model, _sampler, N; + progress = false, kwargs...) + + # Update the progress bar. + progress && put!(channel, true) end end finally @@ -469,12 +466,18 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end + # Create a seed for each chain using the provided random number generator. + seeds = rand(rng, UInt, nchains) + # Sample the chains. - chains = map( - i -> StatsBase.sample(rng, model, sampler, N; progressname = string(progressname, " (Chain ", i, " of ", nchains, ")"), - kwargs...), - 1:nchains - ) + chains = map(enumerate(seeds)) do (i, seed) + Random.seed!(rng, seed) + return StatsBase.sample( + rng, model, sampler, N; + progressname = string(progressname, " (Chain ", i, " of ", nchains, ")"), + kwargs..., + ) + end # Concatenate the chains together. return chainsstack(tighten_eltype(chains)) From dca3aa09a4171c7460966747628c694bf8e08c64 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 21 Feb 2022 22:34:35 +0100 Subject: [PATCH 2/6] Add tests --- test/sample.jl | 120 ++++++++++++++++++++++++++++++------------------- 1 file changed, 74 insertions(+), 46 deletions(-) diff --git a/test/sample.jl b/test/sample.jl index 6e876d48..d8a04712 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -103,61 +103,59 @@ end end - if VERSION ≥ v"1.3" - @testset "Multithreaded sampling" begin - if Threads.nthreads() == 1 - warnregex = r"^Only a single thread available" - @test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCThreads(), - 10, 10) - end + @testset "Multithreaded sampling" begin + if Threads.nthreads() == 1 + warnregex = r"^Only a single thread available" + @test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCThreads(), + 10, 10) + end - # No dedicated chains type - N = 10_000 - chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000) - @test chains isa Vector{<:Vector{<:MySample}} - @test length(chains) == 1000 - @test all(length(x) == N for x in chains) + # No dedicated chains type + N = 10_000 + chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000) + @test chains isa Vector{<:Vector{<:MySample}} + @test length(chains) == 1000 + @test all(length(x) == N for x in chains) - Random.seed!(1234) - chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; - chain_type = MyChain) + Random.seed!(1234) + chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; + chain_type = MyChain) - # test output type and size - @test chains isa Vector{<:MyChain} - @test length(chains) == 1000 - @test all(x -> length(x.as) == length(x.bs) == N, chains) + # test output type and size + @test chains isa Vector{<:MyChain} + @test length(chains) == 1000 + @test all(x -> length(x.as) == length(x.bs) == N, chains) - # test some statistical properties - @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains) - @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) - @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) - @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) + # test some statistical properties + @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains) + @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) + @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) + @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) - # test reproducibility - Random.seed!(1234) - chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; - chain_type = MyChain) + # test reproducibility + Random.seed!(1234) + chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000; + chain_type = MyChain) - @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) - @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) + @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N) + @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) - # Unexpected order of arguments. - str = "Number of chains (10) is greater than number of samples per chain (5)" - @test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(), - MCMCThreads(), 5, 10; - chain_type = MyChain) + # Unexpected order of arguments. + str = "Number of chains (10) is greater than number of samples per chain (5)" + @test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(), + MCMCThreads(), 5, 10; + chain_type = MyChain) - # Suppress output. - logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do - sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000; - progress = false, chain_type = MyChain) - end - @test all(l.level > Logging.LogLevel(-1) for l in logs) + # Suppress output. + logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do + sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000; + progress = false, chain_type = MyChain) + end + @test all(l.level > Logging.LogLevel(-1) for l in logs) - # Smoke test for nchains < nthreads - if Threads.nthreads() == 2 - sample(MyModel(), MySampler(), MCMCThreads(), N, 1) - end + # Smoke test for nchains < nthreads + if Threads.nthreads() == 2 + sample(MyModel(), MySampler(), MCMCThreads(), N, 1) end end @@ -271,6 +269,36 @@ @test all(l.level > Logging.LogLevel(-1) for l in logs) end + @testset "Ensemble sampling: Reproducibility" begin + N = 1_000 + nchains = 10 + + # Serial sampling + Random.seed!(1234) + chains_serial = sample( + MyModel(), MySampler(), MCMCSerial(), N, nchains; + progress=false, chain_type=MyChain + ) + + # Multi-threaded sampling + Random.seed!(1234) + chains_threads = sample( + MyModel(), MySampler(), MCMCThreads(), N, nchains; + progress=false, chain_type=MyChain + ) + @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_threads), i in 1:N) + @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_threads), i in 1:N) + + # Multi-core sampling + Random.seed!(1234) + chains_distributed = sample( + MyModel(), MySampler(), MCMCDistributed(), N, nchains; + progress=false, chain_type=MyChain + ) + @test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_distributed), i in 1:N) + @test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_distributed), i in 1:N) + end + @testset "Chain constructors" begin chain1 = sample(MyModel(), MySampler(), 100; sleepy = true) chain2 = sample(MyModel(), MySampler(), 100; sleepy = true, chain_type = MyChain) From 25d8c391322aa037e0398210c8e56df1da1d8e3d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 21 Feb 2022 22:34:42 +0100 Subject: [PATCH 3/6] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 794deb24..7d3a2ca9 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "3.3.0" +version = "3.3.1" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" From d24e6b04bc3596e933eafd2b5ad0942b55081426 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 21 Feb 2022 22:56:26 +0100 Subject: [PATCH 4/6] Fix CI --- .github/workflows/CI.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 069ef82b..d5b1273a 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -46,6 +46,8 @@ jobs: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - uses: julia-actions/cache@v1 + with: + cache-packages: "false" # caching Conda.jl causes precompilation error - uses: julia-actions/julia-buildpkg@latest - uses: julia-actions/julia-runtest@latest env: From 20ffc79d51e3ca2bc0bb75efa2366389c7fc08f1 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 21 Feb 2022 23:34:00 +0100 Subject: [PATCH 5/6] Increase tolerances --- test/sample.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/sample.jl b/test/sample.jl index d8a04712..407330d2 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -130,7 +130,7 @@ @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains) @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) - @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) + @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=1e-1), chains) # test reproducibility Random.seed!(1234) @@ -199,7 +199,7 @@ @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains) @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) - @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) + @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=1e-1), chains) # Test reproducibility. Random.seed!(1234) @@ -245,7 +245,7 @@ @test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains) @test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains) @test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains) - @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains) + @test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=1e-1), chains) # Test reproducibility. Random.seed!(1234) From 22299e27d56530b367593f2e595cac8e659f9362 Mon Sep 17 00:00:00 2001 From: Cameron Pfiffer Date: Tue, 22 Feb 2022 11:05:42 -0800 Subject: [PATCH 6/6] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7d3a2ca9..bf356aa1 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "3.3.1" +version = "3.3.2" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"