Skip to content

Commit

Permalink
Merge branch 'master' into compathelper/new_version/2024-01-26-00-08-…
Browse files Browse the repository at this point in the history
…39-774-02165247943
  • Loading branch information
sunxd3 authored Feb 16, 2024
2 parents 3129115 + abcf584 commit a594316
Show file tree
Hide file tree
Showing 11 changed files with 202 additions and 21 deletions.
34 changes: 21 additions & 13 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.24.5"
version = "0.24.7"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Expand All @@ -14,6 +15,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -22,19 +24,8 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[extensions]
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]

[compat]
ADTypes = "0.2"
AbstractMCMC = "5"
AbstractPPL = "0.7"
BangBang = "0.3, 0.4"
Expand All @@ -47,6 +38,7 @@ DocStringExtensions = "0.9"
EnzymeCore = "0.6"
LinearAlgebra = "1.6"
LogDensityProblems = "2"
LogDensityProblemsAD = "1.7.0"
MCMCChains = "6"
MacroTools = "0.5.6"
OrderedCollections = "1"
Expand All @@ -57,8 +49,24 @@ Test = "1.6"
ZygoteRules = "0.2"
julia = "1.6"

[extensions]
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLReverseDiffExt = ["ReverseDiff"]
DynamicPPLZygoteRulesExt = ["ZygoteRules"]

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
54 changes: 54 additions & 0 deletions ext/DynamicPPLForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
module DynamicPPLForwardDiffExt

if isdefined(Base, :get_extension)
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
using ForwardDiff
else
using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
using ..ForwardDiff
end

getchunksize(::ADTypes.AutoForwardDiff{chunk}) where {chunk} = chunk

standardtag(::ADTypes.AutoForwardDiff{<:Any,Nothing}) = true
standardtag(::ADTypes.AutoForwardDiff) = false

function LogDensityProblemsAD.ADgradient(
ad::ADTypes.AutoForwardDiff, ℓ::DynamicPPL.LogDensityFunction
)
θ = DynamicPPL.getparams(ℓ)
f = Base.Fix1(LogDensityProblems.logdensity, ℓ)

# Define configuration for ForwardDiff.
tag = if standardtag(ad)
ForwardDiff.Tag(DynamicPPL.DynamicPPLTag(), eltype(θ))
else
ForwardDiff.Tag(f, eltype(θ))
end
chunk_size = getchunksize(ad)
chunk = if chunk_size == 0
ForwardDiff.Chunk(θ)
else
ForwardDiff.Chunk(length(θ), chunk_size)
end

return LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓ; chunk, tag, x=θ)
end

# Allow Turing tag in gradient etc. calls of the log density function
function ForwardDiff.checktag(
::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}},
::DynamicPPL.LogDensityFunction,
::AbstractArray{W},
) where {V,W}
return true
end
function ForwardDiff.checktag(
::Type{ForwardDiff.Tag{DynamicPPL.DynamicPPLTag,V}},
::Base.Fix1{typeof(LogDensityProblems.logdensity),<:DynamicPPL.LogDensityFunction},
::AbstractArray{W},
) where {V,W}
return true
end

end # module
26 changes: 26 additions & 0 deletions ext/DynamicPPLReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
module DynamicPPLReverseDiffExt

if isdefined(Base, :get_extension)
using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
using ReverseDiff
else
using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD
using ..ReverseDiff
end

function LogDensityProblemsAD.ADgradient(
ad::ADTypes.AutoReverseDiff, ℓ::DynamicPPL.LogDensityFunction
)
return LogDensityProblemsAD.ADgradient(
Val(:ReverseDiff),
ℓ;
compile=Val(ad.compile),
# `getparams` can return `Vector{Real}`, in which case, `ReverseDiff` will initialize the gradients to Integer 0
# because at https://github.com/JuliaDiff/ReverseDiff.jl/blob/c982cde5494fc166965a9d04691f390d9e3073fd/src/tracked.jl#L473
# `zero(D)` will return 0 when D is Real.
# here we use `identity` to possibly concretize the type to `Vector{Float64}` in the case of `Vector{Real}`.
x=map(identity, DynamicPPL.getparams(ℓ)),
)
end

end # module
12 changes: 12 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@ using Distributions
using OrderedCollections: OrderedDict

using AbstractMCMC: AbstractMCMC
using ADTypes: ADTypes
using BangBang: BangBang, push!!, empty!!, setindex!!
using MacroTools: MacroTools
using ConstructionBase: ConstructionBase
using Setfield: Setfield
using LogDensityProblems: LogDensityProblems
using LogDensityProblemsAD: LogDensityProblemsAD

using LinearAlgebra: LinearAlgebra, Cholesky

Expand Down Expand Up @@ -189,13 +191,23 @@ end
@require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include(
"../ext/DynamicPPLEnzymeCoreExt.jl"
)
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include(
"../ext/DynamicPPLForwardDiffExt.jl"
)
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
"../ext/DynamicPPLMCMCChainsExt.jl"
)
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(
"../ext/DynamicPPLReverseDiffExt.jl"
)
@require ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" include(
"../ext/DynamicPPLZygoteRulesExt.jl"
)
end
end

# Standard tag: Improves stacktraces
# Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
struct DynamicPPLTag end

end # module
12 changes: 12 additions & 0 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,22 @@ require_particles(spl::Sampler) = false

# Allows samplers, etc. to hook into the final logp accumulation in the tilde-pipeline.
function acclogp_assume!!(context::AbstractContext, vi::AbstractVarInfo, logp)
return acclogp_assume!!(NodeTrait(acclogp_assume!!, context), context, vi, logp)
end
function acclogp_assume!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp)
return acclogp_assume!!(childcontext(context), vi, logp)
end
function acclogp_assume!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp)
return acclogp!!(context, vi, logp)
end

