Skip to content

Commit

Permalink
Merge pull request #97 from TuringLang/dw/reproducible
Browse files Browse the repository at this point in the history
Ensure ensemble sampling is reproducible
  • Loading branch information
cpfiffer authored Feb 22, 2022
2 parents 6dae58b + 22299e2 commit d46ba93
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 74 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ jobs:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v1
with:
cache-packages: "false" # caching Conda.jl causes precompilation error
- uses: julia-actions/julia-buildpkg@latest
- uses: julia-actions/julia-runtest@latest
env:
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "3.3.0"
version = "3.3.2"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
53 changes: 28 additions & 25 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,8 @@ function mcmcsample(
models = [deepcopy(model) for _ in interval]
samplers = [deepcopy(sampler) for _ in interval]

# Create a seed for each chunk using the provided random number generator.
seeds = rand(rng, UInt, nchunks)
# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

# Set up a chains vector.
chains = Vector{Any}(undef, nchains)
Expand Down Expand Up @@ -339,25 +339,22 @@ function mcmcsample(

Distributed.@async begin
try
Distributed.@sync for (i, _rng, seed, _model, _sampler) in zip(1:nchunks, rngs, seeds, models, samplers)
Threads.@spawn begin
Distributed.@sync for (i, _rng, _model, _sampler) in zip(1:nchunks, rngs, models, samplers)
chainidxs = if i == nchunks
((i - 1) * chunksize + 1):nchains
else
((i - 1) * chunksize + 1):(i * chunksize)
end
Threads.@spawn for chainidx in chainidxs
# Seed the chunk-specific random number generator with the pre-made seed.
Random.seed!(_rng, seed)

chainidxs = if i == nchunks
((i - 1) * chunksize + 1):nchains
else
((i - 1) * chunksize + 1):(i * chunksize)
end

for chainidx in chainidxs
# Sample a chain and save it to the vector.
chains[chainidx] = StatsBase.sample(_rng, _model, _sampler, N;
progress = false, kwargs...)

# Update the progress bar.
progress && put!(channel, true)
end
Random.seed!(_rng, seeds[chainidx])

# Sample a chain and save it to the vector.
chains[chainidx] = StatsBase.sample(_rng, _model, _sampler, N;
progress = false, kwargs...)

# Update the progress bar.
progress && put!(channel, true)
end
end
finally
Expand Down Expand Up @@ -469,12 +466,18 @@ function mcmcsample(
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

# Sample the chains.
chains = map(
i -> StatsBase.sample(rng, model, sampler, N; progressname = string(progressname, " (Chain ", i, " of ", nchains, ")"),
kwargs...),
1:nchains
)
chains = map(enumerate(seeds)) do (i, seed)
Random.seed!(rng, seed)
return StatsBase.sample(
rng, model, sampler, N;
progressname = string(progressname, " (Chain ", i, " of ", nchains, ")"),
kwargs...,
)
end

# Concatenate the chains together.
return chainsstack(tighten_eltype(chains))
Expand Down
124 changes: 76 additions & 48 deletions test/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,61 +103,59 @@
end
end

if VERSION v"1.3"
@testset "Multithreaded sampling" begin
if Threads.nthreads() == 1
warnregex = r"^Only a single thread available"
@test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCThreads(),
10, 10)
end
@testset "Multithreaded sampling" begin
if Threads.nthreads() == 1
warnregex = r"^Only a single thread available"
@test_logs (:warn, warnregex) sample(MyModel(), MySampler(), MCMCThreads(),
10, 10)
end

# No dedicated chains type
N = 10_000
chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000)
@test chains isa Vector{<:Vector{<:MySample}}
@test length(chains) == 1000
@test all(length(x) == N for x in chains)
# No dedicated chains type
N = 10_000
chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000)
@test chains isa Vector{<:Vector{<:MySample}}
@test length(chains) == 1000
@test all(length(x) == N for x in chains)

Random.seed!(1234)
chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000;
chain_type = MyChain)
Random.seed!(1234)
chains = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000;
chain_type = MyChain)

# test output type and size
@test chains isa Vector{<:MyChain}
@test length(chains) == 1000
@test all(x -> length(x.as) == length(x.bs) == N, chains)
# test output type and size
@test chains isa Vector{<:MyChain}
@test length(chains) == 1000
@test all(x -> length(x.as) == length(x.bs) == N, chains)

