Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use _init_parmas for MCMCThreads and MCMCDistributed too #126

Merged
merged 18 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "4.4.2"
version = "4.6.0"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand All @@ -30,9 +30,10 @@ Transducers = "0.4.30"
julia = "1.6"

[extras]
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["IJulia", "Statistics", "Test"]
test = ["FillArrays", "IJulia", "Statistics", "Test"]
2 changes: 1 addition & 1 deletion docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ Common keyword arguments for regular and parallel sampling are:
There is no "official" way for providing initial parameter values yet.
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain.
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
- `init_params` (default: `nothing`): if set to `init_params !== nothing`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = Iterators.repeated(x)` or `init_params = FillArrays.Fill(x, N)`.
- `init_params` (default: `nothing`): if `init_params isa AbstractArray`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = FillArrays.Fill(x, N)`.

Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`.

Expand Down
54 changes: 25 additions & 29 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,8 @@
# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

# Ensure that initial parameters are `nothing` or indexable
_init_params = _first_or_nothing(init_params, nchains)
# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)

# Set up a chains vector.
chains = Vector{Any}(undef, nchains)
Expand Down Expand Up @@ -364,10 +364,10 @@
_sampler,
N;
progress=false,
init_params=if _init_params === nothing
init_params=if init_params === nothing
nothing
else
_init_params[chainidx]
init_params[chainidx]
end,
kwargs...,
)
Expand Down Expand Up @@ -410,6 +410,9 @@
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

Expand Down Expand Up @@ -499,6 +502,9 @@
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

Expand Down Expand Up @@ -532,31 +538,21 @@
tighten_eltype(x) = x
tighten_eltype(x::Vector{Any}) = map(identity, x)

"""
_first_or_nothing(x, n::Int)

Return the first `n` elements of collection `x`, or `nothing` if `x === nothing`.

If `x !== nothing`, then `x` has to contain at least `n` elements.
"""
function _first_or_nothing(x, n::Int)
y = _first(x, n)
length(y) == n || throw(
ArgumentError("not enough initial parameters (expected $n, received $(length(y))"),
)
return y
end
_first_or_nothing(::Nothing, ::Int) = nothing
@nospecialize check_initial_params(x, n) = throw(

Check warning on line 541 in src/sample.jl

View check run for this annotation

Codecov / codecov/patch

src/sample.jl#L541

Added line #L541 was not covered by tests
ArgumentError(
"initial parameters must be specified as a vector of length equal to the number of chains or `nothing`",
),
)

# `first(x, n::Int)` requires Julia 1.6
function _first(x, n::Int)
@static if VERSION >= v"1.6.0-DEV.431"
first(x, n)
else
if x isa AbstractVector
@inbounds x[firstindex(x):min(firstindex(x) + n - 1, lastindex(x))]
else
collect(Iterators.take(x, n))
end
check_initial_params(::Nothing, n) = nothing
function check_initial_params(x::AbstractArray, n)
if length(x) != n
throw(
ArgumentError(
"incorrect number of initial parameters (expected $n, received $(length(x))"
),
)
end

return nothing
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using IJulia
using LogDensityProblems
using LoggingExtras: TeeLogger, EarlyFilteredLogger
using TerminalLoggers: TerminalLogger
using FillArrays: FillArrays
using Transducers

using Distributed
Expand Down
105 changes: 87 additions & 18 deletions test/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,18 @@
end

# initial parameters
init_params = [(b=randn(), a=rand()) for _ in 1:100]
nchains = 100
init_params = [(b=randn(), a=rand()) for _ in 1:nchains]
chains = sample(
MyModel(),
MySampler(),
MCMCThreads(),
3,
100;
nchains;
progress=false,
init_params=init_params,
)
@test length(chains) == 100
@test length(chains) == nchains
@test all(
chain[1].a == params.a && chain[1].b == params.b for
(chain, params) in zip(chains, init_params)
Expand All @@ -184,14 +185,36 @@
MySampler(),
MCMCThreads(),
3,
100;
nchains;
progress=false,
init_params=Iterators.repeated(init_params),
init_params=FillArrays.Fill(init_params, nchains),
)
@test length(chains) == 100
@test length(chains) == nchains
@test all(
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
)

# Too many `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCThreads(),
3,
nchains;
progress=false,
init_params=FillArrays.Fill(init_params, nchains + 1),
)

# Too few `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCThreads(),
3,
nchains;
progress=false,
init_params=FillArrays.Fill(init_params, nchains - 1),
)
end

@testset "Multicore sampling" begin
Expand Down Expand Up @@ -274,17 +297,18 @@
@test all(l.level > Logging.LogLevel(-1) for l in logs)

# initial parameters
init_params = [(a=randn(), b=rand()) for _ in 1:100]
nchains = 100
init_params = [(a=randn(), b=rand()) for _ in 1:nchains]
chains = sample(
MyModel(),
MySampler(),
MCMCDistributed(),
3,
100;
nchains;
progress=false,
init_params=init_params,
)
@test length(chains) == 100
@test length(chains) == nchains
@test all(
chain[1].a == params.a && chain[1].b == params.b for
(chain, params) in zip(chains, init_params)
Expand All @@ -296,15 +320,37 @@
MySampler(),
MCMCDistributed(),
3,
100;
nchains;
progress=false,
init_params=Iterators.repeated(init_params),
init_params=FillArrays.Fill(init_params, nchains),
)
@test length(chains) == 100
@test length(chains) == nchains
@test all(
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
)

# Too many `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCDistributed(),
3,
nchains;
progress=false,
init_params=FillArrays.Fill(init_params, nchains + 1),
)

# Too few `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCDistributed(),
3,
nchains;
progress=false,
init_params=FillArrays.Fill(init_params, nchains - 1),
)

# Remove workers
rmprocs(pids...)
end
Expand Down Expand Up @@ -360,17 +406,18 @@
@test all(l.level > Logging.LogLevel(-1) for l in logs)

# initial parameters
init_params = [(a=rand(), b=randn()) for _ in 1:100]
nchains = 100
init_params = [(a=rand(), b=randn()) for _ in 1:nchains]
chains = sample(
MyModel(),
MySampler(),
MCMCSerial(),
3,
100;
nchains;
progress=false,
init_params=init_params,
)
@test length(chains) == 100
@test length(chains) == nchains
@test all(
chain[1].a == params.a && chain[1].b == params.b for
(chain, params) in zip(chains, init_params)
Expand All @@ -382,14 +429,36 @@
MySampler(),
MCMCSerial(),
3,
100;
nchains;
progress=false,
init_params=Iterators.repeated(init_params),
init_params=FillArrays.Fill(init_params, nchains),
)
@test length(chains) == 100
@test length(chains) == nchains
@test all(
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
)

# Too many `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCSerial(),
3,
nchains;
progress=false,
init_params=FillArrays.Fill(init_params, nchains + 1),
)

# Too few `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCSerial(),
3,
nchains;
progress=false,
init_params=FillArrays.Fill(init_params, nchains - 1),
)
end

@testset "Ensemble sampling: Reproducibility" begin
Expand Down
Loading