Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use _init_parmas for MCMCThreads and MCMCDistributed too #126

Merged
merged 18 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "4.4.2"
version = "4.5.0"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand Down
46 changes: 16 additions & 30 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,8 @@ function mcmcsample(
# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

# Ensure that initial parameters are `nothing` or indexable
_init_params = _first_or_nothing(init_params, nchains)
# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)

# Set up a chains vector.
chains = Vector{Any}(undef, nchains)
Expand Down Expand Up @@ -364,10 +364,10 @@ function mcmcsample(
_sampler,
N;
progress=false,
init_params=if _init_params === nothing
init_params=if init_params === nothing
nothing
else
_init_params[chainidx]
init_params[chainidx]
end,
kwargs...,
)
Expand Down Expand Up @@ -410,6 +410,9 @@ function mcmcsample(
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

Expand Down Expand Up @@ -499,6 +502,9 @@ function mcmcsample(
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

Expand Down Expand Up @@ -532,31 +538,11 @@ end
tighten_eltype(x) = x
tighten_eltype(x::Vector{Any}) = map(identity, x)

"""
_first_or_nothing(x, n::Int)

Return the first `n` elements of collection `x`, or `nothing` if `x === nothing`.

If `x !== nothing`, then `x` has to contain at least `n` elements.
"""
function _first_or_nothing(x, n::Int)
y = _first(x, n)
length(y) == n || throw(
ArgumentError("not enough initial parameters (expected $n, received $(length(y))"),
)
return y
end
_first_or_nothing(::Nothing, ::Int) = nothing

# `first(x, n::Int)` requires Julia 1.6
function _first(x, n::Int)
@static if VERSION >= v"1.6.0-DEV.431"
first(x, n)
else
if x isa AbstractVector
@inbounds x[firstindex(x):min(firstindex(x) + n - 1, lastindex(x))]
else
collect(Iterators.take(x, n))
end
check_initial_params(x::Nothing, n::Int) = nothing
function check_initial_params(x, n::Int)
if length(x) != n
throw(
ArgumentError("not enough initial parameters (expected $n, received $(length(x))"),
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
)
end
end
105 changes: 87 additions & 18 deletions test/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,18 @@
end

# initial parameters
init_params = [(b=randn(), a=rand()) for _ in 1:100]
nchains = 100
init_params = [(b=randn(), a=rand()) for _ in 1:nchains]
chains = sample(
MyModel(),
MySampler(),
MCMCThreads(),
3,
100;
nchains;
progress=false,
init_params=init_params,
)
@test length(chains) == 100
@test length(chains) == nchains
@test all(
chain[1].a == params.a && chain[1].b == params.b for
(chain, params) in zip(chains, init_params)
Expand All @@ -184,14 +185,36 @@
MySampler(),
MCMCThreads(),
3,
100;
nchains;
progress=false,
init_params=Iterators.repeated(init_params),
init_params=Iterators.repeated(init_params, nchains),
)
@test length(chains) == 100
@test length(chains) == nchains
@test all(
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
)

# Too many `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCThreads(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains + 1),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So _first_or_nothing doesn't actually care if there are too many initial parameters 😕

IMO it seems like we want to raise an error in this case since it's otherwise very easy to accidentally do the incorrect thing.

)

# Too few `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCThreads(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains - 1),
)
end

@testset "Multicore sampling" begin
Expand Down Expand Up @@ -274,17 +297,18 @@
@test all(l.level > Logging.LogLevel(-1) for l in logs)

# initial parameters
init_params = [(a=randn(), b=rand()) for _ in 1:100]
nchains = 100
init_params = [(a=randn(), b=rand()) for _ in 1:nchains]
chains = sample(
MyModel(),
MySampler(),
MCMCDistributed(),
3,
100;
nchains;
progress=false,
init_params=init_params,
)
@test length(chains) == 100
@test length(chains) == nchains
@test all(
chain[1].a == params.a && chain[1].b == params.b for
(chain, params) in zip(chains, init_params)
Expand All @@ -296,15 +320,37 @@
MySampler(),
MCMCDistributed(),
3,
100;
nchains;
progress=false,
init_params=Iterators.repeated(init_params),
init_params=Iterators.repeated(init_params, nchains),
)
@test length(chains) == 100
@test length(chains) == nchains
@test all(
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
)

# Too many `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCDistributed(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains + 1),
)

# Too few `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCDistributed(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains - 1),
)

# Remove workers
rmprocs(pids...)
end
Expand Down Expand Up @@ -360,17 +406,18 @@
@test all(l.level > Logging.LogLevel(-1) for l in logs)

# initial parameters
init_params = [(a=rand(), b=randn()) for _ in 1:100]
nchains = 100
init_params = [(a=rand(), b=randn()) for _ in 1:nchains]
chains = sample(
MyModel(),
MySampler(),
MCMCSerial(),
3,
100;
nchains;
progress=false,
init_params=init_params,
)
@test length(chains) == 100
@test length(chains) == nchains
@test all(
chain[1].a == params.a && chain[1].b == params.b for
(chain, params) in zip(chains, init_params)
Expand All @@ -382,14 +429,36 @@
MySampler(),
MCMCSerial(),
3,
100;
nchains;
progress=false,
init_params=Iterators.repeated(init_params),
init_params=Iterators.repeated(init_params, nchains),
)
@test length(chains) == 100
@test length(chains) == nchains
@test all(
chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains
)

# Too many `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCSerial(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains + 1),
)

# Too few `init_params`
@test_throws ArgumentError sample(
MyModel(),
MySampler(),
MCMCSerial(),
3,
nchains;
progress=false,
init_params=Iterators.repeated(init_params, nchains - 1),
)
end

@testset "Ensemble sampling: Reproducibility" begin
Expand Down