-
-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add more Duplicated
methods for Enzyme.jl support
#2471
base: master
Are you sure you want to change the base?
Conversation
60acf27
to
6310548
Compare
The docs failure here looks real, but I'm not sure why. |
Fixed the docs. I think we need to own But while I fix that, any objections to the interface? We seem to be merging things in a hurry now... |
CUDA test failure is like this (and one more), why now?
|
The interface looks reasonable to me and Flux owning |
Can you clarify what you mean by "the careful internal plumbing"? If you mean how
Edit: now requires FluxML/Optimisers.jl#192 |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2471 +/- ##
===========================================
+ Coverage 33.54% 62.31% +28.77%
===========================================
Files 31 33 +2
Lines 1881 1993 +112
===========================================
+ Hits 631 1242 +611
+ Misses 1250 751 -499 ☔ View full report in Codecov by Sentry. 🚨 Try these New Features:
|
The gradient code looked pretty straightforward, but there was also The ideal path right now would be landing FluxML/Optimisers.jl#192 and then removing |
Re gradient, one quirk is that Enzyme has the rule that anything not julia> Enzyme.gradient(Reverse, dot, [1, 2.], [3, 4.])
([3.0, 4.0], [1.0, 2.0])
julia> Enzyme.gradient(Reverse, dot, [1, 2.], Const([3, 4.]))
([3.0, 4.0], nothing) while this PR's rule is that, once one thing is julia> Flux.gradient(dot, [1, 2.], [3, 4.]) # Zygote
([3.0, 4.0], [1.0, 2.0])
julia> Flux.gradient(dot, [1, 2.], Duplicated([3, 4.], [NaN, NaN])) # implicit Const
(nothing, [1.0, 2.0])
julia> Flux.gradient(dot, [1, 2.], Const([3, 4.]))
ERROR: ArgumentError: The method `gradient(f, xs...)` using Enzyme.jl requires at least one `Duplicated` argument, not just `Const`. IDK if that's too weird. Edit, one more quirk... now fixed: julia> Flux.gradient((x,y) -> sum((x .* y).^2), [1,2,3.], 4.0) # Zygote
([32.0, 64.0, 96.0], 112.0)
julia> Flux.gradient((x,y) -> sum((x .* y).^2), Duplicated([1,2,3.], zeros(3)), 4.0) # implicit Const
([32.0, 64.0, 96.0], nothing)
julia> Flux.gradient((x,y) -> sum((x .* y).^2), Duplicated([1,2,3.], zeros(3)), Active(4.0))
ERROR: ArgumentError: The method `gradient(f, xs...)` using Enzyme.jl does not support `Active`, only `Duplicated` and ``Const`.
julia> Flux.gradient((x,y) -> sum((x .* y).^2), Duplicated([1,2,3.], zeros(3)), Duplicated(4.0, 0.0)) # now an error, 2636454
ERROR: `Flux.gradient(f, Duplicatged(x), ...)` expects `x` to contain mutable parameter arrays. |
One question this PR opens is how how other AD might slot in. Mooncake.jl similarly prefers pre-allocated space for gradients, but uses different types. The obvious interface would be to have some container like so: struct MoonPair{X,DX}; x::X; dx::DX; end
MoonPair(x) = MoonPair(x, Mooncake.zero_codual(x)) Flux could own this. That would be easy, in terms of e.g. defining Mooncake.jl could also own it, but then |
ext/FluxEnzymeExt/FluxEnzymeExt.jl
Outdated
|
||
forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, map(typeof, args)...) | ||
tape, result, shadow_result = forward(Const(f), args...) | ||
reverse(Const(f), args..., _sensitivity(result), tape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not just call autodiff here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was adapted from the ReverseSplitWithPrimal doc example. My understanding was that it needs that in order to construct _sensitivity(result)
after seeing what f
returns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the expected return type of f. If it's a float it should be fine to just use autodiff dirrectly.
If not a float (as looks like below), it might be more efficient to do something like the following (since split mode will introduce overhead).
@inline asfloat(x) = error("""`Flux.withgradient(f, xs...)` expects that `y = f(xs...)` is a real numnber,
or else a Tuple or NamedTuple whose first element is a real number.""")
@inline asfloat(x::Real) = x
@inline asfloat(x::Tuple) = asfloat(x[1])
@inline asfloat(x:: NamedTuple) = asfloat(x[1])
function return_asfloat(f, args...)
return asfloat(@inline f(args...))
end
autodiff(Reverse, Const(return_asfloat), Active, Const(f), ...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The point is to pass things other than the loss out of the function:
julia> withgradient([1,2,4]) do x
z = 1 ./ x
sum(z), string("aux output: z = ", z)
end
(val = (1.75, "aux output: z = [1.0, 0.5, 0.25]"), grad = ([-1.0, -0.25, -0.0625],))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's possible that trying to infer the return type of f
would pay... when it is shown to return a float you could call ReverseWithPrimal. How big is the overhead of this ReverseSplitWithPrimal?
Edit, with examples from the other thread I can't measure it, <1% maybe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, well in that case maybe something like the following (assuming you know the return type)
function byref(out, f, args)
res = f(args, ...)
out[] = as_float(res)
return res
end
dout = DuplicatedNoNeed(Ref(0.0), Ref(1.0))
autodiff(Reverse, byref, Const, Const(byref), dout, ....)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks possible. But where is it better? Is the overhead of ReverseSplitWithPrimal running it, or in generating code or something (which I didn't try to time)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
running it, though it depends on the code being differentiated how much will matter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW, I tried this Ref approach in ecef1f0 , but couldn't get all cases to work.
Of the cases in docstring text, complicated ones which return (loss, aux...)
worked, but simple ones returning just the loss did not. The @show val
always worked, but result = autodiff
did not always.
Longer term it might be worth making sure that the design is flexible enough for Reactant integration as well (x/ref https://lux.csail.mit.edu/stable/manual/compiling_lux_models). Lux has already shown relatively big speedups at least on CPU |
MLDataDevices has |
…ier commits are a mess now, probably
Co-authored-by: Carlo Lucibello <[email protected]>
sorry for this rebase mess |
What does this break? |
whoever does |
Ah that is true. Broken with a helpful message but I agree. |
We should specialize |
Certainly it should not save the gradient. I guess there's an open question here whether gradient/setup/state/loadmodel! should deal with an outer |
+1 |
ext/FluxEnzymeExt/FluxEnzymeExt.jl
Outdated
function Flux._enzyme_gradient(f, args::Union{Const, Duplicated}...; zero::Bool=true) | ||
for x in args | ||
zero && x isa Duplicated && make_zero!(x.dval) | ||
_check_mutable(x) | ||
end | ||
Enzyme.autodiff(Reverse, f, Active, args...) | ||
map(_grad_or_nothing, args) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some failure cases:
julia> using Flux, Enzyme
# Zygote
julia> Flux.gradient(sum ∘ LayerNorm(3), zeros(3))
([0.0, 0.0, 0.0],)
julia> Flux.gradient(|>, zeros(3), sum ∘ LayerNorm(3))
([0.0, 0.0, 0.0], (outer = nothing, inner = (λ = nothing, diag = (scale = Float32[0.0, 0.0, 0.0], bias = Float32[1.0, 1.0, 1.0], σ = nothing), ϵ = 0.0, size = nothing, affine = nothing)))
# Enzyme
julia> Flux.gradient(sum ∘ LayerNorm(3), Duplicated(zeros(3), zeros(3))
ERROR: Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.
Mismatched activity for: store {} addrspace(10)* %.fca.0.0.0.extract, {} addrspace(10)* addrspace(10)* %.innerparm.sroa.0.0..sroa_cast, align 8, !dbg !15, !alias.scope !17, !noalias !21 const val: %.fca.0.0.0.extract = extractvalue { { { {} addrspace(10)*, {} addrspace(10)* }, float, [1 x i64], i8 } } %0, 0, 0, 0, !dbg !8
value=Unknown object of type Vector{Float32}
llvalue= %.fca.0.0.0.extract = extractvalue { { { {} addrspace(10)*, {} addrspace(10)* }, float, [1 x i64], i8 } } %0, 0, 0, 0, !dbg !8
Stacktrace:
[1] ComposedFunction
@ ./operators.jl:1041
[2] ComposedFunction
@ ./operators.jl:0
Stacktrace:
[1] ComposedFunction
@ ./operators.jl:1041 [inlined]
[2] ComposedFunction
@ ./operators.jl:0 [inlined]
[3] augmented_julia_ComposedFunction_18168_inner_1wrap
@ ./operators.jl:0
[4] macro expansion
@ ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8398 [inlined]
[5] enzyme_call
@ ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:7950 [inlined]
[6] AugmentedForwardThunk
@ ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:7787 [inlined]
[7] autodiff(rmode::ReverseMode{…}, f::Const{…}, ::Type{…}, args::Duplicated{…})
@ Enzyme ~/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:384
[8] _enzyme_gradient(f::Function, args::Duplicated{Vector{Float64}}; zero::Bool)
@ FluxEnzymeExt ~/.julia/dev/Flux/ext/FluxEnzymeExt/FluxEnzymeExt.jl:22
[9] _enzyme_gradient
@ ~/.julia/dev/Flux/ext/FluxEnzymeExt/FluxEnzymeExt.jl:17 [inlined]
[10] gradient(f::Function, args::Duplicated{Vector{Float64}})
@ Flux ~/.julia/dev/Flux/src/gradient.jl:122
julia> Duplicated(c) = Duplicated(c, Enzyme.make_zero(c));
julia> Flux.gradient(|>, Duplicated(zeros(3)), Duplicated(sum ∘ LayerNorm(3)))
ERROR: setfield!: immutable struct of type ComposedFunction cannot be changed
Stacktrace:
[1] make_zero!
@ ~/.julia/packages/Enzyme/RTS5U/src/make_zero.jl:529 [inlined]
[2] make_zero!(prev::ComposedFunction{typeof(sum), LayerNorm{typeof(identity), Flux.Scale{…}, Float32, 1}})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/RTS5U/src/make_zero.jl:504
[3] _enzyme_gradient(::Function, ::Duplicated{Vector{Float64}}, ::Vararg{Union{Const, Duplicated}}; zero::Bool)
@ FluxEnzymeExt ~/.julia/dev/Flux/ext/FluxEnzymeExt/FluxEnzymeExt.jl:19
julia> Flux.gradient(|>, Duplicated(zeros(3)), Const(sum ∘ LayerNorm(3)))
ERROR: Constant memory is stored (or returned) to a differentiable variable.
As a result, Enzyme cannot provably ensure correctness and throws this error.
This might be due to the use of a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/faq/#Runtime-Activity).
If Enzyme should be able to prove this use non-differentable, open an issue!
To work around this issue, either:
a) rewrite this variable to not be conditionally active (fastest, but requires a code change), or
b) set the Enzyme mode to turn on runtime activity (e.g. autodiff(set_runtime_activity(Reverse), ...) ). This will maintain correctness, but may slightly reduce performance.
Mismatched activity for: store {} addrspace(10)* %.fca.0.0.0.extract, {} addrspace(10)* addrspace(10)* %.innerparm.sroa.0.0..sroa_cast, align 8, !dbg !17, !alias.scope !20, !noalias !24 const val: %.fca.0.0.0.extract = extractvalue { { { {} addrspace(10)*, {} addrspace(10)* }, float, [1 x i64], i8 } } %1, 0, 0, 0, !dbg !8
value=Unknown object of type Vector{Float32}
llvalue= %.fca.0.0.0.extract = extractvalue { { { {} addrspace(10)*, {} addrspace(10)* }, float, [1 x i64], i8 } } %1, 0, 0, 0, !dbg !8
# to see code?
julia> @less LayerNorm(3)(zeros(3))
julia> @less LayerNorm(3).diag(zeros(3))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those are solved by 5c1650f
Remaining related bugs in withgradient
:
julia> Flux.withgradient(|>, Duplicated(zeros(3)), Duplicated(sum ∘ LayerNorm(3)))
ERROR: AssertionError: Base.allocatedinline(actualRetType) returns false: actualRetType = Any, rettype = Active{Any}
Stacktrace:
[1] create_abi_wrapper(enzymefn::LLVM.Function, TT::Type, rettype::Type, actualRetType::Type, Mode::Enzyme.API.CDerivativeMode, augmented::Ptr{…}, width::Int64, returnPrimal::Bool, shadow_init::Bool, world::UInt64, interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:4433
[2] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:4169
[3] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:7338
[4] codegen
@ ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:6146 [inlined]
[5] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
@ Enzyme.Compiler ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8468
[6] _thunk
@ ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8468 [inlined]
[7] cached_compilation
@ ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8509 [inlined]
[8] thunkbase(ctx::LLVM.Context, mi::Core.MethodInstance, ::Val{…}, ::Type{…}, ::Type{…}, tt::Type{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Val{…}, ::Type{…}, ::Val{…}, ::Val{…})
@ Enzyme.Compiler ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8641
[9] #s2103#19072
@ ~/.julia/packages/Enzyme/RTS5U/src/compiler.jl:8778 [inlined]
[10]
@ Enzyme.Compiler ./none:0
[11] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
@ Core ./boot.jl:602
[12] autodiff_thunk(::EnzymeCore.ReverseModeSplit{…}, ::Type{…}, ::Type{…}, ::Type{…}, ::Type{…})
@ Enzyme ~/.julia/packages/Enzyme/RTS5U/src/Enzyme.jl:968
[13] _enzyme_withgradient(::Function, ::Duplicated{Vector{Float64}}, ::Vararg{Union{Const, Duplicated}}; zero::Bool)
@ FluxEnzymeExt ~/.julia/dev/Flux/ext/FluxEnzymeExt/FluxEnzymeExt.jl:72
[14] _enzyme_withgradient
@ ~/.julia/dev/Flux/ext/FluxEnzymeExt/FluxEnzymeExt.jl:177 [inlined]
[15] withgradient(::Function, ::Duplicated{Vector{Float64}}, ::Duplicated{ComposedFunction{typeof(sum), LayerNorm{…}}})
@ Flux ~/.julia/dev/Flux/src/gradient.jl:226
[16] top-level scope
@ REPL[112]:1
Some type information was truncated. Use `show(err)` to see complete types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this I think is a good demonstration why using a direct autodiff instead of thunk (and say return the lossfn as a float and return any extra info as a Const(Ref()) or vice versa) would be useful
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah maybe I got that working at last, 016cfe6:
julia> Flux.withgradient(|>, Duplicated(zeros(3)), Duplicated(sum ∘ LayerNorm(3)))
(val = 0.0, grad = ([0.0, 0.0, 0.0], (outer = nothing, inner = (λ = nothing, diag = (scale = Float32[0.0, 0.0, 0.0], bias = Float32[1.0, 1.0, 1.0], σ = nothing), ϵ = nothing, size = (nothing,), affine = nothing))))
That's the "or vice versa" path, I think -- write the loss into Ref(0.0)
& return everything else.
This adds a method like
gradient(f, ::Duplicated)
which liketrain!(loss, model::Duplicated, data, opt)
from #2446 uses the Duplicated type to signal that you want to use Enzyme not Zygote. It returns the gradient (for compatibility?) and mutates theDuplicated
object.To avoid piracy, this creates a new function
Flux.gradient
which by default callsZygote.gradient
. Unfortunately that's going to mean everyusing Flux, Zygote
now produces ambiguities...so probably it should not be exported? Which means 0.15.Such ambiguities give a clear message, maybe that's OK? Maybe clearer than stopping exporting.There's also
withgradient
but it doesn't allow you to return a tuple the way Zygote does, not yet.Now it does.There's also a method of
update!
whicheither needs to move to Optimisers.jl, or again we need to...we should let Flux own the function?has moved to: AddDuplicated
methods Optimisers.jl#192Finally,
@layer Chain
defines a 1-argumentDuplicated(c::Chain)
method, so that you don't need to construct the dual by hand.WIP, RFC?Needs tests, and docs.PR Checklist