Skip to content

Commit

Permalink
Merge branch 'master' into dw/zygoterules_chainrulescore
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Nov 21, 2023
2 parents b823bdc + 03e4ba2 commit b28f6bb
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 1 deletion.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.24.2"
version = "0.24.3"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -24,11 +24,13 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[extensions]
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]

Expand All @@ -42,6 +44,7 @@ Compat = "4"
ConstructionBase = "1.5.4"
Distributions = "0.25"
DocStringExtensions = "0.9"
EnzymeCore = "0.6"
LogDensityProblems = "2"
MCMCChains = "6"
MacroTools = "0.5.6"
Expand All @@ -56,5 +59,6 @@ julia = "1.6"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
13 changes: 13 additions & 0 deletions ext/DynamicPPLEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module DynamicPPLEnzymeCoreExt

if isdefined(Base, :get_extension)
using DynamicPPL: DynamicPPL
using EnzymeCore
else
using ..DynamicPPL: DynamicPPL
using ..EnzymeCore
end

@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true

end
3 changes: 3 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ end
@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include(
"../ext/DynamicPPLChainRulesCoreExt.jl"
)
@require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include(
"../ext/DynamicPPLEnzymeCoreExt.jl"
)
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
"../ext/DynamicPPLMCMCChainsExt.jl"
)
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Expand Down
3 changes: 3 additions & 0 deletions test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ using DynamicPPL:
hasconditioned_nested,
getconditioned_nested

using EnzymeCore

# Dummy context to test nested behaviors.
struct ParentContext{C<:AbstractContext} <: AbstractContext
context::C
Expand Down Expand Up @@ -252,6 +254,7 @@ end
@test SamplingContext(Random.default_rng(), DefaultContext()) == context
@test SamplingContext(SampleFromPrior(), DefaultContext()) == context
@test SamplingContext(SampleFromPrior(), DefaultContext()) == context
@test EnzymeCore.EnzymeRules.inactive_type(typeof(context))
end

@testset "FixedContext" begin
Expand Down

0 comments on commit b28f6bb

Please sign in to comment.