Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CompatHelper: bump compat for AbstractMCMC to 5, (keep existing compat) #551

Merged
14 changes: 7 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.20"
version = "0.24.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -29,21 +29,21 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
DynamicPPLMCMCChainsExt = ["MCMCChains"]

[compat]
AbstractMCMC = "2, 3.0, 4"
AbstractMCMC = "5"
AbstractPPL = "0.6"
BangBang = "0.3"
Bijectors = "0.13"
ChainRulesCore = "0.9.7, 0.10, 1"
ConstructionBase = "1.5.4"
ChainRulesCore = "1"
Compat = "4"
Distributions = "0.23.8, 0.24, 0.25"
DocStringExtensions = "0.8, 0.9"
ConstructionBase = "1.5.4"
Distributions = "0.25"
DocStringExtensions = "0.9"
LogDensityProblems = "2"
MCMCChains = "6"
MacroTools = "0.5.6"
OrderedCollections = "1"
Requires = "1"
Setfield = "0.7.1, 0.8, 1"
Setfield = "1"
ZygoteRules = "0.2"
julia = "1.6"

Expand Down
8 changes: 8 additions & 0 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ function _check_varname_indexing(c::MCMCChains.Chains)
error("Chains do not support indexing using $vn.")
end

# Load state from a `Chains`: By convention, it is stored in `:samplerstate` metadata
function DynamicPPL.loadstate(chain::MCMCChains.Chains)
if !haskey(chain.info, :samplerstate)
throw(ArgumentError("The chain object does not contain the final state of the sampler: Metadata `:samplerstate` missing."))
devmotion marked this conversation as resolved.
Show resolved Hide resolved
end
return chain.info[:samplerstate]
end

# A few methods needed.
function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains)
return _has_varname_to_symbol(chain.info)
Expand Down
44 changes: 30 additions & 14 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,33 @@ function default_varinfo(
return VarInfo(rng, model, init_sampler, context)
end

function AbstractMCMC.sample(
rng::AbstractRNG,
model::Model,
sampler::Sampler,
N::Integer;
chain_type=default_chain_type(sampler),
resume_from=nothing,
kwargs...
devmotion marked this conversation as resolved.
Show resolved Hide resolved
)
initial_state = loadstate(resume_from)
return AbstractMCMC.mcmcsample(rng, model, sampler, N; chain_type, initial_state, kwargs...)
devmotion marked this conversation as resolved.
Show resolved Hide resolved
end

# initial step: general interface for resuming and
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::Model,
spl::Sampler;
resume_from=nothing,
init_params=nothing,
initial_params=nothing,
kwargs...,
yebai marked this conversation as resolved.
Show resolved Hide resolved
)
if resume_from !== nothing
state = loadstate(resume_from)
return AbstractMCMC.step(rng, model, spl, state; kwargs...)
end

# Sample initial values.
vi = default_varinfo(rng, model, spl)

# Update the parameters if provided.
if init_params !== nothing
vi = initialize_parameters!!(vi, init_params, spl, model)
if initial_params !== nothing
vi = initialize_parameters!!(vi, initial_params, spl, model)

