Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes and improvements to experimental Gibbs #2231

Merged
merged 40 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0b2279f
moved new Gibbs tests all into a single block
torfjelde Apr 23, 2024
dcad548
initial work on making Gibbs work with `externalsampler`
torfjelde Apr 23, 2024
8e2d7be
Merge branch 'master' into torfjelde/gibbs-new-improv
torfjelde May 18, 2024
4a609cb
removed references to Setfield.jl
torfjelde May 18, 2024
fc21894
fixed crucial bug in experimental Gibbs sampler
torfjelde May 18, 2024
9910962
added ground-truth comparison for Gibbs sampler on demo models
torfjelde May 18, 2024
b3a4692
added convenience method for performing two sample KS test
torfjelde May 18, 2024
3e17efc
use thinning to avoid OOM issues
torfjelde May 20, 2024
429fc8f
removed incredibly slow testset that didn't really add much
torfjelde May 20, 2024
f6af20e
removed now-redundant testset
torfjelde May 20, 2024
065eef6
use Anderson-Darling test instead of Kolomogorov-Smirnov to better
torfjelde May 20, 2024
c2d23e5
Merge branch 'master' into torfjelde/gibbs-new-improv
torfjelde Jun 4, 2024
99f28f9
more work on testing
torfjelde Jun 6, 2024
6df1ccc
Merge branch 'master' into torfjelde/gibbs-new-improv
torfjelde Jun 6, 2024
b6a907e
fixed tests
torfjelde Jun 6, 2024
a4e223e
Merge remote-tracking branch 'origin/torfjelde/gibbs-new-improv' into…
torfjelde Jun 6, 2024
e1e7386
make failures of `two_sample_ad_tests` a bit more informative
torfjelde Jun 6, 2024
be1ec7f
make failrues of `two_sample_ad_test` produce more informative logs
torfjelde Jun 6, 2024
5f36446
additional information upon `two_sample_ad_test` failure
torfjelde Jun 6, 2024
3be8f8b
rename `two_sample_ad_test` to `two_sample_test` and use KS test instead
torfjelde Jun 6, 2024
dbaf447
added minor test for externalsampler usage
torfjelde Jun 16, 2024
f44c407
also test AdvancedHMC samplers with Gibbs
torfjelde Jun 16, 2024
dd86cfa
forgot to add updates to src/mcmc/abstractmcmc.jl in previous commits
torfjelde Jun 17, 2024
4160577
Merge remote-tracking branch 'origin/torfjelde/gibbs-new-improv' into…
torfjelde Jun 17, 2024
2a7d85b
Merge branch 'master' into torfjelde/gibbs-new-improv
torfjelde Jun 17, 2024
bdc61fe
removed usage of `timeit_testset` macro
torfjelde Jun 17, 2024
d76243e
added temporary fix for externalsampler that needs to be removed once
torfjelde Jun 17, 2024
14f5c89
minor reorg of two testsets
torfjelde Jun 17, 2024
5893d54
set random seeds more aggressively in an attempt to make tests more r…
torfjelde Jun 18, 2024
4a2cea2
Merge branch 'master' into torfjelde/gibbs-new-improv
yebai Jun 18, 2024
4f30ea5
removed hack, awaiting PR to DynamicPPL
torfjelde Jun 18, 2024
414a077
Merge branch 'master' into torfjelde/gibbs-new-improv
yebai Jun 26, 2024
89bc2e1
renamed `_getmodel` to `getmodel`, `_setmodel` to `setmodel`, and
torfjelde Jun 26, 2024
3d3c944
missed some instances during rnenaming
torfjelde Jun 26, 2024
e1f1a0e
fixed missing merge in initial step for experimental `Gibbs`
torfjelde Jul 10, 2024
7c4368e
Always reconstruct `ADGradientWrapper` using the `adype` available in…
torfjelde Jul 15, 2024
06357c6
Test Gibbs with different adtype in externalsampler to ensure that works
torfjelde Jul 15, 2024
02f9fad
Update Project.toml
yebai Jul 15, 2024
30ab9e0
Update Project.toml
yebai Jul 15, 2024
d40d82b
Merge branch 'master' into torfjelde/gibbs-new-improv
yebai Jul 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 36 additions & 17 deletions src/experimental/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
# Short-circuits the tilde assume if `vn` is present in `context`.
if has_conditioned_gibbs(context, vns)
value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns))
return value, broadcast_logpdf(right, values), vi
return value, broadcast_logpdf(right, value), vi

