-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enzyme support #85
base: master
Are you sure you want to change the base?
Enzyme support #85
Changes from all commits
0021bee
62868e4
77699c3
548a6bb
93ef317
63e66ac
c45b16e
080cccc
0599383
328d7e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
module AbstractDifferentiationEnzymeExt | ||
|
||
if isdefined(Base, :get_extension) | ||
import AbstractDifferentiation as AD | ||
using Enzyme: Enzyme | ||
else | ||
import ..AbstractDifferentiation as AD | ||
using ..Enzyme: Enzyme | ||
end | ||
|
||
struct Mutating{F} | ||
f::F | ||
end | ||
function (f::Mutating)(y, xs...) | ||
y .= f.f(xs...) | ||
return y | ||
end | ||
|
||
AD.@primitive function value_and_pullback_function(b::AD.EnzymeReverseBackend, f, xs...) | ||
y = f(xs...) | ||
return y, | ||
Δ -> begin | ||
Δ_xs = zero.(xs) | ||
dup = if y isa Real | ||
if Δ isa Real | ||
Enzyme.Duplicated([y], [Δ]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems a bit strange - that's not something an Enzyme user would do AFAIK. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ya it's a quick and dirty hack to get it running, needs to be optimised There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe one can reuse some of the things I did in TuringLang/DistributionsAD.jl#254. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this is a real or tuple of real, this should be an active argument [in reverse mode] |
||
elseif Δ isa Tuple{Real} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The tuple issue hits again... |
||
Enzyme.Duplicated([y], [Δ[1]]) | ||
else | ||
throw(ArgumentError("Unsupported cotangent type.")) | ||
end | ||
else | ||
if Δ isa AbstractArray{<:Real} | ||
Enzyme.Duplicated(y, Δ) | ||
elseif Δ isa Tuple{AbstractArray{<:Real}} | ||
Enzyme.Duplicated(y, Δ[1]) | ||
else | ||
throw(ArgumentError("Unsupported cotangent type.")) | ||
end | ||
end | ||
Enzyme.autodiff( | ||
Enzyme.Reverse, | ||
Mutating(f), | ||
Enzyme.Const, | ||
dup, | ||
Enzyme.Duplicated.(xs, Δ_xs)..., | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That means users of AbstractDifferentiation miss a major feature of Enzyme. But maybe it's unavoidable and the current design of AbstractDifferentiation can't support it and the wrapper will always be less performant than Enzyme? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's brainstorm solutions, I think it's possible to support partial pullback with an extended API |
||
) | ||
return Δ_xs | ||
end | ||
end | ||
function AD.pushforward_function(::AD.EnzymeReverseBackend, f, xs...) | ||
return AD.pushforward_function(AD.EnzymeForwardBackend(), f, xs...) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This creates an inconsistency with the behaviour of other backends where it is guaranteed that the specified backend is used for every operation. I think the better design might be to have dedicated Reverse+Forward wrappers that allow to specify different backends for forward and reverse mode operations and pick the best mode for every call. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. This was done to make some failed tests pass which likely fail due to an Enzyme correctness issue. We should change this before merge. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the test case of the correctness issue? Can you open an issue with it? |
||
end | ||
|
||
AD.@primitive function pushforward_function(b::AD.EnzymeForwardBackend, f, xs...) | ||
return ds -> Tuple(Enzyme.autodiff(Enzyme.Forward, f, Enzyme.Duplicated.(xs, ds)...)) | ||
end | ||
function AD.value_and_pullback_function(::AD.EnzymeForwardBackend, f, xs...) | ||
return AD.value_and_pullback_function(AD.EnzymeReverseBackend(), f, xs...) | ||
end | ||
|
||
end # module |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -82,3 +82,23 @@ Note, however, that the behaviour of this backend is not equivalent to `ZygoteBa | |||||
To be able to use this backend, you have to load Zygote. | ||||||
""" | ||||||
struct ZygoteBackend <: AbstractReverseMode end | ||||||
|
||||||
""" | ||||||
EnzymeReverseBackend | ||||||
|
||||||
AD backend that uses reverse mode of Enzyme.jl. | ||||||
|
||||||
!!! note | ||||||
To be able to use this backend, you have to load Enzyme. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only true on Julia >= 1.9 I think? |
||||||
""" | ||||||
struct EnzymeReverseBackend <: AbstractReverseMode end | ||||||
|
||||||
""" | ||||||
EnzymeForwardBackend | ||||||
|
||||||
AD backend that uses forward mode of Enzyme.jl. | ||||||
|
||||||
!!! note | ||||||
To be able to use this backend, you have to load Enzyme. | ||||||
""" | ||||||
struct EnzymeForwardBackend <: AbstractForwardMode end | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [JuliaFormatter] reported by reviewdog 🐶
Suggested change
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import AbstractDifferentiation as AD | ||
using Test | ||
using Enzyme | ||
|
||
backends = [ | ||
"EnzymeForwardBackend" => AD.EnzymeForwardBackend(), | ||
"EnzymeReverseBackend" => AD.EnzymeReverseBackend(), | ||
] | ||
|
||
@testset "$name" for (name, backend) in backends | ||
if name == "EnzymeForwardBackend" | ||
@test backend isa AD.AbstractForwardMode | ||
else | ||
@test backend isa AD.AbstractReverseMode | ||
end | ||
|
||
@testset "Derivative" begin | ||
test_derivatives(backend; multiple_inputs=false) | ||
end | ||
@testset "Gradient" begin | ||
test_gradients(backend; multiple_inputs=false) | ||
end | ||
@testset "Jacobian" begin | ||
test_jacobians(backend; multiple_inputs=false) | ||
end | ||
# @testset "Hessian" begin | ||
# test_hessians(backend, multiple_inputs = false) | ||
# end | ||
@testset "jvp" begin | ||
test_jvp(backend; multiple_inputs=false, vaugmented=true) | ||
end | ||
@testset "j′vp" begin | ||
test_j′vp(backend; multiple_inputs=false) | ||
end | ||
@testset "Lazy Derivative" begin | ||
test_lazy_derivatives(backend; multiple_inputs=false) | ||
end | ||
@testset "Lazy Gradient" begin | ||
test_lazy_gradients(backend; multiple_inputs=false) | ||
end | ||
@testset "Lazy Jacobian" begin | ||
test_lazy_jacobians(backend; multiple_inputs=false, vaugmented=true) | ||
end | ||
# @testset "Lazy Hessian" begin | ||
# test_lazy_hessians(backend, multiple_inputs = false) | ||
# end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should use ReverseSplitMode here, and call the augmented forward pass for that result, use the reverse pass (and tape created from aug) for the reverse pass.