diff --git a/src/compiler_plugin.jl b/src/compiler_plugin.jl new file mode 100644 index 00000000..7418e8ea --- /dev/null +++ b/src/compiler_plugin.jl @@ -0,0 +1,273 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +# This is a forward port from https://github.com/JuliaLang/julia/pull/52964 +module CCMixin + +import Core.Compiler as CC +import .CC: NativeInterpreter, AbstractInterpreter, ArgInfo, StmtInfo, AbsIntState, CallMeta, Effects, + get_max_methods, Const, method_table, MethodResultPure, CallInfo, singleton_type + +@static if VERSION >= v"1.11.0-DEV.1498" + import Core.Compiler: get_inference_world + using Base: get_world_counter +else + import Core.Compiler: get_world_counter, get_world_counter as get_inference_world +end + +abstract type AbstractCompiler end +const CompilerInstance = Union{AbstractCompiler, Nothing} +const NativeCompiler = Nothing + +# current_compiler() = ccall(:jl_get_current_task, Ref{Task}, ()).compiler::CompilerInstance + +""" + abstract_interpreter(::CompilerInstance, world::UInt) + +Construct an appropriate abstract interpreter for the given compiler instance. +""" +function abstract_interpreter end + +abstract_interpreter(::Nothing, world::UInt) = NativeInterpreter(world) + +""" + compiler_world(::CompilerInstance) + +The compiler world to execute this compiler instance in. +""" +compiler_world(::Nothing) = unsafe_load(cglobal(:jl_typeinf_world, UInt)) +compiler_world(::AbstractCompiler) = get_world_counter() # equivalent to invokelatest + +struct WithinCallInfo <: CallInfo + compiler::CompilerInstance + info::CallInfo +end + +function _call_within end + + +""" + invoke_within(compiler, f, args...; kwargs...) + +Call `f(args...; kwargs...)` within the compiler context provided by `compiler`. +""" +function invoke_within(compiler::CompilerInstance, @nospecialize(f), @nospecialize args...; kwargs...) + kwargs = Base.merge(NamedTuple(), kwargs) + if isempty(kwargs) + return _call_within(compiler, f, args...) + end + return _call_within(compiler, Core.kwcall, kwargs, f, args...) +end + +function abstract_call_within(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo, si::StmtInfo, + sv::AbsIntState, max_methods::Int=get_max_methods(interp, sv)) + if length(argtypes) < 2 + @static if VERSION < v"1.11.0-" + return CallMeta(Union{}, Effects(), NoCallInfo()) + else + return CallMeta(Union{}, Union{}, Effects(), NoCallInfo()) + end + end + CT = argtypes[2] + other_compiler = singleton_type(CT) + if other_compiler === nothing + if CT isa Const + other_compiler = CT.val + else + # Compiler is not a singleton type result may depend on runtime configuration + add_remark!(interp, sv, "Skipped call_within since compiler plugin not constant") + return CallMeta(Any, Any, Effects(), NoCallInfo()) + end + end + # Change world to one where our methods exist. + cworld = invokelatest(compiler_world, other_compiler)::UInt + other_interp = Core._call_in_world(cworld, abstract_interpreter, other_compiler, get_inference_world(interp)) + other_fargs = fargs === nothing ? nothing : fargs[3:end] + other_arginfo = ArgInfo(other_fargs, argtypes[3:end]) + call = Core._call_in_world(cworld, CC.abstract_call, other_interp, other_arginfo, si, sv, max_methods) + # TODO: Edges? Effects? + @static if VERSION < v"1.11.0-" + return CallMeta(call.rt, call.effects, WithinCallInfo(other_compiler, call.info)) + else + return CallMeta(call.rt, call.exct, call.effects, WithinCallInfo(other_compiler, call.info)) + end +end + +Base.getindex(ir::CC.IRCode, idx::Core.SSAValue) = CC.getindex(ir, idx) +Base.setindex!(inst::CC.Instruction, val::UInt8, idx::Symbol) = CC.setindex!(inst, val, idx) + +# allow inling of WithinCallInfo, why not +const FlagType = VERSION >= v"1.11.0-" ? UInt32 : UInt8 +function CC.handle_call!(todo::Vector{Pair{Int,Any}}, ir::CC.IRCode, idx::CC.Int, + stmt::Expr, info::WithinCallInfo, flag::FlagType, + sig::CC.Signature, state::CC.InliningState) + # I failed at inlining the call, codegen currently can't handle call_within so we have to + # handle it ourselves. + minfo = info.info + if !(minfo isa CC.MethodMatchInfo) + return nothing + end + results = minfo.results + if length(results.matches) != 1 + return nothing + end + match = only(results.matches) + + # lookup the target mi with correct edge tracking + # do we need to do this within the other compiler? + case = CC.compileable_specialization(match, CC.Effects(), CC.InliningEdgeTracker(state), info) + + @assert case isa CC.InvokeCase + @assert stmt.head === :call + + args = Any[ + "extern gpuc.call_within", + ir[CC.SSAValue(idx)][:type], + Core.svec(Any, Any, Any, match.spec_types.parameters[2:end]...), # Must use Any for MethodInstance or ftype + 0, + QuoteNode(:llvmcall), + info.compiler, # we could also use the compiler as passed in stmt.args[2] + case.invoke, + stmt.args[3:end]... + ] + stmt.head = :foreigncall + stmt.args = args + ir[CC.SSAValue(idx)][:flag] |= CC.flags_for_effects(case.effects) + return nothing + + # info = info.info + # @assert info.in isa CC.MethodMatchInfo + # results = info.results + # match = only(results.matches) + # @show match + # new_argtypes = sig.argtypes[3:end] + # item = CC.analyze_method!(match, new_argtypes, info, flag, state; allow_typevars=false) + # @assert item isa CC.InvokeCase + # # handle_single_case inlined + # stmt.head = :invoke + # stmt.args = stmt.args[3:end] + # pushfirst!(stmt.args, item.invoke) + # ir[CC.SSAValue(idx)][:flag] |= CC.flags_for_effects(item.effects) + # return nothing + # @show todo + # @show res + # return res + # @show match + # error("") + # ft = sig.argtypes[3] + # f = singleton_type(ft) + # if f === nothing + # if ft isa Const + # f = ft.val + # else + # error("") + # # # Compiler is not a singleton type result may depend on runtime configuration + # # add_remark!(interp, sv, "Skipped call_within since compiler plugin not constant") + # # return CallMeta(Any, Any, Effects(), NoCallInfo()) + # end + # end + # new_sig = CC.Signature(f, CC.widenconst(ft), sig.argtypes[3:end]) + # stmt.args = stmt.args[3:end] + # @show new_sig = CC.call_sig(ir, stmt) + # # @show info.info + # res = CC.handle_call!(todo, ir, idx, stmt, info.info, flag, new_sig, state) + # @show res + # @show todo + # return res + # # new_stmt = Expr(stmt.head, stmt.args[3:end]) + # @show stmt.head + # if stmt.head === :invoke + # @show new_stmt + # res = CC.handle_invoke_expr!(todo, ir, idx, new_stmt, info.info, flag, new_sig, state) + # else + # res = CC.handle_call!(todo, ir, idx, new_stmt, info.info, flag, new_sig, state) + # end + # @show res + # return res +end + +struct Edges + edges::Vector{Tuple{CompilerInstance, CC.MethodInstance}} +end + +function find_edges(ir::CC.IRCode) + edges = Tuple{CompilerInstance, CC.MethodInstance}[] + # XXX: can we add this instead in handle_call? + for stmt in ir.stmts + inst = stmt[:inst] + inst isa Expr || continue + expr = inst::Expr + if expr.head === :foreigncall && + expr.args[1] == "extern gpuc.call_within" + @show expr + mi = expr.args[7] + compiler = expr.args[6] + push!(edges, (compiler, mi)) + end + end + unique!(edges) + return edges +end + +if VERSION >= v"1.11.0-" +function CC.ipo_dataflow_analysis!(interp::AbstractGPUInterpreter, ir::CC.IRCode, + caller::CC.InferenceResult) + edges = find_edges(ir) + if !isempty(edges) + CC.stack_analysis_result!(caller, Edges(edges)) + end + @invoke CC.ipo_dataflow_analysis!(interp::CC.AbstractInterpreter, ir::CC.IRCode, + caller::CC.InferenceResult) +end +else # v1.10 +# 1.10 doesn't have stack_analysis_result or ipo_dataflow_analysis +function CC.finish(interp::AbstractGPUInterpreter, opt::CC.OptimizationState, ir::CC.IRCode, + caller::CC.InferenceResult) + edges = find_edges(ir) + if !isempty(edges) + # HACK: we store the deferred edges in the argescapes field, which is invalid, + # but nobody should be running EA on our results. + caller.argescapes = Edges(edges) + end + @invoke CC.finish(interp::CC.AbstractInterpreter, opt::CC.OptimizationState, + ir::CC.IRCode, caller::CC.InferenceResult) +end +end + +function current_method_table end + +function abstract_call_current_method_table(interp::AbstractInterpreter, (; fargs, argtypes)::ArgInfo, si::StmtInfo, + sv::AbsIntState, max_methods::Int=get_max_methods(interp, sv)) + if length(argtypes) != 1 + @static if VERSION < v"1.11.0-" + return CallMeta(Union{}, Effects(), NoCallInfo()) + else + return CallMeta(Union{}, Union{}, Effects(), NoCallInfo()) + end + end + mt = Const(method_table(interp)) + @static if VERSION < v"1.11.0-" + return CallMeta(mt, CC.EFFECTS_TOTAL, MethodResultPure()) + else + return CallMeta(mt, Union{}, CC.EFFECTS_TOTAL, MethodResultPure()) + end +end + + + +abstract type AbstractGPUInterpreter <: AbstractInterpreter end + +function CC.abstract_call_known(interp::AbstractGPUInterpreter, @nospecialize(f), + arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, + max_methods::Int = CC.get_max_methods(interp, f, sv)) + (; fargs, argtypes) = arginfo + if f === _call_within + return abstract_call_within(interp, arginfo, si, sv, max_methods) + elseif f === current_method_table + return abstract_call_current_method_table(interp, arginfo, si, sv, max_methods) + end + return @invoke CC.abstract_call_known(interp::AbstractInterpreter, f, + arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, + max_methods::Int) + end +end + diff --git a/src/jlgen.jl b/src/jlgen.jl index a34bd42e..0253c8e3 100644 --- a/src/jlgen.jl +++ b/src/jlgen.jl @@ -1,6 +1,10 @@ # Julia compiler integration +include("compiler_plugin.jl") +import .CCMixin +import .CCMixin: current_method_table + ## world age lookups # `tls_world_age` should be used to look up the current world age. in most cases, this is @@ -318,7 +322,7 @@ else get_method_table_view(world::UInt, mt::MTType) = OverlayMethodTable(world, mt) end -struct GPUInterpreter <: CC.AbstractInterpreter +struct GPUInterpreter <: CCMixin.AbstractGPUInterpreter world::UInt method_table::GPUMethodTableView diff --git a/test/compilerplugins_testsetup.jl b/test/compilerplugins_testsetup.jl new file mode 100644 index 00000000..477ae085 --- /dev/null +++ b/test/compilerplugins_testsetup.jl @@ -0,0 +1,137 @@ +@testsetup module CompilerPlugins + +using Test +using ReTestItems + +Base.Experimental.@MethodTable(FMAMT) + +for (jlf, f) in zip((:+, :*, :-), (:add, :mul, :sub)) + for (T, llvmT) in ((:Float32, "float"), (:Float64, "double")) + ir = """ + %x = f$f contract nsz $llvmT %0, %1 + ret $llvmT %x + """ + @eval begin + # the @pure is necessary so that we can constant propagate. + Base.Experimental.@overlay FMAMT @inline Base.@pure function $jlf(a::$T, b::$T) + Base.llvmcall($ir, $T, Tuple{$T, $T}, a, b) + end + end + end +end + +# Define Compiler plugin that will replace methods with their contract version + +import Core.Compiler as CC +import GPUCompiler.CCMixin +import GPUCompiler + +struct FMACompiler <: CCMixin.AbstractCompiler + parent_mt::CC.MethodTableView +end +FMACompiler() = FMACompiler(CCMixin.current_method_table()) +# CCMixin.compiler_world(::FMACompiler) = COMPILER_WORLD[] +CCMixin.abstract_interpreter(compiler::FMACompiler, world::UInt) = + FMAInterp(compiler; world) + +struct FMAInterp <: CCMixin.AbstractGPUInterpreter + compiler::FMACompiler +@static if !GPUCompiler.HAS_INTEGRATED_CACHE + code_cache::GPUCompiler.CodeCache +end + world::UInt + inf_params::CC.InferenceParams + opt_params::CC.OptimizationParams + inf_cache::Vector{CC.InferenceResult} + function FMAInterp(compiler::FMACompiler; + world::UInt = Base.get_world_counter(), + inf_params::CC.InferenceParams = CC.InferenceParams(), + opt_params::CC.OptimizationParams = CC.OptimizationParams(), + inf_cache::Vector{CC.InferenceResult} = CC.InferenceResult[]) + @static if !GPUCompiler.HAS_INTEGRATED_CACHE + # TODO get a cache here properly... + return new(compiler, GPUCompiler.CodeCache(), world, inf_params, opt_params, inf_cache) + end + return new(compiler, world, inf_params, opt_params, inf_cache) + end +end + +@static if VERSION >= v"1.11.0-DEV.1498" + import Core.Compiler: get_inference_world + using Base: get_world_counter +else + import Core.Compiler: get_world_counter, get_world_counter as get_inference_world +end + +CC.InferenceParams(interp::FMAInterp) = interp.inf_params +CC.OptimizationParams(interp::FMAInterp) = interp.opt_params +get_inference_world(interp::FMAInterp) = interp.world +CC.get_inference_cache(interp::FMAInterp) = interp.inf_cache +if GPUCompiler.HAS_INTEGRATED_CACHE + CC.cache_owner(interp::FMAInterp) = interp.compiler +else + CC.code_cache(interp::FMAInterp) = CC.WorldView(interp.code_cache, interp.world) +end +CC.method_table(interp::FMAInterp) = StackedMethodTable(get_inference_world(interp), FMAMT, interp.compiler.parent_mt) + + + + +# vchuravy/Shenanigans.jl + +# In a stack MT the lower one takes priority + +import Core: MethodTable +import Core.Compiler: MethodTableView, InternalMethodTable, + MethodMatchResult, MethodLookupResult, WorldRange +struct StackedMethodTable{MTV<:MethodTableView} <: MethodTableView + world::UInt + mt::MethodTable + parent::MTV +end +StackedMethodTable(world::UInt, mt::MethodTable) = StackedMethodTable(world, mt, InternalMethodTable(world)) +StackedMethodTable(world::UInt, mt::MethodTable, parent::MethodTable) = StackedMethodTable(world, mt, StackedMethodTable(world, parent)) + +import Core.Compiler: findall, _findall, length, vcat, isempty, max, min, getindex +function findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int=-1) + result = _findall(sig, table.mt, table.world, limit) + result === nothing && return nothing # to many matches + nr = length(result) + if nr ≥ 1 && getindex(result, nr).fully_covers + # no need to fall back to the parent method view + return MethodMatchResult(result, true) + end + + parent_result = findall(sig, table.parent; limit)::Union{Nothing, MethodMatchResult} + parent_result === nothing && return nothing #too many matches + + overlayed = parent_result.overlayed | !isempty(result) + parent_result = parent_result.matches::MethodLookupResult + + # merge the parent match results with the internal method table + return MethodMatchResult( + MethodLookupResult( + vcat(result.matches, parent_result.matches), + WorldRange( + max(result.valid_worlds.min_world, parent_result.valid_worlds.min_world), + min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world)), + result.ambig | parent_result.ambig), + overlayed) +end + +import Core.Compiler: isoverlayed +isoverlayed(::StackedMethodTable) = true + +import Core.Compiler: findsup, _findsup +function findsup(@nospecialize(sig::Type), table::StackedMethodTable) + match, valid_worlds = _findsup(sig, table.mt, table.world) + match !== nothing && return match, valid_worlds, true + # look up in parent + parent_match, parent_valid_worlds, overlayed = findsup(sig, table.parent) + return ( + parent_match, + WorldRange( + max(valid_worlds.min_world, parent_valid_worlds.min_world), + min(valid_worlds.max_world, parent_valid_worlds.max_world)), + overlayed) +end diff --git a/test/ptx_tests.jl b/test/ptx_tests.jl index c059ba60..d607c11c 100644 --- a/test/ptx_tests.jl +++ b/test/ptx_tests.jl @@ -323,6 +323,19 @@ end end end # testitem +@testitem "Compiler Plugins" setup=[PTX, CompilerPlugins] begin + function kernel(ptr, a, b, c) + unsafe_store!(ptr, a*b+c) + return + end + ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Ptr{Float32}, Float32, Float32, Float32})) + @test occursin("fmul float", ir) + @test occursin("fadd float", ir) + + + +end + @testitem "PTX precompile" setup=[Precompile,] begin precompile_test_harness("Inference caching") do load_path # Write out the PTX test setup as a micro package