Skip to content

Commit

Permalink
Refactor the test
Browse files Browse the repository at this point in the history
  • Loading branch information
sunxd3 committed Feb 2, 2024
1 parent 4b535ba commit 101a2c6
Showing 1 changed file with 31 additions and 15 deletions.
46 changes: 31 additions & 15 deletions test/ad.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@testset "AD test for model $(repr(m.f))" for m in DynamicPPL.TestUtils.DEMO_MODELS
@testset "ReverseDiff test for model $(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
f = DynamicPPL.LogDensityFunction(m)
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
vns = DynamicPPL.TestUtils.varnames(m)
Expand All @@ -14,25 +14,41 @@
θ = identity.(varinfo[:])
logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ)

@testset "with ADType $adtype" for adtype in (
ADTypes.AutoReverseDiff(false), ADTypes.AutoReverseDiff(true)
@testset "ReverseDiff with compile=$compile" for compile in (
false, true
)
adtype = ADTypes.AutoReverseDiff(; compile=compile)
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
@test grad ref_grad
end
end
end

if m.f (
DynamicPPL.TestUtils.demo_assume_multivariate_observe,
DynamicPPL.TestUtils.demo_assume_dot_observe,
DynamicPPL.TestUtils.demo_assume_observe_literal,
DynamicPPL.TestUtils.demo_assume_literal_dot_observe,
) &&
varinfo isa Union{DynamicPPL.TypedVarInfo,DynamicPPL.SimpleVarInfo{<:NamedTuple}}
adtype = ADTypes.AutoZygote()
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
@test grad ref_grad
end
@testset "Zygote test for model $(m.f))" for m in (
DynamicPPL.TestUtils.demo_assume_multivariate_observe(),
DynamicPPL.TestUtils.demo_assume_dot_observe(),
DynamicPPL.TestUtils.demo_assume_observe_literal(),
DynamicPPL.TestUtils.demo_assume_literal_dot_observe(),
)
f = DynamicPPL.LogDensityFunction(m)
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
vns = DynamicPPL.TestUtils.varnames(m)
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)

@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos[[2, 3, 5]] # only test on `TypedVarInfo` and `SimpleVarInfo{<:NamedTuple}`
f = DynamicPPL.LogDensityFunction(m, varinfo)

# use ForwardDiff result as reference
ad_forwarddiff_f = LogDensityProblemsAD.ADgradient(
ADTypes.AutoForwardDiff(; chunksize=0), f
)
θ = identity.(varinfo[:])
logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ)

adtype = ADTypes.AutoZygote()
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
@test grad ref_grad
end
end

0 comments on commit 101a2c6

Please sign in to comment.