diff --git a/Project.toml b/Project.toml index a796c429..702f3639 100644 --- a/Project.toml +++ b/Project.toml @@ -3,13 +3,12 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.5.0" +version = "4.5.1" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" -FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" @@ -22,7 +21,6 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" [compat] BangBang = "0.3.19" ConsoleProgressMonitor = "0.1" -FillArrays = "1" LogDensityProblems = "2" LoggingExtras = "0.4, 0.5, 1" ProgressLogging = "0.1" @@ -32,10 +30,9 @@ Transducers = "0.4.30" julia = "1.6" [extras] -FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["FillArrays", "IJulia", "Statistics", "Test"] +test = ["IJulia", "Statistics", "Test"] diff --git a/docs/Project.toml b/docs/Project.toml index f74dfb58..555443ab 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,5 +3,5 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] -Documenter = "1" +Documenter = "0.27" julia = "1" diff --git a/docs/make.jl b/docs/make.jl index 9395d2a0..66d7619c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -9,6 +9,7 @@ makedocs(; format=Documenter.HTML(), modules=[AbstractMCMC], pages=["Home" => "index.md", "api.md", "design.md"], + strict=true, checkdocs=:exports, ) diff --git a/docs/src/api.md b/docs/src/api.md index aabf8d6f..52c2c2e1 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -71,20 +71,18 @@ Common keyword arguments for regular and parallel sampling are: - `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging - `chain_type` (default: `Any`): determines the type of the returned chain - `callback` (default: `nothing`): if `callback !== nothing`, then - `callback(rng, model, sampler, sample, state, iteration)` is called after every sampling step, - where `sample` is the most recent sample of the Markov chain and `state` and `iteration` are the current state and iteration of the sampler + `callback(rng, model, sampler, sample, iteration)` is called after every sampling step, + where `sample` is the most recent sample of the Markov chain and `iteration` is the current iteration - `discard_initial` (default: `0`): number of initial samples that are discarded - `thinning` (default: `1`): factor by which to thin samples. -- `initial_state` (default: `nothing`): if `initial_state !== nothing`, the first call to [`AbstractMCMC.step`](@ref) - is passed `initial_state` as the `state` argument. !!! info The common keyword arguments `progress`, `chain_type`, and `callback` are not supported by the iterator [`AbstractMCMC.steps`](@ref) and the transducer [`AbstractMCMC.Sample`](@ref). There is no "official" way for providing initial parameter values yet. -However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `initial_params` keyword argument for setting the initial values when sampling a single chain. -To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `initial_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): -- `initial_params` (default: `nothing`): if `initial_params isa AbstractArray`, then the `i`th element of `initial_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `initial_params = FillArrays.Fill(x, N)`. +However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain. +To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): +- `init_params` (default: `nothing`): if set to `init_params !== nothing`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = Iterators.repeated(x)` or `init_params = FillArrays.Fill(x, N)`. Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`. diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index dc464d42..64f20f97 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -8,7 +8,6 @@ using ProgressLogging: ProgressLogging using StatsBase: StatsBase using TerminalLoggers: TerminalLoggers using Transducers: Transducers -using FillArrays: FillArrays using Distributed: Distributed using Logging: Logging diff --git a/src/sample.jl b/src/sample.jl index 58217f39..dc951ca2 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -103,7 +103,6 @@ function mcmcsample( discard_initial=0, thinning=1, chain_type::Type=Any, - initial_state=nothing, kwargs..., ) # Check the number of requested samples. @@ -123,11 +122,7 @@ function mcmcsample( end # Obtain the initial sample and state. - sample, state = if initial_state === nothing - step(rng, model, sampler; kwargs...) - else - step(rng, model, sampler, initial_state; kwargs...) - end + sample, state = step(rng, model, sampler; kwargs...) # Discard initial samples. for i in 1:discard_initial @@ -216,7 +211,6 @@ function mcmcsample( callback=nothing, discard_initial=0, thinning=1, - initial_state=nothing, kwargs..., ) @@ -226,11 +220,7 @@ function mcmcsample( @ifwithprogresslogger progress name = progressname begin # Obtain the initial sample and state. - sample, state = if initial_state === nothing - step(rng, model, sampler; kwargs...) - else - step(rng, model, sampler, state; kwargs...) - end + sample, state = step(rng, model, sampler; kwargs...) # Discard initial samples. for _ in 1:discard_initial @@ -298,8 +288,7 @@ function mcmcsample( nchains::Integer; progress=PROGRESS[], progressname="Sampling ($(min(nchains, Threads.nthreads())) threads)", - initial_params=nothing, - initial_state=nothing, + init_params=nothing, kwargs..., ) # Check if actually multiple threads are used. @@ -323,9 +312,8 @@ function mcmcsample( # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) - # Ensure that initial parameters and states are `nothing` or of the correct length - check_initial_params(initial_params, nchains) - check_initial_state(initial_state, nchains) + # Ensure that initial parameters are `nothing` or indexable + _init_params = _first_or_nothing(init_params, nchains) # Set up a chains vector. chains = Vector{Any}(undef, nchains) @@ -376,15 +364,10 @@ function mcmcsample( _sampler, N; progress=false, - initial_params=if initial_params === nothing - nothing - else - initial_params[chainidx] - end, - initial_state=if initial_state === nothing + init_params=if _init_params === nothing nothing else - initial_state[chainidx] + _init_params[chainidx] end, kwargs..., ) @@ -414,8 +397,7 @@ function mcmcsample( nchains::Integer; progress=PROGRESS[], progressname="Sampling ($(Distributed.nworkers()) processes)", - initial_params=nothing, - initial_state=nothing, + init_params=nothing, kwargs..., ) # Check if actually multiple processes are used. @@ -428,15 +410,6 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end - # Ensure that initial parameters and states are `nothing` or of the correct length - check_initial_params(initial_params, nchains) - check_initial_state(initial_state, nchains) - - _initial_params = - initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params - _initial_state = - initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state - # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -472,7 +445,7 @@ function mcmcsample( Distributed.@async begin try - function sample_chain(seed, initial_params, initial_state) + function sample_chain(seed, init_params=nothing) # Seed a new random number generator with the pre-made seed. Random.seed!(rng, seed) @@ -483,8 +456,7 @@ function mcmcsample( sampler, N; progress=false, - initial_params=initial_params, - initial_state=initial_state, + init_params=init_params, kwargs..., ) @@ -494,9 +466,11 @@ function mcmcsample( # Return the new chain. return chain end - chains = Distributed.pmap( - sample_chain, pool, seeds, _initial_params, _initial_state - ) + chains = if init_params === nothing + Distributed.pmap(sample_chain, pool, seeds) + else + Distributed.pmap(sample_chain, pool, seeds, init_params) + end finally # Stop updating the progress bar. progress && put!(channel, false) @@ -517,8 +491,7 @@ function mcmcsample( N::Integer, nchains::Integer; progressname="Sampling", - initial_params=nothing, - initial_state=nothing, + init_params=nothing, kwargs..., ) # Check if the number of chains is larger than the number of samples @@ -526,20 +499,11 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end - # Ensure that initial parameters and states are `nothing` or of the correct length - check_initial_params(initial_params, nchains) - check_initial_state(initial_state, nchains) - - _initial_params = - initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params - _initial_state = - initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state - # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) # Sample the chains. - function sample_chain(i, seed, initial_params, initial_state) + function sample_chain(i, seed, init_params=nothing) # Seed a new random number generator with the pre-made seed. Random.seed!(rng, seed) @@ -550,13 +514,16 @@ function mcmcsample( sampler, N; progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"), - initial_params=initial_params, - initial_state=initial_state, + init_params=init_params, kwargs..., ) end - chains = map(sample_chain, 1:nchains, seeds, _initial_params, _initial_state) + chains = if init_params === nothing + map(sample_chain, 1:nchains, seeds) + else + map(sample_chain, 1:nchains, seeds, init_params) + end # Concatenate the chains together. return chainsstack(tighten_eltype(chains)) @@ -565,38 +532,31 @@ end tighten_eltype(x) = x tighten_eltype(x::Vector{Any}) = map(identity, x) -@nospecialize check_initial_params(x, n) = throw( - ArgumentError( - "initial parameters must be specified as a vector of length equal to the number of chains or `nothing`", - ), -) -check_initial_params(::Nothing, n) = nothing -function check_initial_params(x::AbstractArray, n) - if length(x) != n - throw( - ArgumentError( - "incorrect number of initial parameters (expected $n, received $(length(x))" - ), - ) - end +""" + _first_or_nothing(x, n::Int) - return nothing -end +Return the first `n` elements of collection `x`, or `nothing` if `x === nothing`. -@nospecialize check_initial_state(x, n) = throw( - ArgumentError( - "initial states must be specified as a vector of length equal to the number of chains or `nothing`", - ), -) -check_initial_state(::Nothing, n) = nothing -function check_initial_state(x::AbstractArray, n) - if length(x) != n - throw( - ArgumentError( - "incorrect number of initial states (expected $n, received $(length(x))" - ), - ) +If `x !== nothing`, then `x` has to contain at least `n` elements. +""" +function _first_or_nothing(x, n::Int) + y = _first(x, n) + length(y) == n || throw( + ArgumentError("not enough initial parameters (expected $n, received $(length(y))"), + ) + return y +end +_first_or_nothing(::Nothing, ::Int) = nothing + +# `first(x, n::Int)` requires Julia 1.6 +function _first(x, n::Int) + @static if VERSION >= v"1.6.0-DEV.431" + first(x, n) + else + if x isa AbstractVector + @inbounds x[firstindex(x):min(firstindex(x) + n - 1, lastindex(x))] + else + collect(Iterators.take(x, n)) + end end - - return nothing end diff --git a/test/runtests.jl b/test/runtests.jl index 909ae8b3..75aac0f1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,6 @@ using IJulia using LogDensityProblems using LoggingExtras: TeeLogger, EarlyFilteredLogger using TerminalLoggers: TerminalLogger -using FillArrays: FillArrays using Transducers using Distributed diff --git a/test/sample.jl b/test/sample.jl index dcc87526..261cc1ef 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -28,7 +28,7 @@ # initial parameters chain = sample( - MyModel(), MySampler(), 3; progress=false, initial_params=(b=3.2, a=-1.8) + MyModel(), MySampler(), 3; progress=false, init_params=(b=3.2, a=-1.8) ) @test chain[1].a == -1.8 @test chain[1].b == 3.2 @@ -162,59 +162,35 @@ end # initial parameters - nchains = 100 - initial_params = [(b=randn(), a=rand()) for _ in 1:nchains] + init_params = [(b=randn(), a=rand()) for _ in 1:100] chains = sample( MyModel(), MySampler(), MCMCThreads(), 3, - nchains; + 100; progress=false, - initial_params=initial_params, + init_params=init_params, ) - @test length(chains) == nchains + @test length(chains) == 100 @test all( chain[1].a == params.a && chain[1].b == params.b for - (chain, params) in zip(chains, initial_params) + (chain, params) in zip(chains, init_params) ) - initial_params = (a=randn(), b=rand()) + init_params = (a=randn(), b=rand()) chains = sample( MyModel(), MySampler(), MCMCThreads(), 3, - nchains; + 100; progress=false, - initial_params=FillArrays.Fill(initial_params, nchains), + init_params=Iterators.repeated(init_params), ) - @test length(chains) == nchains + @test length(chains) == 100 @test all( - chain[1].a == initial_params.a && chain[1].b == initial_params.b for - chain in chains - ) - - # Too many `initial_params` - @test_throws ArgumentError sample( - MyModel(), - MySampler(), - MCMCThreads(), - 3, - nchains; - progress=false, - initial_params=FillArrays.Fill(initial_params, nchains + 1), - ) - - # Too few `initial_params` - @test_throws ArgumentError sample( - MyModel(), - MySampler(), - MCMCThreads(), - 3, - nchains; - progress=false, - initial_params=FillArrays.Fill(initial_params, nchains - 1), + chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains ) end @@ -298,59 +274,35 @@ @test all(l.level > Logging.LogLevel(-1) for l in logs) # initial parameters - nchains = 100 - initial_params = [(a=randn(), b=rand()) for _ in 1:nchains] + init_params = [(a=randn(), b=rand()) for _ in 1:100] chains = sample( MyModel(), MySampler(), MCMCDistributed(), 3, - nchains; + 100; progress=false, - initial_params=initial_params, + init_params=init_params, ) - @test length(chains) == nchains + @test length(chains) == 100 @test all( chain[1].a == params.a && chain[1].b == params.b for - (chain, params) in zip(chains, initial_params) + (chain, params) in zip(chains, init_params) ) - initial_params = (b=randn(), a=rand()) + init_params = (b=randn(), a=rand()) chains = sample( MyModel(), MySampler(), MCMCDistributed(), 3, - nchains; + 100; progress=false, - initial_params=FillArrays.Fill(initial_params, nchains), + init_params=Iterators.repeated(init_params), ) - @test length(chains) == nchains + @test length(chains) == 100 @test all( - chain[1].a == initial_params.a && chain[1].b == initial_params.b for - chain in chains - ) - - # Too many `initial_params` - @test_throws ArgumentError sample( - MyModel(), - MySampler(), - MCMCDistributed(), - 3, - nchains; - progress=false, - initial_params=FillArrays.Fill(initial_params, nchains + 1), - ) - - # Too few `initial_params` - @test_throws ArgumentError sample( - MyModel(), - MySampler(), - MCMCDistributed(), - 3, - nchains; - progress=false, - initial_params=FillArrays.Fill(initial_params, nchains - 1), + chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains ) # Remove workers @@ -360,21 +312,13 @@ @testset "Serial sampling" begin # No dedicated chains type N = 10_000 - chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; progress=false) + chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000) @test chains isa Vector{<:Vector{<:MySample}} @test length(chains) == 1000 @test all(length(x) == N for x in chains) Random.seed!(1234) - chains = sample( - MyModel(), - MySampler(), - MCMCSerial(), - N, - 1000; - chain_type=MyChain, - progress=false, - ) + chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain) # Test output type and size. @test chains isa Vector{<:MyChain} @@ -390,15 +334,7 @@ # Test reproducibility. Random.seed!(1234) - chains2 = sample( - MyModel(), - MySampler(), - MCMCSerial(), - N, - 1000; - chain_type=MyChain, - progress=false, - ) + chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain) @test all(ismissing(c.as[1]) for c in chains2) @test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N) @test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) @@ -424,59 +360,35 @@ @test all(l.level > Logging.LogLevel(-1) for l in logs) # initial parameters - nchains = 100 - initial_params = [(a=rand(), b=randn()) for _ in 1:nchains] + init_params = [(a=rand(), b=randn()) for _ in 1:100] chains = sample( MyModel(), MySampler(), MCMCSerial(), 3, - nchains; + 100; progress=false, - initial_params=initial_params, + init_params=init_params, ) - @test length(chains) == nchains + @test length(chains) == 100 @test all( chain[1].a == params.a && chain[1].b == params.b for - (chain, params) in zip(chains, initial_params) + (chain, params) in zip(chains, init_params) ) - initial_params = (b=rand(), a=randn()) + init_params = (b=rand(), a=randn()) chains = sample( MyModel(), MySampler(), MCMCSerial(), 3, - nchains; + 100; progress=false, - initial_params=FillArrays.Fill(initial_params, nchains), + init_params=Iterators.repeated(init_params), ) - @test length(chains) == nchains + @test length(chains) == 100 @test all( - chain[1].a == initial_params.a && chain[1].b == initial_params.b for - chain in chains - ) - - # Too many `initial_params` - @test_throws ArgumentError sample( - MyModel(), - MySampler(), - MCMCSerial(), - 3, - nchains; - progress=false, - initial_params=FillArrays.Fill(initial_params, nchains + 1), - ) - - # Too few `initial_params` - @test_throws ArgumentError sample( - MyModel(), - MySampler(), - MCMCSerial(), - 3, - nchains; - progress=false, - initial_params=FillArrays.Fill(initial_params, nchains - 1), + chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains ) end @@ -674,63 +586,4 @@ ) @test it_array == collect(1:size(chain, 1)) end - - @testset "Providing initial state" begin - function record_state( - rng, model, sampler, sample, state, i; states_channel, kwargs... - ) - return put!(states_channel, state) - end - - initial_state = 10 - - @testset "sample" begin - n = 10 - states_channel = Channel{Int}(n) - chain = sample( - MyModel(), - MySampler(), - n; - initial_state=initial_state, - callback=record_state, - states_channel=states_channel, - ) - - # Extract the states. - states = [take!(states_channel) for _ in 1:n] - @test length(states) == n - for i in 1:n - @test states[i] == initial_state + i - end - end - - @testset "sample with $mode" for mode in - [MCMCSerial(), MCMCThreads(), MCMCDistributed()] - nchains = 4 - initial_state = 10 - states_channel = if mode === MCMCDistributed() - # Need to use `RemoteChannel` for this. - RemoteChannel(() -> Channel{Int}(nchains)) - else - Channel{Int}(nchains) - end - chain = sample( - MyModel(), - MySampler(), - mode, - 1, - nchains; - initial_state=FillArrays.Fill(initial_state, nchains), - callback=record_state, - states_channel=states_channel, - ) - - # Extract the states. - states = [take!(states_channel) for _ in 1:nchains] - @test length(states) == nchains - for i in 1:nchains - @test states[i] == initial_state + 1 - end - end - end end diff --git a/test/utils.jl b/test/utils.jl index 1e29a473..f69fcdab 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -22,12 +22,12 @@ function AbstractMCMC.step( sampler::MySampler, state::Union{Nothing,Integer}=nothing; loggers=false, - initial_params=nothing, + init_params=nothing, kwargs..., ) # sample `a` is missing in the first step if not provided - a, b = if state === nothing && initial_params !== nothing - initial_params.a, initial_params.b + a, b = if state === nothing && init_params !== nothing + init_params.a, init_params.b else (state === nothing ? missing : rand(rng)), randn(rng) end