From b924a177f1e4a431a737325ec9676cb68d47ddef Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Wed, 14 Feb 2024 15:18:10 +0000 Subject: [PATCH] Move the content of `ad.jl` from `Turing.jl` to here (#571) * initialize moving, still need to move tests * Move tests, tests are not fixed yet * Make `ADTypes` a direct dep * Add `ad.jl` for testing * Remove `ADTypes` ext from `require` * Put `ADgradient` code to extensions * Add testing code * Bug fix and adding tests * Update src/simple_varinfo.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * renaming a testset * add require for ReverseDiff extension * fix UUID * fix typo * Also use the original transformation * Fix 1.6 compat * Fix typo * Fix typo, again * Update test/ad.jl Co-authored-by: Tor Erlend Fjelde * Fix errors * Refactor the test * Update ad.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * disable Zygote testing * Change testset description * Update test/ad.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Tor Erlend Fjelde * Apply Tor's comments --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Tor Erlend Fjelde --- Project.toml | 40 ++++++++++++--------- ext/DynamicPPLForwardDiffExt.jl | 54 ++++++++++++++++++++++++++++ ext/DynamicPPLReverseDiffExt.jl | 26 ++++++++++++++ src/DynamicPPL.jl | 12 +++++++ src/simple_varinfo.jl | 7 +++- src/varinfo.jl | 22 ++++++++---- test/Project.toml | 4 +++ test/ad.jl | 28 +++++++++++++++ test/ext/DynamicPPLForwardDiffExt.jl | 14 ++++++++ test/runtests.jl | 8 +++++ 10 files changed, 191 insertions(+), 24 deletions(-) create mode 100644 ext/DynamicPPLForwardDiffExt.jl create mode 100644 ext/DynamicPPLReverseDiffExt.jl create mode 100644 test/ad.jl create mode 100644 test/ext/DynamicPPLForwardDiffExt.jl diff --git a/Project.toml b/Project.toml index 305b1c52c..6510e7ea0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,8 +1,9 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.24.6" +version = "0.24.7" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" @@ -14,6 +15,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -22,19 +24,8 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" -[weakdeps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" - -[extensions] -DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] -DynamicPPLEnzymeCoreExt = ["EnzymeCore"] -DynamicPPLMCMCChainsExt = ["MCMCChains"] -DynamicPPLZygoteRulesExt = ["ZygoteRules"] - [compat] +ADTypes = "0.2" AbstractMCMC = "5" AbstractPPL = "0.7" BangBang = "0.3" @@ -45,20 +36,37 @@ ConstructionBase = "1.5.4" Distributions = "0.25" DocStringExtensions = "0.9" EnzymeCore = "0.6" +LinearAlgebra = "1.6" LogDensityProblems = "2" +LogDensityProblemsAD = "1.7.0" MCMCChains = "6" MacroTools = "0.5.6" OrderedCollections = "1" +Random = "1.6" Requires = "1" Setfield = "1" -ZygoteRules = "0.2" -LinearAlgebra = "1.6" -Random = "1.6" Test = "1.6" +ZygoteRules = "0.2" julia = "1.6" +[extensions] +DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] +DynamicPPLEnzymeCoreExt = ["EnzymeCore"] +DynamicPPLForwardDiffExt = ["ForwardDiff"] +DynamicPPLMCMCChainsExt = ["MCMCChains"] +DynamicPPLReverseDiffExt = ["ReverseDiff"] +DynamicPPLZygoteRulesExt = ["ZygoteRules"] + [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" + +[weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" diff --git a/ext/DynamicPPLForwardDiffExt.jl b/ext/DynamicPPLForwardDiffExt.jl new file mode 100644 index 000000000..10371b3fe --- /dev/null +++ b/ext/DynamicPPLForwardDiffExt.jl @@ -0,0 +1,54 @@ +module DynamicPPLForwardDiffExt + +if isdefined(Base, :get_extension) + using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD + using ForwardDiff +else + using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD + using ..ForwardDiff +end + +getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk + +standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true +standardtag(::ADTypes.AutoForwardDiff) = false + +function LogDensityProblemsAD.ADgradient( + ad::ADTypes.AutoForwardDiff, ℓ::DynamicPPL.LogDensityFunction +) + θ = DynamicPPL.getparams(ℓ) + f = Base.Fix1(LogDensityProblems.logdensity, ℓ) + + # Define configuration for ForwardDiff. + tag = if standardtag(ad) + ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(θ)) + else + ForwardDiff.Tag(f, eltype(θ)) + end + chunk_size = getchunksize(ad) + chunk = if chunk_size == 0 + ForwardDiff.Chunk(θ) + else + ForwardDiff.Chunk(length(θ), chunk_size) + end + + return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x=θ) +end + +# Allow Turing tag in gradient etc. calls of the log density function +function ForwardDiff.checktag( + ::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}}, + ::DynamicPPL.LogDensityFunction, + ::AbstractArray{W}, +) where {V,W} + return true +end +function ForwardDiff.checktag( + ::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}}, + ::Base.Fix1{typeof(LogDensityProblems.logdensity),<:DynamicPPL.LogDensityFunction}, + ::AbstractArray{W}, +) where {V,W} + return true +end + +end # module diff --git a/ext/DynamicPPLReverseDiffExt.jl b/ext/DynamicPPLReverseDiffExt.jl new file mode 100644 index 000000000..b2b378d45 --- /dev/null +++ b/ext/DynamicPPLReverseDiffExt.jl @@ -0,0 +1,26 @@ +module DynamicPPLReverseDiffExt + +if isdefined(Base, :get_extension) + using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD + using ReverseDiff +else + using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD + using ..ReverseDiff +end + +function LogDensityProblemsAD.ADgradient( + ad::ADTypes.AutoReverseDiff, ℓ::DynamicPPL.LogDensityFunction +) + return LogDensityProblemsAD.ADgradient( + Val(:ReverseDiff), + ℓ; + compile=Val(ad.compile), + # `getparams` can return `Vector{Real}`, in which case, `ReverseDiff` will initialize the gradients to Integer 0 + # because at https://github.com/JuliaDiff/ReverseDiff.jl/blob/c982cde5494fc166965a9d04691f390d9e3073fd/src/tracked.jl#L473 + # `zero(D)` will return 0 when D is Real. + # here we use `identity` to possibly concretize the type to `Vector{Float64}` in the case of `Vector{Real}`. + x=map(identity, DynamicPPL.getparams(ℓ)), + ) +end + +end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 9d7eb6b7d..ce6605250 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -8,11 +8,13 @@ using Distributions using OrderedCollections: OrderedDict using AbstractMCMC: AbstractMCMC +using ADTypes: ADTypes using BangBang: BangBang, push!!, empty!!, setindex!! using MacroTools: MacroTools using ConstructionBase: ConstructionBase using Setfield: Setfield using LogDensityProblems: LogDensityProblems +using LogDensityProblemsAD: LogDensityProblemsAD using LinearAlgebra: LinearAlgebra, Cholesky @@ -189,13 +191,23 @@ end @require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include( "../ext/DynamicPPLEnzymeCoreExt.jl" ) + @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include( + "../ext/DynamicPPLForwardDiffExt.jl" + ) @require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include( "../ext/DynamicPPLMCMCChainsExt.jl" ) + @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include( + "../ext/DynamicPPLReverseDiffExt.jl" + ) @require ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" include( "../ext/DynamicPPLZygoteRulesExt.jl" ) end end +# Standard tag: Improves stacktraces +# Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/ +struct DynamicPPLTag end + end # module diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 93c211483..ad37130d6 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -250,7 +250,12 @@ end unflatten(svi::SimpleVarInfo, spl::AbstractSampler, x::AbstractVector) = unflatten(svi, x) function unflatten(svi::SimpleVarInfo, x::AbstractVector) - return Setfield.@set svi.values = unflatten(svi.values, x) + logp = getlogp(svi) + vals = unflatten(svi.values, x) + T = eltype(x) + return SimpleVarInfo{typeof(vals),T,typeof(svi.transformation)}( + vals, T(logp), svi.transformation + ) end function BangBang.empty!!(vi::SimpleVarInfo) diff --git a/src/varinfo.jl b/src/varinfo.jl index 24316aed7..c8c46ee27 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -112,13 +112,7 @@ const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{ # multiple times. transformation(vi::VarInfo) = DynamicTransformation() -function VarInfo(old_vi::UntypedVarInfo, spl, x::AbstractVector) - new_vi = deepcopy(old_vi) - new_vi[spl] = x - return new_vi -end - -function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector) +function VarInfo(old_vi::VarInfo, spl, x::AbstractVector) md = newmetadata(old_vi.metadata, Val(getspace(spl)), x) return VarInfo( md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi)) @@ -147,6 +141,20 @@ function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext return VarInfo(rng, model, SampleFromPrior(), context) end +# TODO: Remove `space` argument when no longer needed. Ref: https://github.com/TuringLang/DynamicPPL.jl/issues/573 +function newmetadata(metadata::Metadata, space, x) + return Metadata( + metadata.idcs, + metadata.vns, + metadata.ranges, + x, + metadata.dists, + metadata.gids, + metadata.orders, + metadata.flags, + ) +end + @generated function newmetadata( metadata::NamedTuple{names}, ::Val{space}, x ) where {names,space} diff --git a/test/Project.toml b/test/Project.toml index 80c227920..73cc134ee 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" @@ -11,10 +12,12 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -23,6 +26,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] +ADTypes = "0.2" AbstractMCMC = "5" AbstractPPL = "0.7" Bijectors = "0.13" diff --git a/test/ad.jl b/test/ad.jl new file mode 100644 index 000000000..6046cfda4 --- /dev/null +++ b/test/ad.jl @@ -0,0 +1,28 @@ +@testset "AD: ForwardDiff and ReverseDiff" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + f = DynamicPPL.LogDensityFunction(m) + rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) + vns = DynamicPPL.TestUtils.varnames(m) + varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) + + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + f = DynamicPPL.LogDensityFunction(m, varinfo) + + # use ForwardDiff result as reference + ad_forwarddiff_f = LogDensityProblemsAD.ADgradient( + ADTypes.AutoForwardDiff(; chunksize=0), f + ) + # convert to `Vector{Float64}` to avoid `ReverseDiff` initializing the gradients to Integer 0 + # reference: https://github.com/TuringLang/DynamicPPL.jl/pull/571#issuecomment-1924304489 + θ = convert(Vector{Float64}, varinfo[:]) + logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ) + + @testset "ReverseDiff with compile=$compile" for compile in (false, true) + adtype = ADTypes.AutoReverseDiff(; compile=compile) + ad_f = LogDensityProblemsAD.ADgradient(adtype, f) + _, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ) + @test grad ≈ ref_grad + end + end + end +end diff --git a/test/ext/DynamicPPLForwardDiffExt.jl b/test/ext/DynamicPPLForwardDiffExt.jl new file mode 100644 index 000000000..1227a8c95 --- /dev/null +++ b/test/ext/DynamicPPLForwardDiffExt.jl @@ -0,0 +1,14 @@ +@testset "tag" begin + for chunksize in (0, 1, 10) + ad = ADTypes.AutoForwardDiff(; chunksize=chunksize) + standardtag = if !isdefined(Base, :get_extension) + DynamicPPL.DynamicPPLForwardDiffExt.standardtag + else + Base.get_extension(DynamicPPL, :DynamicPPLForwardDiffExt).standardtag + end + @test standardtag(ad) + for tag in (false, 0, 1) + @test !standardtag(AutoForwardDiff(; chunksize=chunksize, tag=tag)) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 43d68386c..9e11e2ef4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,3 +1,4 @@ +using ADTypes using DynamicPPL using AbstractMCMC using AbstractPPL @@ -6,9 +7,11 @@ using Distributions using DistributionsAD using Documenter using ForwardDiff +using LogDensityProblems, LogDensityProblemsAD using MacroTools using MCMCChains using Tracker +using ReverseDiff using Zygote using Setfield using Compat @@ -64,6 +67,11 @@ include("test_util.jl") include("ext/DynamicPPLMCMCChainsExt.jl") end + @testset "ad" begin + include("ext/DynamicPPLForwardDiffExt.jl") + include("ad.jl") + end + @testset "doctests" begin DocMeta.setdocmeta!( DynamicPPL,