Skip to content

Commit

Permalink
Support init_params in ensemble methods (#94)
Browse files Browse the repository at this point in the history
* Support `init_params` in ensemble methods

* Fix typo

* Fix typo

* Add documentation

* Support `Iterators.Repeated`

* Breaking release

* Fix and simplify docs setup

* Remove deprecations

* Reduce tasks on Windows

* Generalize to arbitrary collections

* Use Blue style
  • Loading branch information
devmotion authored Mar 7, 2022
1 parent 4994a79 commit 3de7393
Show file tree
Hide file tree
Showing 9 changed files with 187 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "3.3.1"
version = "4.0.0"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
5 changes: 5 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ are:
- `discard_initial` (default: `0`): number of initial samples that are discarded
- `thinning` (default: `1`): factor by which to thin samples.

There is no "official" way for providing initial parameter values yet.
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain.
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
- `init_params` (default: `nothing`): if set to `init_params !== nothing`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = Iterators.repeated(x)` or `init_params = FillArrays.Fill(x, N)`.

Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`.

```@docs
Expand Down
1 change: 0 additions & 1 deletion src/AbstractMCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,5 @@ include("interface.jl")
include("sample.jl")
include("stepper.jl")
include("transducer.jl")
include("deprecations.jl")

end # module AbstractMCMC
2 changes: 0 additions & 2 deletions src/deprecations.jl

This file was deleted.

76 changes: 71 additions & 5 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ function mcmcsample(
nchains::Integer;
progress=PROGRESS[],
progressname="Sampling ($(min(nchains, Threads.nthreads())) threads)",
init_params=nothing,
kwargs...,
)
# Check if actually multiple threads are used.
Expand All @@ -298,14 +299,17 @@ function mcmcsample(
# Copy the random number generator, model, and sample for each thread
nchunks = min(nchains, Threads.nthreads())
chunksize = cld(nchains, nchunks)
interval = 1:min(nchains, Threads.nthreads())
interval = 1:nchunks
rngs = [deepcopy(rng) for _ in interval]
models = [deepcopy(model) for _ in interval]
samplers = [deepcopy(sampler) for _ in interval]

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

# Ensure that initial parameters are `nothing` or indexable
_init_params = _first_or_nothing(init_params, nchains)

# Set up a chains vector.
chains = Vector{Any}(undef, nchains)

Expand Down Expand Up @@ -350,7 +354,17 @@ function mcmcsample(

# Sample a chain and save it to the vector.
chains[chainidx] = StatsBase.sample(
_rng, _model, _sampler, N; progress=false, kwargs...
_rng,
_model,
_sampler,
N;
progress=false,
init_params=if _init_params === nothing
nothing
else
_init_params[chainidx]
end,
kwargs...,
)

# Update the progress bar.
Expand Down Expand Up @@ -378,6 +392,7 @@ function mcmcsample(
nchains::Integer;
progress=PROGRESS[],
progressname="Sampling ($(Distributed.nworkers()) processes)",
init_params=nothing,
kwargs...,
)
# Check if actually multiple processes are used.
Expand Down Expand Up @@ -425,13 +440,19 @@ function mcmcsample(

Distributed.@async begin
try
chains = Distributed.pmap(pool, seeds) do seed
function sample_chain(seed, init_params=nothing)
# Seed a new random number generator with the pre-made seed.
Random.seed!(rng, seed)

# Sample a chain.
chain = StatsBase.sample(
rng, model, sampler, N; progress=false, kwargs...
rng,
model,
sampler,
N;
progress=false,
init_params=init_params,
kwargs...,
)

# Update the progress bar.
Expand All @@ -440,6 +461,11 @@ function mcmcsample(
# Return the new chain.
return chain
end
chains = if init_params === nothing
Distributed.pmap(sample_chain, pool, seeds)
else
Distributed.pmap(sample_chain, pool, seeds, init_params)
end
finally
# Stop updating the progress bar.
progress && put!(channel, false)
Expand All @@ -460,6 +486,7 @@ function mcmcsample(
N::Integer,
nchains::Integer;
progressname="Sampling",
init_params=nothing,
kwargs...,
)
# Check if the number of chains is larger than the number of samples
Expand All @@ -471,21 +498,60 @@ function mcmcsample(
seeds = rand(rng, UInt, nchains)

# Sample the chains.
chains = map(enumerate(seeds)) do (i, seed)
function sample_chain(i, seed, init_params=nothing)
# Seed a new random number generator with the pre-made seed.
Random.seed!(rng, seed)

# Sample a chain.
return StatsBase.sample(
rng,
model,
sampler,
N;
progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"),
init_params=init_params,
kwargs...,
)
end

chains = if init_params === nothing
map(sample_chain, 1:nchains, seeds)
else
map(sample_chain, 1:nchains, seeds, init_params)
end

# Concatenate the chains together.
return chainsstack(tighten_eltype(chains))
end

tighten_eltype(x) = x
tighten_eltype(x::Vector{Any}) = map(identity, x)

"""
_first_or_nothing(x, n::Int)
Return the first `n` elements of collection `x`, or `nothing` if `x === nothing`.
If `x !== nothing`, then `x` has to contain at least `n` elements.
"""
function _first_or_nothing(x, n::Int)
y = _first(x, n)
length(y) == n || throw(
ArgumentError("not enough initial parameters (expected $n, received $(length(y))"),
)
return y
end
_first_or_nothing(::Nothing, ::Int) = nothing

# `first(x, n::Int)` requires Julia 1.6
function _first(x, n::Int)
@static if VERSION >= v"1.6.0-DEV.431"
first(x, n)
else
if x isa AbstractVector
@inbounds x[firstindex(x):min(firstindex(x) + n - 1, lastindex(x))]
else
collect(Iterators.take(x, n))
end
end
end
4 changes: 0 additions & 4 deletions test/deprecations.jl

This file was deleted.

1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,4 @@ include("utils.jl")
include("sample.jl")
include("stepper.jl")
include("transducer.jl")
include("deprecations.jl")
end
103 changes: 103 additions & 0 deletions test/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
@test var(x.a for x in tail_chain) 1 / 12 atol = 5e-3
@test mean(x.b for x in tail_chain) 0.0 atol = 5e-2
@test var(x.b for x in tail_chain) 1 atol = 6e-2

# initial parameters
chain = sample(
MyModel(), MySampler(), 3; progress=false, init_params=(b=3.2, a=-1.8)
)
@test chain[1].a == -1.8
@test chain[1].b == 3.2
end

@testset "Juno" begin
Expand Down Expand Up @@ -168,6 +175,38 @@
if Threads.nthreads() == 2
sample(MyModel(), MySampler(), MCMCThreads(), N, 1)
end

# initial parameters
init_params = [(b=randn(), a=rand()) for _ in 1:100]
chains = sample(
MyModel(),
MySampler(),
MCMCThreads(),
3,
100;
progress=false,
init_params=init_params,
)
@test length(chains) == 100
@test all(
chain[1].a == params.a && chain[1].b == params.b for
(chain, params) in zip(chains, init_params)
)

init_params = (a=randn(), b=rand())
chains = sample(
MyModel(),
MySampler(),
MCMCThreads(),
3,
100;
progress=false,
init_params=Iterators.repeated(init_params),
)
@test length(chains) == 100
@test all(
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
)
end

@testset "Multicore sampling" begin
Expand Down Expand Up @@ -244,6 +283,38 @@
)
end
@test all(l.level > Logging.LogLevel(-1) for l in logs)

# initial parameters
init_params = [(a=randn(), b=rand()) for _ in 1:100]
chains = sample(
MyModel(),
MySampler(),
MCMCDistributed(),
3,
100;
progress=false,
init_params=init_params,
)
@test length(chains) == 100
@test all(
chain[1].a == params.a && chain[1].b == params.b for
(chain, params) in zip(chains, init_params)
)

init_params = (b=randn(), a=rand())
chains = sample(
MyModel(),
MySampler(),
MCMCDistributed(),
3,
100;
progress=false,
init_params=Iterators.repeated(init_params),
)
@test length(chains) == 100
@test all(
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
)
end

@testset "Serial sampling" begin
Expand Down Expand Up @@ -295,6 +366,38 @@
)
end
@test all(l.level > Logging.LogLevel(-1) for l in logs)

# initial parameters
init_params = [(a=rand(), b=randn()) for _ in 1:100]
chains = sample(
MyModel(),
MySampler(),
MCMCSerial(),
3,
100;
progress=false,
init_params=init_params,
)
@test length(chains) == 100
@test all(
chain[1].a == params.a && chain[1].b == params.b for
(chain, params) in zip(chains, init_params)
)

init_params = (b=rand(), a=randn())
chains = sample(
MyModel(),
MySampler(),
MCMCSerial(),
3,
100;
progress=false,
init_params=Iterators.repeated(init_params),
)
@test length(chains) == 100
@test all(
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
)
end

@testset "Ensemble sampling: Reproducibility" begin
Expand Down
10 changes: 7 additions & 3 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@ function AbstractMCMC.step(
state::Union{Nothing,Integer}=nothing;
sleepy=false,
loggers=false,
init_params=nothing,
kwargs...,
)
# sample `a` is missing in the first step
a = state === nothing ? missing : rand(rng)
b = randn(rng)
# sample `a` is missing in the first step if not provided
a, b = if state === nothing && init_params !== nothing
init_params.a, init_params.b
else
(state === nothing ? missing : rand(rng)), randn(rng)
end

loggers && push!(LOGGERS, Logging.current_logger())
sleepy && sleep(0.001)
Expand Down

2 comments on commit 3de7393

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/56117

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v4.0.0 -m "<description of version>" 3de7393b8b8e76330f53505b27d2b928ef178681
git push origin v4.0.0

Please sign in to comment.