From 1aed64e91d8e67bf07b612979ebb4338c03a705e Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Sat, 11 Nov 2023 21:04:32 +0100 Subject: [PATCH] remove compile_wasm and MixTape; allow specifying a method_table --- README.md | 23 ++++++----- src/StaticCompiler.jl | 60 ++++++++--------------------- src/interpreter.jl | 44 ++++++++------------- src/target.jl | 39 +++++++++---------- test/Project.toml | 1 - test/runtests.jl | 1 - test/testcore.jl | 32 ++++++++++++++++ test/testintegration.jl | 84 ----------------------------------------- 8 files changed, 95 insertions(+), 189 deletions(-) diff --git a/README.md b/README.md index c93095f3..d4816cb9 100644 --- a/README.md +++ b/README.md @@ -43,27 +43,31 @@ Sometimes, a julia function you want to statically compile will do things (such ```julia julia> using Libdl, StaticCompiler -julia> f(x) = x + 1; +julia> f(x) = g(x) + 1; -julia> @device_override Base.:(+)(x::Int, y::Int) = x - y +julia> g(x) = 2x + +julia> @device_override g(x::Int) = x - 10 julia> f(1) # Gives the expected answer in regular julia -2 +3 julia> dlopen(compile_shlib(f, (Int,), "./")) do lib fptr = dlsym(lib, "f") # Now use the compiled version where + is replaced with - @ccall $fptr(1::Int)::Int end -0 +-8 ``` -Typically, errors should be overrided and replaced with `@print_and_throw`, which is StaticCompiler friendle, i.e. +Typically, errors should be overrided and replaced with `@print_and_throw`, which is StaticCompiler friendly, i.e. we define overrides such as ``` julia @device_override @noinline Base.Math.throw_complex_domainerror(f::Symbol, x) = @print_and_throw c"This operation requires a complex input to return a complex result" ``` +If for some reason, you wish to use a different method table (defined with `Base.Experimental.@MethodTable` and `Base.Experimental.@overlay`) than the default one provided by StaticCompiler.jl, you can provide it to `compile_executable` and `compile_shlib` via a keyword argument `method_table`. + ## Approach @@ -98,7 +102,7 @@ To enable code to be statically compiled, consider the following: ## Guide for Statically Compiling Code -If you're trying to statically compile generic code, you may run into issues if that code uses features not supported by StaticCompiler. One option is to change the code you're calling using the tips above. If that is not easy, you may by able to compile it anyway. One option is to use method overrides to change what methods are called. Another option is to use the Mixtape feature to change problematic code as part of compilation. For example, you could convert all Strings to StaticStrings. +If you're trying to statically compile generic code, you may run into issues if that code uses features not supported by StaticCompiler. One option is to change the code you're calling using the tips above. If that is not easy, you may by able to compile it anyway. One option is to use method overlays to change what methods are called. [Cthulhu](https://github.com/JuliaDebug/Cthulhu.jl) is a great help in digging into code, finding type instabilities, and finding other sources of code that may break static compilation. @@ -106,9 +110,4 @@ If you're trying to statically compile generic code, you may run into issues if Because Julia objects follow C memory layouts, compiled libraries should be usable from most languages that can interface with C. For example, results should be usable with Python's [CFFI](https://cffi.readthedocs.io/en/latest/) package. -For WebAssembly, interface helpers are available at [WebAssemblyInterfaces](https://github.com/tshort/WebAssemblyInterfaces.jl). - - - - - +For WebAssembly, interface helpers are available at [WebAssemblyInterfaces](https://github.com/tshort/WebAssemblyInterfaces.jl), and users should also see [WebAssemblyCompiler](https://github.com/tshort/WebAssemblyCompiler.jl) for a package more focused on compilation of WebAssebly in general. diff --git a/src/StaticCompiler.jl b/src/StaticCompiler.jl index d38fe729..98f46733 100644 --- a/src/StaticCompiler.jl +++ b/src/StaticCompiler.jl @@ -11,13 +11,12 @@ using Clang_jll: clang using LLD_jll: lld using StaticTools using StaticTools: @symbolcall, @c_str, println - +using Core: MethodTable export load_function, compile_shlib, compile_executable, compile_wasm export native_code_llvm, native_code_typed, native_llvm_module, native_code_native export @device_override, @print_and_throw -include("mixtape.jl") include("interpreter.jl") include("target.jl") include("pointer_warning.jl") @@ -33,6 +32,7 @@ compile_executable(f::Function, types::Tuple, path::String, [name::String=string filename::String=name, cflags=``, # Specify libraries you would like to link against, and other compiler options here also_expose=[], + method_table=StaticCompiler.method_table, kwargs... ) ``` @@ -124,8 +124,18 @@ end """ ```julia -compile_shlib(f::Function, types::Tuple, [path::String="./"], [name::String=string(nameof(f))]; filename::String=name, cflags=``, kwargs...) -compile_shlib(funcs::Array, [path::String="./"]; filename="libfoo", demangle=true, cflags=``, kwargs...) +compile_shlib(f::Function, types::Tuple, [path::String="./"], [name::String=string(nameof(f))]; + filename::String=name, + cflags=``, + method_table=StaticCompiler.method_table, + kwargs...) + +compile_shlib(funcs::Array, [path::String="./"]; + filename="libfoo", + demangle=true, + cflags=``, + method_table=StaticCompiler.method_table, + kwargs...) ``` As `compile_executable`, but compiling to a standalone `.dylib`/`.so` shared library. @@ -184,38 +194,7 @@ function compile_shlib(funcs::Union{Array,Tuple}, path::String="./"; joinpath(abspath(path), filename * "." * Libdl.dlext) end - -""" -```julia -compile_wasm(f::Function, types::Tuple, [path::String="./"], [name::String=string(nameof(f))]; filename::String=name, flags=``, kwargs...) -compile_wasm(funcs::Union{Array,Tuple}, [path::String="./"]; filename="libfoo", demangle=true, flags=``, kwargs...) -``` -As `compile_shlib`, but compiling to a WebAssembly library. - -If `demangle` is set to `false`, compiled function names are prepended with "julia_". -``` -""" -function compile_wasm(f::Function, types=(); - path::String = "./", - filename = fix_name(f), - flags = ``, - kwargs... - ) - tt = Base.to_tuple_type(types) - obj_path, name = generate_obj(f, tt, true, path, filename; target = (triple = "wasm32-unknown-wasi", cpu = "", features = ""), remove_julia_addrspaces = true, kwargs...) - run(`$(lld()) -flavor wasm --no-entry --export-all $flags $obj_path/obj.o -o $path/$name.wasm`) - joinpath(abspath(path), filename * ".wasm") -end -function compile_wasm(funcs::Union{Array,Tuple}; - path::String="./", - filename="libfoo", - flags=``, - kwargs... - ) - obj_path, name = generate_obj(funcs, true, path, filename; target = (triple = "wasm32-unknown-wasi", cpu = "", features = ""), remove_julia_addrspaces = true, kwargs...) - run(`$(lld()) -flavor wasm --no-entry --export-all $flags $obj_path/$filename.o -o $path/$filename.wasm`) - joinpath(abspath(path), filename * ".wasm") -end + """ ```julia @@ -260,6 +239,7 @@ function generate_shlib_fptr(path::String, name, filename::String=name) @assert fptr != C_NULL fptr end + # As above, but also compile (maybe remove this method in the future?) function generate_shlib_fptr(f, tt, path::String=tempname(), name=fix_name(f), filename::String=name; temp::Bool=true, @@ -478,7 +458,6 @@ end """ ```julia generate_obj(f, tt, external::Bool, path::String = tempname(), filenamebase::String="obj"; - mixtape = NoContext(), target = (), demangle = true, strip_llvm = false, @@ -490,9 +469,6 @@ Low level interface for compiling object code (`.o`) for for function `f` given a tuple type `tt` characterizing the types of the arguments for which the function will be compiled. -`mixtape` defines a context that can be used to transform IR prior to compilation using -[Mixtape](https://github.com/JuliaCompilerPlugins/Mixtape.jl) features. - `target` can be used to change the output target. This is useful for compiling to WebAssembly and embedded targets. This is a named tuple with fields `triple`, `cpu`, and `features` (each of these are strings). The defaults compile to the native target. @@ -522,7 +498,6 @@ end """ ```julia generate_obj(funcs::Union{Array,Tuple}, external::Bool, path::String = tempname(), filenamebase::String="obj"; - mixtape = NoContext(), target = (), demangle =false, strip_llvm = false, @@ -534,9 +509,6 @@ Low level interface for compiling object code (`.o`) for an array of Tuples (f, tt) where each function `f` and tuple type `tt` determine the set of methods which will be compiled. -`mixtape` defines a context that can be used to transform IR prior to compilation using -[Mixtape](https://github.com/JuliaCompilerPlugins/Mixtape.jl) features. - `target` can be used to change the output target. This is useful for compiling to WebAssembly and embedded targets. This is a named tuple with fields `triple`, `cpu`, and `features` (each of these are strings). The defaults compile to the native target. diff --git a/src/interpreter.jl b/src/interpreter.jl index c0e00c37..344cc53d 100644 --- a/src/interpreter.jl +++ b/src/interpreter.jl @@ -7,7 +7,7 @@ using GPUCompiler: using CodeInfoTools using CodeInfoTools: resolve -struct StaticInterpreter{M} <: AbstractInterpreter +struct StaticInterpreter <: AbstractInterpreter global_cache::CodeCache method_table::Union{Nothing,Core.MethodTable} @@ -20,13 +20,10 @@ struct StaticInterpreter{M} <: AbstractInterpreter inf_params::InferenceParams opt_params::OptimizationParams - # Mixtape context - mixtape::M - - function StaticInterpreter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt, ip::InferenceParams, op::OptimizationParams, mixtape::CompilationContext) + function StaticInterpreter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt, ip::InferenceParams, op::OptimizationParams) @assert world <= Base.get_world_counter() - return new{typeof(mixtape)}( + return new( cache, mt, @@ -38,10 +35,7 @@ struct StaticInterpreter{M} <: AbstractInterpreter # parameters for inference and optimization ip, - op, - - # Mixtape context - mixtape + op ) end end @@ -79,9 +73,6 @@ function custom_pass!(interp::StaticInterpreter, result::InferenceResult, mi::Co mi.specTypes isa UnionAll && return src sig = Tuple(mi.specTypes.parameters) as = map(resolve_generic, sig) - if allow(interp.mixtape, mi.def.module, as...) - src = transform(interp.mixtape, src, sig) - end return src end @@ -102,22 +93,21 @@ end Core.Compiler.may_optimize(interp::StaticInterpreter) = true Core.Compiler.may_compress(interp::StaticInterpreter) = true Core.Compiler.may_discard_trees(interp::StaticInterpreter) = true -if VERSION >= v"1.7.0-DEV.577" Core.Compiler.verbose_stmt_info(interp::StaticInterpreter) = false -end + if isdefined(Base.Experimental, Symbol("@overlay")) -using Core.Compiler: OverlayMethodTable -if v"1.8-beta2" <= VERSION < v"1.9-" || VERSION >= v"1.9.0-DEV.120" -Core.Compiler.method_table(interp::StaticInterpreter) = - OverlayMethodTable(interp.world, interp.method_table) -else -Core.Compiler.method_table(interp::StaticInterpreter, sv::InferenceState) = - OverlayMethodTable(interp.world, interp.method_table) -end + using Core.Compiler: OverlayMethodTable + if v"1.8-beta2" <= VERSION < v"1.9-" || VERSION >= v"1.9.0-DEV.120" + Core.Compiler.method_table(interp::StaticInterpreter) = + OverlayMethodTable(interp.world, interp.method_table) + else + Core.Compiler.method_table(interp::StaticInterpreter, sv::InferenceState) = + OverlayMethodTable(interp.world, interp.method_table) + end else -Core.Compiler.method_table(interp::StaticInterpreter, sv::InferenceState) = - WorldOverlayMethodTable(interp.world) + Core.Compiler.method_table(interp::StaticInterpreter, sv::InferenceState) = + WorldOverlayMethodTable(interp.world) end # semi-concrete interepretation is broken with overlays (JuliaLang/julia#47349) @@ -134,13 +124,11 @@ end struct StaticCompilerParams <: AbstractCompilerParams opt::Bool optlevel::Int - mixtape::CompilationContext cache::CodeCache end function StaticCompilerParams(; opt = false, optlevel = Base.JLOptions().opt_level, - mixtape = NoContext(), cache = CodeCache()) - return StaticCompilerParams(opt, optlevel, mixtape, cache) + return StaticCompilerParams(opt, optlevel, cache) end diff --git a/src/target.jl b/src/target.jl index 80be6280..777960a3 100644 --- a/src/target.jl +++ b/src/target.jl @@ -29,16 +29,18 @@ macro device_override(ex) return esc(code) end -Base.@kwdef struct NativeCompilerTarget <: GPUCompiler.AbstractCompilerTarget +Base.@kwdef struct NativeCompilerTarget{MT} <: GPUCompiler.AbstractCompilerTarget triple::String=Sys.MACHINE cpu::String=(LLVM.version() < v"8") ? "" : unsafe_string(LLVM.API.LLVMGetHostCPUName()) features::String=(LLVM.version() < v"8") ? "" : unsafe_string(LLVM.API.LLVMGetHostCPUFeatures()) + method_table::MT = method_table end -Base.@kwdef struct ExternalNativeCompilerTarget <: GPUCompiler.AbstractCompilerTarget +Base.@kwdef struct ExternalNativeCompilerTarget{MT} <: GPUCompiler.AbstractCompilerTarget triple::String=Sys.MACHINE cpu::String=(LLVM.version() < v"8") ? "" : unsafe_string(LLVM.API.LLVMGetHostCPUName()) features::String=(LLVM.version() < v"8") ? "" : unsafe_string(LLVM.API.LLVMGetHostCPUFeatures()) + method_table::MT = method_table end module StaticRuntime @@ -66,44 +68,43 @@ for target in (:NativeCompilerTarget, :ExternalNativeCompilerTarget) return tm end - GPUCompiler.runtime_slug(job::GPUCompiler.CompilerJob{$target}) = "native_$(job.config.target.cpu)-$(hash(job.config.target.features))" + GPUCompiler.runtime_slug(job::GPUCompiler.CompilerJob{<:$target}) = "native_$(job.config.target.cpu)-$(hash(job.config.target.features))" - GPUCompiler.runtime_module(::GPUCompiler.CompilerJob{$target}) = StaticRuntime - GPUCompiler.runtime_module(::GPUCompiler.CompilerJob{$target, StaticCompilerParams}) = StaticRuntime + GPUCompiler.runtime_module(::GPUCompiler.CompilerJob{<:$target}) = StaticRuntime + GPUCompiler.runtime_module(::GPUCompiler.CompilerJob{<:$target, StaticCompilerParams}) = StaticRuntime - GPUCompiler.can_throw(job::GPUCompiler.CompilerJob{$target, StaticCompilerParams}) = true - GPUCompiler.can_throw(job::GPUCompiler.CompilerJob{$target}) = true + GPUCompiler.can_throw(job::GPUCompiler.CompilerJob{<:$target, StaticCompilerParams}) = true + GPUCompiler.can_throw(job::GPUCompiler.CompilerJob{<:$target}) = true - GPUCompiler.get_interpreter(job::GPUCompiler.CompilerJob{$target, StaticCompilerParams}) = + GPUCompiler.get_interpreter(job::GPUCompiler.CompilerJob{<:$target, StaticCompilerParams}) = StaticInterpreter(job.config.params.cache, GPUCompiler.method_table(job), job.world, - GPUCompiler.inference_params(job), GPUCompiler.optimization_params(job), - job.config.params.mixtape) - GPUCompiler.ci_cache(job::GPUCompiler.CompilerJob{$target, StaticCompilerParams}) = job.config.params.cache + GPUCompiler.inference_params(job), GPUCompiler.optimization_params(job)) + GPUCompiler.ci_cache(job::GPUCompiler.CompilerJob{<:$target, StaticCompilerParams}) = job.config.params.cache + GPUCompiler.method_table(@nospecialize(job::GPUCompiler.CompilerJob{<:$target})) = job.config.target.method_table end end -GPUCompiler.method_table(@nospecialize(job::GPUCompiler.CompilerJob{ExternalNativeCompilerTarget})) = method_table -GPUCompiler.method_table(@nospecialize(job::GPUCompiler.CompilerJob{ExternalNativeCompilerTarget, StaticCompilerParams})) = method_table - function native_job(@nospecialize(func::Function), @nospecialize(types::Type), external::Bool; - mixtape = NoContext(), name = fix_name(func), kernel::Bool = false, - target = (), + target = (;), + method_table=method_table, kwargs... ) + target = merge(target, (;method_table)) source = methodinstance(typeof(func), Base.to_tuple_type(types)) target = external ? ExternalNativeCompilerTarget(;target...) : NativeCompilerTarget(;target...) - params = StaticCompilerParams(mixtape = mixtape) + params = StaticCompilerParams() config = GPUCompiler.CompilerConfig(target, params, name = name, kernel = kernel) StaticCompiler.CompilerJob(source, config), kwargs end -function native_job(@nospecialize(func), @nospecialize(types), external; mixtape = NoContext(), kernel::Bool=false, name=fix_name(repr(func)), target = (), kwargs...) +function native_job(@nospecialize(func), @nospecialize(types), external; kernel::Bool=false, name=fix_name(repr(func)), target = (;), method_table=method_table, kwargs...) + target = merge(target, (; method_table)) source = methodinstance(typeof(func), Base.to_tuple_type(types)) target = external ? ExternalNativeCompilerTarget(;target...) : NativeCompilerTarget(;target...) - params = StaticCompilerParams(mixtape = mixtape) + params = StaticCompilerParams() config = GPUCompiler.CompilerConfig(target, params, name = name, kernel = kernel) GPUCompiler.CompilerJob(source, config), kwargs end diff --git a/test/Project.toml b/test/Project.toml index 4e8f83bf..5498846c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,4 @@ [deps] -CodeInfoTools = "bc773b8a-8374-437a-b9f2-0e9785855863" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0" GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55" diff --git a/test/runtests.jl b/test/runtests.jl index d25b1ad0..542659c2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,7 +7,6 @@ using ManualMemory using Distributed using StaticTools using StrideArraysCore -using CodeInfoTools using MacroTools using LLD_jll using Bumper diff --git a/test/testcore.jl b/test/testcore.jl index 7e1e821c..e705ed64 100644 --- a/test/testcore.jl +++ b/test/testcore.jl @@ -124,3 +124,35 @@ end @test ccall(fptr, Float64, (Float64,), 10.) == squaresquaresquare(10.) #Compile dylib end + + +# Overlays + +module SubFoo + +rand(args...) = Base.rand(args...) + +function f() + x = rand() + y = rand() + return x + y +end + +end + +@device_override SubFoo.rand() = 2 + +# Lets test having another method table around +Base.Experimental.@MethodTable AnotherTable +Base.Experimental.@overlay AnotherTable SubFoo.rand() = 3 + +@testset "Overlays" begin + Libdl.dlopen(compile_shlib(SubFoo.f, (), workdir)) do lib + fptr = Libdl.dlsym(lib, "f") + @test @ccall($fptr()::Int) == 4 + end + Libdl.dlopen(compile_shlib(SubFoo.f, (), workdir; method_table=AnotherTable)) do lib + fptr = Libdl.dlsym(lib, "f") + @test @ccall($fptr()::Int) == 6 + end +end diff --git a/test/testintegration.jl b/test/testintegration.jl index 338e7dee..47e3c778 100644 --- a/test/testintegration.jl +++ b/test/testintegration.jl @@ -334,90 +334,6 @@ end end - -# Mixtape - -module SubFoo - -function f() - x = rand() - y = rand() - return x + y -end - -function stringfun(s1, s2) - return s1 * s2 -end - -function teststring() - return stringfun("ab", "c") == "abc" -end - -end - -struct MyMix <: CompilationContext end - -@testset "Mixtape" begin - # 101: How2Mix - - # A few little utility functions for working with Expr instances. - swap(e) = e - function swap(e::Expr) - new = MacroTools.postwalk(e) do s - isexpr(s, :call) || return s - s.args[1] == Base.rand || return s - return 4 - end - return new - end - - # This is pre-inference - you get to see a CodeInfoTools.Builder instance. - function StaticCompiler.transform(::MyMix, src) - b = CodeInfoTools.Builder(src) - for (v, st) in b - b[v] = swap(st) - end - return CodeInfoTools.finish(b) - end - - # MyMix will only transform functions which you explicitly allow. - # You can also greenlight modules. - StaticCompiler.allow(ctx::MyMix, m::Module) = m == SubFoo - - # redefine swap to test caching and add StaticString substitution - function swap(e::Expr) - new = MacroTools.postwalk(e) do s - s isa String && return StaticTools.StaticString(tuple(codeunits(s)..., 0x00)) - isexpr(s, :call) || return s - s.args[1] == Base.rand || return s - return 2 - end - return new - end - path = compile_shlib(SubFoo.f, (), testpath, mixtape=MyMix()) - ptr = Libdl.dlopen(path, Libdl.RTLD_LOCAL) - - fptr = Libdl.dlsym(ptr, "f") - @test @ccall($fptr()::Int) == 4 -end - - - -@testset "Cross compiling to WebAssembly" begin - testpath = pwd() - scratch = tempdir() - cd(scratch) - - m2(x) = 2x - m3(x) = 3x - wasm_path = compile_wasm(m2, Tuple{Float64}) - wasm_path2 = compile_wasm([(m2, Tuple{Float64}), (m3, Tuple{Float64})]) - - wasm_path = compile_wasm(m2, (Float64,)) - wasm_path2 = compile_wasm([(m2, (Float64,)), (m3, (Float64,))]) - -end - ## --- Clean up cd(testpath)