Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reversion of commits v4.4.2..v4.5.0 #133

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "A lightweight interface for common MCMC methods."
version = "4.5.0"
version = "4.5.1"

[deps]
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36"
Expand All @@ -22,7 +21,6 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
[compat]
BangBang = "0.3.19"
ConsoleProgressMonitor = "0.1"
FillArrays = "1"
LogDensityProblems = "2"
LoggingExtras = "0.4, 0.5, 1"
ProgressLogging = "0.1"
Expand All @@ -32,10 +30,9 @@ Transducers = "0.4.30"
julia = "1.6"

[extras]
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["FillArrays", "IJulia", "Statistics", "Test"]
test = ["IJulia", "Statistics", "Test"]
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[compat]
Documenter = "1"
Documenter = "0.27"
julia = "1"
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ makedocs(;
format=Documenter.HTML(),
modules=[AbstractMCMC],
pages=["Home" => "index.md", "api.md", "design.md"],
strict=true,
checkdocs=:exports,
)

Expand Down
12 changes: 5 additions & 7 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,18 @@ Common keyword arguments for regular and parallel sampling are:
- `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging
- `chain_type` (default: `Any`): determines the type of the returned chain
- `callback` (default: `nothing`): if `callback !== nothing`, then
`callback(rng, model, sampler, sample, state, iteration)` is called after every sampling step,
where `sample` is the most recent sample of the Markov chain and `state` and `iteration` are the current state and iteration of the sampler
`callback(rng, model, sampler, sample, iteration)` is called after every sampling step,
where `sample` is the most recent sample of the Markov chain and `iteration` is the current iteration
- `discard_initial` (default: `0`): number of initial samples that are discarded
- `thinning` (default: `1`): factor by which to thin samples.
- `initial_state` (default: `nothing`): if `initial_state !== nothing`, the first call to [`AbstractMCMC.step`](@ref)
is passed `initial_state` as the `state` argument.

!!! info
The common keyword arguments `progress`, `chain_type`, and `callback` are not supported by the iterator [`AbstractMCMC.steps`](@ref) and the transducer [`AbstractMCMC.Sample`](@ref).

There is no "official" way for providing initial parameter values yet.
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `initial_params` keyword argument for setting the initial values when sampling a single chain.
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `initial_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
- `initial_params` (default: `nothing`): if `initial_params isa AbstractArray`, then the `i`th element of `initial_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `initial_params = FillArrays.Fill(x, N)`.
However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain.
To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94):
- `init_params` (default: `nothing`): if set to `init_params !== nothing`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = Iterators.repeated(x)` or `init_params = FillArrays.Fill(x, N)`.

Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`.

Expand Down
1 change: 0 additions & 1 deletion src/AbstractMCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ using ProgressLogging: ProgressLogging
using StatsBase: StatsBase
using TerminalLoggers: TerminalLoggers
using Transducers: Transducers
using FillArrays: FillArrays

using Distributed: Distributed
using Logging: Logging
Expand Down
134 changes: 47 additions & 87 deletions src/sample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@
discard_initial=0,
thinning=1,
chain_type::Type=Any,
initial_state=nothing,
kwargs...,
)
# Check the number of requested samples.
Expand All @@ -123,11 +122,7 @@
end

# Obtain the initial sample and state.
sample, state = if initial_state === nothing
step(rng, model, sampler; kwargs...)
else
step(rng, model, sampler, initial_state; kwargs...)
end
sample, state = step(rng, model, sampler; kwargs...)

# Discard initial samples.
for i in 1:discard_initial
Expand Down Expand Up @@ -216,7 +211,6 @@
callback=nothing,
discard_initial=0,
thinning=1,
initial_state=nothing,
kwargs...,
)

Expand All @@ -226,11 +220,7 @@

@ifwithprogresslogger progress name = progressname begin
# Obtain the initial sample and state.
sample, state = if initial_state === nothing
step(rng, model, sampler; kwargs...)
else
step(rng, model, sampler, state; kwargs...)
end
sample, state = step(rng, model, sampler; kwargs...)

# Discard initial samples.
for _ in 1:discard_initial
Expand Down Expand Up @@ -298,8 +288,7 @@
nchains::Integer;
progress=PROGRESS[],
progressname="Sampling ($(min(nchains, Threads.nthreads())) threads)",
initial_params=nothing,
initial_state=nothing,
init_params=nothing,
kwargs...,
)
# Check if actually multiple threads are used.
Expand All @@ -323,9 +312,8 @@
# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

# Ensure that initial parameters and states are `nothing` or of the correct length
check_initial_params(initial_params, nchains)
check_initial_state(initial_state, nchains)
# Ensure that initial parameters are `nothing` or indexable
_init_params = _first_or_nothing(init_params, nchains)

# Set up a chains vector.
chains = Vector{Any}(undef, nchains)
Expand Down Expand Up @@ -376,15 +364,10 @@
_sampler,
N;
progress=false,
initial_params=if initial_params === nothing
nothing
else
initial_params[chainidx]
end,
initial_state=if initial_state === nothing
init_params=if _init_params === nothing
nothing
else
initial_state[chainidx]
_init_params[chainidx]
end,
kwargs...,
)
Expand Down Expand Up @@ -414,8 +397,7 @@
nchains::Integer;
progress=PROGRESS[],
progressname="Sampling ($(Distributed.nworkers()) processes)",
initial_params=nothing,
initial_state=nothing,
init_params=nothing,
kwargs...,
)
# Check if actually multiple processes are used.
Expand All @@ -428,15 +410,6 @@
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Ensure that initial parameters and states are `nothing` or of the correct length
check_initial_params(initial_params, nchains)
check_initial_state(initial_state, nchains)

_initial_params =
initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
_initial_state =
initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

Expand Down Expand Up @@ -472,7 +445,7 @@

Distributed.@async begin
try
function sample_chain(seed, initial_params, initial_state)
function sample_chain(seed, init_params=nothing)
# Seed a new random number generator with the pre-made seed.
Random.seed!(rng, seed)

Expand All @@ -483,8 +456,7 @@
sampler,
N;
progress=false,
initial_params=initial_params,
initial_state=initial_state,
init_params=init_params,
kwargs...,
)

Expand All @@ -494,9 +466,11 @@
# Return the new chain.
return chain
end
chains = Distributed.pmap(
sample_chain, pool, seeds, _initial_params, _initial_state
)
chains = if init_params === nothing
Distributed.pmap(sample_chain, pool, seeds)
else
Distributed.pmap(sample_chain, pool, seeds, init_params)
end
finally
# Stop updating the progress bar.
progress && put!(channel, false)
Expand All @@ -517,29 +491,19 @@
N::Integer,
nchains::Integer;
progressname="Sampling",
initial_params=nothing,
initial_state=nothing,
init_params=nothing,
kwargs...,
)
# Check if the number of chains is larger than the number of samples
if nchains > N
@warn "Number of chains ($nchains) is greater than number of samples per chain ($N)"
end

# Ensure that initial parameters and states are `nothing` or of the correct length
check_initial_params(initial_params, nchains)
check_initial_state(initial_state, nchains)

_initial_params =
initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params
_initial_state =
initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state

# Create a seed for each chain using the provided random number generator.
seeds = rand(rng, UInt, nchains)

# Sample the chains.
function sample_chain(i, seed, initial_params, initial_state)
function sample_chain(i, seed, init_params=nothing)
# Seed a new random number generator with the pre-made seed.
Random.seed!(rng, seed)

Expand All @@ -550,13 +514,16 @@
sampler,
N;
progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"),
initial_params=initial_params,
initial_state=initial_state,
init_params=init_params,
kwargs...,
)
end

chains = map(sample_chain, 1:nchains, seeds, _initial_params, _initial_state)
chains = if init_params === nothing
map(sample_chain, 1:nchains, seeds)
else
map(sample_chain, 1:nchains, seeds, init_params)
end

# Concatenate the chains together.
return chainsstack(tighten_eltype(chains))
Expand All @@ -565,38 +532,31 @@
tighten_eltype(x) = x
tighten_eltype(x::Vector{Any}) = map(identity, x)

@nospecialize check_initial_params(x, n) = throw(
ArgumentError(
"initial parameters must be specified as a vector of length equal to the number of chains or `nothing`",
),
)
check_initial_params(::Nothing, n) = nothing
function check_initial_params(x::AbstractArray, n)
if length(x) != n
throw(
ArgumentError(
"incorrect number of initial parameters (expected $n, received $(length(x))"
),
)
end
"""
_first_or_nothing(x, n::Int)

return nothing
end
Return the first `n` elements of collection `x`, or `nothing` if `x === nothing`.

@nospecialize check_initial_state(x, n) = throw(
ArgumentError(
"initial states must be specified as a vector of length equal to the number of chains or `nothing`",
),
)
check_initial_state(::Nothing, n) = nothing
function check_initial_state(x::AbstractArray, n)
if length(x) != n
throw(
ArgumentError(
"incorrect number of initial states (expected $n, received $(length(x))"
),
)
If `x !== nothing`, then `x` has to contain at least `n` elements.
"""
function _first_or_nothing(x, n::Int)
y = _first(x, n)
length(y) == n || throw(
ArgumentError("not enough initial parameters (expected $n, received $(length(y))"),
)
return y
end
_first_or_nothing(::Nothing, ::Int) = nothing

# `first(x, n::Int)` requires Julia 1.6
function _first(x, n::Int)
@static if VERSION >= v"1.6.0-DEV.431"

Check warning on line 553 in src/sample.jl

View check run for this annotation

Codecov / codecov/patch

src/sample.jl#L553

Added line #L553 was not covered by tests
first(x, n)
else
if x isa AbstractVector
@inbounds x[firstindex(x):min(firstindex(x) + n - 1, lastindex(x))]
else
collect(Iterators.take(x, n))
end
end

return nothing
end
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ using IJulia
using LogDensityProblems
using LoggingExtras: TeeLogger, EarlyFilteredLogger
using TerminalLoggers: TerminalLogger
using FillArrays: FillArrays
using Transducers

using Distributed
Expand Down
Loading
Loading