Skip to content

Commit

Permalink
more progress; still need to deal with w being on simplex
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Aug 15, 2024
1 parent 590d37f commit 3afc232
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 48 deletions.
1 change: 1 addition & 0 deletions gibbs_example/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
100 changes: 83 additions & 17 deletions gibbs_example/gibbs.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
using AbstractMCMC
using LogDensityProblems, Distributions, LinearAlgebra, Random
using OrderedCollections

##

# TODO: introduce some kind of parameter format, for instance, a flattened vector
# then define some kind of function to transform the flattened vector into model's representation

struct Gibbs <: AbstractMCMC.AbstractSampler
sampler_map::OrderedDict
end

struct GibbsState
values::NamedTuple
vi::NamedTuple
states::OrderedDict
end

Expand All @@ -15,31 +21,91 @@ struct GibbsTransition
end

function AbstractMCMC.step(
rng::AbstractRNG, model, sampler::Gibbs, args...; initial_params::NamedTuple, kwargs...
rng::AbstractRNG,
logdensity_model::AbstractMCMC.LogDensityModel,
spl::Gibbs,
args...;
initial_params::NamedTuple,
kwargs...,
)
states = OrderedDict()
for group in keys(sampler.sampler_map)
sampler = sampler.sampler_map[group]
cond_val = NamedTuple{group}([initial_params[g] for g in group]...)
trans, state = AbstractMCMC.step(
rng, condition(model, cond_val), sampler, args...; kwargs...
for group in keys(spl.sampler_map)
sub_spl = spl.sampler_map[group]

vars_to_be_conditioned_on = setdiff(keys(initial_params), group)
cond_val = NamedTuple{Tuple(vars_to_be_conditioned_on)}(
Tuple([initial_params[g] for g in vars_to_be_conditioned_on])
)
params_val = NamedTuple{Tuple(group)}(Tuple([initial_params[g] for g in group]))
sub_state = last(
AbstractMCMC.step(
rng,
AbstractMCMC.LogDensityModel(
condition(logdensity_model.logdensity, cond_val)
),
sub_spl,
args...;
initial_params=flatten(params_val),
kwargs...,
),
)
states[group] = state
states[group] = sub_state
end
return GibbsTransition(initial_params), GibbsState(initial_params, states)
end

function AbstractMCMC.step(
rng::AbstractRNG, model, sampler::Gibbs, state::GibbsState, args...; kwargs...
rng::AbstractRNG,
logdensity_model::AbstractMCMC.LogDensityModel,
spl::Gibbs,
state::GibbsState,
args...;
kwargs...,
)
for group in collect(keys(sampler.sampler_map))
sampler = sampler.sampler_map[group]
state = state.states[group]
trans, state = AbstractMCMC.step(
rng, condition(model, state.values[group]), sampler, state, args...; kwargs...
vi = state.vi
for group in keys(spl.sampler_map)
for (group, sub_state) in state.states
vi = merge(vi, unflatten(getparams(sub_state), group))
end
sub_spl = spl.sampler_map[group]
sub_state = state.states[group]
group_complement = setdiff(keys(vi), group)
cond_val = NamedTuple{Tuple(group_complement)}(
Tuple([vi[g] for g in group_complement])
)
sub_state = last(
AbstractMCMC.step(
rng,
AbstractMCMC.LogDensityModel(
condition(logdensity_model.logdensity, cond_val)
),
sub_spl,
sub_state,
args...;
kwargs...,
),
)
# TODO: what values to condition on here? stored where?
state.states[group] = state
state.states[group] = sub_state
end
return nothing
for sub_state in values(state.states)
vi = merge(vi, getparams(sub_state))
end
return GibbsTransition(vi), GibbsState(vi, state.states)
end

## tests

gmm = GMM((; x=x))

samples = sample(
gmm,
Gibbs(
OrderedDict(
(:z,) => PriorMH(product_distribution([Categorical([0.3, 0.7]) for _ in 1:60])),
(:w,) => PriorMH(Dirichlet(2, 1.0)),
(, :w) => RWMH(1),
),
),
10000;
initial_params=(z=rand(Categorical([0.3, 0.7]), 60), μ=[0.0, 1.0], w=[0.3, 0.7]),
)
56 changes: 52 additions & 4 deletions gibbs_example/gmm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ function log_joint(; μ, w, z, x)
logp += logpdf(w_prior, w)

z_prior = Categorical(w)

logp += sum([logpdf(z_prior, z[i]) for i in 1:N])

obs_priors = [MvNormal(fill(μₖ, D), I) for μₖ in μ]
Expand All @@ -43,33 +44,80 @@ function condition(gmm::GMM, conditioned_values::NamedTuple)
return ConditionedGMM(gmm.data, conditioned_values)
end

function _logdensity(gmm::ConditionedGMM{(:μ, :w)}, params)
function _logdensity(gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}}, params)
return log_joint(;
μ=gmm.conditioned_values.μ, w=gmm.conditioned_values.w, z=params.z, x=gmm.data.x
)
end

function _logdensity(gmm::ConditionedGMM{(:z,)}, params)
return log_joint(; μ=params.μ, w=params.w, z=gmm.conditioned_values.z, x=gmm.data.x)
end

function LogDensityProblems.logdensity(
gmm::ConditionedGMM{(:μ, :w)}, params_vec::AbstractVector
gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}},
params_vec::AbstractVector,
)
@assert length(params_vec) == 60
return _logdensity(gmm, (; z=params_vec))
end
function LogDensityProblems.logdensity(
gmm::ConditionedGMM{(:z,)}, params_vec::AbstractVector
)
@assert length(params_vec) == 4 "length(params_vec) = $(length(params_vec))"
return _logdensity(gmm, (; μ=params_vec[1:2], w=params_vec[3:4]))
end

function LogDensityProblems.dimension(gmm::ConditionedGMM{(:μ, :w)})
return size(gmm.data.x, 1)
function LogDensityProblems.dimension(gmm::GMM)
return 4 + size(gmm.data.x, 1)
end

function LogDensityProblems.dimension(
gmm::Union{ConditionedGMM{(:μ, :w)},ConditionedGMM{(:w, :μ)}}
)
return 4
end

function LogDensityProblems.dimension(gmm::ConditionedGMM{(:z,)})
return size(gmm.data.x, 1)
end

function LogDensityProblems.capabilities(::GMM)
return LogDensityProblems.LogDensityOrder{0}()
end

function LogDensityProblems.capabilities(::ConditionedGMM)
return LogDensityProblems.LogDensityOrder{0}()
end

function flatten(nt::NamedTuple)
if Set(keys(nt)) == Set([, :w])
return vcat(nt.μ, nt.w)
elseif Set(keys(nt)) == Set([:z])
return nt.z
else
error()
end
end

function unflatten(vec::AbstractVector, group::Tuple)
if Set(group) == Set([, :w])
return (; μ=vec[1:2], w=vec[3:4])
elseif Set(group) == Set([:z])
return (; z=vec)
else
error()
end
end

# sampler's states to internal representation
# ? who gets to define the output of `getparams`? (maybe have a `getparams(T, state)`?)

# the point here is that the parameter values are not changed, but because the context was changed, the logprob need to be recomputed
function recompute_logprob!!(gmm::ConditionedGMM, vals, state)
return setlogp!(state, _logdensity(gmm, vals))
end

## test using Turing

# data generation
Expand Down
Loading

0 comments on commit 3afc232

Please sign in to comment.