diff --git a/Project.toml b/Project.toml index 14e3b5e..17313c2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ImplicitDifferentiation" uuid = "57b37032-215b-411a-8a7c-41a003a55207" authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"] -version = "0.4.4" +version = "0.5.0-DEV" [deps] AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" @@ -20,6 +20,7 @@ ImplicitDifferentiationForwardDiffExt = "ForwardDiff" [compat] AbstractDifferentiation = "0.5" +Aqua = "0.6.1" ChainRulesCore = "1.14" ForwardDiff = "0.10" Krylov = "0.8, 0.9" diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 10f9d2b..34c10a2 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -113,12 +113,6 @@ git-tree-sha1 = "1237bdbcfec728721718ef57dcb855a19c11bf3a" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" version = "1.10.1" -[[deps.ChangesOfVariables]] -deps = ["LinearAlgebra", "Test"] -git-tree-sha1 = "f84967c4497e0e1955f9a582c232b02847c5f589" -uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" -version = "0.1.7" - [[deps.CommonSubexpressions]] deps = ["MacroTools", "Test"] git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" @@ -319,7 +313,7 @@ version = "0.1.1" deps = ["AbstractDifferentiation", "Krylov", "LinearOperators", "Requires"] path = ".." uuid = "57b37032-215b-411a-8a7c-41a003a55207" -version = "0.4.1" +version = "0.5.0-DEV" weakdeps = ["ChainRulesCore", "ForwardDiff"] [deps.ImplicitDifferentiation.extensions] @@ -330,12 +324,6 @@ weakdeps = ["ChainRulesCore", "ForwardDiff"] deps = ["Markdown"] uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" -[[deps.InverseFunctions]] -deps = ["Test"] -git-tree-sha1 = "6667aadd1cdee2c6cd068128b3d226ebc4fb0c67" -uuid = "3587e190-3f89-42d0-90ee-14403ec27112" -version = "0.1.9" - [[deps.IrrationalConstants]] git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" @@ -435,13 +423,17 @@ deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] git-tree-sha1 = "0a1b7c2863e44523180fdb3146534e265a91870b" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" version = "0.3.23" -weakdeps = ["ChainRulesCore", "ChangesOfVariables", "InverseFunctions"] [deps.LogExpFunctions.extensions] LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" LogExpFunctionsInverseFunctionsExt = "InverseFunctions" + [deps.LogExpFunctions.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" + InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" + [[deps.Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" diff --git a/docs/src/faq.md b/docs/src/faq.md index 1139ae5..aed9955 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -21,7 +21,7 @@ Consider using an `SVector` from [StaticArrays.jl](https://github.com/JuliaArray ## Multiple inputs / outputs -In this package, implicit functions can only take a single input array `x` and output a single output array `y` (plus the additional info `z`). +In this package, implicit functions can only take a single input array `x` and output a single output array `y` (plus the byproduct `z`). But sometimes, your forward pass or conditions may require multiple input arrays, say `a` and `b`: ```julia @@ -39,6 +39,13 @@ f(x::ComponentVector) = f(x.a, x.b) The same trick works for multiple outputs. +## Byproducts + +At first glance, it is not obvious why we impose that the `forward` callable returns a byproduct `z` in addition to `y`. +It is mainly useful when the solution procedure creates objects such as Jacobians, which we want to reuse when computing or differentiating the `conditions`. +We will provide simple examples soon. +In the meantime, an advanced application is given by [DifferentiableFrankWolfe.jl](https://github.com/gdalle/DifferentiableFrankWolfe.jl). + ## Modeling constrained optimization problems To express constrained optimization problems as implicit functions, you might need differentiable projections or proximal operators to write the optimality conditions. diff --git a/examples/0_basic.jl b/examples/0_basic.jl index ef07c40..73ecd1e 100644 --- a/examples/0_basic.jl +++ b/examples/0_basic.jl @@ -87,8 +87,8 @@ We represent it using a type called `ImplicitFunction`, which you will see in ac #= First we define a `forward` pass correponding to the function we consider. -It returns the actual output $y(x)$ of the function, as well as additional information $z(x)$. -Here we don't need any additional information, so we set it to $0$. +It returns the actual output $y(x)$ of the function, as well as the optional byproduct $z(x)$. +Here we don't need any additional information, so we set $z(x)$ to $0$. Importantly, this forward pass _doesn't need to be differentiable_. =# @@ -118,11 +118,11 @@ What does this wrapper do? implicit = ImplicitFunction(forward, conditions) #= -When we call it as a function, it just falls back on `implicit.forward`, so unsurprisingly we get the same tuple $(y(x), z(x))$. +When we call it as a function, it just falls back on `first ∘ implicit.forward`, so unsurprisingly we get the first output $y(x)$. =# -(first ∘ implicit)(x) ≈ sqrt.(x) -@test (first ∘ implicit)(x) ≈ sqrt.(x) #src +implicit(x) ≈ sqrt.(x) +@test implicit(x) ≈ sqrt.(x) #src #= And when we try to compute its Jacobian, the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem) is applied in the background to circumvent the lack of differentiablility of the forward pass. @@ -134,15 +134,15 @@ And when we try to compute its Jacobian, the [implicit function theorem](https:/ Now ForwardDiff.jl works seamlessly. =# -ForwardDiff.jacobian(first ∘ implicit, x) ≈ J -@test ForwardDiff.jacobian(first ∘ implicit, x) ≈ J #src +ForwardDiff.jacobian(implicit, x) ≈ J +@test ForwardDiff.jacobian(implicit, x) ≈ J #src #= And so does Zygote.jl. Hurray! =# -Zygote.jacobian(first ∘ implicit, x)[1] ≈ J -@test Zygote.jacobian(first ∘ implicit, x)[1] ≈ J #src +Zygote.jacobian(implicit, x)[1] ≈ J +@test Zygote.jacobian(implicit, x)[1] ≈ J #src # ## Second derivative @@ -159,6 +159,6 @@ Then the Jacobian itself is differentiable. =# h = rand(2) -J_Z(t) = Zygote.jacobian(first ∘ implicit2, x .+ t .* h)[1] +J_Z(t) = Zygote.jacobian(implicit2, x .+ t .* h)[1] ForwardDiff.derivative(J_Z, 0) ≈ Diagonal((-0.25 .* h) ./ (x .^ 1.5)) @test ForwardDiff.derivative(J_Z, 0) ≈ Diagonal((-0.25 .* h) ./ (x .^ 1.5)) #src diff --git a/examples/1_unconstrained_optim.jl b/examples/1_unconstrained_optim.jl index e95342f..e38dda3 100644 --- a/examples/1_unconstrained_optim.jl +++ b/examples/1_unconstrained_optim.jl @@ -41,7 +41,7 @@ end #= First, we create the forward pass which returns the solution $y(x)$. -Remember that it should also return additional information $z(x)$, which is useless here. +Remember that it should also return a byproduct $z(x)$, which is useless here. =# function forward_optim(x; method) y = mysqrt_optim(x; method) @@ -70,8 +70,8 @@ x = rand(2) #- -first(implicit_optim(x; method=LBFGS())) .^ 2 -@test first(implicit_optim(x; method=LBFGS())) .^ 2 ≈ x #src +implicit_optim(x; method=LBFGS()) .^ 2 +@test implicit_optim(x; method=LBFGS()) .^ 2 ≈ x #src #= Let's see what the explicit Jacobian looks like. @@ -81,8 +81,8 @@ J = Diagonal(0.5 ./ sqrt.(x)) # ## Forward mode autodiff -ForwardDiff.jacobian(_x -> first(implicit_optim(_x; method=LBFGS())), x) -@test ForwardDiff.jacobian(_x -> first(implicit_optim(_x; method=LBFGS())), x) ≈ J #src +ForwardDiff.jacobian(_x -> implicit_optim(_x; method=LBFGS()), x) +@test ForwardDiff.jacobian(_x -> implicit_optim(_x; method=LBFGS()), x) ≈ J #src #= Unsurprisingly, the Jacobian is the identity. @@ -93,8 +93,8 @@ ForwardDiff.jacobian(_x -> mysqrt_optim(x; method=LBFGS()), x) # ## Reverse mode autodiff -Zygote.jacobian(_x -> first(implicit_optim(_x; method=LBFGS())), x)[1] -@test Zygote.jacobian(_x -> first(implicit_optim(_x; method=LBFGS())), x)[1] ≈ J #src +Zygote.jacobian(_x -> implicit_optim(_x; method=LBFGS()), x)[1] +@test Zygote.jacobian(_x -> implicit_optim(_x; method=LBFGS()), x)[1] ≈ J #src #= Again, the Jacobian is the identity. diff --git a/examples/2_nonlinear_solve.jl b/examples/2_nonlinear_solve.jl index c564bf4..b5688fe 100644 --- a/examples/2_nonlinear_solve.jl +++ b/examples/2_nonlinear_solve.jl @@ -63,8 +63,8 @@ x = rand(2) #- -first(implicit_nlsolve(x; method=:newton)) .^ 2 -@test first(implicit_nlsolve(x; method=:newton)) .^ 2 ≈ x #src +implicit_nlsolve(x; method=:newton) .^ 2 +@test implicit_nlsolve(x; method=:newton) .^ 2 ≈ x #src #- @@ -72,8 +72,8 @@ J = Diagonal(0.5 ./ sqrt.(x)) # ## Forward mode autodiff -ForwardDiff.jacobian(_x -> first(implicit_nlsolve(_x; method=:newton)), x) -@test ForwardDiff.jacobian(_x -> first(implicit_nlsolve(_x; method=:newton)), x) ≈ J #src +ForwardDiff.jacobian(_x -> implicit_nlsolve(_x; method=:newton), x) +@test ForwardDiff.jacobian(_x -> implicit_nlsolve(_x; method=:newton), x) ≈ J #src #- @@ -81,8 +81,8 @@ ForwardDiff.jacobian(_x -> mysqrt_nlsolve(_x; method=:newton), x) # ## Reverse mode autodiff -Zygote.jacobian(_x -> first(implicit_nlsolve(_x; method=:newton)), x)[1] -@test Zygote.jacobian(_x -> first(implicit_nlsolve(_x; method=:newton)), x)[1] ≈ J #src +Zygote.jacobian(_x -> implicit_nlsolve(_x; method=:newton), x)[1] +@test Zygote.jacobian(_x -> implicit_nlsolve(_x; method=:newton), x)[1] ≈ J #src #- diff --git a/examples/3_fixed_points.jl b/examples/3_fixed_points.jl index c6bde74..810ffaa 100644 --- a/examples/3_fixed_points.jl +++ b/examples/3_fixed_points.jl @@ -63,8 +63,8 @@ x = rand(2) #- -first(implicit_fixedpoint(x; iterations=10)) .^ 2 -@test first(implicit_fixedpoint(x; iterations=10)) .^ 2 ≈ x #src +implicit_fixedpoint(x; iterations=10) .^ 2 +@test implicit_fixedpoint(x; iterations=10) .^ 2 ≈ x #src #- @@ -72,8 +72,8 @@ J = Diagonal(0.5 ./ sqrt.(x)) # ## Forward mode autodiff -ForwardDiff.jacobian(_x -> first(implicit_fixedpoint(_x; iterations=10)), x) -@test ForwardDiff.jacobian(_x -> first(implicit_fixedpoint(_x; iterations=10)), x) ≈ J #src +ForwardDiff.jacobian(_x -> implicit_fixedpoint(_x; iterations=10), x) +@test ForwardDiff.jacobian(_x -> implicit_fixedpoint(_x; iterations=10), x) ≈ J #src #- @@ -81,8 +81,8 @@ ForwardDiff.jacobian(_x -> mysqrt_fixedpoint(_x; iterations=10), x) # ## Reverse mode autodiff -Zygote.jacobian(_x -> first(implicit_fixedpoint(_x; iterations=10)), x)[1] -@test Zygote.jacobian(_x -> first(implicit_fixedpoint(_x; iterations=10)), x)[1] ≈ J #src +Zygote.jacobian(_x -> implicit_fixedpoint(_x; iterations=10), x)[1] +@test Zygote.jacobian(_x -> implicit_fixedpoint(_x; iterations=10), x)[1] ≈ J #src #- diff --git a/examples/4_constrained_optim.jl b/examples/4_constrained_optim.jl index 71a369f..9581611 100644 --- a/examples/4_constrained_optim.jl +++ b/examples/4_constrained_optim.jl @@ -75,8 +75,8 @@ x = rand(2) .+ [0, 1] The second component of $x$ is $> 1$, so its square root will be thresholded to one, and the corresponding derivative will be $0$. =# -(first ∘ implicit_cstr_optim)(x) .^ 2 -@test (first ∘ implicit_cstr_optim)(x) .^ 2 ≈ [x[1], 1] #src +implicit_cstr_optim(x) .^ 2 +@test implicit_cstr_optim(x) .^ 2 ≈ [x[1], 1] #src #- @@ -84,8 +84,8 @@ J_thres = Diagonal([0.5 / sqrt(x[1]), 0]) # ## Forward mode autodiff -ForwardDiff.jacobian(first ∘ implicit_cstr_optim, x) -@test ForwardDiff.jacobian(first ∘ implicit_cstr_optim, x) ≈ J_thres #src +ForwardDiff.jacobian(implicit_cstr_optim, x) +@test ForwardDiff.jacobian(implicit_cstr_optim, x) ≈ J_thres #src #- @@ -93,8 +93,8 @@ ForwardDiff.jacobian(mysqrt_cstr_optim, x) # ## Reverse mode autodiff -Zygote.jacobian(first ∘ implicit_cstr_optim, x)[1] -@test Zygote.jacobian(first ∘ implicit_cstr_optim, x)[1] ≈ J_thres #src +Zygote.jacobian(implicit_cstr_optim, x)[1] +@test Zygote.jacobian(implicit_cstr_optim, x)[1] ≈ J_thres #src #- diff --git a/examples/5_multiargs.jl b/examples/5_multiargs.jl index 6bacfce..f382637 100644 --- a/examples/5_multiargs.jl +++ b/examples/5_multiargs.jl @@ -64,8 +64,8 @@ x = ComponentVector(; a=rand(2), b=rand(2)) #- -first(implicit_components(x)) .^ 2 -@test first(implicit_components(x)) .^ 2 ≈ x.a + 2x.b #src +implicit_components(x) .^ 2 +@test implicit_components(x) .^ 2 ≈ x.a + 2x.b #src #= Let's see what the explicit Jacobian looks like. @@ -75,10 +75,10 @@ J = hcat(Diagonal(0.5 ./ sqrt.(x.a + 2x.b)), 2 * Diagonal(0.5 ./ sqrt.(x.a + 2x. # ## Forward mode autodiff -ForwardDiff.jacobian(_x -> first(implicit_components(_x)), x) -@test ForwardDiff.jacobian(_x -> first(implicit_components(_x)), x) ≈ J #src +ForwardDiff.jacobian(implicit_components, x) +@test ForwardDiff.jacobian(implicit_components, x) ≈ J #src # ## Reverse mode autodiff -Zygote.jacobian(_x -> first(implicit_components(_x)), x)[1] -@test Zygote.jacobian(_x -> first(implicit_components(_x)), x)[1] ≈ J #src +Zygote.jacobian(implicit_components, x)[1] +@test Zygote.jacobian(implicit_components, x)[1] ≈ J #src diff --git a/ext/ImplicitDifferentiationChainRulesExt.jl b/ext/ImplicitDifferentiationChainRulesExt.jl index d098356..6673938 100644 --- a/ext/ImplicitDifferentiationChainRulesExt.jl +++ b/ext/ImplicitDifferentiationChainRulesExt.jl @@ -14,12 +14,16 @@ We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and settin Keyword arguments are given to both `implicit.forward` and `implicit.conditions`. """ function ChainRulesCore.rrule( - rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R}; kwargs... -) where {R} + rc::RuleConfig, + implicit::ImplicitFunction, + x::AbstractArray{R}, + ::Val{return_byproduct}; + kwargs..., +) where {R,return_byproduct} conditions = implicit.conditions linear_solver = implicit.linear_solver - y, z = implicit(x; kwargs...) + y, z = implicit(x, Val(true); kwargs...) n, m = length(x), length(y) backend = ReverseRuleConfigBackend(rc) @@ -29,19 +33,26 @@ function ChainRulesCore.rrule( pbmB = PullbackMul!(pbB, size(y)) Aᵀ_op = LinearOperator(R, m, m, false, false, pbmA) Bᵀ_op = LinearOperator(R, n, m, false, false, pbmB) - implicit_pullback = ImplicitPullback(Aᵀ_op, Bᵀ_op, linear_solver, x) + implicit_pullback = ImplicitPullback( + Aᵀ_op, Bᵀ_op, linear_solver, x, Val(return_byproduct) + ) - return (y, z), implicit_pullback + return (return_byproduct ? (y, z) : y), implicit_pullback end -struct ImplicitPullback{A,B,L,X} +struct ImplicitPullback{return_byproduct,A,B,L,X} Aᵀ_op::A Bᵀ_op::B linear_solver::L x::X + _v::Val{return_byproduct} end -function (implicit_pullback::ImplicitPullback)((dy, dz)) +function (pb::ImplicitPullback{false})(dy) + _pb = ImplicitPullback(pb.Aᵀ_op, pb.Bᵀ_op, pb.linear_solver, pb.x, Val(true)) + return _pb((dy, nothing)) +end +function (implicit_pullback::ImplicitPullback{true})((dy, _)) Aᵀ_op = implicit_pullback.Aᵀ_op Bᵀ_op = implicit_pullback.Bᵀ_op linear_solver = implicit_pullback.linear_solver @@ -54,7 +65,7 @@ function (implicit_pullback::ImplicitPullback)((dy, dz)) dx_vec = Bᵀ_op * dF_vec dx_vec .*= -1 dx = reshape(dx_vec, size(x)) - return (NoTangent(), dx) + return (NoTangent(), dx, NoTangent()) end end diff --git a/ext/ImplicitDifferentiationForwardDiffExt.jl b/ext/ImplicitDifferentiationForwardDiffExt.jl index 5b2f4e6..8172104 100644 --- a/ext/ImplicitDifferentiationForwardDiffExt.jl +++ b/ext/ImplicitDifferentiationForwardDiffExt.jl @@ -16,13 +16,13 @@ using LinearOperators: LinearOperator Overload an [`ImplicitFunction`](@ref) on dual numbers to ensure compatibility with ForwardDiff.jl. """ function (implicit::ImplicitFunction)( - x_and_dx::AbstractArray{Dual{T,R,N}}; kwargs... -) where {T,R,N} + x_and_dx::AbstractArray{Dual{T,R,N}}, ::Val{return_byproduct}=Val(false); kwargs... +) where {T,R,N,return_byproduct} conditions = implicit.conditions linear_solver = implicit.linear_solver x = value.(x_and_dx) - y, z = implicit(x; kwargs...) + y, z = implicit(x, Val(true); kwargs...) n, m = length(x), length(y) backend = ForwardDiffBackend() @@ -45,7 +45,7 @@ function (implicit::ImplicitFunction)( end reshape(y_and_dy_vec, size(y)) end - return y_and_dy, z + return return_byproduct ? (y_and_dy, z) : y_and_dy end end diff --git a/src/implicit_function.jl b/src/implicit_function.jl index 1d2ee20..83cc889 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -3,8 +3,7 @@ Differentiable wrapper for an implicit function `x -> y(x)` whose output is defined by conditions `F(x,y(x)) = 0`. -More generally, we consider functions `x -> (y(x),z(x))` and conditions `F(x,y(x),z(x)) = 0`, where `z(x)` contains additional information that _is considered constant for differentiation purposes_. - +More generally, we consider functions `x -> (y(x),z(x))` and conditions `F(x,y(x),z(x)) = 0`, where `z(x)` is a byproduct _considered constant for differentiation purposes_. If `x ∈ ℝⁿ` and `y ∈ ℝᵈ`, then we need as many conditions as output dimensions: `F(x,y,z) ∈ ℝᵈ`. Thanks to these conditions, we can compute the Jacobian of `y(⋅)` using the implicit function theorem: ``` ∂₂F(x,y(x),z(x)) * ∂y(x) = -∂₁F(x,y(x),z(x)) @@ -33,10 +32,15 @@ end """ implicit(x[; kwargs...]) + implicit(x, Val(true), [; kwargs...]) Make `ImplicitFunction` callable by applying `implicit.forward`. + +The first (default) call signature only returns `y(x)`, while the second returns `(y(x), z(x))`. """ -function (implicit::ImplicitFunction)(x; kwargs...) +function (implicit::ImplicitFunction)( + x, ::Val{return_byproduct}=Val(false); kwargs... +) where {return_byproduct} y, z = implicit.forward(x; kwargs...) - return y, z + return return_byproduct ? (y, z) : y end diff --git a/test/misc.jl b/test/misc.jl index 944a3f1..59c9bff 100644 --- a/test/misc.jl +++ b/test/misc.jl @@ -31,24 +31,38 @@ end y, _ = implicit(x) J = Diagonal(0.5 ./ sqrt.(x)) - @testset "Exactness" begin - @test (first ∘ implicit)(x) ≈ sqrt.(x) - @test ForwardDiff.jacobian(first ∘ implicit, x) ≈ J - @test Zygote.jacobian(first ∘ implicit, x)[1] ≈ J + @testset "Call" begin + @test (@inferred implicit(x)) ≈ sqrt.(x) + if VERSION >= v"1.7" + test_opt(implicit, (typeof(x),)) + end end - @testset verbose = true "Forward inference" begin + @testset verbose = true "Forward" begin + @test ForwardDiff.jacobian(implicit, x) ≈ J x_and_dx = ForwardDiff.Dual.(x, ((0, 0),)) @test (@inferred implicit(x_and_dx)) == implicit(x_and_dx) y_and_dy, _ = implicit(x_and_dx) @test size(y_and_dy) == size(y) end - @testset "Reverse type inference" begin - _, pullback = @inferred rrule(Zygote.ZygoteRuleConfig(), implicit, x) - dy, dz = zero(implicit(x)[1]), 0 - @test (@inferred pullback((dy, dz))) == pullback((dy, dz)) - _, dx = pullback((dy, dz)) - @test size(dx) == size(x) + + @testset "Reverse" begin + @test Zygote.jacobian(implicit, x)[1] ≈ J + for return_byproduct in (true, false) + _, pullback = @inferred rrule( + Zygote.ZygoteRuleConfig(), implicit, x, Val(return_byproduct) + ) + dy, dz = zero(implicit(x)), 0 + if return_byproduct + @test (@inferred pullback((dy, dz))) == pullback((dy, dz)) + _, dx = pullback((dy, dz)) + @test size(dx) == size(x) + else + @test (@inferred pullback(dy)) == pullback(dy) + _, dx = pullback(dy) + @test size(dx) == size(x) + end + end end end @@ -57,24 +71,37 @@ end Y, _ = implicit(X) JJ = Diagonal(0.5 ./ sqrt.(vec(X))) - @testset "Exactness" begin - @test (first ∘ implicit)(X) ≈ sqrt.(X) - @test ForwardDiff.jacobian(first ∘ implicit, X) ≈ JJ - @test Zygote.jacobian(first ∘ implicit, X)[1] ≈ JJ + @testset "Call" begin + @test (@inferred implicit(X)) ≈ sqrt.(X) + if VERSION >= v"1.7" + test_opt(implicit, (typeof(X),)) + end end - @testset "Forward type inference" begin + @testset "Forward" begin + @test ForwardDiff.jacobian(implicit, X) ≈ JJ X_and_dX = ForwardDiff.Dual.(X, ((0, 0),)) @test (@inferred implicit(X_and_dX)) == implicit(X_and_dX) Y_and_dY, _ = implicit(X_and_dX) @test size(Y_and_dY) == size(Y) end - @testset "Reverse type inference" begin - _, pullback = @inferred rrule(Zygote.ZygoteRuleConfig(), implicit, X) - dY, dZ = zero(implicit(X)[1]), 0 - @test (@inferred pullback((dY, dZ))) == pullback((dY, dZ)) - _, dX = pullback((dY, dZ)) - @test size(dX) == size(X) + @testset "Reverse" begin + @test Zygote.jacobian(implicit, X)[1] ≈ JJ + for return_byproduct in (true, false) + _, pullback = @inferred rrule( + Zygote.ZygoteRuleConfig(), implicit, X, Val(return_byproduct) + ) + dY, dZ = zero(implicit(X)), 0 + if return_byproduct + @test (@inferred pullback((dY, dZ))) == pullback((dY, dZ)) + _, dX = pullback((dY, dZ)) + @test size(dX) == size(X) + else + @test (@inferred pullback(dY)) == pullback(dY) + _, dX = pullback(dY) + @test size(dX) == size(X) + end + end end end diff --git a/test/runtests.jl b/test/runtests.jl index 4bdb59e..d26134b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -58,11 +58,13 @@ EXAMPLES_DIR_JL = joinpath(dirname(@__DIR__), "examples") @testset verbose = true "Miscellaneous" begin include("misc.jl") end - for file in readdir(EXAMPLES_DIR_JL) - path = joinpath(EXAMPLES_DIR_JL, file) - title = markdown_title(path) - @testset verbose = true "$title" begin - include(path) + @testset verbose = true "Examples" begin + for file in readdir(EXAMPLES_DIR_JL) + path = joinpath(EXAMPLES_DIR_JL, file) + title = markdown_title(path) + @testset verbose = true "$title" begin + include(path) + end end end end