forked from TuringLang/DynamicPPL.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Using JET.jl to determine if typed varinfo is okay (TuringLang#728)
* fixed calls to `to_linked_internal_transform` * fixed incorrect call to `acclogp_assume!!` * added `determine_varinfo` and an implementation using JET for this * made filtering for errors only in the tilde pipeline optional * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed incorrect comment * added test for the branch we were currently imssing * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * renamed `determine_varinfo` to `determine_suitable_varinfo` with fallback to current behavior + `supports_varinfo` to `is_suitable_varinfo` * removed now-redundant init used with Requires.jl, since this is no longer needed on Julia 1.10 and onwards + added error hint for when JET.jl has not been loaded * `determine_suitable_varinfo` now only performs checks using the provided context, but uses `SamplingContext` by default (as this should be a stricter check than just evaluation) * formatting * updated error hint * added def of `untyped_varinfo` which takes just `model` and `context` * fixed incorrect call to `untyped_varinfo` in `_determine_varinfo_jet` * explicitly call `typed_varinfo` when we want such a thing rather than the ambiguous `VarINfo` * `typed_varinfo` and `untyped_varinfo` handles wrapping passed context in sampling context now so no need to handle this explicitly elsewhere * use `determine_suitable_varinfo` in `LogDensityFunction` when not constructed * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * formatting * fixed a bug in `DynamicPPLJETExt.is_tilde_instance` * updated docs * Update docs/src/internals/varinfo.md Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added back def of `untyped_varinfo` that shouldn't have been removed + fixed call in docs * minor codestyle improvement * temporary hack to debug what's happening * more debugging * use the `target_modules` kwarg in `report_call` instead of manually filtering the frames * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * more debugging * more debugging * more debugging: try with new bijectors.jl * formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * removed the hacky debugging stuff used for the CI * removed now-redudant filtering methods since we use JET's own filters * bump Bijectors.jl compat entry to 0.15.1 in test so JET.jl tests pass * moved the JET.jl-dependent experimental `determine_varinfo` into a separate `Experimental` module, as discussed * forgot to add the experimenta.jl file in previous commit * reverted changes to `default_varinfo` and `LogDensityFunction` * added a bunch of docs for introduced and existing methods Added docs for `determine_suitable_varinfo` and existing methods that should be documented, e.g. `untyped_varinfo`, `typed_varinfo`, and `default_varinfo` * added doctests to `determine_suitable_varinfo` * added JET.jl as a dep to docs * fixed referencing in docs * fixed docstring * fixed doctest * Update Project.toml * applied suggestions from @mhauru Co-authored-by: Markus Hauru <[email protected]> * fixed doctests * finally fixed doctests * removed unnecessary `typed_varinfo` and `untyped_varinfo` methods * added filter to ignore source of warnings in doctest --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Markus Hauru <[email protected]> Co-authored-by: Tor Fjelde <[email protected]>
- Loading branch information
1 parent
0548ddf
commit 145f471
Showing
12 changed files
with
329 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
module DynamicPPLJETExt | ||
|
||
using DynamicPPL: DynamicPPL | ||
using JET: JET | ||
|
||
function DynamicPPL.Experimental.is_suitable_varinfo( | ||
model::DynamicPPL.Model, | ||
context::DynamicPPL.AbstractContext, | ||
varinfo::DynamicPPL.AbstractVarInfo; | ||
only_ddpl::Bool=true, | ||
) | ||
# Let's make sure that both evaluation and sampling doesn't result in type errors. | ||
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( | ||
model, varinfo, context | ||
) | ||
# If specified, we only check errors originating somewhere in the DynamicPPL.jl. | ||
# This way we don't just fall back to untyped if the user's code is the issue. | ||
result = if only_ddpl | ||
JET.report_call(f, argtypes; target_modules=(JET.AnyFrameModule(DynamicPPL),)) | ||
else | ||
JET.report_call(f, argtypes) | ||
end | ||
return length(JET.get_reports(result)) == 0, result | ||
end | ||
|
||
function DynamicPPL.Experimental._determine_varinfo_jet( | ||
model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true | ||
) | ||
# First we try with the typed varinfo. | ||
varinfo = DynamicPPL.typed_varinfo(model, context) | ||
|
||
# Let's make sure that both evaluation and sampling doesn't result in type errors. | ||
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo( | ||
model, context, varinfo; only_ddpl | ||
) | ||
|
||
if !issuccess | ||
# Useful information for debugging. | ||
@debug "Evaluaton with typed varinfo failed with the following issues:" | ||
@debug result | ||
end | ||
|
||
# If we didn't fail anywhere, we return the type stable one. | ||
return if issuccess | ||
varinfo | ||
else | ||
# Warn the user that we can't use the type stable one. | ||
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo." | ||
DynamicPPL.untyped_varinfo(model, context) | ||
end | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
module Experimental | ||
|
||
using DynamicPPL: DynamicPPL | ||
|
||
# This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. | ||
""" | ||
is_suitable_varinfo(model::Model, context::AbstractContext, varinfo::AbstractVarInfo; kwargs...) | ||
Check if the `model` supports evaluation using the provided `context` and `varinfo`. | ||
!!! warning | ||
Loading JET.jl is required before calling this function. | ||
# Arguments | ||
- `model`: The model to verify the support for. | ||
- `context`: The context to use for the model evaluation. | ||
- `varinfo`: The varinfo to verify the support for. | ||
# Keyword Arguments | ||
- `only_ddpl`: If `true`, only consider error reports occuring in the tilde pipeline. Default: `true`. | ||
# Returns | ||
- `issuccess`: `true` if the model supports the varinfo, otherwise `false`. | ||
- `report`: The result of `report_call` from JET.jl. | ||
""" | ||
function is_suitable_varinfo end | ||
|
||
# Internal hook for JET.jl to overload. | ||
function _determine_varinfo_jet end | ||
|
||
""" | ||
determine_suitable_varinfo(model[, context]; only_ddpl::Bool=true) | ||
Return a suitable varinfo for the given `model`. | ||
See also: [`DynamicPPL.Experimental.is_suitable_varinfo`](@ref). | ||
!!! warning | ||
For full functionality, this requires JET.jl to be loaded. | ||
If JET.jl is not loaded, this function will assume the model is compatible with typed varinfo. | ||
# Arguments | ||
- `model`: The model for which to determine the varinfo. | ||
- `context`: The context to use for the model evaluation. Default: `SamplingContext()`. | ||
# Keyword Arguments | ||
- `only_ddpl`: If `true`, only consider error reports within DynamicPPL.jl. | ||
# Examples | ||
```jldoctest | ||
julia> using DynamicPPL.Experimental: determine_suitable_varinfo | ||
julia> using JET: JET # needs to be loaded for full functionality | ||
julia> @model function model_with_random_support() | ||
x ~ Bernoulli() | ||
if x | ||
y ~ Normal() | ||
else | ||
z ~ Normal() | ||
end | ||
end | ||
model_with_random_support (generic function with 2 methods) | ||
julia> model = model_with_random_support(); | ||
julia> # Typed varinfo cannot handle this random support model properly | ||
# as using a single execution of the model will not see all random variables. | ||
# Hence, this this model requires untyped varinfo. | ||
vi = determine_suitable_varinfo(model); | ||
┌ Warning: Model seems incompatible with typed varinfo. Falling back to untyped varinfo. | ||
└ @ DynamicPPLJETExt ~/.julia/dev/DynamicPPL.jl/ext/DynamicPPLJETExt.jl:48 | ||
julia> vi isa typeof(DynamicPPL.untyped_varinfo(model)) | ||
true | ||
julia> # In contrast, a simple model with no random support can be handled by typed varinfo. | ||
@model model_with_static_support() = x ~ Normal() | ||
model_with_static_support (generic function with 2 methods) | ||
julia> vi = determine_suitable_varinfo(model_with_static_support()); | ||
julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support())) | ||
true | ||
``` | ||
""" | ||
function determine_suitable_varinfo( | ||
model::DynamicPPL.Model, | ||
context::DynamicPPL.AbstractContext=DynamicPPL.SamplingContext(); | ||
only_ddpl::Bool=true, | ||
) | ||
# If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that. | ||
return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing | ||
_determine_varinfo_jet(model, context; only_ddpl) | ||
else | ||
# Warn the user. | ||
@warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo." | ||
# Otherwise, we use the, possibly incorrect, default typed varinfo (to stay backwards compat). | ||
DynamicPPL.typed_varinfo(model, context) | ||
end | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.