Skip to content

Commit

Permalink
Add new deferred compilation mechanism.
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy authored and maleadt committed Jul 24, 2024
1 parent d68a7fc commit e532812
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 33 deletions.
134 changes: 106 additions & 28 deletions src/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,41 @@ function JuliaContext(f; kwargs...)
end


## deferred compilation

function var"gpuc.deferred" end

# old, deprecated mechanism slated for removal once Enzyme is updated to the new intrinsic
begin
# primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
# this could both be generalized (e.g. supporting actual function calls, instead of
# returning a function pointer), and be integrated with the nonrecursive codegen.
const deferred_codegen_jobs = Dict{Int, Any}()

# We make this function explicitly callable so that we can drive OrcJIT's
# lazy compilation from, while also enabling recursive compilation.
Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid})
ptr
end

@generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt}
id = length(deferred_codegen_jobs) + 1
deferred_codegen_jobs[id] = (; ft, tt)
# don't bother looking up the method instance, as we'll do so again during codegen
# using the world age of the parent.
#
# this also works around an issue on <1.10, where we don't know the world age of
# generated functions so use the current world counter, which may be too new
# for the world we're compiling for.

quote
# TODO: add an edge to this method instance to support method redefinitions
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), $id)
end
end
end


## compiler entrypoint

export compile
Expand Down Expand Up @@ -127,33 +162,6 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool
error("Unknown compilation output $output")
end

# primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
# this could both be generalized (e.g. supporting actual function calls, instead of
# returning a function pointer), and be integrated with the nonrecursive codegen.
const deferred_codegen_jobs = Dict{Int, Any}()

# We make this function explicitly callable so that we can drive OrcJIT's
# lazy compilation from, while also enabling recursive compilation.
Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid})
ptr
end

@generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt}
id = length(deferred_codegen_jobs) + 1
deferred_codegen_jobs[id] = (; ft, tt)
# don't bother looking up the method instance, as we'll do so again during codegen
# using the world age of the parent.
#
# this also works around an issue on <1.10, where we don't know the world age of
# generated functions so use the current world counter, which may be too new
# for the world we're compiling for.

quote
# TODO: add an edge to this method instance to support method redefinitions
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), $id)
end
end

const __llvm_initialized = Ref(false)

