From 2c50679344ec86b5a53f0197faf9f3d4421b69db Mon Sep 17 00:00:00 2001 From: Mohamed Tarek Date: Sun, 30 Jul 2023 23:07:55 +1000 Subject: [PATCH] cache the LU factorisation in the direct linear solver and better static array support (#64) * better direct lin solver and static array support * Revamped byproduct specification, better tests for static arrays, compatible with static arrays of all sizes * Fix and better test staticarrays * Add application for byproducts & group docs examples * Reintroduce ignored docstrings * More tests for byproducts * fix type instability * fix CI --------- Co-authored-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> --- Project.toml | 8 +- docs/Manifest.toml | 211 ++++++++------- docs/Project.toml | 1 + docs/make.jl | 25 +- docs/src/api.md | 6 + docs/src/faq.md | 45 ++-- docs/src/index.md | 2 +- examples/{0_basic.jl => 0_intro.jl} | 10 +- examples/1_basic.jl | 242 ++++++++++++++++++ examples/1_unconstrained_optim.jl | 99 ------- .../{4_constrained_optim.jl => 2_advanced.jl} | 33 ++- examples/2_nonlinear_solve.jl | 85 ------ examples/3_fixed_points.jl | 85 ------ examples/{5_multiargs.jl => 3_tricks.jl} | 86 ++++++- examples/6_byproduct.jl | 8 - ext/ImplicitDifferentiationChainRulesExt.jl | 69 +++-- ext/ImplicitDifferentiationForwardDiffExt.jl | 45 ++-- ext/ImplicitDifferentiationStaticArraysExt.jl | 26 ++ src/ImplicitDifferentiation.jl | 10 +- src/conditions.jl | 22 ++ src/forward.jl | 30 +++ src/implicit_function.jl | 151 ++++------- src/linear_solver.jl | 50 ++++ src/utils.jl | 40 +-- test/misc.jl | 143 ----------- test/runtests.jl | 19 +- test/systematic.jl | 187 ++++++++++++++ 27 files changed, 992 insertions(+), 746 deletions(-) rename examples/{0_basic.jl => 0_intro.jl} (93%) create mode 100644 examples/1_basic.jl delete mode 100644 examples/1_unconstrained_optim.jl rename examples/{4_constrained_optim.jl => 2_advanced.jl} (86%) delete mode 100644 examples/2_nonlinear_solve.jl delete mode 100644 examples/3_fixed_points.jl rename examples/{5_multiargs.jl => 3_tricks.jl} (51%) delete mode 100644 examples/6_byproduct.jl create mode 100644 ext/ImplicitDifferentiationStaticArraysExt.jl create mode 100644 src/conditions.jl create mode 100644 src/forward.jl create mode 100644 src/linear_solver.jl delete mode 100644 test/misc.jl create mode 100644 test/systematic.jl diff --git a/Project.toml b/Project.toml index 2331525..2dec795 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.5.0-DEV" AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LinearOperators = "5c8ed15e-5a4c-59e4-a42b-c7e8811fb125" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" @@ -14,14 +15,15 @@ SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [extensions] ImplicitDifferentiationChainRulesExt = "ChainRulesCore" ImplicitDifferentiationForwardDiffExt = "ForwardDiff" +ImplicitDifferentiationStaticArraysExt = "StaticArrays" [compat] AbstractDifferentiation = "0.5" -Aqua = "0.6.1" ChainRulesCore = "1.14" ForwardDiff = "0.10" Krylov = "0.8, 0.9" @@ -40,14 +42,14 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "Documenter", "ForwardDiff", "JET", "JuliaFormatter", "LinearAlgebra", "NLsolve", "Optim", "Pkg", "Random", "SparseArrays", "Test", "Zygote"] +test = ["Aqua", "ChainRulesCore", "ChainRulesTestUtils", "ComponentArrays", "Documenter", "ForwardDiff", "JET", "JuliaFormatter", "NLsolve", "Optim", "Pkg", "Random", "SparseArrays", "StaticArrays", "Test", "Zygote"] diff --git a/docs/Manifest.toml b/docs/Manifest.toml index bdecaf8..0a2dbdb 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.9.0" +julia_version = "1.9.2" manifest_format = "2.0" -project_hash = "e7f4896b7e8c3921c7466749f467ab7680867992" +project_hash = "ff84ddc3d5227f964f2cd507ce5cbc83b4fba207" [[deps.AMD]] deps = ["Libdl", "LinearAlgebra", "SparseArrays", "Test"] @@ -40,9 +40,9 @@ version = "0.5.2" [[deps.AbstractFFTs]] deps = ["LinearAlgebra"] -git-tree-sha1 = "16b6dbc4cf7caee4e1e75c49485ec67b667098a0" +git-tree-sha1 = "cad4c758c0038eea30394b1b671526921ca85b21" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.3.1" +version = "1.4.0" weakdeps = ["ChainRulesCore"] [deps.AbstractFFTs.extensions] @@ -50,9 +50,9 @@ weakdeps = ["ChainRulesCore"] [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "cc37d689f599e8df4f464b2fa3870ff7db7492ef" +git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.6.1" +version = "3.6.2" weakdeps = ["StaticArrays"] [deps.Adapt.extensions] @@ -63,10 +63,10 @@ uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" version = "1.1.1" [[deps.ArrayInterface]] -deps = ["Adapt", "LinearAlgebra", "Requires", "SnoopPrecompile", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "38911c7737e123b28182d89027f4216cfc8a9da7" +deps = ["Adapt", "LinearAlgebra", "Requires", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "f83ec24f76d4c8f525099b2ac475fc098138ec31" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "7.4.3" +version = "7.4.11" [deps.ArrayInterface.extensions] ArrayInterfaceBandedMatricesExt = "BandedMatrices" @@ -97,21 +97,21 @@ version = "0.4.2" [[deps.ChainRules]] deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"] -git-tree-sha1 = "7d20c2fb8ab838e41069398685e7b6b5f89ed85b" +git-tree-sha1 = "f98ae934cd677d51d2941088849f0bf2f59e6f6e" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.48.0" +version = "1.53.0" [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "c6d890a52d2c4d55d326439580c3b8d0875a77d9" +git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.15.7" +version = "1.16.0" [[deps.ChainRulesTestUtils]] deps = ["ChainRulesCore", "Compat", "FiniteDifferences", "LinearAlgebra", "Random", "Test"] -git-tree-sha1 = "1237bdbcfec728721718ef57dcb855a19c11bf3a" +git-tree-sha1 = "5ab2a7bc21ecc3eb0226478ff8f87e9685b11818" uuid = "cdddcdb0-9152-4a09-a978-84456f9df70a" -version = "1.10.1" +version = "1.11.0" [[deps.CommonSubexpressions]] deps = ["MacroTools", "Test"] @@ -121,9 +121,9 @@ version = "0.3.0" [[deps.Compat]] deps = ["UUIDs"] -git-tree-sha1 = "7a60c856b9fa189eb34f5f8a6f6b5529b7942957" +git-tree-sha1 = "5ce999a19f4ca23ea484e92a1774a61b8ca4cf8e" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.6.1" +version = "4.8.0" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -132,13 +132,13 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.2+0" +version = "1.0.5+0" [[deps.ComponentArrays]] deps = ["ArrayInterface", "ChainRulesCore", "ForwardDiff", "Functors", "LinearAlgebra", "Requires", "StaticArrayInterface"] -git-tree-sha1 = "891f08177789faff56f0deda1e23615ec220ce44" +git-tree-sha1 = "e1a6694a62f7a00cf2c4f65dccd118ac28a6c099" uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" -version = "0.13.12" +version = "0.14.0" [deps.ComponentArrays.extensions] ComponentArraysConstructionBaseExt = "ConstructionBase" @@ -158,28 +158,28 @@ version = "0.13.12" [[deps.ConstructionBase]] deps = ["LinearAlgebra"] -git-tree-sha1 = "89a9db8d28102b094992472d333674bd1a83ce2a" +git-tree-sha1 = "fe2838a593b5f776e1597e086dcd47560d94e816" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.1" +version = "1.5.3" [deps.ConstructionBase.extensions] - IntervalSetsExt = "IntervalSets" - StaticArraysExt = "StaticArrays" + ConstructionBaseIntervalSetsExt = "IntervalSets" + ConstructionBaseStaticArraysExt = "StaticArrays" [deps.ConstructionBase.weakdeps] IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [[deps.DataAPI]] -git-tree-sha1 = "e8119c1a33d267e16108be441a287a6981ba1630" +git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.14.0" +version = "1.15.0" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "d1fff3a548102f48987a52a2e0d114fa97d730f0" +git-tree-sha1 = "cf25ccb972fec4e4817764d01c82386ae94f77b4" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.13" +version = "0.18.14" [[deps.DataValueInterfaces]] git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" @@ -198,15 +198,19 @@ version = "1.1.0" [[deps.DiffRules]] deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "a4ad7ef19d2cdc2eff57abbbe68032b1cd0bd8f8" +git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.13.0" +version = "1.15.1" [[deps.Distances]] -deps = ["LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "49eba9ad9f7ead780bfb7ee319f962c811c6d3b2" +deps = ["LinearAlgebra", "Statistics", "StatsAPI"] +git-tree-sha1 = "b6def76ffad15143924a2199f72a5cd883a2e8a9" uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.8" +version = "0.10.9" +weakdeps = ["SparseArrays"] + + [deps.Distances.extensions] + DistancesSparseArraysExt = "SparseArrays" [[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] @@ -220,9 +224,9 @@ version = "0.9.3" [[deps.Documenter]] deps = ["ANSIColoredPrinters", "Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "58fea7c536acd71f3eef6be3b21c0df5f3df88fd" +git-tree-sha1 = "39fd748a73dce4c05a9655475e437170d8fb1b67" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.27.24" +version = "0.27.25" [[deps.Downloads]] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] @@ -230,9 +234,9 @@ uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" version = "1.6.0" [[deps.ExprTools]] -git-tree-sha1 = "c1d06d129da9f55715c6c212866f5b1bddc5fa00" +git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" -version = "0.1.9" +version = "0.1.10" [[deps.FastClosures]] git-tree-sha1 = "acebe244d53ee1b461970f8910c235b259e772ef" @@ -244,21 +248,31 @@ uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] -git-tree-sha1 = "7072f1e3e5a8be51d525d64f63d3ec1287ff2790" +git-tree-sha1 = "f372472e8672b1d993e93dada09e23139b509f9e" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.13.11" +version = "1.5.0" [[deps.FiniteDiff]] -deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays", "StaticArrays"] -git-tree-sha1 = "03fcb1c42ec905d15b305359603888ec3e65f886" +deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays"] +git-tree-sha1 = "c6e4a1fbe73b31a3dea94b1da449503b8830c306" uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" -version = "2.19.0" +version = "2.21.1" + + [deps.FiniteDiff.extensions] + FiniteDiffBandedMatricesExt = "BandedMatrices" + FiniteDiffBlockBandedMatricesExt = "BlockBandedMatrices" + FiniteDiffStaticArraysExt = "StaticArrays" + + [deps.FiniteDiff.weakdeps] + BandedMatrices = "aae01518-5342-5314-be14-df237901396f" + BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [[deps.FiniteDifferences]] deps = ["ChainRulesCore", "LinearAlgebra", "Printf", "Random", "Richardson", "SparseArrays", "StaticArrays"] -git-tree-sha1 = "3f605dd6db5640c5278f2551afc9427656439f42" +git-tree-sha1 = "549636fd813ddf1816d8407efb23f486822f4b63" uuid = "26cc04aa-876d-5657-8c51-4c34ba976000" -version = "0.12.26" +version = "0.12.29" [[deps.ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] @@ -272,9 +286,9 @@ weakdeps = ["StaticArrays"] [[deps.Functors]] deps = ["LinearAlgebra"] -git-tree-sha1 = "478f8c3145bb91d82c2cf20433e8c1b30df454cc" +git-tree-sha1 = "9a68d75d466ccc1218d0552a8e1631151c569545" uuid = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -version = "0.4.4" +version = "0.4.5" [[deps.Future]] deps = ["Random"] @@ -282,27 +296,27 @@ uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" [[deps.GPUArrays]] deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "9ade6983c3dbbd492cf5729f865fe030d1541463" +git-tree-sha1 = "2e57b4a4f9cc15e85a24d603256fe08e527f48d1" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "8.6.6" +version = "8.8.1" [[deps.GPUArraysCore]] deps = ["Adapt"] -git-tree-sha1 = "1cd7f0af1aa58abc02ea1d872953a97359cb87fa" +git-tree-sha1 = "2d6ca471a6c7b536127afccfa7564b5b39227fe0" uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.4" +version = "0.1.5" [[deps.IOCapture]] deps = ["Logging", "Random"] -git-tree-sha1 = "f7be53659ab06ddc986428d3a9dcc95f6fa6705a" +git-tree-sha1 = "d75853a0bdbfb1ac815478bacd89cd27b550ace6" uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" -version = "0.2.2" +version = "0.2.3" [[deps.IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "0ade27f0c49cebd8db2523c4eeccf779407cf12c" +git-tree-sha1 = "eac00994ce3229a464c2847e956d77a2c64ad3a5" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.9" +version = "0.4.10" [[deps.IfElse]] git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" @@ -310,15 +324,16 @@ uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" version = "0.1.1" [[deps.ImplicitDifferentiation]] -deps = ["AbstractDifferentiation", "Krylov", "LinearOperators", "Requires", "SimpleUnPack"] +deps = ["AbstractDifferentiation", "Krylov", "LinearAlgebra", "LinearOperators", "Requires", "SimpleUnPack"] path = ".." uuid = "57b37032-215b-411a-8a7c-41a003a55207" version = "0.5.0-DEV" -weakdeps = ["ChainRulesCore", "ForwardDiff"] +weakdeps = ["ChainRulesCore", "ForwardDiff", "StaticArrays"] [deps.ImplicitDifferentiation.extensions] ImplicitDifferentiationChainRulesExt = "ChainRulesCore" ImplicitDifferentiationForwardDiffExt = "ForwardDiff" + ImplicitDifferentiationStaticArraysExt = "StaticArrays" [[deps.InteractiveUtils]] deps = ["Markdown"] @@ -348,9 +363,9 @@ version = "0.21.4" [[deps.Krylov]] deps = ["LinearAlgebra", "Printf", "SparseArrays"] -git-tree-sha1 = "dd90aacbfb622f898a97c2a4411ac49101ebab8a" +git-tree-sha1 = "6dc4ad9cd74ad4ca0a8e219e945dbd22039f2125" uuid = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" -version = "0.9.0" +version = "0.9.2" [[deps.LDLFactorizations]] deps = ["AMD", "LinearAlgebra", "SparseArrays", "Test"] @@ -360,15 +375,15 @@ version = "0.10.0" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "a8960cae30b42b66dd41808beb76490519f6f9e2" +git-tree-sha1 = "8695a49bfe05a2dc0feeefd06b4ca6361a018729" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "5.0.0" +version = "6.1.0" [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "09b7505cc0b1cee87e5d4a26eea61d2e1b0dcd35" +git-tree-sha1 = "c35203c1e1002747da220ffc3c0762ce7754b08c" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.21+0" +version = "0.0.23+0" [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] @@ -420,9 +435,9 @@ version = "2.14.0" [[deps.LogExpFunctions]] deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "0a1b7c2863e44523180fdb3146534e265a91870b" +git-tree-sha1 = "c3ce8e7420b3a6e071e0fe4745f5d4300e37b13f" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.23" +version = "0.3.24" [deps.LogExpFunctions.extensions] LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" @@ -505,14 +520,14 @@ version = "0.5.5+0" [[deps.Optim]] deps = ["Compat", "FillArrays", "ForwardDiff", "LineSearches", "LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "PositiveFactorizations", "Printf", "SparseArrays", "StatsBase"] -git-tree-sha1 = "a89b11f0f354f06099e4001c151dffad7ebab015" +git-tree-sha1 = "e3a6546c1577bfd701771b477b794a52949e7594" uuid = "429524aa-4258-5aef-a3af-852621145aeb" -version = "1.7.5" +version = "1.7.6" [[deps.OrderedCollections]] -git-tree-sha1 = "d321bf2de576bf25ec4d3e4360faca399afca282" +git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.0" +version = "1.6.2" [[deps.Parameters]] deps = ["OrderedCollections", "UnPack"] @@ -521,15 +536,15 @@ uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" version = "0.12.3" [[deps.Parsers]] -deps = ["Dates", "SnoopPrecompile"] -git-tree-sha1 = "478ac6c952fddd4399e71d4779797c538d0ff2bf" +deps = ["Dates", "PrecompileTools", "UUIDs"] +git-tree-sha1 = "4b2e829ee66d4218e0cef22c0a64ee37cf258c29" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.5.8" +version = "2.7.1" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.9.0" +version = "1.9.2" [[deps.PositiveFactorizations]] deps = ["LinearAlgebra"] @@ -537,11 +552,17 @@ git-tree-sha1 = "17275485f373e6673f7e7f97051f703ed5b15b20" uuid = "85a6dd25-e78a-55b7-8502-1745935b8125" version = "0.2.4" +[[deps.PrecompileTools]] +deps = ["Preferences"] +git-tree-sha1 = "9673d39decc5feece56ef3940e5dafba15ba0f81" +uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +version = "1.1.2" + [[deps.Preferences]] deps = ["TOML"] -git-tree-sha1 = "47e5f437cc0e7ef2ce8406ce1e7e24d44915f88d" +git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1" uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.3.0" +version = "1.4.0" [[deps.Printf]] deps = ["Unicode"] @@ -607,9 +628,9 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc" [[deps.SortingAlgorithms]] deps = ["DataStructures"] -git-tree-sha1 = "a4ada03f999bd01b3a25dcaa30b2d929fe537e00" +git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee" uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.1.0" +version = "1.1.1" [[deps.SparseArrays]] deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] @@ -617,9 +638,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[deps.SpecialFunctions]] deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "ef28127915f4229c971eb43f3fc075dd3fe91880" +git-tree-sha1 = "7beb031cf8145577fbccacd94b8a8f4ce78428d3" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.2.0" +version = "2.3.0" weakdeps = ["ChainRulesCore"] [deps.SpecialFunctions.extensions] @@ -627,9 +648,9 @@ weakdeps = ["ChainRulesCore"] [[deps.Static]] deps = ["IfElse"] -git-tree-sha1 = "dbde6766fc677423598138a5951269432b0fcc90" +git-tree-sha1 = "f295e0a1da4ca425659c57441bcb59abb035a4bc" uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" -version = "0.8.7" +version = "0.8.8" [[deps.StaticArrayInterface]] deps = ["ArrayInterface", "Compat", "IfElse", "LinearAlgebra", "Requires", "SnoopPrecompile", "SparseArrays", "Static", "SuiteSparse"] @@ -646,15 +667,19 @@ version = "1.4.0" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [[deps.StaticArrays]] -deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] -git-tree-sha1 = "63e84b7fdf5021026d0f17f76af7c57772313d99" +deps = ["LinearAlgebra", "Random", "StaticArraysCore"] +git-tree-sha1 = "9cabadf6e7cd2349b6cf49f1915ad2028d65e881" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.5.21" +version = "1.6.2" +weakdeps = ["Statistics"] + + [deps.StaticArrays.extensions] + StaticArraysStatisticsExt = "Statistics" [[deps.StaticArraysCore]] -git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a" +git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.0" +version = "1.4.2" [[deps.Statistics]] deps = ["LinearAlgebra", "SparseArrays"] @@ -669,9 +694,9 @@ version = "1.6.0" [[deps.StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "d1bf48bfcc554a3761a133fe3a9bb01488e06916" +git-tree-sha1 = "75ebe04c5bed70b91614d684259b661c9e6274a4" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.21" +version = "0.34.0" [[deps.StructArrays]] deps = ["Adapt", "DataAPI", "GPUArraysCore", "StaticArraysCore", "Tables"] @@ -716,9 +741,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[deps.TimerOutputs]] deps = ["ExprTools", "Printf"] -git-tree-sha1 = "f2fd3f288dfc6f507b0c3a2eb3bac009251e548b" +git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.22" +version = "0.5.23" [[deps.UUIDs]] deps = ["Random", "SHA"] @@ -738,10 +763,10 @@ uuid = "83775a58-1f1d-513f-b197-d71354ab007a" version = "1.2.13+0" [[deps.Zygote]] -deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "Requires", "SnoopPrecompile", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "987ae5554ca90e837594a0f30325eeb5e7303d1e" +deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "PrecompileTools", "Random", "Requires", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] +git-tree-sha1 = "5be3ddb88fc992a7d8ea96c3f10a49a7e98ebc7b" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.60" +version = "0.6.62" [deps.Zygote.extensions] ZygoteColorsExt = "Colors" @@ -762,7 +787,7 @@ version = "0.2.3" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.7.0+0" +version = "5.8.0+0" [[deps.nghttp2_jll]] deps = ["Artifacts", "Libdl"] diff --git a/docs/Project.toml b/docs/Project.toml index aa54fc0..521326f 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -10,5 +10,6 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/docs/make.jl b/docs/make.jl index 44850de..5a8c714 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -3,6 +3,7 @@ using Documenter using ForwardDiff: ForwardDiff using ImplicitDifferentiation using Literate +using StaticArrays: StaticArrays DocMeta.setdocmeta!( ImplicitDifferentiation, :DocTestSetup, :(using ImplicitDifferentiation); recursive=true @@ -50,24 +51,40 @@ end pages = [ "Home" => "index.md", - "API reference" => "api.md", "Examples" => example_pages, + "API reference" => "api.md", "FAQ" => "faq.md", ] -format = Documenter.HTML(; +fmt = Documenter.HTML(; prettyurls=get(ENV, "CI", "false") == "true", canonical="https://gdalle.github.io/ImplicitDifferentiation.jl", assets=String[], edit_link=:commit, ) +if isdefined(Base, :get_extension) + extension_modules = [ + Base.get_extension(ImplicitDifferentiation, :ImplicitDifferentiationChainRulesExt) + Base.get_extension(ImplicitDifferentiation, :ImplicitDifferentiationForwardDiffExt) + Base.get_extension( + ImplicitDifferentiation, :ImplicitDifferentiationStaticArraysExt + ) + ] +else + extension_modules = [ + ImplicitDifferentiation.ImplicitDifferentiationChainRulesExt, + ImplicitDifferentiation.ImplicitDifferentiationForwardDiffExt, + ImplicitDifferentiation.ImplicitDifferentiationStaticArraysExt, + ] +end + makedocs(; - modules=[ImplicitDifferentiation], + modules=vcat([ImplicitDifferentiation], extension_modules), authors="Guillaume Dalle, Mohamed Tarek and contributors", repo="https://github.com/gdalle/ImplicitDifferentiation.jl/blob/{commit}{path}#{line}", sitename="ImplicitDifferentiation.jl", - format=format, + format=fmt, pages=pages, linkcheck=true, ) diff --git a/docs/src/api.md b/docs/src/api.md index dac41bc..400d41a 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -4,6 +4,11 @@ ```@docs ImplicitFunction +DirectLinearSolver +IterativeLinearSolver +HandleByproduct +ReturnByproduct +ChainRulesCore.rrule ``` ## Internals @@ -11,6 +16,7 @@ ImplicitFunction ```@docs ImplicitDifferentiation.Forward ImplicitDifferentiation.Conditions +ImplicitDifferentiation.AbstractLinearSolver ImplicitDifferentiation.PushforwardMul! ImplicitDifferentiation.PullbackMul! ``` diff --git a/docs/src/faq.md b/docs/src/faq.md index 425291f..c9e4cf3 100644 --- a/docs/src/faq.md +++ b/docs/src/faq.md @@ -2,22 +2,36 @@ ## Supported autodiff backends -- Forward mode: [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) -- Reverse mode: all the packages compatible with [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) +| Mode | Backend | Support | +| ------- | ---------------------------------------------------------- | ------- | +| Forward | [ForwardDiff.jl] | yes | +| Reverse | [ChainRules.jl]-compatible ([Zygote.jl], [ReverseDiff.jl]) | yes | +| Forward | [ChainRules.jl]-compatible ([Diffractor.jl]) | soon | +| Both | [Enzyme.jl] | someday | -In the future, we would like to add [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) support. +[ForwardDiff.jl]: https://github.com/JuliaDiff/ForwardDiff.jl +[ChainRules.jl]: https://github.com/JuliaDiff/ChainRules.jl +[Zygote.jl]: https://github.com/FluxML/Zygote.jl +[ReverseDiff.jl]: https://github.com/JuliaDiff/ReverseDiff.jl +[Enzyme.jl]: https://github.com/EnzymeAD/Enzyme.jl +[Diffractor.jl]: https://github.com/JuliaDiff/Diffractor.jl -## Higher-dimensional arrays +## Writing conditions + +We recommend that the conditions themselves do not involve calls to autodiff, even when they describe a gradient. +Otherwise, you will need to make sure that nested autodiff works well in your case. +For instance, if you're differentiating your implicit function in reverse mode with Zygote.jl, you may want to use [`Zygote.forwarddiff`](https://fluxml.ai/Zygote.jl/stable/utils/#Zygote.forwarddiff) to wrap the conditions and differentiate them with ForwardDiff.jl instead. + +## Matrices and higher-order arrays For simplicity, our examples only display functions that eat and spit out vectors. However, arbitrary array shapes are supported, as long as the forward mapping _and_ conditions return similar arrays. Beware however, sparse arrays will be densified in the differentiation process. -## Scalar input / output +## Scalars Functions that eat or spit out a single number are not supported. -The forward mapping _and_ conditions need arrays: for example, instead of returning `value` you should return `[value]` (a 1-element `Vector`). -Consider using an `SVector` from [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) if you seek increased performance. +The forward mapping _and_ conditions need arrays: for example, instead of returning `val` you should return `[val]` (a 1-element `Vector`). ## Multiple inputs / outputs @@ -41,17 +55,18 @@ The same trick works for multiple outputs. ## Using byproducts -At first glance, it is not obvious why we impose that the forward mapping should return a byproduct `z` in addition to `y`. -It is mainly useful when the solution procedure creates objects such as Jacobians, which we want to reuse when computing or differentiating the `conditions`. -We will provide simple examples soon. -In the meantime, an advanced application is given by [DifferentiableFrankWolfe.jl](https://github.com/gdalle/DifferentiableFrankWolfe.jl). +Why would the forward mapping return a byproduct `z` in addition to `y`? +It is mainly useful when the solution procedure creates objects such as Jacobians, which we want to reuse when computing or differentiating the conditions. +In that case, you may want to write the differentiation rules yourself for the conditions. +A more advanced application is given by [DifferentiableFrankWolfe.jl](https://github.com/gdalle/DifferentiableFrankWolfe.jl). + +Keep in mind that derivatives of `z` will not be computed: the byproduct is considered constant during differentiation (unlike the case of multiple outputs outlined above). -## Differentiating byproducts +## Performance tips -Nope. Sorry. Don't even think about it. -The package is not designed to compute derivatives of `z`, only `y`, which is why the byproduct is considered constant during differentiation. +If you work with small arrays (say, less than 100 elements), consider using [StaticArrays.jl](https://github.com/JuliaArrays/StaticArrays.jl) if you seek increased performance. -## Modeling constrained optimization problems +## Modeling tips To express constrained optimization problems as implicit functions, you might need differentiable projections or proximal operators to write the optimality conditions. See [_Efficient and modular implicit differentiation_](https://arxiv.org/abs/2105.15183) for precise formulations. diff --git a/docs/src/index.md b/docs/src/index.md index be50c78..42c296d 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -7,7 +7,7 @@ CurrentModule = ImplicitDifferentiation [ImplicitDifferentiation.jl](https://github.com/gdalle/ImplicitDifferentiation.jl) is a package for automatic differentiation of functions defined implicitly, i.e., _forward mappings_ ```math -x \in \mathbb{R}^n \longmapsto y(x) \in \mathbb{R}^m +f: x \in \mathbb{R}^n \longmapsto y(x) \in \mathbb{R}^m ``` whose output is defined by _conditions_ diff --git a/examples/0_basic.jl b/examples/0_intro.jl similarity index 93% rename from examples/0_basic.jl rename to examples/0_intro.jl index 3f85b3d..4e3df7e 100644 --- a/examples/0_basic.jl +++ b/examples/0_intro.jl @@ -1,7 +1,7 @@ -# # Basic use +# # Introduction #= -In this example, we demonstrate the basics of our package on a simple function that is not amenable to automatic differentiation. +We explain the basics of our package on a simple function that is not amenable to naive automatic differentiation. =# using ChainRulesCore #src @@ -147,11 +147,7 @@ We can even go higher-order by mixing the two packages (forward-over-reverse mod The only technical requirement is to switch the linear solver to something that can handle dual numbers: =# -manual_linear_solver(A, b) = (Matrix(A) \ b, (solved=true,)) - -implicit_higher_order = ImplicitFunction( - forward, conditions; linear_solver=manual_linear_solver -) +implicit_higher_order = ImplicitFunction(forward, conditions, DirectLinearSolver()) #= Then the Jacobian itself is differentiable. diff --git a/examples/1_basic.jl b/examples/1_basic.jl new file mode 100644 index 0000000..942e68f --- /dev/null +++ b/examples/1_basic.jl @@ -0,0 +1,242 @@ +# # Basic use cases + +#= +We show how to differentiate through very common routines: +- an unconstrained optimization problem +- a nonlinear system of equations +- a fixed point iteration +=# + +using ForwardDiff +using ImplicitDifferentiation +using LinearAlgebra +using NLsolve +using Optim +using Random +using Test #src +using Zygote + +Random.seed!(63); + +#= +In all three cases, we will use the square root as our forward mapping, but expressed in three different ways. +Here's our heroic test vector: +=# + +x = rand(2); + +#= +Since we already know the mathematical expression of the Jacobian, we will be able to compare it with our numerical results. +=# + +J = Diagonal(0.5 ./ sqrt.(x)) + +# ## Unconstrained optimization + +#= +First, we show how to differentiate through the solution of an unconstrained optimization problem: +```math +y(x) = \underset{y \in \mathbb{R}^m}{\mathrm{argmin}} ~ f(x, y) +``` +The optimality conditions are given by gradient stationarity: +```math +\nabla_2 f(x, y) = 0 +``` +=# + +#= +To make verification easy, we minimize the following objective: +```math +f(x, y) = \lVert y \odot y - x \rVert^2 +``` +In this case, the optimization problem boils down to the componentwise square root function, but we implement it using a black box solver from [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl). +Note the presence of a keyword argument. +=# + +function forward_optim(x; method) + f(y) = sum(abs2, y .^ 2 .- x) + y0 = ones(eltype(x), size(x)) + result = optimize(f, y0, method) + return Optim.minimizer(result) +end + +#= +Even though they are defined as a gradient, it is better to provide optimality conditions explicitly: that way we avoid nesting autodiff calls. By default, the conditions should accept two arguments as input. +The forward mapping and the conditions should accept the same set of keyword arguments. +=# + +function conditions_optim(x, y; method) + ∇₂f = @. 4 * (y^2 - x) * y + return ∇₂f +end + +#= +We now have all the ingredients to construct our implicit function. +=# + +implicit_optim = ImplicitFunction(forward_optim, conditions_optim) + +# And indeed, it behaves as it should when we call it: + +implicit_optim(x; method=LBFGS()) .^ 2 +@test implicit_optim(x; method=LBFGS()) .^ 2 ≈ x #src + +# Forward mode autodiff + +ForwardDiff.jacobian(_x -> implicit_optim(_x; method=LBFGS()), x) +@test ForwardDiff.jacobian(_x -> implicit_optim(_x; method=LBFGS()), x) ≈ J #src + +#= +In this instance, we could use ForwardDiff.jl directly on the solver, but it returns the wrong result (not sure why). +=# + +ForwardDiff.jacobian(_x -> forward_optim(x; method=LBFGS()), x) + +# Reverse mode autodiff + +Zygote.jacobian(_x -> implicit_optim(_x; method=LBFGS()), x)[1] +@test Zygote.jacobian(_x -> implicit_optim(_x; method=LBFGS()), x)[1] ≈ J #src + +#= +In this instance, we cannot use Zygote.jl directly on the solver (due to unsupported `try/catch` statements). +=# + +try + Zygote.jacobian(_x -> forward_optim(x; method=LBFGS()), x)[1] +catch e + e +end + +# ## Nonlinear system + +#= +Next, we show how to differentiate through the solution of a nonlinear system of equations: +```math +\text{find} \quad y(x) \quad \text{such that} \quad F(x, y(x)) = 0 +``` +The optimality conditions are pretty obvious: +```math +F(x, y) = 0 +``` +=# + +#= +To make verification easy, we solve the following system: +```math +F(x, y) = y \odot y - x = 0 +``` +In this case, the optimization problem boils down to the componentwise square root function, but we implement it using a black box solver from [NLsolve.jl](https://github.com/JuliaNLSolvers/NLsolve.jl). +=# + +function forward_nlsolve(x; method) + F!(storage, y) = (storage .= y .^ 2 .- x) + initial_y = similar(x) + initial_y .= 1 + result = nlsolve(F!, initial_y; method) + return result.zero +end + +#- + +function conditions_nlsolve(x, y; method) + c = y .^ 2 .- x + return c +end + +#- + +implicit_nlsolve = ImplicitFunction(forward_nlsolve, conditions_nlsolve) + +#- + +implicit_nlsolve(x; method=:newton) .^ 2 +@test implicit_nlsolve(x; method=:newton) .^ 2 ≈ x #src + +# Forward mode autodiff + +ForwardDiff.jacobian(_x -> implicit_nlsolve(_x; method=:newton), x) +@test ForwardDiff.jacobian(_x -> implicit_nlsolve(_x; method=:newton), x) ≈ J #src + +#- + +ForwardDiff.jacobian(_x -> forward_nlsolve(_x; method=:newton), x) + +# Reverse mode autodiff + +Zygote.jacobian(_x -> implicit_nlsolve(_x; method=:newton), x)[1] +@test Zygote.jacobian(_x -> implicit_nlsolve(_x; method=:newton), x)[1] ≈ J #src + +#- + +try + Zygote.jacobian(_x -> forward_nlsolve(_x; method=:newton), x)[1] +catch e + e +end + +# ## Fixed point + +#= +Finally, we show how to differentiate through the limit of a fixed point iteration: +```math +y \longmapsto T(x, y) +``` +The optimality conditions are pretty obvious: +```math +y = T(x, y) +``` +=# + +#= +To make verification easy, we consider [Heron's method](https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Heron's_method): +```math +T(x, y) = \frac{1}{2} \left(y + \frac{x}{y}\right) +``` +In this case, the fixed point algorithm boils down to the componentwise square root function, but we implement it manually. +=# + +function forward_fixedpoint(x; iterations) + y = ones(eltype(x), size(x)) + for _ in 1:iterations + y .= 0.5 .* (y .+ x ./ y) + end + return y +end + +#- + +function conditions_fixedpoint(x, y; iterations) + T = 0.5 .* (y .+ x ./ y) + return T .- y +end + +#- + +implicit_fixedpoint = ImplicitFunction(forward_fixedpoint, conditions_fixedpoint) + +#- + +implicit_fixedpoint(x; iterations=10) .^ 2 +@test implicit_fixedpoint(x; iterations=10) .^ 2 ≈ x #src + +# Forward mode autodiff + +ForwardDiff.jacobian(_x -> implicit_fixedpoint(_x; iterations=10), x) +@test ForwardDiff.jacobian(_x -> implicit_fixedpoint(_x; iterations=10), x) ≈ J #src + +#- + +ForwardDiff.jacobian(_x -> forward_fixedpoint(_x; iterations=10), x) + +# Reverse mode autodiff + +Zygote.jacobian(_x -> implicit_fixedpoint(_x; iterations=10), x)[1] +@test Zygote.jacobian(_x -> implicit_fixedpoint(_x; iterations=10), x)[1] ≈ J #src + +#- + +try + Zygote.jacobian(_x -> forward_fixedpoint(_x; iterations=10), x)[1] +catch e + e +end diff --git a/examples/1_unconstrained_optim.jl b/examples/1_unconstrained_optim.jl deleted file mode 100644 index 44b7707..0000000 --- a/examples/1_unconstrained_optim.jl +++ /dev/null @@ -1,99 +0,0 @@ -# # Unconstrained optimization - -#= -In this example, we show how to differentiate through the solution of an unconstrained optimization problem: -```math -y(x) = \underset{y \in \mathbb{R}^m}{\mathrm{argmin}} ~ f(x, y) -``` -The optimality conditions are given by gradient stationarity: -```math -\nabla_2 f(x, y) = 0 -``` -=# - -using ForwardDiff -using ImplicitDifferentiation -using LinearAlgebra -using Optim -using Random -using Test #src -using Zygote - -Random.seed!(63); - -# ## Implicit function - -#= -To make verification easy, we minimize the following objective: -```math -f(x, y) = \lVert y \odot y - x \rVert^2 -``` -In this case, the optimization problem boils down to the componentwise square root function, but we implement it using a black box solver from [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl). -Note the presence of a keyword argument. -=# - -function forward_optim(x; method) - f(y) = sum(abs2, y .^ 2 .- x) - y0 = ones(eltype(x), size(x)) - result = optimize(f, y0, method) - return Optim.minimizer(result) -end - -#= -Even though they are defined as a gradient, it is better to provide optimality conditions explicitly: that way we avoid nesting autodiff calls. By default, the conditions should accept two arguments as input. -The forward mapping and the conditions should accept the same set of keyword arguments. -=# - -function conditions_optim(x, y; method) - ∇₂f = 2 .* (y .^ 2 .- x) - return ∇₂f -end - -#= -We now have all the ingredients to construct our implicit function. -=# - -implicit_optim = ImplicitFunction(forward_optim, conditions_optim) - -# And indeed, it behaves as it should when we call it: - -x = rand(2) - -#- - -implicit_optim(x; method=LBFGS()) .^ 2 -@test implicit_optim(x; method=LBFGS()) .^ 2 ≈ x #src - -#= -Let's see what the explicit Jacobian looks like. -=# - -J = Diagonal(0.5 ./ sqrt.(x)) - -# ## Forward mode autodiff - -ForwardDiff.jacobian(_x -> implicit_optim(_x; method=LBFGS()), x) -@test ForwardDiff.jacobian(_x -> implicit_optim(_x; method=LBFGS()), x) ≈ J #src - -#= -Unsurprisingly, the Jacobian is the identity. -In this instance, we could use ForwardDiff.jl directly on the solver, but it returns the wrong result (not sure why). -=# - -ForwardDiff.jacobian(_x -> forward_optim(x; method=LBFGS()), x) - -# ## Reverse mode autodiff - -Zygote.jacobian(_x -> implicit_optim(_x; method=LBFGS()), x)[1] -@test Zygote.jacobian(_x -> implicit_optim(_x; method=LBFGS()), x)[1] ≈ J #src - -#= -Again, the Jacobian is the identity. -In this instance, we cannot use Zygote.jl directly on the solver (due to unsupported `try/catch` statements). -=# - -try - Zygote.jacobian(_x -> forward_optim(x; method=LBFGS()), x)[1] -catch e - e -end diff --git a/examples/4_constrained_optim.jl b/examples/2_advanced.jl similarity index 86% rename from examples/4_constrained_optim.jl rename to examples/2_advanced.jl index 6a59ab8..d1d48ff 100644 --- a/examples/4_constrained_optim.jl +++ b/examples/2_advanced.jl @@ -1,15 +1,8 @@ -# # Constrained optimization +# # Advanced use cases #= -In this example, we show how to differentiate through the solution of a constrained optimization problem: -```math -y(x) = \underset{y \in \mathbb{R}^m}{\mathrm{argmin}} ~ f(x, y) \quad \text{subject to} \quad g(x, y) \leq 0 -``` -The optimality conditions are a bit trickier than in the previous cases. -We can projection on the feasible set $\mathcal{C}(x) = \{y: g(x, y) \leq 0 \}$ and exploit the convergence of projected gradient descent with step size $\eta$: -```math -y = \mathrm{proj}_{\mathcal{C}(x)} (y - \eta \nabla_2 f(x, y)) -``` +We dive into more advanced applications of implicit differentiation: +- constrained optimization problems =# using ForwardDiff @@ -22,7 +15,19 @@ using Zygote Random.seed!(63); -# ## Implicit function +# ## Constrained optimization + +#= +First, we show how to differentiate through the solution of a constrained optimization problem: +```math +y(x) = \underset{y \in \mathbb{R}^m}{\mathrm{argmin}} ~ f(x, y) \quad \text{subject to} \quad g(x, y) \leq 0 +``` +The optimality conditions are a bit trickier than in the previous cases. +We can projection on the feasible set $\mathcal{C}(x) = \{y: g(x, y) \leq 0 \}$ and exploit the convergence of projected gradient descent with step size $\eta$: +```math +y = \mathrm{proj}_{\mathcal{C}(x)} (y - \eta \nabla_2 f(x, y)) +``` +=# #= To make verification easy, we minimize the following objective: @@ -50,7 +55,7 @@ function proj_hypercube(p) end function conditions_cstr_optim(x, y) - ∇₂f = 2 .* (y .^ 2 .- x) + ∇₂f = @. 4 * (y^2 - x) * y η = 0.1 return y .- proj_hypercube(y .- η .* ∇₂f) end @@ -74,7 +79,7 @@ implicit_cstr_optim(x) .^ 2 J_thres = Diagonal([0.5 / sqrt(x[1]), 0]) -# ## Forward mode autodiff +# Forward mode autodiff ForwardDiff.jacobian(implicit_cstr_optim, x) @test ForwardDiff.jacobian(implicit_cstr_optim, x) ≈ J_thres #src @@ -83,7 +88,7 @@ ForwardDiff.jacobian(implicit_cstr_optim, x) ForwardDiff.jacobian(forward_cstr_optim, x) -# ## Reverse mode autodiff +# Reverse mode autodiff Zygote.jacobian(implicit_cstr_optim, x)[1] @test Zygote.jacobian(implicit_cstr_optim, x)[1] ≈ J_thres #src diff --git a/examples/2_nonlinear_solve.jl b/examples/2_nonlinear_solve.jl deleted file mode 100644 index 103a910..0000000 --- a/examples/2_nonlinear_solve.jl +++ /dev/null @@ -1,85 +0,0 @@ -# # Nonlinear solve - -#= -In this example, we show how to differentiate through the solution of a nonlinear system of equations: -```math -\text{find} \quad y(x) \quad \text{such that} \quad F(x, y(x)) = 0 -``` -The optimality conditions are pretty obvious: -```math -F(x, y) = 0 -``` -=# - -using ForwardDiff -using ImplicitDifferentiation -using LinearAlgebra -using NLsolve -using Random -using Test #src -using Zygote - -Random.seed!(63); - -# ## Implicit function - -#= -To make verification easy, we solve the following system: -```math -F(x, y) = y \odot y - x = 0 -``` -In this case, the optimization problem boils down to the componentwise square root function, but we implement it using a black box solver from [NLsolve.jl](https://github.com/JuliaNLSolvers/NLsolve.jl). -=# - -function forward_nlsolve(x; method) - F!(storage, y) = (storage .= y .^ 2 - x) - initial_y = ones(eltype(x), size(x)) - result = nlsolve(F!, initial_y; method) - return result.zero -end - -#- - -function conditions_nlsolve(x, y; method) - c = y .^ 2 .- x - return c -end - -#- - -implicit_nlsolve = ImplicitFunction(forward_nlsolve, conditions_nlsolve) - -#- - -x = rand(2) - -#- - -implicit_nlsolve(x; method=:newton) .^ 2 -@test implicit_nlsolve(x; method=:newton) .^ 2 ≈ x #src - -#- - -J = Diagonal(0.5 ./ sqrt.(x)) - -# ## Forward mode autodiff - -ForwardDiff.jacobian(_x -> implicit_nlsolve(_x; method=:newton), x) -@test ForwardDiff.jacobian(_x -> implicit_nlsolve(_x; method=:newton), x) ≈ J #src - -#- - -ForwardDiff.jacobian(_x -> forward_nlsolve(_x; method=:newton), x) - -# ## Reverse mode autodiff - -Zygote.jacobian(_x -> implicit_nlsolve(_x; method=:newton), x)[1] -@test Zygote.jacobian(_x -> implicit_nlsolve(_x; method=:newton), x)[1] ≈ J #src - -#- - -try - Zygote.jacobian(_x -> forward_nlsolve(_x; method=:newton), x)[1] -catch e - e -end diff --git a/examples/3_fixed_points.jl b/examples/3_fixed_points.jl deleted file mode 100644 index 1741c5d..0000000 --- a/examples/3_fixed_points.jl +++ /dev/null @@ -1,85 +0,0 @@ -# # Fixed point - -#= -In this example, we show how to differentiate through the limit of a fixed point iteration: -```math -y \longmapsto T(x, y) -``` -The optimality conditions are pretty obvious: -```math -y = T(x, y) -``` -=# - -using ForwardDiff -using ImplicitDifferentiation -using LinearAlgebra -using Random -using Test #src -using Zygote - -Random.seed!(63); - -# ## Implicit function - -#= -To make verification easy, we consider [Heron's method](https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Heron's_method): -```math -T(x, y) = \frac{1}{2} \left(y + \frac{x}{y}\right) -``` -In this case, the fixed point algorithm boils down to the componentwise square root function, but we implement it manually. -=# - -function forward_fixedpoint(x; iterations) - y = ones(eltype(x), size(x)) - for _ in 1:iterations - y .= 0.5 .* (y .+ x ./ y) - end - return y -end - -#- - -function conditions_fixedpoint(x, y; iterations) - T = 0.5 .* (y .+ x ./ y) - return T .- y -end - -#- - -implicit_fixedpoint = ImplicitFunction(forward_fixedpoint, conditions_fixedpoint) - -#- - -x = rand(2) - -#- - -implicit_fixedpoint(x; iterations=10) .^ 2 -@test implicit_fixedpoint(x; iterations=10) .^ 2 ≈ x #src - -#- - -J = Diagonal(0.5 ./ sqrt.(x)) - -# ## Forward mode autodiff - -ForwardDiff.jacobian(_x -> implicit_fixedpoint(_x; iterations=10), x) -@test ForwardDiff.jacobian(_x -> implicit_fixedpoint(_x; iterations=10), x) ≈ J #src - -#- - -ForwardDiff.jacobian(_x -> forward_fixedpoint(_x; iterations=10), x) - -# ## Reverse mode autodiff - -Zygote.jacobian(_x -> implicit_fixedpoint(_x; iterations=10), x)[1] -@test Zygote.jacobian(_x -> implicit_fixedpoint(_x; iterations=10), x)[1] ≈ J #src - -#- - -try - Zygote.jacobian(_x -> forward_fixedpoint(_x; iterations=10), x)[1] -catch e - e -end diff --git a/examples/5_multiargs.jl b/examples/3_tricks.jl similarity index 51% rename from examples/5_multiargs.jl rename to examples/3_tricks.jl index 27c7301..ded836f 100644 --- a/examples/5_multiargs.jl +++ b/examples/3_tricks.jl @@ -1,11 +1,7 @@ -# # Multiple arguments +# # Tricks #= -In this example, we explain what to do when your function takes multiple input arguments: -```math -y(a, b) = \underset{y \in \mathbb{R}^m}{\mathrm{argmin}} ~ f(a, b, y) -``` -The key idea is to store both $a$ and $b$ inside a single vector $x$. +We demonstrate several features that may come in handy for some users. =# using ComponentArrays @@ -19,7 +15,15 @@ using Zygote Random.seed!(63); -# ## Implicit function +# ## Multiple arguments + +#= +First, we explain what to do when your forward mapping takes multiple input arguments: +```math +y(a, b) = \underset{y \in \mathbb{R}^m}{\mathrm{argmin}} ~ f(a, b, y) +``` +The key idea is to store both $a$ and $b$ inside a single vector $x$. +=# #= To make verification easy, we minimize the following objective: @@ -71,12 +75,76 @@ Let's see what the explicit Jacobian looks like. J = hcat(Diagonal(0.5 ./ sqrt.(x.a + 2x.b)), 2 * Diagonal(0.5 ./ sqrt.(x.a + 2x.b))) -# ## Forward mode autodiff +# Forward mode autodiff ForwardDiff.jacobian(implicit_components, x) @test ForwardDiff.jacobian(implicit_components, x) ≈ J #src -# ## Reverse mode autodiff +# Reverse mode autodiff Zygote.jacobian(implicit_components, x)[1] @test Zygote.jacobian(implicit_components, x)[1] ≈ J #src + +# ## Byproducts + +#= +Next, we explain what to do when your forward mapping computes another object that you want to keep track of, which we will call its "byproduct". +The difference between this and multiple outputs (which should be managed with ComponentArrays.jl) is that _we do not compute derivatives with respect to byproducts_. +=# + +#= +Imagine a situation where, depending on a coin toss, said mapping either doubles or halves the input. +After all, why not? +For each individual run, the algorithmic derivative is well-defined. +But to obtain it, you need to store the result of the toss. +=# + +function forward_cointoss(x) + z = rand(Bool) + if z + y = 2x + else + y = x / 2 + end + return y, z +end + +#= +And naturally, the optimality condition also depends on the toss. +=# + +function conditions_cointoss(x, y, z) + if z + return y .- 2x + else + return 2y .- x + end +end + +#= +To make sure that the implicit function you create takes this byproduct into account, just construct it like this: +=# + +implicit_cointoss = ImplicitFunction( + forward_cointoss, conditions_cointoss, HandleByproduct() +) + +#= +Then you have two ways of calling the function: the standard way will only return `y` +=# + +x = [1.0, 1.0] + +implicit_cointoss(x) + +#= +Or if you also need the byproduct, you can do +=# + +implicit_cointoss(x, ReturnByproduct()) + +#= +But whatever you choose, the byproduct is taken into account during differentiation! +=# + +Zygote.withjacobian(implicit_cointoss, x) diff --git a/examples/6_byproduct.jl b/examples/6_byproduct.jl deleted file mode 100644 index b20145b..0000000 --- a/examples/6_byproduct.jl +++ /dev/null @@ -1,8 +0,0 @@ -# # Byproduct - -#= -In this example, we show that returning a byproduct with the forward mapping can be useful. - -!!! warning "Work in progress" - Come back soon! -=# diff --git a/ext/ImplicitDifferentiationChainRulesExt.jl b/ext/ImplicitDifferentiationChainRulesExt.jl index dbfbc55..cfedb03 100644 --- a/ext/ImplicitDifferentiationChainRulesExt.jl +++ b/ext/ImplicitDifferentiationChainRulesExt.jl @@ -1,35 +1,37 @@ module ImplicitDifferentiationChainRulesExt using AbstractDifferentiation: ReverseRuleConfigBackend, pullback_function -using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, ZeroTangent, unthunk -using ImplicitDifferentiation: ImplicitFunction, PullbackMul!, check_solution +using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig, ZeroTangent, rrule, unthunk +using ImplicitDifferentiation: + ImplicitFunction, PullbackMul!, ReturnByproduct, presolve, solve +using LinearAlgebra: lmul!, mul! using LinearOperators: LinearOperator using SimpleUnPack: @unpack """ rrule(rc, implicit, x; kwargs...) - rrule(rc, implicit, x, Val(return_byproduct); kwargs...) + rrule(rc, implicit, x, ReturnByproduct(); kwargs...) Custom reverse rule for an [`ImplicitFunction`](@ref), to ensure compatibility with reverse mode autodiff. -This is only available if ChainRulesCore.jl is loaded (extension). +This is only available if ChainRulesCore.jl is loaded (extension), except on Julia < 1.9 where it is always available. -- If `return_byproduct=false` (the default), this returns a single output `y(x)` with a pullback accepting a single cotangent `̄y`. -- If `return_byproduct=true`, this returns a couple of outputs `(y(x),z(x))` with a pullback accepting a couple of cotangents `(̄y, ̄z)` (remember that `z(x)` is not differentiated so its cotangent is ignored). +- By default, this returns a single output `y(x)` with a pullback accepting a single cotangent `dy`. +- If `ReturnByproduct()` is passed as an argument, this returns a couple of outputs `(y(x),z(x))` with a pullback accepting a couple of cotangents `(dy, dz)` (remember that `z(x)` is not differentiated so its cotangent is ignored). -We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and setting `Jᵀv = -Bᵀu`. +We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and setting `Jᵀv = -Bᵀu` (see [`ImplicitFunction`](@ref) for the definition of `A` and `B`). Keyword arguments are given to both `implicit.forward` and `implicit.conditions`. """ function ChainRulesCore.rrule( rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R}, - ::Val{return_byproduct}; + ::ReturnByproduct; kwargs..., -) where {R,return_byproduct} +) where {R} @unpack conditions, linear_solver = implicit - y, z = implicit(x, Val(true); kwargs...) + y, z = implicit(x, ReturnByproduct(); kwargs...) n, m = length(x), length(y) backend = ReverseRuleConfigBackend(rc) @@ -37,44 +39,39 @@ function ChainRulesCore.rrule( pbB = pullback_function(backend, _x -> conditions(_x, y, z; kwargs...), x) pbmA = PullbackMul!(pbA, size(y)) pbmB = PullbackMul!(pbB, size(y)) + Aᵀ_op = LinearOperator(R, m, m, false, false, pbmA) Bᵀ_op = LinearOperator(R, n, m, false, false, pbmB) - implicit_pullback = ImplicitPullback( - Aᵀ_op, Bᵀ_op, linear_solver, x, Val(return_byproduct) - ) + Aᵀ_op_presolved = presolve(linear_solver, Aᵀ_op, y) + + implicit_pullback = ImplicitPullback(Aᵀ_op_presolved, Bᵀ_op, linear_solver, x) + + return (y, z), implicit_pullback +end - if return_byproduct - return (y, z), implicit_pullback - else - return y, implicit_pullback - end +function ChainRulesCore.rrule( + rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray{R}; kwargs... +) where {R} + (y, z), implicit_pullback = rrule(rc, implicit, x, ReturnByproduct(); kwargs...) + implicit_pullback_no_byproduct(dy) = Base.front(implicit_pullback((dy, nothing))) + return y, implicit_pullback_no_byproduct end -struct ImplicitPullback{return_byproduct,A,B,L,X} +struct ImplicitPullback{A,B,L,X} Aᵀ_op::A Bᵀ_op::B linear_solver::L x::X - _v::Val{return_byproduct} -end - -function (implicit_pullback_nobyproduct::ImplicitPullback{false})(dy) - @unpack Aᵀ_op, Bᵀ_op, linear_solver, x = implicit_pullback_nobyproduct - implicit_pullback_byproduct = ImplicitPullback( - Aᵀ_op, Bᵀ_op, linear_solver, x, Val(true) - ) - return implicit_pullback_byproduct((dy, nothing)) end -function (implicit_pullback_byproduct::ImplicitPullback{true})((dy, _)) - @unpack Aᵀ_op, Bᵀ_op, linear_solver, x = implicit_pullback_byproduct +function (implicit_pullback::ImplicitPullback)((dy, dz)) + @unpack Aᵀ_op, Bᵀ_op, linear_solver, x = implicit_pullback R = eltype(x) - - dy_vec = convert(Vector{R}, vec(unthunk(dy))) - dF_vec, stats = linear_solver(Aᵀ_op, dy_vec) - check_solution(linear_solver, stats) - dx_vec = Bᵀ_op * dF_vec - dx_vec .*= -1 + dy_vec = convert(AbstractVector{R}, vec(unthunk(dy))) + dF_vec = solve(linear_solver, Aᵀ_op, dy_vec) + dx_vec = vec(similar(x)) + mul!(dx_vec, Bᵀ_op, dF_vec) + lmul!(-one(R), dx_vec) dx = reshape(dx_vec, size(x)) return (NoTangent(), dx, NoTangent()) end diff --git a/ext/ImplicitDifferentiationForwardDiffExt.jl b/ext/ImplicitDifferentiationForwardDiffExt.jl index 99fa2b1..ae62b28 100644 --- a/ext/ImplicitDifferentiationForwardDiffExt.jl +++ b/ext/ImplicitDifferentiationForwardDiffExt.jl @@ -6,45 +6,51 @@ else using ..ForwardDiff: Dual, Partials, jacobian, partials, value end -using AbstractDifferentiation: ForwardDiffBackend, pushforward_function -using ImplicitDifferentiation: ImplicitFunction, PushforwardMul!, check_solution +using AbstractDifferentiation: + AbstractDifferentiation, ForwardDiffBackend, pushforward_function +using ImplicitDifferentiation: + ImplicitFunction, PushforwardMul!, ReturnByproduct, presolve, solve +using LinearAlgebra: lmul!, mul! using LinearOperators: LinearOperator using SimpleUnPack: @unpack """ implicit(x_and_dx::AbstractArray{<:Dual}; kwargs...) - implicit(x_and_dx::AbstractArray{<:Dual}, Val(return_byproduct); kwargs...) + implicit(x_and_dx::AbstractArray{<:Dual}, ReturnByproduct(); kwargs...) Overload an [`ImplicitFunction`](@ref) on dual numbers to ensure compatibility with forward mode autodiff. This is only available if ForwardDiff.jl is loaded (extension). -- If `return_byproduct=false` (the default), this returns a single output `y_and_dy(x)`. -- If `return_byproduct=true`, this returns a couple of outputs `(y_and_dy(x),z(x))` (remember that `z(x)` is not differentiated so `dz(x)` doesn't exist). +- By default, this returns a single output `y_and_dy(x)`. +- If `ReturnByproduct()` is passed as an argument, this returns a couple of outputs `(y_and_dy(x),z(x))` (remember that `z(x)` is not differentiated so `dz(x)` doesn't exist). -We compute the Jacobian-vector product `Jv` by solving `Au = Bv` and setting `Jv = u`. +We compute the Jacobian-vector product `Jv` by solving `Au = Bv` and setting `Jv = u` (see [`ImplicitFunction`](@ref) for the definition of `A` and `B`). Keyword arguments are given to both `implicit.forward` and `implicit.conditions`. """ function (implicit::ImplicitFunction)( - x_and_dx::AbstractArray{Dual{T,R,N}}, ::Val{return_byproduct}=Val(false); kwargs... -) where {T,R,N,return_byproduct} + x_and_dx::AbstractArray{Dual{T,R,N}}, ::ReturnByproduct; kwargs... +) where {T,R,N} @unpack conditions, linear_solver = implicit x = value.(x_and_dx) - y, z = implicit(x, Val(true); kwargs...) + y, z = implicit(x, ReturnByproduct(); kwargs...) n, m = length(x), length(y) backend = ForwardDiffBackend() pfA = pushforward_function(backend, _y -> conditions(x, _y, z; kwargs...), y) pfB = pushforward_function(backend, _x -> conditions(_x, y, z; kwargs...), x) + A_op = LinearOperator(R, m, m, false, false, PushforwardMul!(pfA, size(y))) B_op = LinearOperator(R, m, n, false, false, PushforwardMul!(pfB, size(x))) + A_op_presolved = presolve(linear_solver, A_op, y) - dy = map(1:N) do k + dy = ntuple(Val(N)) do k dₖx_vec = vec(partials.(x_and_dx, k)) - dₖy_vec, stats = linear_solver(A_op, B_op * dₖx_vec) - dₖy_vec .*= -1 - check_solution(linear_solver, stats) + Bdx = vec(similar(y)) + mul!(Bdx, B_op, dₖx_vec) + dₖy_vec = solve(linear_solver, A_op_presolved, Bdx) + lmul!(-one(R), dₖy_vec) reshape(dₖy_vec, size(y)) end @@ -55,11 +61,14 @@ function (implicit::ImplicitFunction)( reshape(y_and_dy_vec, size(y)) end - if return_byproduct - return y_and_dy, z - else - return y_and_dy - end + return y_and_dy, z +end + +function (implicit::ImplicitFunction)( + x_and_dx::AbstractArray{Dual{T,R,N}}; kwargs... +) where {T,R,N} + y_and_dy, z = implicit(x_and_dx, ReturnByproduct(); kwargs...) + return y_and_dy end end diff --git a/ext/ImplicitDifferentiationStaticArraysExt.jl b/ext/ImplicitDifferentiationStaticArraysExt.jl new file mode 100644 index 0000000..bfae8d5 --- /dev/null +++ b/ext/ImplicitDifferentiationStaticArraysExt.jl @@ -0,0 +1,26 @@ +module ImplicitDifferentiationStaticArraysExt + +@static if isdefined(Base, :get_extension) + using StaticArrays: StaticArray, MMatrix +else + using ..StaticArrays: StaticArray, MMatrix +end + +import ImplicitDifferentiation: ImplicitDifferentiation, DirectLinearSolver +using LinearAlgebra: lu, mul! + +function ImplicitDifferentiation.presolve( + ::DirectLinearSolver, A, y::StaticArray{S,T,N} +) where {S,T,N} + m = length(y) + A_static = zero(MMatrix{m,m,T}) + for i in axes(A_static, 2) + v = vec(similar(y)) + v .= zero(T) + v[i] = one(T) + mul!(@view(A_static[:, i]), A, v) + end + return lu(A_static) +end + +end diff --git a/src/ImplicitDifferentiation.jl b/src/ImplicitDifferentiation.jl index 9e2178d..3b3c446 100644 --- a/src/ImplicitDifferentiation.jl +++ b/src/ImplicitDifferentiation.jl @@ -1,15 +1,20 @@ module ImplicitDifferentiation -using AbstractDifferentiation: LazyJacobian, ReverseRuleConfigBackend, lazy_jacobian using Krylov: KrylovStats, gmres using LinearOperators: LinearOperators, LinearOperator +using LinearAlgebra: lu, SingularException using Requires: @require using SimpleUnPack: @unpack include("utils.jl") +include("forward.jl") +include("conditions.jl") +include("linear_solver.jl") include("implicit_function.jl") export ImplicitFunction +export IterativeLinearSolver, DirectLinearSolver +export HandleByproduct, ReturnByproduct @static if !isdefined(Base, :get_extension) include("../ext/ImplicitDifferentiationChainRulesExt.jl") @@ -17,6 +22,9 @@ export ImplicitFunction @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin include("../ext/ImplicitDifferentiationForwardDiffExt.jl") end + @require StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" begin + include("../ext/ImplicitDifferentiationStaticArraysExt.jl") + end end end diff --git a/src/conditions.jl b/src/conditions.jl new file mode 100644 index 0000000..1d3d725 --- /dev/null +++ b/src/conditions.jl @@ -0,0 +1,22 @@ +""" + Conditions{byproduct,C} + +Callable wrapper for conditions `c::C`, which ensures that a byproduct `z` is always accepted in addition to `x` and `y`. + +The type parameter `byproduct` is a boolean stating whether or not `c` natively accepts `z`. +""" +struct Conditions{byproduct,C} + c::C + function Conditions{byproduct}(c::C) where {byproduct,C} + return new{byproduct,C}(c) + end +end + +function Base.show(io::IO, conditions::Conditions{byproduct}) where {byproduct} + return print(io, "Conditions{$byproduct}($(conditions.c))") +end + +(conditions::Conditions{true})(x, y, z; kwargs...) = conditions.c(x, y, z; kwargs...) +(conditions::Conditions{false})(x, y, z; kwargs...) = conditions.c(x, y; kwargs...) + +handles_byproduct(::Conditions{byproduct}) where {byproduct} = byproduct diff --git a/src/forward.jl b/src/forward.jl new file mode 100644 index 0000000..97fa2ba --- /dev/null +++ b/src/forward.jl @@ -0,0 +1,30 @@ +""" + Forward{byproduct,F} + +Callable wrapper for a forward mapping `f::F`, which ensures that a byproduct `z(x)` is always returned in addition to `y(x)`. + +The type parameter `byproduct` is a boolean stating whether or not `f` natively returns `z(x)`. +""" +struct Forward{byproduct,F} + f::F + function Forward{byproduct}(f::F) where {byproduct,F} + return new{byproduct,F}(f) + end +end + +function Base.show(io::IO, forward::Forward{byproduct}) where {byproduct} + return print(io, "Forward{$byproduct}($(forward.f))") +end + +function (forward::Forward{true})(x; kwargs...) + y, z = forward.f(x; kwargs...) + return y, z +end + +function (forward::Forward{false})(x; kwargs...) + y = forward.f(x; kwargs...) + z = 0 + return y, z +end + +handles_byproduct(::Forward{byproduct}) where {byproduct} = byproduct diff --git a/src/implicit_function.jl b/src/implicit_function.jl index a66f504..2bf2465 100644 --- a/src/implicit_function.jl +++ b/src/implicit_function.jl @@ -1,135 +1,86 @@ """ - Forward{handle_byproduct,F} - -Callable wrapper for a forward mapping `f::F`, which ensures that a byproduct `z(x)` is always returned in addition to `y(x)`. - -The type parameter `handle_byproduct` is a boolean stating whether or not `f` natively returns `z(x)`. -""" -struct Forward{handle_byproduct,F} - f::F - function Forward{handle_byproduct}(f::F) where {handle_byproduct,F} - return new{handle_byproduct,F}(f) - end -end - -function Base.show(io::IO, forward::Forward{handle_byproduct}) where {handle_byproduct} - return print("Forward{$handle_byproduct}($(forward.f))") -end - -""" - forward(x; kwargs...) - -Apply `forward.f` to `x`, returning a dummy byproduct `z(x)=0` if needed. -""" -(forward::Forward{true})(x; kwargs...) = forward.f(x; kwargs...) -(forward::Forward{false})(x; kwargs...) = (forward.f(x; kwargs...), 0) - -""" - Conditions{handle_byproduct,C} - -Callable wrapper for conditions `c::C`, which ensures that a byproduct `z` is always accepted in addition to `x` and `y`. - -The type parameter `handle_byproduct` is a boolean stating whether or not `c` natively accepts `z`. -""" -struct Conditions{handle_byproduct,C} - c::C - function Conditions{handle_byproduct}(c::C) where {handle_byproduct,C} - return new{handle_byproduct,C}(c) - end -end - -function Base.show( - io::IO, conditions::Conditions{handle_byproduct} -) where {handle_byproduct} - return print("Conditions{$handle_byproduct}($(conditions.c))") -end - -""" - conditions(x, y, z; kwargs...) - -Apply `conditions.c` to `(x, y, z)`, discarding `z` beforehand if needed. -""" -(conditions::Conditions{true})(x, y, z; kwargs...) = conditions.c(x, y, z; kwargs...) -(conditions::Conditions{false})(x, y, z; kwargs...) = conditions.c(x, y; kwargs...) - -""" - ImplicitFunction{handle_byproduct,FF<:Forward,CC<:Conditions,LS} + ImplicitFunction{FF,CC,LS} Differentiable wrapper for an implicit function defined by a forward mapping and a set of conditions. # Constructors - ImplicitFunction(f, c; linear_solver=gmres) - ImplicitFunction(f, c, Val(handle_byproduct); linear_solver=gmres) - -Construct an `ImplicitFunction` from a forward mapping `f` and conditions `c`, both of which are Julia callables. +You can construct an `ImplicitFunction` from a forward mapping `f` and conditions `c`, both of which must be callables (function-like objects). While `f` does not not need to be compatible with automatic differentiation, `c` has to be. + ImplicitFunction(f, c[, HandleByproduct()]) + ImplicitFunction(f, c, linear_solver[, HandleByproduct()]) + +# Callable behavior + +An `ImplicitFunction` object `implicit` behaves like a function, and every call to it is differentiable. + + implicit(x::AbstractArray[, ReturnByproduct()]; kwargs...) + # Details -- If `handle_byproduct=false` (the default), the forward mapping is `x -> y(x)` and the conditions are `c(x,y(x)) = 0`. -- If `handle_byproduct=true`, the forward mapping is `x -> (y(x),z(x))` and the conditions are `c(x,y(x),z(x)) = 0`. In this case, `z(x)` can contain additional information generated by the forward mapping, but beware that we consider it constant for differentiation purposes. +- By default, we assume that the forward mapping is `x -> y(x)` and the conditions are `c(x,y(x)) = 0`. +- If `HandleByproduct()` is passed as an argument to the constructor, we assume instead that the forward mapping is `x -> (y(x),z(x))` and the conditions are `c(x,y(x),z(x)) = 0`. In this case, `z(x)` can contain additional information generated by the forward mapping, but beware that we consider it constant for differentiation purposes. Given `x ∈ ℝⁿ` and `y ∈ ℝᵈ`, we need as many conditions as output dimensions: `c(x,y,z) ∈ ℝᵈ`. We can then compute the Jacobian of `y(⋅)` using the implicit function theorem: ``` ∂₂c(x,y(x),z(x)) * ∂y(x) = -∂₁c(x,y(x),z(x)) ``` This requires solving a linear system `A * J = -B`, where `A ∈ ℝᵈˣᵈ`, `B ∈ ℝᵈˣⁿ` and `J ∈ ℝᵈˣⁿ`. -The default linear solver is `Krylov.gmres`, but this can be changed with a keyword argument. # Fields -- `forward::FF`: a wrapper of type [`Forward`](@ref) coherent with the value of `handle_byproduct` -- `conditions::FF`: a wrapper of type [`Conditions`](@ref) coherent with the value of `handle_byproduct` -- `linear_solver::LS`: a callable of the form `(A,b) -> (u,stats)` such that `Au = b` and `stats.solved ∈ {true,false}`, typically taken from [Krylov.jl](https://github.com/JuliaSmoothOptimizers/Krylov.jl) +- `forward::FF`: a wrapper of type [`Forward`](@ref) around the callable `f` +- `conditions::CC`: a wrapper of type [`Conditions`](@ref) around the callable `c` +- `linear_solver::LS`: an object subtyping [`AbstractLinearSolver`](@ref) (defaults to [`IterativeLinearSolver`](@ref)). """ -struct ImplicitFunction{ - handle_byproduct,FF<:Forward{handle_byproduct},CC<:Conditions{handle_byproduct},LS -} +struct ImplicitFunction{FF<:Forward,CC<:Conditions,LS<:AbstractLinearSolver} forward::FF conditions::CC linear_solver::LS - function ImplicitFunction( - f, c, ::Val{handle_byproduct}=Val(false); linear_solver=gmres - ) where {handle_byproduct} - forward = Forward{handle_byproduct}(f) - conditions = Conditions{handle_byproduct}(c) - return new{ - handle_byproduct,typeof(forward),typeof(conditions),typeof(linear_solver) - }( + function ImplicitFunction(f, c, linear_solver::AbstractLinearSolver) + forward = Forward{false}(f) + conditions = Conditions{false}(c) + return new{typeof(forward),typeof(conditions),typeof(linear_solver)}( + forward, conditions, linear_solver + ) + end + + function ImplicitFunction(f, c, linear_solver::AbstractLinearSolver, ::HandleByproduct) + forward = Forward{true}(f) + conditions = Conditions{true}(c) + return new{typeof(forward),typeof(conditions),typeof(linear_solver)}( forward, conditions, linear_solver ) end end -function Base.show( - io::IO, implicit::ImplicitFunction{handle_byproduct} -) where {handle_byproduct} - @unpack forward, conditions, linear_solver = implicit - @unpack f = forward - @unpack c = conditions - return print(io, "ImplicitFunction{$handle_byproduct}($f, $c, $linear_solver)") +function ImplicitFunction(f, c) + linear_solver = IterativeLinearSolver() + return ImplicitFunction(f, c, linear_solver) end -""" - implicit(x::AbstractArray; kwargs...) - implicit(x::AbstractArray, Val(return_byproduct), ; kwargs...) +function ImplicitFunction(f, c, ::HandleByproduct) + linear_solver = IterativeLinearSolver() + return ImplicitFunction(f, c, linear_solver, HandleByproduct()) +end -Make an [`ImplicitFunction`](@ref) callable by applying the forward mapping `implicit.forward`. +function Base.show(io::IO, implicit::ImplicitFunction) + @unpack forward, conditions, linear_solver = implicit + return print(io, "ImplicitFunction($(forward.f), $(conditions.c), $linear_solver)") +end -- If `return_byproduct=false` (the default), this returns a single output `y(x)`. -- If `return_byproduct=true`, this returns a couple of outputs `(y(x),z(x))`. +function (implicit::ImplicitFunction)(x::AbstractArray; kwargs...) + y, z = implicit.forward(x; kwargs...) + return y +end -The argument `return_byproduct` is independent from the type parameter `handle_byproduct` in `ImplicitFunction`, so any combination is possible. -""" -function (implicit::ImplicitFunction)( - x::AbstractArray, ::Val{return_byproduct}=Val(false); kwargs... -) where {return_byproduct} - y, z = implicit.forward(x, ; kwargs...) - if return_byproduct - return (y, z) - else - return y - end +function (implicit::ImplicitFunction)(x::AbstractArray, ::ReturnByproduct; kwargs...) + y, z = implicit.forward(x; kwargs...) + return (y, z) +end + +function handles_byproduct(implicit::ImplicitFunction) + return handles_byproduct(implicit.forward) && handles_byproduct(implicit.conditions) end diff --git a/src/linear_solver.jl b/src/linear_solver.jl new file mode 100644 index 0000000..4d9b86d --- /dev/null +++ b/src/linear_solver.jl @@ -0,0 +1,50 @@ +""" + AbstractLinearSolver + +All linear solvers used within an `ImplicitFunction` must satisfy this interface. + +# Required methods + +- `presolve(linear_solver, A, y)`: return a matrix-like object `A` for which it is cheaper to solve several linear systems with different vectors `b` (a typical example would be to perform LU factorization). +- `solve(linear_solver, A, b)`: return a tuple `(x, stats)` where `x` satisfies `Ax = b` and `stats.solved ∈ {true, false}`. +""" +abstract type AbstractLinearSolver end + +""" + IterativeLinearSolver + +An implementation of `AbstractLinearSolver` using `Krylov.gmres`. +""" +struct IterativeLinearSolver <: AbstractLinearSolver end + +presolve(::IterativeLinearSolver, A, y) = A + +function solve(::IterativeLinearSolver, A, b) + x, stats = gmres(A, b) + if !stats.solved + throw(SolverFailureException(gmres, stats)) + end + return x +end + +""" + DirectLinearSolver + +An implementation of `AbstractLinearSolver` using the built-in `\` operator. +""" +struct DirectLinearSolver <: AbstractLinearSolver end + +presolve(::DirectLinearSolver, A, y) = lu(Matrix(A)) +solve(::DirectLinearSolver, A, b) = A \ b + +struct SolverFailureException{A,B} <: Exception + solver::A + stats::B +end + +function Base.show(io::IO, sfe::SolverFailureException) + return println( + io, + "SolverFailureException: \n Linear solver: $(sfe.solver) \n Solver stats: $(string(sfe.stats))", + ) +end diff --git a/src/utils.jl b/src/utils.jl index 2f005c6..3ef77ca 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,22 +1,20 @@ -struct SolverFailureException{A,B} <: Exception - solver::A - stats::B -end +""" + HandleByproduct -function Base.show(io::IO, sfe::SolverFailureException) - return println( - io, - "SolverFailureException: \n Solver: $(sfe.solver) \n Solver stats: $(string(sfe.stats))", - ) -end +Trivial struct specifying that the forward mapping and conditions handle a byproduct. -function check_solution(solver, stats) - if stats.solved - return nothing - else - throw(SolverFailureException(solver, stats)) - end -end +Used in the constructor for `ImplicitFunction`. +""" +struct HandleByproduct end + +""" + ReturnByproduct + +Trivial struct specifying that we want to obtain a byproduct in addition to the solution. + +Used when calling an `ImplicitFunction`. +""" +struct ReturnByproduct end """ PushforwardMul!{P,N} @@ -49,13 +47,17 @@ end function (pfm::PushforwardMul!)(res::AbstractVector, δinput_vec::AbstractVector) δinput = reshape(δinput_vec, pfm.input_size) δoutput = only(pfm.pushforward(δinput)) - return res .= vec(δoutput) + for i in eachindex(IndexLinear(), res, δoutput) + res[i] = δoutput[i] + end end function (pbm::PullbackMul!)(res::AbstractVector, δoutput_vec::AbstractVector) δoutput = reshape(δoutput_vec, pbm.output_size) δinput = only(pbm.pullback(δoutput)) - return res .= vec(δinput) + for i in eachindex(IndexLinear(), res, δinput) + res[i] = δinput[i] + end end ## Override this function from LinearOperators to avoid generating the whole methods table diff --git a/test/misc.jl b/test/misc.jl deleted file mode 100644 index 4fb708d..0000000 --- a/test/misc.jl +++ /dev/null @@ -1,143 +0,0 @@ -using ChainRulesCore -using ChainRulesTestUtils -using ForwardDiff -using ImplicitDifferentiation -using JET -using LinearAlgebra -using Random -using Test -using Zygote - -Random.seed!(63); - -""" - mysqrt(x) - -Compute the elementwise square root, breaking Zygote.jl and ForwardDiff.jl in the process. -""" -function mysqrt(x::AbstractArray) - a = [0.0] - a[1] = first(x) - return sqrt.(x) -end - -myval(::Val{X}) where {X} = X - -function make_implicit_sqrt(::Val{handle_byproduct}) where {handle_byproduct} - if handle_byproduct - forward_byproduct(x) = (mysqrt(x), 0) - conditions_byproduct(x, y, z) = y .^ 2 .- x - implicit = ImplicitFunction(forward_byproduct, conditions_byproduct, Val(true)) - else - forward(x) = mysqrt(x) - conditions(x, y) = y .^ 2 .- x - implicit = ImplicitFunction(forward, conditions) - end - return implicit -end - -for handle_byproduct in (Val(true), Val(false)) - testsetname = myval(handle_byproduct) ? "With byproduct" : "Without byproduct" - @testset "$testsetname" verbose = true begin - implicit = make_implicit_sqrt(handle_byproduct) - # Skipped because of https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/232 and because it detects weird type instabilities - @testset verbose = true "ChainRulesTestUtils.jl" begin - @test_skip test_rrule(implicit, x) - @test_skip test_rrule(implicit, X) - end - - @testset verbose = true "Vectors" begin - x = rand(2) - y = implicit(x) - J = Diagonal(0.5 ./ sqrt.(x)) - - @testset "Call" begin - @test (@inferred implicit(x)) ≈ sqrt.(x) - if VERSION >= v"1.7" - test_opt(implicit, (typeof(x),)) - end - end - - @testset verbose = true "Forward" begin - @test ForwardDiff.jacobian(implicit, x) ≈ J - x_and_dx = ForwardDiff.Dual.(x, ((0, 0),)) - for return_byproduct in (true, false) - res_and_dres = @inferred implicit(x_and_dx, Val(return_byproduct)) - if return_byproduct - y_and_dy, z = res_and_dres - @test size(y_and_dy) == size(y) - else - y_and_dy = res_and_dres - @test size(y_and_dy) == size(y) - end - end - end - - @testset "Reverse" begin - @test Zygote.jacobian(implicit, x)[1] ≈ J - for return_byproduct in (true, false) - _, pullback = @inferred rrule( - Zygote.ZygoteRuleConfig(), implicit, x, Val(return_byproduct) - ) - dy, dz = zero(implicit(x)), 0 - if return_byproduct - @test (@inferred pullback((dy, dz))) == pullback((dy, dz)) - _, dx = pullback((dy, dz)) - @test size(dx) == size(x) - else - @test (@inferred pullback(dy)) == pullback(dy) - _, dx = pullback(dy) - @test size(dx) == size(x) - end - end - end - end - - @testset verbose = true "Arrays" begin - X = rand(2, 3, 4) - Y = implicit(X) - JJ = Diagonal(0.5 ./ sqrt.(vec(X))) - - @testset "Call" begin - @test (@inferred implicit(X)) ≈ sqrt.(X) - if VERSION >= v"1.7" - test_opt(implicit, (typeof(X),)) - end - end - - @testset "Forward" begin - @test ForwardDiff.jacobian(implicit, X) ≈ JJ - X_and_dX = ForwardDiff.Dual.(X, ((0, 0),)) - for return_byproduct in (true, false) - res_and_dres = @inferred implicit(X_and_dX, Val(return_byproduct)) - if return_byproduct - Y_and_dY, Z = res_and_dres - @test size(Y_and_dY) == size(Y) - else - Y_and_dY = res_and_dres - @test size(Y_and_dY) == size(Y) - end - end - end - - @testset "Reverse" begin - @test Zygote.jacobian(implicit, X)[1] ≈ JJ - for return_byproduct in (true, false) - _, pullback = @inferred rrule( - Zygote.ZygoteRuleConfig(), implicit, X, Val(return_byproduct) - ) - dY, dZ = zero(implicit(X)), 0 - if return_byproduct - @test (@inferred pullback((dY, dZ))) == pullback((dY, dZ)) - _, dX = pullback((dY, dZ)) - @test size(dX) == size(X) - else - @test (@inferred pullback(dY)) == pullback(dY) - _, dX = pullback(dY) - @test size(dX) == size(X) - end - end - end - end - end -end diff --git a/test/runtests.jl b/test/runtests.jl index d26134b..1256129 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using Aqua using Documenter +using ForwardDiff using ImplicitDifferentiation using JET using JuliaFormatter @@ -38,25 +39,31 @@ EXAMPLES_DIR_JL = joinpath(dirname(@__DIR__), "examples") Aqua.test_undefined_exports(ImplicitDifferentiation) Aqua.test_piracy(ImplicitDifferentiation) Aqua.test_project_extras(ImplicitDifferentiation) - Aqua.test_stale_deps(ImplicitDifferentiation; ignore=[:ChainRulesCore]) + Aqua.test_stale_deps( + ImplicitDifferentiation; ignore=[:AbstractDifferentiation, :ChainRulesCore] + ) Aqua.test_deps_compat(ImplicitDifferentiation) if VERSION >= v"1.7" Aqua.test_project_toml_formatting(ImplicitDifferentiation) end end @testset verbose = true "Formatting (JuliaFormatter.jl)" begin - @test format(ImplicitDifferentiation; verbose=true, overwrite=false) + @test format(ImplicitDifferentiation; verbose=false, overwrite=false) end @testset verbose = true "Static checking (JET.jl)" begin - if VERSION >= v"1.8" - JET.test_package(ImplicitDifferentiation; toplevel_logger=nothing) + if VERSION >= v"1.9" + JET.test_package( + ImplicitDifferentiation; + target_defined_modules=true, + toplevel_logger=nothing, + ) end end @testset verbose = false "Doctests (Documenter.jl)" begin doctest(ImplicitDifferentiation) end - @testset verbose = true "Miscellaneous" begin - include("misc.jl") + @testset verbose = true "Systematic" begin + include("systematic.jl") end @testset verbose = true "Examples" begin for file in readdir(EXAMPLES_DIR_JL) diff --git a/test/systematic.jl b/test/systematic.jl new file mode 100644 index 0000000..568026d --- /dev/null +++ b/test/systematic.jl @@ -0,0 +1,187 @@ +using ChainRulesCore +using ChainRulesTestUtils +using ForwardDiff +using ImplicitDifferentiation +using ImplicitDifferentiation: handles_byproduct +using JET +using LinearAlgebra +using Random +using StaticArrays +using Test +using Zygote +using Zygote: ZygoteRuleConfig + +@static if VERSION < v"1.9" + macro test_opt(x...) + return :() + end + macro test_call(x...) + return :() + end +end + +Random.seed!(63); + +function is_static_array(a) + return ( + typeof(a) <: StaticArray || + typeof(a) <: (Base.ReshapedArray{T,N,<:StaticArray} where {T,N}) + ) +end + +function break_forwarddiff_zygote(x) + a = [0.0] + a[1] = float(first(x)) + return nothing +end + +function mysqrt(x::AbstractArray) + break_forwarddiff_zygote(x) + return sqrt.(x) +end + +function mysqrt_byproduct(x::AbstractArray) + break_forwarddiff_zygote(x) + z = rand((2,)) + y = x .^ (1 / z) + return y, z +end + +function make_implicit_sqrt(linear_solver) + forward(x) = mysqrt(x) + conditions(x, y) = y .^ 2 .- x + implicit = ImplicitFunction(forward, conditions, linear_solver) + return implicit +end + +function make_implicit_sqrt_byproduct(linear_solver) + forward(x) = mysqrt_byproduct(x) + conditions(x, y, z) = y .^ z .- x + implicit = ImplicitFunction(forward, conditions, linear_solver, HandleByproduct()) + return implicit +end + +function test_implicit_call(implicit, x; y_true) + @test_throws MethodError implicit("hello") + @test_throws MethodError implicit(x, x) + y1 = @inferred implicit(x) + y2, z2 = @inferred implicit(x, ReturnByproduct()) + @test y1 ≈ y_true + @test y2 ≈ y_true + if typeof(x) <: StaticArray + @test is_static_array(y1) + @test is_static_array(y2) + end + if handles_byproduct(implicit) + @test z2 == 2 + else + @test z2 == 0 + end + @test_opt target_modules = (ImplicitDifferentiation,) implicit(x) + @test_call target_modules = (ImplicitDifferentiation,) implicit(x) +end + +function test_implicit_forward(implicit, x; y_true, J_true) + # High-level + J1 = ForwardDiff.jacobian(implicit, x) + J2 = ForwardDiff.jacobian(x -> implicit(x, ReturnByproduct())[1], x) + @test J1 ≈ J_true + @test J2 ≈ J_true + # Low-level + x_and_dx = ForwardDiff.Dual.(x, ((0, 0),)) + y_and_dy1 = @inferred implicit(x_and_dx) + y_and_dy2, z2 = @inferred implicit(x_and_dx, ReturnByproduct()) + @test size(y_and_dy1) == size(y_true) + @test size(y_and_dy2) == size(y_true) + @test ForwardDiff.value.(y_and_dy1) ≈ y_true + @test ForwardDiff.value.(y_and_dy2) ≈ y_true + if typeof(x) <: StaticArray + @test is_static_array(y_and_dy1) + @test is_static_array(y_and_dy2) + end + if handles_byproduct(implicit) + @test z2 == 2 + else + @test z2 == 0 + end + @test_opt target_modules = (ImplicitDifferentiation,) implicit(x_and_dx) + @test_call target_modules = (ImplicitDifferentiation,) implicit(x_and_dx) +end + +function test_implicit_reverse(implicit, x; y_true, J_true) + # High-level + J1 = Zygote.jacobian(implicit, x)[1] + J2 = Zygote.jacobian(x -> implicit(x, ReturnByproduct())[1], x)[1] + @test J1 ≈ J_true + @test J2 ≈ J_true + # Low-level + y1, pb1 = @inferred rrule(ZygoteRuleConfig(), implicit, x) + (y2, z2), pb2 = @inferred rrule(ZygoteRuleConfig(), implicit, x, ReturnByproduct()) + @test y1 ≈ y_true + @test y2 ≈ y_true + dy1 = zeros(eltype(y1), size(y1)...) + dy2 = zeros(eltype(y2), size(y2)...) + dz2 = nothing + dimp1, dx1 = @inferred pb1(dy1) + dimp2, dx2, drp = @inferred pb2((dy2, dz2)) + @test size(dx1) == size(x) + @test size(dx2) == size(x) + if typeof(x) <: StaticArray + @test is_static_array(y1) + @test is_static_array(y2) + @test is_static_array(dx1) + @test is_static_array(dx2) + end + @test dimp1 isa NoTangent + @test dimp2 isa NoTangent + @test drp isa NoTangent + if handles_byproduct(implicit) + @test z2 == 2 + else + @test z2 == 0 + end + @test_skip @test_opt target_modules = (ImplicitDifferentiation,) rrule( + ZygoteRuleConfig(), implicit, x + ) + @test_skip @test_opt target_modules = (ImplicitDifferentiation,) pb1(dy1) + @test_call target_modules = (ImplicitDifferentiation,) rrule( + ZygoteRuleConfig(), implicit, x + ) + @test_call target_modules = (ImplicitDifferentiation,) pb1(dy1) + # Skipped because of https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/232 and because it detects weird type instabilities + @test_skip test_rrule(implicit, x) + @test_skip test_rrule(x -> implicit(x, ReturnByproduct()), x) +end + +x_candidates = ( + rand(2), rand(2, 3, 4), SVector{2}(rand(2)), SArray{Tuple{2,3,4}}(rand(2, 3, 4)) +); + +linear_solver_candidates = (IterativeLinearSolver(), DirectLinearSolver()) + +for linear_solver in linear_solver_candidates, x in x_candidates + if x isa StaticArray && linear_solver isa IterativeLinearSolver + continue + end + y_true = sqrt.(x) + J_true = Diagonal(0.5 ./ vec(sqrt.(x))) + + testsetname = "$(typeof(x)) - $(typeof(linear_solver))" + implicit_sqrt = make_implicit_sqrt(linear_solver) + implicit_sqrt_byproduct = make_implicit_sqrt_byproduct(linear_solver) + + @testset verbose = true "$testsetname" begin + @testset "Call" begin + test_implicit_call(implicit_sqrt, x; y_true) + test_implicit_call(implicit_sqrt_byproduct, x; y_true) + end + @testset "Forward" begin + test_implicit_forward(implicit_sqrt, x; y_true, J_true) + test_implicit_forward(implicit_sqrt_byproduct, x; y_true, J_true) + end + @testset "Reverse" begin + test_implicit_reverse(implicit_sqrt, x; y_true, J_true) + test_implicit_reverse(implicit_sqrt_byproduct, x; y_true, J_true) + end + end +end