Skip to content

Gibbs sampler #2647

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

Gibbs sampler #2647

wants to merge 7 commits into from

Conversation

AoifeHughes
Copy link
Contributor

WRT: #2547

@AoifeHughes AoifeHughes requested a review from Copilot August 7, 2025 08:15
@AoifeHughes AoifeHughes self-assigned this Aug 7, 2025
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements a GibbsConditional sampler component for Turing.jl that allows users to provide analytical conditional distributions for variables in Gibbs sampling. The implementation enables mixing user-defined conditional distributions with other MCMC samplers within the Gibbs framework.

Key changes:

  • Added GibbsConditional struct and supporting functions for analytical conditional sampling
  • Comprehensive test coverage for the new functionality
  • Added example test file demonstrating usage

Reviewed Changes

Copilot reviewed 5 out of 6 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
src/mcmc/gibbs_conditional.jl Core implementation of GibbsConditional sampler with step functions and variable handling
src/mcmc/Inference.jl Added GibbsConditional export and module inclusion
test/mcmc/gibbs.jl Added comprehensive test suite for GibbsConditional functionality
test_gibbs_conditional.jl Example/demo file showing GibbsConditional usage
HISTORY.md Version history update (unrelated to main feature)


# For GibbsConditional within Gibbs, we need to get all variable values
# Check if we're in a Gibbs context
global_vi = if hasproperty(model, :context) && model.context isa GibbsContext
Copy link
Preview

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using hasproperty for checking if a field exists is fragile and could break with future changes to the model structure. Consider using a more robust method like isdefined or checking the model type directly.

Suggested change
global_vi = if hasproperty(model, :context) && model.context isa GibbsContext
global_vi = if isdefined(model, :context) && model.context isa GibbsContext

Copilot uses AI. Check for mistakes.

end
end
if !found
error("Could not find variable $S in VarInfo")
Copy link
Preview

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message could be more helpful by suggesting what variables are available or providing debugging information about the VarInfo contents.

Suggested change
error("Could not find variable $S in VarInfo")
error("Could not find variable $S in VarInfo. Available variables: $(join([string(DynamicPPL.getsym(k)) for k in keys(state)], \", \")).")

Copilot uses AI. Check for mistakes.

else
# Try to find the variable with indices
# This handles cases where the variable might have indices
local updated_vi = state
Copy link
Preview

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The local keyword is unnecessary here since updated_vi is already in a local scope. This adds visual clutter without functional benefit.

Suggested change
local updated_vi = state
updated_vi = state

Copilot uses AI. Check for mistakes.

HISTORY.md Outdated
@@ -1,3 +1,7 @@
# 0.39.10

Added a compatibility entry for DataStructures v0.19.
Copy link
Preview

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The HISTORY.md change appears unrelated to the GibbsConditional implementation and should be in a separate commit or PR to maintain clean version history.

Suggested change
Added a compatibility entry for DataStructures v0.19.

Copilot uses AI. Check for mistakes.


# Sample using GibbsConditional
println("Testing GibbsConditional sampler...")
sampler = Gibbs(:λ => GibbsConditional(:λ, cond_λ), :m => GibbsConditional(:m, cond_m))
Copy link
Preview

Copilot AI Aug 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The variable name is redundantly specified in both the Gibbs pair key and GibbsConditional constructor. Consider if this duplication is necessary or if the API could be simplified.

Suggested change
sampler = Gibbs( => GibbsConditional(, cond_λ), :m => GibbsConditional(:m, cond_m))
sampler = Gibbs( => GibbsConditional(cond_λ), :m => GibbsConditional(cond_m))

Copilot uses AI. Check for mistakes.

@AoifeHughes AoifeHughes changed the base branch from breaking to main August 7, 2025 08:16
Copy link
Contributor

github-actions bot commented Aug 7, 2025

Turing.jl documentation for PR #2647 is available at:
https://TuringLang.github.io/Turing.jl/previews/PR2647/

Copy link

codecov bot commented Aug 7, 2025

Codecov Report

❌ Patch coverage is 0% with 31 lines in your changes missing coverage. Please review.
✅ Project coverage is 79.94%. Comparing base (d75e6f2) to head (94b723d).

Files with missing lines Patch % Lines
src/mcmc/gibbs_conditional.jl 0.00% 31 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2647      +/-   ##
==========================================
- Coverage   84.88%   79.94%   -4.94%     
==========================================
  Files          22       23       +1     
  Lines        1475     1506      +31     
==========================================
- Hits         1252     1204      -48     
- Misses        223      302      +79     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@coveralls
Copy link

coveralls commented Aug 7, 2025

