Skip to content

Commit a484ad6

Browse files
Merge pull request #382 from SciML/krylov_enzyme
Fix KrylovJL_GMRES with Enzyme
2 parents 8008fa3 + 5eb5a18 commit a484ad6

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

ext/LinearSolveEnzymeExt.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ using LinearSolve.LinearAlgebra
55
using EnzymeCore
66
using EnzymeCore: EnzymeRules
77

8+
@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:LinearSolve.SciMLLinearSolveAlgorithm}) = true
9+
810
function EnzymeRules.forward(config::EnzymeRules.FwdConfigWidth{1},
911
func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP},
1012
alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
@@ -223,10 +225,10 @@ function EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)},
223225
elseif _linsolve.alg isa LinearSolve.AbstractKrylovSubspaceMethod
224226
# Doesn't modify `A`, so it's safe to just reuse it
225227
invprob = LinearSolve.LinearProblem(transpose(_linsolve.A), dy)
226-
solve(invprob, _linearsolve.alg;
227-
abstol = _linsolve.val.abstol,
228-
reltol = _linsolve.val.reltol,
229-
verbose = _linsolve.val.verbose)
228+
solve(invprob, _linsolve.alg;
229+
abstol = _linsolve.abstol,
230+
reltol = _linsolve.reltol,
231+
verbose = _linsolve.verbose)
230232
elseif _linsolve.alg isa LinearSolve.DefaultLinearSolver
231233
LinearSolve.defaultalg_adjoint_eval(_linsolve, dy)
232234
else

src/iterative_wrappers.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::KrylovJL; kwargs...)
327327
else
328328
cache.u = convert(typeof(cache.u), cacheval.x)
329329
end
330-
331-
return SciMLBase.build_linear_solution(alg, cache.u, resid, cache;
330+
331+
return SciMLBase.build_linear_solution(alg, cache.u, Ref(resid), cache;
332332
iters = stats.niter, retcode, stats)
333333
end

test/enzyme.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ Enzyme.autodiff(Reverse, f2, Duplicated(copy(A), dA),
157157
@test db1 db12
158158
@test db2 db22
159159

160-
#=
161160
function f3(A, b1, b2; alg = KrylovJL_GMRES())
162161
prob = LinearProblem(A, b1)
163162
cache = init(prob, alg)
@@ -167,12 +166,14 @@ function f3(A, b1, b2; alg = KrylovJL_GMRES())
167166
norm(s1 + s2)
168167
end
169168

170-
Enzyme.autodiff(Reverse, f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))
169+
dA = zeros(n, n);
170+
db1 = zeros(n);
171+
db2 = zeros(n);
172+
Enzyme.autodiff(set_runtime_activity(Reverse), f3, Duplicated(copy(A), dA), Duplicated(copy(b1), db1), Duplicated(copy(b2), db2))
171173

172174
@test dA dA2 atol=5e-5
173175
@test db1 db12
174176
@test db2 db22
175-
=#
176177

177178
A = rand(n, n);
178179
dA = zeros(n, n);

0 commit comments

Comments
 (0)