Check warning on line 81 in src/experimental/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/experimental/gibbs.jl#L81

Added line #L81 was not covered by tests
end

# Otherwise, falls back to the default behavior.
Expand All @@ -90,8 +90,8 @@
)
# Short-circuits the tilde assume if `vn` is present in `context`.
if has_conditioned_gibbs(context, vns)
values = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns))
return values, broadcast_logpdf(right, values), vi
value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns))
return value, broadcast_logpdf(right, value), vi

Check warning on line 94 in src/experimental/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/experimental/gibbs.jl#L93-L94

Added lines #L93 - L94 were not covered by tests
end

# Otherwise, falls back to the default behavior.
Expand Down Expand Up @@ -144,14 +144,14 @@
Return a `GibbsContext` with the values extracted from the given `varinfos` treated as conditioned.
"""
function condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo)
return DynamicPPL.condition(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo)))
return condition_gibbs(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo)))

Check warning on line 147 in src/experimental/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/experimental/gibbs.jl#L147

Added line #L147 was not covered by tests
end
function DynamicPPL.condition(
function condition_gibbs(

Check warning on line 149 in src/experimental/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/experimental/gibbs.jl#L149

Added line #L149 was not covered by tests
context::DynamicPPL.AbstractContext,
varinfo::DynamicPPL.AbstractVarInfo,
varinfos::DynamicPPL.AbstractVarInfo...
)
return DynamicPPL.condition(DynamicPPL.condition(context, varinfo), varinfos...)
return condition_gibbs(condition_gibbs(context, varinfo), varinfos...)

Check warning on line 154 in src/experimental/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/experimental/gibbs.jl#L154

Added line #L154 was not covered by tests
end
# Allow calling this on a `DynamicPPL.Model` directly.
function condition_gibbs(model::DynamicPPL.Model, values...)
Expand Down Expand Up @@ -238,6 +238,9 @@
return Gibbs(map(first, algs), map(wrap_algorithm_maybe, map(last, algs)))
end

# TODO: Remove when no longer needed.
DynamicPPL.getspace(::Gibbs) = ()

Check warning on line 242 in src/experimental/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/experimental/gibbs.jl#L242

Added line #L242 was not covered by tests

struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S}
vi::V
states::S
Expand All @@ -252,6 +255,7 @@
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:Gibbs},
vi_base::DynamicPPL.AbstractVarInfo;
initial_params=nothing,
kwargs...,
)
alg = spl.alg
Expand All @@ -260,15 +264,35 @@

# 1. Run the model once to get the varnames present + initial values to condition on.
vi_base = DynamicPPL.VarInfo(model)

# Simple way of setting the initial parameters: set them in the `vi_base`
# if they are given so they propagate to the subset varinfos used by each sampler.
if initial_params !== nothing
vi_base = DynamicPPL.unflatten(vi_base, initial_params)

Check warning on line 271 in src/experimental/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/experimental/gibbs.jl#L270-L271

Added lines #L270 - L271 were not covered by tests
end

# Create the varinfos for each sampler.
varinfos = map(Base.Fix1(DynamicPPL.subset, vi_base) ∘ _maybevec, varnames)
initial_params_all = if initial_params === nothing
fill(nothing, length(varnames))

Check warning on line 277 in src/experimental/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/experimental/gibbs.jl#L276-L277

Added lines #L276 - L277 were not covered by tests
else
# Extract from the `vi_base`, which should have the values set correctly from above.
map(vi -> vi[:], varinfos)

Check warning on line 280 in src/experimental/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/experimental/gibbs.jl#L280

Added line #L280 was not covered by tests
end

# 2. Construct a varinfo for every vn + sampler combo.
states_and_varinfos = map(samplers, varinfos) do sampler_local, varinfo_local
states_and_varinfos = map(samplers, varinfos, initial_params_all) do sampler_local, varinfo_local, initial_params_local

Check warning on line 284 in src/experimental/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/experimental/gibbs.jl#L284

Added line #L284 was not covered by tests
# Construct the conditional model.
model_local = make_conditional(model, varinfo_local, varinfos)

# Take initial step.
new_state_local = last(AbstractMCMC.step(rng, model_local, sampler_local; kwargs...))
new_state_local = last(AbstractMCMC.step(

Check warning on line 289 in src/experimental/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/experimental/gibbs.jl#L289

Added line #L289 was not covered by tests
rng, model_local, sampler_local;
# FIXME: This will cause issues if the sampler expects initial params in unconstrained space.
# This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
initial_params=initial_params_local,
kwargs...
))

# Return the new state and the invlinked `varinfo`.
vi_local_state = Turing.Inference.varinfo(new_state_local)
Expand All @@ -284,7 +308,7 @@
varinfos = map(last, states_and_varinfos)

# Update the base varinfo from the first varinfo and replace it.
varinfos_new = DynamicPPL.setindex!!(varinfos, vi_base, 1)
varinfos_new = DynamicPPL.setindex!!(varinfos, merge(vi_base, first(varinfos)), 1)

Check warning on line 311 in src/experimental/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/experimental/gibbs.jl#L311

Added line #L311 was not covered by tests
# Merge the updated initial varinfo with the rest of the varinfos + update the logp.
vi = DynamicPPL.setlogp!!(
reduce(merge, varinfos_new),
Expand Down Expand Up @@ -365,12 +389,7 @@
end

# TODO: Remove `rng`?
"""
recompute_logprob!!(rng, model, sampler, state)

