From 31e8f708026690e1794036f30cc950810d1d4ca3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Oct 2023 15:35:54 +0100 Subject: [PATCH] fix + test for compiled ReverseDiff without linking --- Project.toml | 2 +- src/essential/ad.jl | 2 +- test/essential/ad.jl | 14 ++++++++++++++ 3 files changed, 16 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 5ccbbf752..2ef9edc8b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.29.2" +version = "0.29.3" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/essential/ad.jl b/src/essential/ad.jl index 4f07c1eab..5dda65d0f 100644 --- a/src/essential/ad.jl +++ b/src/essential/ad.jl @@ -118,7 +118,7 @@ end for cache in (:true, :false) @eval begin function LogDensityProblemsAD.ADgradient(::ReverseDiffAD{$cache}, ℓ::Turing.LogDensityFunction) - return LogDensityProblemsAD.ADgradient(Val(:ReverseDiff), ℓ; compile=Val($cache)) + return LogDensityProblemsAD.ADgradient(Val(:ReverseDiff), ℓ; compile=Val($cache), x=DynamicPPL.getparams(ℓ)) end end end diff --git a/test/essential/ad.jl b/test/essential/ad.jl index c00f76f12..359101257 100644 --- a/test/essential/ad.jl +++ b/test/essential/ad.jl @@ -198,4 +198,18 @@ end end end + + @testset "ReverseDiff compiled without linking" begin + f = DynamicPPL.LogDensityFunction(gdemo_default) + θ = DynamicPPL.getparams(f) + + f_rd = LogDensityProblemsAD.ADgradient(Turing.Essential.ReverseDiffAD{false}(), f) + f_rd_compiled = LogDensityProblemsAD.ADgradient(Turing.Essential.ReverseDiffAD{true}(), f) + + ℓ, ℓ_grad = LogDensityProblems.logdensity_and_gradient(f_rd, θ) + ℓ_compiled, ℓ_grad_compiled = LogDensityProblems.logdensity_and_gradient(f_rd_compiled, θ) + + @test ℓ == ℓ_compiled + @test ℓ_grad == ℓ_grad_compiled + end end