Skip to content

Commit

Permalink
Make ZygoteRules and ChainRulesCore weak dependencies (TuringLang#564)
Browse files Browse the repository at this point in the history
* Make ZygoteRules and ChainRulesCore weak dependencies

* Fix format

* Add another non-differentiable to CRC extension

* Perform coverage analysis on all Julia versions
  • Loading branch information
devmotion authored Nov 21, 2023
1 parent 03e4ba2 commit fc6cae9
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 39 deletions.
5 changes: 0 additions & 5 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,14 @@ jobs:
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@latest
- uses: julia-actions/julia-runtest@latest
with:
coverage: ${{ matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 1 }}
env:
GROUP: All
JULIA_NUM_THREADS: ${{ matrix.num_threads }}
- uses: julia-actions/julia-processcoverage@v1
if: matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 1
- uses: codecov/codecov-action@v1
if: matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 1
with:
file: lcov.info
- uses: coverallsapp/github-action@master
if: matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 1
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
path-to-lcov: lcov.info
14 changes: 10 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.24.2"
version = "0.24.3"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -23,12 +23,16 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[extensions]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]

[compat]
AbstractMCMC = "5"
Expand All @@ -54,5 +58,7 @@ Test = "1.6"
julia = "1.6"

[extras]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
27 changes: 27 additions & 0 deletions ext/DynamicPPLChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
module DynamicPPLChainRulesCoreExt

if isdefined(Base, :get_extension)
using DynamicPPL: DynamicPPL, BangBang, Distributions
using ChainRulesCore: ChainRulesCore
else
using ..DynamicPPL: DynamicPPL, BangBang, Distributions
using ..ChainRulesCore: ChainRulesCore
end

# See https://github.com/TuringLang/Turing.jl/issues/1199
ChainRulesCore.@non_differentiable BangBang.push!!(
vi::DynamicPPL.VarInfo,
vn::DynamicPPL.VarName,
r,
dist::Distributions.Distribution,
gidset::Set{DynamicPPL.Selector},
)

ChainRulesCore.@non_differentiable DynamicPPL.updategid!(
vi::DynamicPPL.AbstractVarInfo, vn::DynamicPPL.VarName, spl::DynamicPPL.Sampler
)

# No need + causes issues for some AD backends, e.g. Zygote.
ChainRulesCore.@non_differentiable DynamicPPL.infer_nested_eltype(x)

end # module
25 changes: 25 additions & 0 deletions ext/DynamicPPLZygoteRulesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module DynamicPPLZygoteRulesExt

if isdefined(Base, :get_extension)
using DynamicPPL: DynamicPPL, Distributions
using ZygoteRules: ZygoteRules
else
using ..DynamicPPL: DynamicPPL, Distributions
using ..ZygoteRules: ZygoteRules
end

# https://github.com/TuringLang/Turing.jl/issues/1595
ZygoteRules.@adjoint function DynamicPPL.dot_observe(
spl::Union{DynamicPPL.SampleFromPrior,DynamicPPL.SampleFromUniform},
dists::AbstractArray{<:Distributions.Distribution},
value::AbstractArray,
vi,
)
function dot_observe_fallback(spl, dists, value, vi)
DynamicPPL.increment_num_produce!(vi)
return sum(map(Distributions.loglikelihood, dists, value)), vi
end
return ZygoteRules.pullback(__context__, dot_observe_fallback, spl, dists, value, vi)
end

end # module
13 changes: 8 additions & 5 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@ using OrderedCollections: OrderedDict

using AbstractMCMC: AbstractMCMC
using BangBang: BangBang, push!!, empty!!, setindex!!
using ChainRulesCore: ChainRulesCore
using MacroTools: MacroTools
using ConstructionBase: ConstructionBase
using Setfield: Setfield
using ZygoteRules: ZygoteRules
using LogDensityProblems: LogDensityProblems

using LinearAlgebra: LinearAlgebra, Cholesky
Expand Down Expand Up @@ -171,7 +169,6 @@ include("simple_varinfo.jl")
include("context_implementations.jl")
include("compiler.jl")
include("prob_macro.jl")
include("compat/ad.jl")
include("loglikelihoods.jl")
include("submodel_macro.jl")
include("test_utils.jl")
Expand All @@ -186,12 +183,18 @@ end

@static if !isdefined(Base, :get_extension)
function __init__()
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
"../ext/DynamicPPLMCMCChainsExt.jl"
@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include(
"../ext/DynamicPPLChainRulesCoreExt.jl"
)
@require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include(
"../ext/DynamicPPLEnzymeCoreExt.jl"
)
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
"../ext/DynamicPPLMCMCChainsExt.jl"
)
@require ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" include(
"../ext/DynamicPPLZygoteRulesExt.jl"
)
end
end

Expand Down
22 changes: 0 additions & 22 deletions src/compat/ad.jl

This file was deleted.

3 changes: 0 additions & 3 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -883,9 +883,6 @@ end
# Handle `AbstractDict` differently since `eltype` results in a `Pair`.
infer_nested_eltype(::Type{<:AbstractDict{<:Any,ET}}) where {ET} = infer_nested_eltype(ET)

# No need + causes issues for some AD backends, e.g. Zygote.
ChainRulesCore.@non_differentiable infer_nested_eltype(x)

"""
varname_leaves(vn::VarName, val)
Expand Down

0 comments on commit fc6cae9

Please sign in to comment.