diff --git a/ext/FluxEnzymeExt/FluxEnzymeExt.jl b/ext/FluxEnzymeExt/FluxEnzymeExt.jl index 595d4d3c55..68d6b59cdd 100644 --- a/ext/FluxEnzymeExt/FluxEnzymeExt.jl +++ b/ext/FluxEnzymeExt/FluxEnzymeExt.jl @@ -20,7 +20,7 @@ function Flux._enzyme_gradient(f, args::Union{Const, Duplicated}...; zero::Bool= zero && x isa Duplicated && _make_zero!(x.dval) _check_mutable(x) end - Enzyme.autodiff(ReverseWithPrimal, f, Active, args...) + Enzyme.autodiff(Reverse, f, Active, args...) map(_grad_or_nothing, args) end