diff --git a/Project.toml b/Project.toml index 9db693d..11f1470 100644 --- a/Project.toml +++ b/Project.toml @@ -5,27 +5,11 @@ version = "0.6.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Requires = "ae029012-a4dd-5104-9daa-d747884805df" -[weakdeps] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - -[extensions] -AbstractDifferentiationChainRulesCoreExt = "ChainRulesCore" -AbstractDifferentiationFiniteDifferencesExt = "FiniteDifferences" -AbstractDifferentiationForwardDiffExt = ["DiffResults", "ForwardDiff"] -AbstractDifferentiationReverseDiffExt = ["DiffResults", "ReverseDiff"] -AbstractDifferentiationTrackerExt = "Tracker" -AbstractDifferentiationZygoteExt = "Zygote" - [compat] ChainRulesCore = "1" DiffResults = "1" @@ -38,6 +22,16 @@ ReverseDiff = "1" Tracker = "0.2" Zygote = "0.6" julia = "1.6" +Enzyme = "0.12" + +[extensions] +AbstractDifferentiationChainRulesCoreExt = "ChainRulesCore" +AbstractDifferentiationFiniteDifferencesExt = "FiniteDifferences" +AbstractDifferentiationForwardDiffExt = ["DiffResults", "ForwardDiff"] +AbstractDifferentiationReverseDiffExt = ["DiffResults", "ReverseDiff"] +AbstractDifferentiationTrackerExt = "Tracker" +AbstractDifferentiationZygoteExt = "Zygote" +AbstractDifferentiationEnzymeExt = "Enzyme" [extras] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -50,6 +44,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" [targets] -test = ["ChainRulesCore", "DiffResults", "Documenter", "FiniteDifferences", "ForwardDiff", "Random", "ReverseDiff", "Test", "Tracker", "Zygote"] +test = ["ChainRulesCore", "DiffResults", "Documenter", "FiniteDifferences", "ForwardDiff", "Random", "ReverseDiff", "Test", "Tracker", "Zygote", "Enzyme"] \ No newline at end of file diff --git a/ext/AbstractDifferentiationEnzymeExt.jl b/ext/AbstractDifferentiationEnzymeExt.jl new file mode 100644 index 0000000..e1396b8 --- /dev/null +++ b/ext/AbstractDifferentiationEnzymeExt.jl @@ -0,0 +1,62 @@ +module AbstractDifferentiationEnzymeExt + +if isdefined(Base, :get_extension) + import AbstractDifferentiation as AD + using Enzyme: Enzyme +else + import ..AbstractDifferentiation as AD + using ..Enzyme: Enzyme +end + +struct Mutating{F} + f::F +end +function (f::Mutating)(y, xs...) + y .= f.f(xs...) + return y +end + +AD.@primitive function value_and_pullback_function(b::AD.EnzymeReverseBackend, f, xs...) + y = f(xs...) + return y, + Δ -> begin + Δ_xs = zero.(xs) + dup = if y isa Real + if Δ isa Real + Enzyme.Duplicated([y], [Δ]) + elseif Δ isa Tuple{Real} + Enzyme.Duplicated([y], [Δ[1]]) + else + throw(ArgumentError("Unsupported cotangent type.")) + end + else + if Δ isa AbstractArray{<:Real} + Enzyme.Duplicated(y, Δ) + elseif Δ isa Tuple{AbstractArray{<:Real}} + Enzyme.Duplicated(y, Δ[1]) + else + throw(ArgumentError("Unsupported cotangent type.")) + end + end + Enzyme.autodiff( + Enzyme.Reverse, + Mutating(f), + Enzyme.Const, + dup, + Enzyme.Duplicated.(xs, Δ_xs)..., + ) + return Δ_xs + end +end +function AD.pushforward_function(::AD.EnzymeReverseBackend, f, xs...) + return AD.pushforward_function(AD.EnzymeForwardBackend(), f, xs...) +end + +AD.@primitive function pushforward_function(b::AD.EnzymeForwardBackend, f, xs...) + return ds -> Tuple(Enzyme.autodiff(Enzyme.Forward, f, Enzyme.Duplicated.(xs, ds)...)) +end +function AD.value_and_pullback_function(::AD.EnzymeForwardBackend, f, xs...) + return AD.value_and_pullback_function(AD.EnzymeReverseBackend(), f, xs...) +end + +end # module diff --git a/ext/AbstractDifferentiationFiniteDifferencesExt.jl b/ext/AbstractDifferentiationFiniteDifferencesExt.jl index 1199d28..3ea1a98 100644 --- a/ext/AbstractDifferentiationFiniteDifferencesExt.jl +++ b/ext/AbstractDifferentiationFiniteDifferencesExt.jl @@ -13,8 +13,9 @@ end Create an AD backend that uses forward mode with FiniteDifferences.jl. """ -AD.FiniteDifferencesBackend() = - AD.FiniteDifferencesBackend(FiniteDifferences.central_fdm(5, 1)) +function AD.FiniteDifferencesBackend() + return AD.FiniteDifferencesBackend(FiniteDifferences.central_fdm(5, 1)) +end function AD.jacobian(ba::AD.FiniteDifferencesBackend, f, xs...) return FiniteDifferences.jacobian(ba.method, f, xs...) diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index 389f4c8..755c7f5 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -336,6 +336,8 @@ function value_and_pullback_function(ab::AbstractBackend, f, xs...) if ws isa Tuple @assert length(vs) == length(ws) return sum(Base.splat(_dot), zip(ws, vs)) + elseif ws isa Tuple && length(ws) == 1 + return _dot(vs, only(ws)) else return _dot(vs, ws) end diff --git a/src/backends.jl b/src/backends.jl index 397eff5..4d77284 100644 --- a/src/backends.jl +++ b/src/backends.jl @@ -82,3 +82,23 @@ Note, however, that the behaviour of this backend is not equivalent to `ZygoteBa To be able to use this backend, you have to load Zygote. """ struct ZygoteBackend <: AbstractReverseMode end + +""" + EnzymeReverseBackend + +AD backend that uses reverse mode of Enzyme.jl. + +!!! note + To be able to use this backend, you have to load Enzyme. +""" +struct EnzymeReverseBackend <: AbstractReverseMode end + +""" + EnzymeForwardBackend + +AD backend that uses forward mode of Enzyme.jl. + +!!! note + To be able to use this backend, you have to load Enzyme. +""" +struct EnzymeForwardBackend <: AbstractForwardMode end \ No newline at end of file diff --git a/test/enzyme.jl b/test/enzyme.jl new file mode 100644 index 0000000..d586303 --- /dev/null +++ b/test/enzyme.jl @@ -0,0 +1,47 @@ +import AbstractDifferentiation as AD +using Test +using Enzyme + +backends = [ + "EnzymeForwardBackend" => AD.EnzymeForwardBackend(), + "EnzymeReverseBackend" => AD.EnzymeReverseBackend(), +] + +@testset "$name" for (name, backend) in backends + if name == "EnzymeForwardBackend" + @test backend isa AD.AbstractForwardMode + else + @test backend isa AD.AbstractReverseMode + end + + @testset "Derivative" begin + test_derivatives(backend; multiple_inputs=false) + end + @testset "Gradient" begin + test_gradients(backend; multiple_inputs=false) + end + @testset "Jacobian" begin + test_jacobians(backend; multiple_inputs=false) + end + # @testset "Hessian" begin + # test_hessians(backend, multiple_inputs = false) + # end + @testset "jvp" begin + test_jvp(backend; multiple_inputs=false, vaugmented=true) + end + @testset "j′vp" begin + test_j′vp(backend; multiple_inputs=false) + end + @testset "Lazy Derivative" begin + test_lazy_derivatives(backend; multiple_inputs=false) + end + @testset "Lazy Gradient" begin + test_lazy_gradients(backend; multiple_inputs=false) + end + @testset "Lazy Jacobian" begin + test_lazy_jacobians(backend; multiple_inputs=false, vaugmented=true) + end + # @testset "Lazy Hessian" begin + # test_lazy_hessians(backend, multiple_inputs = false) + # end +end diff --git a/test/runtests.jl b/test/runtests.jl index 4716c57..abfbc49 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,4 +11,5 @@ using Test include("finitedifferences.jl") include("tracker.jl") include("ruleconfig.jl") + include("enzyme.jl") end