Skip to content

Commit

Permalink
Merge branch 'compathelper/new_version/2023-10-26-00-09-06-954-015731…
Browse files Browse the repository at this point in the history
…51329' into compathelper/new_version/2023-10-27-00-09-36-947-01221212180
  • Loading branch information
yebai authored Nov 1, 2023
2 parents 52235d8 + 9a2d2e5 commit a7f8658
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 36 deletions.
12 changes: 6 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,21 @@ MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
DynamicPPLMCMCChainsExt = ["MCMCChains"]

[compat]
AbstractMCMC = "2, 3.0, 4"
AbstractMCMC = "5"
AbstractPPL = "0.6"
BangBang = "0.3"
Bijectors = "0.13"
ChainRulesCore = "0.9.7, 0.10, 1"
ConstructionBase = "1.5.4"
ChainRulesCore = "1"
Compat = "4"
Distributions = "0.23.8, 0.24, 0.25"
DocStringExtensions = "0.8, 0.9"
ConstructionBase = "1.5.4"
Distributions = "0.25"
DocStringExtensions = "0.9"
LogDensityProblems = "2"
MCMCChains = "6"
MacroTools = "0.5.6"
OrderedCollections = "1"
Requires = "1"
Setfield = "0.7.1, 0.8, 1"
Setfield = "1"
ZygoteRules = "0.2"
julia = "1.6"

Expand Down
12 changes: 12 additions & 0 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ function _check_varname_indexing(c::MCMCChains.Chains)
error("Chains do not support indexing using $vn.")
end

# Load state from a `Chains`: By convention, it is stored in `:samplerstate` metadata
function DynamicPPL.loadstate(chain::MCMCChains.Chains)
if !haskey(chain.info, :samplerstate)
throw(
ArgumentError(
"The chain object does not contain the final state of the sampler: Metadata `:samplerstate` missing.",
),
)
end
return chain.info[:samplerstate]
end

