From 09354705a4554a8cbcab19730bb4554114d5ea0b Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 11 Nov 2024 11:04:24 +0000 Subject: [PATCH] Import matrix exponential rule from ChainRules (#365) * Matrix exponential * Bump patch version * Remove bad rule definition --- Project.toml | 6 ++++-- src/Mooncake.jl | 2 ++ src/rrules/linear_algebra.jl | 32 ++++++++++++++++++++++++++++++++ test/rrules/linear_algebra.jl | 3 +++ test/runtests.jl | 2 ++ 5 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 src/rrules/linear_algebra.jl create mode 100644 test/rrules/linear_algebra.jl diff --git a/Project.toml b/Project.toml index 1fdcbad09..c80a71522 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.40" +version = "0.4.41" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d" @@ -43,6 +44,7 @@ AllocCheck = "0.2" Aqua = "0.8.9" BenchmarkTools = "1" CUDA = "5" +ChainRules = "1.71.0" ChainRulesCore = "1" DiffRules = "1" DiffTests = "0.1" @@ -56,8 +58,8 @@ LogDensityProblemsAD = "1" LuxLib = "1.2 - 1.3.3" MistyClosures = "2" NNlib = "0.9" -Random = "1" Pkg = "1" +Random = "1" Setfield = "1" SpecialFunctions = "2" StableRNGs = "1" diff --git a/src/Mooncake.jl b/src/Mooncake.jl index 6889da23e..0b5a222ac 100644 --- a/src/Mooncake.jl +++ b/src/Mooncake.jl @@ -4,6 +4,7 @@ const CC = Core.Compiler using ADTypes, + ChainRules, DiffRules, ExprTools, Graphs, @@ -83,6 +84,7 @@ include(joinpath("rrules", "fastmath.jl")) include(joinpath("rrules", "foreigncall.jl")) include(joinpath("rrules", "iddict.jl")) include(joinpath("rrules", "lapack.jl")) +include(joinpath("rrules", "linear_algebra.jl")) include(joinpath("rrules", "low_level_maths.jl")) include(joinpath("rrules", "misc.jl")) include(joinpath("rrules", "new.jl")) diff --git a/src/rrules/linear_algebra.jl b/src/rrules/linear_algebra.jl new file mode 100644 index 000000000..dc628d2bc --- /dev/null +++ b/src/rrules/linear_algebra.jl @@ -0,0 +1,32 @@ +@is_primitive MinimalCtx Tuple{typeof(exp), Matrix{<:IEEEFloat}} + +struct ExpPullback{P} + pb + Ȳ::Matrix{P} + X̄::Matrix{P} +end + +function (pb::ExpPullback)(::NoRData) + _, X̄_inc = pb.pb(pb.Ȳ) + pb.X̄ .+= X̄_inc + return NoRData(), NoRData() +end + +function rrule!!(::CoDual{typeof(exp)}, X::CoDual{Matrix{P}}) where {P<:IEEEFloat} + Y, pb = ChainRules.rrule(exp, X.x) + Ȳ = zero(Y) + return CoDual(Y, Ȳ), ExpPullback{P}(pb, Ȳ, X.dx) +end + +function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:linear_algebra}) + test_cases = Any[ + (false, :none, nothing, exp, randn(3, 3)), + (false, :none, nothing, exp, randn(7, 7)), + ] + memory = Any[] + return test_cases, memory +end + +function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:linear_algebra}) + return Any[], Any[] +end diff --git a/test/rrules/linear_algebra.jl b/test/rrules/linear_algebra.jl new file mode 100644 index 000000000..f7e1ac0ed --- /dev/null +++ b/test/rrules/linear_algebra.jl @@ -0,0 +1,3 @@ +@testset "linear_algebra" begin + TestUtils.run_rrule!!_test_cases(StableRNG, Val(:linear_algebra)) +end diff --git a/test/runtests.jl b/test/runtests.jl index dd9d955ea..a4a9e7d71 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -40,6 +40,8 @@ include("front_matter.jl") include(joinpath("rrules", "iddict.jl")) @info "lapack" include(joinpath("rrules", "lapack.jl")) + @info "linear_algebra" + include(joinpath("rrules", "linear_algebra.jl")) @info "low_level_maths" include(joinpath("rrules", "low_level_maths.jl")) @info "misc"