@locked function emit_llvm(@nospecialize(job::CompilerJob); toplevel::Bool,
Expand Down Expand Up @@ -183,6 +191,77 @@ const __llvm_initialized = Ref(false)
entry = finish_module!(job, ir, entry)

# deferred code generation
if haskey(functions(ir), "gpuc.lookup")
dyn_marker = functions(ir)["gpuc.lookup"]

# gpuc.deferred is lowered to a gpuc.lookup foreigncall, so we need to extract the
# target method instance from the LLVM IR
# TODO: drive deferred compilation from the Julia IR instead
function find_base_object(val)
while true
if val isa ConstantExpr && (opcode(val) == LLVM.API.LLVMIntToPtr ||
opcode(val) == LLVM.API.LLVMBitCast ||
opcode(val) == LLVM.API.LLVMAddrSpaceCast)
val = first(operands(val))
elseif val isa LLVM.IntToPtrInst ||
val isa LLVM.BitCastInst ||
val isa LLVM.AddrSpaceCastInst
val = first(operands(val))
elseif val isa LLVM.LoadInst
# In 1.11+ we no longer embed integer constants directly.
gv = first(operands(val))
if gv isa LLVM.GlobalValue
val = LLVM.initializer(gv)
continue
end
break
else
break
end
end
return val
end

worklist = Dict{Any, Vector{LLVM.CallInst}}()
for use in uses(dyn_marker)
# decode the call
call = user(use)::LLVM.CallInst
dyn_mi_inst = find_base_object(operands(call)[1])
@compiler_assert isa(dyn_mi_inst, LLVM.ConstantInt) job
dyn_mi = Base.unsafe_pointer_to_objref(
convert(Ptr{Cvoid}, convert(Int, dyn_mi_inst)))
push!(get!(worklist, dyn_mi, LLVM.CallInst[]), call)
end

for dyn_mi in keys(worklist)
dyn_fn_name = compiled[dyn_mi].specfunc
dyn_fn = functions(ir)[dyn_fn_name]

# insert a pointer to the function everywhere the entry is used
T_ptr = convert(LLVMType, Ptr{Cvoid})
for call in worklist[dyn_mi]
@dispose builder=IRBuilder() begin
position!(builder, call)
fptr = if LLVM.version() >= v"17"
T_ptr = LLVM.PointerType()
bitcast!(builder, dyn_fn, T_ptr)
elseif VERSION >= v"1.12.0-DEV.225"
T_ptr = LLVM.PointerType(LLVM.Int8Type())
bitcast!(builder, dyn_fn, T_ptr)
else
ptrtoint!(builder, dyn_fn, T_ptr)
end
replace_uses!(call, fptr)
end
unsafe_delete!(LLVM.parent(call), call)
end
end

# all deferred compilations should have been resolved
@compiler_assert isempty(uses(dyn_marker)) job
unsafe_delete!(ir, dyn_marker)
end
## old, deprecated implementation
has_deferred_jobs = toplevel && !only_entry && haskey(functions(ir), "deferred_codegen")
jobs = Dict{CompilerJob, String}(job => entry_fn)
if has_deferred_jobs
Expand All @@ -194,7 +273,6 @@ const __llvm_initialized = Ref(false)
changed = false

# find deferred compiler
# TODO: recover this information earlier, from the Julia IR
worklist = Dict{CompilerJob, Vector{LLVM.CallInst}}()
for use in uses(dyn_marker)
# decode the call
Expand Down
6 changes: 6 additions & 0 deletions src/irgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,12 @@ function irgen(@nospecialize(job::CompilerJob))
compiled[job.source] =
(; compiled[job.source].ci, func, specfunc)

for mi in keys(compiled)
mi == job.source && continue
ci, func, specfunc = compiled[mi]
compiled[mi] = (; ci, func=safe_name(func), specfunc=safe_name(specfunc))
end

# minimal required optimization
@timeit_debug to "rewrite" begin
if job.config.kernel && needs_byval(job)
Expand Down
159 changes: 154 additions & 5 deletions src/jlgen.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Julia compiler integration


## world age lookups

# `tls_world_age` should be used to look up the current world age. in most cases, this is
Expand All @@ -12,6 +11,7 @@ else
tls_world_age() = ccall(:jl_get_tls_world_age, UInt, ())
end


## looking up method instances

export methodinstance, generic_methodinstance
Expand Down Expand Up @@ -159,6 +159,7 @@ end


## code instance cache

const HAS_INTEGRATED_CACHE = VERSION >= v"1.11.0-DEV.1552"

if !HAS_INTEGRATED_CACHE
Expand Down Expand Up @@ -318,7 +319,8 @@ else
get_method_table_view(world::UInt, mt::MTType) = OverlayMethodTable(world, mt)
end

struct GPUInterpreter <: CC.AbstractInterpreter
abstract type AbstractGPUInterpreter <: CC.AbstractInterpreter end
struct GPUInterpreter <: AbstractGPUInterpreter
world::UInt
method_table::GPUMethodTableView

Expand Down Expand Up @@ -436,6 +438,112 @@ function CC.concrete_eval_eligible(interp::GPUInterpreter,
end


## deferred compilation

struct DeferredCallInfo <: CC.CallInfo
rt::DataType
info::CC.CallInfo
end

# recognize calls to gpuc.deferred and save DeferredCallInfo metadata
function CC.abstract_call_known(interp::AbstractGPUInterpreter, @nospecialize(f),
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
max_methods::Int = CC.get_max_methods(interp, f, sv))
(; fargs, argtypes) = arginfo
if f === var"gpuc.deferred"
argvec = argtypes[2:end]
call = CC.abstract_call(interp, CC.ArgInfo(nothing, argvec), si, sv, max_methods)
callinfo = DeferredCallInfo(call.rt, call.info)
@static if VERSION < v"1.11.0-"
return CC.CallMeta(Ptr{Cvoid}, CC.Effects(), callinfo)
else
return CC.CallMeta(Ptr{Cvoid}, Union{}, CC.Effects(), callinfo)
end
end
return @invoke CC.abstract_call_known(interp::CC.AbstractInterpreter, f,
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
max_methods::Int)
end

# during inlining, refine deferred calls to gpuc.lookup foreigncalls
const FlagType = VERSION >= v"1.11.0-" ? UInt32 : UInt8
function CC.handle_call!(todo::Vector{Pair{Int,Any}}, ir::CC.IRCode, idx::CC.Int,
stmt::Expr, info::DeferredCallInfo, flag::FlagType,
sig::CC.Signature, state::CC.InliningState)
minfo = info.info
results = minfo.results
if length(results.matches) != 1
return nothing
end
match = only(results.matches)

# lookup the target mi with correct edge tracking
case = CC.compileable_specialization(match, CC.Effects(), CC.InliningEdgeTracker(state),
info)
@assert case isa CC.InvokeCase
@assert stmt.head === :call

args = Any[
"extern gpuc.lookup",
Ptr{Cvoid},
Core.svec(Any, Any, match.spec_types.parameters[2:end]...), # Must use Any for MethodInstance or ftype
0,
QuoteNode(:llvmcall),
case.invoke,
stmt.args[2:end]...
]
stmt.head = :foreigncall
stmt.args = args
return nothing
end

struct DeferredEdges
edges::Vector{MethodInstance}
end

function find_deferred_edges(ir::CC.IRCode)
edges = MethodInstance[]
# XXX: can we add this instead in handle_call?
for stmt in ir.stmts
inst = stmt[:inst]
inst isa Expr || continue
expr = inst::Expr
if expr.head === :foreigncall &&
expr.args[1] == "extern gpuc.lookup"
deferred_mi = expr.args[6]
push!(edges, deferred_mi)
end
end
unique!(edges)
return edges
end

if VERSION >= v"1.11.0-"
function CC.ipo_dataflow_analysis!(interp::AbstractGPUInterpreter, ir::CC.IRCode,
caller::CC.InferenceResult)
edges = find_deferred_edges(ir)
if !isempty(edges)
CC.stack_analysis_result!(caller, DeferredEdges(edges))
end
@invoke CC.ipo_dataflow_analysis!(interp::CC.AbstractInterpreter, ir::CC.IRCode,
caller::CC.InferenceResult)
end
else # v1.10
# 1.10 doesn't have stack_analysis_result or ipo_dataflow_analysis
function CC.finish(interp::AbstractGPUInterpreter, opt::CC.OptimizationState, ir::CC.IRCode,
caller::CC.InferenceResult)
edges = find_deferred_edges(ir)
if !isempty(edges)
# HACK: we store the deferred edges in the argescapes field, which is invalid,
# but nobody should be running EA on our results.
caller.argescapes = DeferredEdges(edges)
end
@invoke CC.finish(interp::CC.AbstractInterpreter, opt::CC.OptimizationState,
ir::CC.IRCode, caller::CC.InferenceResult)
end
end


## world view of the cache
using Core.Compiler: WorldView

Expand Down Expand Up @@ -584,6 +692,24 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
error("Cannot compile $(job.source) for world $(job.world); method is only valid in worlds $(job.source.def.primary_world) to $(job.source.def.deleted_world)")
end

compiled = IdDict()
llvm_mod, outstanding = compile_method_instance(job, compiled)
worklist = outstanding
while !isempty(worklist)
source = pop!(worklist)
haskey(compiled, source) && continue
job2 = CompilerJob(source, job.config)
@debug "Processing..." job2
llvm_mod2, outstanding = compile_method_instance(job2, compiled)
append!(worklist, outstanding)
@assert context(llvm_mod) == context(llvm_mod2)
link!(llvm_mod, llvm_mod2)
end

return llvm_mod, compiled
end

function compile_method_instance(@nospecialize(job::CompilerJob), compiled::IdDict{Any, Any})
# populate the cache
interp = get_interpreter(job)
cache = CC.code_cache(interp)
Expand All @@ -594,7 +720,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob))

