Skip to content

Commit

Permalink
Precompilation is cool, we should do more of it (#2160)
Browse files Browse the repository at this point in the history
* Precompilation is cool, we should do more of it

* fix

* tm stuff

* ix attempt

* reset

* more

* ix

* reduce

* fix
  • Loading branch information
wsmoses authored Dec 4, 2024
1 parent 3ad827f commit 358d647
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 6 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
ObjectFile = "d8793406-e978-5875-9003-1fc021f44a92"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
2 changes: 2 additions & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1587,4 +1587,6 @@ Returns true if within autodiff, otherwise false.
"""
@inline EnzymeCore.within_autodiff() = false

include("precompile.jl")

end # module
18 changes: 12 additions & 6 deletions src/compiler/orcv2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ function define_absolute_symbol(jd, name)
return false
end

function __init__()
function setup_globals()
opt_level = Base.JLOptions().opt_level
if opt_level < 2
optlevel = LLVM.API.LLVMCodeGenLevelNone
Expand All @@ -105,11 +105,6 @@ function __init__()
dg = LLVM.CreateDynamicLibrarySearchGeneratorForProcess(prefix)
LLVM.add!(jd_main, dg)

if Sys.iswindows() && Int === Int64
# TODO can we check isGNU?
define_absolute_symbol(jd_main, mangle(lljit, "___chkstk_ms"))
end

es = ExecutionSession(lljit)
try
lctm = LLVM.LocalLazyCallThroughManager(triple(lljit), es)
Expand All @@ -120,6 +115,17 @@ function __init__()
jit[] = CompilerInstance(lljit, nothing, nothing)
end

jd_main, lljit
end

function __init__()
jd_main, lljit = setup_globals()

if Sys.iswindows() && Int === Int64
# TODO can we check isGNU?
define_absolute_symbol(jd_main, mangle(lljit, "___chkstk_ms"))
end

hnd = unsafe_load(cglobal(:jl_libjulia_handle, Ptr{Cvoid}))
for (k, v) in Compiler.JuliaGlobalNameMap
ptr = unsafe_load(Base.reinterpret(Ptr{Ptr{Cvoid}}, Libdl.dlsym(hnd, k)))
Expand Down
13 changes: 13 additions & 0 deletions src/precompile.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
using PrecompileTools: @setup_workload, @compile_workload

@setup_workload begin
precompile_module = @eval module $(gensym())
f(x) = x^2
end

Compiler.JIT.setup_globals()

@compile_workload begin
Enzyme.autodiff(Reverse, precompile_module.f, Active(2.0))
end
end

0 comments on commit 358d647

Please sign in to comment.