From 0f7976c399522026475abccd03efa87fe2544759 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 30 Apr 2024 12:56:40 +0100 Subject: [PATCH] ADTypes Interop (#127) * ADTypes interop * Bump patch --- Project.toml | 6 +++-- ext/TapirLogDensityProblemsADExt.jl | 23 ++++++++++++------- src/Tapir.jl | 1 + .../logdensityproblemsad_interop.jl | 4 +++- 4 files changed, 23 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index aff28c7b3..fdff8a624 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.1" +version = "0.2.2" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d" @@ -25,6 +26,7 @@ TapirLogDensityProblemsADExt = "LogDensityProblemsAD" TapirSpecialFunctionsExt = "SpecialFunctions" [compat] +ADTypes = "1" BenchmarkTools = "1" ChainRulesCore = "1" DiffRules = "1" @@ -40,7 +42,7 @@ Setfield = "1" SpecialFunctions = "2" StableRNGs = "1" TemporalGPs = "0.6" -Turing = "0.31" +Turing = "0.31.3" julia = "1" [extras] diff --git a/ext/TapirLogDensityProblemsADExt.jl b/ext/TapirLogDensityProblemsADExt.jl index a6ed57dab..8a918d67a 100644 --- a/ext/TapirLogDensityProblemsADExt.jl +++ b/ext/TapirLogDensityProblemsADExt.jl @@ -4,10 +4,12 @@ module TapirLogDensityProblemsADExt if isdefined(Base, :get_extension) + using ADTypes using LogDensityProblemsAD: ADGradientWrapper import LogDensityProblemsAD: ADgradient, logdensity_and_gradient, dimension, logdensity import Tapir else + using ADTypes using ..LogDensityProblemsAD: ADGradientWrapper import ..LogDensityProblemsAD: ADgradient, logdensity_and_gradient, dimension, logdensity import ..Tapir @@ -15,13 +17,13 @@ end struct TapirGradientLogDensity{Trule, L} <: ADGradientWrapper rule::Trule - l::L + ℓ::L end -dimension(∇l::TapirGradientLogDensity) = dimension(Tapir.primal(∇l.l)) +dimension(∇l::TapirGradientLogDensity) = dimension(Tapir.primal(∇l.ℓ)) function logdensity(∇l::TapirGradientLogDensity, x::Vector{Float64}) - return logdensity(Tapir.primal(∇l.l), x) + return logdensity(Tapir.primal(∇l.ℓ), x) end """ @@ -29,13 +31,15 @@ end Gradient using algorithmic/automatic differentiation via Tapir. """ -function ADgradient(::Val{:Tapir}, l) - primal_sig = Tuple{typeof(logdensity), typeof(l), Vector{Float64}} +function ADgradient(::Val{:Tapir}, ℓ) + primal_sig = Tuple{typeof(logdensity), typeof(ℓ), Vector{Float64}} rule = Tapir.build_rrule(Tapir.TapirInterpreter(), primal_sig) - return TapirGradientLogDensity(rule, Tapir.uninit_fcodual(l)) + return TapirGradientLogDensity(rule, Tapir.uninit_fcodual(ℓ)) end -Base.show(io::IO, ∇ℓ::TapirGradientLogDensity) = print(io, "Tapir AD wrapper for ", ∇ℓ.ℓ) +function Base.show(io::IO, ∇ℓ::TapirGradientLogDensity) + return print(io, "Tapir AD wrapper for ", Tapir.primal(∇ℓ.ℓ)) +end # We only test Tapir with `Float64`s at the minute, so make strong assumptions about the # types supported in order to prevent silent errors. @@ -46,10 +50,13 @@ end function logdensity_and_gradient(∇l::TapirGradientLogDensity, x::Vector{Float64}) dx = zeros(length(x)) - y, pb!! = ∇l.rule(Tapir.zero_fcodual(logdensity), ∇l.l, Tapir.CoDual(x, dx)) + y, pb!! = ∇l.rule(Tapir.zero_fcodual(logdensity), ∇l.ℓ, Tapir.CoDual(x, dx)) @assert Tapir.primal(y) isa Float64 pb!!(1.0) return Tapir.primal(y), dx end +# Interop with ADTypes. +ADgradient(::ADTypes.AutoTapir, ℓ) = ADgradient(Val(:Tapir), ℓ) + end diff --git a/src/Tapir.jl b/src/Tapir.jl index 1e4183a44..07339aaf1 100644 --- a/src/Tapir.jl +++ b/src/Tapir.jl @@ -3,6 +3,7 @@ module Tapir const CC = Core.Compiler using + ADTypes, DiffRules, ExprTools, Graphs, diff --git a/test/integration_testing/logdensityproblemsad_interop.jl b/test/integration_testing/logdensityproblemsad_interop.jl index 73009bd93..13b7cd6d5 100644 --- a/test/integration_testing/logdensityproblemsad_interop.jl +++ b/test/integration_testing/logdensityproblemsad_interop.jl @@ -1,4 +1,4 @@ -using LogDensityProblemsAD +using ADTypes, LogDensityProblemsAD using LogDensityProblemsAD: logdensity_and_gradient, capabilities, dimension, logdensity # Copied over from LogDensityProblemsAD test suite. @@ -19,4 +19,6 @@ test_gradient(x) = -2 .* x @test isapprox(logdensity_and_gradient(∇l, x)[1], logdensity(TestLogDensity2(), x)) @test isapprox(logdensity_and_gradient(∇l, x)[2], test_gradient(x)) end + + @test ADgradient(ADTypes.AutoTapir(), l) isa typeof(∇l) end