Recompute the log-probability of the `model` based on the given `state` and return the resulting state.
"""
function recompute_logprob!!(
function Turing.Inference.recompute_logprob!!(

Check warning on line 392 in src/experimental/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/experimental/gibbs.jl#L392

Added line #L392 was not covered by tests
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler,
Expand Down Expand Up @@ -436,7 +455,7 @@
state_local,
state_previous
)
current_state_local = recompute_logprob!!(
state_local = Turing.Inference.recompute_logprob!!(

Check warning on line 458 in src/experimental/gibbs.jl

View check run for this annotation

Codecov / codecov/patch

src/experimental/gibbs.jl#L458

Added line #L458 was not covered by tests
rng,
model_local,
sampler_local,
Expand All @@ -450,7 +469,7 @@
rng,
model_local,
sampler_local,
current_state_local;
state_local;
kwargs...,
),
)
Expand Down
2 changes: 2 additions & 0 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@
end
end

DynamicPPL.getspace(::ExternalSampler) = ()

Check warning on line 117 in src/mcmc/Inference.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/Inference.jl#L117

Added line #L117 was not covered by tests

"""
requires_unconstrained_space(sampler::ExternalSampler)

Expand Down
99 changes: 96 additions & 3 deletions src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,52 @@
return transition_to_turing(parent(f), transition)
end

"""
getmodel(f)

Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
"""
getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = getmodel(parent(f))
getmodel(f::DynamicPPL.LogDensityFunction) = f.model

Check warning on line 26 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L25-L26

Added lines #L25 - L26 were not covered by tests

# FIXME: We'll have to overload this for every AD backend since some of the AD backends
# will cache certain parts of a given model, e.g. the tape, which results in a discrepancy
# between the primal (forward) and dual (backward).
"""
setmodel(f, model)

Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.

!!! warning
Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a
`DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f`
might require recompilation of the gradient tape, depending on the AD backend.
"""
function setmodel(f::LogDensityProblemsAD.ADGradientWrapper, model::DynamicPPL.Model)
return Accessors.@set f.ℓ = setmodel(f.ℓ, model)

Check warning on line 42 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L41-L42

Added lines #L41 - L42 were not covered by tests
end
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
return Accessors.@set f.model = model

Check warning on line 45 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L44-L45

Added lines #L44 - L45 were not covered by tests
end

function varinfo_from_logdensityfn(f::LogDensityProblemsAD.ADGradientWrapper)
return varinfo_from_logdensityfn(parent(f))

Check warning on line 49 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L48-L49

Added lines #L48 - L49 were not covered by tests
end
varinfo_from_logdensityfn(f::DynamicPPL.LogDensityFunction) = f.varinfo

