diff --git a/Project.toml b/Project.toml index 6a7cda61b..501179d1f 100644 --- a/Project.toml +++ b/Project.toml @@ -29,21 +29,21 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" DynamicPPLMCMCChainsExt = ["MCMCChains"] [compat] -AbstractMCMC = "2, 3.0, 4" +AbstractMCMC = "5" AbstractPPL = "0.6" BangBang = "0.3" Bijectors = "0.13" -ChainRulesCore = "0.9.7, 0.10, 1" -ConstructionBase = "1.5.4" +ChainRulesCore = "1" Compat = "4" -Distributions = "0.23.8, 0.24, 0.25" -DocStringExtensions = "0.8, 0.9" +ConstructionBase = "1.5.4" +Distributions = "0.25" +DocStringExtensions = "0.9" LogDensityProblems = "2" MCMCChains = "6" MacroTools = "0.5.6" OrderedCollections = "1" Requires = "1" -Setfield = "0.7.1, 0.8, 1" +Setfield = "1" ZygoteRules = "0.2" julia = "1.6" diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 2630e9d1b..8c598a6a8 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -14,6 +14,18 @@ function _check_varname_indexing(c::MCMCChains.Chains) error("Chains do not support indexing using $vn.") end +# Load state from a `Chains`: By convention, it is stored in `:samplerstate` metadata +function DynamicPPL.loadstate(chain::MCMCChains.Chains) + if !haskey(chain.info, :samplerstate) + throw( + ArgumentError( + "The chain object does not contain the final state of the sampler: Metadata `:samplerstate` missing.", + ), + ) + end + return chain.info[:samplerstate] +end + # A few methods needed. function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains) return _has_varname_to_symbol(chain.info) diff --git a/src/sampler.jl b/src/sampler.jl index 3a4daf0b1..17f0ffca3 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -80,26 +80,31 @@ function default_varinfo( return VarInfo(rng, model, init_sampler, context) end -# initial step: general interface for resuming and -function AbstractMCMC.step( - rng::Random.AbstractRNG, +function AbstractMCMC.sample( + rng::AbstractRNG, model::Model, - spl::Sampler; + sampler::Sampler, + N::Integer; + chain_type=default_chain_type(sampler), resume_from=nothing, - init_params=nothing, kwargs..., ) - if resume_from !== nothing - state = loadstate(resume_from) - return AbstractMCMC.step(rng, model, spl, state; kwargs...) - end + initial_state = loadstate(resume_from) + return AbstractMCMC.mcmcsample( + rng, model, sampler, N; chain_type, initial_state, kwargs... + ) +end +# initial step: general interface for resuming and +function AbstractMCMC.step( + rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs... +) # Sample initial values. vi = default_varinfo(rng, model, spl) # Update the parameters if provided. - if init_params !== nothing - vi = initialize_parameters!!(vi, init_params, spl, model) + if initial_params !== nothing + vi = initialize_parameters!!(vi, initial_params, spl, model) # Update joint log probability. # This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588 @@ -108,15 +113,24 @@ function AbstractMCMC.step( vi = last(evaluate!!(model, vi, DefaultContext())) end - return initialstep(rng, model, spl, vi; init_params=init_params, kwargs...) + return initialstep(rng, model, spl, vi; initial_params, kwargs...) end """ loadstate(data) Load sampler state from `data`. + +By default, `data` is returned. +""" +loadstate(data) = data + +""" + default_chaintype(sampler) + +Default type of the chain of posterior samples from `sampler`. """ -function loadstate end +default_chain_type(sampler::Sampler) = Any """ initialsampler(sampler::Sampler) @@ -129,12 +143,12 @@ By default, it returns an instance of [`SampleFromPrior`](@ref). initialsampler(spl::Sampler) = SampleFromPrior() function initialize_parameters!!( - vi::AbstractVarInfo, init_params, spl::Sampler, model::Model + vi::AbstractVarInfo, initial_params, spl::Sampler, model::Model ) - @debug "Using passed-in initial variable values" init_params + @debug "Using passed-in initial variable values" initial_params # Flatten parameters. - init_theta = mapreduce(vcat, init_params) do x + init_theta = mapreduce(vcat, initial_params) do x vec([x;]) end diff --git a/test/Project.toml b/test/Project.toml index 4a7654c03..483684368 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -22,19 +22,19 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -AbstractMCMC = "2.1, 3.0, 4, 5" +AbstractMCMC = "5" AbstractPPL = "0.6" Bijectors = "0.13" Compat = "4.3.0" Distributions = "0.25" DistributionsAD = "0.6.3" -Documenter = "0.26.1, 0.27, 1" +Documenter = "1" ForwardDiff = "0.10.12" LogDensityProblems = "2" -MCMCChains = "4.0.4, 5, 6" +MCMCChains = "6.0.4" MacroTools = "0.5.5" -Setfield = "0.7.1, 0.8, 1" +Setfield = "1" StableRNGs = "1" Tracker = "0.2.23" -Zygote = "0.5.4, 0.6" +Zygote = "0.6" julia = "1.6" diff --git a/test/sampler.jl b/test/sampler.jl index 25a655507..b52a9c921 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -84,7 +84,7 @@ model = coinflip() sampler = Sampler(alg) lptrue = logpdf(Binomial(25, 0.2), 10) - chain = sample(model, sampler, 1; init_params=0.2, progress=false) + chain = sample(model, sampler, 1; initial_params=0.2, progress=false) @test chain[1].metadata.p.vals == [0.2] @test getlogp(chain[1]) == lptrue @@ -95,7 +95,7 @@ MCMCThreads(), 1, 10; - init_params=fill(0.2, 10), + initial_params=fill(0.2, 10), progress=false, ) for c in chains @@ -110,7 +110,7 @@ end model = twovars() lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) - chain = sample(model, sampler, 1; init_params=[4, -1], progress=false) + chain = sample(model, sampler, 1; initial_params=[4, -1], progress=false) @test chain[1].metadata.s.vals == [4] @test chain[1].metadata.m.vals == [-1] @test getlogp(chain[1]) == lptrue @@ -122,7 +122,7 @@ MCMCThreads(), 1, 10; - init_params=fill([4, -1], 10), + initial_params=fill([4, -1], 10), progress=false, ) for c in chains @@ -132,7 +132,7 @@ end # set only m = -1 - chain = sample(model, sampler, 1; init_params=[missing, -1], progress=false) + chain = sample(model, sampler, 1; initial_params=[missing, -1], progress=false) @test !ismissing(chain[1].metadata.s.vals[1]) @test chain[1].metadata.m.vals == [-1] @@ -143,7 +143,7 @@ MCMCThreads(), 1, 10; - init_params=fill([missing, -1], 10), + initial_params=fill([missing, -1], 10), progress=false, ) for c in chains @@ -151,11 +151,11 @@ @test c[1].metadata.m.vals == [-1] end - # specify `init_params=nothing` + # specify `initial_params=nothing` Random.seed!(1234) chain1 = sample(model, sampler, 1; progress=false) Random.seed!(1234) - chain2 = sample(model, sampler, 1; init_params=nothing, progress=false) + chain2 = sample(model, sampler, 1; initial_params=nothing, progress=false) @test chain1[1].metadata.m.vals == chain2[1].metadata.m.vals @test chain1[1].metadata.s.vals == chain2[1].metadata.s.vals @@ -164,7 +164,7 @@ chains1 = sample(model, sampler, MCMCThreads(), 1, 10; progress=false) Random.seed!(1234) chains2 = sample( - model, sampler, MCMCThreads(), 1, 10; init_params=nothing, progress=false + model, sampler, MCMCThreads(), 1, 10; initial_params=nothing, progress=false ) for (c1, c2) in zip(chains1, chains2) @test c1[1].metadata.m.vals == c2[1].metadata.m.vals