function acclogp_observe!!(context::AbstractContext, vi::AbstractVarInfo, logp)
return acclogp_observe!!(NodeTrait(acclogp_observe!!, context), context, vi, logp)
end
function acclogp_observe!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp)
return acclogp_observe!!(childcontext(context), vi, logp)
end
function acclogp_observe!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp)
return acclogp!!(context, vi, logp)
end

Expand Down
7 changes: 6 additions & 1 deletion src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,12 @@ end

unflatten(svi::SimpleVarInfo, spl::AbstractSampler, x::AbstractVector) = unflatten(svi, x)
function unflatten(svi::SimpleVarInfo, x::AbstractVector)
return Setfield.@set svi.values = unflatten(svi.values, x)
logp = getlogp(svi)
vals = unflatten(svi.values, x)
T = eltype(x)
return SimpleVarInfo{typeof(vals),T,typeof(svi.transformation)}(
vals, T(logp), svi.transformation
)
end

function BangBang.empty!!(vi::SimpleVarInfo)
Expand Down
22 changes: 15 additions & 7 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,7 @@ const VarInfoOrThreadSafeVarInfo{Tmeta} = Union{
# multiple times.
transformation(vi::VarInfo) = DynamicTransformation()

function VarInfo(old_vi::UntypedVarInfo, spl, x::AbstractVector)
new_vi = deepcopy(old_vi)
new_vi[spl] = x
return new_vi
end

function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector)
function VarInfo(old_vi::VarInfo, spl, x::AbstractVector)
md = newmetadata(old_vi.metadata, Val(getspace(spl)), x)
return VarInfo(
md, Base.RefValue{eltype(x)}(getlogp(old_vi)), Ref(get_num_produce(old_vi))
Expand Down Expand Up @@ -147,6 +141,20 @@ function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext
return VarInfo(rng, model, SampleFromPrior(), context)
end

# TODO: Remove `space` argument when no longer needed. Ref: https://github.com/TuringLang/DynamicPPL.jl/issues/573
function newmetadata(metadata::Metadata, space, x)
return Metadata(
metadata.idcs,
metadata.vns,
metadata.ranges,
x,
metadata.dists,
metadata.gids,
metadata.orders,
metadata.flags,
)
end

@generated function newmetadata(
metadata::NamedTuple{names}, ::Val{space}, x
) where {names,space}
Expand Down
6 changes: 6 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Expand All @@ -11,10 +12,12 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand All @@ -23,6 +26,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "0.2"
AbstractMCMC = "5"
AbstractPPL = "0.7"
Bijectors = "0.13"
Expand All @@ -33,8 +37,10 @@ Documenter = "1"
EnzymeCore = "0.6"
ForwardDiff = "0.10.12"
LogDensityProblems = "2"
LogDensityProblemsAD = "1"
MCMCChains = "6.0.4"
MacroTools = "0.5.5"
ReverseDiff = "1"
Setfield = "1"
StableRNGs = "1"
Tracker = "0.2.23"
Expand Down
28 changes: 28 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
@testset "AD: ForwardDiff and ReverseDiff" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
f = DynamicPPL.LogDensityFunction(m)
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
vns = DynamicPPL.TestUtils.varnames(m)
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)

@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
f = DynamicPPL.LogDensityFunction(m, varinfo)

# use ForwardDiff result as reference
ad_forwarddiff_f = LogDensityProblemsAD.ADgradient(
ADTypes.AutoForwardDiff(; chunksize=0), f
)
# convert to `Vector{Float64}` to avoid `ReverseDiff` initializing the gradients to Integer 0
# reference: https://github.com/TuringLang/DynamicPPL.jl/pull/571#issuecomment-1924304489
θ = convert(Vector{Float64}, varinfo[:])
logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ)

@testset "ReverseDiff with compile=$compile" for compile in (false, true)
adtype = ADTypes.AutoReverseDiff(; compile=compile)
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
@test grad ref_grad
end
end
end
end
14 changes: 14 additions & 0 deletions test/ext/DynamicPPLForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
@testset "tag" begin
for chunksize in (0, 1, 10)
ad = ADTypes.AutoForwardDiff(; chunksize=chunksize)
standardtag = if !isdefined(Base, :get_extension)
DynamicPPL.DynamicPPLForwardDiffExt.standardtag
else
Base.get_extension(DynamicPPL, :DynamicPPLForwardDiffExt).standardtag
end
@test standardtag(ad)
for tag in (false, 0, 1)
@test !standardtag(AutoForwardDiff(; chunksize=chunksize, tag=tag))
end
end
end
8 changes: 8 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using ADTypes
using DynamicPPL
using AbstractMCMC
using AbstractPPL
Expand All @@ -6,9 +7,11 @@ using Distributions
using DistributionsAD
using Documenter
using ForwardDiff
using LogDensityProblems, LogDensityProblemsAD
using MacroTools
using MCMCChains
using Tracker
using ReverseDiff
using Zygote
using Setfield
using Compat
Expand Down Expand Up @@ -64,6 +67,11 @@ include("test_util.jl")
include("ext/DynamicPPLMCMCChainsExt.jl")
end

@testset "ad" begin
include("ext/DynamicPPLForwardDiffExt.jl")
include("ad.jl")
end

@testset "doctests" begin
DocMeta.setdocmeta!(
DynamicPPL,
Expand Down

0 comments on commit a594316

Please sign in to comment.