Skip to content

Commit

Permalink
Return the output only from ImplicitFunction not the byproduct by def…
Browse files Browse the repository at this point in the history
…ault (#56)

* return the output only not the byprod by default

* rename var

* Some formatting and documentation additions

* Fix JET test on 1.6

* Fix unknown macro

* Fix tuple in test_opt

* Reformat Project.toml

---------

Co-authored-by: Guillaume Dalle <[email protected]>
  • Loading branch information
mohamed82008 and gdalle authored May 27, 2023
1 parent 58fa014 commit 7f9fa8f
Show file tree
Hide file tree
Showing 14 changed files with 144 additions and 100 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ImplicitDifferentiation"
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"]
version = "0.4.4"
version = "0.5.0-DEV"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
Expand All @@ -20,6 +20,7 @@ ImplicitDifferentiationForwardDiffExt = "ForwardDiff"

[compat]
AbstractDifferentiation = "0.5"
Aqua = "0.6.1"
ChainRulesCore = "1.14"
ForwardDiff = "0.10"
Krylov = "0.8, 0.9"
Expand Down
20 changes: 6 additions & 14 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,6 @@ git-tree-sha1 = "1237bdbcfec728721718ef57dcb855a19c11bf3a"
uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a"
version = "1.10.1"

[[deps.ChangesOfVariables]]
deps = ["LinearAlgebra", "Test"]
git-tree-sha1 = "f84967c4497e0e1955f9a582c232b02847c5f589"
uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
version = "0.1.7"

[[deps.CommonSubexpressions]]
deps = ["MacroTools", "Test"]
git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7"
Expand Down Expand Up @@ -319,7 +313,7 @@ version = "0.1.1"
deps = ["AbstractDifferentiation", "Krylov", "LinearOperators", "Requires"]
path = ".."
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
version = "0.4.1"
version = "0.5.0-DEV"
weakdeps = ["ChainRulesCore", "ForwardDiff"]

[deps.ImplicitDifferentiation.extensions]
Expand All @@ -330,12 +324,6 @@ weakdeps = ["ChainRulesCore", "ForwardDiff"]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[deps.InverseFunctions]]
deps = ["Test"]
git-tree-sha1 = "6667aadd1cdee2c6cd068128b3d226ebc4fb0c67"
uuid = "3587e190-3f89-42d0-90ee-14403ec27112"
version = "0.1.9"

[[deps.IrrationalConstants]]
git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2"
uuid = "92d709cd-6900-40b7-9082-c6be49f344b6"
Expand Down Expand Up @@ -435,13 +423,17 @@ deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"]
git-tree-sha1 = "0a1b7c2863e44523180fdb3146534e265a91870b"
uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
version = "0.3.23"
weakdeps = ["ChainRulesCore", "ChangesOfVariables", "InverseFunctions"]

[deps.LogExpFunctions.extensions]
LogExpFunctionsChainRulesCoreExt = "ChainRulesCore"
LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables"
LogExpFunctionsInverseFunctionsExt = "InverseFunctions"

