Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CompatHelper: bump compat for AbstractMCMC to 5, (keep existing compat) #551

Merged
14 changes: 7 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,21 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
DynamicPPLMCMCChainsExt = ["MCMCChains"]

[compat]
AbstractMCMC = "2, 3.0, 4"
AbstractPPL = "0.6"
AbstractMCMC = "5"
AbstractPPL = "0.7"
BangBang = "0.3"
Bijectors = "0.13"
ChainRulesCore = "0.9.7, 0.10, 1"
ConstructionBase = "1.5.4"
ChainRulesCore = "1"
Compat = "4"
Distributions = "0.23.8, 0.24, 0.25"
DocStringExtensions = "0.8, 0.9"
ConstructionBase = "1.5.4"
Distributions = "0.25"
DocStringExtensions = "0.9"
LogDensityProblems = "2"
MCMCChains = "6"
MacroTools = "0.5.6"
OrderedCollections = "1"
Requires = "1"
Setfield = "0.7.1, 0.8, 1"
Setfield = "1"
ZygoteRules = "0.2"
julia = "1.6"

Expand Down
12 changes: 12 additions & 0 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@
error("Chains do not support indexing using $vn.")
end

# Load state from a `Chains`: By convention, it is stored in `:samplerstate` metadata
function DynamicPPL.loadstate(chain::MCMCChains.Chains)
if !haskey(chain.info, :samplerstate)
throw(

Check warning on line 20 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L18-L20

Added lines #L18 - L20 were not covered by tests
ArgumentError(
"The chain object does not contain the final state of the sampler: Metadata `:samplerstate` missing.",
),
)
end
return chain.info[:samplerstate]

Check warning on line 26 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L26

Added line #L26 was not covered by tests
end

