From bd9ff747198f89df46684cbd30bc846716d6fbab Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 20 Jun 2024 12:31:58 -0400 Subject: [PATCH 1/4] Add tests for precompilation support --- Project.toml | 2 ++ test/Project.toml | 1 + test/precompile.jl | 59 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+) create mode 100644 test/precompile.jl diff --git a/Project.toml b/Project.toml index e47535efbd..4bd13a772f 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" ObjectFile = "d8793406-e978-5875-9003-1fc021f44a92" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -37,6 +38,7 @@ GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1, 7, 8" LogExpFunctions = "0.3" ObjectFile = "0.4" +PrecompileTools = "1.2" Preferences = "1.4" SpecialFunctions = "1, 2" StaticArrays = "1" diff --git a/test/Project.toml b/test/Project.toml index 5c8286d1af..fbdc4b6038 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -13,6 +13,7 @@ LLVM_jll = "86de99a1-58d6-5da7-8064-bd56ce2e322c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" diff --git a/test/precompile.jl b/test/precompile.jl new file mode 100644 index 0000000000..e7a9a85248 --- /dev/null +++ b/test/precompile.jl @@ -0,0 +1,59 @@ +using Test + +function precompile_test_harness(@nospecialize(f), testset::String) + @testset "$testset" begin + precompile_test_harness(f, true) + end +end +function precompile_test_harness(@nospecialize(f), separate::Bool) + load_path = mktempdir() + load_cache_path = separate ? mktempdir() : load_path + try + pushfirst!(LOAD_PATH, load_path) + pushfirst!(DEPOT_PATH, load_cache_path) + f(load_path) + finally + try + rm(load_path, force=true, recursive=true) + catch err + @show err + end + if separate + try + rm(load_cache_path, force=true, recursive=true) + catch err + @show err + end + end + filter!((≠)(load_path), LOAD_PATH) + separate && filter!((≠)(load_cache_path), DEPOT_PATH) + end + nothing +end + +precompile_test_harness("Inference caching") do load_path + write(joinpath(load_path, "InferenceCaching.jl"), :(module InferenceCaching + using Enzyme + using PrecompileTools + + function mul(x, y) + return x * y + end + + @setup_workload begin + @compile_workload begin + autodiff(Reverse, mul, Active, Active(1.0), Active(2.0)) + autodiff(Forward, mul, Duplicated, Duplicated(1.0, 1.0), Const(2.0)) + end + end + end) |> string) + + Base.compilecache(Base.PkgId("InferenceCaching")) + @eval let + using InferenceCaching + using Enzyme + + autodiff(Reverse, InferenceCaching.mul, Active, Active(1.0), Active(2.0)) + autodiff(Forward, InferenceCaching.mul, Duplicated, Duplicated(1.0, 1.0), Const(2.0)) + end +end \ No newline at end of file From 493b50a84c92f794e85ffe0886233e2bbc1675e9 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sun, 23 Jun 2024 21:42:26 -0400 Subject: [PATCH 2/4] attempt to use PrecompileTools --- src/Enzyme.jl | 9 +++++++++ test/precompile.jl | 11 +++++++---- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 9ff56bdd81..12f7495907 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1475,4 +1475,13 @@ macro import_rrule(args...) return _import_rrule(args...) end +# using PrecompileTools +# Crashes on 1.11 +# @setup_workload let +# @compile_workload begin +# autodiff(ReverseMode{false,InlineABI,false}(), ()->nothing, Const) +# autodiff(ForwardMode{InlineABI}(), ()->nothing, Const) +# end +# end + end # module diff --git a/test/precompile.jl b/test/precompile.jl index e7a9a85248..98dfc95b1f 100644 --- a/test/precompile.jl +++ b/test/precompile.jl @@ -42,8 +42,10 @@ precompile_test_harness("Inference caching") do load_path @setup_workload begin @compile_workload begin - autodiff(Reverse, mul, Active, Active(1.0), Active(2.0)) - autodiff(Forward, mul, Duplicated, Duplicated(1.0, 1.0), Const(2.0)) + autodiff(ReverseMode{false,InlineABI,false}(), mul, Active, Active(1.0), Active(2.0)) + # Non-Inline mode uses `@generated` functions and poisons the caller + # autodiff(Reverse, mul, Active, Active(1.0), Active(2.0)) + # autodiff(Forward, mul, Duplicated, Duplicated(1.0, 1.0), Const(2.0)) end end end) |> string) @@ -53,7 +55,8 @@ precompile_test_harness("Inference caching") do load_path using InferenceCaching using Enzyme - autodiff(Reverse, InferenceCaching.mul, Active, Active(1.0), Active(2.0)) - autodiff(Forward, InferenceCaching.mul, Duplicated, Duplicated(1.0, 1.0), Const(2.0)) + @test autodiff(ReverseMode{false,InlineABI,false}(), InferenceCaching.mul, Active, Active(1.0), Active(2.0)) == ((2.0, 1.0),) + # autodiff(Reverse, InferenceCaching.mul, Active, Active(1.0), Active(2.0)) + # autodiff(Forward, InferenceCaching.mul, Duplicated, Duplicated(1.0, 1.0), Const(2.0)) end end \ No newline at end of file From 3697135dac208b1bcbeb1ed66f367a261afefb9a Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sun, 23 Jun 2024 21:42:46 -0400 Subject: [PATCH 3/4] tell the serializer to not cache code instances originating from Enzyme --- src/compiler/interpreter.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index e1652c5895..cd231abb66 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -68,6 +68,19 @@ else Core.Compiler.code_cache(interp::EnzymeInterpreter) = WorldView(interp.code_cache, interp.world) end +@static if HAS_INTEGRATED_CACHE + function CC.CodeInstance(interp::EnzymeInterpreter, result::CC.InferenceResult, + valid_worlds::CC.WorldRange) + ci = @invoke CC.CodeInstance(interp::CC.AbstractInterpreter, result::CC.InferenceResult, + valid_worlds::CC.WorldRange) + + # FIXME: Enzyme embeds global pointers and other fun things directly + # So forbid the caching of the results. + ci.relocatability = 0x0 + return ci + end +end + # No need to do any locking since we're not putting our results into the runtime cache Core.Compiler.lock_mi_inference(interp::EnzymeInterpreter, mi::MethodInstance) = nothing Core.Compiler.unlock_mi_inference(interp::EnzymeInterpreter, mi::MethodInstance) = nothing From 6f73850e5401f8fec919c1ee28909c05f3932a9f Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sun, 23 Jun 2024 21:53:00 -0400 Subject: [PATCH 4/4] fixup! tell the serializer to not cache code instances originating from Enzyme --- src/compiler/interpreter.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index cd231abb66..cc778e7d88 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -68,6 +68,7 @@ else Core.Compiler.code_cache(interp::EnzymeInterpreter) = WorldView(interp.code_cache, interp.world) end +const CC = Core.Compiler @static if HAS_INTEGRATED_CACHE function CC.CodeInstance(interp::EnzymeInterpreter, result::CC.InferenceResult, valid_worlds::CC.WorldRange)