[deps.LogExpFunctions.weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"

[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

Expand Down
9 changes: 8 additions & 1 deletion docs/src/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Consider using an `SVector` from [StaticArrays.jl](https://github.com/JuliaArray

## Multiple inputs / outputs

In this package, implicit functions can only take a single input array `x` and output a single output array `y` (plus the additional info `z`).
In this package, implicit functions can only take a single input array `x` and output a single output array `y` (plus the byproduct `z`).
But sometimes, your forward pass or conditions may require multiple input arrays, say `a` and `b`:

```julia
Expand All @@ -39,6 +39,13 @@ f(x::ComponentVector) = f(x.a, x.b)

The same trick works for multiple outputs.

## Byproducts

At first glance, it is not obvious why we impose that the `forward` callable returns a byproduct `z` in addition to `y`.
It is mainly useful when the solution procedure creates objects such as Jacobians, which we want to reuse when computing or differentiating the `conditions`.
We will provide simple examples soon.
In the meantime, an advanced application is given by [DifferentiableFrankWolfe.jl](https://github.com/gdalle/DifferentiableFrankWolfe.jl).

## Modeling constrained optimization problems

To express constrained optimization problems as implicit functions, you might need differentiable projections or proximal operators to write the optimality conditions.
Expand Down
20 changes: 10 additions & 10 deletions examples/0_basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ We represent it using a type called `ImplicitFunction`, which you will see in ac

#=
First we define a `forward` pass correponding to the function we consider.
It returns the actual output $y(x)$ of the function, as well as additional information $z(x)$.
Here we don't need any additional information, so we set it to $0$.
It returns the actual output $y(x)$ of the function, as well as the optional byproduct $z(x)$.
Here we don't need any additional information, so we set $z(x)$ to $0$.
Importantly, this forward pass _doesn't need to be differentiable_.
=#

Expand Down Expand Up @@ -118,11 +118,11 @@ What does this wrapper do?
implicit = ImplicitFunction(forward, conditions)

#=
When we call it as a function, it just falls back on `implicit.forward`, so unsurprisingly we get the same tuple $(y(x), z(x))$.
When we call it as a function, it just falls back on `first ∘ implicit.forward`, so unsurprisingly we get the first output $y(x)$.
=#

(first implicit)(x) sqrt.(x)
@test (first implicit)(x) sqrt.(x) #src
implicit(x) sqrt.(x)
@test implicit(x) sqrt.(x) #src

#=
And when we try to compute its Jacobian, the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem) is applied in the background to circumvent the lack of differentiablility of the forward pass.
Expand All @@ -134,15 +134,15 @@ And when we try to compute its Jacobian, the [implicit function theorem](https:/
Now ForwardDiff.jl works seamlessly.
=#

ForwardDiff.jacobian(first implicit, x) J
@test ForwardDiff.jacobian(first implicit, x) J #src
ForwardDiff.jacobian(implicit, x) J
@test ForwardDiff.jacobian(implicit, x) J #src

#=
And so does Zygote.jl. Hurray!
=#

Zygote.jacobian(first implicit, x)[1] J
@test Zygote.jacobian(first implicit, x)[1] J #src
Zygote.jacobian(implicit, x)[1] J
@test Zygote.jacobian(implicit, x)[1] J #src

# ## Second derivative

Expand All @@ -159,6 +159,6 @@ Then the Jacobian itself is differentiable.
=#

h = rand(2)
J_Z(t) = Zygote.jacobian(first implicit2, x .+ t .* h)[1]
J_Z(t) = Zygote.jacobian(implicit2, x .+ t .* h)[1]
ForwardDiff.derivative(J_Z, 0) Diagonal((-0.25 .* h) ./ (x .^ 1.5))
@test ForwardDiff.derivative(J_Z, 0) Diagonal((-0.25 .* h) ./ (x .^ 1.5)) #src
14 changes: 7 additions & 7 deletions examples/1_unconstrained_optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ end

#=
First, we create the forward pass which returns the solution $y(x)$.
Remember that it should also return additional information $z(x)$, which is useless here.
Remember that it should also return a byproduct $z(x)$, which is useless here.
=#
function forward_optim(x; method)
y = mysqrt_optim(x; method)
Expand Down Expand Up @@ -70,8 +70,8 @@ x = rand(2)

#-

first(implicit_optim(x; method=LBFGS())) .^ 2
@test first(implicit_optim(x; method=LBFGS())) .^ 2 x #src
implicit_optim(x; method=LBFGS()) .^ 2
@test implicit_optim(x; method=LBFGS()) .^ 2 x #src

#=
Let's see what the explicit Jacobian looks like.
Expand All @@ -81,8 +81,8 @@ J = Diagonal(0.5 ./ sqrt.(x))

# ## Forward mode autodiff

ForwardDiff.jacobian(_x -> first(implicit_optim(_x; method=LBFGS())), x)
@test ForwardDiff.jacobian(_x -> first(implicit_optim(_x; method=LBFGS())), x) J #src
ForwardDiff.jacobian(_x -> implicit_optim(_x; method=LBFGS()), x)
@test ForwardDiff.jacobian(_x -> implicit_optim(_x; method=LBFGS()), x) J #src

#=
Unsurprisingly, the Jacobian is the identity.
Expand All @@ -93,8 +93,8 @@ ForwardDiff.jacobian(_x -> mysqrt_optim(x; method=LBFGS()), x)

# ## Reverse mode autodiff

Zygote.jacobian(_x -> first(implicit_optim(_x; method=LBFGS())), x)[1]
@test Zygote.jacobian(_x -> first(implicit_optim(_x; method=LBFGS())), x)[1] J #src
Zygote.jacobian(_x -> implicit_optim(_x; method=LBFGS()), x)[1]
@test Zygote.jacobian(_x -> implicit_optim(_x; method=LBFGS()), x)[1] J #src

#=
Again, the Jacobian is the identity.
Expand Down
12 changes: 6 additions & 6 deletions examples/2_nonlinear_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,26 @@ x = rand(2)

#-

first(implicit_nlsolve(x; method=:newton)) .^ 2
@test first(implicit_nlsolve(x; method=:newton)) .^ 2 x #src
implicit_nlsolve(x; method=:newton) .^ 2
@test implicit_nlsolve(x; method=:newton) .^ 2 x #src

#-

J = Diagonal(0.5 ./ sqrt.(x))

# ## Forward mode autodiff

ForwardDiff.jacobian(_x -> first(implicit_nlsolve(_x; method=:newton)), x)
@test ForwardDiff.jacobian(_x -> first(implicit_nlsolve(_x; method=:newton)), x) J #src
ForwardDiff.jacobian(_x -> implicit_nlsolve(_x; method=:newton), x)
@test ForwardDiff.jacobian(_x -> implicit_nlsolve(_x; method=:newton), x) J #src

#-

ForwardDiff.jacobian(_x -> mysqrt_nlsolve(_x; method=:newton), x)

# ## Reverse mode autodiff

Zygote.jacobian(_x -> first(implicit_nlsolve(_x; method=:newton)), x)[1]
@test Zygote.jacobian(_x -> first(implicit_nlsolve(_x; method=:newton)), x)[1] J #src
Zygote.jacobian(_x -> implicit_nlsolve(_x; method=:newton), x)[1]
@test Zygote.jacobian(_x -> implicit_nlsolve(_x; method=:newton), x)[1] J #src

#-

Expand Down
12 changes: 6 additions & 6 deletions examples/3_fixed_points.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,26 +63,26 @@ x = rand(2)

#-

first(implicit_fixedpoint(x; iterations=10)) .^ 2
@test first(implicit_fixedpoint(x; iterations=10)) .^ 2 x #src
implicit_fixedpoint(x; iterations=10) .^ 2
@test implicit_fixedpoint(x; iterations=10) .^ 2 x #src

#-

J = Diagonal(0.5 ./ sqrt.(x))

# ## Forward mode autodiff

ForwardDiff.jacobian(_x -> first(implicit_fixedpoint(_x; iterations=10)), x)
@test ForwardDiff.jacobian(_x -> first(implicit_fixedpoint(_x; iterations=10)), x) J #src
ForwardDiff.jacobian(_x -> implicit_fixedpoint(_x; iterations=10), x)
@test ForwardDiff.jacobian(_x -> implicit_fixedpoint(_x; iterations=10), x) J #src

#-

ForwardDiff.jacobian(_x -> mysqrt_fixedpoint(_x; iterations=10), x)

# ## Reverse mode autodiff

Zygote.jacobian(_x -> first(implicit_fixedpoint(_x; iterations=10)), x)[1]
@test Zygote.jacobian(_x -> first(implicit_fixedpoint(_x; iterations=10)), x)[1] J #src
Zygote.jacobian(_x -> implicit_fixedpoint(_x; iterations=10), x)[1]
@test Zygote.jacobian(_x -> implicit_fixedpoint(_x; iterations=10), x)[1] J #src

#-

Expand Down
12 changes: 6 additions & 6 deletions examples/4_constrained_optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,26 +75,26 @@ x = rand(2) .+ [0, 1]
The second component of $x$ is $> 1$, so its square root will be thresholded to one, and the corresponding derivative will be $0$.
=#

(first implicit_cstr_optim)(x) .^ 2
@test (first implicit_cstr_optim)(x) .^ 2 [x[1], 1] #src
implicit_cstr_optim(x) .^ 2
@test implicit_cstr_optim(x) .^ 2 [x[1], 1] #src

#-

J_thres = Diagonal([0.5 / sqrt(x[1]), 0])

# ## Forward mode autodiff

ForwardDiff.jacobian(first implicit_cstr_optim, x)
@test ForwardDiff.jacobian(first implicit_cstr_optim, x) J_thres #src
ForwardDiff.jacobian(implicit_cstr_optim, x)
@test ForwardDiff.jacobian(implicit_cstr_optim, x) J_thres #src

#-

ForwardDiff.jacobian(mysqrt_cstr_optim, x)

# ## Reverse mode autodiff

Zygote.jacobian(first implicit_cstr_optim, x)[1]
@test Zygote.jacobian(first implicit_cstr_optim, x)[1] J_thres #src
Zygote.jacobian(implicit_cstr_optim, x)[1]
@test Zygote.jacobian(implicit_cstr_optim, x)[1] J_thres #src

#-

Expand Down
12 changes: 6 additions & 6 deletions examples/5_multiargs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ x = ComponentVector(; a=rand(2), b=rand(2))

#-

first(implicit_components(x)) .^ 2
@test first(implicit_components(x)) .^ 2 x.a + 2x.b #src
implicit_components(x) .^ 2
@test implicit_components(x) .^ 2 x.a + 2x.b #src

#=
Let's see what the explicit Jacobian looks like.
Expand All @@ -75,10 +75,10 @@ J = hcat(Diagonal(0.5 ./ sqrt.(x.a + 2x.b)), 2 * Diagonal(0.5 ./ sqrt.(x.a + 2x.

# ## Forward mode autodiff

ForwardDiff.jacobian(_x -> first(implicit_components(_x)), x)
@test ForwardDiff.jacobian(_x -> first(implicit_components(_x)), x) J #src
ForwardDiff.jacobian(implicit_components, x)
@test ForwardDiff.jacobian(implicit_components, x) J #src

# ## Reverse mode autodiff

Zygote.jacobian(_x -> first(implicit_components(_x)), x)[1]
@test Zygote.jacobian(_x -> first(implicit_components(_x)), x)[1] J #src
Zygote.jacobian(implicit_components, x)[1]
@test Zygote.jacobian(implicit_components, x)[1] J #src
27 changes: 19 additions & 8 deletions ext/ImplicitDifferentiationChainRulesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@ We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and settin
Keyword arguments are given to both `implicit.forward` and `implicit.conditions`.
"""
function ChainRulesCore.rrule(
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R}; kwargs...
) where {R}
rc::RuleConfig,
implicit::ImplicitFunction,
x::AbstractArray{R},
::Val{return_byproduct};
kwargs...,
) where {R,return_byproduct}
conditions = implicit.conditions
linear_solver = implicit.linear_solver

y, z = implicit(x; kwargs...)
y, z = implicit(x, Val(true); kwargs...)
n, m = length(x), length(y)

backend = ReverseRuleConfigBackend(rc)
Expand All @@ -29,19 +33,26 @@ function ChainRulesCore.rrule(
pbmB = PullbackMul!(pbB, size(y))
Aᵀ_op = LinearOperator(R, m, m, false, false, pbmA)
Bᵀ_op = LinearOperator(R, n, m, false, false, pbmB)
implicit_pullback = ImplicitPullback(Aᵀ_op, Bᵀ_op, linear_solver, x)
implicit_pullback = ImplicitPullback(
Aᵀ_op, Bᵀ_op, linear_solver, x, Val(return_byproduct)
)

return (y, z), implicit_pullback
return (return_byproduct ? (y, z) : y), implicit_pullback
end

struct ImplicitPullback{A,B,L,X}
struct ImplicitPullback{return_byproduct,A,B,L,X}
Aᵀ_op::A
Bᵀ_op::B
linear_solver::L
x::X
_v::Val{return_byproduct}
end

function (implicit_pullback::ImplicitPullback)((dy, dz))
function (pb::ImplicitPullback{false})(dy)
_pb = ImplicitPullback(pb.Aᵀ_op, pb.Bᵀ_op, pb.linear_solver, pb.x, Val(true))
return _pb((dy, nothing))
end
function (implicit_pullback::ImplicitPullback{true})((dy, _))
Aᵀ_op = implicit_pullback.Aᵀ_op
Bᵀ_op = implicit_pullback.Bᵀ_op
linear_solver = implicit_pullback.linear_solver
Expand All @@ -54,7 +65,7 @@ function (implicit_pullback::ImplicitPullback)((dy, dz))
dx_vec = Bᵀ_op * dF_vec
dx_vec .*= -1
dx = reshape(dx_vec, size(x))
return (NoTangent(), dx)
return (NoTangent(), dx, NoTangent())
end

end
Loading

0 comments on commit 7f9fa8f

Please sign in to comment.