diff --git a/Project.toml b/Project.toml index 21d501e9f..14a7e1b1c 100644 --- a/Project.toml +++ b/Project.toml @@ -29,6 +29,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" @@ -37,6 +38,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] +DynamicPPLJETExt = ["JET"] DynamicPPLMCMCChainsExt = ["MCMCChains"] DynamicPPLMooncakeExt = ["Mooncake"] DynamicPPLZygoteRulesExt = ["ZygoteRules"] @@ -55,6 +57,7 @@ Distributions = "0.25" DocStringExtensions = "0.9" EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10" +JET = "0.9" LinearAlgebra = "1.6" LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" diff --git a/docs/Project.toml b/docs/Project.toml index cf9c4ecaf..069c406f3 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,6 +6,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterMermaid = "a078cd44-4d9c-4618-b545-3ab9d77f9177" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -18,6 +19,7 @@ Documenter = "1" DocumenterMermaid = "0.1" FillArrays = "0.13, 1" ForwardDiff = "0.10" +JET = "0.9" LogDensityProblems = "2" MCMCChains = "5, 6" StableRNGs = "1" diff --git a/docs/src/api.md b/docs/src/api.md index 7448812bf..5a226e73b 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -265,6 +265,13 @@ AbstractVarInfo But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary. +For constructing the "default" typed and untyped varinfo types used in DynamicPPL (see [the section on varinfo design](@ref "Design of `VarInfo`") for more on this), we have the following two methods: + +```@docs +DynamicPPL.untyped_varinfo +DynamicPPL.typed_varinfo +``` + #### `VarInfo` ```@docs @@ -425,6 +432,19 @@ DynamicPPL.loadstate DynamicPPL.initialsampler ``` +Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`. + +```@docs +DynamicPPL.default_varinfo +``` + +There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_varinfo`](@ref), which uses static checking via [JET.jl](https://github.com/aviatesk/JET.jl) to determine whether one should use [`DynamicPPL.typed_varinfo`](@ref) or [`DynamicPPL.untyped_varinfo`](@ref), depending on which supports the model: + +```@docs +DynamicPPL.Experimental.determine_suitable_varinfo +DynamicPPL.Experimental.is_suitable_varinfo +``` + ### [Model-Internal Functions](@id model_internal) ```@docs diff --git a/docs/src/internals/varinfo.md b/docs/src/internals/varinfo.md index 4f0480b61..e6e1f2619 100644 --- a/docs/src/internals/varinfo.md +++ b/docs/src/internals/varinfo.md @@ -79,9 +79,7 @@ For example, with the model above we have ```@example varinfo-design # Type-unstable `VarInfo` -varinfo_untyped = DynamicPPL.untyped_varinfo( - demo(), SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata() -) +varinfo_untyped = DynamicPPL.untyped_varinfo(demo()) typeof(varinfo_untyped.metadata) ``` diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl new file mode 100644 index 000000000..aa95093f2 --- /dev/null +++ b/ext/DynamicPPLJETExt.jl @@ -0,0 +1,53 @@ +module DynamicPPLJETExt + +using DynamicPPL: DynamicPPL +using JET: JET + +function DynamicPPL.Experimental.is_suitable_varinfo( + model::DynamicPPL.Model, + context::DynamicPPL.AbstractContext, + varinfo::DynamicPPL.AbstractVarInfo; + only_ddpl::Bool=true, +) + # Let's make sure that both evaluation and sampling doesn't result in type errors. + f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( + model, varinfo, context + ) + # If specified, we only check errors originating somewhere in the DynamicPPL.jl. + # This way we don't just fall back to untyped if the user's code is the issue. + result = if only_ddpl + JET.report_call(f, argtypes; target_modules=(JET.AnyFrameModule(DynamicPPL),)) + else + JET.report_call(f, argtypes) + end + return length(JET.get_reports(result)) == 0, result +end + +function DynamicPPL.Experimental._determine_varinfo_jet( + model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true +) + # First we try with the typed varinfo. + varinfo = DynamicPPL.typed_varinfo(model, context) + + # Let's make sure that both evaluation and sampling doesn't result in type errors. + issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo( + model, context, varinfo; only_ddpl + ) + + if !issuccess + # Useful information for debugging. + @debug "Evaluaton with typed varinfo failed with the following issues:" + @debug result + end + + # If we didn't fail anywhere, we return the type stable one. + return if issuccess + varinfo + else + # Warn the user that we can't use the type stable one. + @warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo." + DynamicPPL.untyped_varinfo(model, context) + end +end + +end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index d4e13d456..a44a4123c 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -199,32 +199,35 @@ include("values_as_in_model.jl") include("debug_utils.jl") using .DebugUtils +include("experimental.jl") include("deprecated.jl") if !isdefined(Base, :get_extension) using Requires end -@static if !isdefined(Base, :get_extension) +# Better error message if users forget to load JET +if isdefined(Base.Experimental, :register_error_hint) function __init__() - @require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include( - "../ext/DynamicPPLChainRulesCoreExt.jl" - ) - @require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include( - "../ext/DynamicPPLEnzymeCoreExt.jl" - ) - @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include( - "../ext/DynamicPPLForwardDiffExt.jl" - ) - @require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include( - "../ext/DynamicPPLMCMCChainsExt.jl" - ) - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include( - "../ext/DynamicPPLReverseDiffExt.jl" - ) - @require ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" include( - "../ext/DynamicPPLZygoteRulesExt.jl" - ) + Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ + requires_jet = + exc.f === DynamicPPL.Experimental._determine_varinfo_jet && + length(argtypes) >= 2 && + argtypes[1] <: Model && + argtypes[2] <: AbstractContext + requires_jet |= + exc.f === DynamicPPL.Experimental.is_suitable_varinfo && + length(argtypes) >= 3 && + argtypes[1] <: Model && + argtypes[2] <: AbstractContext && + argtypes[3] <: AbstractVarInfo + if requires_jet + print( + io, + "\n$(exc.f) requires JET.jl to be loaded. Please run `using JET` before calling $(exc.f).", + ) + end + end end end diff --git a/src/experimental.jl b/src/experimental.jl new file mode 100644 index 000000000..84038803c --- /dev/null +++ b/src/experimental.jl @@ -0,0 +1,104 @@ +module Experimental + +using DynamicPPL: DynamicPPL + +# This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. +""" + is_suitable_varinfo(model::Model, context::AbstractContext, varinfo::AbstractVarInfo; kwargs...) + +Check if the `model` supports evaluation using the provided `context` and `varinfo`. + +!!! warning + Loading JET.jl is required before calling this function. + +# Arguments +- `model`: The model to verify the support for. +- `context`: The context to use for the model evaluation. +- `varinfo`: The varinfo to verify the support for. + +# Keyword Arguments +- `only_ddpl`: If `true`, only consider error reports occuring in the tilde pipeline. Default: `true`. + +# Returns +- `issuccess`: `true` if the model supports the varinfo, otherwise `false`. +- `report`: The result of `report_call` from JET.jl. +""" +function is_suitable_varinfo end + +# Internal hook for JET.jl to overload. +function _determine_varinfo_jet end + +""" + determine_suitable_varinfo(model[, context]; only_ddpl::Bool=true) + +Return a suitable varinfo for the given `model`. + +See also: [`DynamicPPL.Experimental.is_suitable_varinfo`](@ref). + +!!! warning + For full functionality, this requires JET.jl to be loaded. + If JET.jl is not loaded, this function will assume the model is compatible with typed varinfo. + +# Arguments +- `model`: The model for which to determine the varinfo. +- `context`: The context to use for the model evaluation. Default: `SamplingContext()`. + +# Keyword Arguments +- `only_ddpl`: If `true`, only consider error reports within DynamicPPL.jl. + +# Examples + +```jldoctest +julia> using DynamicPPL.Experimental: determine_suitable_varinfo + +julia> using JET: JET # needs to be loaded for full functionality + +julia> @model function model_with_random_support() + x ~ Bernoulli() + if x + y ~ Normal() + else + z ~ Normal() + end + end +model_with_random_support (generic function with 2 methods) + +julia> model = model_with_random_support(); + +julia> # Typed varinfo cannot handle this random support model properly + # as using a single execution of the model will not see all random variables. + # Hence, this this model requires untyped varinfo. + vi = determine_suitable_varinfo(model); +┌ Warning: Model seems incompatible with typed varinfo. Falling back to untyped varinfo. +└ @ DynamicPPLJETExt ~/.julia/dev/DynamicPPL.jl/ext/DynamicPPLJETExt.jl:48 + +julia> vi isa typeof(DynamicPPL.untyped_varinfo(model)) +true + +julia> # In contrast, a simple model with no random support can be handled by typed varinfo. + @model model_with_static_support() = x ~ Normal() +model_with_static_support (generic function with 2 methods) + +julia> vi = determine_suitable_varinfo(model_with_static_support()); + +julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support())) +true +``` +""" +function determine_suitable_varinfo( + model::DynamicPPL.Model, + context::DynamicPPL.AbstractContext=DynamicPPL.SamplingContext(); + only_ddpl::Bool=true, +) + # If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that. + return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing + _determine_varinfo_jet(model, context; only_ddpl) + else + # Warn the user. + @warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo." + # Otherwise, we use the, possibly incorrect, default typed varinfo (to stay backwards compat). + DynamicPPL.typed_varinfo(model, context) + end +end + +end diff --git a/src/sampler.jl b/src/sampler.jl index 833aaf7e2..40418114e 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -67,6 +67,20 @@ function AbstractMCMC.step( return vi, nothing end +""" + default_varinfo(rng, model, sampler[, context]) + +Return a default varinfo object for the given `model` and `sampler`. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `model::Model`: Model for which we want to create a varinfo object. +- `sampler::AbstractSampler`: Sampler which will make use of the varinfo object. +- `context::AbstractContext`: Context in which the model is evaluated. + +# Returns +- `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`. +""" function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler) return default_varinfo(rng, model, sampler, DefaultContext()) end @@ -126,7 +140,7 @@ By default, `data` is returned. loadstate(data) = data """ - default_chaintype(sampler) + default_chain_type(sampler) Default type of the chain of posterior samples from `sampler`. """ diff --git a/src/varinfo.jl b/src/varinfo.jl index 2c07d4298..bf2dd08c8 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -164,30 +164,36 @@ function has_varnamedvector(vi::VarInfo) end """ - untyped_varinfo([rng, ]model[, sampler, context]) + untyped_varinfo(model[, context, metadata]) -Return an untyped `VarInfo` instance for the model `model`. +Return an untyped varinfo object for the given `model` and `context`. + +# Arguments +- `model::Model`: The model for which to create the varinfo object. +- `context::AbstractContext`: The context in which to evaluate the model. Default: `SamplingContext()`. +- `metadata::Union{Metadata,VarNamedVector}`: The metadata to use for the varinfo object. + Default: `Metadata()`. """ function untyped_varinfo( - rng::Random.AbstractRNG, model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + context::AbstractContext=SamplingContext(), metadata::Union{Metadata,VarNamedVector}=Metadata(), ) varinfo = VarInfo(metadata) - return last(evaluate!!(model, varinfo, SamplingContext(rng, sampler, context))) -end -function untyped_varinfo( - model::Model, args::Union{AbstractSampler,AbstractContext,Metadata,VarNamedVector}... -) - return untyped_varinfo(Random.default_rng(), model, args...) + return last( + evaluate!!(model, varinfo, hassampler(context) ? context : SamplingContext(context)) + ) end """ - typed_varinfo([rng, ]model[, sampler, context]) + typed_varinfo(model[, context, metadata]) + +Return a typed varinfo object for the given `model`, `sampler` and `context`. + +This simply calls [`DynamicPPL.untyped_varinfo`](@ref) and converts the resulting +varinfo object to a typed varinfo object. -Return a typed `VarInfo` instance for the model `model`. +See also: [`DynamicPPL.untyped_varinfo`](@ref) """ typed_varinfo(args...) = TypedVarInfo(untyped_varinfo(args...)) @@ -198,7 +204,7 @@ function VarInfo( context::AbstractContext=DefaultContext(), metadata::Union{Metadata,VarNamedVector}=Metadata(), ) - return typed_varinfo(rng, model, sampler, context, metadata) + return typed_varinfo(model, SamplingContext(rng, sampler, context), metadata) end VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...) diff --git a/test/Project.toml b/test/Project.toml index e536fbfa8..1e4ec34fa 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -13,6 +13,7 @@ DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" @@ -33,7 +34,7 @@ ADTypes = "1" AbstractMCMC = "5" AbstractPPL = "0.8.4, 0.9" Accessors = "0.1" -Bijectors = "0.13.9, 0.14, 0.15" +Bijectors = "0.15.1" Combinatorics = "1" Compat = "4.3.0" DifferentiationInterface = "0.6" diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl new file mode 100644 index 000000000..b95107b2d --- /dev/null +++ b/test/ext/DynamicPPLJETExt.jl @@ -0,0 +1,81 @@ +@testset "DynamicPPLJETExt.jl" begin + @testset "determine_suitable_varinfo" begin + @model function demo1() + x ~ Bernoulli() + if x + y ~ Normal() + else + z ~ Normal() + end + end + model = demo1() + @test DynamicPPL.Experimental.determine_suitable_varinfo(model) isa + DynamicPPL.UntypedVarInfo + + @model demo2() = x ~ Normal() + @test DynamicPPL.Experimental.determine_suitable_varinfo(demo2()) isa + DynamicPPL.TypedVarInfo + + @model function demo3() + # Just making sure that nothing strange happens when type inference fails. + x = Vector(undef, 1) + x[1] ~ Bernoulli() + if x[1] + y ~ Normal() + else + z ~ Normal() + end + end + @test DynamicPPL.Experimental.determine_suitable_varinfo(demo3()) isa + DynamicPPL.UntypedVarInfo + + # Evaluation works (and it would even do so in practice), but sampling + # fill fail due to storing `Cauchy{Float64}` in `Vector{Normal{Float64}}`. + @model function demo4() + x ~ Bernoulli() + if x + y ~ Normal() + else + y ~ Cauchy() # different distibution, but same transformation + end + end + @test DynamicPPL.Experimental.determine_suitable_varinfo(demo4()) isa + DynamicPPL.UntypedVarInfo + + # In this model, the type error occurs in the user code rather than in DynamicPPL. + @model function demo5() + x ~ Normal() + xs = Any[] + push!(xs, x) + # `sum(::Vector{Any})` can potentially error unless the dynamic manages to resolve the + # correct `zero` method. As a result, this code will run, but JET will raise this is an issue. + return sum(xs) + end + # Should pass if we're only checking the tilde statements. + @test DynamicPPL.Experimental.determine_suitable_varinfo(demo5()) isa + DynamicPPL.TypedVarInfo + # Should fail if we're including errors in the model body. + @test DynamicPPL.Experimental.determine_suitable_varinfo( + demo5(); only_ddpl=false + ) isa DynamicPPL.UntypedVarInfo + end + + @testset "demo models" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + # Use debug logging below. + varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) + # They should all result in typed. + @test varinfo isa DynamicPPL.TypedVarInfo + # But let's also make sure that they're not lying. + f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( + model, varinfo + ) + JET.test_call(f_eval, argtypes_eval) + + f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( + model, varinfo, DynamicPPL.SamplingContext() + ) + JET.test_call(f_sample, argtypes_sample) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index a4fdabf22..aea02a337 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,6 +28,8 @@ using Test using Distributions using LinearAlgebra # Diagonal +using JET: JET + using Combinatorics: combinations using DynamicPPL: getargs_dottilde, getargs_tilde, Selector @@ -74,6 +76,7 @@ include("test_util.jl") @testset "extensions" begin include("ext/DynamicPPLMCMCChainsExt.jl") + include("ext/DynamicPPLJETExt.jl") end @testset "ad" begin @@ -107,6 +110,9 @@ include("test_util.jl") # Older versions do not have `;;]` but instead just `]` at end of the line # => need to treat `;;]` and `]` as the same, i.e. ignore them if at the end of a line r"(;;){0,1}\]$"m, + # Ignore the source of a warning in the doctest output, since this is dependent on host. + # This is a line that starts with "└ @ " and ends with the line number. + r"└ @ .+:[0-9]+", ] doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters) end