Skip to content

Commit

Permalink
Transparent handling of byproduct (#86)
Browse files Browse the repository at this point in the history
* Guess byproduct or not

* Fix static arrays

* Resolve docs

* Clarify docs
  • Loading branch information
gdalle authored Aug 4, 2023
1 parent 79e470f commit 9d34d8b
Show file tree
Hide file tree
Showing 13 changed files with 341 additions and 375 deletions.
8 changes: 1 addition & 7 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,13 @@
ImplicitFunction
DirectLinearSolver
IterativeLinearSolver
HandleByproduct
ReturnByproduct
ChainRulesCore.rrule
```

## Internals

```@docs
ImplicitDifferentiation.Forward
ImplicitDifferentiation.Conditions
ImplicitDifferentiation.AbstractLinearSolver
ImplicitDifferentiation.PushforwardMul!
ImplicitDifferentiation.PullbackMul!
ChainRulesCore.rrule
```

## Index
Expand Down
18 changes: 5 additions & 13 deletions examples/3_tricks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,29 +122,21 @@ function conditions_cointoss(x, y, z)
end

#=
To make sure that the implicit function you create takes this byproduct into account, just construct it like this:
The `ImplicitFunction` is created as usual:
=#

implicit_cointoss = ImplicitFunction(
forward_cointoss, conditions_cointoss, HandleByproduct()
)
implicit_cointoss = ImplicitFunction(forward_cointoss, conditions_cointoss)

#=
Then you have two ways of calling the function: the standard way will only return `y`
But this time, when you call it, it will return a tuple:
=#

x = [1.0, 1.0]

implicit_cointoss(x)

#=
Or if you also need the byproduct, you can do
Differentiation works by taking the byproduct into account but without computing a derivative for it:
=#

implicit_cointoss(x, ReturnByproduct())

#=
But whatever you choose, the byproduct is taken into account during differentiation!
=#

Zygote.withjacobian(implicit_cointoss, x)
Zygote.withjacobian(first implicit_cointoss, x)
65 changes: 26 additions & 39 deletions ext/ImplicitDifferentiationChainRulesExt.jl
Original file line number Diff line number Diff line change
@@ -1,67 +1,54 @@
module ImplicitDifferentiationChainRulesExt

using AbstractDifferentiation: ReverseRuleConfigBackend, pullback_function
using AbstractDifferentiation: ReverseRuleConfigBackend
using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, ZeroTangent, rrule, unthunk
using ImplicitDifferentiation: ImplicitFunction, PullbackMul!, ReturnByproduct
using ImplicitDifferentiation: presolve, solve
using ImplicitDifferentiation: ImplicitFunction, reverse_operators, solve
using LinearAlgebra: lmul!, mul!
using LinearOperators: LinearOperator
using SimpleUnPack: @unpack

"""
rrule(rc, implicit, x[, ReturnByproduct()]; kwargs...)
rrule(rc, implicit, x; kwargs...)
Custom reverse rule for an [`ImplicitFunction`](@ref), to ensure compatibility with reverse mode autodiff.
This is only available if ChainRulesCore.jl is loaded (extension), except on Julia < 1.9 where it is always available.
- By default, this returns a single output `y(x)` with a pullback accepting a single cotangent `dy`.
- If `ReturnByproduct()` is passed as an argument, this returns a couple of outputs `(y(x),z(x))` with a pullback accepting a couple of cotangents `(dy, dz)` (remember that `z(x)` is not differentiated so its cotangent is ignored).
We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and setting `Jᵀv = -Bᵀu` (see [`ImplicitFunction`](@ref) for the definition of `A` and `B`).
We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and setting `Jᵀv = -Bᵀu`.
Keyword arguments are given to both `implicit.forward` and `implicit.conditions`.
"""
function ChainRulesCore.rrule(
rc::RuleConfig,
implicit::ImplicitFunction,
x::AbstractArray{R},
::ReturnByproduct;
kwargs...,
) where {R}
@unpack conditions, linear_solver = implicit

y, z = implicit(x, ReturnByproduct(); kwargs...)
n, m = length(x), length(y)

backend = ReverseRuleConfigBackend(rc)
pbA = pullback_function(backend, _y -> conditions(x, _y, z; kwargs...), y)
pbB = pullback_function(backend, _x -> conditions(_x, y, z; kwargs...), x)

Aᵀ_op = LinearOperator(R, m, m, false, false, PullbackMul!(pbA, size(y)))
Bᵀ_op = LinearOperator(R, n, m, false, false, PullbackMul!(pbB, size(y)))
Aᵀ_op_presolved = presolve(linear_solver, Aᵀ_op, y)

implicit_pullback = ImplicitPullback(Aᵀ_op_presolved, Bᵀ_op, linear_solver, x)

return (y, z), implicit_pullback
end

function ChainRulesCore.rrule(
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R}; kwargs...
) where {R}
(y, z), implicit_pullback = rrule(rc, implicit, x, ReturnByproduct(); kwargs...)
implicit_pullback_no_byproduct(dy) = Base.front(implicit_pullback((dy, nothing)))
return y, implicit_pullback_no_byproduct
y_or_yz = implicit(x; kwargs...)
backend = ReverseRuleConfigBackend(rc)
Aᵀ_op, Bᵀ_op = reverse_operators(backend, implicit, x, y_or_yz; kwargs)
byproduct = y_or_yz isa Tuple
implicit_pullback = ImplicitPullback{byproduct}(Aᵀ_op, Bᵀ_op, implicit.linear_solver, x)
return y_or_yz, implicit_pullback
end

struct ImplicitPullback{A,B,L,X}
struct ImplicitPullback{byproduct,A,B,L,X}
Aᵀ_op::A
Bᵀ_op::B
linear_solver::L
x::X

function ImplicitPullback{byproduct}(
Aᵀ_op::A, Bᵀ_op::B, linear_solver::L, x::X
) where {byproduct,A,B,L,X}
return new{byproduct,A,B,L,X}(Aᵀ_op, Bᵀ_op, linear_solver, x)
end
end

function (implicit_pullback::ImplicitPullback{true})((dy, dz))
return _apply(implicit_pullback, dy)
end

function (implicit_pullback::ImplicitPullback{false})(dy)
return _apply(implicit_pullback, dy)
end

function (implicit_pullback::ImplicitPullback)((dy, dz))
function _apply(implicit_pullback::ImplicitPullback, dy)
@unpack Aᵀ_op, Bᵀ_op, linear_solver, x = implicit_pullback
R = eltype(x)
dy_vec = convert(AbstractVector{R}, vec(unthunk(dy)))
Expand Down
48 changes: 18 additions & 30 deletions ext/ImplicitDifferentiationForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,50 +6,37 @@ else
using ..ForwardDiff: Dual, Partials, jacobian, partials, value
end

using AbstractDifferentiation: ForwardDiffBackend, pushforward_function
using ImplicitDifferentiation: ImplicitFunction, PushforwardMul!, ReturnByproduct
using ImplicitDifferentiation: DirectLinearSolver, IterativeLinearSolver
using ImplicitDifferentiation: presolve, solve, identity_break_autodiff
using AbstractDifferentiation: AbstractBackend, ForwardDiffBackend, pushforward_function
using ImplicitDifferentiation: ImplicitFunction, DirectLinearSolver, IterativeLinearSolver
using ImplicitDifferentiation: forward_operators, solve, identity_break_autodiff
using LinearAlgebra: lmul!, mul!
using LinearOperators: LinearOperator
using PrecompileTools: @compile_workload
using SimpleUnPack: @unpack

"""
implicit(x_and_dx::AbstractArray{<:Dual}[, ReturnByproduct()]; kwargs...)
implicit(x_and_dx::AbstractArray{<:Dual}; kwargs...)
Overload an [`ImplicitFunction`](@ref) on dual numbers to ensure compatibility with forward mode autodiff.
This is only available if ForwardDiff.jl is loaded (extension).
- By default, this returns a single output `y_and_dy(x)`.
- If `ReturnByproduct()` is passed as an argument, this returns a couple of outputs `(y_and_dy(x),z(x))` (remember that `z(x)` is not differentiated so `dz(x)` doesn't exist).
We compute the Jacobian-vector product `Jv` by solving `Au = Bv` and setting `Jv = u` (see [`ImplicitFunction`](@ref) for the definition of `A` and `B`).
We compute the Jacobian-vector product `Jv` by solving `Au = Bv` and setting `Jv = u`.
Keyword arguments are given to both `implicit.forward` and `implicit.conditions`.
"""
function (implicit::ImplicitFunction)(
x_and_dx::AbstractArray{Dual{T,R,N}}, ::ReturnByproduct; kwargs...
x_and_dx::AbstractArray{Dual{T,R,N}}; kwargs...
) where {T,R,N}
@unpack conditions, linear_solver = implicit

x = value.(x_and_dx)
y, z = implicit(x, ReturnByproduct(); kwargs...)
n, m = length(x), length(y)
y_or_yz = implicit(x; kwargs...)
y = _output(y_or_yz)

backend = ForwardDiffBackend()
pfA = pushforward_function(backend, _y -> conditions(x, _y, z; kwargs...), y)
pfB = pushforward_function(backend, _x -> conditions(_x, y, z; kwargs...), x)

A_op = LinearOperator(R, m, m, false, false, PushforwardMul!(pfA, size(y)))
B_op = LinearOperator(R, m, n, false, false, PushforwardMul!(pfB, size(x)))
A_op_presolved = presolve(linear_solver, A_op, y)
A_op, B_op = forward_operators(backend, implicit, x, y_or_yz; kwargs)

dy = ntuple(Val(N)) do k
dₖx_vec = vec(partials.(x_and_dx, k))
Bdx = vec(similar(y))
mul!(Bdx, B_op, dₖx_vec)
dₖy_vec = solve(linear_solver, A_op_presolved, Bdx)
dₖy_vec = solve(implicit.linear_solver, A_op, Bdx)
lmul!(-one(R), dₖy_vec)
reshape(dₖy_vec, size(y))
end
Expand All @@ -61,15 +48,16 @@ function (implicit::ImplicitFunction)(
reshape(y_and_dy_vec, size(y))
end

return y_and_dy, z
if y_or_yz isa Tuple
return y_and_dy, _byproduct(y_or_yz)
else
return y_and_dy
end
end

function (implicit::ImplicitFunction)(
x_and_dx::AbstractArray{Dual{T,R,N}}; kwargs...
) where {T,R,N}
y_and_dy, z = implicit(x_and_dx, ReturnByproduct(); kwargs...)
return y_and_dy
end
_output(y::AbstractArray) = y
_output(yz::Tuple) = yz[1]
_byproduct(yz::Tuple) = yz[2]

@compile_workload begin
forward(x) = sqrt.(identity_break_autodiff(x))
Expand Down
9 changes: 4 additions & 5 deletions src/ImplicitDifferentiation.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
module ImplicitDifferentiation

using Krylov: KrylovStats, gmres
using AbstractDifferentiation: AbstractBackend, pushforward_function, pullback_function
using Krylov: gmres
using LinearOperators: LinearOperators, LinearOperator
using LinearAlgebra: lu, SingularException, issuccess
using LinearAlgebra: issuccess, lu
using PrecompileTools: @compile_workload
using Requires: @require
using SimpleUnPack: @unpack

include("utils.jl")
include("forward.jl")
include("conditions.jl")
include("linear_solver.jl")
include("implicit_function.jl")
include("operators.jl")

export ImplicitFunction
export IterativeLinearSolver, DirectLinearSolver
export HandleByproduct, ReturnByproduct

@static if !isdefined(Base, :get_extension)
include("../ext/ImplicitDifferentiationChainRulesExt.jl")
Expand Down
22 changes: 0 additions & 22 deletions src/conditions.jl

This file was deleted.

37 changes: 0 additions & 37 deletions src/forward.jl

This file was deleted.

Loading

0 comments on commit 9d34d8b

Please sign in to comment.