From 3de7393b8b8e76330f53505b27d2b928ef178681 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 7 Mar 2022 08:39:09 +0100 Subject: [PATCH] Support `init_params` in ensemble methods (#94) * Support `init_params` in ensemble methods * Fix typo * Fix typo * Add documentation * Support `Iterators.Repeated` * Breaking release * Fix and simplify docs setup * Remove deprecations * Reduce tasks on Windows * Generalize to arbitrary collections * Use Blue style --- Project.toml | 2 +- docs/src/api.md | 5 +++ src/AbstractMCMC.jl | 1 - src/deprecations.jl | 2 - src/sample.jl | 76 ++++++++++++++++++++++++++++--- test/deprecations.jl | 4 -- test/runtests.jl | 1 - test/sample.jl | 103 +++++++++++++++++++++++++++++++++++++++++++ test/utils.jl | 10 +++-- 9 files changed, 187 insertions(+), 17 deletions(-) delete mode 100644 src/deprecations.jl delete mode 100644 test/deprecations.jl diff --git a/Project.toml b/Project.toml index 7d3a2ca9..9690fd8d 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "3.3.1" +version = "4.0.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" diff --git a/docs/src/api.md b/docs/src/api.md index 8dcf55f4..c7451cc5 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -53,6 +53,11 @@ are: - `discard_initial` (default: `0`): number of initial samples that are discarded - `thinning` (default: `1`): factor by which to thin samples. +There is no "official" way for providing initial parameter values yet. +However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain. +To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): +- `init_params` (default: `nothing`): if set to `init_params !== nothing`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = Iterators.repeated(x)` or `init_params = FillArrays.Fill(x, N)`. + Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`. ```@docs diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index 686924a8..3e8e2ff2 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -84,6 +84,5 @@ include("interface.jl") include("sample.jl") include("stepper.jl") include("transducer.jl") -include("deprecations.jl") end # module AbstractMCMC diff --git a/src/deprecations.jl b/src/deprecations.jl deleted file mode 100644 index 128f16d1..00000000 --- a/src/deprecations.jl +++ /dev/null @@ -1,2 +0,0 @@ -# Deprecate the old name AbstractMCMCParallel in favor of AbstractMCMCEnsemble -Base.@deprecate_binding AbstractMCMCParallel AbstractMCMCEnsemble false diff --git a/src/sample.jl b/src/sample.jl index b6fad3fe..01548470 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -283,6 +283,7 @@ function mcmcsample( nchains::Integer; progress=PROGRESS[], progressname="Sampling ($(min(nchains, Threads.nthreads())) threads)", + init_params=nothing, kwargs..., ) # Check if actually multiple threads are used. @@ -298,7 +299,7 @@ function mcmcsample( # Copy the random number generator, model, and sample for each thread nchunks = min(nchains, Threads.nthreads()) chunksize = cld(nchains, nchunks) - interval = 1:min(nchains, Threads.nthreads()) + interval = 1:nchunks rngs = [deepcopy(rng) for _ in interval] models = [deepcopy(model) for _ in interval] samplers = [deepcopy(sampler) for _ in interval] @@ -306,6 +307,9 @@ function mcmcsample( # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) + # Ensure that initial parameters are `nothing` or indexable + _init_params = _first_or_nothing(init_params, nchains) + # Set up a chains vector. chains = Vector{Any}(undef, nchains) @@ -350,7 +354,17 @@ function mcmcsample( # Sample a chain and save it to the vector. chains[chainidx] = StatsBase.sample( - _rng, _model, _sampler, N; progress=false, kwargs... + _rng, + _model, + _sampler, + N; + progress=false, + init_params=if _init_params === nothing + nothing + else + _init_params[chainidx] + end, + kwargs..., ) # Update the progress bar. @@ -378,6 +392,7 @@ function mcmcsample( nchains::Integer; progress=PROGRESS[], progressname="Sampling ($(Distributed.nworkers()) processes)", + init_params=nothing, kwargs..., ) # Check if actually multiple processes are used. @@ -425,13 +440,19 @@ function mcmcsample( Distributed.@async begin try - chains = Distributed.pmap(pool, seeds) do seed + function sample_chain(seed, init_params=nothing) # Seed a new random number generator with the pre-made seed. Random.seed!(rng, seed) # Sample a chain. chain = StatsBase.sample( - rng, model, sampler, N; progress=false, kwargs... + rng, + model, + sampler, + N; + progress=false, + init_params=init_params, + kwargs..., ) # Update the progress bar. @@ -440,6 +461,11 @@ function mcmcsample( # Return the new chain. return chain end + chains = if init_params === nothing + Distributed.pmap(sample_chain, pool, seeds) + else + Distributed.pmap(sample_chain, pool, seeds, init_params) + end finally # Stop updating the progress bar. progress && put!(channel, false) @@ -460,6 +486,7 @@ function mcmcsample( N::Integer, nchains::Integer; progressname="Sampling", + init_params=nothing, kwargs..., ) # Check if the number of chains is larger than the number of samples @@ -471,21 +498,60 @@ function mcmcsample( seeds = rand(rng, UInt, nchains) # Sample the chains. - chains = map(enumerate(seeds)) do (i, seed) + function sample_chain(i, seed, init_params=nothing) + # Seed a new random number generator with the pre-made seed. Random.seed!(rng, seed) + + # Sample a chain. return StatsBase.sample( rng, model, sampler, N; progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"), + init_params=init_params, kwargs..., ) end + chains = if init_params === nothing + map(sample_chain, 1:nchains, seeds) + else + map(sample_chain, 1:nchains, seeds, init_params) + end + # Concatenate the chains together. return chainsstack(tighten_eltype(chains)) end tighten_eltype(x) = x tighten_eltype(x::Vector{Any}) = map(identity, x) + +""" + _first_or_nothing(x, n::Int) + +Return the first `n` elements of collection `x`, or `nothing` if `x === nothing`. + +If `x !== nothing`, then `x` has to contain at least `n` elements. +""" +function _first_or_nothing(x, n::Int) + y = _first(x, n) + length(y) == n || throw( + ArgumentError("not enough initial parameters (expected $n, received $(length(y))"), + ) + return y +end +_first_or_nothing(::Nothing, ::Int) = nothing + +# `first(x, n::Int)` requires Julia 1.6 +function _first(x, n::Int) + @static if VERSION >= v"1.6.0-DEV.431" + first(x, n) + else + if x isa AbstractVector + @inbounds x[firstindex(x):min(firstindex(x) + n - 1, lastindex(x))] + else + collect(Iterators.take(x, n)) + end + end +end diff --git a/test/deprecations.jl b/test/deprecations.jl deleted file mode 100644 index dd53cb42..00000000 --- a/test/deprecations.jl +++ /dev/null @@ -1,4 +0,0 @@ -@testset "deprecations.jl" begin - @test_deprecated AbstractMCMC.transitions(MySample(1, 2.0), MyModel(), MySampler()) - @test_deprecated AbstractMCMC.transitions(MySample(1, 2.0), MyModel(), MySampler(), 3) -end diff --git a/test/runtests.jl b/test/runtests.jl index e8f09589..3baef78c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,5 +22,4 @@ include("utils.jl") include("sample.jl") include("stepper.jl") include("transducer.jl") - include("deprecations.jl") end diff --git a/test/sample.jl b/test/sample.jl index debb2238..f5a69c12 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -25,6 +25,13 @@ @test var(x.a for x in tail_chain) ≈ 1 / 12 atol = 5e-3 @test mean(x.b for x in tail_chain) ≈ 0.0 atol = 5e-2 @test var(x.b for x in tail_chain) ≈ 1 atol = 6e-2 + + # initial parameters + chain = sample( + MyModel(), MySampler(), 3; progress=false, init_params=(b=3.2, a=-1.8) + ) + @test chain[1].a == -1.8 + @test chain[1].b == 3.2 end @testset "Juno" begin @@ -168,6 +175,38 @@ if Threads.nthreads() == 2 sample(MyModel(), MySampler(), MCMCThreads(), N, 1) end + + # initial parameters + init_params = [(b=randn(), a=rand()) for _ in 1:100] + chains = sample( + MyModel(), + MySampler(), + MCMCThreads(), + 3, + 100; + progress=false, + init_params=init_params, + ) + @test length(chains) == 100 + @test all( + chain[1].a == params.a && chain[1].b == params.b for + (chain, params) in zip(chains, init_params) + ) + + init_params = (a=randn(), b=rand()) + chains = sample( + MyModel(), + MySampler(), + MCMCThreads(), + 3, + 100; + progress=false, + init_params=Iterators.repeated(init_params), + ) + @test length(chains) == 100 + @test all( + chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains + ) end @testset "Multicore sampling" begin @@ -244,6 +283,38 @@ ) end @test all(l.level > Logging.LogLevel(-1) for l in logs) + + # initial parameters + init_params = [(a=randn(), b=rand()) for _ in 1:100] + chains = sample( + MyModel(), + MySampler(), + MCMCDistributed(), + 3, + 100; + progress=false, + init_params=init_params, + ) + @test length(chains) == 100 + @test all( + chain[1].a == params.a && chain[1].b == params.b for + (chain, params) in zip(chains, init_params) + ) + + init_params = (b=randn(), a=rand()) + chains = sample( + MyModel(), + MySampler(), + MCMCDistributed(), + 3, + 100; + progress=false, + init_params=Iterators.repeated(init_params), + ) + @test length(chains) == 100 + @test all( + chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains + ) end @testset "Serial sampling" begin @@ -295,6 +366,38 @@ ) end @test all(l.level > Logging.LogLevel(-1) for l in logs) + + # initial parameters + init_params = [(a=rand(), b=randn()) for _ in 1:100] + chains = sample( + MyModel(), + MySampler(), + MCMCSerial(), + 3, + 100; + progress=false, + init_params=init_params, + ) + @test length(chains) == 100 + @test all( + chain[1].a == params.a && chain[1].b == params.b for + (chain, params) in zip(chains, init_params) + ) + + init_params = (b=rand(), a=randn()) + chains = sample( + MyModel(), + MySampler(), + MCMCSerial(), + 3, + 100; + progress=false, + init_params=Iterators.repeated(init_params), + ) + @test length(chains) == 100 + @test all( + chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains + ) end @testset "Ensemble sampling: Reproducibility" begin diff --git a/test/utils.jl b/test/utils.jl index 32474639..67ba1481 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -23,11 +23,15 @@ function AbstractMCMC.step( state::Union{Nothing,Integer}=nothing; sleepy=false, loggers=false, + init_params=nothing, kwargs..., ) - # sample `a` is missing in the first step - a = state === nothing ? missing : rand(rng) - b = randn(rng) + # sample `a` is missing in the first step if not provided + a, b = if state === nothing && init_params !== nothing + init_params.a, init_params.b + else + (state === nothing ? missing : rand(rng)), randn(rng) + end loggers && push!(LOGGERS, Logging.current_logger()) sleepy && sleep(0.001)