Skip to content

Commit

Permalink
Apply Tor's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Feb 4, 2024
1 parent 91d082e commit 7b84ba1
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 21 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.24.6"
version = "0.24.7"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
9 changes: 8 additions & 1 deletion ext/DynamicPPLReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@ function LogDensityProblemsAD.ADgradient(
ad::ADTypes.AutoReverseDiff, ℓ::DynamicPPL.LogDensityFunction
)
return LogDensityProblemsAD.ADgradient(
Val(:ReverseDiff), ℓ; compile=Val(ad.compile), x=identity.(DynamicPPL.getparams(ℓ))
Val(:ReverseDiff),
ℓ;
compile=Val(ad.compile),
# `getparams` can return `Vector{Real}`, in which case, `ReverseDiff` will initialize the gradients to Integer 0
# because at https://github.com/JuliaDiff/ReverseDiff.jl/blob/c982cde5494fc166965a9d04691f390d9e3073fd/src/tracked.jl#L473
# `zero(D)` will return 0 when D is Real.
# here we use `identity` to possibly concretize the type to `Vector{Float64}` in the case of `Vector{Real}`.
x=map(identity, DynamicPPL.getparams(ℓ)),
)
end

Expand Down
41 changes: 22 additions & 19 deletions test/ad.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
@testset "Testing AD by comparing gradient using ForwardDiff and ReverseDiff for model $(m.f)" for m in
DynamicPPL.TestUtils.DEMO_MODELS
f = DynamicPPL.LogDensityFunction(m)
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
vns = DynamicPPL.TestUtils.varnames(m)
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)
@testset "AD: ForwardDiff and ReverseDiff" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
f = DynamicPPL.LogDensityFunction(m)
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
vns = DynamicPPL.TestUtils.varnames(m)
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)

@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
f = DynamicPPL.LogDensityFunction(m, varinfo)
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
f = DynamicPPL.LogDensityFunction(m, varinfo)

# use ForwardDiff result as reference
ad_forwarddiff_f = LogDensityProblemsAD.ADgradient(
ADTypes.AutoForwardDiff(; chunksize=0), f
)
θ = identity.(varinfo[:])
logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ)
# use ForwardDiff result as reference
ad_forwarddiff_f = LogDensityProblemsAD.ADgradient(
ADTypes.AutoForwardDiff(; chunksize=0), f
)
# convert to `Vector{Float64}` to avoid `ReverseDiff` initializing the gradients to Integer 0
# reference: https://github.com/TuringLang/DynamicPPL.jl/pull/571#issuecomment-1924304489
θ = convert(Vector{Float64}, varinfo[:])
logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ)

@testset "ReverseDiff with compile=$compile" for compile in (false, true)
adtype = ADTypes.AutoReverseDiff(; compile=compile)
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
@test grad ref_grad
@testset "ReverseDiff with compile=$compile" for compile in (false, true)
adtype = ADTypes.AutoReverseDiff(; compile=compile)
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
@test grad ref_grad
end
end
end
end

0 comments on commit 7b84ba1

Please sign in to comment.