Pull Request Test Coverage Report for Build 16831838756

Details

  • 0 of 31 (0.0%) changed or added relevant lines in 1 file are covered.
  • 48 unchanged lines in 2 files lost coverage.
  • Overall coverage decreased (-4.9%) to 80.053%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/mcmc/gibbs_conditional.jl 0 31 0.0%
Files with Coverage Reduction New Missed Lines %
src/mcmc/repeat_sampler.jl 10 50.0%
src/mcmc/gibbs.jl 38 68.91%
Totals Coverage Status
Change from base Build 16800936663: -4.9%
Covered Lines: 1204
Relevant Lines: 1504

💛 - Coveralls

Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed on Slack, rather than doing a full review, I'm just going to give some high level comments and pointers for where to find more details on some of the relevant context.

x = c.x
n = length(x)
a_new = a + (n + 1) / 2
b_new = b + sum((x[i] - m)^2 for i in 1:n) / 2 + m^2 / 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise comparing to https://github.com/TuringLang/Turing.jl/blob/v0.35.5/src/mcmc/gibbs_conditional.jl, should the m in the variance term rather be the mean of x?


# For GibbsConditional within Gibbs, we need to get all variable values
# Check if we're in a Gibbs context
global_vi = if hasproperty(model, :context) && model.context isa GibbsContext
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The core idea here of finding the possible GibbsContext and getting the global varinfo from it is good. However, GibbsContext is a bit weird, in that it's always inserted at the bottom of the context stack. By the context stack I mean the fact that contexts often have child contexts, and thus model.context may in fact be many nested contexts. See e.g. how the GibbsContext is set here, by calling setleafcontext rather than setcontext:

gibbs_context = DynamicPPL.setleafcontext(model.context, gibbs_context_inner)

So rather than check whether model.context isa GibbsContext, I think you'll need to traverse the whole context stack, and check if any of them are a GibbsContext, until you hit a leaf context and the stack ends.

Moreover, I think you'll need to check not just for GibbsContext, but also for ConditionContext and FixedContext, which condition/fix the values of some variables. So all in all, if you go through the whole stack, starting with model.context and going through its child contexts, and collect any variables set in ConditionContext, FixedContext, and GibbsContext, that should give you all of the variable values you need. See here for more details on condition and fix: https://github.com/TuringLang/DynamicPPL.jl/blob/1ed8cc8d9f013f46806c88a83e93f7a4c5b891dd/src/contexts.jl#L258

As mentioned on Slack a week or two ago, all this context stack business is likely changing Soon (TM), since @penelopeysm is overhauling condition and fix over here, TuringLang/DynamicPPL.jl#1010, and as a result we may be able to overhaul GibbsContext as well. You could wait for that to be finished first, at least if it looks like getting this to work would be a lot of work.

updated = rand(rng, conddist)

# Update the variable in state
# We need to get the actual VarName for this variable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The operating principle of the new(ish) Gibbs sampler is that every component sampler only ever sees a VarInfo with the variables that that component sampler is supposed to sample. Thus, you should be able to assume that updated includes values for all the variables in state, and for nothing else. Hence the below checks and loops I think shouldn't be necessary. The solution be might be as simple as new_state = unflatten(state, updated), though there may be details there that I'm not thinking of right now. (What if state is linked? But maybe we can guarantee that it's never linked, because the sampler can control it.) Happy to discuss details more if unflatten by itself doesn't seem to cut it.

), 1000)
```
"""
struct GibbsConditional{S,C} <: InferenceAlgorithm
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the type parameter S shouldn't be necessary once we don't use it any more to construct the VarName that is being sampled. See below comments for more details.

end

# Update log joint probability
new_vi = last(DynamicPPL.evaluate!!(model, new_vi, DynamicPPL.DefaultContext()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you shouldn't need this, because the log joint is going to be recomputed anyway by the Gibbs sampler once it's looped over all component samplers. Saves one model evaluation.


Perform a single step of GibbsConditional sampling.
"""
function gibbs_step_recursive(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would hope you wouldn't need to overload this or gibbs_initialstep_recursive. Also, the below implementation seems to be just a repeat of step.

model = inverse_gdemo(x_obs)

# Sample using GibbsConditional
println("Testing GibbsConditional sampler...")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably want various @test calls rather than a lot of prints.

```julia
# Define a model
@model function inverse_gdemo(x)
λ ~ Gamma(2, 3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing to https://github.com/TuringLang/Turing.jl/blob/v0.35.5/src/mcmc/gibbs_conditional.jl, should the distribution be Gamma(2, inv(3))?

@mhauru mhauru self-requested a review August 8, 2025 13:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants