diff --git a/src/Turing.jl b/src/Turing.jl index 11dcbdb6f..33286a665 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -73,7 +73,6 @@ export @model, # modelling Prior, # Sampling from the prior MH, # classic sampling - RWMH, Emcee, ESS, Gibbs, diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 73f190dcc..45ae434a4 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -99,6 +99,24 @@ Wrap a sampler so it can be used as an inference algorithm. """ externalsampler(sampler::AbstractSampler) = ExternalSampler(sampler) +""" + ESLogDensityFunction + +A log density function for the External sampler. + +""" +const ESLogDensityFunction{M<:Model,S<:Sampler{<:ExternalSampler},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,<:DynamicPPL.DefaultContext} +function LogDensityProblems.logdensity(f::ESLogDensityFunction, x::NamedTuple) + return DynamicPPL.logjoint(f.model, DynamicPPL.unflatten(f.varinfo, x)) +end + +# TODO: make a nicer `set_namedtuple!` and move these functions to DynamicPPL. +function DynamicPPL.unflatten(vi::TypedVarInfo, θ::NamedTuple) + set_namedtuple!(deepcopy(vi), θ) + return vi +end +DynamicPPL.unflatten(vi::SimpleVarInfo, θ::NamedTuple) = SimpleVarInfo(θ, vi.logp, vi.transformation) + # Algorithm for sampling from the prior struct Prior <: InferenceAlgorithm end diff --git a/src/inference/mh.jl b/src/inference/mh.jl index ddbeaa2c5..dd97efd18 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -188,6 +188,20 @@ function MH(space...) return MH{tuple(syms...), typeof(proposals)}(proposals) end +# Some of the proposals require working in unconstrained space. +transform_maybe(proposal::AMH.Proposal) = proposal +function transform_maybe(proposal::AMH.RandomWalkProposal) + return AMH.RandomWalkProposal(Bijectors.transformed(proposal.proposal)) +end + +function MH(model::Model; proposal_type=AMH.StaticProposal) + priors = DynamicPPL.extract_priors(model) + props = Tuple([proposal_type(prop) for prop in values(priors)]) + vars = Tuple(map(Symbol, collect(keys(priors)))) + priors = map(transform_maybe, NamedTuple{vars}(props)) + return AMH.MetropolisHastings(priors) +end + ##################### # Utility functions # ##################### @@ -346,6 +360,7 @@ end function should_link(varinfo, sampler, proposal::AdvancedMH.RandomWalkProposal) return true end +# FIXME: This won't be hit unless `vals` are all the exactly same concrete type of `AdvancedMH.RandomWalkProposal`! function should_link( varinfo, sampler, diff --git a/test/inference/mh.jl b/test/inference/mh.jl index 8e52aec9b..94f9aa992 100644 --- a/test/inference/mh.jl +++ b/test/inference/mh.jl @@ -17,6 +17,12 @@ s4 = Gibbs(MH(:m), MH(:s)) c4 = sample(gdemo_default, s4, N) + + s5 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.RandomWalkProposal)) + c5 = sample(gdemo_default, s5, N) + + s6 = externalsampler(MH(gdemo_default, proposal_type=AdvancedMH.StaticProposal)) + c6 = sample(gdemo_default, s6, N) end @numerical_testset "mh inference" begin Random.seed!(125)