# create a callback to look-up function in our cache,
# and keep track of the method instances we needed.
method_instances = []
method_instances = Any[]
if Sys.ARCH == :x86 || Sys.ARCH == :x86_64
function lookup_fun(mi, min_world, max_world)
push!(method_instances, mi)
Expand Down Expand Up @@ -659,7 +785,6 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
end

# process all compiled method instances
compiled = Dict()
for mi in method_instances
ci = ci_cache_lookup(cache, mi, job.world, job.world)
ci === nothing && continue
Expand Down Expand Up @@ -696,10 +821,34 @@ function compile_method_instance(@nospecialize(job::CompilerJob))
compiled[mi] = (; ci, func=llvm_func, specfunc=llvm_specfunc)
end

# Collect the deferred edges
outstanding = Any[]
for mi in method_instances
!haskey(compiled, mi) && continue # Equivalent to ci_cache_lookup == nothing
ci = compiled[mi].ci
@static if VERSION >= v"1.11.0-"
edges = CC.traverse_analysis_results(ci) do @nospecialize result
return result isa DeferredEdges ? result : return
end
else
edges = ci.argescapes
if !(edges isa Union{Nothing, DeferredEdges})
edges = nothing
end
end
if edges !== nothing
for deferred_mi in (edges::DeferredEdges).edges
if !haskey(compiled, deferred_mi)
push!(outstanding, deferred_mi)
end
end
end
end

# ensure that the requested method instance was compiled
@assert haskey(compiled, job.source)

return llvm_mod, compiled
return llvm_mod, outstanding
end

# partially revert JuliaLangjulia#49391
Expand Down
Loading

0 comments on commit e532812

Please sign in to comment.