Skip to content

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Sep 24, 2025

This PR is being tested against this DynamicPPL branch: TuringLang/DynamicPPL.jl#1057

It should be noted that due to the changes in DynamicPPL's src/sampler.jl, the results of running MCMC sampling on this branch will pretty much always differ from that on the main branch. Thus there is no (easy) way to test full reproducibility of MCMC results (we have to rely instead on statistics for converged chains).

TODO:

  • pMCMC (it at least runs and gives sensible results on simple models, proper tests will have to wait for CI to run)
  • Gibbs (same as above)
  • fix initial_params argument for most samplers to require AbstractInitStrategy
  • fix tests
  • changelog

Separate PRs:

  • use InitStrategy for optimisation as well

    Note that the three pre-existing InitStrategies can be used directly with optimisation. However, to handle constraints properly, it seems necessary to introduce a new subtype of AbstractInitStrategy. I think this should be a separate PR because it's a fair bit of work.

  • fix docs for that argument, wherever it is (there's probably some in AbstractMCMC but it should probably be documented on the main site) EDIT: https://turinglang.org/docs/usage/sampling-options/#specifying-initial-parameters

@penelopeysm penelopeysm marked this pull request as draft September 24, 2025 18:06
Comment on lines -448 to -453
# Get the initial values for this component sampler.
initial_params_local = if initial_params === nothing
nothing
else
DynamicPPL.subset(vi, varnames)[:]
end
Copy link
Member Author

@penelopeysm penelopeysm Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was quite pleased with this discovery. Previously the initial params had to be subsetted to be the correct length for the conditioned model. That's not only a faff, but also I get a bit scared whenever there's direct VarInfo manipulation like this.

Now, if you use InitFromParams with a NamedTuple/Dict that has extra params, the extra params are just ignored. So no need to subset it at all, just pass it through directly!

Comment on lines -181 to -182
# TODO(DPPL0.38/penelopeysm): This function should no longer be needed
# once InitContext is merged.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately set_namedtuple! is used elsewhere in this file (though it won't appear in this diff) so we can't delete it (yet)

Comment on lines 406 to 416
function DynamicPPL.tilde_assume!!(
context::MHContext, right::Distribution, vn::VarName, vi::AbstractVarInfo
)
# Just defer to `SampleFromPrior`.
retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi)
return retval
# Allow MH to sample new variables from the prior if it's not already present in the
# VarInfo.
dispatch_ctx = if haskey(vi, vn)
DynamicPPL.DefaultContext()
else
DynamicPPL.InitContext(context.rng, DynamicPPL.InitFromPrior())
end
return DynamicPPL.tilde_assume!!(dispatch_ctx, right, vn, vi)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The behaviour of SampleFromPrior used to be: if the key is present, don't actually sample, and if it was absent, sample. This if/else replicates the old behaviour.

sampler::S
varinfo::V
evaluator::E
resample::Bool
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For pMCMC, this Boolean field essentially replaces the del flag. Instead of set_all_del and unset_all_del we construct new TracedModel with this set to true and false respectively.

Comment on lines -111 to +116
@test sample(StableRNG(23), xy(), spl_xy, num_samples).value
sample(StableRNG(23), x12(), spl_x, num_samples).value
chn1 = sample(StableRNG(23), xy(), spl_xy, num_samples)
chn2 = sample(StableRNG(23), x12(), spl_x, num_samples)

@test mean(chn1[:z]) mean(chn2[:z]) atol = 0.05
@test mean(chn1[:x]) mean(chn2["x[1]"]) atol = 0.05
@test mean(chn1[:y]) mean(chn2["x[2]"]) atol = 0.05
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The values are no longer exactly the same (it has something to do with initialisation behaviour which is different for the two models). But we can still check that the results are sensibly similar, which is probably also more meaningful anyway as it means that ESS not only works on both models but also consistently converges regardless of how the model is specified.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we are at it, shall we add a ≈ -3.0 on the last line?

It's a shame that these don't match exactly any more, but I guess if something about the order of sampling changes in Gibbs or some such then what can you do.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems you can't chain comparisons under @test when using atol. I'm superficial enough that that makes me lean towards not checking against -3.0.

Copy link
Contributor

Turing.jl documentation for PR #2676 is available at:
https://TuringLang.github.io/Turing.jl/previews/PR2676/

Copy link

codecov bot commented Sep 24, 2025

Codecov Report

❌ Patch coverage is 16.09195% with 73 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (breaking@ff8d01e). Learn more about missing BASE report.

Files with missing lines Patch % Lines
src/mcmc/particle_mcmc.jl 0.00% 25 Missing ⚠️
src/mcmc/emcee.jl 0.00% 11 Missing ⚠️
src/mcmc/mh.jl 0.00% 11 Missing ⚠️
src/mcmc/is.jl 0.00% 8 Missing ⚠️
src/mcmc/repeat_sampler.jl 0.00% 6 Missing ⚠️
src/mcmc/external_sampler.jl 0.00% 3 Missing ⚠️
src/mcmc/ess.jl 0.00% 2 Missing ⚠️
src/mcmc/hmc.jl 50.00% 2 Missing ⚠️
ext/TuringOptimExt.jl 0.00% 1 Missing ⚠️
src/mcmc/Inference.jl 0.00% 1 Missing ⚠️
... and 3 more
Additional details and impacted files
@@             Coverage Diff             @@
##             breaking    #2676   +/-   ##
===========================================
  Coverage            ?   18.52%           
