Skip to content

Commit

Permalink
Performance tweaks (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored May 24, 2023
1 parent a710a81 commit 3f3bc10
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 3 deletions.
3 changes: 2 additions & 1 deletion ext/ImplicitDifferentiationChainRulesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ function (implicit_pullback::ImplicitPullback)((dy, dz))
dy_vec = convert(Vector{R}, vec(unthunk(dy)))
dF_vec, stats = linear_solver(Aᵀ_op, dy_vec)
check_solution(linear_solver, stats)
dx_vec = -(Bᵀ_op * dF_vec)
dx_vec = Bᵀ_op * dF_vec
dx_vec .*= -1
dx = reshape(dx_vec, size(x))
return (NoTangent(), dx)
end
Expand Down
3 changes: 2 additions & 1 deletion ext/ImplicitDifferentiationForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ function (implicit::ImplicitFunction)(

dy = map(1:N) do k
dₖx_vec = vec(partials.(x_and_dx, k))
dₖy_vec, stats = linear_solver(A_op, -(B_op * dₖx_vec))
dₖy_vec, stats = linear_solver(A_op, B_op * dₖx_vec)
dₖy_vec .*= -1
check_solution(linear_solver, stats)
reshape(dₖy_vec, size(y))
end
Expand Down
2 changes: 1 addition & 1 deletion src/ImplicitDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module ImplicitDifferentiation

using AbstractDifferentiation: LazyJacobian, ReverseRuleConfigBackend, lazy_jacobian
using Krylov: KrylovStats, gmres
using LinearOperators: LinearOperator
using LinearOperators: LinearOperators, LinearOperator
using Requires: @require

include("utils.jl")
Expand Down
5 changes: 5 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,8 @@ function (pbm::PullbackMul!)(res::AbstractVector, δoutput_vec::AbstractVector)
δinput = only(pbm.pullback(δoutput))
return res .= vec(δinput)
end

## Override this function from LinearOperators to avoid generating the whole methods table

LinearOperators.get_nargs(pfm::PushforwardMul!) = 1
LinearOperators.get_nargs(pbm::PullbackMul!) = 1

0 comments on commit 3f3bc10

Please sign in to comment.