From 420e588e5ecca1da8e05d608044f2cf0c0bdd2a4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 24 Oct 2023 18:39:59 +0100 Subject: [PATCH 1/6] Revert "Merge pull request #119 from TuringLang/torfjelde/initial-state" This reverts commit 8d45ff49780e1aee2f02ad568eb81908f85980b1, reversing changes made to d5218159232bc3b035ad9c789e874ac68b5643d5. --- Project.toml | 2 - docs/src/api.md | 8 +-- src/AbstractMCMC.jl | 1 - src/sample.jl | 100 +++++++++--------------------- test/sample.jl | 146 +++++++++++--------------------------------- test/utils.jl | 6 +- 6 files changed, 68 insertions(+), 195 deletions(-) diff --git a/Project.toml b/Project.toml index a796c429..90117048 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,6 @@ version = "4.5.0" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" ConsoleProgressMonitor = "88cd18e8-d9cc-4ea6-8889-5259c0d15c8b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" -FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LoggingExtras = "e6f89c97-d47a-5376-807f-9c37f3926c36" @@ -22,7 +21,6 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" [compat] BangBang = "0.3.19" ConsoleProgressMonitor = "0.1" -FillArrays = "1" LogDensityProblems = "2" LoggingExtras = "0.4, 0.5, 1" ProgressLogging = "0.1" diff --git a/docs/src/api.md b/docs/src/api.md index aabf8d6f..f0c2c158 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -75,16 +75,14 @@ Common keyword arguments for regular and parallel sampling are: where `sample` is the most recent sample of the Markov chain and `state` and `iteration` are the current state and iteration of the sampler - `discard_initial` (default: `0`): number of initial samples that are discarded - `thinning` (default: `1`): factor by which to thin samples. -- `initial_state` (default: `nothing`): if `initial_state !== nothing`, the first call to [`AbstractMCMC.step`](@ref) - is passed `initial_state` as the `state` argument. !!! info The common keyword arguments `progress`, `chain_type`, and `callback` are not supported by the iterator [`AbstractMCMC.steps`](@ref) and the transducer [`AbstractMCMC.Sample`](@ref). There is no "official" way for providing initial parameter values yet. -However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `initial_params` keyword argument for setting the initial values when sampling a single chain. -To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `initial_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): -- `initial_params` (default: `nothing`): if `initial_params isa AbstractArray`, then the `i`th element of `initial_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `initial_params = FillArrays.Fill(x, N)`. +However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain. +To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): +- `init_params` (default: `nothing`): if `init_params isa AbstractArray`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = FillArrays.Fill(x, N)`. Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`. diff --git a/src/AbstractMCMC.jl b/src/AbstractMCMC.jl index dc464d42..64f20f97 100644 --- a/src/AbstractMCMC.jl +++ b/src/AbstractMCMC.jl @@ -8,7 +8,6 @@ using ProgressLogging: ProgressLogging using StatsBase: StatsBase using TerminalLoggers: TerminalLoggers using Transducers: Transducers -using FillArrays: FillArrays using Distributed: Distributed using Logging: Logging diff --git a/src/sample.jl b/src/sample.jl index 58217f39..6c9c32ae 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -103,7 +103,6 @@ function mcmcsample( discard_initial=0, thinning=1, chain_type::Type=Any, - initial_state=nothing, kwargs..., ) # Check the number of requested samples. @@ -123,11 +122,7 @@ function mcmcsample( end # Obtain the initial sample and state. - sample, state = if initial_state === nothing - step(rng, model, sampler; kwargs...) - else - step(rng, model, sampler, initial_state; kwargs...) - end + sample, state = step(rng, model, sampler; kwargs...) # Discard initial samples. for i in 1:discard_initial @@ -216,7 +211,6 @@ function mcmcsample( callback=nothing, discard_initial=0, thinning=1, - initial_state=nothing, kwargs..., ) @@ -226,11 +220,7 @@ function mcmcsample( @ifwithprogresslogger progress name = progressname begin # Obtain the initial sample and state. - sample, state = if initial_state === nothing - step(rng, model, sampler; kwargs...) - else - step(rng, model, sampler, state; kwargs...) - end + sample, state = step(rng, model, sampler; kwargs...) # Discard initial samples. for _ in 1:discard_initial @@ -298,8 +288,7 @@ function mcmcsample( nchains::Integer; progress=PROGRESS[], progressname="Sampling ($(min(nchains, Threads.nthreads())) threads)", - initial_params=nothing, - initial_state=nothing, + init_params=nothing, kwargs..., ) # Check if actually multiple threads are used. @@ -323,9 +312,8 @@ function mcmcsample( # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) - # Ensure that initial parameters and states are `nothing` or of the correct length - check_initial_params(initial_params, nchains) - check_initial_state(initial_state, nchains) + # Ensure that initial parameters are `nothing` or of the correct length + check_initial_params(init_params, nchains) # Set up a chains vector. chains = Vector{Any}(undef, nchains) @@ -376,15 +364,10 @@ function mcmcsample( _sampler, N; progress=false, - initial_params=if initial_params === nothing - nothing - else - initial_params[chainidx] - end, - initial_state=if initial_state === nothing + init_params=if init_params === nothing nothing else - initial_state[chainidx] + init_params[chainidx] end, kwargs..., ) @@ -414,8 +397,7 @@ function mcmcsample( nchains::Integer; progress=PROGRESS[], progressname="Sampling ($(Distributed.nworkers()) processes)", - initial_params=nothing, - initial_state=nothing, + init_params=nothing, kwargs..., ) # Check if actually multiple processes are used. @@ -428,14 +410,8 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end - # Ensure that initial parameters and states are `nothing` or of the correct length - check_initial_params(initial_params, nchains) - check_initial_state(initial_state, nchains) - - _initial_params = - initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params - _initial_state = - initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state + # Ensure that initial parameters are `nothing` or of the correct length + check_initial_params(init_params, nchains) # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -472,7 +448,7 @@ function mcmcsample( Distributed.@async begin try - function sample_chain(seed, initial_params, initial_state) + function sample_chain(seed, init_params=nothing) # Seed a new random number generator with the pre-made seed. Random.seed!(rng, seed) @@ -483,8 +459,7 @@ function mcmcsample( sampler, N; progress=false, - initial_params=initial_params, - initial_state=initial_state, + init_params=init_params, kwargs..., ) @@ -494,9 +469,11 @@ function mcmcsample( # Return the new chain. return chain end - chains = Distributed.pmap( - sample_chain, pool, seeds, _initial_params, _initial_state - ) + chains = if init_params === nothing + Distributed.pmap(sample_chain, pool, seeds) + else + Distributed.pmap(sample_chain, pool, seeds, init_params) + end finally # Stop updating the progress bar. progress && put!(channel, false) @@ -517,8 +494,7 @@ function mcmcsample( N::Integer, nchains::Integer; progressname="Sampling", - initial_params=nothing, - initial_state=nothing, + init_params=nothing, kwargs..., ) # Check if the number of chains is larger than the number of samples @@ -526,20 +502,14 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end - # Ensure that initial parameters and states are `nothing` or of the correct length - check_initial_params(initial_params, nchains) - check_initial_state(initial_state, nchains) - - _initial_params = - initial_params === nothing ? FillArrays.Fill(nothing, nchains) : initial_params - _initial_state = - initial_state === nothing ? FillArrays.Fill(nothing, nchains) : initial_state + # Ensure that initial parameters are `nothing` or of the correct length + check_initial_params(init_params, nchains) # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) # Sample the chains. - function sample_chain(i, seed, initial_params, initial_state) + function sample_chain(i, seed, init_params=nothing) # Seed a new random number generator with the pre-made seed. Random.seed!(rng, seed) @@ -550,13 +520,16 @@ function mcmcsample( sampler, N; progressname=string(progressname, " (Chain ", i, " of ", nchains, ")"), - initial_params=initial_params, - initial_state=initial_state, + init_params=init_params, kwargs..., ) end - chains = map(sample_chain, 1:nchains, seeds, _initial_params, _initial_state) + chains = if init_params === nothing + map(sample_chain, 1:nchains, seeds) + else + map(sample_chain, 1:nchains, seeds, init_params) + end # Concatenate the chains together. return chainsstack(tighten_eltype(chains)) @@ -570,6 +543,7 @@ tighten_eltype(x::Vector{Any}) = map(identity, x) "initial parameters must be specified as a vector of length equal to the number of chains or `nothing`", ), ) + check_initial_params(::Nothing, n) = nothing function check_initial_params(x::AbstractArray, n) if length(x) != n @@ -582,21 +556,3 @@ function check_initial_params(x::AbstractArray, n) return nothing end - -@nospecialize check_initial_state(x, n) = throw( - ArgumentError( - "initial states must be specified as a vector of length equal to the number of chains or `nothing`", - ), -) -check_initial_state(::Nothing, n) = nothing -function check_initial_state(x::AbstractArray, n) - if length(x) != n - throw( - ArgumentError( - "incorrect number of initial states (expected $n, received $(length(x))" - ), - ) - end - - return nothing -end diff --git a/test/sample.jl b/test/sample.jl index dcc87526..22f4b26d 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -28,7 +28,7 @@ # initial parameters chain = sample( - MyModel(), MySampler(), 3; progress=false, initial_params=(b=3.2, a=-1.8) + MyModel(), MySampler(), 3; progress=false, init_params=(b=3.2, a=-1.8) ) @test chain[1].a == -1.8 @test chain[1].b == 3.2 @@ -163,7 +163,7 @@ # initial parameters nchains = 100 - initial_params = [(b=randn(), a=rand()) for _ in 1:nchains] + init_params = [(b=randn(), a=rand()) for _ in 1:nchains] chains = sample( MyModel(), MySampler(), @@ -171,15 +171,15 @@ 3, nchains; progress=false, - initial_params=initial_params, + init_params=init_params, ) @test length(chains) == nchains @test all( chain[1].a == params.a && chain[1].b == params.b for - (chain, params) in zip(chains, initial_params) + (chain, params) in zip(chains, init_params) ) - initial_params = (a=randn(), b=rand()) + init_params = (a=randn(), b=rand()) chains = sample( MyModel(), MySampler(), @@ -187,15 +187,14 @@ 3, nchains; progress=false, - initial_params=FillArrays.Fill(initial_params, nchains), + init_params=FillArrays.Fill(init_params, nchains), ) @test length(chains) == nchains @test all( - chain[1].a == initial_params.a && chain[1].b == initial_params.b for - chain in chains + chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains ) - # Too many `initial_params` + # Too many `init_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -203,10 +202,10 @@ 3, nchains; progress=false, - initial_params=FillArrays.Fill(initial_params, nchains + 1), + init_params=FillArrays.Fill(init_params, nchains + 1), ) - # Too few `initial_params` + # Too few `init_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -214,7 +213,7 @@ 3, nchains; progress=false, - initial_params=FillArrays.Fill(initial_params, nchains - 1), + init_params=FillArrays.Fill(init_params, nchains - 1), ) end @@ -299,7 +298,7 @@ # initial parameters nchains = 100 - initial_params = [(a=randn(), b=rand()) for _ in 1:nchains] + init_params = [(a=randn(), b=rand()) for _ in 1:nchains] chains = sample( MyModel(), MySampler(), @@ -307,15 +306,15 @@ 3, nchains; progress=false, - initial_params=initial_params, + init_params=init_params, ) @test length(chains) == nchains @test all( chain[1].a == params.a && chain[1].b == params.b for - (chain, params) in zip(chains, initial_params) + (chain, params) in zip(chains, init_params) ) - initial_params = (b=randn(), a=rand()) + init_params = (b=randn(), a=rand()) chains = sample( MyModel(), MySampler(), @@ -323,15 +322,14 @@ 3, nchains; progress=false, - initial_params=FillArrays.Fill(initial_params, nchains), + init_params=FillArrays.Fill(init_params, nchains), ) @test length(chains) == nchains @test all( - chain[1].a == initial_params.a && chain[1].b == initial_params.b for - chain in chains + chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains ) - # Too many `initial_params` + # Too many `init_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -339,10 +337,10 @@ 3, nchains; progress=false, - initial_params=FillArrays.Fill(initial_params, nchains + 1), + init_params=FillArrays.Fill(init_params, nchains + 1), ) - # Too few `initial_params` + # Too few `init_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -350,7 +348,7 @@ 3, nchains; progress=false, - initial_params=FillArrays.Fill(initial_params, nchains - 1), + init_params=FillArrays.Fill(init_params, nchains - 1), ) # Remove workers @@ -360,21 +358,13 @@ @testset "Serial sampling" begin # No dedicated chains type N = 10_000 - chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; progress=false) + chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000) @test chains isa Vector{<:Vector{<:MySample}} @test length(chains) == 1000 @test all(length(x) == N for x in chains) Random.seed!(1234) - chains = sample( - MyModel(), - MySampler(), - MCMCSerial(), - N, - 1000; - chain_type=MyChain, - progress=false, - ) + chains = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain) # Test output type and size. @test chains isa Vector{<:MyChain} @@ -390,15 +380,7 @@ # Test reproducibility. Random.seed!(1234) - chains2 = sample( - MyModel(), - MySampler(), - MCMCSerial(), - N, - 1000; - chain_type=MyChain, - progress=false, - ) + chains2 = sample(MyModel(), MySampler(), MCMCSerial(), N, 1000; chain_type=MyChain) @test all(ismissing(c.as[1]) for c in chains2) @test all(c1.as[i] == c2.as[i] for (c1, c2) in zip(chains, chains2), i in 2:N) @test all(c1.bs[i] == c2.bs[i] for (c1, c2) in zip(chains, chains2), i in 1:N) @@ -425,7 +407,7 @@ # initial parameters nchains = 100 - initial_params = [(a=rand(), b=randn()) for _ in 1:nchains] + init_params = [(a=rand(), b=randn()) for _ in 1:nchains] chains = sample( MyModel(), MySampler(), @@ -433,15 +415,15 @@ 3, nchains; progress=false, - initial_params=initial_params, + init_params=init_params, ) @test length(chains) == nchains @test all( chain[1].a == params.a && chain[1].b == params.b for - (chain, params) in zip(chains, initial_params) + (chain, params) in zip(chains, init_params) ) - initial_params = (b=rand(), a=randn()) + init_params = (b=rand(), a=randn()) chains = sample( MyModel(), MySampler(), @@ -449,15 +431,14 @@ 3, nchains; progress=false, - initial_params=FillArrays.Fill(initial_params, nchains), + init_params=FillArrays.Fill(init_params, nchains), ) @test length(chains) == nchains @test all( - chain[1].a == initial_params.a && chain[1].b == initial_params.b for - chain in chains + chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains ) - # Too many `initial_params` + # Too many `init_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -465,10 +446,10 @@ 3, nchains; progress=false, - initial_params=FillArrays.Fill(initial_params, nchains + 1), + init_params=FillArrays.Fill(init_params, nchains + 1), ) - # Too few `initial_params` + # Too few `init_params` @test_throws ArgumentError sample( MyModel(), MySampler(), @@ -476,7 +457,7 @@ 3, nchains; progress=false, - initial_params=FillArrays.Fill(initial_params, nchains - 1), + init_params=FillArrays.Fill(init_params, nchains - 1), ) end @@ -674,63 +655,4 @@ ) @test it_array == collect(1:size(chain, 1)) end - - @testset "Providing initial state" begin - function record_state( - rng, model, sampler, sample, state, i; states_channel, kwargs... - ) - return put!(states_channel, state) - end - - initial_state = 10 - - @testset "sample" begin - n = 10 - states_channel = Channel{Int}(n) - chain = sample( - MyModel(), - MySampler(), - n; - initial_state=initial_state, - callback=record_state, - states_channel=states_channel, - ) - - # Extract the states. - states = [take!(states_channel) for _ in 1:n] - @test length(states) == n - for i in 1:n - @test states[i] == initial_state + i - end - end - - @testset "sample with $mode" for mode in - [MCMCSerial(), MCMCThreads(), MCMCDistributed()] - nchains = 4 - initial_state = 10 - states_channel = if mode === MCMCDistributed() - # Need to use `RemoteChannel` for this. - RemoteChannel(() -> Channel{Int}(nchains)) - else - Channel{Int}(nchains) - end - chain = sample( - MyModel(), - MySampler(), - mode, - 1, - nchains; - initial_state=FillArrays.Fill(initial_state, nchains), - callback=record_state, - states_channel=states_channel, - ) - - # Extract the states. - states = [take!(states_channel) for _ in 1:nchains] - @test length(states) == nchains - for i in 1:nchains - @test states[i] == initial_state + 1 - end - end - end end diff --git a/test/utils.jl b/test/utils.jl index 1e29a473..f69fcdab 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -22,12 +22,12 @@ function AbstractMCMC.step( sampler::MySampler, state::Union{Nothing,Integer}=nothing; loggers=false, - initial_params=nothing, + init_params=nothing, kwargs..., ) # sample `a` is missing in the first step if not provided - a, b = if state === nothing && initial_params !== nothing - initial_params.a, initial_params.b + a, b = if state === nothing && init_params !== nothing + init_params.a, init_params.b else (state === nothing ? missing : rand(rng)), randn(rng) end From 8919bfa9682f3a912d37b340d09c9a008047e256 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 24 Oct 2023 18:43:13 +0100 Subject: [PATCH 2/6] Revert "Update Project.toml" This reverts commit d5218159232bc3b035ad9c789e874ac68b5643d5. --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 90117048..40f90c96 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.5.0" +version = "4.6.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" From 2be66bc9d39c60ea3d9a131b05aa8e9d26a26d1a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 24 Oct 2023 18:43:19 +0100 Subject: [PATCH 3/6] Revert "Merge pull request #126 from TuringLang/torfjelde/init-params-fix" This reverts commit 084f80916c937ae2dffd5b55380f9fcc817a23f0, reversing changes made to caeade2abe60b6803201cd341f7d62595465f6b2. --- Project.toml | 5 +-- docs/src/api.md | 2 +- src/sample.jl | 54 +++++++++++++----------- test/runtests.jl | 1 - test/sample.jl | 105 ++++++++--------------------------------------- 5 files changed, 50 insertions(+), 117 deletions(-) diff --git a/Project.toml b/Project.toml index 40f90c96..7960600c 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.6.0" +version = "4.4.2" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" @@ -30,10 +30,9 @@ Transducers = "0.4.30" julia = "1.6" [extras] -FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["FillArrays", "IJulia", "Statistics", "Test"] +test = ["IJulia", "Statistics", "Test"] diff --git a/docs/src/api.md b/docs/src/api.md index f0c2c158..629e4c37 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -82,7 +82,7 @@ Common keyword arguments for regular and parallel sampling are: There is no "official" way for providing initial parameter values yet. However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain. To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): -- `init_params` (default: `nothing`): if `init_params isa AbstractArray`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = FillArrays.Fill(x, N)`. +- `init_params` (default: `nothing`): if set to `init_params !== nothing`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = Iterators.repeated(x)` or `init_params = FillArrays.Fill(x, N)`. Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`. diff --git a/src/sample.jl b/src/sample.jl index 6c9c32ae..dc951ca2 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -312,8 +312,8 @@ function mcmcsample( # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) - # Ensure that initial parameters are `nothing` or of the correct length - check_initial_params(init_params, nchains) + # Ensure that initial parameters are `nothing` or indexable + _init_params = _first_or_nothing(init_params, nchains) # Set up a chains vector. chains = Vector{Any}(undef, nchains) @@ -364,10 +364,10 @@ function mcmcsample( _sampler, N; progress=false, - init_params=if init_params === nothing + init_params=if _init_params === nothing nothing else - init_params[chainidx] + _init_params[chainidx] end, kwargs..., ) @@ -410,9 +410,6 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end - # Ensure that initial parameters are `nothing` or of the correct length - check_initial_params(init_params, nchains) - # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -502,9 +499,6 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end - # Ensure that initial parameters are `nothing` or of the correct length - check_initial_params(init_params, nchains) - # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -538,21 +532,31 @@ end tighten_eltype(x) = x tighten_eltype(x::Vector{Any}) = map(identity, x) -@nospecialize check_initial_params(x, n) = throw( - ArgumentError( - "initial parameters must be specified as a vector of length equal to the number of chains or `nothing`", - ), -) +""" + _first_or_nothing(x, n::Int) -check_initial_params(::Nothing, n) = nothing -function check_initial_params(x::AbstractArray, n) - if length(x) != n - throw( - ArgumentError( - "incorrect number of initial parameters (expected $n, received $(length(x))" - ), - ) - end +Return the first `n` elements of collection `x`, or `nothing` if `x === nothing`. - return nothing +If `x !== nothing`, then `x` has to contain at least `n` elements. +""" +function _first_or_nothing(x, n::Int) + y = _first(x, n) + length(y) == n || throw( + ArgumentError("not enough initial parameters (expected $n, received $(length(y))"), + ) + return y +end +_first_or_nothing(::Nothing, ::Int) = nothing + +# `first(x, n::Int)` requires Julia 1.6 +function _first(x, n::Int) + @static if VERSION >= v"1.6.0-DEV.431" + first(x, n) + else + if x isa AbstractVector + @inbounds x[firstindex(x):min(firstindex(x) + n - 1, lastindex(x))] + else + collect(Iterators.take(x, n)) + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index 909ae8b3..75aac0f1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,6 @@ using IJulia using LogDensityProblems using LoggingExtras: TeeLogger, EarlyFilteredLogger using TerminalLoggers: TerminalLogger -using FillArrays: FillArrays using Transducers using Distributed diff --git a/test/sample.jl b/test/sample.jl index 22f4b26d..261cc1ef 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -162,18 +162,17 @@ end # initial parameters - nchains = 100 - init_params = [(b=randn(), a=rand()) for _ in 1:nchains] + init_params = [(b=randn(), a=rand()) for _ in 1:100] chains = sample( MyModel(), MySampler(), MCMCThreads(), 3, - nchains; + 100; progress=false, init_params=init_params, ) - @test length(chains) == nchains + @test length(chains) == 100 @test all( chain[1].a == params.a && chain[1].b == params.b for (chain, params) in zip(chains, init_params) @@ -185,36 +184,14 @@ MySampler(), MCMCThreads(), 3, - nchains; + 100; progress=false, - init_params=FillArrays.Fill(init_params, nchains), + init_params=Iterators.repeated(init_params), ) - @test length(chains) == nchains + @test length(chains) == 100 @test all( chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains ) - - # Too many `init_params` - @test_throws ArgumentError sample( - MyModel(), - MySampler(), - MCMCThreads(), - 3, - nchains; - progress=false, - init_params=FillArrays.Fill(init_params, nchains + 1), - ) - - # Too few `init_params` - @test_throws ArgumentError sample( - MyModel(), - MySampler(), - MCMCThreads(), - 3, - nchains; - progress=false, - init_params=FillArrays.Fill(init_params, nchains - 1), - ) end @testset "Multicore sampling" begin @@ -297,18 +274,17 @@ @test all(l.level > Logging.LogLevel(-1) for l in logs) # initial parameters - nchains = 100 - init_params = [(a=randn(), b=rand()) for _ in 1:nchains] + init_params = [(a=randn(), b=rand()) for _ in 1:100] chains = sample( MyModel(), MySampler(), MCMCDistributed(), 3, - nchains; + 100; progress=false, init_params=init_params, ) - @test length(chains) == nchains + @test length(chains) == 100 @test all( chain[1].a == params.a && chain[1].b == params.b for (chain, params) in zip(chains, init_params) @@ -320,37 +296,15 @@ MySampler(), MCMCDistributed(), 3, - nchains; + 100; progress=false, - init_params=FillArrays.Fill(init_params, nchains), + init_params=Iterators.repeated(init_params), ) - @test length(chains) == nchains + @test length(chains) == 100 @test all( chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains ) - # Too many `init_params` - @test_throws ArgumentError sample( - MyModel(), - MySampler(), - MCMCDistributed(), - 3, - nchains; - progress=false, - init_params=FillArrays.Fill(init_params, nchains + 1), - ) - - # Too few `init_params` - @test_throws ArgumentError sample( - MyModel(), - MySampler(), - MCMCDistributed(), - 3, - nchains; - progress=false, - init_params=FillArrays.Fill(init_params, nchains - 1), - ) - # Remove workers rmprocs(pids...) end @@ -406,18 +360,17 @@ @test all(l.level > Logging.LogLevel(-1) for l in logs) # initial parameters - nchains = 100 - init_params = [(a=rand(), b=randn()) for _ in 1:nchains] + init_params = [(a=rand(), b=randn()) for _ in 1:100] chains = sample( MyModel(), MySampler(), MCMCSerial(), 3, - nchains; + 100; progress=false, init_params=init_params, ) - @test length(chains) == nchains + @test length(chains) == 100 @test all( chain[1].a == params.a && chain[1].b == params.b for (chain, params) in zip(chains, init_params) @@ -429,36 +382,14 @@ MySampler(), MCMCSerial(), 3, - nchains; + 100; progress=false, - init_params=FillArrays.Fill(init_params, nchains), + init_params=Iterators.repeated(init_params), ) - @test length(chains) == nchains + @test length(chains) == 100 @test all( chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains ) - - # Too many `init_params` - @test_throws ArgumentError sample( - MyModel(), - MySampler(), - MCMCSerial(), - 3, - nchains; - progress=false, - init_params=FillArrays.Fill(init_params, nchains + 1), - ) - - # Too few `init_params` - @test_throws ArgumentError sample( - MyModel(), - MySampler(), - MCMCSerial(), - 3, - nchains; - progress=false, - init_params=FillArrays.Fill(init_params, nchains - 1), - ) end @testset "Ensemble sampling: Reproducibility" begin From e99eb65e25b7f9f34d14715977b7661b49d1041c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 24 Oct 2023 18:44:32 +0100 Subject: [PATCH 4/6] Revert "Update callback signature in docs (#130)" This reverts commit caeade2abe60b6803201cd341f7d62595465f6b2. --- docs/src/api.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 629e4c37..52c2c2e1 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -71,8 +71,8 @@ Common keyword arguments for regular and parallel sampling are: - `progress` (default: `AbstractMCMC.PROGRESS[]` which is `true` initially): toggles progress logging - `chain_type` (default: `Any`): determines the type of the returned chain - `callback` (default: `nothing`): if `callback !== nothing`, then - `callback(rng, model, sampler, sample, state, iteration)` is called after every sampling step, - where `sample` is the most recent sample of the Markov chain and `state` and `iteration` are the current state and iteration of the sampler + `callback(rng, model, sampler, sample, iteration)` is called after every sampling step, + where `sample` is the most recent sample of the Markov chain and `iteration` is the current iteration - `discard_initial` (default: `0`): number of initial samples that are discarded - `thinning` (default: `1`): factor by which to thin samples. From 4b38987b02bf5b40efce34959cb00b3c8ed26fa5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 24 Oct 2023 18:44:33 +0100 Subject: [PATCH 5/6] Revert "CompatHelper: bump compat for Documenter to 1 for package docs, (keep existing compat) (#127)" This reverts commit 1c589e8d4e054dd3b30b4a6cdb7bc25577ae8dce. --- docs/Project.toml | 2 +- docs/make.jl | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/Project.toml b/docs/Project.toml index f74dfb58..555443ab 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -3,5 +3,5 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] -Documenter = "1" +Documenter = "0.27" julia = "1" diff --git a/docs/make.jl b/docs/make.jl index 9395d2a0..66d7619c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -9,6 +9,7 @@ makedocs(; format=Documenter.HTML(), modules=[AbstractMCMC], pages=["Home" => "index.md", "api.md", "design.md"], + strict=true, checkdocs=:exports, ) From c4471ad7232373ca260d41b7030cce3344b0743f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 24 Oct 2023 18:46:33 +0100 Subject: [PATCH 6/6] bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 7960600c..702f3639 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.4.2" +version = "4.5.1" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"