diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 9acde4f89..fce8d9e30 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -17,10 +17,17 @@ permissions: actions: write contents: read +# Cancel existing tests on the same PR if a new commit is added to a pull request +concurrency: + group: ${{ github.workflow }}-${{ github.ref || github.run_id }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + jobs: test: runs-on: ${{ matrix.runner.os }} strategy: + fail-fast: false + matrix: runner: # Current stable version @@ -58,6 +65,9 @@ jobs: os: macos-latest arch: aarch64 num_threads: 2 + test_group: + - Group1 + - Group2 steps: - uses: actions/checkout@v4 @@ -73,14 +83,14 @@ jobs: - uses: julia-actions/julia-runtest@v1 env: - GROUP: All + GROUP: ${{ matrix.test_group }} JULIA_NUM_THREADS: ${{ matrix.runner.num_threads }} - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v4 + - uses: codecov/codecov-action@v5 with: - file: lcov.info + files: lcov.info token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: true diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index e3dfc3030..36bf4939c 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -14,4 +14,4 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} - run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "docs", "test", "test/turing"])' + run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "docs", "test"])' diff --git a/.github/workflows/JuliaPre.yml b/.github/workflows/JuliaPre.yml index 98b0b0ffa..a9118a4bf 100644 --- a/.github/workflows/JuliaPre.yml +++ b/.github/workflows/JuliaPre.yml @@ -25,5 +25,3 @@ jobs: - uses: julia-actions/cache@v2 - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 - env: - GROUP: DynamicPPL diff --git a/Project.toml b/Project.toml index 301e485f7..51bac6df5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.31.0" +version = "0.32.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -29,16 +29,18 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [extensions] DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] +DynamicPPLJETExt = ["JET"] DynamicPPLMCMCChainsExt = ["MCMCChains"] -DynamicPPLReverseDiffExt = ["ReverseDiff"] +DynamicPPLMooncakeExt = ["Mooncake"] DynamicPPLZygoteRulesExt = ["ZygoteRules"] [compat] @@ -55,23 +57,16 @@ Distributions = "0.25" DocStringExtensions = "0.9" EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10" +JET = "0.9" LinearAlgebra = "1.6" LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" MCMCChains = "6" MacroTools = "0.5.6" +Mooncake = "0.4.59" OrderedCollections = "1" Random = "1.6" Requires = "1" -ReverseDiff = "1" Test = "1.6" ZygoteRules = "0.2" julia = "1.10" - -[extras] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" diff --git a/docs/Project.toml b/docs/Project.toml index cf9c4ecaf..069c406f3 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -6,6 +6,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterMermaid = "a078cd44-4d9c-4618-b545-3ab9d77f9177" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -18,6 +19,7 @@ Documenter = "1" DocumenterMermaid = "0.1" FillArrays = "0.13, 1" ForwardDiff = "0.10" +JET = "0.9" LogDensityProblems = "2" MCMCChains = "5, 6" StableRNGs = "1" diff --git a/docs/src/api.md b/docs/src/api.md index d7531af0f..d5c6bd690 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -14,12 +14,6 @@ These statements are rewritten by `@model` as calls of [internal functions](@ref @model ``` -One can nest models and call another model inside the model function with [`@submodel`](@ref). - -```@docs -@submodel -``` - ### Type A [`Model`](@ref) can be created by calling the model function, as defined by [`@model`](@ref). @@ -110,6 +104,34 @@ Similarly, we can [`unfix`](@ref) variables, i.e. return them to their original unfix ``` +## Models within models + +One can include models and call another model inside the model function with `left ~ to_submodel(model)`. + +```@docs +to_submodel +``` + +Note that a `[to_submodel](@ref)` is only sampleable; one cannot compute `logpdf` for its realizations. + +In the past, one would instead embed sub-models using [`@submodel`](@ref), which has been deprecated since the introduction of [`to_submodel(model)`](@ref) + +```@docs +@submodel +``` + +In the context of including models within models, it's also useful to prefix the variables in sub-models to avoid variable names clashing: + +```@docs +prefix +``` + +Under the hood, [`to_submodel`](@ref) makes use of the following method to indicate that the model it's wrapping is a model over its return-values rather than something else + +```@docs +returned(::Model) +``` + ## Utilities It is possible to manually increase (or decrease) the accumulated log density from within a model function. @@ -118,10 +140,10 @@ It is possible to manually increase (or decrease) the accumulated log density fr @addlogprob! ``` -Return values of the model function for a collection of samples can be obtained with [`generated_quantities`](@ref). +Return values of the model function for a collection of samples can be obtained with [`returned(model, chain)`](@ref). ```@docs -generated_quantities +returned(::DynamicPPL.Model, ::NamedTuple) ``` For a chain of samples, one can compute the pointwise log-likelihoods of each observed random variable with [`pointwise_loglikelihoods`](@ref). Similarly, the log-densities of the priors using @@ -243,20 +265,24 @@ AbstractVarInfo But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary. -#### `VarInfo` +For constructing the "default" typed and untyped varinfo types used in DynamicPPL (see [the section on varinfo design](@ref "Design of `VarInfo`") for more on this), we have the following two methods: ```@docs -VarInfo -TypedVarInfo +DynamicPPL.untyped_varinfo +DynamicPPL.typed_varinfo ``` -One main characteristic of [`VarInfo`](@ref) is that samples are stored in a linearized form. +#### `VarInfo` ```@docs -link! -invlink! +VarInfo +TypedVarInfo ``` +One main characteristic of [`VarInfo`](@ref) is that samples are transformed to unconstrained Euclidean space and stored in a linearized form, as described in the [transformation page](internals/transformations.md). +The [Transformations section below](#Transformations) describes the methods used for this. +In the specific case of `VarInfo`, it keeps track of whether samples have been transformed by setting flags on them, using the following functions. + ```@docs set_flag! unset_flag! @@ -403,6 +429,19 @@ DynamicPPL.loadstate DynamicPPL.initialsampler ``` +Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`. + +```@docs +DynamicPPL.default_varinfo +``` + +There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_varinfo`](@ref), which uses static checking via [JET.jl](https://github.com/aviatesk/JET.jl) to determine whether one should use [`DynamicPPL.typed_varinfo`](@ref) or [`DynamicPPL.untyped_varinfo`](@ref), depending on which supports the model: + +```@docs +DynamicPPL.Experimental.determine_suitable_varinfo +DynamicPPL.Experimental.is_suitable_varinfo +``` + ### [Model-Internal Functions](@id model_internal) ```@docs diff --git a/docs/src/internals/varinfo.md b/docs/src/internals/varinfo.md index 4f0480b61..e6e1f2619 100644 --- a/docs/src/internals/varinfo.md +++ b/docs/src/internals/varinfo.md @@ -79,9 +79,7 @@ For example, with the model above we have ```@example varinfo-design # Type-unstable `VarInfo` -varinfo_untyped = DynamicPPL.untyped_varinfo( - demo(), SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata() -) +varinfo_untyped = DynamicPPL.untyped_varinfo(demo()) typeof(varinfo_untyped.metadata) ``` diff --git a/ext/DynamicPPLJETExt.jl b/ext/DynamicPPLJETExt.jl new file mode 100644 index 000000000..aa95093f2 --- /dev/null +++ b/ext/DynamicPPLJETExt.jl @@ -0,0 +1,53 @@ +module DynamicPPLJETExt + +using DynamicPPL: DynamicPPL +using JET: JET + +function DynamicPPL.Experimental.is_suitable_varinfo( + model::DynamicPPL.Model, + context::DynamicPPL.AbstractContext, + varinfo::DynamicPPL.AbstractVarInfo; + only_ddpl::Bool=true, +) + # Let's make sure that both evaluation and sampling doesn't result in type errors. + f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( + model, varinfo, context + ) + # If specified, we only check errors originating somewhere in the DynamicPPL.jl. + # This way we don't just fall back to untyped if the user's code is the issue. + result = if only_ddpl + JET.report_call(f, argtypes; target_modules=(JET.AnyFrameModule(DynamicPPL),)) + else + JET.report_call(f, argtypes) + end + return length(JET.get_reports(result)) == 0, result +end + +function DynamicPPL.Experimental._determine_varinfo_jet( + model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true +) + # First we try with the typed varinfo. + varinfo = DynamicPPL.typed_varinfo(model, context) + + # Let's make sure that both evaluation and sampling doesn't result in type errors. + issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo( + model, context, varinfo; only_ddpl + ) + + if !issuccess + # Useful information for debugging. + @debug "Evaluaton with typed varinfo failed with the following issues:" + @debug result + end + + # If we didn't fail anywhere, we return the type stable one. + return if issuccess + varinfo + else + # Warn the user that we can't use the type stable one. + @warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo." + DynamicPPL.untyped_varinfo(model, context) + end +end + +end diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index 82f489765..41efcb15c 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -193,7 +193,7 @@ function _predictive_samples_to_chains(predictive_samples) end """ - generated_quantities(model::Model, chain::MCMCChains.Chains) + returned(model::Model, chain::MCMCChains.Chains) Execute `model` for each of the samples in `chain` and return an array of the values returned by the `model` for each sample. @@ -213,12 +213,12 @@ m = demo(data) chain = sample(m, alg, n) # To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples # from the posterior/`chain`: -generated_quantities(m, chain) # <= results in a `Vector` of returned values +returned(m, chain) # <= results in a `Vector` of returned values # from `interesting_quantity(θ, x)` ``` ## Concrete (and simple) ```julia -julia> using DynamicPPL, Turing +julia> using Turing julia> @model function demo(xs) s ~ InverseGamma(2, 3) @@ -237,7 +237,7 @@ julia> model = demo(randn(10)); julia> chain = sample(model, MH(), 10); -julia> generated_quantities(model, chain) +julia> returned(model, chain) 10×1 Array{Tuple{Float64},2}: (2.1964758025119338,) (2.1964758025119338,) @@ -251,9 +251,7 @@ julia> generated_quantities(model, chain) (-0.16489786710222099,) ``` """ -function DynamicPPL.generated_quantities( - model::DynamicPPL.Model, chain_full::MCMCChains.Chains -) +function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Chains) chain = MCMCChains.get_sections(chain_full, :parameters) varinfo = DynamicPPL.VarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl new file mode 100644 index 000000000..b86d807bc --- /dev/null +++ b/ext/DynamicPPLMooncakeExt.jl @@ -0,0 +1,9 @@ +module DynamicPPLMooncakeExt + +using DynamicPPL: DynamicPPL, istrans +using Mooncake: Mooncake + +# This is purely an optimisation. +Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg} + +end # module diff --git a/ext/DynamicPPLReverseDiffExt.jl b/ext/DynamicPPLReverseDiffExt.jl deleted file mode 100644 index 3fd174ed1..000000000 --- a/ext/DynamicPPLReverseDiffExt.jl +++ /dev/null @@ -1,26 +0,0 @@ -module DynamicPPLReverseDiffExt - -if isdefined(Base, :get_extension) - using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD - using ReverseDiff -else - using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD - using ..ReverseDiff -end - -function LogDensityProblemsAD.ADgradient( - ad::ADTypes.AutoReverseDiff{Tcompile}, ℓ::DynamicPPL.LogDensityFunction -) where {Tcompile} - return LogDensityProblemsAD.ADgradient( - Val(:ReverseDiff), - ℓ; - compile=Val(Tcompile), - # `getparams` can return `Vector{Real}`, in which case, `ReverseDiff` will initialize the gradients to Integer 0 - # because at https://github.com/JuliaDiff/ReverseDiff.jl/blob/c982cde5494fc166965a9d04691f390d9e3073fd/src/tracked.jl#L473 - # `zero(D)` will return 0 when D is Real. - # here we use `identity` to possibly concretize the type to `Vector{Float64}` in the case of `Vector{Real}`. - x=map(identity, DynamicPPL.getparams(ℓ)), - ) -end - -end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e8cee3b08..d94050ff2 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -86,7 +86,6 @@ export AbstractVarInfo, Model, getmissings, getargnames, - generated_quantities, extract_priors, values_as_in_model, # Samplers @@ -122,6 +121,9 @@ export AbstractVarInfo, decondition, fix, unfix, + prefix, + returned, + to_submodel, # Convenience macros @addlogprob!, @submodel, @@ -130,7 +132,8 @@ export AbstractVarInfo, check_model_and_trace, # Deprecated. @logprob_str, - @prob_str + @prob_str, + generated_quantities # Reexport using Distributions: loglikelihood @@ -196,30 +199,35 @@ include("values_as_in_model.jl") include("debug_utils.jl") using .DebugUtils +include("experimental.jl") +include("deprecated.jl") + if !isdefined(Base, :get_extension) using Requires end -@static if !isdefined(Base, :get_extension) +# Better error message if users forget to load JET +if isdefined(Base.Experimental, :register_error_hint) function __init__() - @require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include( - "../ext/DynamicPPLChainRulesCoreExt.jl" - ) - @require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include( - "../ext/DynamicPPLEnzymeCoreExt.jl" - ) - @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include( - "../ext/DynamicPPLForwardDiffExt.jl" - ) - @require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include( - "../ext/DynamicPPLMCMCChainsExt.jl" - ) - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include( - "../ext/DynamicPPLReverseDiffExt.jl" - ) - @require ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" include( - "../ext/DynamicPPLZygoteRulesExt.jl" - ) + Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ + requires_jet = + exc.f === DynamicPPL.Experimental._determine_varinfo_jet && + length(argtypes) >= 2 && + argtypes[1] <: Model && + argtypes[2] <: AbstractContext + requires_jet |= + exc.f === DynamicPPL.Experimental.is_suitable_varinfo && + length(argtypes) >= 3 && + argtypes[1] <: Model && + argtypes[2] <: AbstractContext && + argtypes[3] <: AbstractVarInfo + if requires_jet + print( + io, + "\n$(exc.f) requires JET.jl to be loaded. Please run `using JET` before calling $(exc.f).", + ) + end + end end end diff --git a/src/compiler.jl b/src/compiler.jl index 90220cbf5..c67da6f95 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -178,6 +178,11 @@ function check_tilde_rhs(@nospecialize(x)) end check_tilde_rhs(x::Distribution) = x check_tilde_rhs(x::AbstractArray{<:Distribution}) = x +check_tilde_rhs(x::ReturnedModelWrapper) = x +function check_tilde_rhs(x::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} + model = check_tilde_rhs(x.model) + return Sampleable{typeof(model),AutoPrefix}(model) +end """ unwrap_right_vn(right, vn) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 489c64c57..462012676 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -103,8 +103,17 @@ By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log probability of `vi` with the returned value. """ function tilde_assume!!(context, right, vn, vi) - value, logp, vi = tilde_assume(context, right, vn, vi) - return value, acclogp_assume!!(context, vi, logp) + return if is_rhs_model(right) + # Prefix the variables using the `vn`. + rand_like!!( + right, + should_auto_prefix(right) ? PrefixContext{Symbol(vn)}(context) : context, + vi, + ) + else + value, logp, vi = tilde_assume(context, right, vn, vi) + value, acclogp_assume!!(context, vi, logp) + end end # observe @@ -159,6 +168,11 @@ Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the informati and indices; if needed, these can be accessed through this function, though. """ function tilde_observe!!(context, right, left, vname, vi) + is_rhs_model(right) && throw( + ArgumentError( + "`~` with a model on the right-hand side of an observe statement is not supported", + ), + ) return tilde_observe!!(context, right, left, vi) end @@ -172,6 +186,11 @@ By default, calls `tilde_observe(context, right, left, vi)` and accumulates the probability of `vi` with the returned value. """ function tilde_observe!!(context, right, left, vi) + is_rhs_model(right) && throw( + ArgumentError( + "`~` with a model on the right-hand side of an observe statement is not supported", + ), + ) logp, vi = tilde_observe(context, right, left, vi) return left, acclogp_observe!!(context, vi, logp) end @@ -321,8 +340,13 @@ model inputs), accumulate the log probability, and return the sampled value and Falls back to `dot_tilde_assume(context, right, left, vn, vi)`. """ function dot_tilde_assume!!(context, right, left, vn, vi) + is_rhs_model(right) && throw( + ArgumentError( + "`.~` with a model on the right-hand side is not supported; please use `~`" + ), + ) value, logp, vi = dot_tilde_assume(context, right, left, vn, vi) - return value, acclogp_assume!!(context, vi, logp), vi + return value, acclogp_assume!!(context, vi, logp) end # `dot_assume` @@ -573,6 +597,11 @@ Falls back to `dot_tilde_observe!!(context, right, left, vi)` ignoring the infor name and indices; if needed, these can be accessed through this function, though. """ function dot_tilde_observe!!(context, right, left, vn, vi) + is_rhs_model(right) && throw( + ArgumentError( + "`~` with a model on the right-hand side of an observe statement is not supported", + ), + ) return dot_tilde_observe!!(context, right, left, vi) end @@ -585,6 +614,11 @@ probability, and return the observed value and updated `vi`. Falls back to `dot_tilde_observe(context, right, left, vi)`. """ function dot_tilde_observe!!(context, right, left, vi) + is_rhs_model(right) && throw( + ArgumentError( + "`~` with a model on the right-hand side of an observe statement is not supported", + ), + ) logp, vi = dot_tilde_observe(context, right, left, vi) return left, acclogp_observe!!(context, vi, logp) end diff --git a/src/contexts.jl b/src/contexts.jl index 5da4208b5..9eb3d5ccb 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -281,6 +281,34 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} end end +""" + prefix(model::Model, x) + +Return `model` but with all random variables prefixed by `x`. + +If `x` is known at compile-time, use `Val{x}()` to avoid runtime overheads for prefixing. + +# Examples + +```jldoctest +julia> using DynamicPPL: prefix + +julia> @model demo() = x ~ Dirac(1) +demo (generic function with 2 methods) + +julia> rand(prefix(demo(), :my_prefix)) +(var"my_prefix.x" = 1,) + +julia> # One can also use `Val` to avoid runtime overheads. + rand(prefix(demo(), Val(:my_prefix))) +(var"my_prefix.x" = 1,) +``` +""" +prefix(model::Model, x) = contextualize(model, PrefixContext{Symbol(x)}(model.context)) +function prefix(model::Model, ::Val{x}) where {x} + return contextualize(model, PrefixContext{Symbol(x)}(model.context)) +end + struct ConditionContext{Values,Ctx<:AbstractContext} <: AbstractContext values::Values context::Ctx diff --git a/src/deprecated.jl b/src/deprecated.jl new file mode 100644 index 000000000..0bcaae9b7 --- /dev/null +++ b/src/deprecated.jl @@ -0,0 +1 @@ +@deprecate generated_quantities(model, params) returned(model, params) diff --git a/src/experimental.jl b/src/experimental.jl new file mode 100644 index 000000000..84038803c --- /dev/null +++ b/src/experimental.jl @@ -0,0 +1,104 @@ +module Experimental + +using DynamicPPL: DynamicPPL + +# This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency. +""" + is_suitable_varinfo(model::Model, context::AbstractContext, varinfo::AbstractVarInfo; kwargs...) + +Check if the `model` supports evaluation using the provided `context` and `varinfo`. + +!!! warning + Loading JET.jl is required before calling this function. + +# Arguments +- `model`: The model to verify the support for. +- `context`: The context to use for the model evaluation. +- `varinfo`: The varinfo to verify the support for. + +# Keyword Arguments +- `only_ddpl`: If `true`, only consider error reports occuring in the tilde pipeline. Default: `true`. + +# Returns +- `issuccess`: `true` if the model supports the varinfo, otherwise `false`. +- `report`: The result of `report_call` from JET.jl. +""" +function is_suitable_varinfo end + +# Internal hook for JET.jl to overload. +function _determine_varinfo_jet end + +""" + determine_suitable_varinfo(model[, context]; only_ddpl::Bool=true) + +Return a suitable varinfo for the given `model`. + +See also: [`DynamicPPL.Experimental.is_suitable_varinfo`](@ref). + +!!! warning + For full functionality, this requires JET.jl to be loaded. + If JET.jl is not loaded, this function will assume the model is compatible with typed varinfo. + +# Arguments +- `model`: The model for which to determine the varinfo. +- `context`: The context to use for the model evaluation. Default: `SamplingContext()`. + +# Keyword Arguments +- `only_ddpl`: If `true`, only consider error reports within DynamicPPL.jl. + +# Examples + +```jldoctest +julia> using DynamicPPL.Experimental: determine_suitable_varinfo + +julia> using JET: JET # needs to be loaded for full functionality + +julia> @model function model_with_random_support() + x ~ Bernoulli() + if x + y ~ Normal() + else + z ~ Normal() + end + end +model_with_random_support (generic function with 2 methods) + +julia> model = model_with_random_support(); + +julia> # Typed varinfo cannot handle this random support model properly + # as using a single execution of the model will not see all random variables. + # Hence, this this model requires untyped varinfo. + vi = determine_suitable_varinfo(model); +┌ Warning: Model seems incompatible with typed varinfo. Falling back to untyped varinfo. +└ @ DynamicPPLJETExt ~/.julia/dev/DynamicPPL.jl/ext/DynamicPPLJETExt.jl:48 + +julia> vi isa typeof(DynamicPPL.untyped_varinfo(model)) +true + +julia> # In contrast, a simple model with no random support can be handled by typed varinfo. + @model model_with_static_support() = x ~ Normal() +model_with_static_support (generic function with 2 methods) + +julia> vi = determine_suitable_varinfo(model_with_static_support()); + +julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support())) +true +``` +""" +function determine_suitable_varinfo( + model::DynamicPPL.Model, + context::DynamicPPL.AbstractContext=DynamicPPL.SamplingContext(); + only_ddpl::Bool=true, +) + # If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that. + return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing + _determine_varinfo_jet(model, context; only_ddpl) + else + # Warn the user. + @warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo." + # Otherwise, we use the, possibly incorrect, default typed varinfo (to stay backwards compat). + DynamicPPL.typed_varinfo(model, context) + end +end + +end diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 9e86590fa..214369ab0 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -144,3 +144,19 @@ function LogDensityProblems.capabilities(::Type{<:LogDensityFunction}) end # TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)? LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) + +# This is important for performance -- one needs to provide `ADGradient` with a vector of +# parameters, or DifferentiationInterface will not have sufficient information to e.g. +# compile a rule for Mooncake (because it won't know the type of the input), or pre-allocate +# a tape when using ReverseDiff.jl. +function _make_ad_gradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction) + x = map(identity, getparams(ℓ)) # ensure we concretise the elements of the params + return LogDensityProblemsAD.ADgradient(ad, ℓ; x) +end + +function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction) + return _make_ad_gradient(ad, f) +end +function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction) + return _make_ad_gradient(ad, f) +end diff --git a/src/model.jl b/src/model.jl index dfae5fb1d..9fb4dd5a9 100644 --- a/src/model.jl +++ b/src/model.jl @@ -223,15 +223,16 @@ true ## Nested models `condition` of course also supports the use of nested models through -the use of [`@submodel`](@ref). +the use of [`to_submodel`](@ref). ```jldoctest condition julia> @model demo_inner() = m ~ Normal() demo_inner (generic function with 2 methods) julia> @model function demo_outer() - @submodel m = demo_inner() - return m + # By default, `to_submodel` prefixes the variables using the left-hand side of `~`. + inner ~ to_submodel(demo_inner()) + return inner end demo_outer (generic function with 2 methods) @@ -240,63 +241,28 @@ julia> model = demo_outer(); julia> model() ≠ 1.0 true -julia> conditioned_model = model | (m = 1.0, ); +julia> # To condition the variable inside `demo_inner` we need to refer to it as `inner.m`. + conditioned_model = model | (var"inner.m" = 1.0, ); julia> conditioned_model() 1.0 -``` - -But one needs to be careful when prefixing variables in the nested models: - -```jldoctest condition -julia> @model function demo_outer_prefix() - @submodel prefix="inner" m = demo_inner() - return m - end -demo_outer_prefix (generic function with 2 methods) - -julia> # (×) This doesn't work now! - conditioned_model = demo_outer_prefix() | (m = 1.0, ); - -julia> conditioned_model() == 1.0 -false -julia> # (✓) `m` in `demo_inner` is referred to as `inner.m` internally, so we do: - conditioned_model = demo_outer_prefix() | (var"inner.m" = 1.0, ); +julia> # However, it's not possible to condition `inner` directly. + conditioned_model_fail = model | (inner = 1.0, ); -julia> conditioned_model() -1.0 - -julia> # Note that the above `var"..."` is just standard Julia syntax: - keys((var"inner.m" = 1.0, )) -(Symbol("inner.m"),) +julia> conditioned_model_fail() +ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported +[...] ``` And similarly when using `Dict`: ```jldoctest condition -julia> conditioned_model_dict = demo_outer_prefix() | (@varname(var"inner.m") => 1.0); +julia> conditioned_model_dict = model | (@varname(var"inner.m") => 1.0); julia> conditioned_model_dict() 1.0 ``` - -The difference is maybe more obvious once we look at how these different -in their trace/`VarInfo`: - -```jldoctest condition -julia> keys(VarInfo(demo_outer())) -1-element Vector{VarName{:m, typeof(identity)}}: - m - -julia> keys(VarInfo(demo_outer_prefix())) -1-element Vector{VarName{Symbol("inner.m"), typeof(identity)}}: - inner.m -``` - -From this we can tell what the correct way to condition `m` within `demo_inner` -is in the two different models. - """ AbstractPPL.condition(model::Model; values...) = condition(model, NamedTuple(values)) function AbstractPPL.condition(model::Model, value, values...) @@ -578,15 +544,15 @@ true ## Nested models `fix` of course also supports the use of nested models through -the use of [`@submodel`](@ref). +the use of [`to_submodel`](@ref), similar to [`condition`](@ref). ```jldoctest fix julia> @model demo_inner() = m ~ Normal() demo_inner (generic function with 2 methods) julia> @model function demo_outer() - @submodel m = demo_inner() - return m + inner ~ to_submodel(demo_inner()) + return inner end demo_outer (generic function with 2 methods) @@ -595,63 +561,36 @@ julia> model = demo_outer(); julia> model() ≠ 1.0 true -julia> fixed_model = model | (m = 1.0, ); +julia> fixed_model = fix(model, var"inner.m" = 1.0, ); julia> fixed_model() 1.0 ``` -But one needs to be careful when prefixing variables in the nested models: - -```jldoctest fix -julia> @model function demo_outer_prefix() - @submodel prefix="inner" m = demo_inner() - return m - end -demo_outer_prefix (generic function with 2 methods) - -julia> # (×) This doesn't work now! - fixed_model = demo_outer_prefix() | (m = 1.0, ); - -julia> fixed_model() == 1.0 -false +However, unlike [`condition`](@ref), `fix` can also be used to fix the +return-value of the submodel: -julia> # (✓) `m` in `demo_inner` is referred to as `inner.m` internally, so we do: - fixed_model = demo_outer_prefix() | (var"inner.m" = 1.0, ); +```julia +julia> fixed_model = fix(model, inner = 2.0,); julia> fixed_model() -1.0 - -julia> # Note that the above `var"..."` is just standard Julia syntax: - keys((var"inner.m" = 1.0, )) -(Symbol("inner.m"),) +2.0 ``` And similarly when using `Dict`: ```jldoctest fix -julia> fixed_model_dict = demo_outer_prefix() | (@varname(var"inner.m") => 1.0); +julia> fixed_model_dict = fix(model, @varname(var"inner.m") => 1.0); julia> fixed_model_dict() 1.0 -``` -The difference is maybe more obvious once we look at how these different -in their trace/`VarInfo`: +julia> fixed_model_dict = fix(model, @varname(inner) => 2.0); -```jldoctest fix -julia> keys(VarInfo(demo_outer())) -1-element Vector{VarName{:m, typeof(identity)}}: - m - -julia> keys(VarInfo(demo_outer_prefix())) -1-element Vector{VarName{Symbol("inner.m"), typeof(identity)}}: - inner.m +julia> fixed_model_dict() +2.0 ``` -From this we can tell what the correct way to fix `m` within `demo_inner` -is in the two different models. - ## Difference from `condition` A very similar functionality is also provided by [`condition`](@ref) which, @@ -1051,7 +990,9 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T} evaluate!!( model, SimpleVarInfo{Float64}(OrderedDict()), - SamplingContext(rng, SampleFromPrior(), model.context), + # NOTE: Use `leafcontext` here so we a) avoid overriding the leaf context of `model`, + # and b) avoid double-stacking the parent contexts. + SamplingContext(rng, SampleFromPrior(), leafcontext(model.context)), ), ) return values_as(x, T) @@ -1220,9 +1161,9 @@ function predict(model::Model, chain; include_all=false) end """ - generated_quantities(model::Model, parameters::NamedTuple) - generated_quantities(model::Model, values, keys) - generated_quantities(model::Model, values, keys) + returned(model::Model, parameters::NamedTuple) + returned(model::Model, values, keys) + returned(model::Model, values, keys) Execute `model` with variables `keys` set to `values` and return the values returned by the `model`. @@ -1247,18 +1188,257 @@ julia> model = demo(randn(10)); julia> parameters = (; s = 1.0, m_shifted=10.0); -julia> generated_quantities(model, parameters) +julia> returned(model, parameters) (0.0,) -julia> generated_quantities(model, values(parameters), keys(parameters)) +julia> returned(model, values(parameters), keys(parameters)) (0.0,) ``` """ -function generated_quantities(model::Model, parameters::NamedTuple) +function returned(model::Model, parameters::NamedTuple) fixed_model = fix(model, parameters) return fixed_model() end -function generated_quantities(model::Model, values, keys) - return generated_quantities(model, NamedTuple{keys}(values)) +function returned(model::Model, values, keys) + return returned(model, NamedTuple{keys}(values)) end + +""" + is_rhs_model(x) + +Return `true` if `x` is a model or model wrapper, and `false` otherwise. +""" +is_rhs_model(x) = false + +""" + Distributional + +Abstract type for type indicating that something is "distributional". +""" +abstract type Distributional end + +""" + should_auto_prefix(distributional) + +Return `true` if the `distributional` should use automatic prefixing, and `false` otherwise. +""" +function should_auto_prefix end + +""" + is_rhs_model(x) + +Return `true` if the `distributional` is a model, and `false` otherwise. +""" +function is_rhs_model end + +""" + Sampleable{M} <: Distributional + +A wrapper around a model indicating it is sampleable. +""" +struct Sampleable{M,AutoPrefix} <: Distributional + model::M +end + +should_auto_prefix(::Sampleable{<:Any,AutoPrefix}) where {AutoPrefix} = AutoPrefix +is_rhs_model(x::Sampleable) = is_rhs_model(x.model) + +# TODO: Export this if it end up having a purpose beyond `to_submodel`. +""" + to_sampleable(model[, auto_prefix]) + +Return a wrapper around `model` indicating it is sampleable. + +# Arguments +- `model::Model`: the model to wrap. +- `auto_prefix::Bool`: whether to prefix the variables in the model. Default: `true`. +""" +to_sampleable(model, auto_prefix::Bool=true) = Sampleable{typeof(model),auto_prefix}(model) + +""" + rand_like!!(model_wrap, context, varinfo) + +Returns a tuple with the first element being the realization and the second the updated varinfo. + +# Arguments +- `model_wrap::ReturnedModelWrapper`: the wrapper of the model to use. +- `context::AbstractContext`: the context to use for evaluation. +- `varinfo::AbstractVarInfo`: the varinfo to use for evaluation. + """ +function rand_like!!( + model_wrap::Sampleable, context::AbstractContext, varinfo::AbstractVarInfo +) + return rand_like!!(model_wrap.model, context, varinfo) +end + +""" + ReturnedModelWrapper + +A wrapper around a model indicating it is a model over its return values. + +This should rarely be constructed explicitly; see [`returned(model)`](@ref) instead. +""" +struct ReturnedModelWrapper{M<:Model} + model::M +end + +is_rhs_model(::ReturnedModelWrapper) = true + +function rand_like!!( + model_wrap::ReturnedModelWrapper, context::AbstractContext, varinfo::AbstractVarInfo +) + # Return's the value and the (possibly mutated) varinfo. + return _evaluate!!(model_wrap.model, varinfo, context) +end + +""" + returned(model) + +Return a `model` wrapper indicating that it is a model over its return-values. +""" +returned(model::Model) = ReturnedModelWrapper(model) + +""" + to_submodel(model::Model[, auto_prefix::Bool]) + +Return a model wrapper indicating that it is a sampleable model over the return-values. + +This is mainly meant to be used on the right-hand side of a `~` operator to indicate that +the model can be sampled from but not necessarily evaluated for its log density. + +!!! warning + Note that some other operations that one typically associate with expressions of the form + `left ~ right` such as [`condition`](@ref), will also not work with `to_submodel`. + +!!! warning + To avoid variable names clashing between models, it is recommend leave argument `auto_prefix` equal to `true`. + If one does not use automatic prefixing, then it's recommended to use [`prefix(::Model, input)`](@ref) explicitly. + +# Arguments +- `model::Model`: the model to wrap. +- `auto_prefix::Bool`: whether to automatically prefix the variables in the model using the left-hand + side of the `~` statement. Default: `true`. + +# Examples + +## Simple example +```jldoctest submodel-to_submodel; setup=:(using Distributions) +julia> @model function demo1(x) + x ~ Normal() + return 1 + abs(x) + end; + +julia> @model function demo2(x, y) + a ~ to_submodel(demo1(x)) + return y ~ Uniform(0, a) + end; +``` + +When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: +```jldoctest submodel-to_submodel +julia> vi = VarInfo(demo2(missing, 0.4)); + +julia> @varname(var\"a.x\") in keys(vi) +true +``` + +The variable `a` is not tracked. However, it will be assigned the return value of `demo1`, +and can be used in subsequent lines of the model, as shown above. +```jldoctest submodel-to_submodel +julia> @varname(a) in keys(vi) +false +``` + +We can check that the log joint probability of the model accumulated in `vi` is correct: + +```jldoctest submodel-to_submodel +julia> x = vi[@varname(var\"a.x\")]; + +julia> getlogp(vi) ≈ logpdf(Normal(), x) + logpdf(Uniform(0, 1 + abs(x)), 0.4) +true +``` + +## Without automatic prefixing +As mentioned earlier, by default, the `auto_prefix` argument specifies whether to automatically +prefix the variables in the submodel. If `auto_prefix=false`, then the variables in the submodel +will not be prefixed. +```jldoctest submodel-to_submodel-prefix; setup=:(using Distributions) +julia> @model function demo1(x) + x ~ Normal() + return 1 + abs(x) + end; + +julia> @model function demo2_no_prefix(x, z) + a ~ to_submodel(demo1(x), false) + return z ~ Uniform(-a, 1) + end; + +julia> vi = VarInfo(demo2_no_prefix(missing, 0.4)); + +julia> @varname(x) in keys(vi) # here we just use `x` instead of `a.x` +true +``` +However, not using prefixing is generally not recommended as it can lead to variable name clashes +unless one is careful. For example, if we're re-using the same model twice in a model, not using prefixing +will lead to variable name clashes: However, one can manually prefix using the [`prefix(::Model, input)`](@ref): +```jldoctest submodel-to_submodel-prefix +julia> @model function demo2(x, y, z) + a ~ to_submodel(prefix(demo1(x), :sub1), false) + b ~ to_submodel(prefix(demo1(y), :sub2), false) + return z ~ Uniform(-a, b) + end; + +julia> vi = VarInfo(demo2(missing, missing, 0.4)); + +julia> @varname(var"sub1.x") in keys(vi) +true + +julia> @varname(var"sub2.x") in keys(vi) +true +``` + +Variables `a` and `b` are not tracked, but are assigned the return values of the respective +calls to `demo1`: +```jldoctest submodel-to_submodel-prefix +julia> @varname(a) in keys(vi) +false + +julia> @varname(b) in keys(vi) +false +``` + +We can check that the log joint probability of the model accumulated in `vi` is correct: + +```jldoctest submodel-to_submodel-prefix +julia> sub1_x = vi[@varname(var"sub1.x")]; + +julia> sub2_x = vi[@varname(var"sub2.x")]; + +julia> logprior = logpdf(Normal(), sub1_x) + logpdf(Normal(), sub2_x); + +julia> loglikelihood = logpdf(Uniform(-1 - abs(sub1_x), 1 + abs(sub2_x)), 0.4); + +julia> getlogp(vi) ≈ logprior + loglikelihood +true +``` + +## Usage as likelihood is illegal + +Note that it is illegal to use a `to_submodel` model as a likelihood in another model: + +```jldoctest submodel-to_submodel-illegal; setup=:(using Distributions) +julia> @model inner() = x ~ Normal() +inner (generic function with 2 methods) + +julia> @model illegal_likelihood() = a ~ to_submodel(inner()) +illegal_likelihood (generic function with 2 methods) + +julia> model = illegal_likelihood() | (a = 1.0,); + +julia> model() +ERROR: ArgumentError: `~` with a model on the right-hand side of an observe statement is not supported +[...] +""" +to_submodel(model::Model, auto_prefix::Bool=true) = + to_sampleable(returned(model), auto_prefix) diff --git a/src/sampler.jl b/src/sampler.jl index 833aaf7e2..40418114e 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -67,6 +67,20 @@ function AbstractMCMC.step( return vi, nothing end +""" + default_varinfo(rng, model, sampler[, context]) + +Return a default varinfo object for the given `model` and `sampler`. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `model::Model`: Model for which we want to create a varinfo object. +- `sampler::AbstractSampler`: Sampler which will make use of the varinfo object. +- `context::AbstractContext`: Context in which the model is evaluated. + +# Returns +- `AbstractVarInfo`: Default varinfo object for the given `model` and `sampler`. +""" function default_varinfo(rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler) return default_varinfo(rng, model, sampler, DefaultContext()) end @@ -126,7 +140,7 @@ By default, `data` is returned. loadstate(data) = data """ - default_chaintype(sampler) + default_chain_type(sampler) Default type of the chain of posterior samples from `sampler`. """ diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 050bf31fc..e5a8e0617 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -4,6 +4,10 @@ Run a Turing `model` nested inside of a Turing model. +!!! warning + This is deprecated and will be removed in a future release. + Use `left ~ to_submodel(model)` instead (see [`to_submodel`](@ref)). + # Examples ```jldoctest submodel; setup=:(using Distributions) @@ -21,6 +25,9 @@ julia> @model function demo2(x, y) When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled: ```jldoctest submodel julia> vi = VarInfo(demo2(missing, 0.4)); +┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. +│ caller = ip:0x0 +└ @ Core :-1 julia> @varname(x) in keys(vi) true @@ -62,6 +69,10 @@ Valid expressions for `prefix=...` are: The prefix makes it possible to run the same Turing model multiple times while keeping track of all random variables correctly. +!!! warning + This is deprecated and will be removed in a future release. + Use `left ~ to_submodel(model)` instead (see [`to_submodel(model)`](@ref)). + # Examples ## Example models ```jldoctest submodelprefix; setup=:(using Distributions) @@ -81,6 +92,9 @@ When we sample from the model `demo2(missing, missing, 0.4)` random variables `s `sub2.x` will be sampled: ```jldoctest submodelprefix julia> vi = VarInfo(demo2(missing, missing, 0.4)); +┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. +│ caller = ip:0x0 +└ @ Core :-1 julia> @varname(var"sub1.x") in keys(vi) true @@ -124,6 +138,9 @@ julia> # When `prefix` is unspecified, no prefix is used. submodel_noprefix (generic function with 2 methods) julia> @varname(x) in keys(VarInfo(submodel_noprefix())) +┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. +│ caller = ip:0x0 +└ @ Core :-1 true julia> # Explicitely don't use any prefix. @@ -131,6 +148,9 @@ julia> # Explicitely don't use any prefix. submodel_prefix_false (generic function with 2 methods) julia> @varname(x) in keys(VarInfo(submodel_prefix_false())) +┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. +│ caller = ip:0x0 +└ @ Core :-1 true julia> # Automatically determined from `a`. @@ -138,6 +158,9 @@ julia> # Automatically determined from `a`. submodel_prefix_true (generic function with 2 methods) julia> @varname(var"a.x") in keys(VarInfo(submodel_prefix_true())) +┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. +│ caller = ip:0x0 +└ @ Core :-1 true julia> # Using a static string. @@ -145,6 +168,9 @@ julia> # Using a static string. submodel_prefix_string (generic function with 2 methods) julia> @varname(var"my prefix.x") in keys(VarInfo(submodel_prefix_string())) +┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. +│ caller = ip:0x0 +└ @ Core :-1 true julia> # Using string interpolation. @@ -152,6 +178,9 @@ julia> # Using string interpolation. submodel_prefix_interpolation (generic function with 2 methods) julia> @varname(var"inner.x") in keys(VarInfo(submodel_prefix_interpolation())) +┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. +│ caller = ip:0x0 +└ @ Core :-1 true julia> # Or using some arbitrary expression. @@ -159,6 +188,9 @@ julia> # Or using some arbitrary expression. submodel_prefix_expr (generic function with 2 methods) julia> @varname(var"3.x") in keys(VarInfo(submodel_prefix_expr())) +┌ Warning: `@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax. +│ caller = ip:0x0 +└ @ Core :-1 true julia> # (×) Automatic prefixing without a left-hand side expression does not work! @@ -207,6 +239,8 @@ function prefix_submodel_context(prefix::Bool, ctx) return ctx end +const SUBMODEL_DEPWARN_MSG = "`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax." + function submodel(prefix_expr, expr, ctx=esc(:__context__)) prefix_left, prefix = getargs_assignment(prefix_expr) if prefix_left !== :prefix @@ -225,6 +259,9 @@ function submodel(prefix_expr, expr, ctx=esc(:__context__)) return if args_assign === nothing ctx = prefix_submodel_context(prefix, ctx) quote + # Raise deprecation warning to let user know that we recommend using `left ~ to_submodel(model)`. + $(Base.depwarn)(SUBMODEL_DEPWARN_MSG, Symbol("@submodel")) + $retval, $(esc(:__varinfo__)) = $(_evaluate!!)( $(esc(expr)), $(esc(:__varinfo__)), $(ctx) ) @@ -241,6 +278,9 @@ function submodel(prefix_expr, expr, ctx=esc(:__context__)) ) end quote + # Raise deprecation warning to let user know that we recommend using `left ~ to_submodel(model)`. + $(Base.depwarn)(SUBMODEL_DEPWARN_MSG, Symbol("@submodel")) + $retval, $(esc(:__varinfo__)) = $(_evaluate!!)( $(esc(R)), $(esc(:__varinfo__)), $(ctx) ) diff --git a/src/test_utils/models.jl b/src/test_utils/models.jl index 4aa2aaa42..92a69d9ad 100644 --- a/src/test_utils/models.jl +++ b/src/test_utils/models.jl @@ -323,28 +323,30 @@ function varnames(model::Model{typeof(demo_assume_dot_observe)}) return [@varname(s), @varname(m)] end -@model function demo_assume_observe_literal() - # `assume` and literal `observe` +@model function demo_assume_multivariate_observe_literal() + # multivariate `assume` and literal `observe` s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) m ~ MvNormal(zeros(2), Diagonal(s)) [1.5, 2.0] ~ MvNormal(m, Diagonal(s)) return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) +function logprior_true(model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m) s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) m_dist = MvNormal(zeros(2), Diagonal(s)) return logpdf(s_dist, s) + logpdf(m_dist, m) end -function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) +function loglikelihood_true( + model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m +) return logpdf(MvNormal(m, Diagonal(s)), [1.5, 2.0]) end function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_assume_observe_literal)}, s, m + model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function varnames(model::Model{typeof(demo_assume_observe_literal)}) +function varnames(model::Model{typeof(demo_assume_multivariate_observe_literal)}) return [@varname(s), @varname(m)] end @@ -377,7 +379,31 @@ function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -@model function demo_assume_literal_dot_observe() +@model function demo_assume_observe_literal() + # univariate `assume` and literal `observe` + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + 1.5 ~ Normal(m, sqrt(s)) + 2.0 ~ Normal(m, sqrt(s)) + + return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) +end +function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) + return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) +end +function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) + return logpdf(Normal(m, sqrt(s)), 1.5) + logpdf(Normal(m, sqrt(s)), 2.0) +end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_observe_literal)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end +function varnames(model::Model{typeof(demo_assume_observe_literal)}) + return [@varname(s), @varname(m)] +end + +@model function demo_assume_dot_observe_literal() # `assume` and literal `dot_observe` s ~ InverseGamma(2, 3) m ~ Normal(0, sqrt(s)) @@ -385,18 +411,18 @@ end return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m) +function logprior_true(model::Model{typeof(demo_assume_dot_observe_literal)}, s, m) return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) end -function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m) +function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe_literal)}, s, m) return loglikelihood(Normal(m, sqrt(s)), [1.5, 2.0]) end function logprior_true_with_logabsdet_jacobian( - model::Model{typeof(demo_assume_literal_dot_observe)}, s, m + model::Model{typeof(demo_assume_dot_observe_literal)}, s, m ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function varnames(model::Model{typeof(demo_assume_literal_dot_observe)}) +function varnames(model::Model{typeof(demo_assume_dot_observe_literal)}) return [@varname(s), @varname(m)] end @@ -574,8 +600,9 @@ const DemoModels = Union{ Model{typeof(demo_assume_multivariate_observe)}, Model{typeof(demo_dot_assume_observe_index)}, Model{typeof(demo_assume_dot_observe)}, - Model{typeof(demo_assume_literal_dot_observe)}, + Model{typeof(demo_assume_dot_observe_literal)}, Model{typeof(demo_assume_observe_literal)}, + Model{typeof(demo_assume_multivariate_observe_literal)}, Model{typeof(demo_dot_assume_observe_index_literal)}, Model{typeof(demo_assume_submodel_observe_index_literal)}, Model{typeof(demo_dot_assume_observe_submodel)}, @@ -585,7 +612,9 @@ const DemoModels = Union{ } const UnivariateAssumeDemoModels = Union{ - Model{typeof(demo_assume_dot_observe)},Model{typeof(demo_assume_literal_dot_observe)} + Model{typeof(demo_assume_dot_observe)}, + Model{typeof(demo_assume_dot_observe_literal)}, + Model{typeof(demo_assume_observe_literal)}, } function posterior_mean(model::UnivariateAssumeDemoModels) return (s=49 / 24, m=7 / 6) @@ -609,7 +638,7 @@ const MultivariateAssumeDemoModels = Union{ Model{typeof(demo_assume_index_observe)}, Model{typeof(demo_assume_multivariate_observe)}, Model{typeof(demo_dot_assume_observe_index)}, - Model{typeof(demo_assume_observe_literal)}, + Model{typeof(demo_assume_multivariate_observe_literal)}, Model{typeof(demo_dot_assume_observe_index_literal)}, Model{typeof(demo_assume_submodel_observe_index_literal)}, Model{typeof(demo_dot_assume_observe_submodel)}, @@ -759,9 +788,10 @@ const DEMO_MODELS = ( demo_assume_multivariate_observe(), demo_dot_assume_observe_index(), demo_assume_dot_observe(), - demo_assume_observe_literal(), + demo_assume_multivariate_observe_literal(), demo_dot_assume_observe_index_literal(), - demo_assume_literal_dot_observe(), + demo_assume_dot_observe_literal(), + demo_assume_observe_literal(), demo_assume_submodel_observe_index_literal(), demo_dot_assume_observe_submodel(), demo_dot_assume_dot_observe_matrix(), diff --git a/src/threadsafe.jl b/src/threadsafe.jl index ec890a674..cedb0efad 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -79,8 +79,6 @@ setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo) haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) -link!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = link!(vi.varinfo, spl) -invlink!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = invlink!(vi.varinfo, spl) islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) function link!!( @@ -178,6 +176,12 @@ function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vns::AbstractVector{<: return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vns) end +vector_length(vi::ThreadSafeVarInfo) = vector_length(vi.varinfo) +vector_getrange(vi::ThreadSafeVarInfo, vn::VarName) = vector_getrange(vi.varinfo, vn) +function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) + return vector_getranges(vi.varinfo, vns) +end + function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler) return set_retained_vns_del_by_spl!(vi.varinfo, spl) end diff --git a/src/varinfo.jl b/src/varinfo.jl index 4cf1f1b02..3ebb505e0 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -164,30 +164,36 @@ function has_varnamedvector(vi::VarInfo) end """ - untyped_varinfo([rng, ]model[, sampler, context]) + untyped_varinfo(model[, context, metadata]) -Return an untyped `VarInfo` instance for the model `model`. +Return an untyped varinfo object for the given `model` and `context`. + +# Arguments +- `model::Model`: The model for which to create the varinfo object. +- `context::AbstractContext`: The context in which to evaluate the model. Default: `SamplingContext()`. +- `metadata::Union{Metadata,VarNamedVector}`: The metadata to use for the varinfo object. + Default: `Metadata()`. """ function untyped_varinfo( - rng::Random.AbstractRNG, model::Model, - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + context::AbstractContext=SamplingContext(), metadata::Union{Metadata,VarNamedVector}=Metadata(), ) varinfo = VarInfo(metadata) - return last(evaluate!!(model, varinfo, SamplingContext(rng, sampler, context))) -end -function untyped_varinfo( - model::Model, args::Union{AbstractSampler,AbstractContext,Metadata,VarNamedVector}... -) - return untyped_varinfo(Random.default_rng(), model, args...) + return last( + evaluate!!(model, varinfo, hassampler(context) ? context : SamplingContext(context)) + ) end """ - typed_varinfo([rng, ]model[, sampler, context]) + typed_varinfo(model[, context, metadata]) + +Return a typed varinfo object for the given `model`, `sampler` and `context`. -Return a typed `VarInfo` instance for the model `model`. +This simply calls [`DynamicPPL.untyped_varinfo`](@ref) and converts the resulting +varinfo object to a typed varinfo object. + +See also: [`DynamicPPL.untyped_varinfo`](@ref) """ typed_varinfo(args...) = TypedVarInfo(untyped_varinfo(args...)) @@ -198,10 +204,19 @@ function VarInfo( context::AbstractContext=DefaultContext(), metadata::Union{Metadata,VarNamedVector}=Metadata(), ) - return typed_varinfo(rng, model, sampler, context, metadata) + return typed_varinfo(model, SamplingContext(rng, sampler, context), metadata) end VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...) +""" + vector_length(varinfo::VarInfo) + +Return the length of the vector representation of `varinfo`. +""" +vector_length(varinfo::VarInfo) = length(varinfo.metadata) +vector_length(varinfo::TypedVarInfo) = sum(length, varinfo.metadata) +vector_length(md::Metadata) = sum(length, md.ranges) + unflatten(vi::VarInfo, x::AbstractVector) = unflatten(vi, SampleFromPrior(), x) # TODO: deprecate. @@ -626,7 +641,72 @@ setrange!(md::Metadata, vn::VarName, range) = md.ranges[getidx(md, vn)] = range Return the indices of `vns` in the metadata of `vi` corresponding to `vn`. """ function getranges(vi::VarInfo, vns::Vector{<:VarName}) - return mapreduce(vn -> getrange(vi, vn), vcat, vns; init=Int[]) + return map(Base.Fix1(getrange, vi), vns) +end + +""" + vector_getrange(varinfo::VarInfo, varname::VarName) + +Return the range corresponding to `varname` in the vector representation of `varinfo`. +""" +vector_getrange(vi::VarInfo, vn::VarName) = getrange(vi.metadata, vn) +function vector_getrange(vi::TypedVarInfo, vn::VarName) + offset = 0 + for md in values(vi.metadata) + # First, we need to check if `vn` is in `md`. + # In this case, we can just return the corresponding range + offset. + haskey(md, vn) && return getrange(md, vn) .+ offset + # Otherwise, we need to get the cumulative length of the ranges in `md` + # and add it to the offset. + offset += sum(length, md.ranges) + end + # If we reach this point, `vn` is not in `vi.metadata`. + throw(KeyError(vn)) +end + +""" + vector_getranges(varinfo::VarInfo, varnames::Vector{<:VarName}) + +Return the range corresponding to `varname` in the vector representation of `varinfo`. +""" +function vector_getranges(varinfo::VarInfo, varname::Vector{<:VarName}) + return map(Base.Fix1(vector_getrange, varinfo), varname) +end +# Specialized version for `TypedVarInfo`. +function vector_getranges(varinfo::TypedVarInfo, vns::Vector{<:VarName}) + # TODO: Does it help if we _don't_ convert to a vector here? + metadatas = collect(values(varinfo.metadata)) + # Extract the offsets. + offsets = cumsum(map(vector_length, metadatas)) + # Extract the ranges from each metadata. + ranges = Vector{UnitRange{Int}}(undef, length(vns)) + # Need to keep track of which ones we've seen. + not_seen = fill(true, length(vns)) + for (i, metadata) in enumerate(metadatas) + vns_metadata = filter(Base.Fix1(haskey, metadata), vns) + # If none of the variables exist in the metadata, we return an empty array. + isempty(vns_metadata) && continue + # Otherwise, we extract the ranges. + offset = i == 1 ? 0 : offsets[i - 1] + for vn in vns_metadata + r_vn = getrange(metadata, vn) + # Get the index, so we return in the same order as `vns`. + # NOTE: There might be duplicates in `vns`, so we need to handle that. + indices = findall(==(vn), vns) + for idx in indices + not_seen[idx] = false + ranges[idx] = r_vn .+ offset + end + end + end + # Raise key error if any of the variables were not found. + if any(not_seen) + inds = findall(not_seen) + # Just use a `convert` to get the same type as the input; don't want to confuse by overly + # specilizing the types in the error message. + throw(KeyError(convert(typeof(vns), vns[inds]))) + end + return ranges end """ @@ -1141,27 +1221,6 @@ function link!!( return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, spl, model) end -""" - link!(vi::VarInfo, spl::Sampler) - -Transform the values of the random variables sampled by `spl` in `vi` from the support -of their distributions to the Euclidean space and set their corresponding `"trans"` -flag values to `true`. -""" -function link!(vi::VarInfo, spl::AbstractSampler) - Base.depwarn( - "`link!(varinfo, sampler)` is deprecated, use `link!!(varinfo, sampler, model)` instead.", - :link!, - ) - return _link!(vi, spl) -end -function link!(vi::VarInfo, spl::AbstractSampler, spaceval::Val) - Base.depwarn( - "`link!(varinfo, sampler, spaceval)` is deprecated, use `link!!(varinfo, sampler, model)` instead.", - :link!, - ) - return _link!(vi, spl, spaceval) -end function _link!(vi::UntypedVarInfo, spl::AbstractSampler) # TODO: Change to a lazy iterator over `vns` vns = _getvns(vi, spl) @@ -1239,29 +1298,6 @@ function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, mode return maybe_invlink_before_eval!!(t, vi, context, model) end -""" - invlink!(vi::VarInfo, spl::AbstractSampler) - -Transform the values of the random variables sampled by `spl` in `vi` from the -Euclidean space back to the support of their distributions and sets their corresponding -`"trans"` flag values to `false`. -""" -function invlink!(vi::VarInfo, spl::AbstractSampler) - Base.depwarn( - "`invlink!(varinfo, sampler)` is deprecated, use `invlink!!(varinfo, sampler, model)` instead.", - :invlink!, - ) - return _invlink!(vi, spl) -end - -function invlink!(vi::VarInfo, spl::AbstractSampler, spaceval::Val) - Base.depwarn( - "`invlink!(varinfo, sampler, spaceval)` is deprecated, use `invlink!!(varinfo, sampler, model)` instead.", - :invlink!, - ) - return _invlink!(vi, spl, spaceval) -end - function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) vns = _getvns(vi, spl) if istrans(vi, vns[1]) @@ -1314,13 +1350,13 @@ end function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) # TODO: Use inplace versions to avoid allocations - yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(vi, vn)) + yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(md, vn)) # Determine the new range. - start = first(getrange(vi, vn)) + start = first(getrange(md, vn)) # NOTE: `length(yvec)` should never be longer than `getrange(vi, vn)`. - setrange!(vi, vn, start:(start + length(yvec) - 1)) + setrange!(md, vn, start:(start + length(yvec) - 1)) # Set the new value. - setval!(vi, yvec, vn) + setval!(md, yvec, vn) acclogp!!(vi, -logjac) return vi end diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index a5097602d..039b549d6 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1036,6 +1036,8 @@ function replace_raw_storage(vnv::VarNamedVector, ::Val{space}, vals) where {spa return replace_raw_storage(vnv, vals) end +vector_length(vnv::VarNamedVector) = length(vnv.vals) - num_inactive(vnv) + """ unflatten(vnv::VarNamedVector, vals::AbstractVector) diff --git a/test/Project.toml b/test/Project.toml index 4ceda2b25..b61cc7a29 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,17 +7,21 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -32,19 +36,22 @@ ADTypes = "1" AbstractMCMC = "5" AbstractPPL = "0.8.4, 0.9" Accessors = "0.1" -Bijectors = "0.13.9, 0.14, 0.15" AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6" +Bijectors = "0.15.1" Combinatorics = "1" Compat = "4.3.0" +DifferentiationInterface = "0.6" Distributions = "0.25" DistributionsAD = "0.6.3" Documenter = "1" EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10.12" +JET = "0.9" LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" MCMCChains = "6.0.4" MacroTools = "0.5.6" +Mooncake = "0.4.59" ReverseDiff = "1" StableRNGs = "1" Tracker = "0.2.23" diff --git a/test/ad.jl b/test/ad.jl index 6046cfda4..17981cf2a 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,4 +1,4 @@ -@testset "AD: ForwardDiff and ReverseDiff" begin +@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS f = DynamicPPL.LogDensityFunction(m) rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) @@ -17,12 +17,60 @@ θ = convert(Vector{Float64}, varinfo[:]) logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ) - @testset "ReverseDiff with compile=$compile" for compile in (false, true) - adtype = ADTypes.AutoReverseDiff(; compile=compile) - ad_f = LogDensityProblemsAD.ADgradient(adtype, f) - _, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ) - @test grad ≈ ref_grad + @testset "$adtype" for adtype in [ + ADTypes.AutoReverseDiff(; compile=false), + ADTypes.AutoReverseDiff(; compile=true), + ADTypes.AutoMooncake(; config=nothing), + ] + # Mooncake can't currently handle something that is going on in + # SimpleVarInfo{<:VarNamedVector}. Disable all SimpleVarInfo tests for now. + if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo + @test_broken 1 == 0 + else + ad_f = LogDensityProblemsAD.ADgradient(adtype, f) + _, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ) + @test grad ≈ ref_grad + end end end end + + @testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin + # Failing model + t = 1:0.05:8 + σ = 0.3 + y = @. rand(sin(t) + Normal(0, σ)) + @model function state_space(y, TT, ::Type{T}=Float64) where {T} + # Priors + α ~ Normal(y[1], 0.001) + τ ~ Exponential(1) + η ~ filldist(Normal(0, 1), TT - 1) + σ ~ Exponential(1) + # create latent variable + x = Vector{T}(undef, TT) + x[1] = α + for t in 2:TT + x[t] = x[t - 1] + η[t - 1] * τ + end + # measurement model + y ~ MvNormal(x, σ^2 * I) + return x + end + model = state_space(y, length(t)) + + # Dummy sampling algorithm for testing. The test case can only be replicated + # with a custom sampler, it doesn't work with SampleFromPrior(). We need to + # overload assume so that model evaluation doesn't fail due to a lack + # of implementation + struct MyEmptyAlg end + DynamicPPL.getspace(::DynamicPPL.Sampler{MyEmptyAlg}) = () + DynamicPPL.assume(rng, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi) = + DynamicPPL.assume(dist, vn, vi) + + # Compiling the ReverseDiff tape used to fail here + spl = Sampler(MyEmptyAlg()) + vi = VarInfo(model) + ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl)) + @test LogDensityProblemsAD.ADgradient(AutoReverseDiff(; compile=true), ldf) isa Any + end end diff --git a/test/compiler.jl b/test/compiler.jl index f2d7e5852..977c1156c 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -382,6 +382,27 @@ module Issue537 end @test demo2()() == 42 end + @testset "@submodel is deprecated" begin + @model inner() = x ~ Normal() + @model outer() = @submodel x = inner() + @test_logs( + ( + :warn, + "`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.", + ), + outer()() + ) + + @model outer_with_prefix() = @submodel prefix = "sub" x = inner() + @test_logs( + ( + :warn, + "`@submodel model` and `@submodel prefix=... model` are deprecated; see `to_submodel` for the up-to-date syntax.", + ), + outer_with_prefix()() + ) + end + @testset "submodel" begin # No prefix, 1 level. @model function demo1(x) @@ -469,7 +490,7 @@ module Issue537 end num_steps = length(y[1]) num_obs = length(y) @inbounds for i in 1:num_obs - @submodel prefix = "ar1_$i" x = AR1(num_steps, α, μ, σ) + x ~ to_submodel(prefix(AR1(num_steps, α, μ, σ), "ar1_$i"), false) y[i] ~ MvNormal(x, 0.01 * I) end end diff --git a/test/debug_utils.jl b/test/debug_utils.jl index 9ca1bc1ba..294364758 100644 --- a/test/debug_utils.jl +++ b/test/debug_utils.jl @@ -45,14 +45,16 @@ @testset "submodel" begin @model ModelInner() = x ~ Normal() @model function ModelOuterBroken() - @submodel z = ModelInner() + # Without automatic prefixing => `x` s used twice. + z ~ to_submodel(ModelInner(), false) return x ~ Normal() end model = ModelOuterBroken() @test_throws ErrorException check_model(model; error_on_failure=true) @model function ModelOuterWorking() - @submodel prefix = true z = ModelInner() + # With automatic prefixing => `x` is not duplicated. + z ~ to_submodel(ModelInner()) x ~ Normal() return z end @@ -197,7 +199,7 @@ @test retype <: Tuple # Just make sure the following is runnable. - @test (DynamicPPL.DebugUtils.model_warntype(model); true) + @test DynamicPPL.DebugUtils.model_warntype(model) isa Any end end end diff --git a/test/ext/DynamicPPLJETExt.jl b/test/ext/DynamicPPLJETExt.jl new file mode 100644 index 000000000..b95107b2d --- /dev/null +++ b/test/ext/DynamicPPLJETExt.jl @@ -0,0 +1,81 @@ +@testset "DynamicPPLJETExt.jl" begin + @testset "determine_suitable_varinfo" begin + @model function demo1() + x ~ Bernoulli() + if x + y ~ Normal() + else + z ~ Normal() + end + end + model = demo1() + @test DynamicPPL.Experimental.determine_suitable_varinfo(model) isa + DynamicPPL.UntypedVarInfo + + @model demo2() = x ~ Normal() + @test DynamicPPL.Experimental.determine_suitable_varinfo(demo2()) isa + DynamicPPL.TypedVarInfo + + @model function demo3() + # Just making sure that nothing strange happens when type inference fails. + x = Vector(undef, 1) + x[1] ~ Bernoulli() + if x[1] + y ~ Normal() + else + z ~ Normal() + end + end + @test DynamicPPL.Experimental.determine_suitable_varinfo(demo3()) isa + DynamicPPL.UntypedVarInfo + + # Evaluation works (and it would even do so in practice), but sampling + # fill fail due to storing `Cauchy{Float64}` in `Vector{Normal{Float64}}`. + @model function demo4() + x ~ Bernoulli() + if x + y ~ Normal() + else + y ~ Cauchy() # different distibution, but same transformation + end + end + @test DynamicPPL.Experimental.determine_suitable_varinfo(demo4()) isa + DynamicPPL.UntypedVarInfo + + # In this model, the type error occurs in the user code rather than in DynamicPPL. + @model function demo5() + x ~ Normal() + xs = Any[] + push!(xs, x) + # `sum(::Vector{Any})` can potentially error unless the dynamic manages to resolve the + # correct `zero` method. As a result, this code will run, but JET will raise this is an issue. + return sum(xs) + end + # Should pass if we're only checking the tilde statements. + @test DynamicPPL.Experimental.determine_suitable_varinfo(demo5()) isa + DynamicPPL.TypedVarInfo + # Should fail if we're including errors in the model body. + @test DynamicPPL.Experimental.determine_suitable_varinfo( + demo5(); only_ddpl=false + ) isa DynamicPPL.UntypedVarInfo + end + + @testset "demo models" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + # Use debug logging below. + varinfo = DynamicPPL.Experimental.determine_suitable_varinfo(model) + # They should all result in typed. + @test varinfo isa DynamicPPL.TypedVarInfo + # But let's also make sure that they're not lying. + f_eval, argtypes_eval = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( + model, varinfo + ) + JET.test_call(f_eval, argtypes_eval) + + f_sample, argtypes_sample = DynamicPPL.DebugUtils.gen_evaluator_call_with_types( + model, varinfo, DynamicPPL.SamplingContext() + ) + JET.test_call(f_sample, argtypes_sample) + end + end +end diff --git a/test/ext/DynamicPPLMCMCChainsExt.jl b/test/ext/DynamicPPLMCMCChainsExt.jl index 25e0d55ba..8693c3b02 100644 --- a/test/ext/DynamicPPLMCMCChainsExt.jl +++ b/test/ext/DynamicPPLMCMCChainsExt.jl @@ -3,7 +3,7 @@ model = demo() chain = MCMCChains.Chains(randn(1000, 2, 1), [:x, :y], Dict(:internals => [:y])) - chain_generated = @test_nowarn generated_quantities(model, chain) + chain_generated = @test_nowarn returned(model, chain) @test size(chain_generated) == (1000, 1) @test mean(chain_generated) ≈ 0 atol = 0.1 end diff --git a/test/ext/DynamicPPLMooncakeExt.jl b/test/ext/DynamicPPLMooncakeExt.jl new file mode 100644 index 000000000..986057da0 --- /dev/null +++ b/test/ext/DynamicPPLMooncakeExt.jl @@ -0,0 +1,5 @@ +@testset "DynamicPPLMooncakeExt" begin + Mooncake.TestUtils.test_rule( + StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true, interface_only=true + ) +end diff --git a/test/lkj.jl b/test/lkj.jl index b9c20f916..d581cd21b 100644 --- a/test/lkj.jl +++ b/test/lkj.jl @@ -22,14 +22,14 @@ _lkj_atol = 0.05 model = lkj_prior_demo() # `SampleFromPrior` will sample in constrained space. @testset "SampleFromPrior" begin - samples = sample(model, SampleFromPrior(), 1_000) + samples = sample(model, SampleFromPrior(), 1_000; progress=false) @test mean(map(Base.Fix2(getindex, Colon()), samples)) ≈ target_mean atol = _lkj_atol end # `SampleFromUniform` will sample in unconstrained space. @testset "SampleFromUniform" begin - samples = sample(model, SampleFromUniform(), 1_000) + samples = sample(model, SampleFromUniform(), 1_000; progress=false) @test mean(map(Base.Fix2(getindex, Colon()), samples)) ≈ target_mean atol = _lkj_atol end @@ -39,7 +39,7 @@ end model = lkj_chol_prior_demo(uplo) # `SampleFromPrior` will sample in unconstrained space. @testset "SampleFromPrior" begin - samples = sample(model, SampleFromPrior(), 1_000) + samples = sample(model, SampleFromPrior(), 1_000; progress=false) # Build correlation matrix from factor corr_matrices = map(samples) do s M = reshape(s.metadata.vals, (2, 2)) @@ -50,7 +50,7 @@ end # `SampleFromUniform` will sample in unconstrained space. @testset "SampleFromUniform" begin - samples = sample(model, SampleFromUniform(), 1_000) + samples = sample(model, SampleFromUniform(), 1_000; progress=false) # Build correlation matrix from factor corr_matrices = map(samples) do s M = reshape(s.metadata.vals, (2, 2)) diff --git a/test/model.jl b/test/model.jl index d163f55f0..a19cb29d2 100644 --- a/test/model.jl +++ b/test/model.jl @@ -29,9 +29,11 @@ is_typed_varinfo(::DynamicPPL.AbstractVarInfo) = false is_typed_varinfo(varinfo::DynamicPPL.TypedVarInfo) = true is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true +const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() + @testset "model.jl" begin @testset "convenience functions" begin - model = gdemo_default # defined in test/test_util.jl + model = GDEMO_DEFAULT # sample from model and extract variables vi = VarInfo(model) @@ -55,53 +57,26 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true @test ljoint ≈ lp #### logprior, logjoint, loglikelihood for MCMC chains #### - for model in DynamicPPL.TestUtils.DEMO_MODELS # length(DynamicPPL.TestUtils.DEMO_MODELS)=12 - var_info = VarInfo(model) - vns = DynamicPPL.TestUtils.varnames(model) - syms = unique(DynamicPPL.getsym.(vns)) - - # generate a chain of sample parameter values. + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS N = 200 - vals_OrderedDict = mapreduce(hcat, 1:N) do _ - rand(OrderedDict, model) - end - vals_mat = mapreduce(hcat, 1:N) do i - [vals_OrderedDict[i][vn] for vn in vns] - end - i = 1 - for col in eachcol(vals_mat) - col_flattened = [] - [push!(col_flattened, x...) for x in col] - if i == 1 - chain_mat = Matrix(reshape(col_flattened, 1, length(col_flattened))) - else - chain_mat = vcat( - chain_mat, reshape(col_flattened, 1, length(col_flattened)) - ) - end - i += 1 - end - chain_mat = convert(Matrix{Float64}, chain_mat) - - # devise parameter names for chain - sample_values_vec = collect(values(vals_OrderedDict[1])) - symbol_names = [] - chain_sym_map = Dict() - for k in 1:length(keys(var_info)) - vn_parent = keys(var_info)[k] + chain = make_chain_from_prior(model, N) + logpriors = logprior(model, chain) + loglikelihoods = loglikelihood(model, chain) + logjoints = logjoint(model, chain) + + # Construct mapping of varname symbols to varname-parent symbols. + # Here, varname_leaves is used to ensure compatibility with the + # variables stored in the chain + var_info = VarInfo(model) + chain_sym_map = Dict{Symbol,Symbol}() + for vn_parent in keys(var_info) sym = DynamicPPL.getsym(vn_parent) - vn_children = DynamicPPL.varname_leaves(vn_parent, sample_values_vec[k]) # `varname_leaves` defined in src/utils.jl + vn_children = DynamicPPL.varname_leaves(vn_parent, var_info[vn_parent]) for vn_child in vn_children chain_sym_map[Symbol(vn_child)] = sym - symbol_names = [symbol_names; Symbol(vn_child)] end end - chain = Chains(chain_mat, symbol_names) - # calculate the pointwise loglikelihoods for the whole chain using the newly written functions - logpriors = logprior(model, chain) - loglikelihoods = loglikelihood(model, chain) - logjoints = logjoint(model, chain) # compare them with true values for i in 1:N samples_dict = Dict() @@ -125,8 +100,21 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end end + @testset "DynamicPPL#684: threadsafe evaluation with multiple types" begin + @model function multiple_types(x) + ns ~ filldist(Normal(0, 2.0), 3) + m ~ Uniform(0, 1) + return x ~ Normal(m, 1) + end + model = multiple_types(1) + chain = make_chain_from_prior(model, 10) + loglikelihood(model, chain) + logprior(model, chain) + logjoint(model, chain) + end + @testset "rng" begin - model = gdemo_default + model = GDEMO_DEFAULT for sampler in (SampleFromPrior(), SampleFromUniform()) for i in 1:10 @@ -144,13 +132,15 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end @testset "defaults without VarInfo, Sampler, and Context" begin - model = gdemo_default + model = GDEMO_DEFAULT Random.seed!(100) - s, m = model() + retval = model() Random.seed!(100) - @test model(Random.default_rng()) == (s, m) + retval2 = model(Random.default_rng()) + @test retval2.s == retval.s + @test retval2.m == retval.m end @testset "nameof" begin @@ -184,7 +174,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end @testset "Internal methods" begin - model = gdemo_default + model = GDEMO_DEFAULT # sample from model and extract variables vi = VarInfo(model) @@ -224,7 +214,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end @testset "rand" begin - model = gdemo_default + model = GDEMO_DEFAULT Random.seed!(1776) s, m = model() @@ -309,7 +299,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true end end - @testset "generated_quantities on `LKJCholesky`" begin + @testset "returned() on `LKJCholesky`" begin n = 10 d = 2 model = DynamicPPL.TestUtils.demo_lkjchol(d) @@ -333,7 +323,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true ) # Test! - results = generated_quantities(model, chain) + results = returned(model, chain) for (x_true, result) in zip(xs, results) @test x_true.UL == result.x.UL end @@ -352,7 +342,7 @@ is_typed_varinfo(varinfo::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = true info=(varname_to_symbol=vns_to_syms_with_extra,), ) # Test! - results = generated_quantities(model, chain_with_extra) + results = returned(model, chain_with_extra) for (x_true, result) in zip(xs, results) @test x_true.UL == result.x.UL end diff --git a/test/model_utils.jl b/test/model_utils.jl new file mode 100644 index 000000000..720ae55aa --- /dev/null +++ b/test/model_utils.jl @@ -0,0 +1,20 @@ +@testset "model_utils.jl" begin + @testset "value_iterator_from_chain" begin + @testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS + # Check that the values generated by value_iterator_from_chain + # match the values in the original chain + chain = make_chain_from_prior(model, 10) + for (i, d) in enumerate(value_iterator_from_chain(model, chain)) + for vn in keys(d) + val = DynamicPPL.getvalue(d, vn) + # Because value_iterator_from_chain groups varnames with + # the same parent symbol, we have to ungroup them here + for vn_leaf in DynamicPPL.varname_leaves(vn, val) + val_leaf = DynamicPPL.getvalue(d, vn_leaf) + @test val_leaf == chain[i, Symbol(vn_leaf), 1] + end + end + end + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 9e4b3a446..8e3bcc3b4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,7 @@ using DynamicPPL using AbstractMCMC using AbstractPPL using Bijectors +using DifferentiationInterface using Distributions using DistributionsAD using Documenter @@ -12,6 +13,8 @@ using ForwardDiff using LogDensityProblems, LogDensityProblemsAD using MacroTools using MCMCChains +using Mooncake: Mooncake +using StableRNGs using Tracker using ReverseDiff using Zygote @@ -26,64 +29,60 @@ using Test using Distributions using LinearAlgebra # Diagonal +using JET: JET + using Combinatorics: combinations +using OrderedCollections: OrderedSet using DynamicPPL: getargs_dottilde, getargs_tilde, Selector -const DIRECTORY_DynamicPPL = dirname(dirname(pathof(DynamicPPL))) -const DIRECTORY_Turing_tests = joinpath(DIRECTORY_DynamicPPL, "test", "turing") const GROUP = get(ENV, "GROUP", "All") - Random.seed!(100) include("test_util.jl") -@testset "DynamicPPL.jl" begin - if GROUP == "All" || GROUP == "DynamicPPL" - @testset "interface" begin - include("utils.jl") - include("compiler.jl") - include("varnamedvector.jl") - include("varinfo.jl") - include("simple_varinfo.jl") - include("model.jl") - include("sampler.jl") - include("independence.jl") - include("distribution_wrappers.jl") - include("contexts.jl") - include("context_implementations.jl") - include("logdensityfunction.jl") - include("linking.jl") - - include("threadsafe.jl") - - include("serialization.jl") - - include("pointwise_logdensities.jl") - - include("lkj.jl") - - include("debug_utils.jl") - end +@testset verbose = true "DynamicPPL.jl" begin + # The tests are split into two groups so that CI can run in parallel. The + # groups are chosen to make both groups take roughly the same amount of + # time, but beyond that there is no particular reason for the split. + if GROUP == "All" || GROUP == "Group1" + include("utils.jl") + include("compiler.jl") + include("varnamedvector.jl") + include("varinfo.jl") + include("simple_varinfo.jl") + include("model.jl") + include("sampler.jl") + include("independence.jl") + include("distribution_wrappers.jl") + include("logdensityfunction.jl") + include("linking.jl") + include("serialization.jl") + include("pointwise_logdensities.jl") + include("lkj.jl") + end + if GROUP == "All" || GROUP == "Group2" + include("contexts.jl") + include("context_implementations.jl") + include("threadsafe.jl") + include("debug_utils.jl") @testset "compat" begin include(joinpath("compat", "ad.jl")) end - @testset "extensions" begin include("ext/DynamicPPLMCMCChainsExt.jl") + include("ext/DynamicPPLJETExt.jl") end - @testset "ad" begin include("ext/DynamicPPLForwardDiffExt.jl") + include("ext/DynamicPPLMooncakeExt.jl") include("ad.jl") end - @testset "prob and logprob macro" begin @test_throws ErrorException prob"..." @test_throws ErrorException logprob"..." end - @testset "doctests" begin DocMeta.setdocmeta!( DynamicPPL, @@ -104,33 +103,11 @@ include("test_util.jl") # Older versions do not have `;;]` but instead just `]` at end of the line # => need to treat `;;]` and `]` as the same, i.e. ignore them if at the end of a line r"(;;){0,1}\]$"m, + # Ignore the source of a warning in the doctest output, since this is dependent on host. + # This is a line that starts with "└ @ " and ends with the line number. + r"└ @ .+:[0-9]+", ] doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters) end end - - if GROUP == "All" || GROUP == "Downstream" - @testset "turing" begin - try - # activate separate test environment - Pkg.activate(DIRECTORY_Turing_tests) - Pkg.develop(PackageSpec(; path=DIRECTORY_DynamicPPL)) - Pkg.instantiate() - - # make sure that the new environment is considered `using` and `import` statements - # (not added automatically on Julia 1.3, see e.g. PR #209) - if !(joinpath(DIRECTORY_Turing_tests, "Project.toml") in Base.load_path()) - pushfirst!(LOAD_PATH, DIRECTORY_Turing_tests) - end - - include(joinpath("turing", "runtests.jl")) - catch err - err isa Pkg.Resolve.ResolverError || rethrow() - # If we can't resolve that means this is incompatible by SemVer and this is fine - # It means we marked this as a breaking change, so we don't need to worry about - # Mistakenly introducing a breaking change, as we have intentionally made one - @info "Not compatible with this release. No problem." exception = err - end - end - end end diff --git a/test/sampler.jl b/test/sampler.jl index 95e838167..e5fe6dc98 100644 --- a/test/sampler.jl +++ b/test/sampler.jl @@ -34,22 +34,20 @@ @testset "init" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - N = 1000 - chain_init = sample(model, SampleFromUniform(), N; progress=false) - - for vn in keys(first(chain_init)) - if AbstractPPL.subsumes(@varname(s), vn) - # `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2]. - dist = InverseGamma(2, 3) - b = DynamicPPL.link_transform(dist) - @test mean(mean(b(vi[vn])) for vi in chain_init) ≈ 0 atol = 0.11 - elseif AbstractPPL.subsumes(@varname(m), vn) - # `m ~ Normal(0, sqrt(s))` and its constrained value is the same. - @test mean(mean(vi[vn]) for vi in chain_init) ≈ 0 atol = 0.11 - else - error("Unknown variable name: $vn") - end + N = 1000 + chain_init = sample(model, SampleFromUniform(), N; progress=false) + + for vn in keys(first(chain_init)) + if AbstractPPL.subsumes(@varname(s), vn) + # `s ~ InverseGamma(2, 3)` and its unconstrained value will be sampled from Unif[-2,2]. + dist = InverseGamma(2, 3) + b = DynamicPPL.link_transform(dist) + @test mean(mean(b(vi[vn])) for vi in chain_init) ≈ 0 atol = 0.11 + elseif AbstractPPL.subsumes(@varname(m), vn) + # `m ~ Normal(0, sqrt(s))` and its constrained value is the same. + @test mean(mean(vi[vn]) for vi in chain_init) ≈ 0 atol = 0.11 + else + error("Unknown variable name: $vn") end end end diff --git a/test/test_util.jl b/test/test_util.jl index f1325b729..27a68456c 100644 --- a/test/test_util.jl +++ b/test/test_util.jl @@ -43,40 +43,6 @@ function test_model_ad(model, logp_manual) @test back(1)[1] ≈ grad end -""" - test_setval!(model, chain; sample_idx = 1, chain_idx = 1) - -Test `setval!` on `model` and `chain`. - -Worth noting that this only supports models containing symbols of the forms -`m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. -""" -function test_setval!(model, chain; sample_idx=1, chain_idx=1) - var_info = VarInfo(model) - spl = SampleFromPrior() - θ_old = var_info[spl] - DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) - θ_new = var_info[spl] - @test θ_old != θ_new - vals = DynamicPPL.values_as(var_info, OrderedDict) - iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) - for (n, v) in mapreduce(collect, vcat, iters) - n = string(n) - if Symbol(n) ∉ keys(chain) - # Assume it's a group - chain_val = vec( - MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] - ) - v_true = vec(v) - else - chain_val = chain[sample_idx, n, chain_idx] - v_true = v - end - - @test v_true == chain_val - end -end - """ short_varinfo_name(vi::AbstractVarInfo) @@ -110,3 +76,36 @@ function modify_value_representation(nt::NamedTuple) end return modified_nt end + +""" + make_chain_from_prior([rng,] model, n_iters) + +Construct an MCMCChains.Chains object by sampling from the prior of `model` for +`n_iters` iterations. +""" +function make_chain_from_prior(rng::Random.AbstractRNG, model::Model, n_iters::Int) + # Sample from the prior + varinfos = [VarInfo(rng, model) for _ in 1:n_iters] + # Extract all varnames found in any dictionary. Doing it this way guards + # against the possibility of having different varnames in different + # dictionaries, e.g. for models that have dynamic variables / array sizes + varnames = OrderedSet{VarName}() + # Convert each varinfo into an OrderedDict of vns => params. + # We have to use varname_and_value_leaves so that each parameter is a scalar + dicts = map(varinfos) do t + vals = DynamicPPL.values_as(t, OrderedDict) + iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + tuples = mapreduce(collect, vcat, iters) + push!(varnames, map(first, tuples)...) + OrderedDict(tuples) + end + # Convert back to list + varnames = collect(varnames) + # Construct matrix of values + vals = [get(dict, vn, missing) for dict in dicts, vn in varnames] + # Construct and return the Chains object + return Chains(vals, varnames) +end +function make_chain_from_prior(model::Model, n_iters::Int) + return make_chain_from_prior(Random.default_rng(), model, n_iters) +end diff --git a/test/turing/Project.toml b/test/turing/Project.toml deleted file mode 100644 index 28341c20b..000000000 --- a/test/turing/Project.toml +++ /dev/null @@ -1,19 +0,0 @@ -[deps] -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" -HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" - -[compat] -Distributions = "0.25" -DynamicPPL = "0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30" -HypothesisTests = "0.11" -MCMCChains = "6" -ReverseDiff = "1.15" -Turing = "0.33, 0.34, 0.35" -julia = "1.7" diff --git a/test/turing/compiler.jl b/test/turing/compiler.jl deleted file mode 100644 index 5c46ab777..000000000 --- a/test/turing/compiler.jl +++ /dev/null @@ -1,348 +0,0 @@ -@testset "compiler.jl" begin - @testset "assume" begin - @model function test_assume() - x ~ Bernoulli(1) - y ~ Bernoulli(x / 2) - return x, y - end - - smc = SMC() - pg = PG(10) - - res1 = sample(test_assume(), smc, 1000) - res2 = sample(test_assume(), pg, 1000) - - check_numerical(res1, [:y], [0.5]; atol=0.1) - check_numerical(res2, [:y], [0.5]; atol=0.1) - - # Check that all xs are 1. - @test all(isone, res1[:x]) - @test all(isone, res2[:x]) - end - @testset "beta binomial" begin - prior = Beta(2, 2) - obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1] - exact = Beta(prior.α + sum(obs), prior.β + length(obs) - sum(obs)) - meanp = exact.α / (exact.α + exact.β) - - @model function testbb(obs) - p ~ Beta(2, 2) - x ~ Bernoulli(p) - for i in 1:length(obs) - obs[i] ~ Bernoulli(p) - end - return p, x - end - - smc = SMC() - pg = PG(10) - gibbs = Gibbs(HMC(0.2, 3, :p), PG(10, :x)) - - chn_s = sample(testbb(obs), smc, 1000) - chn_p = sample(testbb(obs), pg, 2000) - chn_g = sample(testbb(obs), gibbs, 1500) - - check_numerical(chn_s, [:p], [meanp]; atol=0.05) - check_numerical(chn_p, [:x], [meanp]; atol=0.1) - check_numerical(chn_g, [:x], [meanp]; atol=0.1) - end - @testset "forbid global" begin - xs = [1.5 2.0] - # xx = 1 - - @model function fggibbstest(xs) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - # xx ~ Normal(m, sqrt(s)) # this is illegal - - for i in 1:length(xs) - xs[i] ~ Normal(m, sqrt(s)) - # for xx in xs - # xx ~ Normal(m, sqrt(s)) - end - return s, m - end - - gibbs = Gibbs(PG(10, :s), HMC(0.4, 8, :m)) - chain = sample(fggibbstest(xs), gibbs, 2) - end - @testset "new grammar" begin - x = Float64[1 2] - - @model function gauss(x) - priors = Array{Float64}(undef, 2) - priors[1] ~ InverseGamma(2, 3) # s - priors[2] ~ Normal(0, sqrt(priors[1])) # m - for i in 1:length(x) - x[i] ~ Normal(priors[2], sqrt(priors[1])) - end - return priors - end - - chain = sample(gauss(x), PG(10), 10) - chain = sample(gauss(x), SMC(), 10) - - @model function gauss2(::Type{TV}=Vector{Float64}; x) where {TV} - priors = TV(undef, 2) - priors[1] ~ InverseGamma(2, 3) # s - priors[2] ~ Normal(0, sqrt(priors[1])) # m - for i in 1:length(x) - x[i] ~ Normal(priors[2], sqrt(priors[1])) - end - return priors - end - - @test_throws ErrorException chain = sample(gauss2(; x=x), PG(10), 10) - @test_throws ErrorException chain = sample(gauss2(; x=x), SMC(), 10) - - @test_throws ErrorException chain = sample( - gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), PG(10), 10 - ) - @test_throws ErrorException chain = sample( - gauss2(DynamicPPL.TypeWrap{Vector{Float64}}(); x=x), SMC(), 10 - ) - end - @testset "new interface" begin - obs = [0, 1, 0, 1, 1, 1, 1, 1, 1, 1] - - @model function newinterface(obs) - p ~ Beta(2, 2) - for i in 1:length(obs) - obs[i] ~ Bernoulli(p) - end - return p - end - - chain = sample( - newinterface(obs), - HMC(0.75, 3, :p, :x; adtype=AutoForwardDiff(; chunksize=2)), - 100, - ) - end - @testset "no return" begin - @model function noreturn(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - for i in 1:length(x) - x[i] ~ Normal(m, sqrt(s)) - end - end - - chain = sample(noreturn([1.5 2.0]), HMC(0.15, 6), 1000) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6]) - end - @testset "observe" begin - @model function test() - z ~ Normal(0, 1) - x ~ Bernoulli(1) - 1 ~ Bernoulli(x / 2) - 0 ~ Bernoulli(x / 2) - return x - end - - is = IS() - smc = SMC() - pg = PG(10) - - res_is = sample(test(), is, 10000) - res_smc = sample(test(), smc, 1000) - res_pg = sample(test(), pg, 100) - - @test all(isone, res_is[:x]) - @test res_is.logevidence ≈ 2 * log(0.5) - - @test all(isone, res_smc[:x]) - @test res_smc.logevidence ≈ 2 * log(0.5) - - @test all(isone, res_pg[:x]) - end - @testset "sample" begin - alg = Gibbs(HMC(0.2, 3, :m), PG(10, :s)) - chn = sample(gdemo_default, alg, 1000) - end - @testset "vectorization @." begin - @model function vdemo1(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - @. x ~ Normal(m, sqrt(s)) - return s, m - end - - alg = HMC(0.01, 5) - x = randn(100) - res = sample(vdemo1(x), alg, 250) - - @model function vdemo1b(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - @. x ~ Normal(m, $(sqrt(s))) - return s, m - end - - res = sample(vdemo1b(x), alg, 250) - - @model function vdemo2(x) - μ ~ MvNormal(zeros(size(x, 1)), I) - @. x ~ $(MvNormal(μ, I)) - end - - D = 2 - alg = HMC(0.01, 5) - res = sample(vdemo2(randn(D, 100)), alg, 250) - - # Vector assumptions - N = 10 - alg = HMC(0.2, 4; adtype=AutoForwardDiff(; chunksize=N)) - - @model function vdemo3() - x = Vector{Real}(undef, N) - for i in 1:N - x[i] ~ Normal(0, sqrt(4)) - end - end - - t_loop = @elapsed res = sample(vdemo3(), alg, 1000) - - # Test for vectorize UnivariateDistribution - @model function vdemo4() - x = Vector{Real}(undef, N) - @. x ~ Normal(0, 2) - end - - t_vec = @elapsed res = sample(vdemo4(), alg, 1000) - - @model vdemo5() = x ~ MvNormal(zeros(N), 4 * I) - - t_mv = @elapsed res = sample(vdemo5(), alg, 1000) - - println("Time for") - println(" Loop : ", t_loop) - println(" Vec : ", t_vec) - println(" Mv : ", t_mv) - - # Transformed test - @model function vdemo6() - x = Vector{Real}(undef, N) - @. x ~ InverseGamma(2, 3) - end - - sample(vdemo6(), alg, 1000) - - N = 3 - @model function vdemo7() - x = Array{Real}(undef, N, N) - @. x ~ [InverseGamma(2, 3) for i in 1:N] - end - - sample(vdemo7(), alg, 1000) - end - @testset "vectorization .~" begin - @model function vdemo1(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - x .~ Normal(m, sqrt(s)) - return s, m - end - - alg = HMC(0.01, 5) - x = randn(100) - res = sample(vdemo1(x), alg, 250) - - @model function vdemo2(x) - μ ~ MvNormal(zeros(size(x, 1)), I) - return x .~ MvNormal(μ, I) - end - - D = 2 - alg = HMC(0.01, 5) - res = sample(vdemo2(randn(D, 100)), alg, 250) - - # Vector assumptions - N = 10 - alg = HMC(0.2, 4; adtype=AutoForwardDiff(; chunksize=N)) - - @model function vdemo3() - x = Vector{Real}(undef, N) - for i in 1:N - x[i] ~ Normal(0, sqrt(4)) - end - end - - t_loop = @elapsed res = sample(vdemo3(), alg, 1000) - - # Test for vectorize UnivariateDistribution - @model function vdemo4() - x = Vector{Real}(undef, N) - return x .~ Normal(0, 2) - end - - t_vec = @elapsed res = sample(vdemo4(), alg, 1000) - - @model vdemo5() = x ~ MvNormal(zeros(N), 4 * I) - - t_mv = @elapsed res = sample(vdemo5(), alg, 1000) - - println("Time for") - println(" Loop : ", t_loop) - println(" Vec : ", t_vec) - println(" Mv : ", t_mv) - - # Transformed test - @model function vdemo6() - x = Vector{Real}(undef, N) - return x .~ InverseGamma(2, 3) - end - - sample(vdemo6(), alg, 1000) - - @model function vdemo7() - x = Array{Real}(undef, N, N) - return x .~ [InverseGamma(2, 3) for i in 1:N] - end - - sample(vdemo7(), alg, 1000) - end - @testset "Type parameters" begin - N = 10 - alg = HMC(0.01, 5; adtype=AutoForwardDiff(; chunksize=N)) - x = randn(1000) - @model function vdemo1(::Type{T}=Float64) where {T} - x = Vector{T}(undef, N) - for i in 1:N - x[i] ~ Normal(0, sqrt(4)) - end - end - - t_loop = @elapsed res = sample(vdemo1(), alg, 250) - t_loop = @elapsed res = sample(vdemo1(DynamicPPL.TypeWrap{Float64}()), alg, 250) - - vdemo1kw(; T) = vdemo1(T) - t_loop = @elapsed res = sample( - vdemo1kw(; T=DynamicPPL.TypeWrap{Float64}()), alg, 250 - ) - - @model function vdemo2(::Type{T}=Float64) where {T<:Real} - x = Vector{T}(undef, N) - @. x ~ Normal(0, 2) - end - - t_vec = @elapsed res = sample(vdemo2(), alg, 250) - t_vec = @elapsed res = sample(vdemo2(DynamicPPL.TypeWrap{Float64}()), alg, 250) - - vdemo2kw(; T) = vdemo2(T) - t_vec = @elapsed res = sample( - vdemo2kw(; T=DynamicPPL.TypeWrap{Float64}()), alg, 250 - ) - - @model function vdemo3(::Type{TV}=Vector{Float64}) where {TV<:AbstractVector} - x = TV(undef, N) - @. x ~ InverseGamma(2, 3) - end - - sample(vdemo3(), alg, 250) - sample(vdemo3(DynamicPPL.TypeWrap{Vector{Float64}}()), alg, 250) - - vdemo3kw(; T) = vdemo3(T) - sample(vdemo3kw(; T=DynamicPPL.TypeWrap{Vector{Float64}}()), alg, 250) - end -end diff --git a/test/turing/loglikelihoods.jl b/test/turing/loglikelihoods.jl deleted file mode 100644 index 5f1c41572..000000000 --- a/test/turing/loglikelihoods.jl +++ /dev/null @@ -1,41 +0,0 @@ -@testset "loglikelihoods.jl" begin - @model function demo(xs, y) - s ~ InverseGamma(2, 3) - m ~ Normal(0, √s) - for i in eachindex(xs) - xs[i] ~ Normal(m, √s) - end - - return y ~ Normal(m, √s) - end - - xs = randn(3) - y = randn() - model = demo(xs, y) - chain = sample(model, MH(), MCMCThreads(), 100, 2) - var_to_likelihoods = pointwise_loglikelihoods( - model, MCMCChains.get_sections(chain, :parameters) - ) - @test haskey(var_to_likelihoods, "xs[1]") - @test haskey(var_to_likelihoods, "xs[2]") - @test haskey(var_to_likelihoods, "xs[3]") - @test haskey(var_to_likelihoods, "y") - - for chain_idx in MCMCChains.chains(chain) - for (i, (s, m)) in enumerate(zip(chain[:, :s, chain_idx], chain[:, :m, chain_idx])) - @test logpdf(Normal(m, √s), xs[1]) == var_to_likelihoods["xs[1]"][i, chain_idx] - @test logpdf(Normal(m, √s), xs[2]) == var_to_likelihoods["xs[2]"][i, chain_idx] - @test logpdf(Normal(m, √s), xs[3]) == var_to_likelihoods["xs[3]"][i, chain_idx] - @test logpdf(Normal(m, √s), y) == var_to_likelihoods["y"][i, chain_idx] - end - end - - var_info = VarInfo(model) - results = pointwise_loglikelihoods(model, var_info) - var_to_likelihoods = Dict(string(vn) => ℓ for (vn, ℓ) in results) - s, m = var_info[SampleFromPrior()] - @test [logpdf(Normal(m, √s), xs[1])] == var_to_likelihoods["xs[1]"] - @test [logpdf(Normal(m, √s), xs[2])] == var_to_likelihoods["xs[2]"] - @test [logpdf(Normal(m, √s), xs[3])] == var_to_likelihoods["xs[3]"] - @test [logpdf(Normal(m, √s), y)] == var_to_likelihoods["y"] -end diff --git a/test/turing/model.jl b/test/turing/model.jl deleted file mode 100644 index 599fba21b..000000000 --- a/test/turing/model.jl +++ /dev/null @@ -1,27 +0,0 @@ -@testset "model.jl" begin - @testset "setval! & generated_quantities" begin - @testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS - chain = sample(model, Prior(), 10) - # A simple way of checking that the computation is determinstic: run twice and compare. - res1 = generated_quantities(model, MCMCChains.get_sections(chain, :parameters)) - res2 = generated_quantities(model, MCMCChains.get_sections(chain, :parameters)) - @test all(res1 .== res2) - test_setval!(model, MCMCChains.get_sections(chain, :parameters)) - end - end - - @testset "value_iterator_from_chain" begin - @testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS - chain = sample(model, Prior(), 10; progress=false) - for (i, d) in enumerate(value_iterator_from_chain(model, chain)) - for vn in keys(d) - val = DynamicPPL.getvalue(d, vn) - for vn_leaf in DynamicPPL.varname_leaves(vn, val) - val_leaf = DynamicPPL.getvalue(d, vn_leaf) - @test val_leaf == chain[i, Symbol(vn_leaf), 1] - end - end - end - end - end -end diff --git a/test/turing/runtests.jl b/test/turing/runtests.jl deleted file mode 100644 index b3ff72819..000000000 --- a/test/turing/runtests.jl +++ /dev/null @@ -1,24 +0,0 @@ -using DynamicPPL -using Turing -using LinearAlgebra -using ReverseDiff - -using Random -using Test - -setprogress!(false) - -Random.seed!(100) - -# load test utilities -include(joinpath(pathof(DynamicPPL), "..", "..", "test", "test_util.jl")) -include(joinpath(pathof(Turing), "..", "..", "test", "test_utils", "numerical_tests.jl")) - -using .NumericalTests: check_numerical - -@testset "Turing" begin - include("compiler.jl") - include("loglikelihoods.jl") - include("model.jl") - include("varinfo.jl") -end diff --git a/test/turing/varinfo.jl b/test/turing/varinfo.jl deleted file mode 100644 index e5b8eb79f..000000000 --- a/test/turing/varinfo.jl +++ /dev/null @@ -1,360 +0,0 @@ -@testset "varinfo.jl" begin - # Declare empty model to make the Sampler constructor work. - @model empty_model() = begin - x = 1 - end - - function randr( - vi::VarInfo, vn::VarName, dist::Distribution, spl::Sampler, count::Bool=false - ) - if !haskey(vi, vn) - r = rand(dist) - push!!(vi, vn, r, dist, spl) - r - elseif is_flagged(vi, vn, "del") - unset_flag!(vi, vn, "del") - r = rand(dist) - vi[vn] = DynamicPPL.tovec(r) - setorder!(vi, vn, get_num_produce(vi)) - r - else - count && checkindex(vn, vi, spl) - DynamicPPL.updategid!(vi, vn, spl) - vi[vn] - end - end - - @testset "link!" begin - # Test linking spl and vi: - # link!, invlink!, istrans - @model gdemo(x, y) = begin - s ~ InverseGamma(2, 3) - m ~ Uniform(0, 2) - x ~ Normal(m, sqrt(s)) - y ~ Normal(m, sqrt(s)) - end - model = gdemo(1.0, 2.0) - - vi = VarInfo() - meta = vi.metadata - model(vi, SampleFromUniform()) - @test all(x -> !istrans(vi, x), meta.vns) - - alg = HMC(0.1, 5) - spl = Sampler(alg, model) - v = copy(meta.vals) - link!(vi, spl) - @test all(x -> istrans(vi, x), meta.vns) - invlink!(vi, spl) - @test all(x -> !istrans(vi, x), meta.vns) - @test meta.vals == v - - vi = TypedVarInfo(vi) - meta = vi.metadata - alg = HMC(0.1, 5) - spl = Sampler(alg, model) - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) - v_s = copy(meta.s.vals) - v_m = copy(meta.m.vals) - link!(vi, spl) - @test all(x -> istrans(vi, x), meta.s.vns) - @test all(x -> istrans(vi, x), meta.m.vns) - invlink!(vi, spl) - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) - @test meta.s.vals == v_s - @test meta.m.vals == v_m - - # Transforming only a subset of the variables - link!(vi, spl, Val((:m,))) - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> istrans(vi, x), meta.m.vns) - invlink!(vi, spl, Val((:m,))) - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) - @test meta.s.vals == v_s - @test meta.m.vals == v_m - end - @testset "orders" begin - csym = gensym() # unique per model - vn_z1 = @varname z[1] - vn_z2 = @varname z[2] - vn_z3 = @varname z[3] - vn_z4 = @varname z[4] - vn_a1 = @varname a[1] - vn_a2 = @varname a[2] - vn_b = @varname b - - vi = VarInfo() - dists = [Categorical([0.7, 0.3]), Normal()] - - spl1 = Sampler(PG(5), empty_model()) - spl2 = Sampler(PG(5), empty_model()) - - # First iteration, variables are added to vi - # variables samples in order: z1,a1,z2,a2,z3 - increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) - increment_num_produce!(vi) - randr(vi, vn_b, dists[2], spl2) - randr(vi, vn_z2, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) - increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) - @test vi.metadata.orders == [1, 1, 2, 2, 2, 3] - @test get_num_produce(vi) == 3 - - reset_num_produce!(vi) - set_retained_vns_del_by_spl!(vi, spl1) - @test is_flagged(vi, vn_z1, "del") - @test is_flagged(vi, vn_a1, "del") - @test is_flagged(vi, vn_z2, "del") - @test is_flagged(vi, vn_a2, "del") - @test is_flagged(vi, vn_z3, "del") - - increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) - increment_num_produce!(vi) - randr(vi, vn_z2, dists[1], spl1) - increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) - @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] - @test get_num_produce(vi) == 3 - - vi = empty!!(TypedVarInfo(vi)) - # First iteration, variables are added to vi - # variables samples in order: z1,a1,z2,a2,z3 - increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) - increment_num_produce!(vi) - randr(vi, vn_b, dists[2], spl2) - randr(vi, vn_z2, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) - increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) - @test vi.metadata.z.orders == [1, 2, 3] - @test vi.metadata.a.orders == [1, 2] - @test vi.metadata.b.orders == [2] - @test get_num_produce(vi) == 3 - - reset_num_produce!(vi) - set_retained_vns_del_by_spl!(vi, spl1) - @test is_flagged(vi, vn_z1, "del") - @test is_flagged(vi, vn_a1, "del") - @test is_flagged(vi, vn_z2, "del") - @test is_flagged(vi, vn_a2, "del") - @test is_flagged(vi, vn_z3, "del") - - increment_num_produce!(vi) - randr(vi, vn_z1, dists[1], spl1) - randr(vi, vn_a1, dists[2], spl1) - increment_num_produce!(vi) - randr(vi, vn_z2, dists[1], spl1) - increment_num_produce!(vi) - randr(vi, vn_z3, dists[1], spl1) - randr(vi, vn_a2, dists[2], spl1) - @test vi.metadata.z.orders == [1, 2, 3] - @test vi.metadata.a.orders == [1, 3] - @test vi.metadata.b.orders == [2] - @test get_num_produce(vi) == 3 - end - @testset "replay" begin - # Generate synthesised data - xs = rand(Normal(0.5, 1), 100) - - # Define model - @model function priorsinarray(xs, ::Type{T}=Float64) where {T} - begin - priors = Vector{T}(undef, 2) - priors[1] ~ InverseGamma(2, 3) - priors[2] ~ Normal(0, sqrt(priors[1])) - for i in 1:length(xs) - xs[i] ~ Normal(priors[2], sqrt(priors[1])) - end - priors - end - end - - # Sampling - chain = sample(priorsinarray(xs), HMC(0.01, 10), 10) - end - @testset "varname" begin - @model function mat_name_test() - p = Array{Any}(undef, 2, 2) - for i in 1:2, j in 1:2 - p[i, j] ~ Normal(0, 1) - end - return p - end - chain = sample(mat_name_test(), HMC(0.2, 4), 1000) - check_numerical(chain, ["p[1, 1]"], [0]; atol=0.25) - - @model function marr_name_test() - p = Array{Array{Any}}(undef, 2) - p[1] = Array{Any}(undef, 2) - p[2] = Array{Any}(undef, 2) - for i in 1:2, j in 1:2 - p[i][j] ~ Normal(0, 1) - end - return p - end - - chain = sample(marr_name_test(), HMC(0.2, 4), 1000) - check_numerical(chain, ["p[1][1]"], [0]; atol=0.25) - end - @testset "varinfo" begin - dists = [Normal(0, 1), MvNormal(zeros(2), I), Wishart(7, [1 0.5; 0.5 1])] - function test_varinfo!(vi) - @test getlogp(vi) === 0.0 - vi = setlogp!!(vi, 1) - @test getlogp(vi) === 1.0 - vi = acclogp!!(vi, 1) - @test getlogp(vi) === 2.0 - vi = resetlogp!!(vi) - @test getlogp(vi) === 0.0 - - spl2 = Sampler(PG(5, :w, :u), empty_model()) - vn_w = @varname w - randr(vi, vn_w, dists[1], spl2, true) - - vn_x = @varname x - vn_y = @varname y - vn_z = @varname z - vns = [vn_x, vn_y, vn_z] - - spl1 = Sampler(PG(5, :x, :y, :z), empty_model()) - for i in 1:3 - r = randr(vi, vns[i], dists[i], spl1, false) - val = vi[vns[i]] - @test sum(val - r) <= 1e-9 - end - - idcs = DynamicPPL._getidcs(vi, spl1) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 3 - else - @test length(idcs) == 3 - end - @test length(vi[spl1]) == 7 - - idcs = DynamicPPL._getidcs(vi, spl2) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 1 - else - @test length(idcs) == 1 - end - @test length(vi[spl2]) == 1 - - vn_u = @varname u - randr(vi, vn_u, dists[1], spl2, true) - - idcs = DynamicPPL._getidcs(vi, spl2) - if idcs isa NamedTuple - @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 2 - else - @test length(idcs) == 2 - end - @test length(vi[spl2]) == 2 - end - vi = VarInfo() - test_varinfo!(vi) - test_varinfo!(empty!!(TypedVarInfo(vi))) - - @model igtest() = begin - x ~ InverseGamma(2, 3) - y ~ InverseGamma(2, 3) - z ~ InverseGamma(2, 3) - w ~ InverseGamma(2, 3) - u ~ InverseGamma(2, 3) - end - - # Test the update of group IDs - g_demo_f = igtest() - - # This test section no longer seems as applicable, considering the - # user will never end up using an UntypedVarInfo. The `VarInfo` - # Varible is also not passed around in the same way as it used to be. - - # TODO: Has to be fixed - - #= g = Sampler(Gibbs(PG(10, :x, :y, :z), HMC(0.4, 8, :w, :u)), g_demo_f) - vi = VarInfo() - g_demo_f(vi, SampleFromPrior()) - _, state = @inferred AbstractMCMC.step(Random.default_rng(), g_demo_f, g) - pg, hmc = state.states - @test pg isa TypedVarInfo - @test hmc isa Turing.Inference.HMCState - vi1 = state.vi - @test mapreduce(x -> x.gids, vcat, vi1.metadata) == - [Set([pg.selector]), Set([pg.selector]), Set([pg.selector]), Set{Selector}(), Set{Selector}()] - - @inferred g_demo_f(vi1, hmc) - @test mapreduce(x -> x.gids, vcat, vi1.metadata) == - [Set([pg.selector]), Set([pg.selector]), Set([pg.selector]), Set([hmc.selector]), Set([hmc.selector])] - - g = Sampler(Gibbs(PG(10, :x, :y, :z), HMC(0.4, 8, :w, :u)), g_demo_f) - pg, hmc = g.state.samplers - vi = empty!!(TypedVarInfo(vi)) - @inferred g_demo_f(vi, SampleFromPrior()) - pg.state.vi = vi - step!(Random.default_rng(), g_demo_f, pg, 1) - vi = pg.state.vi - @inferred g_demo_f(vi, hmc) - @test vi.metadata.x.gids[1] == Set([pg.selector]) - @test vi.metadata.y.gids[1] == Set([pg.selector]) - @test vi.metadata.z.gids[1] == Set([pg.selector]) - @test vi.metadata.w.gids[1] == Set([hmc.selector]) - @test vi.metadata.u.gids[1] == Set([hmc.selector]) =# - end - - @testset "Turing#2151: eltype(vi, spl)" begin - # build data - t = 1:0.05:8 - σ = 0.3 - y = @. rand(sin(t) + Normal(0, σ)) - - @model function state_space(y, TT, ::Type{T}=Float64) where {T} - # Priors - α ~ Normal(y[1], 0.001) - τ ~ Exponential(1) - η ~ filldist(Normal(0, 1), TT - 1) - σ ~ Exponential(1) - - # create latent variable - x = Vector{T}(undef, TT) - x[1] = α - for t in 2:TT - x[t] = x[t - 1] + η[t - 1] * τ - end - - # measurement model - y ~ MvNormal(x, σ^2 * I) - - return x - end - - n = 10 - model = state_space(y, length(t)) - @test size(sample(model, NUTS(; adtype=AutoReverseDiff(true)), n), 1) == n - end - - if Threads.nthreads() > 1 - @testset "DynamicPPL#684: OrderedDict with multiple types when multithreaded" begin - @model function f(x) - ns ~ filldist(Normal(0, 2.0), 3) - m ~ Uniform(0, 1) - return x ~ Normal(m, 1) - end - model = f(1) - chain = sample(model, NUTS(), MCMCThreads(), 10, 2) - loglikelihood(model, chain) - logprior(model, chain) - logjoint(model, chain) - end - end -end diff --git a/test/varinfo.jl b/test/varinfo.jl index c45fb47e0..9a55cffb9 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -1,3 +1,8 @@ +# Dummy algorithm for testing +# Invoke with: DynamicPPL.Sampler(MyAlg{(:x, :y)}(), ...) +struct MyAlg{space} end +DynamicPPL.getspace(::DynamicPPL.Sampler{MyAlg{space}}) where {space} = space + function check_varinfo_keys(varinfo, vns) if varinfo isa DynamicPPL.SimpleOrThreadSafeSimple{<:NamedTuple} # NOTE: We can't compare the `keys(varinfo_merged)` directly with `vns`, @@ -14,9 +19,29 @@ function check_varinfo_keys(varinfo, vns) end end -# A simple "algorithm" which only has `s` variables in its space. -struct MySAlg end -DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) +function randr( + vi::DynamicPPL.VarInfo, + vn::VarName, + dist::Distribution, + spl::DynamicPPL.Sampler, + count::Bool=false, +) + if !haskey(vi, vn) + r = rand(dist) + push!!(vi, vn, r, dist, spl) + r + elseif DynamicPPL.is_flagged(vi, vn, "del") + DynamicPPL.unset_flag!(vi, vn, "del") + r = rand(dist) + vi[vn] = DynamicPPL.tovec(r) + DynamicPPL.setorder!(vi, vn, DynamicPPL.get_num_produce(vi)) + r + else + count && checkindex(vn, vi, spl) + DynamicPPL.updategid!(vi, vn, spl) + vi[vn] + end +end @testset "varinfo.jl" begin @testset "TypedVarInfo with Metadata" begin @@ -130,6 +155,26 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) test_base!!(SimpleVarInfo(Dict())) test_base!!(SimpleVarInfo(DynamicPPL.VarNamedVector())) end + + @testset "get/set/acc/resetlogp" begin + function test_varinfo_logp!(vi) + @test DynamicPPL.getlogp(vi) === 0.0 + vi = DynamicPPL.setlogp!!(vi, 1.0) + @test DynamicPPL.getlogp(vi) === 1.0 + vi = DynamicPPL.acclogp!!(vi, 1.0) + @test DynamicPPL.getlogp(vi) === 2.0 + vi = DynamicPPL.resetlogp!!(vi) + @test DynamicPPL.getlogp(vi) === 0.0 + end + + vi = VarInfo() + test_varinfo_logp!(vi) + test_varinfo_logp!(TypedVarInfo(vi)) + test_varinfo_logp!(SimpleVarInfo()) + test_varinfo_logp!(SimpleVarInfo(Dict())) + test_varinfo_logp!(SimpleVarInfo(DynamicPPL.VarNamedVector())) + end + @testset "flags" begin # Test flag setting: # is_flagged, set_flag!, unset_flag! @@ -187,6 +232,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) setgid!(vi, gid2, vn) @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1, gid2]) end + @testset "setval! & setval_and_resample!" begin @model function testmodel(x) n = length(x) @@ -339,6 +385,103 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test vals_prev == vi.metadata.x.vals end + @testset "setval! on chain" begin + # Define a helper function + """ + test_setval!(model, chain; sample_idx = 1, chain_idx = 1) + + Test `setval!` on `model` and `chain`. + + Worth noting that this only supports models containing symbols of the forms + `m`, `m[1]`, `m[1, 2]`, not `m[1][1]`, etc. + """ + function test_setval!(model, chain; sample_idx=1, chain_idx=1) + var_info = VarInfo(model) + spl = SampleFromPrior() + θ_old = var_info[spl] + DynamicPPL.setval!(var_info, chain, sample_idx, chain_idx) + θ_new = var_info[spl] + @test θ_old != θ_new + vals = DynamicPPL.values_as(var_info, OrderedDict) + iters = map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)) + for (n, v) in mapreduce(collect, vcat, iters) + n = string(n) + if Symbol(n) ∉ keys(chain) + # Assume it's a group + chain_val = vec( + MCMCChains.group(chain, Symbol(n)).value[sample_idx, :, chain_idx] + ) + v_true = vec(v) + else + chain_val = chain[sample_idx, n, chain_idx] + v_true = v + end + + @test v_true == chain_val + end + end + + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + chain = make_chain_from_prior(model, 10) + # A simple way of checking that the computation is determinstic: run twice and compare. + res1 = returned(model, MCMCChains.get_sections(chain, :parameters)) + res2 = returned(model, MCMCChains.get_sections(chain, :parameters)) + @test all(res1 .== res2) + test_setval!(model, MCMCChains.get_sections(chain, :parameters)) + end + end + + @testset "link!! and invlink!!" begin + @model gdemo(x, y) = begin + s ~ InverseGamma(2, 3) + m ~ Uniform(0, 2) + x ~ Normal(m, sqrt(s)) + y ~ Normal(m, sqrt(s)) + end + model = gdemo(1.0, 2.0) + + # Check that instantiating the model does not perform linking + vi = VarInfo() + meta = vi.metadata + model(vi, SampleFromUniform()) + @test all(x -> !istrans(vi, x), meta.vns) + + # Check that linking and invlinking set the `trans` flag accordingly + v = copy(meta.vals) + link!!(vi, model) + @test all(x -> istrans(vi, x), meta.vns) + invlink!!(vi, model) + @test all(x -> !istrans(vi, x), meta.vns) + @test meta.vals ≈ v atol = 1e-10 + + # Check that linking and invlinking preserves the values + vi = TypedVarInfo(vi) + meta = vi.metadata + @test all(x -> !istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + v_s = copy(meta.s.vals) + v_m = copy(meta.m.vals) + link!!(vi, model) + @test all(x -> istrans(vi, x), meta.s.vns) + @test all(x -> istrans(vi, x), meta.m.vns) + invlink!!(vi, model) + @test all(x -> !istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + @test meta.s.vals ≈ v_s atol = 1e-10 + @test meta.m.vals ≈ v_m atol = 1e-10 + + # Transform only one variable (`s`) but not the others (`m`) + spl = DynamicPPL.Sampler(MyAlg{(:s,)}(), model) + link!!(vi, spl, model) + @test all(x -> istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + invlink!!(vi, spl, model) + @test all(x -> !istrans(vi, x), meta.s.vns) + @test all(x -> !istrans(vi, x), meta.m.vns) + @test meta.s.vals ≈ v_s atol = 1e-10 + @test meta.m.vals ≈ v_m atol = 1e-10 + end + @testset "istrans" begin @model demo_constrained() = x ~ truncated(Normal(), 0, Inf) model = demo_constrained() @@ -737,7 +880,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) DynamicPPL.Metadata(), ) selector = DynamicPPL.Selector() - spl = Sampler(MySAlg(), model, selector) + spl = Sampler(MyAlg{(:s,)}(), model, selector) vns = DynamicPPL.TestUtils.varnames(model) vns_s = filter(vn -> DynamicPPL.getsym(vn) === :s, vns) @@ -813,4 +956,189 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test DynamicPPL.istrans(varinfo2, vn) end end + + # NOTE: It is not yet clear if this is something we want from all varinfo types. + # Hence, we only test the `VarInfo` types here. + @testset "vector_getranges for `VarInfo`" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + vns = DynamicPPL.TestUtils.varnames(model) + nt = DynamicPPL.TestUtils.rand_prior_true(model) + varinfos = DynamicPPL.TestUtils.setup_varinfos( + model, nt, vns; include_threadsafe=true + ) + # Only keep `VarInfo` types. + varinfos = filter( + Base.Fix2(isa, DynamicPPL.VarInfoOrThreadSafeVarInfo), varinfos + ) + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + x = values_as(varinfo, Vector) + + # Let's just check all the subsets of `vns`. + @testset "$(convert(Vector{Any},vns_subset))" for vns_subset in + combinations(vns) + ranges = DynamicPPL.vector_getranges(varinfo, vns_subset) + @test length(ranges) == length(vns_subset) + for (r, vn) in zip(ranges, vns_subset) + @test x[r] == DynamicPPL.tovec(varinfo[vn]) + end + end + + # Let's try some failure cases. + @test DynamicPPL.vector_getranges(varinfo, VarName[]) == UnitRange{Int}[] + # Non-existent variables. + @test_throws KeyError DynamicPPL.vector_getranges( + varinfo, [VarName{gensym("vn")}()] + ) + @test_throws KeyError DynamicPPL.vector_getranges( + varinfo, [VarName{gensym("vn")}(), VarName{gensym("vn")}()] + ) + # Duplicate variables. + ranges_duplicated = DynamicPPL.vector_getranges(varinfo, repeat(vns, 2)) + @test x[reduce(vcat, ranges_duplicated)] == repeat(x, 2) + end + end + end + + @testset "orders" begin + @model empty_model() = x = 1 + + csym = gensym() # unique per model + vn_z1 = @varname z[1] + vn_z2 = @varname z[2] + vn_z3 = @varname z[3] + vn_z4 = @varname z[4] + vn_a1 = @varname a[1] + vn_a2 = @varname a[2] + vn_b = @varname b + + vi = DynamicPPL.VarInfo() + dists = [Categorical([0.7, 0.3]), Normal()] + + spl1 = DynamicPPL.Sampler(MyAlg{()}(), empty_model()) + spl2 = DynamicPPL.Sampler(MyAlg{()}(), empty_model()) + + # First iteration, variables are added to vi + # variables samples in order: z1,a1,z2,a2,z3 + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z1, dists[1], spl1) + randr(vi, vn_a1, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_b, dists[2], spl2) + randr(vi, vn_z2, dists[1], spl1) + randr(vi, vn_a2, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z3, dists[1], spl1) + @test vi.metadata.orders == [1, 1, 2, 2, 2, 3] + @test DynamicPPL.get_num_produce(vi) == 3 + + DynamicPPL.reset_num_produce!(vi) + DynamicPPL.set_retained_vns_del_by_spl!(vi, spl1) + @test DynamicPPL.is_flagged(vi, vn_z1, "del") + @test DynamicPPL.is_flagged(vi, vn_a1, "del") + @test DynamicPPL.is_flagged(vi, vn_z2, "del") + @test DynamicPPL.is_flagged(vi, vn_a2, "del") + @test DynamicPPL.is_flagged(vi, vn_z3, "del") + + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z1, dists[1], spl1) + randr(vi, vn_a1, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z2, dists[1], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z3, dists[1], spl1) + randr(vi, vn_a2, dists[2], spl1) + @test vi.metadata.orders == [1, 1, 2, 2, 3, 3] + @test DynamicPPL.get_num_produce(vi) == 3 + + vi = empty!!(DynamicPPL.TypedVarInfo(vi)) + # First iteration, variables are added to vi + # variables samples in order: z1,a1,z2,a2,z3 + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z1, dists[1], spl1) + randr(vi, vn_a1, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_b, dists[2], spl2) + randr(vi, vn_z2, dists[1], spl1) + randr(vi, vn_a2, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z3, dists[1], spl1) + @test vi.metadata.z.orders == [1, 2, 3] + @test vi.metadata.a.orders == [1, 2] + @test vi.metadata.b.orders == [2] + @test DynamicPPL.get_num_produce(vi) == 3 + + DynamicPPL.reset_num_produce!(vi) + DynamicPPL.set_retained_vns_del_by_spl!(vi, spl1) + @test DynamicPPL.is_flagged(vi, vn_z1, "del") + @test DynamicPPL.is_flagged(vi, vn_a1, "del") + @test DynamicPPL.is_flagged(vi, vn_z2, "del") + @test DynamicPPL.is_flagged(vi, vn_a2, "del") + @test DynamicPPL.is_flagged(vi, vn_z3, "del") + + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z1, dists[1], spl1) + randr(vi, vn_a1, dists[2], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z2, dists[1], spl1) + DynamicPPL.increment_num_produce!(vi) + randr(vi, vn_z3, dists[1], spl1) + randr(vi, vn_a2, dists[2], spl1) + @test vi.metadata.z.orders == [1, 2, 3] + @test vi.metadata.a.orders == [1, 3] + @test vi.metadata.b.orders == [2] + @test DynamicPPL.get_num_produce(vi) == 3 + end + + @testset "varinfo ranges" begin + @model empty_model() = x = 1 + dists = [Normal(0, 1), MvNormal(zeros(2), I), Wishart(7, [1 0.5; 0.5 1])] + + function test_varinfo!(vi) + spl2 = DynamicPPL.Sampler(MyAlg{(:w, :u)}(), empty_model()) + vn_w = @varname w + randr(vi, vn_w, dists[1], spl2, true) + + vn_x = @varname x + vn_y = @varname y + vn_z = @varname z + vns = [vn_x, vn_y, vn_z] + + spl1 = DynamicPPL.Sampler(MyAlg{(:x, :y, :z)}(), empty_model()) + for i in 1:3 + r = randr(vi, vns[i], dists[i], spl1, false) + val = vi[vns[i]] + @test sum(val - r) <= 1e-9 + end + + idcs = DynamicPPL._getidcs(vi, spl1) + if idcs isa NamedTuple + @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 3 + else + @test length(idcs) == 3 + end + @test length(vi[spl1]) == 7 + + idcs = DynamicPPL._getidcs(vi, spl2) + if idcs isa NamedTuple + @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 1 + else + @test length(idcs) == 1 + end + @test length(vi[spl2]) == 1 + + vn_u = @varname u + randr(vi, vn_u, dists[1], spl2, true) + + idcs = DynamicPPL._getidcs(vi, spl2) + if idcs isa NamedTuple + @test sum(length(getfield(idcs, f)) for f in fieldnames(typeof(idcs))) == 2 + else + @test length(idcs) == 2 + end + @test length(vi[spl2]) == 2 + end + vi = DynamicPPL.VarInfo() + test_varinfo!(vi) + test_varinfo!(empty!!(DynamicPPL.TypedVarInfo(vi))) + end end