diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index f4e70da..113cff1 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -18,13 +18,13 @@ jobs: strategy: fail-fast: false matrix: - version: ['1.6', '1'] + version: ['1.10', '1'] os: [ubuntu-latest] arch: [x64] allow_failure: [false] steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} @@ -45,7 +45,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v1 + - uses: julia-actions/setup-julia@v2 with: version: '1' - uses: julia-actions/cache@v1 diff --git a/CITATION.bib b/CITATION.bib index a9a43d6..981d199 100644 --- a/CITATION.bib +++ b/CITATION.bib @@ -2,7 +2,7 @@ @misc{ImplicitDifferentiation.jl author = {Guillaume Dalle, Mohamed Tarek and contributors}, title = {ImplicitDifferentiation.jl}, url = {https://github.com/gdalle/ImplicitDifferentiation.jl}, - version = {v0.5.0}, - year = {2023}, - month = {8} + version = {v0.6.0}, + year = {2024}, + month = {4} } diff --git a/Project.toml b/Project.toml index b7be278..d1107d5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,64 +1,73 @@ name = "ImplicitDifferentiation" uuid = "57b37032-215b-411a-8a7c-41a003a55207" authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"] -version = "0.5.2" +version = "0.6.0" [deps] -AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" -PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -Requires = "ae029012-a4dd-5104-9daa-d747884805df" -SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] ImplicitDifferentiationChainRulesCoreExt = "ChainRulesCore" +ImplicitDifferentiationEnzymeExt = "Enzyme" ImplicitDifferentiationForwardDiffExt = "ForwardDiff" -ImplicitDifferentiationStaticArraysExt = "StaticArrays" -ImplicitDifferentiationZygoteExt = "Zygote" [compat] -AbstractDifferentiation = "0.5, 0.6" -ChainRulesCore = "1.14" -ForwardDiff = "0.10" -Krylov = "0.8, 0.9" -LinearAlgebra = "1.6" -LinearOperators = "2.2" -PrecompileTools = "1.1" -Requires = "1.3" -SimpleUnPack = "1.1" -StaticArrays = "1.6" -Zygote = "0.6" -julia = "1.6" +ChainRulesCore = "1.23.0" +Enzyme = "0.11.20" +ForwardDiff = "0.10.36" +Krylov = "0.9.5" +LinearAlgebra = "1.10" +LinearOperators = "2.7.0" +julia = "1.10" [extras] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" Optim = "429524aa-4258-5aef-a3af-852621145aeb" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "Documenter", "FiniteDifferences", "ForwardDiff", "JET", "JuliaFormatter", "NLsolve", "Optim", "Pkg", "Random", "ReverseDiff", "SparseArrays", "StaticArrays", "Test", "Zygote"] +test = [ + "ADTypes", + "Aqua", + "ChainRulesCore", + "ChainRulesTestUtils", + "ComponentArrays", + "DifferentiationInterface", + "Documenter", + "Enzyme", + "ForwardDiff", + "JET", + "JuliaFormatter", + "NLsolve", + "Optim", + "Random", + "SparseArrays", + "StaticArrays", + "Test", + "Zygote", +] diff --git a/README.md b/README.md index fc10902..b513215 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,7 @@ Please read the [documentation](https://gdalle.github.io/ImplicitDifferentiation In Julia: +- [SciML](https://sciml.ai/) ecosystem, especially [LinearSolve.jl](https://github.com/SciML/LinearSolve.jl), [NonlinearSolve.jl](https://github.com/SciML/NonlinearSolve.jl) and [Optimization.jl](https://github.com/SciML/Optimization.jl) - [jump-dev/DiffOpt.jl](https://github.com/jump-dev/DiffOpt.jl): differentiation of convex optimization problems - [axelparmentier/InferOpt.jl](https://github.com/axelparmentier/InferOpt.jl): approximate differentiation of combinatorial optimization problems - [JuliaNonconvex/NonconvexUtils.jl](https://github.com/JuliaNonconvex/NonconvexUtils.jl): contains the original implementation from which this package drew inspiration diff --git a/benchmark/Project.toml b/benchmark/Project.toml deleted file mode 100644 index 9c1fae3..0000000 --- a/benchmark/Project.toml +++ /dev/null @@ -1,13 +0,0 @@ -[deps] -AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207" -Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" -ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/benchmark/analysis.jl b/benchmark/analysis.jl deleted file mode 100644 index be425e1..0000000 --- a/benchmark/analysis.jl +++ /dev/null @@ -1,168 +0,0 @@ - -## Benchmark analysis - -using CSV -using DataFrames -using Plots - -function export_results( - results; - scenario_symbols, - linear_solver_symbols, - backend_symbols, - conditions_backend_symbols, - input_sizes, - output_sizes, - path=joinpath(@__DIR__, "benchmark_results.csv"), -) - min_results = minimum(results) - - data = DataFrame() - - for sc in scenario_symbols, - ls in linear_solver_symbols, - ba in backend_symbols, - cb in conditions_backend_symbols, - is in input_sizes, - os in output_sizes - - try - perf = min_results[sc][ls][ba][cb][is][os] - @unpack time, gctime, memory, allocs = perf - row = (; - scenario=sc, - linear_solver=ls, - backend=ba, - conditions_backend=cb, - input_size=is, - output_size=os, - time, - gctime, - memory, - allocs, - ) - push!(data, row) - catch KeyError - nothing - end - end - - if !isnothing(path) - open(path, "w") do file - CSV.write(file, data) - end - end - return data -end - -function plot_results( - data; - scenario::Symbol, - linear_solver_symbols=unique(data[!, :linear_solver]), - backend_symbols=unique(data[!, :backend]), - conditions_backend_symbols=unique(data[!, :conditions_backend]), - input_size=nothing, - output_size=nothing, - path=joinpath( - @__DIR__, - "benchmark_plot_$(scenario)_$(linear_solver_symbols)_$(backend_symbols)_$(conditions_backend_symbols)_$(input_size)_$(output_size).png", - ), -) - pl = plot(; - size=(800, 400), - ylabel="Time [s] (log)", - legendtitle="lin. solver / AD / cond. AD", - legend=:outerright, - xaxis=:log10, - yaxis=:log10, - margin=5Plots.mm, - legendtitlefontsize=7, - legendfontsize=6, - ) - - data = subset(data, :scenario => _col -> _col .== scenario) - - if isnothing(input_size) && isnothing(output_size) - error("Cannot plot if neither input nor output size is fixed") - elseif !isnothing(input_size) && !isnothing(output_size) - error("Cannot plot if both input and output size are fixed") - elseif !isnothing(input_size) - plot!( - pl; - xlabel="Output dimension (log)", - title="Implicit diff. - $scenario - input size $input_size", - ) - data = subset(data, :input_size => _col -> _col .== Ref(input_size)) - else - plot!( - pl; - xlabel="Input dimension (log)", - title="Implicit diff. - $scenario - output size $output_size", - ) - data = subset(data, :output_size => _col -> _col .== Ref(output_size)) - end - - for ls in linear_solver_symbols, ba in backend_symbols, cb in conditions_backend_symbols - filtered_data = subset( - data, - :linear_solver => _col -> _col .== ls, - :backend => _col -> _col .== ba, - :conditions_backend => _col -> _col .== cb, - ) - - if !isempty(filtered_data) - x = nothing - if !isnothing(output_size) - x = map(prod, filtered_data[!, :input_size]) - elseif !isnothing(output_size) - x = map(prod, filtered_data[!, :output_size]) - end - y = filtered_data[!, :time] ./ 1e9 - plot!( - pl, - x, - y; - linestyle=:auto, - markershape=:auto, - label="$ls / $ba / $(cb == :nothing ? ba : cb)", - ) - end - end - - if !isnothing(path) - savefig(pl, path) - end - return pl -end - -# results = BenchmarkTools.run(SUITE; verbose=true, evals=1, seconds=1) - -# data = export_results( -# results; -# scenario_symbols, -# linear_solver_symbols, -# backend_symbols, -# conditions_backend_symbols, -# input_sizes, -# output_sizes, -# ) - -# plot_results(data; scenario=:pullback, input_size=(1,)) -# plot_results(data; scenario=:pullback, input_size=(10,)) -# plot_results(data; scenario=:pullback, input_size=(100,)) -# plot_results(data; scenario=:pullback, input_size=(1000,)) - -# plot_results(data; scenario=:pushforward, output_size=(1,)) -# plot_results(data; scenario=:pushforward, output_size=(10,)) -# plot_results(data; scenario=:pushforward, output_size=(100,)) -# plot_results(data; scenario=:pushforward, output_size=(1000,)) - -# plot_results(data; scenario=:rrule, input_size=(1,)) -# plot_results(data; scenario=:rrule, input_size=(10,)) -# plot_results(data; scenario=:rrule, input_size=(100,)) -# plot_results(data; scenario=:rrule, input_size=(1000,)) - -# plot_results(data; scenario=:jacobian, input_size=(1,)) -# plot_results(data; scenario=:jacobian, input_size=(10,)) -# plot_results(data; scenario=:jacobian, input_size=(100,)) -# plot_results(data; scenario=:jacobian, input_size=(1000,)) diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl deleted file mode 100644 index eefeb23..0000000 --- a/benchmark/benchmarks.jl +++ /dev/null @@ -1,143 +0,0 @@ -using Pkg -Pkg.activate(@__DIR__) - -using AbstractDifferentiation: ForwardDiffBackend -using BenchmarkTools -using ForwardDiff: ForwardDiff, Dual -using ProgressMeter -using Random -using SimpleUnPack -using Zygote: Zygote - -using ImplicitDifferentiation - -## Benchmark definition - -forward(x; output_size) = fill(sqrt(sum(x)), output_size...) -conditions(x, y; output_size) = abs2.(y) .- sum(x) - -function get_linear_solver(linear_solver_symbol::Symbol) - if linear_solver_symbol == :direct - return DirectLinearSolver() - elseif linear_solver_symbol == :iterative - return IterativeLinearSolver() - end -end - -function get_conditions_backend(conditions_backend_symbol::Symbol) - if conditions_backend_symbol == :nothing - return nothing - elseif conditions_backend_symbol == :ForwardDiff - return ForwardDiffBackend() - end -end - -function create_benchmarkable(; - scenario_symbol, - linear_solver_symbol, - backend_symbol, - conditions_backend_symbol, - input_size, - output_size, -) - linear_solver = get_linear_solver(linear_solver_symbol) - conditions_backend = get_conditions_backend(conditions_backend_symbol) - - if scenario_symbol == :jacobian && prod(input_size) * prod(output_size) >= 10^5 - return nothing - end - - x = rand(input_size...) - implicit = ImplicitFunction( - x -> forward(x; output_size), - (x, y) -> conditions(x, y; output_size); - linear_solver, - conditions_backend, - ) - - dx = similar(x) - dx .= one(eltype(x)) - x_and_dx = Dual.(x, dx) - y = implicit(x) - dy = similar(y) - dy .= one(eltype(y)) - - if scenario_symbol == :jacobian && backend_symbol == :ForwardDiff - return @benchmarkable ForwardDiff.jacobian($implicit, $x) seconds = 1 samples = 100 - elseif scenario_symbol == :jacobian && backend_symbol == :Zygote - return @benchmarkable Zygote.jacobian($implicit, $x) seconds = 1 samples = 100 - elseif scenario_symbol == :rrule && backend_symbol == :Zygote - return @benchmarkable Zygote.pullback($implicit, $x) seconds = 1 samples = 100 - elseif scenario_symbol == :pullback && backend_symbol == :Zygote - _, back = Zygote.pullback(implicit, x) - return @benchmarkable ($back)($dy) seconds = 1 samples = 100 - elseif scenario_symbol == :pushforward && backend_symbol == :ForwardDiff - return @benchmarkable $implicit($x_and_dx) seconds = 1 samples = 100 - else - return nothing - end -end - -function make_suite(; - scenario_symbols, - linear_solver_symbols, - backend_symbols, - conditions_backend_symbols, - input_sizes, - output_sizes, -) - SUITE = BenchmarkGroup() - - for sc in scenario_symbols, - ls in linear_solver_symbols, - ba in backend_symbols, - cb in conditions_backend_symbols, - is in input_sizes, - os in output_sizes - - bench = create_benchmarkable(; - scenario_symbol=sc, - linear_solver_symbol=ls, - backend_symbol=ba, - conditions_backend_symbol=cb, - input_size=is, - output_size=os, - ) - - isnothing(bench) && continue - - if !haskey(SUITE, sc) - SUITE[sc] = BenchmarkGroup() - end - if !haskey(SUITE[sc], ls) - SUITE[sc][ls] = BenchmarkGroup() - end - if !haskey(SUITE[sc][ls], ba) - SUITE[sc][ls][ba] = BenchmarkGroup() - end - if !haskey(SUITE[sc][ls][ba], cb) - SUITE[sc][ls][ba][cb] = BenchmarkGroup() - end - if !haskey(SUITE[sc][ls][ba][cb], is) - SUITE[sc][ls][ba][cb][is] = BenchmarkGroup() - end - SUITE[sc][ls][ba][cb][is][os] = bench - end - return SUITE -end - -scenario_symbols = (:jacobian, :rrule, :pullback, :pushforward) -linear_solver_symbols = (:direct, :iterative) -backend_symbols = (:Zygote, :ForwardDiff) -conditions_backend_symbols = (:nothing, :ForwardDiff) -input_sizes = [(n,) for n in floor.(Int, 10 .^ (0:1:3))]; -output_sizes = [(n,) for n in floor.(Int, 10 .^ (0:1:3))]; - -SUITE = make_suite(; - scenario_symbols, - linear_solver_symbols, - backend_symbols, - conditions_backend_symbols, - input_sizes, - output_sizes, -) diff --git a/benchmark/judge.jl b/benchmark/judge.jl deleted file mode 100644 index c361ba6..0000000 --- a/benchmark/judge.jl +++ /dev/null @@ -1,13 +0,0 @@ -using BenchmarkTools -using PkgBenchmark - -pkg = dirname(@__DIR__) # this git repo -baseline = "112d549" # commit id -target = "82242b9" # commit id - -results_baseline = benchmarkpkg(pkg, baseline; verbose=true, retune=false) -results_target = benchmarkpkg(pkg, target; verbose=true, retune=false) - -judgement = judge(results_target, results_baseline, minimum) - -export_markdown(joinpath(@__DIR__, "benchmark_judgement.md"), judgement) diff --git a/docs/Project.toml b/docs/Project.toml index b070058..b310c93 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,8 +1,8 @@ [deps] -AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ImplicitDifferentiation = "57b37032-215b-411a-8a7c-41a003a55207" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" @@ -15,4 +15,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -Documenter = "1.3" \ No newline at end of file +Documenter = "1.3" diff --git a/docs/make.jl b/docs/make.jl index 5f8b71e..b08b857 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,4 +1,3 @@ -using ChainRulesCore: ChainRulesCore using Documenter using ForwardDiff: ForwardDiff using ImplicitDifferentiation @@ -11,36 +10,11 @@ DocMeta.setdocmeta!( ImplicitDifferentiation, :DocTestSetup, :(using ImplicitDifferentiation); recursive=true ) -base_url = "https://github.com/gdalle/ImplicitDifferentiation.jl/blob/main/" - -open(joinpath(@__DIR__, "src", "index.md"), "w") do io - # Point to source license file - println( - io, - """ - ```@meta - EditURL = "$(base_url)README.md" - ``` - """, - ) - # Write the contents out below the meta block - for line in eachline(joinpath(dirname(@__DIR__), "README.md")) - println(io, line) - end -end - -function markdown_title(path) - title = "?" - open(path, "r") do file - for line in eachline(file) - if startswith(line, '#') - title = strip(line, [' ', '#']) - break - end - end - end - return title -end +cp( + joinpath(dirname(@__DIR__), "README.md"), + joinpath(@__DIR__, "src", "index.md"); + force=true, +) EXAMPLES_DIR_JL = joinpath(dirname(@__DIR__), "examples") EXAMPLES_DIR_MD = joinpath(@__DIR__, "src", "examples") @@ -60,53 +34,25 @@ for file in readdir(EXAMPLES_DIR_JL) ) end -example_pages = Pair{String,String}[] -for file in sort(readdir(EXAMPLES_DIR_MD)) - if endswith(file, ".md") - title = markdown_title(joinpath(EXAMPLES_DIR_MD, file)) - path = joinpath("examples", file) - push!(example_pages, title => path) - end -end - pages = [ "Home" => "index.md", - "FAQ" => "faq.md", - "Examples" => example_pages, - "API reference" => "api.md", + "Examples" => [ + joinpath("examples", file) for + file in sort(readdir(EXAMPLES_DIR_MD)) if endswith(file, ".md") + ], + "api.md", + "faq.md", ] -fmt = Documenter.HTML(; - prettyurls=get(ENV, "CI", "false") == "true", - canonical="https://gdalle.github.io/ImplicitDifferentiation.jl", - assets=String[], - edit_link=:commit, -) - -if isdefined(Base, :get_extension) - extension_modules = [ - Base.get_extension(ID, :ImplicitDifferentiationChainRulesCoreExt), - Base.get_extension(ID, :ImplicitDifferentiationForwardDiffExt), - ] -else - extension_modules = [ - ID.ImplicitDifferentiationChainRulesCoreExt, - ID.ImplicitDifferentiationForwardDiffExt, - ] -end - makedocs(; modules=[ImplicitDifferentiation], authors="Guillaume Dalle, Mohamed Tarek and contributors", repo=Documenter.Remotes.GitHub("gdalle", "ImplicitDifferentiation.jl"), sitename="ImplicitDifferentiation.jl", - format=fmt, + format=Documenter.HTML(; + canonical="https://gdalle.github.io/ImplicitDifferentiation.jl" + ), pages=pages, - linkcheck=true, -) - -deploydocs(; - repo="github.com/gdalle/ImplicitDifferentiation.jl", devbranch="main", push_preview=true ) -rm(joinpath(@__DIR__, "src", "index.md")) +deploydocs(; repo="github.com/gdalle/ImplicitDifferentiation.jl", devbranch="main") diff --git a/docs/src/faq.md b/docs/src/faq.md index 4451f57..6f52441 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -1,21 +1,18 @@ -# Frequently Asked Questions +# FAQ ## Supported autodiff backends To differentiate an `ImplicitFunction`, the following backends are supported. | Backend | Forward mode | Reverse mode | -| ---------------------------------------------------------------------- | ------------ | ------------ | +| :--------------------------------------------------------------------- | :----------- | :----------- | | [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) | yes | - | -| [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible | soon | yes | -| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) | someday | someday | +| [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible | no | yes | +| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) | yes | soon | -By default, the conditions are differentiated with the same backend as the `ImplicitFunction` that contains them. -However, this can be switched to any backend compatible with [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) (i.e. a subtype of `AD.AbstractBackend`). -You can specify it with the `conditions_backend` keyword argument when constructing an `ImplicitFunction`. - -!!! warning "Warning" - At the moment, `conditions_backend` can only be `nothing` or `AD.ForwardDiffBackend()`. We are investigating why the other backends fail. +By default, the conditions are differentiated using the same "outer" backend that is trying to differentiate the `ImplicitFunction`. +However, this can be switched to any other "inner" backend compatible with [DifferentiationInterface.jl](https://github.com/gdalle/DifferentiationInterface.jl) (i.e. a subtype of `ADTypes.AbstractADType`). +You can override the default with the `conditions_x_backend` and `conditions_y_backend` keyword arguments when constructing an `ImplicitFunction`. ## Input and output types @@ -37,7 +34,7 @@ Or better yet, wrap it in a static vector: `SVector(val)`. Sparse arrays are not officially supported and might give incorrect values or `NaN`s! With ForwardDiff.jl, differentiation of sparse arrays will always give wrong results due to [sparsity pattern cancellation](https://github.com/JuliaDiff/ForwardDiff.jl/issues/658). -With Zygote.jl it appears to work, but this functionality is considered experimental and might evolve. +That is why we do not test behavior for sparse inputs. ## Number of inputs and outputs @@ -46,9 +43,9 @@ What can you do to handle multiple inputs or outputs? Well, it depends whether you want their derivatives or not. | | Derivatives needed | Derivatives not needed | -| -------------------- | --------------------------------------- | --------------------------------------- | +| :------------------- | :-------------------------------------- | :-------------------------------------- | | **Multiple inputs** | Make `x` a `ComponentVector` | Supply `args` and `kwargs` to `forward` | -| **Multiple outputs** | Make `y` and `c` two `ComponentVector`s | Let `forward` return a byproduct | +| **Multiple outputs** | Make `y` and `c` two `ComponentVector`s | Let `forward` return a byproduct `z` | We now detail each of these options. @@ -100,7 +97,7 @@ A more advanced application is given by [DifferentiableFrankWolfe.jl](https://gi ### Writing conditions We recommend that the conditions themselves do not involve calls to autodiff, even when they describe a gradient. -Otherwise, you will need to make sure that nested autodiff works well in your case. +Otherwise, you will need to make sure that nested autodiff works well in your case (i.e. that the "outer" backend can differentiate through the "inner" backend). For instance, if you're differentiating your implicit function (and your conditions) in reverse mode with Zygote.jl, you may want to use ForwardDiff.jl mode to compute gradients inside the conditions. ### Dealing with constraints diff --git a/examples/0_intro.jl b/examples/0_intro.jl index ba5d3f3..051f381 100644 --- a/examples/0_intro.jl +++ b/examples/0_intro.jl @@ -4,17 +4,13 @@ We explain the basics of our package on a simple function that is not amenable to naive automatic differentiation. =# -using ChainRulesCore #src using ForwardDiff using ImplicitDifferentiation using JET #src using LinearAlgebra -using Random using Test #src using Zygote -Random.seed!(63); - # ## Why do we bother? #= @@ -24,17 +20,17 @@ While they are very generic, there are simple language constructs that they cann function badsqrt(x::AbstractArray) a = [0.0] - a[1] = first(x) + a[1] = x[1] return sqrt.(x) -end +end; #= This is essentially the componentwise square root function but with an additional twist: `a::Vector{Float64}` is created internally, and its only element is replaced with the first element of `x`. We can check that it does what it's supposed to do. =# -x = rand(2) -badsqrt(x) ≈ sqrt.(x) +x = [4.0, 9.0] +badsqrt(x) @test badsqrt(x) ≈ sqrt.(x) #src #= @@ -79,9 +75,9 @@ x \in \mathbb{R}^n \longmapsto y(x) \in \mathbb{R}^m ``` whose output is defined by conditions ```math -F(x,y(x)) = 0 \in \mathbb{R}^m +c(x,y(x)) = 0 \in \mathbb{R}^m ``` -We represent it using a type called `ImplicitFunction`, which you will see in action shortly. +We represent it using a type called [`ImplicitFunction`](@ref), which you will see in action shortly. =# #= @@ -90,7 +86,7 @@ It returns the actual output $y(x)$ of the function, and can be thought of as a Importantly, this Julia callable _doesn't need to be differentiable by automatic differentiation packages but the underlying function still needs to be mathematically differentiable_. =# -forward(x) = badsqrt(x) +forward(x) = badsqrt(x); #= Then we define `conditions` $c(x, y) = 0$ that the output $y(x)$ is supposed to satisfy. @@ -102,7 +98,7 @@ Here the conditions are very obvious: the square of the square root should be eq function conditions(x, y) c = y .^ 2 .- x return c -end +end; #= Finally, we construct a wrapper `implicit` around the previous objects. @@ -114,10 +110,10 @@ implicit = ImplicitFunction(forward, conditions) #= What does this wrapper do? -When we call it as a function, it just falls back on `first ∘ implicit.forward`, so unsurprisingly we get the first output $y(x)$. +When we call it as a function, it just falls back on `implicit.forward`, so unsurprisingly we get the output $y(x)$. =# -implicit(x) ≈ sqrt.(x) +implicit(x) @test implicit(x) ≈ sqrt.(x) #src #= @@ -139,23 +135,3 @@ And so does Zygote.jl. Hurray! Zygote.jacobian(implicit, x)[1] ≈ J @test Zygote.jacobian(implicit, x)[1] ≈ J #src - -# ## Second derivative - -#= -We can even go higher-order by mixing the two packages (forward-over-reverse mode). -The only technical requirement is to switch the linear solver to something that can handle dual numbers: -=# - -implicit_higher_order = ImplicitFunction( - forward, conditions; linear_solver=DirectLinearSolver() -) - -#= -Then the Jacobian itself is differentiable. -=# - -h = rand(2) -J_Z(t) = Zygote.jacobian(implicit_higher_order, x .+ t .* h)[1] -ForwardDiff.derivative(J_Z, 0) ≈ Diagonal((-0.25 .* h) ./ (x .^ 1.5)) -@test ForwardDiff.derivative(J_Z, 0) ≈ Diagonal((-0.25 .* h) ./ (x .^ 1.5)) #src diff --git a/examples/1_basic.jl b/examples/1_basic.jl index 942e68f..c713d17 100644 --- a/examples/1_basic.jl +++ b/examples/1_basic.jl @@ -5,6 +5,8 @@ We show how to differentiate through very common routines: - an unconstrained optimization problem - a nonlinear system of equations - a fixed point iteration + +Note that some packages from the [SciML](https://sciml.ai/) ecosystem provide a similar implicit differentiation mechanism. =# using ForwardDiff @@ -12,18 +14,15 @@ using ImplicitDifferentiation using LinearAlgebra using NLsolve using Optim -using Random using Test #src using Zygote -Random.seed!(63); - #= In all three cases, we will use the square root as our forward mapping, but expressed in three different ways. Here's our heroic test vector: =# -x = rand(2); +x = [4.0, 9.0]; #= Since we already know the mathematical expression of the Jacobian, we will be able to compare it with our numerical results. @@ -40,7 +39,7 @@ y(x) = \underset{y \in \mathbb{R}^m}{\mathrm{argmin}} ~ f(x, y) ``` The optimality conditions are given by gradient stationarity: ```math -\nabla_2 f(x, y) = 0 +c(x, y) = \nabla_2 f(x, y) = 0 ``` =# @@ -58,7 +57,7 @@ function forward_optim(x; method) y0 = ones(eltype(x), size(x)) result = optimize(f, y0, method) return Optim.minimizer(result) -end +end; #= Even though they are defined as a gradient, it is better to provide optimality conditions explicitly: that way we avoid nesting autodiff calls. By default, the conditions should accept two arguments as input. @@ -68,13 +67,13 @@ The forward mapping and the conditions should accept the same set of keyword arg function conditions_optim(x, y; method) ∇₂f = @. 4 * (y^2 - x) * y return ∇₂f -end +end; #= We now have all the ingredients to construct our implicit function. =# -implicit_optim = ImplicitFunction(forward_optim, conditions_optim) +implicit_optim = ImplicitFunction(; forward=forward_optim, conditions=conditions_optim) # And indeed, it behaves as it should when we call it: @@ -90,7 +89,7 @@ ForwardDiff.jacobian(_x -> implicit_optim(_x; method=LBFGS()), x) In this instance, we could use ForwardDiff.jl directly on the solver, but it returns the wrong result (not sure why). =# -ForwardDiff.jacobian(_x -> forward_optim(x; method=LBFGS()), x) +ForwardDiff.jacobian(_x -> forward_optim(_x; method=LBFGS()), x) # Reverse mode autodiff @@ -102,7 +101,7 @@ In this instance, we cannot use Zygote.jl directly on the solver (due to unsuppo =# try - Zygote.jacobian(_x -> forward_optim(x; method=LBFGS()), x)[1] + Zygote.jacobian(_x -> forward_optim(_x; method=LBFGS()), x)[1] catch e e end @@ -112,18 +111,18 @@ end #= Next, we show how to differentiate through the solution of a nonlinear system of equations: ```math -\text{find} \quad y(x) \quad \text{such that} \quad F(x, y(x)) = 0 +\text{find} \quad y(x) \quad \text{such that} \quad c(x, y(x)) = 0 ``` The optimality conditions are pretty obvious: ```math -F(x, y) = 0 +c(x, y) = 0 ``` =# #= To make verification easy, we solve the following system: ```math -F(x, y) = y \odot y - x = 0 +c(x, y) = y \odot y - x = 0 ``` In this case, the optimization problem boils down to the componentwise square root function, but we implement it using a black box solver from [NLsolve.jl](https://github.com/JuliaNLSolvers/NLsolve.jl). =# @@ -134,14 +133,14 @@ function forward_nlsolve(x; method) initial_y .= 1 result = nlsolve(F!, initial_y; method) return result.zero -end +end; #- function conditions_nlsolve(x, y; method) c = y .^ 2 .- x return c -end +end; #- @@ -179,18 +178,18 @@ end #= Finally, we show how to differentiate through the limit of a fixed point iteration: ```math -y \longmapsto T(x, y) +y \longmapsto g(x, y) ``` The optimality conditions are pretty obvious: ```math -y = T(x, y) +c(x, y) = g(x, y) - y = 0 ``` =# #= To make verification easy, we consider [Heron's method](https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Heron's_method): ```math -T(x, y) = \frac{1}{2} \left(y + \frac{x}{y}\right) +g(x, y) = \frac{1}{2} \left(y + \frac{x}{y}\right) ``` In this case, the fixed point algorithm boils down to the componentwise square root function, but we implement it manually. =# @@ -198,17 +197,17 @@ In this case, the fixed point algorithm boils down to the componentwise square r function forward_fixedpoint(x; iterations) y = ones(eltype(x), size(x)) for _ in 1:iterations - y .= 0.5 .* (y .+ x ./ y) + y .= (y .+ x ./ y) ./ 2 end return y -end +end; #- function conditions_fixedpoint(x, y; iterations) - T = 0.5 .* (y .+ x ./ y) - return T .- y -end + g = (y .+ x ./ y) ./ 2 + return g .- y +end; #- diff --git a/examples/2_advanced.jl b/examples/2_advanced.jl index d1d48ff..2309c27 100644 --- a/examples/2_advanced.jl +++ b/examples/2_advanced.jl @@ -9,12 +9,9 @@ using ForwardDiff using ImplicitDifferentiation using LinearAlgebra using Optim -using Random using Test #src using Zygote -Random.seed!(63); - # ## Constrained optimization #= @@ -46,19 +43,17 @@ function forward_cstr_optim(x) res = optimize(f, lower, upper, y0, Fminbox(GradientDescent())) y = Optim.minimizer(res) return y -end +end; #- -function proj_hypercube(p) - return max.(0, min.(1, p)) -end +proj_hypercube(p) = max.(0, min.(1, p)) function conditions_cstr_optim(x, y) ∇₂f = @. 4 * (y^2 - x) * y η = 0.1 return y .- proj_hypercube(y .- η .* ∇₂f) -end +end; # We now have all the ingredients to construct our implicit function. @@ -66,7 +61,7 @@ implicit_cstr_optim = ImplicitFunction(forward_cstr_optim, conditions_cstr_optim # And indeed, it behaves as it should when we call it: -x = rand(2) .+ [0, 1] +x = [0.3, 1.4] #= The second component of $x$ is $> 1$, so its square root will be thresholded to one, and the corresponding derivative will be $0$. diff --git a/examples/3_tricks.jl b/examples/3_tricks.jl index f4532f6..cc1c524 100644 --- a/examples/3_tricks.jl +++ b/examples/3_tricks.jl @@ -9,12 +9,9 @@ using ForwardDiff using ImplicitDifferentiation using Krylov using LinearAlgebra -using Random using Test #src using Zygote -Random.seed!(63); - # ## ComponentArrays # For when you need derivatives with respect to multiple inputs or outputs. @@ -55,7 +52,7 @@ Krylov.ktypeof(::ComponentVector{T,V}) where {T,V} = V # Now we're good to go. -a, b, m = rand(2), rand(3), 7 +a, b, m = [1.0, 2.0], [3.0, 4.0, 5.0], 6.0 x = ComponentVector(; a=a, b=b, m=m) implicit_components(x) @@ -83,7 +80,7 @@ end; # For when you need an additional output but don't care about its derivative. function forward_byproduct(x) - z = rand((2, 2)) # "randomized" choice + z = 2 # "randomized" choice y = x .^ (1 / z) return y, z end @@ -99,7 +96,7 @@ implicit_byproduct = ImplicitFunction(forward_byproduct, conditions_byproduct); # But this time the return value is a tuple `(y, z)` -x = rand(3) +x = [4.0, 9.0] implicit_byproduct(x) # And it works with both ForwardDiff.jl and Zygote.jl diff --git a/ext/ImplicitDifferentiationChainRulesCoreExt.jl b/ext/ImplicitDifferentiationChainRulesCoreExt.jl index cd24173..cea6a32 100644 --- a/ext/ImplicitDifferentiationChainRulesCoreExt.jl +++ b/ext/ImplicitDifferentiationChainRulesCoreExt.jl @@ -1,95 +1,52 @@ module ImplicitDifferentiationChainRulesCoreExt -using AbstractDifferentiation: AbstractBackend, ReverseRuleConfigBackend, ruleconfig -using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig +using ADTypes: AbstractADType, AutoChainRules +using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, RuleConfig using ChainRulesCore: rrule, rrule_via_ad, unthunk, @not_implemented -using ImplicitDifferentiation: ImplicitDifferentiation -using ImplicitDifferentiation: ImplicitFunction -using ImplicitDifferentiation: conditions_reverse_operators -using ImplicitDifferentiation: get_output, presolve, solve -using LinearAlgebra: mul! -using SimpleUnPack: @unpack +using ImplicitDifferentiation: ImplicitFunction, build_Aᵀ, build_Bᵀ, output -""" - rrule(rc, implicit, x, args...; kwargs...) - -Custom reverse rule for an [`ImplicitFunction`](@ref), to ensure compatibility with reverse mode autodiff. - -This is only available if ChainRulesCore.jl is loaded (extension), except on Julia < 1.9 where it is always available. - -We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and setting `Jᵀv = -Bᵀu`. -Positional and keyword arguments are passed to both `implicit.forward` and `implicit.conditions`. -""" function ChainRulesCore.rrule( - rc::RuleConfig, implicit::ImplicitFunction, x::X, args...; kwargs... -) where {R,X<:AbstractArray{R}} - linear_solver = implicit.linear_solver + rc::RuleConfig, + implicit::ImplicitFunction, + x::AbstractVector, + args::Vararg{T,N}; + kwargs..., +) where {T,N} y_or_yz = implicit(x, args...; kwargs...) - backend = reverse_conditions_backend(rc, implicit) - Aᵀ_vec, Bᵀ_vec = conditions_reverse_operators( - backend, implicit, x, y_or_yz, args; kwargs - ) - Aᵀ_vec_presolved = presolve(linear_solver, Aᵀ_vec, get_output(y_or_yz)) - byproduct = y_or_yz isa Tuple - nbargs = length(args) - implicit_pullback = ImplicitPullback{byproduct,nbargs}( - Aᵀ_vec_presolved, Bᵀ_vec, linear_solver, vec(x), size(x) + suggested_backend = AutoChainRules(rc) + Aᵀ = build_Aᵀ(implicit, x, y_or_yz, args...; suggested_backend, kwargs...) + Bᵀ = build_Bᵀ(implicit, x, y_or_yz, args...; suggested_backend, kwargs...) + project_x = ProjectTo(x) + + implicit_pullback = ImplicitPullback( + Aᵀ, Bᵀ, implicit.linear_solver, project_x, Val{N}() ) return y_or_yz, implicit_pullback end -function reverse_conditions_backend( - rc::RuleConfig, ::ImplicitFunction{F,C,L,Nothing} -) where {F,C,L} - return ReverseRuleConfigBackend(rc) -end - -function reverse_conditions_backend( - ::RuleConfig, implicit::ImplicitFunction{F,C,L,<:AbstractBackend} -) where {F,C,L} - return implicit.conditions_backend -end - -struct ImplicitPullback{byproduct,nbargs,A,B,L,X,N} - Aᵀ_vec::A - Bᵀ_vec::B +struct ImplicitPullback{N,M1,M2,L,P} + Aᵀ::M1 + Bᵀ::M2 linear_solver::L - x_vec::X - x_size::NTuple{N,Int} - - function ImplicitPullback{byproduct,nbargs}( - Aᵀ_vec::A, Bᵀ_vec::B, linear_solver::L, x_vec::X, x_size::NTuple{N,Int} - ) where {byproduct,nbargs,A,B,L,X,N} - return new{byproduct,nbargs,A,B,L,X,N}(Aᵀ_vec, Bᵀ_vec, linear_solver, x_vec, x_size) - end + project_x::P + nargs::Val{N} end -function (implicit_pullback::ImplicitPullback{true})((dy, dz)) - return apply_implicit_pullback(implicit_pullback, dy) -end - -function (implicit_pullback::ImplicitPullback{false})(dy) - return apply_implicit_pullback(implicit_pullback, dy) +function (ip::ImplicitPullback{N})(dy_or_dydz) where {N} + (; Aᵀ, Bᵀ, linear_solver, project_x) = ip + dy = output(unthunk(dy_or_dydz)) + dc = linear_solver(Aᵀ, -dy) + dx = Bᵀ * dc + df = NoTangent() + dargs = ntuple(unimplemented_tangent, N) + return (df, project_x(dx), dargs...) end function unimplemented_tangent(_) return @not_implemented( - "Tangents for positional arguments of an ImplicitFunction beyond x (the first one) are not implemented" + "Tangents for positional arguments of an `ImplicitFunction` beyond `x` (the first one) are not implemented" ) end -function apply_implicit_pullback( - implicit_pullback::ImplicitPullback{byproduct,nbargs}, dy_thunk -) where {byproduct,nbargs} - @unpack Aᵀ_vec, Bᵀ_vec, linear_solver, x_vec, x_size = implicit_pullback - dy = unthunk(dy_thunk) - dy_vec = vec(dy) - dc_vec = solve(linear_solver, Aᵀ_vec, -dy_vec) - dx_vec = similar(x_vec) - mul!(dx_vec, Bᵀ_vec, dc_vec) - dx = reshape(dx_vec, x_size) - return (NoTangent(), dx, ntuple(unimplemented_tangent, nbargs)...) -end - end diff --git a/ext/ImplicitDifferentiationEnzymeExt.jl b/ext/ImplicitDifferentiationEnzymeExt.jl new file mode 100644 index 0000000..eaeec45 --- /dev/null +++ b/ext/ImplicitDifferentiationEnzymeExt.jl @@ -0,0 +1,54 @@ +module ImplicitDifferentiationEnzymeExt + +using ADTypes +using Enzyme +using Enzyme.EnzymeCore +using ImplicitDifferentiation: ImplicitFunction, build_A, build_B, byproduct, output + +function EnzymeRules.forward( + func::Const{<:ImplicitFunction}, + RT::Type{<:Union{BatchDuplicated,BatchDuplicatedNoNeed}}, + func_x::Union{BatchDuplicated{T,N},BatchDuplicatedNoNeed{T,N}}, + func_args::Vararg{Const,P}, +) where {T,N,P} + implicit = func.val + x = func_x.val + dx = func_x.dval + args = map(a -> a.val, func_args) + + y_or_yz = implicit(x, args...) + y = output(y_or_yz) + Y = typeof(y) + + suggested_backend = AutoEnzyme(Enzyme.Forward) + A = build_A(implicit, x, y_or_yz, args...; suggested_backend) + B = build_B(implicit, x, y_or_yz, args...; suggested_backend) + + dx_batch = reduce(hcat, dx) + dc_batch = mapreduce(hcat, eachcol(dx_batch)) do dₖx + B * dₖx + end + dy_batch = implicit.linear_solver(A, -dc_batch) + + dy::NTuple{N,Y} = ntuple(k -> convert(Y, dy_batch[:, k]), Val(N)) + + if y_or_yz isa AbstractArray + if RT <: BatchDuplicated + return BatchDuplicated(y, dy) + elseif RT <: BatchDuplicatedNoNeed + return dy + end + elseif y_or_yz isa Tuple + yz = y_or_yz + z = byproduct(yz) + Z = typeof(z) + dyz::NTuple{N,Tuple{Y,Z}} = ntuple(k -> (dy[k], make_zero(z)), Val(N)) + if RT <: BatchDuplicated + return BatchDuplicated(yz, dyz) + elseif RT <: BatchDuplicatedNoNeed + return dyz + end + end +end + +end diff --git a/ext/ImplicitDifferentiationForwardDiffExt.jl b/ext/ImplicitDifferentiationForwardDiffExt.jl index cb83390..3337a92 100644 --- a/ext/ImplicitDifferentiationForwardDiffExt.jl +++ b/ext/ImplicitDifferentiationForwardDiffExt.jl @@ -1,82 +1,52 @@ module ImplicitDifferentiationForwardDiffExt -@static if isdefined(Base, :get_extension) - using ForwardDiff: Dual, Partials, jacobian, partials, value -else - using ..ForwardDiff: Dual, Partials, jacobian, partials, value -end - -using AbstractDifferentiation: AbstractBackend, ForwardDiffBackend -using ImplicitDifferentiation: ImplicitFunction, DirectLinearSolver, IterativeLinearSolver -using ImplicitDifferentiation: conditions_forward_operators -using ImplicitDifferentiation: get_output, get_byproduct, presolve, solve -using ImplicitDifferentiation: identity_break_autodiff -using LinearAlgebra: mul! -using PrecompileTools: @compile_workload - -""" - implicit(x_and_dx::AbstractArray{<:Dual}, args...; kwargs...) - -Overload an [`ImplicitFunction`](@ref) on dual numbers to ensure compatibility with forward mode autodiff. +using ADTypes: AutoForwardDiff +using ForwardDiff: Chunk, Dual, Partials, jacobian, partials, value +using ImplicitDifferentiation: ImplicitFunction, build_A, build_B, byproduct, output -This is only available if ForwardDiff.jl is loaded (extension). +chunksize(::Chunk{N}) where {N} = N -We compute the Jacobian-vector product `Jv` by solving `Au = -Bv` and setting `Jv = u`. -Positional and keyword arguments are passed to both `implicit.forward` and `implicit.conditions`. -""" function (implicit::ImplicitFunction)( - x_and_dx::AbstractArray{Dual{T,R,N}}, args...; kwargs... + x_and_dx::AbstractVector{Dual{T,R,N}}, args...; kwargs... ) where {T,R,N} - linear_solver = implicit.linear_solver - x = value.(x_and_dx) y_or_yz = implicit(x, args...; kwargs...) - y = get_output(y_or_yz) - y_vec = vec(y) - - backend = forward_conditions_backend(implicit) - A_vec, B_vec = conditions_forward_operators(backend, implicit, x, y_or_yz, args; kwargs) - A_vec_presolved = presolve(linear_solver, A_vec, y) - - dy = ntuple(Val(N)) do k - dₖx = partials.(x_and_dx, k) - dₖx_vec = vec(dₖx) - dₖc_vec = similar(y_vec) - mul!(dₖc_vec, B_vec, dₖx_vec) - dₖy_vec = solve(implicit.linear_solver, A_vec_presolved, -dₖc_vec) - reshape(dₖy_vec, size(y)) + y = output(y_or_yz) + + A = build_A( + implicit, + x, + y_or_yz, + args...; + suggested_backend=AutoForwardDiff(; tag=T(), chunksize=chunksize(Chunk(y))), + kwargs..., + ) + B = build_B( + implicit, + x, + y_or_yz, + args...; + suggested_backend=AutoForwardDiff(; tag=T(), chunksize=chunksize(Chunk(x))), + kwargs..., + ) + + dX = mapreduce(hcat, 1:N) do k + partials.(x_and_dx, k) + end + dC = mapreduce(hcat, eachcol(dX)) do dₖx + B * dₖx end + dY = implicit.linear_solver(A, -dC) - y_and_dy = map(eachindex(IndexCartesian(), y)) do i - Dual{T}(y[i], Partials(ntuple(k -> dy[k][i], Val(N)))) + y_and_dy = map(eachindex(y)) do i + Dual{T}(y[i], Partials(ntuple(k -> dY[i, k], Val(N)))) end if y_or_yz isa Tuple - return y_and_dy, get_byproduct(y_or_yz) + return y_and_dy, byproduct(y_or_yz) else return y_and_dy end end -function forward_conditions_backend(::ImplicitFunction{F,C,L,Nothing}) where {F,C,L} - return ForwardDiffBackend() -end - -function forward_conditions_backend( - implicit::ImplicitFunction{F,C,L,<:AbstractBackend} -) where {F,C,L} - return implicit.conditions_backend -end - -@compile_workload begin - forward(x) = sqrt.(identity_break_autodiff(x)) - conditions(x, y) = y .^ 2 .- x - for linear_solver in (DirectLinearSolver(), IterativeLinearSolver()) - implicit = ImplicitFunction(forward, conditions; linear_solver) - x = rand(2) - implicit(x) - jacobian(implicit, x) - end -end - end diff --git a/ext/ImplicitDifferentiationStaticArraysExt.jl b/ext/ImplicitDifferentiationStaticArraysExt.jl deleted file mode 100644 index 90f3998..0000000 --- a/ext/ImplicitDifferentiationStaticArraysExt.jl +++ /dev/null @@ -1,25 +0,0 @@ -module ImplicitDifferentiationStaticArraysExt - -@static if isdefined(Base, :get_extension) - using StaticArrays: StaticArray, MMatrix, StaticVector -else - using ..StaticArrays: StaticArray, MMatrix, StaticVector -end - -import ImplicitDifferentiation: ImplicitDifferentiation, DirectLinearSolver -using LinearAlgebra: lu, mul! - -function ImplicitDifferentiation.presolve(::DirectLinearSolver, A, y::StaticArray) - T = eltype(A) - m = length(y) - A_static = zero(MMatrix{m,m,T}) - v = vec(similar(y, T)) - for i in axes(A_static, 2) - v .= zero(T) - v[i] = one(T) - mul!(@view(A_static[:, i]), A, v) - end - return lu(A_static) -end - -end diff --git a/ext/ImplicitDifferentiationZygoteExt.jl b/ext/ImplicitDifferentiationZygoteExt.jl deleted file mode 100644 index fb64e4e..0000000 --- a/ext/ImplicitDifferentiationZygoteExt.jl +++ /dev/null @@ -1,24 +0,0 @@ -module ImplicitDifferentiationZygoteExt - -@static if isdefined(Base, :get_extension) - using Zygote: jacobian -else - using ..Zygote: jacobian -end - -using ImplicitDifferentiation: ImplicitFunction, identity_break_autodiff -using ImplicitDifferentiation: DirectLinearSolver, IterativeLinearSolver -using PrecompileTools: @compile_workload - -@compile_workload begin - forward(x) = sqrt.(identity_break_autodiff(x)) - conditions(x, y) = y .^ 2 .- x - for linear_solver in (DirectLinearSolver(), IterativeLinearSolver()) - implicit = ImplicitFunction(forward, conditions; linear_solver) - x = rand(2) - implicit(x) - jacobian(implicit, x) - end -end - -end diff --git a/src/ImplicitDifferentiation.jl b/src/ImplicitDifferentiation.jl index 0fc55d8..07c6d7b 100644 --- a/src/ImplicitDifferentiation.jl +++ b/src/ImplicitDifferentiation.jl @@ -7,38 +7,20 @@ Its main export is the type [`ImplicitFunction`](@ref). """ module ImplicitDifferentiation -using AbstractDifferentiation: AbstractBackend -using AbstractDifferentiation: pushforward_function, pullback_function, jacobian -using Krylov: gmres -using LinearOperators: LinearOperators, LinearOperator -using LinearAlgebra: issuccess, lu -using PrecompileTools: @compile_workload -using Requires: @require -using SimpleUnPack: @unpack +using ADTypes: AbstractADType +using DifferentiationInterface: + jacobian, + prepare_pushforward, + prepare_pullback, + pushforward!!, + value_and_pullback!!_split +using Krylov: block_gmres, gmres +using LinearOperators: LinearOperator +using LinearAlgebra: factorize, lu -include("utils.jl") -include("linear_solver.jl") include("implicit_function.jl") include("operators.jl") export ImplicitFunction -export AbstractLinearSolver, IterativeLinearSolver, DirectLinearSolver - -@static if !isdefined(Base, :get_extension) - # Loaded unconditionally on Julia < 1.9 - include("../ext/ImplicitDifferentiationChainRulesCoreExt.jl") - function __init__() - # Loaded conditionally on Julia < 1.9 - @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin - include("../ext/ImplicitDifferentiationForwardDiffExt.jl") - end - @require StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" begin - include("../ext/ImplicitDifferentiationStaticArraysExt.jl") - end - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin - include("../ext/ImplicitDifferentiationZygoteExt.jl") - end - end -end end diff --git a/src/implicit_function.jl b/src/implicit_function.jl index fd73f0a..6a91012 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -1,5 +1,20 @@ +struct DefaultLinearSolver end + +function (::DefaultLinearSolver)(A, b::AbstractVector) + x, stats = gmres(A, b) + return x +end + +function (::DefaultLinearSolver)(A, B::AbstractMatrix) + # X, stats = block_gmres(A, B) # https://github.com/JuliaSmoothOptimizers/Krylov.jl/issues/854 + X = mapreduce(hcat, eachcol(B)) do b + first(gmres(A, b)) + end + return X +end + """ - ImplicitFunction{F,C,L,B} + ImplicitFunction Wrapper for an implicit function defined by a forward mapping `y` and a set of conditions `c`. @@ -12,76 +27,67 @@ This requires solving a linear system `A * J = -B` where `A = ∂c/∂y`, `B = # Fields -- `forward::F`: a callable, does not need to be compatible with automatic differentiation -- `conditions::C`: a callable, must be compatible with automatic differentiation -- `linear_solver::L`: a subtype of `AbstractLinearSolver`, defines how the linear system will be solved -- `conditions_backend::B`: either `nothing` or a subtype of `AbstractDifferentiation.AbstractBackend`, defines how the conditions will be differentiated within the implicit function theorem +- `forward`: a callable, does not need to be compatible with automatic differentiation +- `conditions`: a callable, must be compatible with automatic differentiation +- `linear_solver`: a callable with two methods: + - `(A, b::AbstractVector) -> s::AbstractVector` such that `A * s = b` + - `(A, B::AbstractVector) -> S::AbstractMatrix` such that `A * S = B` +- `conditions_x_backend`: either `nothing` or an object subtyping `AbstractADType` from [ADTypes.jl](https://github.com/SciML/ADTypes.jl), defines how the conditions will be differentiated with respect to the first argument `x` +- `conditions_y_backend`: same for the second argument `y` There are two possible signatures for `forward` and `conditions`, which must be consistent with one another: -1. Standard: `forward(x, args...; kwargs...) = y` and `conditions(x, y, args...; kwargs...) = c` -2. Byproduct: `forward(x, args...; kwargs...) = (y, z)` and `conditions(x, y, z, args...; kwargs...) = c`. +| standard | byproduct | +|:---|:---| +| `forward(x, args...; kwargs...) = y` | `conditions(x, y, args...; kwargs...) = c` | +| `forward(x, args...; kwargs...) = (y, z)` | `conditions(x, y, z, args...; kwargs...) = c` | -In both cases, `x`, `y` and `c` must be arrays, with `size(y) = size(c)`. +In both cases, `x`, `y` and `c` must be `AbstractVector`s, with `length(y) = length(c)`. In the second case, the byproduct `z` can be an arbitrary object generated by `forward`. The positional arguments `args...` and keyword arguments `kwargs...` must be the same for both `forward` and `conditions`. -!!! warning "Warning" - The byproduct `z` and the other positional arguments `args...` beyond `x` are considered constant for differentiation purposes. +The byproduct `z` and the other positional arguments `args...` beyond `x` are considered constant for differentiation purposes. """ -struct ImplicitFunction{F,C,L<:AbstractLinearSolver,B<:Union{Nothing,AbstractBackend}} +@kwdef struct ImplicitFunction{ + F,C,L,B1<:Union{Nothing,AbstractADType},B2<:Union{Nothing,AbstractADType} +} forward::F conditions::C - linear_solver::L - conditions_backend::B + linear_solver::L = DefaultLinearSolver() + conditions_x_backend::B1 = nothing + conditions_y_backend::B2 = nothing end -""" - ImplicitFunction( - forward, - conditions; - linear_solver=IterativeLinearSolver(), - conditions_backend=nothing, - ) - -Construct an `ImplicitFunction` with default parameters. -""" -function ImplicitFunction( - forward, conditions; linear_solver=IterativeLinearSolver(), conditions_backend=nothing -) - return ImplicitFunction(forward, conditions, linear_solver, conditions_backend) +function ImplicitFunction(forward, conditions; kwargs...) + return ImplicitFunction(; forward, conditions, kwargs...) end function Base.show(io::IO, implicit::ImplicitFunction) - @unpack forward, conditions, linear_solver, conditions_backend = implicit + (; forward, conditions, linear_solver, conditions_x_backend, conditions_y_backend) = + implicit return print( - io, "ImplicitFunction($forward, $conditions, $linear_solver, $conditions_backend)" + io, + "ImplicitFunction($forward, $conditions, $linear_solver, $conditions_x_backend, $conditions_y_backend)", ) end """ (implicit::ImplicitFunction)(x::AbstractArray, args...; kwargs...) -Return `implicit.forward(x, args...; kwargs...)`, which can be either an array `y` or a tuple `(y, z)`. +Return `implicit.forward(x, args...; kwargs...)`, which can be either an `AbstractVector` `y` or a tuple `(y, z)`. -This call is differentiable. +This call makes `y` differentiable with respect to `x`. """ -function (implicit::ImplicitFunction)(x::AbstractArray, args...; kwargs...) +function (implicit::ImplicitFunction)(x::AbstractVector, args...; kwargs...) y_or_yz = implicit.forward(x, args...; kwargs...) - valid = ( - y_or_yz isa AbstractArray || # - (y_or_yz isa Tuple && length(y_or_yz) == 2 && y_or_yz[1] isa AbstractArray) - ) - if !valid - throw( - DimensionMismatch( - "The forward mapping must return an array `y` or a tuple `(y, z)` where `y` is an array", - ), - ) - end return y_or_yz end -get_output(y::AbstractArray) = y -get_output(yz::Tuple) = yz[1] -get_byproduct(yz::Tuple) = yz[2] +output(y::AbstractVector) = y +byproduct(::AbstractVector) = error("No byproduct") + +output(yz::Tuple{<:Any,<:Any}) = yz[1] +byproduct(yz::Tuple{<:Any,<:Any}) = yz[2] + +output((y, z)) = y +byproduct((y, z)) = z diff --git a/src/linear_solver.jl b/src/linear_solver.jl deleted file mode 100644 index fd39398..0000000 --- a/src/linear_solver.jl +++ /dev/null @@ -1,78 +0,0 @@ -""" - AbstractLinearSolver - -All linear solvers used within an `ImplicitFunction` must satisfy this interface. - -It can be useful to roll out your own solver if you need more fine-grained control on convergence / speed / behavior in case of singularity. -Check out the source code of `IterativeLinearSolver` and `DirectLinearSolver` for implementation examples. - -# Required methods - -- `presolve(linear_solver, A, y)`: Returns a matrix-like object `A` for which it is cheaper to solve several linear systems with different vectors `b` of type similar to `y` (a typical example would be to perform LU factorization). -- `solve(linear_solver, A, b)`: Returns a vector `x` satisfying `Ax = b`. If the linear system has not been solved to satisfaction, every element of `x` should be a `NaN` of the appropriate floating point type. -""" -abstract type AbstractLinearSolver end - -""" - IterativeLinearSolver - -An implementation of `AbstractLinearSolver` using `Krylov.gmres`, set as the default for `ImplicitFunction`. - -# Fields - -- `verbose::Bool`: Whether to display a warning when the solver fails and returns `NaN`s (defaults to `true`) -- `accept_inconsistent::Bool`: Whether to accept approximate least squares solutions for inconsistent systems, or fail and return `NaN`s (defaults to `false`) - -!!! note - If you find that your implicit gradients contains `NaN`s, try using this solver with `accept_inconsistent=true`. - However, beware that the implicit function theorem does not cover the case of inconsistent linear systems `AJ = B`, so it is unclear what the result will mean. -""" -Base.@kwdef struct IterativeLinearSolver <: AbstractLinearSolver - verbose::Bool = true - accept_inconsistent::Bool = false -end - -presolve(::IterativeLinearSolver, A, y) = A - -function solve(sol::IterativeLinearSolver, A, b) - x, stats = gmres(A, b) - if sol.accept_inconsistent - success = stats.solved || stats.inconsistent - else - success = stats.solved && !stats.inconsistent - end - if !success - if sol.verbose - @warn "IterativeLinearSolver failed, result contains NaNs" - @show stats - end - x .= NaN - end - return x -end - -""" - DirectLinearSolver - -An implementation of `AbstractLinearSolver` using the built-in backslash operator. - -# Fields - -- `verbose::Bool`: Whether to throw a warning when the solver fails (defaults to `true`) -""" -Base.@kwdef struct DirectLinearSolver <: AbstractLinearSolver - verbose::Bool = true -end - -function presolve(::DirectLinearSolver, A, y) - return lu(Matrix(A); check=false) -end - -function solve(sol::DirectLinearSolver, A_lu, b) - x = A_lu \ b - if !issuccess(A_lu) - sol.verbose && @warn "DirectLinearSolver failed, result contains NaNs" - x .= NaN - end - return x -end diff --git a/src/operators.jl b/src/operators.jl index 4b2fee7..bb3f617 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -1,115 +1,233 @@ -## Forward +## Partial conditions -function conditions_pushforwards( - ba::AbstractBackend, - implicit::ImplicitFunction, - x::AbstractArray, - y::AbstractArray, - args; - kwargs, -) - conditions = implicit.conditions - pfA = only ∘ pushforward_function(ba, _y -> conditions(x, _y, args...; kwargs...), y) - pfB = only ∘ pushforward_function(ba, _x -> conditions(_x, y, args...; kwargs...), x) - return pfA, pfB +struct ConditionsXNoByproduct{C,Y,A,K} + conditions::C + y::Y + args::A + kwargs::K end -function conditions_pushforwards( - ba::AbstractBackend, - implicit::ImplicitFunction, - x::AbstractArray, - yz::Tuple, - args; - kwargs, -) - conditions = implicit.conditions - y, z = yz - pfA = only ∘ pushforward_function(ba, _y -> conditions(x, _y, z, args...; kwargs...), y) - pfB = only ∘ pushforward_function(ba, _x -> conditions(_x, y, z, args...; kwargs...), x) - return pfA, pfB +function (conditions_x_nobyproduct::ConditionsXNoByproduct)(x::AbstractVector) + (; conditions, y, args, kwargs) = conditions_x_nobyproduct + return conditions(x, y, args...; kwargs...) end -struct PushforwardProd!{F,N} - pushforward::F - size::NTuple{N,Int} +struct ConditionsYNoByproduct{C,X,A,K} + conditions::C + x::X + args::A + kwargs::K end -function (pfp::PushforwardProd!)(dc_vec::AbstractVector, dy_vec::AbstractVector) - dy = reshape(dy_vec, pfp.size) - dc = pfp.pushforward(dy) - return dc_vec .= vec(dc) +function (conditions_y_nobyproduct::ConditionsYNoByproduct)(y::AbstractVector) + (; conditions, x, args, kwargs) = conditions_y_nobyproduct + return conditions(x, y, args...; kwargs...) end -function pushforwards_to_operators(x::AbstractArray, y::AbstractArray, pfA, pfB) - n, m = length(x), length(y) - A_vec = LinearOperator(eltype(y), m, m, false, false, PushforwardProd!(pfA, size(y))) - B_vec = LinearOperator(eltype(x), m, n, false, false, PushforwardProd!(pfB, size(x))) - return A_vec, B_vec +struct ConditionsXByproduct{C,Y,Z,A,K} + conditions::C + y::Y + z::Z + args::A + kwargs::K end -function conditions_forward_operators( - backend::AbstractBackend, implicit::ImplicitFunction, x, y_or_yz, args; kwargs -) - y = get_output(y_or_yz) - pfA, pfB = conditions_pushforwards(backend, implicit, x, y_or_yz, args; kwargs) - A_vec, B_vec = pushforwards_to_operators(x, y, pfA, pfB) - return A_vec, B_vec +function (conditions_x_byproduct::ConditionsXByproduct)(x::AbstractVector) + (; conditions, y, z, args, kwargs) = conditions_x_byproduct + return conditions(x, y, z, args...; kwargs...) end -## Reverse +struct ConditionsYByproduct{C,X,Z,A,K} + conditions::C + x::X + z::Z + args::A + kwargs::K +end -function conditions_pullbacks( - ba::AbstractBackend, - implicit::ImplicitFunction, - x::AbstractArray, - y::AbstractArray, - args; - kwargs, -) - conditions = implicit.conditions - pbAᵀ = only ∘ pullback_function(ba, _y -> conditions(x, _y, args...; kwargs...), y) - pbBᵀ = only ∘ pullback_function(ba, _x -> conditions(_x, y, args...; kwargs...), x) - return pbAᵀ, pbBᵀ +function (conditions_y_byproduct::ConditionsYByproduct)(y::AbstractVector) + (; conditions, x, z, args, kwargs) = conditions_y_byproduct + return conditions(x, y, z, args...; kwargs...) end -function conditions_pullbacks( - ba::AbstractBackend, - implicit::ImplicitFunction, - x::AbstractArray, - yz::Tuple, - args; - kwargs, -) - conditions = implicit.conditions - y, z = yz - pbAᵀ = only ∘ pullback_function(ba, _y -> conditions(x, _y, z, args...; kwargs...), y) - pbBᵀ = only ∘ pullback_function(ba, _x -> conditions(_x, y, z, args...; kwargs...), x) - return pbAᵀ, pbBᵀ +function ConditionsX(conditions, x, y_or_yz, args...; kwargs...) + y = output(y_or_yz) + if y_or_yz isa Tuple + z = byproduct(y_or_yz) + return ConditionsXByproduct(conditions, y, z, args, kwargs) + else + return ConditionsXNoByproduct(conditions, y, args, kwargs) + end +end + +function ConditionsY(conditions, x, y_or_yz, args...; kwargs...) + if y_or_yz isa Tuple + z = byproduct(y_or_yz) + return ConditionsYByproduct(conditions, x, z, args, kwargs) + else + return ConditionsYNoByproduct(conditions, x, args, kwargs) + end +end + +## Lazy operators + +struct PushforwardOperator!{F,B,X,E,R} + f::F + backend::B + x::X + extras::E + res_backup::R end -struct PullbackProd!{F,N} - pullback::F - size::NTuple{N,Int} +function (po::PushforwardOperator!)(res, v, α, β) + if iszero(β) + res .= pushforward!!(po.f, res, po.backend, po.x, v, po.extras) + res .= α .* res + else + po.res_backup .= res + res .= pushforward!!(po.f, res, po.backend, po.x, v, po.extras) + res .= α .* res .+ β .* po.res_backup + end + return res end -function (pbp::PullbackProd!)(dy_vec::AbstractVector, dc_vec::AbstractVector) - dc = reshape(dc_vec, pbp.size) - dy = pbp.pullback(dc) - return dy_vec .= vec(dy) +struct PullbackOperator!{PB,R} + pullbackfunc!!::PB + res_backup::R end -function pullbacks_to_operators(x::AbstractArray, y::AbstractArray, pbAᵀ, pbBᵀ) +function (po::PullbackOperator!)(res, v, α, β) + if iszero(β) + res .= po.pullbackfunc!!(res, v) + res .= α .* res + else + po.res_backup .= res + res .= po.pullbackfunc!!(res, v) + res .= α .* res .+ β .+ po.res_backup + end + return res +end + +function build_A( + implicit::ImplicitFunction, + x::AbstractVector, + y_or_yz, + args...; + suggested_backend, + kwargs..., +) + (; conditions, linear_solver, conditions_y_backend) = implicit + y = output(y_or_yz) + n, m = length(x), length(y) + back_y = isnothing(conditions_y_backend) ? suggested_backend : conditions_y_backend + cond_y = ConditionsY(conditions, x, y_or_yz, args...; kwargs...) + if linear_solver isa typeof(\) + J = jacobian(cond_y, back_y, y) + A = factorize(J) + else + extras = prepare_pushforward(cond_y, back_y, y) + A = LinearOperator( + eltype(y), + m, + m, + false, + false, + PushforwardOperator!(cond_y, back_y, y, extras, similar(y)), + typeof(y), + ) + end + return A +end + +function build_Aᵀ( + implicit::ImplicitFunction, + x::AbstractVector, + y_or_yz, + args...; + suggested_backend, + kwargs..., +) + (; conditions, linear_solver, conditions_y_backend) = implicit + y = output(y_or_yz) + n, m = length(x), length(y) + back_y = isnothing(conditions_y_backend) ? suggested_backend : conditions_y_backend + cond_y = ConditionsY(conditions, x, y_or_yz, args...; kwargs...) + if linear_solver isa typeof(\) + Jᵀ = transpose(jacobian(cond_y, back_y, y)) + Aᵀ = factorize(Jᵀ) + else + extras = prepare_pullback(cond_y, back_y, y) + _, pullbackfunc!! = value_and_pullback!!_split(cond_y, back_y, y, extras) + Aᵀ = LinearOperator( + eltype(y), + m, + m, + false, + false, + PullbackOperator!(pullbackfunc!!, similar(y)), + typeof(y), + ) + end + return Aᵀ +end + +function build_B( + implicit::ImplicitFunction, + x::AbstractVector, + y_or_yz, + args...; + suggested_backend, + kwargs..., +) + (; conditions, linear_solver, conditions_x_backend) = implicit + y = output(y_or_yz) n, m = length(x), length(y) - Aᵀ_vec = LinearOperator(eltype(y), m, m, false, false, PullbackProd!(pbAᵀ, size(y))) - Bᵀ_vec = LinearOperator(eltype(y), n, m, false, false, PullbackProd!(pbBᵀ, size(y))) - return Aᵀ_vec, Bᵀ_vec + back_x = isnothing(conditions_x_backend) ? suggested_backend : conditions_x_backend + cond_x = ConditionsX(conditions, x, y_or_yz, args...; kwargs...) + if linear_solver isa typeof(\) + B = transpose(jacobian(cond_x, back_x, x)) + else + extras = prepare_pushforward(cond_x, back_x, x) + B = LinearOperator( + eltype(y), + m, + n, + false, + false, + PushforwardOperator!(cond_x, back_x, x, extras, similar(y)), + typeof(x), + ) + end + return B end -function conditions_reverse_operators( - backend::AbstractBackend, implicit::ImplicitFunction, x, y_or_yz, args; kwargs +function build_Bᵀ( + implicit::ImplicitFunction, + x::AbstractVector, + y_or_yz, + args...; + suggested_backend, + kwargs..., ) - y = get_output(y_or_yz) - pbAᵀ, pbBᵀ = conditions_pullbacks(backend, implicit, x, y_or_yz, args; kwargs) - Aᵀ_vec, Bᵀ_vec = pullbacks_to_operators(x, y, pbAᵀ, pbBᵀ) - return Aᵀ_vec, Bᵀ_vec + (; conditions, linear_solver, conditions_x_backend) = implicit + y = output(y_or_yz) + n, m = length(x), length(y) + back_x = isnothing(conditions_x_backend) ? suggested_backend : conditions_x_backend + cond_x = ConditionsX(conditions, x, y_or_yz, args...; kwargs...) + if linear_solver isa typeof(\) + Bᵀ = transpose(jacobian(cond_x, back_x, x)) + else + extras = prepare_pullback(cond_x, back_x, x) + _, pullbackfunc!! = value_and_pullback!!_split(cond_x, back_x, x, extras) + Bᵀ = LinearOperator( + eltype(y), + n, + m, + false, + false, + PullbackOperator!(pullbackfunc!!, similar(y)), + typeof(x), + ) + end + return Bᵀ end diff --git a/src/utils.jl b/src/utils.jl deleted file mode 100644 index 3c32c2a..0000000 --- a/src/utils.jl +++ /dev/null @@ -1,5 +0,0 @@ -function identity_break_autodiff(x) - a = [0.0] - a[1] = float(first(x)) - return x -end diff --git a/test/errors.jl b/test/errors.jl deleted file mode 100644 index 0a729f4..0000000 --- a/test/errors.jl +++ /dev/null @@ -1,80 +0,0 @@ -using ChainRulesCore -using ChainRulesTestUtils -using ForwardDiff -using ImplicitDifferentiation -using Test -using Zygote - -@testset "Byproduct handling" begin - f1 = (_) -> (1, 2) - f2 = (_) -> ([1.0], 2, 3) - c = nothing - imf1 = ImplicitFunction(f1, c) - imf2 = ImplicitFunction(f2, c) - @test_throws DimensionMismatch imf1(zeros(1)) - @test_throws DimensionMismatch imf2(zeros(1)) -end - -@testset "Only accept one array" begin - f = identity - c = nothing - imf = ImplicitFunction(f, c) - @test_throws MethodError imf((1.0,)) - @test_throws MethodError imf([1.0], [1.0]) -end - -@testset verbose = true "Derivative NaNs" begin - x = zeros(Float32, 2) - linear_solvers = ( - IterativeLinearSolver(; verbose=false), # - IterativeLinearSolver(; verbose=false, accept_inconsistent=true), # - DirectLinearSolver(; verbose=false), # - ) - function should_give_nan(linear_solver) - return linear_solver isa DirectLinearSolver || !linear_solver.accept_inconsistent - end - - @testset "Infinite derivative" begin - f = x -> sqrt.(x) # nondifferentiable at 0 - c = (x, y) -> y .^ 2 .- x - for linear_solver in linear_solvers - @testset "$(typeof(linear_solver))" begin - implicit = ImplicitFunction(f, c; linear_solver) - J1 = ForwardDiff.jacobian(implicit, x) - J2 = Zygote.jacobian(implicit, x)[1] - @test all(isnan, J1) == should_give_nan(linear_solver) - @test all(isnan, J2) == should_give_nan(linear_solver) - @test eltype(J1) == Float32 - @test eltype(J2) == Float32 - end - end - end - - @testset "Singular linear system" begin - f = x -> x # wrong solver - c = (x, y) -> (x .+ 1) .^ 2 .- y .^ 2 - for linear_solver in linear_solvers - @testset "$(typeof(linear_solver))" begin - implicit = ImplicitFunction(f, c; linear_solver) - J1 = ForwardDiff.jacobian(implicit, x) - J2 = Zygote.jacobian(implicit, x)[1] - @test all(isnan, J1) == should_give_nan(linear_solver) - @test all(isnan, J2) == should_give_nan(linear_solver) - @test eltype(J1) == Float32 - @test eltype(J2) == Float32 - end - end - end -end - -@testset "Weird ChainRulesTestUtils behavior" begin - x = rand(3) - forward(x) = sqrt.(abs.(x)), 1 - conditions(x, y, z) = abs.(y ./ z) .- abs.(x) - implicit = ImplicitFunction(forward, conditions) - y, z = implicit(x) - dy = similar(y) - rc = Zygote.ZygoteRuleConfig() - test_rrule(rc, implicit, x; atol=1e-2, output_tangent=(dy, 0)) - @test_skip test_rrule(rc, implicit, x; atol=1e-2, output_tangent=(dy, NoTangent())) -end diff --git a/test/runtests.jl b/test/runtests.jl index 37bc32d..338eb57 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,7 +6,6 @@ using ForwardDiff: ForwardDiff using ImplicitDifferentiation using JET using JuliaFormatter -using Pkg using Random using Test using Zygote: Zygote @@ -15,41 +14,21 @@ DocMeta.setdocmeta!( ImplicitDifferentiation, :DocTestSetup, :(using ImplicitDifferentiation); recursive=true ) -function markdown_title(path) - title = "?" - open(path, "r") do file - for line in eachline(file) - if startswith(line, '#') - title = strip(line, [' ', '#']) - break - end - end - end - return title -end - EXAMPLES_DIR_JL = joinpath(dirname(@__DIR__), "examples") ## Test sets @testset verbose = true "ImplicitDifferentiation.jl" begin @testset verbose = false "Code quality (Aqua.jl)" begin - if VERSION >= v"1.9" - Aqua.test_all(ImplicitDifferentiation; ambiguities=false, deps_compat=false) - Aqua.test_deps_compat(ImplicitDifferentiation; check_extras=false) - end + Aqua.test_all( + ImplicitDifferentiation; ambiguities=false, deps_compat=(check_extras = false) + ) end @testset verbose = true "Formatting (JuliaFormatter.jl)" begin @test format(ImplicitDifferentiation; verbose=false, overwrite=false) end @testset verbose = true "Static checking (JET.jl)" begin - if VERSION >= v"1.9" - JET.test_package( - ImplicitDifferentiation; - target_defined_modules=true, - toplevel_logger=nothing, - ) - end + JET.test_package(ImplicitDifferentiation; target_defined_modules=true) end @testset verbose = false "Doctests (Documenter.jl)" begin doctest(ImplicitDifferentiation) @@ -57,18 +36,12 @@ EXAMPLES_DIR_JL = joinpath(dirname(@__DIR__), "examples") @testset verbose = true "Examples" begin @info "Example tests" for file in readdir(EXAMPLES_DIR_JL) - path = joinpath(EXAMPLES_DIR_JL, file) - title = markdown_title(path) - @info "$title" - @testset verbose = true "$title" begin - include(path) + @info "$file" + @testset "$file" begin + include(joinpath(EXAMPLES_DIR_JL, file)) end end end - @testset verbose = true "Errors" begin - @info "Error tests" - include("errors.jl") - end @testset verbose = true "Systematic" begin @info "Systematic tests" include("systematic.jl") diff --git a/test/systematic.jl b/test/systematic.jl index 2f022e2..29ad82e 100644 --- a/test/systematic.jl +++ b/test/systematic.jl @@ -1,404 +1,53 @@ -import AbstractDifferentiation as AD -using ChainRulesCore -using ChainRulesTestUtils +using ADTypes +using Enzyme: Enzyme using ForwardDiff: ForwardDiff -import ImplicitDifferentiation as ID -using ImplicitDifferentiation: ImplicitFunction, identity_break_autodiff -using ImplicitDifferentiation: DirectLinearSolver, IterativeLinearSolver -using JET -using LinearAlgebra -using Random using SparseArrays using StaticArrays using Test using Zygote: Zygote, ZygoteRuleConfig -@static if VERSION < v"1.9" - macro test_opt(x...) - return :() - end - macro test_call(x...) - return :() - end -end - -Random.seed!(63); - -## Utils - -change_shape(x::AbstractArray{T,3}) where {T} = x[:, :, 1] -change_shape(x::AbstractSparseArray) = x - -function mysqrt(x::AbstractArray) - return identity_break_autodiff(sqrt.(abs.(x))) -end - -## Various signatures - -function make_implicit_sqrt(; kwargs...) - forward(x) = mysqrt(change_shape(x)) - conditions(x, y) = abs2.(y) .- abs.(change_shape(x)) - implicit = ImplicitFunction(forward, conditions; kwargs...) - return implicit -end - -function make_implicit_sqrt_byproduct(; kwargs...) - forward(x) = 1 * mysqrt(change_shape(x)), 1 - conditions(x, y, z::Integer) = abs2.(y ./ z) .- abs.(change_shape(x)) - implicit = ImplicitFunction(forward, conditions; kwargs...) - return implicit -end - -function make_implicit_sqrt_args(; kwargs...) - forward(x, p::Integer) = p * mysqrt(change_shape(x)) - conditions(x, y, p::Integer) = abs2.(y ./ p) .- abs.(change_shape(x)) - implicit = ImplicitFunction(forward, conditions; kwargs...) - return implicit -end - -function make_implicit_sqrt_kwargs(; kwargs...) - forward(x; p::Integer) = p .* mysqrt(change_shape(x)) - conditions(x, y; p::Integer) = abs2.(y ./ p) .- abs.(change_shape(x)) - implicit = ImplicitFunction(forward, conditions; kwargs...) - return implicit -end - -## Low level tests - -function coherent_array_type(a, b) - if a isa Array - return b isa Array || b isa (Base.ReshapedArray{T,N,<:Array} where {T,N}) - elseif a isa StaticArray - return b isa StaticArray || - b isa (Base.ReshapedArray{T,N,<:StaticArray} where {T,N}) - elseif a isa AbstractSparseArray - return b isa AbstractSparseArray || - b isa (Base.ReshapedArray{T,N,<:AbstractSparseArray} where {T,N}) - else - error("New array type") - end -end - -function test_implicit_call(x::AbstractArray{T}; kwargs...) where {T} - imf1 = make_implicit_sqrt(; kwargs...) - imf2 = make_implicit_sqrt_byproduct(; kwargs...) - imf3 = make_implicit_sqrt_args(; kwargs...) - imf4 = make_implicit_sqrt_kwargs(; kwargs...) - - y_true = mysqrt(change_shape(x)) - y1 = @inferred imf1(x) - y2, z2 = @inferred imf2(x) - y3 = @inferred imf3(x, 1) - y4 = @inferred imf4(x; p=1) - - @testset "Exact value" begin - @test y1 ≈ y_true - @test y2 ≈ y_true - @test y3 ≈ y_true - @test y4 ≈ y_true - @test z2 ≈ 1 - end - - @testset "Array type" begin - @test coherent_array_type(x, y1) - @test coherent_array_type(x, y2) - @test coherent_array_type(x, y3) - @test coherent_array_type(x, y4) - end - - @testset "JET" begin - @test_opt target_modules = (ID,) imf1(x) - @test_opt target_modules = (ID,) imf2(x) - @test_opt target_modules = (ID,) imf3(x, 1) - @test_opt target_modules = (ID,) imf4(x; p=1) - - @test_call target_modules = (ID,) imf1(x) - @test_call target_modules = (ID,) imf2(x) - @test_call target_modules = (ID,) imf3(x, 1) - @test_call target_modules = (ID,) imf4(x; p=1) - end -end - -function test_implicit_duals(x::AbstractArray{T}; kwargs...) where {T} - imf1 = make_implicit_sqrt(; kwargs...) - imf2 = make_implicit_sqrt_byproduct(; kwargs...) - imf3 = make_implicit_sqrt_args(; kwargs...) - imf4 = make_implicit_sqrt_kwargs(; kwargs...) - - y_true = mysqrt(change_shape(x)) - dx = similar(x) - dx .= one(T) - x_and_dx = ForwardDiff.Dual.(x, dx) - - #= - TODO: fix AbstractDifferentiation.jl 0.6 - - y_and_dy1 = @inferred imf1(x_and_dx) - y_and_dy2, z2 = @inferred imf2(x_and_dx) - y_and_dy3 = @inferred imf3(x_and_dx, 1) - y_and_dy4 = @inferred imf4(x_and_dx; p=1) - =# - - y_and_dy1 = imf1(x_and_dx) - y_and_dy2, z2 = imf2(x_and_dx) - y_and_dy3 = imf3(x_and_dx, 1) - y_and_dy4 = imf4(x_and_dx; p=1) - - @testset "Dual numbers" begin - @test ForwardDiff.value.(y_and_dy1) ≈ y_true - @test ForwardDiff.value.(y_and_dy2) ≈ y_true - @test ForwardDiff.value.(y_and_dy3) ≈ y_true - @test ForwardDiff.value.(y_and_dy4) ≈ y_true - @test z2 ≈ 1 - end - - @testset "Static arrays" begin - @test coherent_array_type(x, y_and_dy1) - @test coherent_array_type(x, y_and_dy2) - @test coherent_array_type(x, y_and_dy3) - @test coherent_array_type(x, y_and_dy4) - end - - @testset "JET" begin - @test_opt target_modules = (ID,) imf1(x_and_dx) - @test_opt target_modules = (ID,) imf2(x_and_dx) - @test_opt target_modules = (ID,) imf3(x_and_dx, 1) - @test_opt target_modules = (ID,) imf4(x_and_dx; p=1) - - @test_call target_modules = (ID,) imf1(x_and_dx) - @test_call target_modules = (ID,) imf2(x_and_dx) - @test_call target_modules = (ID,) imf3(x_and_dx, 1) - @test_call target_modules = (ID,) imf4(x_and_dx; p=1) - end -end - -function test_implicit_rrule(rc, x::AbstractArray{T}; kwargs...) where {T} - imf1 = make_implicit_sqrt(; kwargs...) - imf2 = make_implicit_sqrt_byproduct(; kwargs...) - imf3 = make_implicit_sqrt_args(; kwargs...) - imf4 = make_implicit_sqrt_kwargs(; kwargs...) - - y_true = mysqrt(change_shape(x)) - dy = similar(y_true) - dy .= one(eltype(y_true)) - dz = nothing - - #= - # TODO: fix AbstractDifferentiation.jl 0.6 - - y1, pb1 = @inferred rrule(rc, imf1, x) - (y2, z2), pb2 = @inferred rrule(rc, imf2, x) - y3, pb3 = @inferred rrule(rc, imf3, x, 1) - y4, pb4 = @inferred rrule(rc, imf4, x; p=1) - - dimf1, dx1 = @inferred pb1(dy) - dimf2, dx2 = @inferred pb2((dy, dz)) - dimf3, dx3, dp3 = @inferred pb3(dy) - dimf4, dx4 = @inferred pb4(dy) - =# - - y1, pb1 = rrule(rc, imf1, x) - (y2, z2), pb2 = rrule(rc, imf2, x) - y3, pb3 = rrule(rc, imf3, x, 1) - y4, pb4 = rrule(rc, imf4, x; p=1) - - dimf1, dx1 = pb1(dy) - dimf2, dx2 = pb2((dy, dz)) - dimf3, dx3, dp3 = pb3(dy) - dimf4, dx4 = pb4(dy) - - @testset "Pullbacks" begin - @test y1 ≈ y_true - @test y2 ≈ y_true - @test y3 ≈ y_true - @test y4 ≈ y_true - @test z2 ≈ 1 - - @test dimf1 isa NoTangent - @test dimf2 isa NoTangent - @test dimf3 isa NoTangent - @test dimf4 isa NoTangent - - @test size(dx1) == size(x) - @test size(dx2) == size(x) - @test size(dx3) == size(x) - @test size(dx4) == size(x) - - @test dp3 isa ChainRulesCore.NotImplemented - end - - @testset "Array type" begin - @test coherent_array_type(x, y1) - @test coherent_array_type(x, y2) - @test coherent_array_type(x, y3) - @test coherent_array_type(x, y4) - - @test coherent_array_type(x, dx1) - @test coherent_array_type(x, dx2) - @test coherent_array_type(x, dx3) - @test coherent_array_type(x, dx4) - end - - @testset "JET" begin - @test_skip @test_opt target_modules = (ID,) rrule(rc, imf1, x) - @test_skip @test_opt target_modules = (ID,) rrule(rc, imf2, x) - @test_skip @test_opt target_modules = (ID,) rrule(rc, imf3, x, 1) - @test_skip @test_opt target_modules = (ID,) rrule(rc, imf4, x; p=1) - - @test_skip @test_opt target_modules = (ID,) pb1(dy) - @test_skip @test_opt target_modules = (ID,) pb2((dy, dz)) - @test_skip @test_opt target_modules = (ID,) pb3(dy) - @test_skip @test_opt target_modules = (ID,) pb4(dy) - - @test_call target_modules = (ID,) rrule(rc, imf1, x) - @test_call target_modules = (ID,) rrule(rc, imf2, x) - @test_call target_modules = (ID,) rrule(rc, imf3, x, 1) - @test_call target_modules = (ID,) rrule(rc, imf4, x; p=1) - - @test_call target_modules = (ID,) pb1(dy) - @test_call target_modules = (ID,) pb2((dy, dz)) - @test_call target_modules = (ID,) pb3(dy) - @test_call target_modules = (ID,) pb4(dy) - end - - @testset "ChainRulesTestUtils" begin - test_rrule(rc, imf1, x; atol=1e-2, check_inferred=false) - test_rrule(rc, imf2, x; atol=5e-2, output_tangent=(dy, 0), check_inferred=false) # see issue https://github.com/gdalle/ImplicitDifferentiation.jl/issues/112 - test_rrule(rc, imf3, x, 1; atol=1e-2, check_inferred=false) - test_rrule(rc, imf4, x; atol=1e-2, fkwargs=(p=1,), check_inferred=false) - end -end - -## High-level tests per backend - -function test_implicit_forwarddiff(x::AbstractArray{T}; kwargs...) where {T} - imf1 = make_implicit_sqrt(; kwargs...) - imf2 = make_implicit_sqrt_byproduct(; kwargs...) - imf3 = make_implicit_sqrt_args(; kwargs...) - imf4 = make_implicit_sqrt_kwargs(; kwargs...) - - J1 = ForwardDiff.jacobian(imf1, x) - J2 = ForwardDiff.jacobian(first ∘ imf2, x) - J3 = ForwardDiff.jacobian(_x -> imf3(_x, 1), x) - J4 = ForwardDiff.jacobian(_x -> imf4(_x; p=1), x) - J_true = ForwardDiff.jacobian(_x -> sqrt.(change_shape(_x)), x) - - @testset "Exact Jacobian" begin - @test J1 ≈ J_true - @test J2 ≈ J_true - @test J3 ≈ J_true - @test J4 ≈ J_true - - @test eltype(J1) == eltype(x) - @test eltype(J2) == eltype(x) - @test eltype(J3) == eltype(x) - @test eltype(J4) == eltype(x) - end - return nothing -end - -function test_implicit_zygote(x::AbstractArray{T}; kwargs...) where {T} - imf1 = make_implicit_sqrt(; kwargs...) - imf2 = make_implicit_sqrt_byproduct(; kwargs...) - imf3 = make_implicit_sqrt_args(; kwargs...) - imf4 = make_implicit_sqrt_kwargs(; kwargs...) - - J1 = Zygote.jacobian(imf1, x)[1] - J2 = Zygote.jacobian(first ∘ imf2, x)[1] - J3 = Zygote.jacobian(imf3, x, 1)[1] - J4 = Zygote.jacobian(_x -> imf4(_x; p=1), x)[1] - J_true = Zygote.jacobian(_x -> sqrt.(change_shape(_x)), x)[1] - - @testset "Exact Jacobian" begin - @test J1 ≈ J_true - @test J2 ≈ J_true - @test J3 ≈ J_true - @test J4 ≈ J_true - - @test eltype(J1) == eltype(x) - @test eltype(J2) == eltype(x) - @test eltype(J3) == eltype(x) - @test eltype(J4) == eltype(x) - end - return nothing -end - -function test_implicit(x; kwargs...) - @testset verbose = true "Call" begin - test_implicit_call(x; kwargs...) - end - @testset verbose = true "ForwardDiff.jl" begin - if !(x isa AbstractSparseArray) - test_implicit_forwarddiff(x; kwargs...) - test_implicit_duals(x; kwargs...) - end - end - @testset verbose = true "Zygote.jl" begin - rc = Zygote.ZygoteRuleConfig() - test_implicit_zygote(x; kwargs...) - test_implicit_rrule(rc, x; kwargs...) - end - return nothing -end +include("utils.jl") ## Parameter combinations +backends = [ + AutoForwardDiff(; chunksize=1), # + AutoEnzyme(Enzyme.Forward), + AutoZygote(), +] + linear_solver_candidates = ( - IterativeLinearSolver(), # - DirectLinearSolver(), # + \, # + ID.DefaultLinearSolver(), ) conditions_backend_candidates = ( - nothing, # - AD.ForwardDiffBackend(), # - # AD.ZygoteBackend(), # TODO: failing - # AD.ReverseDiffBackend() # TODO: failing - # AD.FiniteDifferencesBackend() # TODO: failing + nothing, # + AutoForwardDiff(; chunksize=1), + # AutoEnzyme(Enzyme.Forward), ); x_candidates = ( - rand(Float32, 2, 3, 2), # - SArray{Tuple{2,3,2}}(rand(Float32, 2, 3, 2)), # - sparse(rand(Float32, 2)), # - sparse(rand(Float32, 2, 3)), # + Float32[3, 4], # + MVector{2}(Float32[3, 4]), # ); -params_candidates = [] - -for linear_solver in linear_solver_candidates, x in x_candidates - push!( - params_candidates, (; - linear_solver=linear_solver, # - conditions_backend=nothing, # - x=x, # - ) - ) -end - -for conditions_backend in conditions_backend_candidates - push!( - params_candidates, - (; - linear_solver=linear_solver_candidates[1], # - conditions_backend=conditions_backend, # - x=x_candidates[1], # - ), - ) -end - ## Test loop -for (linear_solver, conditions_backend, x) in params_candidates - testsetname = "$(typeof(linear_solver)) - $(typeof(conditions_backend)) - $(typeof(x))" - if ( - linear_solver isa DirectLinearSolver && - x isa AbstractSparseArray && - VERSION < v"1.9" - ) # missing linalg function for sparse arrays in 1.6 +@testset verbose = false "$(typeof(x)) - $linear_solver - $(typeof(conditions_backend))" for ( + x, linear_solver, conditions_backend +) in Iterators.product( + x_candidates, linear_solver_candidates, conditions_backend_candidates +) + if x isa StaticArray && (linear_solver != \) continue end - @info "$testsetname" - @testset "$testsetname" begin - test_implicit(x; linear_solver, conditions_backend) - end -end + @info "Testing $(typeof(x)) - $linear_solver - $(typeof(conditions_backend))" + test_implicit( + backends, + x; + linear_solver, + conditions_x_backend=conditions_backend, + conditions_y_backend=conditions_backend, + ) +end; diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 0000000..373a503 --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,298 @@ +using ADTypes +using ChainRulesCore +using ChainRulesTestUtils +using DifferentiationInterface: DifferentiationInterface +using ForwardDiff: ForwardDiff +import ImplicitDifferentiation as ID +using ImplicitDifferentiation: ImplicitFunction +using JET +using LinearAlgebra +using SparseArrays +using StaticArrays +using Test +using Zygote: Zygote, ZygoteRuleConfig + +## + +function identity_break_autodiff(x::X)::X where {R,X<:AbstractVector{R}} + float(first(x)) # break ForwardDiff + (Vector{R}(undef, 1))[1] = first(x) # break Zygote + result = try + throw(copy(x)) + catch y + y + end + return result +end + +function mysqrt(x::AbstractVector) + return identity_break_autodiff(sqrt.(abs.(x))) +end + +## Various signatures + +function make_implicit_sqrt(; kwargs...) + forward(x) = mysqrt(x) + conditions(x, y) = abs2.(y) .- abs.(x) + implicit = ImplicitFunction(forward, conditions; kwargs...) + return implicit +end + +function make_implicit_sqrt_byproduct(; kwargs...) + forward(x) = one(eltype(x)) .* mysqrt(x), one(eltype(x)) + conditions(x, y, z) = abs2.(y ./ z) .- abs.(x) + implicit = ImplicitFunction(forward, conditions; kwargs...) + return implicit +end + +function make_implicit_sqrt_args(; kwargs...) + forward(x, p) = p .* mysqrt(x) + conditions(x, y, p) = abs2.(y ./ p) .- abs.(x) + implicit = ImplicitFunction(forward, conditions; kwargs...) + return implicit +end + +function make_implicit_sqrt_kwargs(; kwargs...) + forward(x; p) = p .* mysqrt(x) + conditions(x, y; p) = abs2.(y ./ p) .- abs.(x) + implicit = ImplicitFunction(forward, conditions; kwargs...) + return implicit +end + +## Low level tests + +function test_coherent_array_type(a, b) + @test eltype(a) == eltype(b) + if a isa Array + @test b isa Array || b isa (Base.ReshapedArray{T,N,<:Array} where {T,N}) + elseif a isa StaticArray + @test b isa StaticArray || b isa (Base.ReshapedArray{T,N,<:StaticArray} where {T,N}) + elseif a isa AbstractSparseArray + @test b isa AbstractSparseArray || + b isa (Base.ReshapedArray{T,N,<:AbstractSparseArray} where {T,N}) + else + error("New array type") + end +end + +function test_implicit_call(x::AbstractVector{T}; type_stability=false, kwargs...) where {T} + imf1 = make_implicit_sqrt(; kwargs...) + imf2 = make_implicit_sqrt_byproduct(; kwargs...) + imf3 = make_implicit_sqrt_args(; kwargs...) + imf4 = make_implicit_sqrt_kwargs(; kwargs...) + + y_true = mysqrt(x) + y1 = imf1(x) + y2, z2 = imf2(x) + y3 = imf3(x, 1) + y4 = imf4(x; p=1) + + @testset "Primal value" begin + @test y1 ≈ y_true + @test y2 ≈ y_true + @test y3 ≈ y_true + @test y4 ≈ y_true + @test z2 ≈ 1 + end + + @testset "Array type" begin + test_coherent_array_type(x, y1) + test_coherent_array_type(x, y2) + test_coherent_array_type(x, y3) + test_coherent_array_type(x, y4) + end + + if type_stability + @testset "Type stability" begin + @test_opt target_modules = (ID,) imf1(x) + @test_opt target_modules = (ID,) imf2(x) + @test_opt target_modules = (ID,) imf3(x, 1) + @test_opt target_modules = (ID,) imf4(x; p=1) + end + end +end + +tag(::AbstractVector{<:ForwardDiff.Dual{T}}) where {T} = T + +function test_implicit_duals( + x::AbstractVector{T}; type_stability=false, kwargs... +) where {T} + imf1 = make_implicit_sqrt(; kwargs...) + imf2 = make_implicit_sqrt_byproduct(; kwargs...) + imf3 = make_implicit_sqrt_args(; kwargs...) + imf4 = make_implicit_sqrt_kwargs(; kwargs...) + + y_true = mysqrt(x) + dx = similar(x) + dx .= 2 * one(T) + x_and_dx = ForwardDiff.Dual.(x, dx) + + y_and_dy1 = imf1(x_and_dx) + y_and_dy2, z2 = imf2(x_and_dx) + y_and_dy3 = imf3(x_and_dx, 1) + y_and_dy4 = imf4(x_and_dx; p=1) + + @testset "Dual numbers" begin + @test ForwardDiff.value.(y_and_dy1) ≈ y_true + @test ForwardDiff.value.(y_and_dy2) ≈ y_true + @test ForwardDiff.value.(y_and_dy3) ≈ y_true + @test ForwardDiff.value.(y_and_dy4) ≈ y_true + @test ForwardDiff.extract_derivative(tag(y_and_dy1), y_and_dy1) ≈ + 2 .* inv.(2 .* sqrt.(x)) + @test ForwardDiff.extract_derivative(tag(y_and_dy2), y_and_dy2) ≈ + 2 .* inv.(2 .* sqrt.(x)) + @test ForwardDiff.extract_derivative(tag(y_and_dy3), y_and_dy3) ≈ + 2 .* inv.(2 .* sqrt.(x)) + @test ForwardDiff.extract_derivative(tag(y_and_dy4), y_and_dy4) ≈ + 2 .* inv.(2 .* sqrt.(x)) + @test z2 ≈ 1 + end + + @testset "Array type" begin + test_coherent_array_type(x, ForwardDiff.value.(y_and_dy1)) + test_coherent_array_type(x, ForwardDiff.value.(y_and_dy2)) + test_coherent_array_type(x, ForwardDiff.value.(y_and_dy3)) + test_coherent_array_type(x, ForwardDiff.value.(y_and_dy4)) + end + + if type_stability + @testset "Type stability" begin + @test_opt target_modules = (ID,) imf1(x_and_dx) + @test_opt target_modules = (ID,) imf2(x_and_dx) + @test_opt target_modules = (ID,) imf3(x_and_dx, 1) + @test_opt target_modules = (ID,) imf4(x_and_dx; p=1) + end + end +end + +function test_implicit_rrule( + rc, x::AbstractVector{T}; type_stability=false, kwargs... +) where {T} + imf1 = make_implicit_sqrt(; kwargs...) + imf2 = make_implicit_sqrt_byproduct(; kwargs...) + imf3 = make_implicit_sqrt_args(; kwargs...) + imf4 = make_implicit_sqrt_kwargs(; kwargs...) + + y_true = mysqrt(x) + dy = similar(y_true) + dy .= one(eltype(y_true)) + dz = nothing + + y1, pb1 = rrule(rc, imf1, x) + (y2, z2), pb2 = rrule(rc, imf2, x) + y3, pb3 = rrule(rc, imf3, x, 1) + y4, pb4 = rrule(rc, imf4, x; p=1) + + dimf1, dx1 = pb1(dy) + dimf2, dx2 = pb2((dy, dz)) + dimf3, dx3, dp3 = pb3(dy) + dimf4, dx4 = pb4(dy) + + @testset "Pullbacks" begin + @test y1 ≈ y_true + @test y2 ≈ y_true + @test y3 ≈ y_true + @test y4 ≈ y_true + @test z2 ≈ 1 + + @test dimf1 isa NoTangent + @test dimf2 isa NoTangent + @test dimf3 isa NoTangent + @test dimf4 isa NoTangent + + @test size(dx1) == size(x) + @test size(dx2) == size(x) + @test size(dx3) == size(x) + @test size(dx4) == size(x) + + @test dp3 isa ChainRulesCore.NotImplemented + end + + @testset "Array type" begin + test_coherent_array_type(x, y1) + test_coherent_array_type(x, y2) + test_coherent_array_type(x, y3) + test_coherent_array_type(x, y4) + + test_coherent_array_type(x, dx1) + test_coherent_array_type(x, dx2) + test_coherent_array_type(x, dx3) + test_coherent_array_type(x, dx4) + end + + @testset "ChainRulesTestUtils" begin + test_rrule(rc, imf1, x; atol=1e-2, check_inferred=false) + test_rrule(rc, imf2, x; atol=5e-2, output_tangent=(dy, 0), check_inferred=false) # see issue https://github.com/gdalle/ImplicitDifferentiation.jl/issues/112 + test_rrule(rc, imf3, x, 1; atol=1e-2, check_inferred=false) + test_rrule(rc, imf4, x; atol=1e-2, fkwargs=(p=1,), check_inferred=false) + end + + if type_stability + @testset "Type stability" begin + @test_opt target_modules = (ID,) rrule(rc, imf1, x) + @test_opt target_modules = (ID,) rrule(rc, imf2, x) + @test_opt target_modules = (ID,) rrule(rc, imf3, x, 1) + @test_opt target_modules = (ID,) rrule(rc, imf4, x; p=1) + + @test_opt target_modules = (ID,) pb1(dy) + @test_opt target_modules = (ID,) pb2((dy, dz)) + @test_opt target_modules = (ID,) pb3(dy) + @test_opt target_modules = (ID,) pb4(dy) + end + end +end + +## High-level tests per backend + +function test_implicit_backend( + backend::ADTypes.AbstractADType, x::AbstractVector{T}; type_stability=false, kwargs... +) where {T} + imf1 = make_implicit_sqrt(; kwargs...) + imf2 = make_implicit_sqrt_byproduct(; kwargs...) + imf3 = make_implicit_sqrt_args(; kwargs...) + imf4 = make_implicit_sqrt_kwargs(; kwargs...) + + J1 = DifferentiationInterface.jacobian(imf1, backend, x) + J2 = DifferentiationInterface.jacobian(first ∘ imf2, backend, x) + J3 = DifferentiationInterface.jacobian(_x -> imf3(_x, one(eltype(x))), backend, x) + + J4 = if !(backend isa AutoEnzyme) + DifferentiationInterface.jacobian(_x -> imf4(_x; p=one(eltype(x))), backend, x) + else + nothing + end + + J_true = ForwardDiff.jacobian(_x -> sqrt.(_x), x) + + @testset "Exact Jacobian" begin + @test J1 ≈ J_true + @test J2 ≈ J_true + @test J3 ≈ J_true + + @test eltype(J1) == eltype(x) + @test eltype(J2) == eltype(x) + @test eltype(J3) == eltype(x) + + if !(backend isa AutoEnzyme) + @test J4 ≈ J_true + @test eltype(J4) == eltype(x) + end + end + return nothing +end + +function test_implicit(backends, x; type_stability=false, kwargs...) + @testset "Call" begin + test_implicit_call(x; kwargs...) + end + @testset "Duals" begin + test_implicit_duals(x; kwargs...) + end + @testset "ChainRule" begin + test_implicit_rrule(ZygoteRuleConfig(), x; kwargs...) + end + @testset "$backend" for backend in backends + test_implicit_backend(backend, x; kwargs...) + end + return nothing +end