diff --git a/gibbs_example/Project.toml b/gibbs_example/Project.toml index 1e8d8677..81b2b669 100644 --- a/gibbs_example/Project.toml +++ b/gibbs_example/Project.toml @@ -6,5 +6,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" diff --git a/gibbs_example/gibbs.jl b/gibbs_example/gibbs.jl index e201a099..1c68643f 100644 --- a/gibbs_example/gibbs.jl +++ b/gibbs_example/gibbs.jl @@ -1,12 +1,18 @@ +using AbstractMCMC using LogDensityProblems, Distributions, LinearAlgebra, Random using OrderedCollections +## + +# TODO: introduce some kind of parameter format, for instance, a flattened vector +# then define some kind of function to transform the flattened vector into model's representation + struct Gibbs <: AbstractMCMC.AbstractSampler sampler_map::OrderedDict end struct GibbsState - values::NamedTuple + vi::NamedTuple states::OrderedDict end @@ -15,31 +21,91 @@ struct GibbsTransition end function AbstractMCMC.step( - rng::AbstractRNG, model, sampler::Gibbs, args...; initial_params::NamedTuple, kwargs... + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + spl::Gibbs, + args...; + initial_params::NamedTuple, + kwargs..., ) states = OrderedDict() - for group in keys(sampler.sampler_map) - sampler = sampler.sampler_map[group] - cond_val = NamedTuple{group}([initial_params[g] for g in group]...) - trans, state = AbstractMCMC.step( - rng, condition(model, cond_val), sampler, args...; kwargs... + for group in keys(spl.sampler_map) + sub_spl = spl.sampler_map[group] + + vars_to_be_conditioned_on = setdiff(keys(initial_params), group) + cond_val = NamedTuple{Tuple(vars_to_be_conditioned_on)}( + Tuple([initial_params[g] for g in vars_to_be_conditioned_on]) + ) + params_val = NamedTuple{Tuple(group)}(Tuple([initial_params[g] for g in group])) + sub_state = last( + AbstractMCMC.step( + rng, + AbstractMCMC.LogDensityModel( + condition(logdensity_model.logdensity, cond_val) + ), + sub_spl, + args...; + initial_params=flatten(params_val), + kwargs..., + ), ) - states[group] = state + states[group] = sub_state end return GibbsTransition(initial_params), GibbsState(initial_params, states) end function AbstractMCMC.step( - rng::AbstractRNG, model, sampler::Gibbs, state::GibbsState, args...; kwargs... + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + spl::Gibbs, + state::GibbsState, + args...; + kwargs..., ) - for group in collect(keys(sampler.sampler_map)) - sampler = sampler.sampler_map[group] - state = state.states[group] - trans, state = AbstractMCMC.step( - rng, condition(model, state.values[group]), sampler, state, args...; kwargs... + vi = state.vi + for group in keys(spl.sampler_map) + for (group, sub_state) in state.states + vi = merge(vi, unflatten(getparams(sub_state), group)) + end + sub_spl = spl.sampler_map[group] + sub_state = state.states[group] + group_complement = setdiff(keys(vi), group) + cond_val = NamedTuple{Tuple(group_complement)}( + Tuple([vi[g] for g in group_complement]) + ) + sub_state = last( + AbstractMCMC.step( + rng, + AbstractMCMC.LogDensityModel( + condition(logdensity_model.logdensity, cond_val) + ), + sub_spl, + sub_state, + args...; + kwargs..., + ), ) - # TODO: what values to condition on here? stored where? - state.states[group] = state + state.states[group] = sub_state end - return nothing + for sub_state in values(state.states) + vi = merge(vi, getparams(sub_state)) + end + return GibbsTransition(vi), GibbsState(vi, state.states) end + +## tests + +gmm = GMM((; x=x)) + +samples = sample( + gmm, + Gibbs( + OrderedDict( + (:z,) => PriorMH(product_distribution([Categorical([0.3, 0.7]) for _ in 1:60])), + (:w,) => PriorMH(Dirichlet(2, 1.0)), + (:μ, :w) => RWMH(1), + ), + ), + 10000; + initial_params=(z=rand(Categorical([0.3, 0.7]), 60), μ=[0.0, 1.0], w=[0.3, 0.7]), +) diff --git a/gibbs_example/gmm.jl b/gibbs_example/gmm.jl index b40b4579..7a8eab79 100644 --- a/gibbs_example/gmm.jl +++ b/gibbs_example/gmm.jl @@ -29,6 +29,7 @@ function log_joint(; μ, w, z, x) logp += logpdf(w_prior, w) z_prior = Categorical(w) + logp += sum([logpdf(z_prior, z[i]) for i in 1:N]) obs_priors = [MvNormal(fill(μₖ, D), I) for μₖ in μ] @@ -43,33 +44,80 @@ function condition(gmm::GMM, conditioned_values::NamedTuple) return ConditionedGMM(gmm.data, conditioned_values) end -function _logdensity(gmm::ConditionedGMM{(:μ, :w)}, params) +function _logdensity(gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}}, params) return log_joint(; μ=gmm.conditioned_values.μ, w=gmm.conditioned_values.w, z=params.z, x=gmm.data.x ) end + function _logdensity(gmm::ConditionedGMM{(:z,)}, params) return log_joint(; μ=params.μ, w=params.w, z=gmm.conditioned_values.z, x=gmm.data.x) end function LogDensityProblems.logdensity( - gmm::ConditionedGMM{(:μ, :w)}, params_vec::AbstractVector + gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}}, + params_vec::AbstractVector, ) + @assert length(params_vec) == 60 return _logdensity(gmm, (; z=params_vec)) end function LogDensityProblems.logdensity( gmm::ConditionedGMM{(:z,)}, params_vec::AbstractVector ) + @assert length(params_vec) == 4 "length(params_vec) = $(length(params_vec))" return _logdensity(gmm, (; μ=params_vec[1:2], w=params_vec[3:4])) end -function LogDensityProblems.dimension(gmm::ConditionedGMM{(:μ, :w)}) - return size(gmm.data.x, 1) +function LogDensityProblems.dimension(gmm::GMM) + return 4 + size(gmm.data.x, 1) +end + +function LogDensityProblems.dimension( + gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}} +) + return 4 end + function LogDensityProblems.dimension(gmm::ConditionedGMM{(:z,)}) return size(gmm.data.x, 1) end +function LogDensityProblems.capabilities(::GMM) + return LogDensityProblems.LogDensityOrder{0}() +end + +function LogDensityProblems.capabilities(::ConditionedGMM) + return LogDensityProblems.LogDensityOrder{0}() +end + +function flatten(nt::NamedTuple) + if Set(keys(nt)) == Set([:μ, :w]) + return vcat(nt.μ, nt.w) + elseif Set(keys(nt)) == Set([:z]) + return nt.z + else + error() + end +end + +function unflatten(vec::AbstractVector, group::Tuple) + if Set(group) == Set([:μ, :w]) + return (; μ=vec[1:2], w=vec[3:4]) + elseif Set(group) == Set([:z]) + return (; z=vec) + else + error() + end +end + +# sampler's states to internal representation +# ? who gets to define the output of `getparams`? (maybe have a `getparams(T, state)`?) + +# the point here is that the parameter values are not changed, but because the context was changed, the logprob need to be recomputed +function recompute_logprob!!(gmm::ConditionedGMM, vals, state) + return setlogp!(state, _logdensity(gmm, vals)) +end + ## test using Turing # data generation diff --git a/gibbs_example/mh.jl b/gibbs_example/mh.jl index 39d1c69f..b2fa91dc 100644 --- a/gibbs_example/mh.jl +++ b/gibbs_example/mh.jl @@ -1,13 +1,9 @@ -struct RWMH <: AbstractMCMC.AbstractSampler - σ -end - -struct MHTransition{T} where {T} - params::T +struct MHTransition{T} + params::Vector{T} end -struct MHState{T} where {T} - params::T +struct MHState{T} + params::Vector{T} logp::Float64 end @@ -16,21 +12,43 @@ setparams!!(state::MHState, params) = MHState(params, state.logp) getlogp(state::MHState) = state.logp setlogp!!(state::MHState, logp) = MHState(state.params, logp) -function AbstractMCMC.step(rng::AbstractRNG, logdensity, sampler::RWMH, args...; kwargs...) - params = rand(rng, LogDensityProblems.dimension(logdensity)) - return MHTransition(params), - MHState(params, LogDensityProblems.logdensity(logdensity, params)) +struct RWMH <: AbstractMCMC.AbstractSampler + σ::Float64 end function AbstractMCMC.step( - rng::AbstractRNG, logdensity, sampler::RWMH, state::MHState, args...; kwargs... + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::RWMH, + args...; + initial_params, + kwargs..., ) - params = getparams(state) - proposal_dist = MvNormal(params, sampler.σ) - proposal = rand(rng, proposal_dist) - logp_proposal = logpdf(proposal_dist, proposal) - accepted = log(rand(rng)) < log1pexp(min(0, logp_proposal - getlogp(state))) - if accepted + return MHTransition(initial_params), + MHState( + initial_params, + only(LogDensityProblems.logdensity(logdensity_model.logdensity, initial_params)), + ) +end + +function AbstractMCMC.step( + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::RWMH, + state::MHState, + args...; + kwargs..., +) + params = state.params + proposal_dist = MvNormal(zeros(length(params)), sampler.σ) + proposal = params .+ rand(rng, proposal_dist) + logp_proposal = only( + LogDensityProblems.logdensity(logdensity_model.logdensity, proposal) + ) + + log_acceptance_ratio = min(0, logp_proposal - getlogp(state)) + + if log(rand(rng)) < log_acceptance_ratio return MHTransition(proposal), MHState(proposal, logp_proposal) else return MHTransition(params), MHState(params, getlogp(state)) @@ -38,27 +56,104 @@ function AbstractMCMC.step( end struct PriorMH <: AbstractMCMC.AbstractSampler - prior_dist + prior_dist::Distribution end function AbstractMCMC.step( - rng::AbstractRNG, logdensity, sampler::PriorMH, args...; kwargs... + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::PriorMH, + args...; + initial_params, + kwargs..., ) - params = rand(rng, sampler.prior_dist) - return MHTransition(params), MHState(params, logdensity(params)) + return MHTransition(initial_params), + MHState( + initial_params, + only(LogDensityProblems.logdensity(logdensity_model.logdensity, initial_params)), + ) end function AbstractMCMC.step( - rng::AbstractRNG, logdensity, sampler::PriorMH, state::MHState, args...; kwargs... + rng::AbstractRNG, + logdensity_model::AbstractMCMC.LogDensityModel, + sampler::PriorMH, + state::MHState, + args...; + kwargs..., ) params = getparams(state) proposal_dist = sampler.prior_dist proposal = rand(rng, proposal_dist) - logp_proposal = logpdf(proposal_dist, proposal) - accepted = log(rand(rng)) < log1pexp(min(0, logp_proposal - getlogp(state))) - if accepted + logp_proposal = only( + LogDensityProblems.logdensity(logdensity_model.logdensity, proposal) + ) + + log_acceptance_ratio = min( + 0, + logp_proposal - getlogp(state) + logpdf(proposal_dist, params) - + logpdf(proposal_dist, proposal), + ) + + if log(rand(rng)) < log_acceptance_ratio return MHTransition(proposal), MHState(proposal, logp_proposal) else return MHTransition(params), MHState(params, getlogp(state)) end end + +## tests + +# for RWMH +# sample from Normal(10, 1) +struct NormalLogDensity end +LogDensityProblems.logdensity(l::NormalLogDensity, x) = logpdf(Normal(10, 1), only(x)) +LogDensityProblems.dimension(l::NormalLogDensity) = 1 +function LogDensityProblems.capabilities(::NormalLogDensity) + return LogDensityProblems.LogDensityOrder{1}() +end + +# for PriorMH +# sample from Categorical([0.2, 0.5, 0.3]) +struct CategoricalLogDensity end +function LogDensityProblems.logdensity(l::CategoricalLogDensity, x) + return logpdf(Categorical([0.2, 0.6, 0.2]), only(x)) +end +LogDensityProblems.dimension(l::CategoricalLogDensity) = 1 +function LogDensityProblems.capabilities(::CategoricalLogDensity) + return LogDensityProblems.LogDensityOrder{0}() +end + +## + +using StatsPlots + +samples = AbstractMCMC.sample( + Random.default_rng(), NormalLogDensity(), RWMH(1), 100000; initial_params=[0.0] +) +_samples = map(t -> only(t.params), samples) + +histogram( + _samples; + normalize=:pdf, + label="Samples", + title="RWMH Sampling of Normal(10, 1)", +) +plot!(Normal(10, 1); linewidth=2, label="Ground Truth") + +samples = AbstractMCMC.sample( + Random.default_rng(), + CategoricalLogDensity(), + PriorMH(product_distribution([Categorical([0.3, 0.3, 0.4])])), + 100000; + initial_params=[1], +) +_samples = map(t -> only(t.params), samples) + +histogram( + _samples; + normalize=:probability, + label="Samples", + title="MH From Prior Sampling of Categorical([0.3, 0.3, 0.4])", +) +plot!(Categorical([0.2, 0.6, 0.2]); linewidth=2, label="Ground Truth")