-
Notifications
You must be signed in to change notification settings - Fork 228
Gibbs sampler #2647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Gibbs sampler #2647
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements a GibbsConditional sampler component for Turing.jl that allows users to provide analytical conditional distributions for variables in Gibbs sampling. The implementation enables mixing user-defined conditional distributions with other MCMC samplers within the Gibbs framework.
Key changes:
- Added
GibbsConditional
struct and supporting functions for analytical conditional sampling - Comprehensive test coverage for the new functionality
- Added example test file demonstrating usage
Reviewed Changes
Copilot reviewed 5 out of 6 changed files in this pull request and generated 5 comments.
Show a summary per file
File | Description |
---|---|
src/mcmc/gibbs_conditional.jl |
Core implementation of GibbsConditional sampler with step functions and variable handling |
src/mcmc/Inference.jl |
Added GibbsConditional export and module inclusion |
test/mcmc/gibbs.jl |
Added comprehensive test suite for GibbsConditional functionality |
test_gibbs_conditional.jl |
Example/demo file showing GibbsConditional usage |
HISTORY.md |
Version history update (unrelated to main feature) |
src/mcmc/gibbs_conditional.jl
Outdated
|
||
# For GibbsConditional within Gibbs, we need to get all variable values | ||
# Check if we're in a Gibbs context | ||
global_vi = if hasproperty(model, :context) && model.context isa GibbsContext |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using hasproperty
for checking if a field exists is fragile and could break with future changes to the model structure. Consider using a more robust method like isdefined
or checking the model type directly.
global_vi = if hasproperty(model, :context) && model.context isa GibbsContext | |
global_vi = if isdefined(model, :context) && model.context isa GibbsContext |
Copilot uses AI. Check for mistakes.
src/mcmc/gibbs_conditional.jl
Outdated
end | ||
end | ||
if !found | ||
error("Could not find variable $S in VarInfo") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The error message could be more helpful by suggesting what variables are available or providing debugging information about the VarInfo contents.
error("Could not find variable $S in VarInfo") | |
error("Could not find variable $S in VarInfo. Available variables: $(join([string(DynamicPPL.getsym(k)) for k in keys(state)], \", \")).") |
Copilot uses AI. Check for mistakes.
src/mcmc/gibbs_conditional.jl
Outdated
else | ||
# Try to find the variable with indices | ||
# This handles cases where the variable might have indices | ||
local updated_vi = state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The local
keyword is unnecessary here since updated_vi
is already in a local scope. This adds visual clutter without functional benefit.
local updated_vi = state | |
updated_vi = state |
Copilot uses AI. Check for mistakes.
HISTORY.md
Outdated
@@ -1,3 +1,7 @@ | |||
# 0.39.10 | |||
|
|||
Added a compatibility entry for DataStructures v0.19. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The HISTORY.md change appears unrelated to the GibbsConditional implementation and should be in a separate commit or PR to maintain clean version history.
Added a compatibility entry for DataStructures v0.19. |
Copilot uses AI. Check for mistakes.
|
||
# Sample using GibbsConditional | ||
println("Testing GibbsConditional sampler...") | ||
sampler = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => GibbsConditional(:m, cond_m)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The variable name is redundantly specified in both the Gibbs pair key and GibbsConditional constructor. Consider if this duplication is necessary or if the API could be simplified.
sampler = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => GibbsConditional(:m, cond_m)) | |
sampler = Gibbs(:λ => GibbsConditional(cond_λ), :m => GibbsConditional(cond_m)) |
Copilot uses AI. Check for mistakes.
Turing.jl documentation for PR #2647 is available at: |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2647 +/- ##
==========================================
- Coverage 84.88% 79.94% -4.94%
==========================================
Files 22 23 +1
Lines 1475 1506 +31
==========================================
- Hits 1252 1204 -48
- Misses 223 302 +79 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Pull Request Test Coverage Report for Build 16831838756Details
💛 - Coveralls |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed on Slack, rather than doing a full review, I'm just going to give some high level comments and pointers for where to find more details on some of the relevant context.
x = c.x | ||
n = length(x) | ||
a_new = a + (n + 1) / 2 | ||
b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likewise comparing to https://github.com/TuringLang/Turing.jl/blob/v0.35.5/src/mcmc/gibbs_conditional.jl, should the m
in the variance term rather be the mean of x
?
src/mcmc/gibbs_conditional.jl
Outdated
|
||
# For GibbsConditional within Gibbs, we need to get all variable values | ||
# Check if we're in a Gibbs context | ||
global_vi = if hasproperty(model, :context) && model.context isa GibbsContext |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The core idea here of finding the possible GibbsContext
and getting the global varinfo from it is good. However, GibbsContext is a bit weird, in that it's always inserted at the bottom of the context stack. By the context stack I mean the fact that contexts often have child contexts, and thus model.context
may in fact be many nested contexts. See e.g. how the GibbsContext is set here, by calling setleafcontext
rather than setcontext
:
Line 258 in d75e6f2
gibbs_context = DynamicPPL.setleafcontext(model.context, gibbs_context_inner) |
So rather than check whether model.context isa GibbsContext
, I think you'll need to traverse the whole context stack, and check if any of them are a GibbsContext
, until you hit a leaf context and the stack ends.
Moreover, I think you'll need to check not just for GibbsContext, but also for ConditionContext
and FixedContext
, which condition/fix the values of some variables. So all in all, if you go through the whole stack, starting with model.context
and going through its child contexts, and collect any variables set in ConditionContext
, FixedContext
, and GibbsContext
, that should give you all of the variable values you need. See here for more details on condition and fix: https://github.com/TuringLang/DynamicPPL.jl/blob/1ed8cc8d9f013f46806c88a83e93f7a4c5b891dd/src/contexts.jl#L258
As mentioned on Slack a week or two ago, all this context stack business is likely changing Soon (TM), since @penelopeysm is overhauling condition
and fix
over here, TuringLang/DynamicPPL.jl#1010, and as a result we may be able to overhaul GibbsContext
as well. You could wait for that to be finished first, at least if it looks like getting this to work would be a lot of work.
src/mcmc/gibbs_conditional.jl
Outdated
updated = rand(rng, conddist) | ||
|
||
# Update the variable in state | ||
# We need to get the actual VarName for this variable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The operating principle of the new(ish) Gibbs sampler is that every component sampler only ever sees a VarInfo with the variables that that component sampler is supposed to sample. Thus, you should be able to assume that updated
includes values for all the variables in state
, and for nothing else. Hence the below checks and loops I think shouldn't be necessary. The solution be might be as simple as new_state = unflatten(state, updated)
, though there may be details there that I'm not thinking of right now. (What if state
is linked? But maybe we can guarantee that it's never linked, because the sampler can control it.) Happy to discuss details more if unflatten
by itself doesn't seem to cut it.
src/mcmc/gibbs_conditional.jl
Outdated
), 1000) | ||
``` | ||
""" | ||
struct GibbsConditional{S,C} <: InferenceAlgorithm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the type parameter S
shouldn't be necessary once we don't use it any more to construct the VarName that is being sampled. See below comments for more details.
src/mcmc/gibbs_conditional.jl
Outdated
end | ||
|
||
# Update log joint probability | ||
new_vi = last(DynamicPPL.evaluate!!(model, new_vi, DynamicPPL.DefaultContext())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you shouldn't need this, because the log joint is going to be recomputed anyway by the Gibbs sampler once it's looped over all component samplers. Saves one model evaluation.
src/mcmc/gibbs_conditional.jl
Outdated
|
||
Perform a single step of GibbsConditional sampling. | ||
""" | ||
function gibbs_step_recursive( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would hope you wouldn't need to overload this or gibbs_initialstep_recursive
. Also, the below implementation seems to be just a repeat of step
.
model = inverse_gdemo(x_obs) | ||
|
||
# Sample using GibbsConditional | ||
println("Testing GibbsConditional sampler...") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You probably want various @test
calls rather than a lot of prints.
```julia | ||
# Define a model | ||
@model function inverse_gdemo(x) | ||
λ ~ Gamma(2, 3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comparing to https://github.com/TuringLang/Turing.jl/blob/v0.35.5/src/mcmc/gibbs_conditional.jl, should the distribution be Gamma(2, inv(3))
?
WRT: #2547