===========================================
  Files               ?       22           
  Lines               ?     1382           
  Branches            ?        0           
===========================================
  Hits                ?      256           
  Misses              ?     1126           
  Partials            ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment on lines +133 to +140
seed = if dist isa GeneralizedExtremeValue
# GEV is prone to giving really wacky results that are quite
# seed-dependent.
StableRNG(469)
else
StableRNG(468)
end
chn = sample(seed, m(), HMC(0.05, 20), n_samples)
Copy link
Member Author

@penelopeysm penelopeysm Sep 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Case in point:

julia> using Turing, StableRNGs

julia> dist = GeneralizedExtremeValue(0, 1, 0.5); @model m() = x ~ dist
m (generic function with 2 methods)

julia> mean(dist)
1.5449077018110322

julia> mean(sample(StableRNG(468), m(), HMC(0.05, 20), 10000; progress=false))
Mean
  parameters      mean
      Symbol   Float64

           x    3.9024


julia> mean(sample(StableRNG(469), m(), HMC(0.05, 20), 10000; progress=false))
Mean
  parameters      mean
      Symbol   Float64

           x    1.5868

@penelopeysm penelopeysm marked this pull request as ready for review September 25, 2025 13:24
@penelopeysm
Copy link
Member Author

For the record, 11 failing CI jobs is the expected number:

  • 8x failing jobs because [sources] is not understood on 1.10
  • 3x failing jobs because Libtask 1.12

There is also the failing job caused by base Julia segfault (#2655), but that's on 1.10 so overlaps with the first category.

@penelopeysm
Copy link
Member Author

@mhauru, I haven't run CI against the latest revisions like removal of the del flag, but I think this might be meaty enough as it stands and also any adjustments arising from that PR (like renaming islinked) should be quite trivial.

Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good stuff, some minor comments.

I'm wondering about how to merge this. Should be review the code here, but then hold off merging to breaking before all the 0.38 compat fixes are in and a release of 0.38 is out, so all the temporary source stuff etc. can go and we can see tests pass?

Comment on lines +4 to +5
new_context = DynamicPPL.setleafcontext(model.context, DynamicPPL.InitContext())
new_model = DynamicPPL.contextualize(model, new_context)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there a short-hand for these two lines now?

end
DynamicPPL.NodeTrait(::MHContext) = DynamicPPL.IsLeaf()

function DynamicPPL.tilde_assume!!(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Doesn't MH just use init to get initial values like any other sampler, and then evaluate logpdfs on proposed steps?

Also, this looks like dynamical dispatch, since the context type depends on haskey(vi, vn), which could be a performance issue.

trng = get_trace_local_rng_maybe(ctx.rng)
resample = get_trace_local_resampled_maybe(true)

dispatch_ctx = if ~haskey(vi, vn) || resample
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like with MH, I wonder about dynamic dispatch here. Might be unavoidable and/or inconsequential in this case though.

return DynamicPPL.init_strategy(spl.sampler)
end

function AbstractMCMC.sample(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see why these became necessary now. Is it something about the type hierarchy around Sampler?

Comment on lines -111 to +116
@test sample(StableRNG(23), xy(), spl_xy, num_samples).value
sample(StableRNG(23), x12(), spl_x, num_samples).value
chn1 = sample(StableRNG(23), xy(), spl_xy, num_samples)
chn2 = sample(StableRNG(23), x12(), spl_x, num_samples)

@test mean(chn1[:z]) mean(chn2[:z]) atol = 0.05
@test mean(chn1[:x]) mean(chn2["x[1]"]) atol = 0.05
@test mean(chn1[:y]) mean(chn2["x[2]"]) atol = 0.05
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While we are at it, shall we add a ≈ -3.0 on the last line?

It's a shame that these don't match exactly any more, but I guess if something about the order of sampling changes in Gibbs or some such then what can you do.

N = 1000
model = normal()
chain = sample(StableRNG(468), model, alg, N)
ref = reference(N)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Getting quite nit-picky, but could the RNG be an argument to reference, to ensure the two instance of StableRNG(468) remain in sync?

julia = "1.10"

[sources]
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving this comment just as a reminder that this needs to be removed before merging.

Optim = "429524aa-4258-5aef-a3af-852621145aeb"

[sources]
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "breaking"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise this.

You still need to use the `initial_params` keyword argument to `sample`, but the allowed values are different.
For almost all samplers in Turing.jl (except `Emcee`) this should now be a `DynamicPPL.AbstractInitStrategy`.

TODO LINK TO DPPL DOCS WHEN THIS IS LIVE
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise a reminder comment.


- `InitFromPrior()`: Sample from the prior distribution. This is the default for most samplers in Turing.jl (if you don't specify `initial_params`).
- `InitFromUniform(a, b)`: Sample uniformly from `[a, b]` in linked space. This is the default for Hamiltonian samplers. If `a` and `b` are not specified it defaults to `[-2, 2]`, which preserves the behaviour in previous versions (and mimics that of Stan).
- `InitFromParams(p)`: Explicitly provide a set of initial parameters. **Note: `p` must be either a `NamedTuple` or a `Dict{<:VarName}`; it can no longer be a `Vector`.** Parameters must be provided in unlinked space, even if the sampler later performs linking.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I recall correctly that you did end up implementing the option of providing an unwrapped NamedTuple or Dict as well?

Oh, also, just came to mind: Does it need to be a Dict, or can it be an AbstractDict?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants