Skip to content

Commit

Permalink
remove _first_or_nothing and just check if init_params is of the righ…
Browse files Browse the repository at this point in the history
…t length
  • Loading branch information
torfjelde committed Sep 14, 2023
1 parent 2e6e23d commit e897b8a
Showing 1 changed file with 15 additions and 37 deletions.
52 changes: 15 additions & 37 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,8 @@ function mcmcsample(
# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

# Ensure that initial parameters are `nothing` or indexable
_init_params = _first_or_nothing(init_params, nchains)
# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)

# Set up a chains vector.
chains = Vector{Any}(undef, nchains)
Expand Down Expand Up @@ -364,10 +364,10 @@ function mcmcsample(
_sampler,
N;
progress=false,
init_params=if _init_params === nothing
init_params=if init_params === nothing
nothing
else
_init_params[chainidx]
init_params[chainidx]
end,
kwargs...,
)
Expand Down Expand Up @@ -410,8 +410,8 @@ function mcmcsample(
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Ensure that initial parameters are `nothing` or indexable
_init_params = _first_or_nothing(init_params, nchains)
# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)
Expand Down Expand Up @@ -469,10 +469,10 @@ function mcmcsample(
# Return the new chain.
return chain
end
chains = if _init_params === nothing
chains = if init_params === nothing
Distributed.pmap(sample_chain, pool, seeds)
else
Distributed.pmap(sample_chain, pool, seeds, _init_params)
Distributed.pmap(sample_chain, pool, seeds, init_params)
end
finally
# Stop updating the progress bar.
Expand Down Expand Up @@ -502,8 +502,8 @@ function mcmcsample(
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Ensure that initial parameters are `nothing` or indexable
_init_params = _first_or_nothing(init_params, nchains)
# Ensure that initial parameters are `nothing` or of the correct length
check_initial_params(init_params, nchains)

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)
Expand All @@ -525,10 +525,10 @@ function mcmcsample(
)
end

chains = if _init_params === nothing
chains = if init_params === nothing
map(sample_chain, 1:nchains, seeds)
else
map(sample_chain, 1:nchains, seeds, _init_params)
map(sample_chain, 1:nchains, seeds, init_params)
end

# Concatenate the chains together.
Expand All @@ -538,31 +538,9 @@ end
tighten_eltype(x) = x
tighten_eltype(x::Vector{Any}) = map(identity, x)

"""
_first_or_nothing(x, n::Int)
Return the first `n` elements of collection `x`, or `nothing` if `x === nothing`.
If `x !== nothing`, then `x` has to contain at least `n` elements.
"""
function _first_or_nothing(x, n::Int)
y = _first(x, n)
length(y) == n || throw(
check_initial_params(x::Nothing, n::Int) = nothing
function check_initial_params(x, n::Int)
length(x) == n || throw(
ArgumentError("not enough initial parameters (expected $n, received $(length(y))"),
)
return y
end
_first_or_nothing(::Nothing, ::Int) = nothing

# `first(x, n::Int)` requires Julia 1.6
function _first(x, n::Int)
@static if VERSION >= v"1.6.0-DEV.431"
first(x, n)
else
if x isa AbstractVector
@inbounds x[firstindex(x):min(firstindex(x) + n - 1, lastindex(x))]
else
collect(Iterators.take(x, n))
end
end
end

0 comments on commit e897b8a

Please sign in to comment.