# Update joint log probability.
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
Expand All @@ -108,15 +115,24 @@ function AbstractMCMC.step(
vi = last(evaluate!!(model, vi, DefaultContext()))
end

return initialstep(rng, model, spl, vi; init_params=init_params, kwargs...)
return initialstep(rng, model, spl, vi; initial_params, kwargs...)
end

"""
loadstate(data)

Load sampler state from `data`.

By default, `data` is returned.
"""
loadstate(data) = data

"""
default_chaintype(sampler)

Default type of the chain of posterior samples from `sampler`.
"""
function loadstate end
default_chain_type(sampler::Sampler) = Any

"""
initialsampler(sampler::Sampler)
Expand All @@ -129,12 +145,12 @@ By default, it returns an instance of [`SampleFromPrior`](@ref).
initialsampler(spl::Sampler) = SampleFromPrior()

function initialize_parameters!!(
vi::AbstractVarInfo, init_params, spl::Sampler, model::Model
vi::AbstractVarInfo, initial_params, spl::Sampler, model::Model
)
@debug "Using passed-in initial variable values" init_params
@debug "Using passed-in initial variable values" initial_params

# Flatten parameters.
init_theta = mapreduce(vcat, init_params) do x
init_theta = mapreduce(vcat, initial_params) do x
vec([x;])
end

Expand Down
10 changes: 5 additions & 5 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractMCMC = "2.1, 3.0, 4"
AbstractMCMC = "5"
AbstractPPL = "0.6"
Bijectors = "0.13"
Compat = "4.3.0"
Distributions = "0.25"
DistributionsAD = "0.6.3"
Documenter = "0.26.1, 0.27, 1"
Documenter = "1"
ForwardDiff = "0.10.12"
LogDensityProblems = "2"
MCMCChains = "4.0.4, 5, 6"
MCMCChains = "6.0.4"
MacroTools = "0.5.5"
Setfield = "0.7.1, 0.8, 1"
Setfield = "1"
StableRNGs = "1"
Tracker = "0.2.23"
Zygote = "0.5.4, 0.6"
Zygote = "0.6"
julia = "1.6"
18 changes: 9 additions & 9 deletions test/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
model = coinflip()
sampler = Sampler(alg)
lptrue = logpdf(Binomial(25, 0.2), 10)
chain = sample(model, sampler, 1; init_params=0.2, progress=false)
chain = sample(model, sampler, 1; initial_params=0.2, progress=false)
@test chain[1].metadata.p.vals == [0.2]
@test getlogp(chain[1]) == lptrue

Expand All @@ -95,7 +95,7 @@
MCMCThreads(),
1,
10;
init_params=fill(0.2, 10),
initial_params=fill(0.2, 10),
progress=false,
)
for c in chains
Expand All @@ -110,7 +110,7 @@
end
model = twovars()
lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1)
chain = sample(model, sampler, 1; init_params=[4, -1], progress=false)
chain = sample(model, sampler, 1; initial_params=[4, -1], progress=false)
@test chain[1].metadata.s.vals == [4]
@test chain[1].metadata.m.vals == [-1]
@test getlogp(chain[1]) == lptrue
Expand All @@ -122,7 +122,7 @@
MCMCThreads(),
1,
10;
init_params=fill([4, -1], 10),
initial_params=fill([4, -1], 10),
progress=false,
)
for c in chains
Expand All @@ -132,7 +132,7 @@
end

# set only m = -1
chain = sample(model, sampler, 1; init_params=[missing, -1], progress=false)
chain = sample(model, sampler, 1; initial_params=[missing, -1], progress=false)
@test !ismissing(chain[1].metadata.s.vals[1])
@test chain[1].metadata.m.vals == [-1]

Expand All @@ -143,19 +143,19 @@
MCMCThreads(),
1,
10;
init_params=fill([missing, -1], 10),
initial_params=fill([missing, -1], 10),
progress=false,
)
for c in chains
@test !ismissing(c[1].metadata.s.vals[1])
@test c[1].metadata.m.vals == [-1]
end

# specify `init_params=nothing`
# specify `initial_params=nothing`
Random.seed!(1234)
chain1 = sample(model, sampler, 1; progress=false)
Random.seed!(1234)
chain2 = sample(model, sampler, 1; init_params=nothing, progress=false)
chain2 = sample(model, sampler, 1; initial_params=nothing, progress=false)
@test chain1[1].metadata.m.vals == chain2[1].metadata.m.vals
@test chain1[1].metadata.s.vals == chain2[1].metadata.s.vals

Expand All @@ -164,7 +164,7 @@
chains1 = sample(model, sampler, MCMCThreads(), 1, 10; progress=false)
Random.seed!(1234)
chains2 = sample(
model, sampler, MCMCThreads(), 1, 10; init_params=nothing, progress=false
model, sampler, MCMCThreads(), 1, 10; initial_params=nothing, progress=false
)
for (c1, c2) in zip(chains1, chains2)
@test c1[1].metadata.m.vals == c2[1].metadata.m.vals
Expand Down
Loading