Skip to content

Commit

Permalink
Using JET.jl to determine if typed varinfo is okay (TuringLang#728)
Browse files Browse the repository at this point in the history
* fixed calls to `to_linked_internal_transform`

* fixed incorrect call to `acclogp_assume!!`

* added `determine_varinfo` and an implementation using JET for this

* made filtering for errors only in the tilde pipeline optional

* formatting

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fixed incorrect comment

* added test for the branch we were currently imssing

* formatting

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* renamed `determine_varinfo` to `determine_suitable_varinfo` with
fallback to current behavior + `supports_varinfo` to `is_suitable_varinfo`

* removed now-redundant init used with Requires.jl, since this is no
longer needed on Julia 1.10 and onwards + added error hint for when
JET.jl has not been loaded

* `determine_suitable_varinfo` now only performs checks using the
provided context, but uses `SamplingContext` by default (as this
should be a stricter check than just evaluation)

* formatting

* updated error hint

* added def of `untyped_varinfo` which takes just `model` and `context`

* fixed incorrect call to `untyped_varinfo` in `_determine_varinfo_jet`

* explicitly call `typed_varinfo` when we want such a thing rather than
the ambiguous `VarINfo`

* `typed_varinfo` and `untyped_varinfo` handles wrapping passed context
in sampling context now so no need to handle this explicitly elsewhere

* use `determine_suitable_varinfo` in `LogDensityFunction` when not constructed

* formatting

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* formatting

* fixed a bug in `DynamicPPLJETExt.is_tilde_instance`

* updated docs

* Update docs/src/internals/varinfo.md

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added back def of `untyped_varinfo` that shouldn't have been removed +
fixed call in docs

* minor codestyle improvement

* temporary hack to debug what's happening

* more debugging

* use the `target_modules` kwarg in `report_call` instead of manually
filtering the frames

* formatting

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* more debugging

* more debugging

* more debugging: try with new bijectors.jl

* formatting

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* removed the hacky debugging stuff used for the CI

* removed now-redudant filtering methods since we use JET's own filters

* bump Bijectors.jl compat entry to 0.15.1 in test so JET.jl tests pass

* moved the JET.jl-dependent experimental `determine_varinfo` into a
separate `Experimental` module, as discussed

* forgot to add the experimenta.jl file in previous commit

* reverted changes to `default_varinfo` and `LogDensityFunction`

* added a bunch of docs for introduced and existing methods

Added docs for `determine_suitable_varinfo` and existing methods that should be
documented, e.g. `untyped_varinfo`, `typed_varinfo`, and `default_varinfo`

* added doctests to `determine_suitable_varinfo`

* added JET.jl as a dep to docs

* fixed referencing in docs

* fixed docstring

* fixed doctest

* Update Project.toml

* applied suggestions from @mhauru

Co-authored-by: Markus Hauru <[email protected]>

* fixed doctests

* finally fixed doctests

* removed unnecessary `typed_varinfo` and `untyped_varinfo` methods

* added filter to ignore source of warnings in doctest

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Markus Hauru <[email protected]>
Co-authored-by: Tor Fjelde <[email protected]>
  • Loading branch information
4 people authored Dec 10, 2024
1 parent 0548ddf commit 145f471
Show file tree
Hide file tree
Showing 12 changed files with 329 additions and 38 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
Expand All @@ -37,6 +38,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLJETExt = ["JET"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLMooncakeExt = ["Mooncake"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
Expand All @@ -55,6 +57,7 @@ Distributions = "0.25"
DocStringExtensions = "0.9"
EnzymeCore = "0.6 - 0.8"
ForwardDiff = "0.10"
JET = "0.9"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
Expand Down
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterMermaid = "a078cd44-4d9c-4618-b545-3ab9d77f9177"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand All @@ -18,6 +19,7 @@ Documenter = "1"
DocumenterMermaid = "0.1"
FillArrays = "0.13, 1"
ForwardDiff = "0.10"
JET = "0.9"
LogDensityProblems = "2"
MCMCChains = "5, 6"
StableRNGs = "1"
20 changes: 20 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,13 @@ AbstractVarInfo

But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary.

For constructing the "default" typed and untyped varinfo types used in DynamicPPL (see [the section on varinfo design](@ref "Design of `VarInfo`") for more on this), we have the following two methods:

```@docs
DynamicPPL.untyped_varinfo
DynamicPPL.typed_varinfo
```

#### `VarInfo`

```@docs
Expand Down Expand Up @@ -425,6 +432,19 @@ DynamicPPL.loadstate
DynamicPPL.initialsampler
```

Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`.

```@docs
DynamicPPL.default_varinfo
```

There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_varinfo`](@ref), which uses static checking via [JET.jl](https://github.com/aviatesk/JET.jl) to determine whether one should use [`DynamicPPL.typed_varinfo`](@ref) or [`DynamicPPL.untyped_varinfo`](@ref), depending on which supports the model:

```@docs
DynamicPPL.Experimental.determine_suitable_varinfo
DynamicPPL.Experimental.is_suitable_varinfo
```

### [Model-Internal Functions](@id model_internal)

```@docs
Expand Down
4 changes: 1 addition & 3 deletions docs/src/internals/varinfo.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ For example, with the model above we have

```@example varinfo-design
# Type-unstable `VarInfo`
varinfo_untyped = DynamicPPL.untyped_varinfo(
demo(), SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata()
)
varinfo_untyped = DynamicPPL.untyped_varinfo(demo())
typeof(varinfo_untyped.metadata)
```

Expand Down
53 changes: 53 additions & 0 deletions ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
module DynamicPPLJETExt

using DynamicPPL: DynamicPPL
using JET: JET

function DynamicPPL.Experimental.is_suitable_varinfo(
model::DynamicPPL.Model,
context::DynamicPPL.AbstractContext,
varinfo::DynamicPPL.AbstractVarInfo;
only_ddpl::Bool=true,
)
# Let's make sure that both evaluation and sampling doesn't result in type errors.
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
model, varinfo, context
)
# If specified, we only check errors originating somewhere in the DynamicPPL.jl.
# This way we don't just fall back to untyped if the user's code is the issue.
result = if only_ddpl
JET.report_call(f, argtypes; target_modules=(JET.AnyFrameModule(DynamicPPL),))
else
JET.report_call(f, argtypes)
end
return length(JET.get_reports(result)) == 0, result
end

function DynamicPPL.Experimental._determine_varinfo_jet(
model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true
)
# First we try with the typed varinfo.
varinfo = DynamicPPL.typed_varinfo(model, context)

# Let's make sure that both evaluation and sampling doesn't result in type errors.
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
model, context, varinfo; only_ddpl
)

if !issuccess
# Useful information for debugging.
@debug "Evaluaton with typed varinfo failed with the following issues:"
@debug result
end

# If we didn't fail anywhere, we return the type stable one.
return if issuccess
varinfo
else
# Warn the user that we can't use the type stable one.
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
DynamicPPL.untyped_varinfo(model, context)
end
end

end
41 changes: 22 additions & 19 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,32 +199,35 @@ include("values_as_in_model.jl")
include("debug_utils.jl")
using .DebugUtils

include("experimental.jl")
include("deprecated.jl")

if !isdefined(Base, :get_extension)
using Requires
end

@static if !isdefined(Base, :get_extension)
# Better error message if users forget to load JET
if isdefined(Base.Experimental, :register_error_hint)
function __init__()
@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include(
"../ext/DynamicPPLChainRulesCoreExt.jl"
)
@require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include(
"../ext/DynamicPPLEnzymeCoreExt.jl"
)
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include(
"../ext/DynamicPPLForwardDiffExt.jl"
)
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
"../ext/DynamicPPLMCMCChainsExt.jl"
)
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(
"../ext/DynamicPPLReverseDiffExt.jl"
)
@require ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" include(
"../ext/DynamicPPLZygoteRulesExt.jl"
)
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
requires_jet =
exc.f === DynamicPPL.Experimental._determine_varinfo_jet &&
length(argtypes) >= 2 &&
argtypes[1] <: Model &&
argtypes[2] <: AbstractContext
requires_jet |=
exc.f === DynamicPPL.Experimental.is_suitable_varinfo &&
length(argtypes) >= 3 &&
argtypes[1] <: Model &&
argtypes[2] <: AbstractContext &&
argtypes[3] <: AbstractVarInfo
if requires_jet
print(
io,
"\n$(exc.f) requires JET.jl to be loaded. Please run `using JET` before calling $(exc.f).",
)
end
end
end
end

Expand Down
104 changes: 104 additions & 0 deletions src/experimental.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
module Experimental

using DynamicPPL: DynamicPPL

# This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency.
"""
is_suitable_varinfo(model::Model, context::AbstractContext, varinfo::AbstractVarInfo; kwargs...)
Check if the `model` supports evaluation using the provided `context` and `varinfo`.
!!! warning
Loading JET.jl is required before calling this function.
# Arguments
- `model`: The model to verify the support for.
- `context`: The context to use for the model evaluation.
- `varinfo`: The varinfo to verify the support for.
# Keyword Arguments
- `only_ddpl`: If `true`, only consider error reports occuring in the tilde pipeline. Default: `true`.
# Returns
- `issuccess`: `true` if the model supports the varinfo, otherwise `false`.
- `report`: The result of `report_call` from JET.jl.
"""
function is_suitable_varinfo end

# Internal hook for JET.jl to overload.
function _determine_varinfo_jet end

"""
determine_suitable_varinfo(model[, context]; only_ddpl::Bool=true)
Return a suitable varinfo for the given `model`.
See also: [`DynamicPPL.Experimental.is_suitable_varinfo`](@ref).
!!! warning
For full functionality, this requires JET.jl to be loaded.
If JET.jl is not loaded, this function will assume the model is compatible with typed varinfo.
# Arguments
- `model`: The model for which to determine the varinfo.
- `context`: The context to use for the model evaluation. Default: `SamplingContext()`.
# Keyword Arguments
- `only_ddpl`: If `true`, only consider error reports within DynamicPPL.jl.
# Examples
```jldoctest
julia> using DynamicPPL.Experimental: determine_suitable_varinfo
julia> using JET: JET # needs to be loaded for full functionality
julia> @model function model_with_random_support()
x ~ Bernoulli()
if x
y ~ Normal()
else
z ~ Normal()
end
end
model_with_random_support (generic function with 2 methods)
julia> model = model_with_random_support();
julia> # Typed varinfo cannot handle this random support model properly
# as using a single execution of the model will not see all random variables.
# Hence, this this model requires untyped varinfo.
vi = determine_suitable_varinfo(model);
┌ Warning: Model seems incompatible with typed varinfo. Falling back to untyped varinfo.
└ @ DynamicPPLJETExt ~/.julia/dev/DynamicPPL.jl/ext/DynamicPPLJETExt.jl:48
julia> vi isa typeof(DynamicPPL.untyped_varinfo(model))
true
julia> # In contrast, a simple model with no random support can be handled by typed varinfo.
@model model_with_static_support() = x ~ Normal()
model_with_static_support (generic function with 2 methods)
julia> vi = determine_suitable_varinfo(model_with_static_support());
julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support()))
true
```
"""
function determine_suitable_varinfo(
model::DynamicPPL.Model,
context::DynamicPPL.AbstractContext=DynamicPPL.SamplingContext();
only_ddpl::Bool=true,
)
# If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that.
return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing
_determine_varinfo_jet(model, context; only_ddpl)
else
# Warn the user.
@warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo."
# Otherwise, we use the, possibly incorrect, default typed varinfo (to stay backwards compat).
DynamicPPL.typed_varinfo(model, context)
end
end

end
16 changes: 15 additions & 1 deletion src/sampler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,20 @@ function AbstractMCMC.step(
return vi, nothing
end

"""
default_varinfo(rng, model, sampler[, context])
Return a default varinfo object for the given `model` and `sampler`.
# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `model::Model`: Model for which we want to create a varinfo object.
- `sampler::AbstractSampler`: Sampler which will make use of the varinfo object.
- `context::AbstractContext`: Context in which the model is evaluated.
# Returns
- `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`.
"""
function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler)
return default_varinfo(rng, model, sampler, DefaultContext())
end
Expand Down Expand Up @@ -126,7 +140,7 @@ By default, `data` is returned.
loadstate(data) = data

"""
default_chaintype(sampler)
default_chain_type(sampler)
Default type of the chain of posterior samples from `sampler`.
"""
Expand Down
34 changes: 20 additions & 14 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,30 +164,36 @@ function has_varnamedvector(vi::VarInfo)
end

"""
untyped_varinfo([rng, ]model[, sampler, context])
untyped_varinfo(model[, context, metadata])
Return an untyped `VarInfo` instance for the model `model`.
Return an untyped varinfo object for the given `model` and `context`.
# Arguments
- `model::Model`: The model for which to create the varinfo object.
- `context::AbstractContext`: The context in which to evaluate the model. Default: `SamplingContext()`.
- `metadata::Union{Metadata,VarNamedVector}`: The metadata to use for the varinfo object.
Default: `Metadata()`.
"""
function untyped_varinfo(
rng::Random.AbstractRNG,
model::Model,
sampler::AbstractSampler=SampleFromPrior(),
context::AbstractContext=DefaultContext(),
context::AbstractContext=SamplingContext(),
metadata::Union{Metadata,VarNamedVector}=Metadata(),
)
varinfo = VarInfo(metadata)
return last(evaluate!!(model, varinfo, SamplingContext(rng, sampler, context)))
end
function untyped_varinfo(
model::Model, args::Union{AbstractSampler,AbstractContext,Metadata,VarNamedVector}...
)
return untyped_varinfo(Random.default_rng(), model, args...)
return last(
evaluate!!(model, varinfo, hassampler(context) ? context : SamplingContext(context))
)
end

"""
typed_varinfo([rng, ]model[, sampler, context])
typed_varinfo(model[, context, metadata])
Return a typed varinfo object for the given `model`, `sampler` and `context`.
This simply calls [`DynamicPPL.untyped_varinfo`](@ref) and converts the resulting
varinfo object to a typed varinfo object.
Return a typed `VarInfo` instance for the model `model`.
See also: [`DynamicPPL.untyped_varinfo`](@ref)
"""
typed_varinfo(args...) = TypedVarInfo(untyped_varinfo(args...))

Expand All @@ -198,7 +204,7 @@ function VarInfo(
context::AbstractContext=DefaultContext(),
metadata::Union{Metadata,VarNamedVector}=Metadata(),
)
return typed_varinfo(rng, model, sampler, context, metadata)
return typed_varinfo(model, SamplingContext(rng, sampler, context), metadata)
end
VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...)

Expand Down
Loading

0 comments on commit 145f471

Please sign in to comment.