Skip to content

Commit

Permalink
disallow Active
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 10, 2024
1 parent dde8e52 commit db67dcf
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
22 changes: 16 additions & 6 deletions src/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,27 @@ function gradient(f, args...; zero::Bool=true)
a isa EnzymeCore.Duplicated && return _enzyme_gradient(f, map(_ensure_enzyme, args)...; zero)
end
for a in args
a isa EnzymeCore.Const && throw(ArgumentError(
"The method `gradient(f, xs...)` using Enzyme.jl requires at least one `Duplicated` argument, not just `Const`."
))
_ensure_noenzyme(a)
end
Zygote.gradient(f, args...)
end

# Given one Duplicated, we wrap everything else in Const before calling Enzyme
_ensure_enzyme(x::EnzymeCore.Duplicated) = x
_ensure_enzyme(x::EnzymeCore.Const) = x
_ensure_enzyme(x) = EnzymeCore.Const(x)
_ensure_enzyme(x::EnzymeCore.Active) = throw(ArgumentError(

Check warning on line 45 in src/gradient.jl

View check run for this annotation

Codecov / codecov/patch

src/gradient.jl#L42-L45

Added lines #L42 - L45 were not covered by tests
"The method `gradient(f, xs...)` using Enzyme.jl does not support `Active`, only `Duplicated` and ``Const`."
))

# Without any Duplicated, check for no stray Enzyme types before calling Zygote
_ensure_noenzyme(::EnzymeCore.Const) = throw(ArgumentError(

Check warning on line 50 in src/gradient.jl

View check run for this annotation

Codecov / codecov/patch

src/gradient.jl#L50

Added line #L50 was not covered by tests
"The method `gradient(f, xs...)` using Enzyme.jl requires at least one `Duplicated` argument, not just `Const`."
))
_ensure_noenzyme(::EnzymeCore.Active) = throw(ArgumentError(

Check warning on line 53 in src/gradient.jl

View check run for this annotation

Codecov / codecov/patch

src/gradient.jl#L53

Added line #L53 was not covered by tests
"The method `gradient(f, xs...)` using Enzyme.jl does not support `Active`, only `Duplicated` and ``Const`"
))
_ensure_noenzyme(_) = nothing

"""
gradient(f, args::Union{Const,Duplicated}...)
Expand All @@ -54,6 +65,7 @@ Only available when Enzyme is loaded!
This method is used when at least one argument is of type `Duplicated`,
and all unspecified aguments are wrapped in `Const`.
Note that Enzyme's `Active` is not supported.
Besides returning the gradient, this is also stored within the `Duplicated` object.
Calling `Enzyme.Duplicated(model)` allocates space for the gradient,
Expand Down Expand Up @@ -153,9 +165,7 @@ function withgradient(f, args...; zero::Bool=true)
a isa EnzymeCore.Duplicated && return _enzyme_withgradient(f, map(_ensure_enzyme, args)...; zero)
end
for a in args
a isa EnzymeCore.Const && throw(ArgumentError(
"The method `withgradient(f, xs...)` using Enzyme.jl requires at least one `Duplicated` argument, not just `Const`."
))
_ensure_noenzyme(a)
end
Zygote.withgradient(f, args...)
end
Expand Down
12 changes: 8 additions & 4 deletions test/ext_enzyme/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,12 @@ end
@test Flux.setup(Adam(), m1) == Flux.setup(Adam(), m1.val)

# At least one Duplicated is required:
@test_throws ArgumentError Flux.gradient(m -> sum(m.bias), Const(m1))
@test_throws ArgumentError Flux.gradient((m,x) -> sum(m(x)), Const(m1), [1,2,3f0])
@test_throws ArgumentError Flux.withgradient(m -> sum(m.bias), Const(m1))
@test_throws ArgumentError Flux.withgradient((m,x) -> sum(m(x)), Const(m1), [1,2,3f0])
@test_throws ArgumentError Flux.gradient(m -> sum(m.bias), Const(m1.val))
@test_throws ArgumentError Flux.gradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0])
@test_throws ArgumentError Flux.withgradient(m -> sum(m.bias), Const(m1.val))
@test_throws ArgumentError Flux.withgradient((m,x) -> sum(m(x)), Const(m1.val), [1,2,3f0])
# Active is disallowed:
@test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, m1, Active(3f0))
@test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, m1.val, Active(3f0))
@test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, Const(m1.val), Active(3f0))
end

0 comments on commit db67dcf

Please sign in to comment.