Check warning on line 51 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L51

Added line #L51 was not covered by tests

function varinfo(state::TuringState)
θ = getparams(getmodel(state.logdensity), state.state)

Check warning on line 54 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L53-L54

Added lines #L53 - L54 were not covered by tests
# TODO: Do we need to link here first?
return DynamicPPL.unflatten(varinfo_from_logdensityfn(state.logdensity), θ)

Check warning on line 56 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L56

Added line #L56 was not covered by tests
end

# NOTE: Only thing that depends on the underlying sampler.
# Something similar should be part of AbstractMCMC at some point:
# https://github.com/TuringLang/AbstractMCMC.jl/pull/86
getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ
function getparams(model::DynamicPPL.Model, state::AdvancedHMC.HMCState)
return getparams(model, state.transition)

Check warning on line 64 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L63-L64

Added lines #L63 - L64 were not covered by tests
end
getstats(transition::AdvancedHMC.Transition) = transition.stat

getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params
Expand All @@ -33,13 +75,59 @@
return Accessors.@set f.ℓ = setvarinfo(f.ℓ, varinfo)
end

"""
recompute_logprob!!(rng, model, sampler, state)

Recompute the log-probability of the `model` based on the given `state` and return the resulting state.
"""
function recompute_logprob!!(

Check warning on line 83 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L83

Added line #L83 was not covered by tests
rng::Random.AbstractRNG, # TODO: Do we need the `rng` here?
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:ExternalSampler},
state,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should state have type TuringState, given function body assumes state has fields logdensity and state?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@torfjelde any comment on this?

)
# Re-using the log-density function from the `state` and updating only the `model` field,
# since the `model` might now contain different conditioning values.
f = setmodel(state.logdensity, model)

Check warning on line 91 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L91

Added line #L91 was not covered by tests
# Recompute the log-probability with the new `model`.
state_inner = recompute_logprob!!(

Check warning on line 93 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L93

Added line #L93 was not covered by tests
rng, AbstractMCMC.LogDensityModel(f), sampler.alg.sampler, state.state
)
return state_to_turing(f, state_inner)

Check warning on line 96 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L96

Added line #L96 was not covered by tests
end

function recompute_logprob!!(

Check warning on line 99 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L99

Added line #L99 was not covered by tests
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
sampler::AdvancedHMC.AbstractHMCSampler,
state::AdvancedHMC.HMCState,
)
# Construct hamiltionian.
hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model)

Check warning on line 106 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L106

Added line #L106 was not covered by tests
# Re-compute the log-probability and gradient.
return Accessors.@set state.transition.z = AdvancedHMC.phasepoint(

Check warning on line 108 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L108

Added line #L108 was not covered by tests
hamiltonian, state.transition.z.θ, state.transition.z.r
)
end

function recompute_logprob!!(

Check warning on line 113 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L113

Added line #L113 was not covered by tests
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
sampler::AdvancedMH.MetropolisHastings,
state::AdvancedMH.Transition,
)
logdensity = model.logdensity
return Accessors.@set state.lp = LogDensityProblems.logdensity(logdensity, state.params)

Check warning on line 120 in src/mcmc/abstractmcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/abstractmcmc.jl#L119-L120

Added lines #L119 - L120 were not covered by tests
end

# TODO: Do we also support `resume`, etc?
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler_wrapper::Sampler{<:ExternalSampler};
initial_state=nothing,
initial_params=nothing,
kwargs...
kwargs...,
)
alg = sampler_wrapper.alg
sampler = alg.sampler
Expand Down Expand Up @@ -69,7 +157,12 @@
)
else
transition_inner, state_inner = AbstractMCMC.step(
rng, AbstractMCMC.LogDensityModel(f), sampler, initial_state; initial_params, kwargs...
rng,
AbstractMCMC.LogDensityModel(f),
sampler,
initial_state;
initial_params,
kwargs...,
)
end
# Update the `state`
Expand All @@ -81,7 +174,7 @@
model::DynamicPPL.Model,
sampler_wrapper::Sampler{<:ExternalSampler},
state::TuringState;
kwargs...
kwargs...,
)
sampler = sampler_wrapper.alg.sampler
f = state.logdensity
Expand Down
Loading
Loading