# A few methods needed.
function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains)
return _has_varname_to_symbol(chain.info)
Expand Down
46 changes: 30 additions & 16 deletions src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,26 +80,31 @@ function default_varinfo(
return VarInfo(rng, model, init_sampler, context)
end

# initial step: general interface for resuming and
function AbstractMCMC.step(
rng::Random.AbstractRNG,
function AbstractMCMC.sample(
rng::AbstractRNG,
model::Model,
spl::Sampler;
sampler::Sampler,
N::Integer;
chain_type=default_chain_type(sampler),
resume_from=nothing,
init_params=nothing,
kwargs...,
)
if resume_from !== nothing
state = loadstate(resume_from)
return AbstractMCMC.step(rng, model, spl, state; kwargs...)
end
initial_state = loadstate(resume_from)
return AbstractMCMC.mcmcsample(
rng, model, sampler, N; chain_type, initial_state, kwargs...
)
end

# initial step: general interface for resuming and
function AbstractMCMC.step(
rng::Random.AbstractRNG, model::Model, spl::Sampler; initial_params=nothing, kwargs...
)
# Sample initial values.
vi = default_varinfo(rng, model, spl)

# Update the parameters if provided.
if init_params !== nothing
vi = initialize_parameters!!(vi, init_params, spl, model)
if initial_params !== nothing
vi = initialize_parameters!!(vi, initial_params, spl, model)

# Update joint log probability.
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
Expand All @@ -108,15 +113,24 @@ function AbstractMCMC.step(
vi = last(evaluate!!(model, vi, DefaultContext()))
end

return initialstep(rng, model, spl, vi; init_params=init_params, kwargs...)
return initialstep(rng, model, spl, vi; initial_params, kwargs...)
end

"""
loadstate(data)
Load sampler state from `data`.
By default, `data` is returned.
"""
loadstate(data) = data

"""
default_chaintype(sampler)
Default type of the chain of posterior samples from `sampler`.
"""
function loadstate end
default_chain_type(sampler::Sampler) = Any

"""
initialsampler(sampler::Sampler)
Expand All @@ -129,12 +143,12 @@ By default, it returns an instance of [`SampleFromPrior`](@ref).
initialsampler(spl::Sampler) = SampleFromPrior()

function initialize_parameters!!(
vi::AbstractVarInfo, init_params, spl::Sampler, model::Model
vi::AbstractVarInfo, initial_params, spl::Sampler, model::Model
)
@debug "Using passed-in initial variable values" init_params
@debug "Using passed-in initial variable values" initial_params

# Flatten parameters.
init_theta = mapreduce(vcat, init_params) do x
init_theta = mapreduce(vcat, initial_params) do x
vec([x;])
end

Expand Down
10 changes: 5 additions & 5 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractMCMC = "2.1, 3.0, 4, 5"
AbstractMCMC = "5"
AbstractPPL = "0.6"
Bijectors = "0.13"
Compat = "4.3.0"
Distributions = "0.25"
DistributionsAD = "0.6.3"
Documenter = "0.26.1, 0.27, 1"
Documenter = "1"
ForwardDiff = "0.10.12"
LogDensityProblems = "2"
MCMCChains = "4.0.4, 5, 6"
MCMCChains = "6.0.4"
MacroTools = "0.5.5"
Setfield = "0.7.1, 0.8, 1"
Setfield = "1"
StableRNGs = "1"
Tracker = "0.2.23"
Zygote = "0.5.4, 0.6"
Zygote = "0.6"
julia = "1.6"
18 changes: 9 additions & 9 deletions test/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
model = coinflip()
sampler = Sampler(alg)
lptrue = logpdf(Binomial(25, 0.2), 10)
chain = sample(model, sampler, 1; init_params=0.2, progress=false)
chain = sample(model, sampler, 1; initial_params=0.2, progress=false)
@test chain[1].metadata.p.vals == [0.2]
@test getlogp(chain[1]) == lptrue

Expand All @@ -95,7 +95,7 @@
MCMCThreads(),
1,
10;
init_params=fill(0.2, 10),
initial_params=fill(0.2, 10),
progress=false,
)
for c in chains
Expand All @@ -110,7 +110,7 @@
end
model = twovars()
lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1)
chain = sample(model, sampler, 1; init_params=[4, -1], progress=false)
chain = sample(model, sampler, 1; initial_params=[4, -1], progress=false)
@test chain[1].metadata.s.vals == [4]
@test chain[1].metadata.m.vals == [-1]
@test getlogp(chain[1]) == lptrue
Expand All @@ -122,7 +122,7 @@
MCMCThreads(),
1,
10;
init_params=fill([4, -1], 10),
initial_params=fill([4, -1], 10),
progress=false,
)
for c in chains
Expand All @@ -132,7 +132,7 @@
end

# set only m = -1
chain = sample(model, sampler, 1; init_params=[missing, -1], progress=false)
chain = sample(model, sampler, 1; initial_params=[missing, -1], progress=false)
@test !ismissing(chain[1].metadata.s.vals[1])
@test chain[1].metadata.m.vals == [-1]

Expand All @@ -143,19 +143,19 @@
MCMCThreads(),
1,
10;
init_params=fill([missing, -1], 10),
initial_params=fill([missing, -1], 10),
progress=false,
)
for c in chains
@test !ismissing(c[1].metadata.s.vals[1])
@test c[1].metadata.m.vals == [-1]
end

# specify `init_params=nothing`
# specify `initial_params=nothing`
Random.seed!(1234)
chain1 = sample(model, sampler, 1; progress=false)
Random.seed!(1234)
chain2 = sample(model, sampler, 1; init_params=nothing, progress=false)
chain2 = sample(model, sampler, 1; initial_params=nothing, progress=false)
@test chain1[1].metadata.m.vals == chain2[1].metadata.m.vals
@test chain1[1].metadata.s.vals == chain2[1].metadata.s.vals

Expand All @@ -164,7 +164,7 @@
chains1 = sample(model, sampler, MCMCThreads(), 1, 10; progress=false)
Random.seed!(1234)
chains2 = sample(
model, sampler, MCMCThreads(), 1, 10; init_params=nothing, progress=false
model, sampler, MCMCThreads(), 1, 10; initial_params=nothing, progress=false
)
for (c1, c2) in zip(chains1, chains2)
@test c1[1].metadata.m.vals == c2[1].metadata.m.vals
Expand Down

0 comments on commit a7f8658

Please sign in to comment.