diff --git a/examples/jit.jl b/examples/jit.jl index 8a70a543..6f0a259c 100644 --- a/examples/jit.jl +++ b/examples/jit.jl @@ -116,31 +116,31 @@ function get_trampoline(job) return addr end -import GPUCompiler: deferred_codegen_jobs -@generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F,tt,world} - # manual version of native_job because we have a function type - source = methodinstance(F, Base.to_tuple_type(tt), world) - target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true) - # XXX: do we actually require the Julia runtime? - # with jlruntime=false, we reach an unreachable. - params = TestCompilerParams() - config = CompilerConfig(target, params; kernel=false) - job = CompilerJob(source, config, world) - # XXX: invoking GPUCompiler from a generated function is not allowed! - # for things to work, we need to forward the correct world, at least. - - addr = get_trampoline(job) - trampoline = pointer(addr) - id = Base.reinterpret(Int, trampoline) - - deferred_codegen_jobs[id] = job - - quote - ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $trampoline) - assume(ptr != C_NULL) - return ptr - end -end +# import GPUCompiler: deferred_codegen_jobs +# @generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F,tt,world} +# # manual version of native_job because we have a function type +# source = methodinstance(F, Base.to_tuple_type(tt), world) +# target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true) +# # XXX: do we actually require the Julia runtime? +# # with jlruntime=false, we reach an unreachable. +# params = TestCompilerParams() +# config = CompilerConfig(target, params; kernel=false) +# job = CompilerJob(source, config, world) +# # XXX: invoking GPUCompiler from a generated function is not allowed! +# # for things to work, we need to forward the correct world, at least. + +# addr = get_trampoline(job) +# trampoline = pointer(addr) +# id = Base.reinterpret(Int, trampoline) + +# deferred_codegen_jobs[id] = job + +# quote +# ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $trampoline) +# assume(ptr != C_NULL) +# return ptr +# end +# end @generated function abi_call(f::Ptr{Cvoid}, rt::Type{RT}, tt::Type{T}, func::F, args::Vararg{Any, N}) where {T, RT, F, N} argtt = tt.parameters[1] @@ -224,8 +224,9 @@ end @inline function call_delayed(f::F, args...) where F tt = Tuple{map(Core.Typeof, args)...} rt = Core.Compiler.return_type(f, tt) - world = GPUCompiler.tls_world_age() - ptr = deferred_codegen(f, Val(tt), Val(world)) + # FIXME: Horrible idea, have `var"gpuc.deferred"` actually do the work + # But that will only be needed here, and in Enzyme... + ptr = GPUCompiler.var"gpuc.deferred"(f, args...) abi_call(ptr, rt, tt, f, args...) end diff --git a/src/driver.jl b/src/driver.jl index 9e05eb63..728e2763 100644 --- a/src/driver.jl +++ b/src/driver.jl @@ -39,6 +39,17 @@ function JuliaContext(f; kwargs...) end +## deferred compilation + +""" + var"gpuc.deferred"(f, args...)::Ptr{Cvoid} + +As if we were to call `f(args...)` but instead we are +putting down a marker and return a function pointer to later +call. +""" +function var"gpuc.deferred" end + ## compiler entrypoint export compile @@ -127,33 +138,6 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool error("Unknown compilation output $output") end -# primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism. -# this could both be generalized (e.g. supporting actual function calls, instead of -# returning a function pointer), and be integrated with the nonrecursive codegen. -const deferred_codegen_jobs = Dict{Int, Any}() - -# We make this function explicitly callable so that we can drive OrcJIT's -# lazy compilation from, while also enabling recursive compilation. -Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid}) - ptr -end - -@generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt} - id = length(deferred_codegen_jobs) + 1 - deferred_codegen_jobs[id] = (; ft, tt) - # don't bother looking up the method instance, as we'll do so again during codegen - # using the world age of the parent. - # - # this also works around an issue on <1.10, where we don't know the world age of - # generated functions so use the current world counter, which may be too new - # for the world we're compiling for. - - quote - # TODO: add an edge to this method instance to support method redefinitions - ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), $id) - end -end - const __llvm_initialized = Ref(false) @locked function emit_llvm(@nospecialize(job::CompilerJob); toplevel::Bool, @@ -178,78 +162,74 @@ const __llvm_initialized = Ref(false) entry = functions(ir)[entry_fn] end - # finalize the current module. this needs to happen before linking deferred modules, - # since those modules have been finalized themselves, and we don't want to re-finalize. + # finalize the current module. entry = finish_module!(job, ir, entry) - # deferred code generation - has_deferred_jobs = toplevel && !only_entry && haskey(functions(ir), "deferred_codegen") - jobs = Dict{CompilerJob, String}(job => entry_fn) - if has_deferred_jobs - dyn_marker = functions(ir)["deferred_codegen"] - - # iterative compilation (non-recursive) - changed = true - while changed - changed = false - - # find deferred compiler - # TODO: recover this information earlier, from the Julia IR - worklist = Dict{CompilerJob, Vector{LLVM.CallInst}}() - for use in uses(dyn_marker) - # decode the call - call = user(use)::LLVM.CallInst - id = convert(Int, first(operands(call))) - - global deferred_codegen_jobs - dyn_val = deferred_codegen_jobs[id] - - # get a job in the appopriate world - dyn_job = if dyn_val isa CompilerJob - # trust that the user knows what they're doing - dyn_val + # rewrite "gpuc.lookup" for deferred code generation + run_optimization_for_deferred = false + if haskey(functions(ir), "gpuc.lookup") + run_optimization_for_deferred = true + dyn_marker = functions(ir)["gpuc.lookup"] + + # gpuc.deferred is lowered to a gpuc.lookup foreigncall, so we need to extract the + # target method instance from the LLVM IR + function find_base_object(val) + while true + if val isa ConstantExpr && (opcode(val) == LLVM.API.LLVMIntToPtr || + opcode(val) == LLVM.API.LLVMBitCast || + opcode(val) == LLVM.API.LLVMAddrSpaceCast) + val = first(operands(val)) + elseif val isa LLVM.IntToPtrInst || + val isa LLVM.BitCastInst || + val isa LLVM.AddrSpaceCastInst + val = first(operands(val)) + elseif val isa LLVM.LoadInst + # In 1.11+ we no longer embed integer constants directly. + gv = first(operands(val)) + if gv isa LLVM.GlobalValue + val = LLVM.initializer(gv) + continue + end + break else - ft, tt = dyn_val - dyn_src = methodinstance(ft, tt, tls_world_age()) - CompilerJob(dyn_src, job.config) + break end - - push!(get!(worklist, dyn_job, LLVM.CallInst[]), call) end + return val + end - # compile and link - for dyn_job in keys(worklist) - # cached compilation - dyn_entry_fn = get!(jobs, dyn_job) do - dyn_ir, dyn_meta = codegen(:llvm, dyn_job; toplevel=false, - parent_job=job) - dyn_entry_fn = LLVM.name(dyn_meta.entry) - merge!(compiled, dyn_meta.compiled) - @assert context(dyn_ir) == context(ir) - link!(ir, dyn_ir) - changed = true - dyn_entry_fn - end - dyn_entry = functions(ir)[dyn_entry_fn] - - # insert a pointer to the function everywhere the entry is used - T_ptr = convert(LLVMType, Ptr{Cvoid}) - for call in worklist[dyn_job] - @dispose builder=IRBuilder() begin - position!(builder, call) - fptr = if LLVM.version() >= v"17" - T_ptr = LLVM.PointerType() - bitcast!(builder, dyn_entry, T_ptr) - elseif VERSION >= v"1.12.0-DEV.225" - T_ptr = LLVM.PointerType(LLVM.Int8Type()) - bitcast!(builder, dyn_entry, T_ptr) - else - ptrtoint!(builder, dyn_entry, T_ptr) - end - replace_uses!(call, fptr) + worklist = Dict{Any, Vector{LLVM.CallInst}}() + for use in uses(dyn_marker) + # decode the call + call = user(use)::LLVM.CallInst + dyn_mi_inst = find_base_object(operands(call)[1]) + @compiler_assert isa(dyn_mi_inst, LLVM.ConstantInt) job + dyn_mi = Base.unsafe_pointer_to_objref( + convert(Ptr{Cvoid}, convert(Int, dyn_mi_inst))) + push!(get!(worklist, dyn_mi, LLVM.CallInst[]), call) + end + + for dyn_mi in keys(worklist) + dyn_fn_name = compiled[dyn_mi].specfunc + dyn_fn = functions(ir)[dyn_fn_name] + + # insert a pointer to the function everywhere the entry is used + T_ptr = convert(LLVMType, Ptr{Cvoid}) + for call in worklist[dyn_mi] + @dispose builder=IRBuilder() begin + position!(builder, call) + fptr = if LLVM.version() >= v"17" + T_ptr = LLVM.PointerType() + bitcast!(builder, dyn_fn, T_ptr) + elseif VERSION >= v"1.12.0-DEV.225" + T_ptr = LLVM.PointerType(LLVM.Int8Type()) + bitcast!(builder, dyn_fn, T_ptr) + else + ptrtoint!(builder, dyn_fn, T_ptr) end - erase!(call) + replace_uses!(call, fptr) end + erase!( call) end end @@ -285,7 +265,7 @@ const __llvm_initialized = Ref(false) # global variables. this makes sure that the optimizer can, e.g., # rewrite function signatures. if toplevel - preserved_gvs = collect(values(jobs)) + preserved_gvs = [entry_fn] for gvar in globals(ir) if linkage(gvar) == LLVM.API.LLVMExternalLinkage push!(preserved_gvs, LLVM.name(gvar)) @@ -317,7 +297,7 @@ const __llvm_initialized = Ref(false) # deferred codegen has some special optimization requirements, # which also need to happen _after_ regular optimization. # XXX: make these part of the optimizer pipeline? - if has_deferred_jobs + if run_optimization_for_deferred @dispose pb=NewPMPassBuilder() begin add!(pb, NewPMFunctionPassManager()) do fpm add!(fpm, InstCombinePass()) @@ -353,15 +333,15 @@ const __llvm_initialized = Ref(false) # finish the module # # we want to finish the module after optimization, so we cannot do so - # during deferred code generation. instead, process the deferred jobs - # here. + # during deferred code generation. Instead, process the merged module + # from all the jobs here. if toplevel entry = finish_ir!(job, ir, entry) - for (job′, fn′) in jobs - job′ == job && continue - finish_ir!(job′, ir, functions(ir)[fn′]) - end + # for (job′, fn′) in jobs + # job′ == job && continue + # finish_ir!(job′, ir, functions(ir)[fn′]) + # end end # replace non-entry function definitions with a declaration diff --git a/src/irgen.jl b/src/irgen.jl index 7d8ee4be..874ed961 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -80,6 +80,19 @@ function irgen(@nospecialize(job::CompilerJob)) compiled[job.source] = (; compiled[job.source].ci, func, specfunc) + # Earlier we sanitize global names, this invalidates the + # func, specfunc names safed in compiled. Update the names now, + # such that when when use the compiled mappings to lookup the + # llvm function for a methodinstance (deferred codegen) we have + # valid targets. + for mi in keys(compiled) + mi == job.source && continue + ci, func, specfunc = compiled[mi] + compiled[mi] = (; ci, func=safe_name(func), specfunc=safe_name(specfunc)) + end + + # TODO: Should we rewrite gpuc.lookup here? + # minimal required optimization @timeit_debug to "rewrite" begin if job.config.kernel && needs_byval(job) @@ -95,9 +108,16 @@ function irgen(@nospecialize(job::CompilerJob)) end end - # internalize all functions and, but keep exported global variables. + # internalize all functions, but keep exported global variables. linkage!(entry, LLVM.API.LLVMExternalLinkage) preserved_gvs = String[LLVM.name(entry)] + for mi in keys(compiled) + # delay internalizing of deferred calls since + # gpuc.lookup is not yet rewriten. + mi == job.source && continue + _, _, specfunc = compiled[mi] + push!(preserved_gvs, specfunc) # this could be deleted if we rewrite gpuc.lookup earlier + end for gvar in globals(mod) push!(preserved_gvs, LLVM.name(gvar)) end diff --git a/src/jlgen.jl b/src/jlgen.jl index a34bd42e..c6be8c94 100644 --- a/src/jlgen.jl +++ b/src/jlgen.jl @@ -1,6 +1,5 @@ # Julia compiler integration - ## world age lookups # `tls_world_age` should be used to look up the current world age. in most cases, this is @@ -12,6 +11,7 @@ else tls_world_age() = ccall(:jl_get_tls_world_age, UInt, ()) end + ## looking up method instances export methodinstance, generic_methodinstance @@ -159,6 +159,7 @@ end ## code instance cache + const HAS_INTEGRATED_CACHE = VERSION >= v"1.11.0-DEV.1552" if !HAS_INTEGRATED_CACHE @@ -436,6 +437,112 @@ function CC.concrete_eval_eligible(interp::GPUInterpreter, end +## deferred compilation + +struct DeferredCallInfo <: CC.CallInfo + rt::DataType + info::CC.CallInfo +end + +# recognize calls to gpuc.deferred and save DeferredCallInfo metadata +function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f), + arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState, + max_methods::Int = CC.get_max_methods(interp, f, sv)) + (; fargs, argtypes) = arginfo + if f === var"gpuc.deferred" + argvec = argtypes[2:end] + call = CC.abstract_call(interp, CC.ArgInfo(nothing, argvec), si, sv, max_methods) + callinfo = DeferredCallInfo(call.rt, call.info) + @static if VERSION < v"1.11.0-" + return CC.CallMeta(Ptr{Cvoid}, CC.Effects(), callinfo) + else + return CC.CallMeta(Ptr{Cvoid}, Union{}, CC.Effects(), callinfo) + end + end + return @invoke CC.abstract_call_known(interp::CC.AbstractInterpreter, f, + arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState, + max_methods::Int) +end + +# during inlining, refine deferred calls to gpuc.lookup foreigncalls +const FlagType = VERSION >= v"1.11.0-" ? UInt32 : UInt8 +function CC.handle_call!(todo::Vector{Pair{Int,Any}}, ir::CC.IRCode, idx::CC.Int, + stmt::Expr, info::DeferredCallInfo, flag::FlagType, + sig::CC.Signature, state::CC.InliningState) + minfo = info.info + results = minfo.results + if length(results.matches) != 1 + return nothing + end + match = only(results.matches) + + # lookup the target mi with correct edge tracking + case = CC.compileable_specialization(match, CC.Effects(), CC.InliningEdgeTracker(state), + info) + @assert case isa CC.InvokeCase + @assert stmt.head === :call + + args = Any[ + "extern gpuc.lookup", + Ptr{Cvoid}, + Core.svec(Any, Any, match.spec_types.parameters[2:end]...), # Must use Any for MethodInstance or ftype + 0, + QuoteNode(:llvmcall), + case.invoke, + stmt.args[2:end]... + ] + stmt.head = :foreigncall + stmt.args = args + return nothing +end + +struct DeferredEdges + edges::Vector{MethodInstance} +end + +function find_deferred_edges(ir::CC.IRCode) + edges = MethodInstance[] + # XXX: can we add this instead in handle_call? + for stmt in ir.stmts + inst = stmt[:inst] + inst isa Expr || continue + expr = inst::Expr + if expr.head === :foreigncall && + expr.args[1] == "extern gpuc.lookup" + deferred_mi = expr.args[6] + push!(edges, deferred_mi) + end + end + unique!(edges) + return edges +end + +if VERSION >= v"1.11.0-" +function CC.ipo_dataflow_analysis!(interp::GPUInterpreter, ir::CC.IRCode, + caller::CC.InferenceResult) + edges = find_deferred_edges(ir) + if !isempty(edges) + CC.stack_analysis_result!(caller, DeferredEdges(edges)) + end + @invoke CC.ipo_dataflow_analysis!(interp::CC.AbstractInterpreter, ir::CC.IRCode, + caller::CC.InferenceResult) +end +else # v1.10 +# 1.10 doesn't have stack_analysis_result or ipo_dataflow_analysis +function CC.finish(interp::GPUInterpreter, opt::CC.OptimizationState, ir::CC.IRCode, + caller::CC.InferenceResult) + edges = find_deferred_edges(ir) + if !isempty(edges) + # HACK: we store the deferred edges in the argescapes field, which is invalid, + # but nobody should be running EA on our results. + caller.argescapes = DeferredEdges(edges) + end + @invoke CC.finish(interp::CC.AbstractInterpreter, opt::CC.OptimizationState, + ir::CC.IRCode, caller::CC.InferenceResult) +end +end + + ## world view of the cache using Core.Compiler: WorldView @@ -584,6 +691,30 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) error("Cannot compile $(job.source) for world $(job.world); method is only valid in worlds $(job.source.def.primary_world) to $(job.source.def.deleted_world)") end + # A poor man's worklist implementation. + # `compiled` contains a mapping from `mi->ci, func, specfunc` + # FIXME: Since we are disabling Julia internal caching we might + # generate for the same mi multiple LLVM functions. + # `outstanding` are the missing edges that were not compiled by `compile_method_instance` + # Currently these edges are generated through deferred codegen. + compiled = IdDict() + llvm_mod, outstanding = compile_method_instance(job, compiled) + worklist = outstanding + while !isempty(worklist) + source = pop!(worklist) + haskey(compiled, source) && continue # We have fulfilled the request already + # Create a new compiler job for this edge, reusing the config settings from the inital one + job2 = CompilerJob(source, job.config) + llvm_mod2, outstanding = compile_method_instance(job2, compiled) + append!(worklist, outstanding) # merge worklist with new outstanding edges + @assert context(llvm_mod) == context(llvm_mod2) + link!(llvm_mod, llvm_mod2) + end + + return llvm_mod, compiled +end + +function compile_method_instance(@nospecialize(job::CompilerJob), compiled::IdDict{Any, Any}) # populate the cache interp = get_interpreter(job) cache = CC.code_cache(interp) @@ -594,7 +725,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) # create a callback to look-up function in our cache, # and keep track of the method instances we needed. - method_instances = [] + method_instances = Any[] if Sys.ARCH == :x86 || Sys.ARCH == :x86_64 function lookup_fun(mi, min_world, max_world) push!(method_instances, mi) @@ -659,7 +790,6 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) end # process all compiled method instances - compiled = Dict() for mi in method_instances ci = ci_cache_lookup(cache, mi, job.world, job.world) ci === nothing && continue @@ -693,13 +823,39 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) # NOTE: it's not safe to store raw LLVM functions here, since those may get # removed or renamed during optimization, so we store their name instead. + # FIXME: Enable this assert when we have a fully featured worklist + # @assert !haskey(compiled, mi) compiled[mi] = (; ci, func=llvm_func, specfunc=llvm_specfunc) end + # Collect the deferred edges + outstanding = Any[] + for mi in method_instances + !haskey(compiled, mi) && continue # Equivalent to ci_cache_lookup == nothing + ci = compiled[mi].ci + @static if VERSION >= v"1.11.0-" + edges = CC.traverse_analysis_results(ci) do @nospecialize result + return result isa DeferredEdges ? result : return + end + else + edges = ci.argescapes + if !(edges isa Union{Nothing, DeferredEdges}) + edges = nothing + end + end + if edges !== nothing + for deferred_mi in (edges::DeferredEdges).edges + if !haskey(compiled, deferred_mi) + push!(outstanding, deferred_mi) + end + end + end + end + # ensure that the requested method instance was compiled @assert haskey(compiled, job.source) - return llvm_mod, compiled + return llvm_mod, outstanding end # partially revert JuliaLangjulia#49391 diff --git a/test/native_tests.jl b/test/native_tests.jl index 2b4b8b48..cd4a20c0 100644 --- a/test/native_tests.jl +++ b/test/native_tests.jl @@ -162,6 +162,22 @@ end ir = fetch(t) @test contains(ir, r"add i64 %\d+, 3") end + + @testset "deferred" begin + @gensym child kernel unrelated + @eval @noinline $child(i) = i + @eval $kernel(i) = GPUCompiler.var"gpuc.deferred"($child, i) + + # smoke test + job, _ = Native.create_job(eval(kernel), (Int64,)) + + # TODO: Add a `kernel=true` test + + ci, rt = only(GPUCompiler.code_typed(job)) + @test rt === Ptr{Cvoid} + + ir = sprint(io->GPUCompiler.code_llvm(io, job)) + end end ############################################################################################