# test some statistical properties
@test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains)
@test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains)
@test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains)
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains)
# test some statistical properties
@test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains)
@test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains)
@test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains)
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=1e-1), chains)

# test reproducibility
Random.seed!(1234)
chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000;
chain_type = MyChain)
# test reproducibility
Random.seed!(1234)
chains2 = sample(MyModel(), MySampler(), MCMCThreads(), N, 1000;
chain_type = MyChain)

@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains, chains2), i in 1:N)
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N)

# Unexpected order of arguments.
str = "Number of chains (10) is greater than number of samples per chain (5)"
@test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(),
MCMCThreads(), 5, 10;
chain_type = MyChain)
# Unexpected order of arguments.
str = "Number of chains (10) is greater than number of samples per chain (5)"
@test_logs (:warn, str) match_mode=:any sample(MyModel(), MySampler(),
MCMCThreads(), 5, 10;
chain_type = MyChain)

# Suppress output.
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000;
progress = false, chain_type = MyChain)
end
@test all(l.level > Logging.LogLevel(-1) for l in logs)
# Suppress output.
logs, _ = collect_test_logs(; min_level=Logging.LogLevel(-1)) do
sample(MyModel(), MySampler(), MCMCThreads(), 10_000, 1000;
progress = false, chain_type = MyChain)
end
@test all(l.level > Logging.LogLevel(-1) for l in logs)

# Smoke test for nchains < nthreads
if Threads.nthreads() == 2
sample(MyModel(), MySampler(), MCMCThreads(), N, 1)
end
# Smoke test for nchains < nthreads
if Threads.nthreads() == 2
sample(MyModel(), MySampler(), MCMCThreads(), N, 1)
end
end

Expand Down Expand Up @@ -201,7 +199,7 @@
@test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains)
@test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains)
@test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains)
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains)
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=1e-1), chains)

# Test reproducibility.
Random.seed!(1234)
Expand Down Expand Up @@ -247,7 +245,7 @@
@test all(x -> isapprox(mean(@view x.as[2:end]), 0.5; atol=5e-2), chains)
@test all(x -> isapprox(var(@view x.as[2:end]), 1 / 12; atol=5e-3), chains)
@test all(x -> isapprox(mean(@view x.bs[2:end]), 0; atol=5e-2), chains)
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=5e-2), chains)
@test all(x -> isapprox(var(@view x.bs[2:end]), 1; atol=1e-1), chains)

# Test reproducibility.
Random.seed!(1234)
Expand All @@ -271,6 +269,36 @@
@test all(l.level > Logging.LogLevel(-1) for l in logs)
end

@testset "Ensemble sampling: Reproducibility" begin
N = 1_000
nchains = 10

# Serial sampling
Random.seed!(1234)
chains_serial = sample(
MyModel(), MySampler(), MCMCSerial(), N, nchains;
progress=false, chain_type=MyChain
)

# Multi-threaded sampling
Random.seed!(1234)
chains_threads = sample(
MyModel(), MySampler(), MCMCThreads(), N, nchains;
progress=false, chain_type=MyChain
)
@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_threads), i in 1:N)
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_threads), i in 1:N)

# Multi-core sampling
Random.seed!(1234)
chains_distributed = sample(
MyModel(), MySampler(), MCMCDistributed(), N, nchains;
progress=false, chain_type=MyChain
)
@test all(c1.as[i] === c2.as[i] for (c1, c2) in zip(chains_serial, chains_distributed), i in 1:N)
@test all(c1.bs[i] === c2.bs[i] for (c1, c2) in zip(chains_serial, chains_distributed), i in 1:N)
end

@testset "Chain constructors" begin
chain1 = sample(MyModel(), MySampler(), 100; sleepy = true)
chain2 = sample(MyModel(), MySampler(), 100; sleepy = true, chain_type = MyChain)
Expand Down

2 comments on commit d46ba93

@cpfiffer
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/55232

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v3.3.2 -m "<description of version>" d46ba9363cdd6d2f5ffae2fa79d38aacce333926
git push origin v3.3.2

Also, note the warning: Version 3.3.2 skips over 3.3.1
This can be safely ignored. However, if you want to fix this you can do so. Call register() again after making the fix. This will update the Pull request.

Please sign in to comment.