From 67ff8e80d4e2018e6a6d9c3e40397a2522f89595 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 15 Aug 2024 14:01:08 +0100 Subject: [PATCH] results is wrong --- gibbs_example/gibbs.jl | 26 +++++++++------- gibbs_example/gmm.jl | 71 ++++++++---------------------------------- 2 files changed, 28 insertions(+), 69 deletions(-) diff --git a/gibbs_example/gibbs.jl b/gibbs_example/gibbs.jl index 1c68643f..2cd9643d 100644 --- a/gibbs_example/gibbs.jl +++ b/gibbs_example/gibbs.jl @@ -4,9 +4,6 @@ using OrderedCollections ## -# TODO: introduce some kind of parameter format, for instance, a flattened vector -# then define some kind of function to transform the flattened vector into model's representation - struct Gibbs <: AbstractMCMC.AbstractSampler sampler_map::OrderedDict end @@ -73,12 +70,12 @@ function AbstractMCMC.step( cond_val = NamedTuple{Tuple(group_complement)}( Tuple([vi[g] for g in group_complement]) ) + cond_logdensity = condition(logdensity_model.logdensity, cond_val) + sub_state = recompute_logprob!!(cond_logdensity, getparams(sub_state), sub_state) sub_state = last( AbstractMCMC.step( rng, - AbstractMCMC.LogDensityModel( - condition(logdensity_model.logdensity, cond_val) - ), + AbstractMCMC.LogDensityModel(cond_logdensity), sub_spl, sub_state, args...; @@ -87,8 +84,8 @@ function AbstractMCMC.step( ) state.states[group] = sub_state end - for sub_state in values(state.states) - vi = merge(vi, getparams(sub_state)) + for (group, sub_state) in state.states + vi = merge(vi, unflatten(getparams(sub_state), group)) end return GibbsTransition(vi), GibbsState(vi, state.states) end @@ -103,9 +100,16 @@ samples = sample( OrderedDict( (:z,) => PriorMH(product_distribution([Categorical([0.3, 0.7]) for _ in 1:60])), (:w,) => PriorMH(Dirichlet(2, 1.0)), - (:μ, :w) => RWMH(1), + (:μ,) => RWMH(1), ), ), - 10000; + 100000; initial_params=(z=rand(Categorical([0.3, 0.7]), 60), μ=[0.0, 1.0], w=[0.3, 0.7]), -) +); + +z_samples = [sample.values.z for sample in samples][20001:end] +μ_samples = [sample.values.μ for sample in samples][20001:end] +w_samples = [sample.values.w for sample in samples][20001:end] + +mean(μ_samples) +mean(w_samples) diff --git a/gibbs_example/gmm.jl b/gibbs_example/gmm.jl index 7a8eab79..a843eafc 100644 --- a/gibbs_example/gmm.jl +++ b/gibbs_example/gmm.jl @@ -44,42 +44,16 @@ function condition(gmm::GMM, conditioned_values::NamedTuple) return ConditionedGMM(gmm.data, conditioned_values) end -function _logdensity(gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}}, params) - return log_joint(; - μ=gmm.conditioned_values.μ, w=gmm.conditioned_values.w, z=params.z, x=gmm.data.x - ) -end - -function _logdensity(gmm::ConditionedGMM{(:z,)}, params) - return log_joint(; μ=params.μ, w=params.w, z=gmm.conditioned_values.z, x=gmm.data.x) -end - -function LogDensityProblems.logdensity( - gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}}, - params_vec::AbstractVector, -) - @assert length(params_vec) == 60 - return _logdensity(gmm, (; z=params_vec)) -end -function LogDensityProblems.logdensity( - gmm::ConditionedGMM{(:z,)}, params_vec::AbstractVector -) - @assert length(params_vec) == 4 "length(params_vec) = $(length(params_vec))" - return _logdensity(gmm, (; μ=params_vec[1:2], w=params_vec[3:4])) -end - -function LogDensityProblems.dimension(gmm::GMM) - return 4 + size(gmm.data.x, 1) -end - -function LogDensityProblems.dimension( - gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}} -) - return 4 -end - -function LogDensityProblems.dimension(gmm::ConditionedGMM{(:z,)}) - return size(gmm.data.x, 1) +function LogDensityProblems.logdensity(gmm::ConditionedGMM{names}, params::AbstractVector) where {names} + if Set(names) == Set([:μ, :w]) # conditioned on μ, w, so params are z + return log_joint(; μ=gmm.conditioned_values.μ, w=gmm.conditioned_values.w, z=params, x=gmm.data.x) + elseif Set(names) == Set([:z, :w]) # conditioned on z, w, so params are μ + return log_joint(; μ=params, w=gmm.conditioned_values.w, z=gmm.conditioned_values.z, x=gmm.data.x) + elseif Set(names) == Set([:z, :μ]) # conditioned on z, μ, so params are w + return log_joint(; μ=gmm.conditioned_values.μ, w=params, z=gmm.conditioned_values.z, x=gmm.data.x) + else + error("Unsupported conditioning configuration.") + end end function LogDensityProblems.capabilities(::GMM) @@ -91,41 +65,22 @@ function LogDensityProblems.capabilities(::ConditionedGMM) end function flatten(nt::NamedTuple) - if Set(keys(nt)) == Set([:μ, :w]) - return vcat(nt.μ, nt.w) - elseif Set(keys(nt)) == Set([:z]) - return nt.z - else - error() - end + return only(values(nt)) end function unflatten(vec::AbstractVector, group::Tuple) - if Set(group) == Set([:μ, :w]) - return (; μ=vec[1:2], w=vec[3:4]) - elseif Set(group) == Set([:z]) - return (; z=vec) - else - error() - end + return NamedTuple((only(group) => vec,)) end -# sampler's states to internal representation -# ? who gets to define the output of `getparams`? (maybe have a `getparams(T, state)`?) - -# the point here is that the parameter values are not changed, but because the context was changed, the logprob need to be recomputed function recompute_logprob!!(gmm::ConditionedGMM, vals, state) - return setlogp!(state, _logdensity(gmm, vals)) + return setlogp!!(state, LogDensityProblems.logdensity(gmm, vals)) end ## test using Turing # data generation -using Distributions using FillArrays -using LinearAlgebra -using Random w = [0.5, 0.5] μ = [-3.5, 0.5]