diff --git a/Project.toml b/Project.toml index 7960600c..40f90c96 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "80f14c24-f653-4e6a-9b94-39d6b0f70001" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "A lightweight interface for common MCMC methods." -version = "4.4.2" +version = "4.6.0" [deps] BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" @@ -30,9 +30,10 @@ Transducers = "0.4.30" julia = "1.6" [extras] +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IJulia = "7073ff75-c697-5162-941a-fcdaad2a7d2a" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["IJulia", "Statistics", "Test"] +test = ["FillArrays", "IJulia", "Statistics", "Test"] diff --git a/docs/src/api.md b/docs/src/api.md index 52c2c2e1..d89b078a 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -82,7 +82,7 @@ Common keyword arguments for regular and parallel sampling are: There is no "official" way for providing initial parameter values yet. However, multiple packages such as [EllipticalSliceSampling.jl](https://github.com/TuringLang/EllipticalSliceSampling.jl) and [AdvancedMH.jl](https://github.com/TuringLang/AdvancedMH.jl) support an `init_params` keyword argument for setting the initial values when sampling a single chain. To ensure that sampling multiple chains "just works" when sampling of a single chain is implemented, [we decided to support `init_params` in the default implementations of the ensemble methods](https://github.com/TuringLang/AbstractMCMC.jl/pull/94): -- `init_params` (default: `nothing`): if set to `init_params !== nothing`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = Iterators.repeated(x)` or `init_params = FillArrays.Fill(x, N)`. +- `init_params` (default: `nothing`): if `init_params isa AbstractArray`, then the `i`th element of `init_params` is used as initial parameters of the `i`th chain. If one wants to use the same initial parameters `x` for every chain, one can specify e.g. `init_params = FillArrays.Fill(x, N)`. Progress logging can be enabled and disabled globally with `AbstractMCMC.setprogress!(progress)`. diff --git a/src/sample.jl b/src/sample.jl index dc951ca2..6c9c32ae 100644 --- a/src/sample.jl +++ b/src/sample.jl @@ -312,8 +312,8 @@ function mcmcsample( # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) - # Ensure that initial parameters are `nothing` or indexable - _init_params = _first_or_nothing(init_params, nchains) + # Ensure that initial parameters are `nothing` or of the correct length + check_initial_params(init_params, nchains) # Set up a chains vector. chains = Vector{Any}(undef, nchains) @@ -364,10 +364,10 @@ function mcmcsample( _sampler, N; progress=false, - init_params=if _init_params === nothing + init_params=if init_params === nothing nothing else - _init_params[chainidx] + init_params[chainidx] end, kwargs..., ) @@ -410,6 +410,9 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end + # Ensure that initial parameters are `nothing` or of the correct length + check_initial_params(init_params, nchains) + # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -499,6 +502,9 @@ function mcmcsample( @warn "Number of chains ($nchains) is greater than number of samples per chain ($N)" end + # Ensure that initial parameters are `nothing` or of the correct length + check_initial_params(init_params, nchains) + # Create a seed for each chain using the provided random number generator. seeds = rand(rng, UInt, nchains) @@ -532,31 +538,21 @@ end tighten_eltype(x) = x tighten_eltype(x::Vector{Any}) = map(identity, x) -""" - _first_or_nothing(x, n::Int) - -Return the first `n` elements of collection `x`, or `nothing` if `x === nothing`. - -If `x !== nothing`, then `x` has to contain at least `n` elements. -""" -function _first_or_nothing(x, n::Int) - y = _first(x, n) - length(y) == n || throw( - ArgumentError("not enough initial parameters (expected $n, received $(length(y))"), - ) - return y -end -_first_or_nothing(::Nothing, ::Int) = nothing +@nospecialize check_initial_params(x, n) = throw( + ArgumentError( + "initial parameters must be specified as a vector of length equal to the number of chains or `nothing`", + ), +) -# `first(x, n::Int)` requires Julia 1.6 -function _first(x, n::Int) - @static if VERSION >= v"1.6.0-DEV.431" - first(x, n) - else - if x isa AbstractVector - @inbounds x[firstindex(x):min(firstindex(x) + n - 1, lastindex(x))] - else - collect(Iterators.take(x, n)) - end +check_initial_params(::Nothing, n) = nothing +function check_initial_params(x::AbstractArray, n) + if length(x) != n + throw( + ArgumentError( + "incorrect number of initial parameters (expected $n, received $(length(x))" + ), + ) end + + return nothing end diff --git a/test/runtests.jl b/test/runtests.jl index 75aac0f1..909ae8b3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using IJulia using LogDensityProblems using LoggingExtras: TeeLogger, EarlyFilteredLogger using TerminalLoggers: TerminalLogger +using FillArrays: FillArrays using Transducers using Distributed diff --git a/test/sample.jl b/test/sample.jl index 261cc1ef..22f4b26d 100644 --- a/test/sample.jl +++ b/test/sample.jl @@ -162,17 +162,18 @@ end # initial parameters - init_params = [(b=randn(), a=rand()) for _ in 1:100] + nchains = 100 + init_params = [(b=randn(), a=rand()) for _ in 1:nchains] chains = sample( MyModel(), MySampler(), MCMCThreads(), 3, - 100; + nchains; progress=false, init_params=init_params, ) - @test length(chains) == 100 + @test length(chains) == nchains @test all( chain[1].a == params.a && chain[1].b == params.b for (chain, params) in zip(chains, init_params) @@ -184,14 +185,36 @@ MySampler(), MCMCThreads(), 3, - 100; + nchains; progress=false, - init_params=Iterators.repeated(init_params), + init_params=FillArrays.Fill(init_params, nchains), ) - @test length(chains) == 100 + @test length(chains) == nchains @test all( chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains ) + + # Too many `init_params` + @test_throws ArgumentError sample( + MyModel(), + MySampler(), + MCMCThreads(), + 3, + nchains; + progress=false, + init_params=FillArrays.Fill(init_params, nchains + 1), + ) + + # Too few `init_params` + @test_throws ArgumentError sample( + MyModel(), + MySampler(), + MCMCThreads(), + 3, + nchains; + progress=false, + init_params=FillArrays.Fill(init_params, nchains - 1), + ) end @testset "Multicore sampling" begin @@ -274,17 +297,18 @@ @test all(l.level > Logging.LogLevel(-1) for l in logs) # initial parameters - init_params = [(a=randn(), b=rand()) for _ in 1:100] + nchains = 100 + init_params = [(a=randn(), b=rand()) for _ in 1:nchains] chains = sample( MyModel(), MySampler(), MCMCDistributed(), 3, - 100; + nchains; progress=false, init_params=init_params, ) - @test length(chains) == 100 + @test length(chains) == nchains @test all( chain[1].a == params.a && chain[1].b == params.b for (chain, params) in zip(chains, init_params) @@ -296,15 +320,37 @@ MySampler(), MCMCDistributed(), 3, - 100; + nchains; progress=false, - init_params=Iterators.repeated(init_params), + init_params=FillArrays.Fill(init_params, nchains), ) - @test length(chains) == 100 + @test length(chains) == nchains @test all( chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains ) + # Too many `init_params` + @test_throws ArgumentError sample( + MyModel(), + MySampler(), + MCMCDistributed(), + 3, + nchains; + progress=false, + init_params=FillArrays.Fill(init_params, nchains + 1), + ) + + # Too few `init_params` + @test_throws ArgumentError sample( + MyModel(), + MySampler(), + MCMCDistributed(), + 3, + nchains; + progress=false, + init_params=FillArrays.Fill(init_params, nchains - 1), + ) + # Remove workers rmprocs(pids...) end @@ -360,17 +406,18 @@ @test all(l.level > Logging.LogLevel(-1) for l in logs) # initial parameters - init_params = [(a=rand(), b=randn()) for _ in 1:100] + nchains = 100 + init_params = [(a=rand(), b=randn()) for _ in 1:nchains] chains = sample( MyModel(), MySampler(), MCMCSerial(), 3, - 100; + nchains; progress=false, init_params=init_params, ) - @test length(chains) == 100 + @test length(chains) == nchains @test all( chain[1].a == params.a && chain[1].b == params.b for (chain, params) in zip(chains, init_params) @@ -382,14 +429,36 @@ MySampler(), MCMCSerial(), 3, - 100; + nchains; progress=false, - init_params=Iterators.repeated(init_params), + init_params=FillArrays.Fill(init_params, nchains), ) - @test length(chains) == 100 + @test length(chains) == nchains @test all( chain[1].a == init_params.a && chain[1].b == init_params.b for chain in chains ) + + # Too many `init_params` + @test_throws ArgumentError sample( + MyModel(), + MySampler(), + MCMCSerial(), + 3, + nchains; + progress=false, + init_params=FillArrays.Fill(init_params, nchains + 1), + ) + + # Too few `init_params` + @test_throws ArgumentError sample( + MyModel(), + MySampler(), + MCMCSerial(), + 3, + nchains; + progress=false, + init_params=FillArrays.Fill(init_params, nchains - 1), + ) end @testset "Ensemble sampling: Reproducibility" begin