diff --git a/Project.toml b/Project.toml index df27d35b4..46b2bc046 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.24.2" +version = "0.24.3" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -24,11 +24,13 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [extensions] DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] +DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLMCMCChainsExt = ["MCMCChains"] DynamicPPLZygoteRulesExt = ["ZygoteRules"] @@ -42,6 +44,7 @@ Compat = "4" ConstructionBase = "1.5.4" Distributions = "0.25" DocStringExtensions = "0.9" +EnzymeCore = "0.6" LogDensityProblems = "2" MCMCChains = "6" MacroTools = "0.5.6" @@ -56,5 +59,6 @@ julia = "1.6" [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl new file mode 100644 index 000000000..f83d6e8f7 --- /dev/null +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -0,0 +1,13 @@ +module DynamicPPLEnzymeCoreExt + +if isdefined(Base, :get_extension) + using DynamicPPL: DynamicPPL + using EnzymeCore +else + using ..DynamicPPL: DynamicPPL + using ..EnzymeCore +end + +@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true + +end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index f4e6ca04f..9d7eb6b7d 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -186,6 +186,9 @@ end @require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include( "../ext/DynamicPPLChainRulesCoreExt.jl" ) + @require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include( + "../ext/DynamicPPLEnzymeCoreExt.jl" + ) @require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include( "../ext/DynamicPPLMCMCChainsExt.jl" ) diff --git a/test/Project.toml b/test/Project.toml index 16c793956..878d3c1d1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,6 +7,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" diff --git a/test/contexts.jl b/test/contexts.jl index 9b0427cd0..d04aecb52 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -16,6 +16,8 @@ using DynamicPPL: hasconditioned_nested, getconditioned_nested +using EnzymeCore + # Dummy context to test nested behaviors. struct ParentContext{C<:AbstractContext} <: AbstractContext context::C @@ -252,6 +254,7 @@ end @test SamplingContext(Random.default_rng(), DefaultContext()) == context @test SamplingContext(SampleFromPrior(), DefaultContext()) == context @test SamplingContext(SampleFromPrior(), DefaultContext()) == context + @test EnzymeCore.EnzymeRules.inactive_type(typeof(context)) end @testset "FixedContext" begin