# A few methods needed.
function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains)
return _has_varname_to_symbol(chain.info)
Expand Down
44 changes: 29 additions & 15 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,31 @@ function default_varinfo(
return VarInfo(rng, model, init_sampler, context)
end

# initial step: general interface for resuming and
function AbstractMCMC.step(
function AbstractMCMC.sample(
rng::Random.AbstractRNG,
model::Model,
spl::Sampler;
sampler::Sampler,
N::Integer;
chain_type=default_chain_type(sampler),
resume_from=nothing,
init_params=nothing,
initial_state = loadstate(resume_from),
yebai marked this conversation as resolved.
Show resolved Hide resolved
kwargs...,
)
if resume_from !== nothing
state = loadstate(resume_from)
return AbstractMCMC.step(rng, model, spl, state; kwargs...)
end
return AbstractMCMC.mcmcsample(
rng, model, sampler, N; chain_type, initial_state, kwargs...
)
end

# initial step: general interface for resuming and
function AbstractMCMC.step(
rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs...
)
# Sample initial values.
vi = default_varinfo(rng, model, spl)

# Update the parameters if provided.
if init_params !== nothing
vi = initialize_parameters!!(vi, init_params, spl, model)
if initial_params !== nothing
vi = initialize_parameters!!(vi, initial_params, spl, model)

# Update joint log probability.
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
Expand All @@ -108,15 +113,24 @@ function AbstractMCMC.step(
vi = last(evaluate!!(model, vi, DefaultContext()))
end

return initialstep(rng, model, spl, vi; init_params=init_params, kwargs...)
return initialstep(rng, model, spl, vi; initial_params, kwargs...)
end

"""
loadstate(data)

Load sampler state from `data`.

By default, `data` is returned.
"""
loadstate(data) = data

"""
default_chaintype(sampler)

Default type of the chain of posterior samples from `sampler`.
"""
function loadstate end
default_chain_type(sampler::Sampler) = Any

"""
initialsampler(sampler::Sampler)
Expand All @@ -129,12 +143,12 @@ By default, it returns an instance of [`SampleFromPrior`](@ref).
initialsampler(spl::Sampler) = SampleFromPrior()

function initialize_parameters!!(
vi::AbstractVarInfo, init_params, spl::Sampler, model::Model
vi::AbstractVarInfo, initial_params, spl::Sampler, model::Model
)
@debug "Using passed-in initial variable values" init_params
@debug "Using passed-in initial variable values" initial_params

# Flatten parameters.
init_theta = mapreduce(vcat, init_params) do x
init_theta = mapreduce(vcat, initial_params) do x
vec([x;])
end

Expand Down
12 changes: 6 additions & 6 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractMCMC = "2.1, 3.0, 4"
AbstractPPL = "0.6"
AbstractMCMC = "5"
AbstractPPL = "0.7"
Bijectors = "0.13"
Compat = "4.3.0"
Distributions = "0.25"
DistributionsAD = "0.6.3"
Documenter = "0.26.1, 0.27, 1"
Documenter = "1"
ForwardDiff = "0.10.12"
LogDensityProblems = "2"
MCMCChains = "4.0.4, 5, 6"
MCMCChains = "6.0.4"
MacroTools = "0.5.5"
Setfield = "0.7.1, 0.8, 1"
Setfield = "1"
StableRNGs = "1"
Tracker = "0.2.23"
Zygote = "0.5.4, 0.6"
Zygote = "0.6"
julia = "1.6"
18 changes: 9 additions & 9 deletions test/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
model = coinflip()
sampler = Sampler(alg)
lptrue = logpdf(Binomial(25, 0.2), 10)
chain = sample(model, sampler, 1; init_params=0.2, progress=false)
chain = sample(model, sampler, 1; initial_params=0.2, progress=false)
@test chain[1].metadata.p.vals == [0.2]
@test getlogp(chain[1]) == lptrue

Expand All @@ -95,7 +95,7 @@
MCMCThreads(),
1,
10;
init_params=fill(0.2, 10),
initial_params=fill(0.2, 10),
progress=false,
)
for c in chains
Expand All @@ -110,7 +110,7 @@
end
model = twovars()
lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1)
chain = sample(model, sampler, 1; init_params=[4, -1], progress=false)
chain = sample(model, sampler, 1; initial_params=[4, -1], progress=false)
@test chain[1].metadata.s.vals == [4]
@test chain[1].metadata.m.vals == [-1]
@test getlogp(chain[1]) == lptrue
Expand All @@ -122,7 +122,7 @@
MCMCThreads(),
1,
10;
init_params=fill([4, -1], 10),
initial_params=fill([4, -1], 10),
progress=false,
)
for c in chains
Expand All @@ -132,7 +132,7 @@
end

# set only m = -1
chain = sample(model, sampler, 1; init_params=[missing, -1], progress=false)
chain = sample(model, sampler, 1; initial_params=[missing, -1], progress=false)
@test !ismissing(chain[1].metadata.s.vals[1])
@test chain[1].metadata.m.vals == [-1]

Expand All @@ -143,19 +143,19 @@
MCMCThreads(),
1,
10;
init_params=fill([missing, -1], 10),
initial_params=fill([missing, -1], 10),
progress=false,
)
for c in chains
@test !ismissing(c[1].metadata.s.vals[1])
@test c[1].metadata.m.vals == [-1]
end

# specify `init_params=nothing`
# specify `initial_params=nothing`
Random.seed!(1234)
chain1 = sample(model, sampler, 1; progress=false)
Random.seed!(1234)
chain2 = sample(model, sampler, 1; init_params=nothing, progress=false)
chain2 = sample(model, sampler, 1; initial_params=nothing, progress=false)
@test chain1[1].metadata.m.vals == chain2[1].metadata.m.vals
@test chain1[1].metadata.s.vals == chain2[1].metadata.s.vals

Expand All @@ -164,7 +164,7 @@
chains1 = sample(model, sampler, MCMCThreads(), 1, 10; progress=false)
Random.seed!(1234)
chains2 = sample(
model, sampler, MCMCThreads(), 1, 10; init_params=nothing, progress=false
model, sampler, MCMCThreads(), 1, 10; initial_params=nothing, progress=false
)
for (c1, c2) in zip(chains1, chains2)
@test c1[1].metadata.m.vals == c2[1].metadata.m.vals
Expand Down
Loading