Skip to content

Commit

Permalink
remove compile_wasm and MixTape; allow specifying a method_table
Browse files Browse the repository at this point in the history
  • Loading branch information
MasonProtter committed Nov 11, 2023
1 parent 7502758 commit 1aed64e
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 189 deletions.
23 changes: 11 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,27 +43,31 @@ Sometimes, a julia function you want to statically compile will do things (such
```julia
julia> using Libdl, StaticCompiler

julia> f(x) = x + 1;
julia> f(x) = g(x) + 1;

julia> @device_override Base.:(+)(x::Int, y::Int) = x - y
julia> g(x) = 2x

julia> @device_override g(x::Int) = x - 10

julia> f(1) # Gives the expected answer in regular julia
2
3

julia> dlopen(compile_shlib(f, (Int,), "./")) do lib
fptr = dlsym(lib, "f")
# Now use the compiled version where + is replaced with -
@ccall $fptr(1::Int)::Int
end
0
-8
```
Typically, errors should be overrided and replaced with `@print_and_throw`, which is StaticCompiler friendle, i.e.
Typically, errors should be overrided and replaced with `@print_and_throw`, which is StaticCompiler friendly, i.e.
we define overrides such as
``` julia
@device_override @noinline Base.Math.throw_complex_domainerror(f::Symbol, x) =
@print_and_throw c"This operation requires a complex input to return a complex result"
```

If for some reason, you wish to use a different method table (defined with `Base.Experimental.@MethodTable` and `Base.Experimental.@overlay`) than the default one provided by StaticCompiler.jl, you can provide it to `compile_executable` and `compile_shlib` via a keyword argument `method_table`.


## Approach

Expand Down Expand Up @@ -98,17 +102,12 @@ To enable code to be statically compiled, consider the following:

## Guide for Statically Compiling Code

If you're trying to statically compile generic code, you may run into issues if that code uses features not supported by StaticCompiler. One option is to change the code you're calling using the tips above. If that is not easy, you may by able to compile it anyway. One option is to use method overrides to change what methods are called. Another option is to use the Mixtape feature to change problematic code as part of compilation. For example, you could convert all Strings to StaticStrings.
If you're trying to statically compile generic code, you may run into issues if that code uses features not supported by StaticCompiler. One option is to change the code you're calling using the tips above. If that is not easy, you may by able to compile it anyway. One option is to use method overlays to change what methods are called.

[Cthulhu](https://github.com/JuliaDebug/Cthulhu.jl) is a great help in digging into code, finding type instabilities, and finding other sources of code that may break static compilation.

## Foreign Function Interfacing

Because Julia objects follow C memory layouts, compiled libraries should be usable from most languages that can interface with C. For example, results should be usable with Python's [CFFI](https://cffi.readthedocs.io/en/latest/) package.

For WebAssembly, interface helpers are available at [WebAssemblyInterfaces](https://github.com/tshort/WebAssemblyInterfaces.jl).





For WebAssembly, interface helpers are available at [WebAssemblyInterfaces](https://github.com/tshort/WebAssemblyInterfaces.jl), and users should also see [WebAssemblyCompiler](https://github.com/tshort/WebAssemblyCompiler.jl) for a package more focused on compilation of WebAssebly in general.
60 changes: 16 additions & 44 deletions src/StaticCompiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@ using Clang_jll: clang
using LLD_jll: lld
using StaticTools
using StaticTools: @symbolcall, @c_str, println

using Core: MethodTable

export load_function, compile_shlib, compile_executable, compile_wasm
export native_code_llvm, native_code_typed, native_llvm_module, native_code_native
export @device_override, @print_and_throw

include("mixtape.jl")
include("interpreter.jl")
include("target.jl")
include("pointer_warning.jl")
Expand All @@ -33,6 +32,7 @@ compile_executable(f::Function, types::Tuple, path::String, [name::String=string
filename::String=name,
cflags=``, # Specify libraries you would like to link against, and other compiler options here
also_expose=[],
method_table=StaticCompiler.method_table,
kwargs...
)
```
Expand Down Expand Up @@ -124,8 +124,18 @@ end

"""
```julia
compile_shlib(f::Function, types::Tuple, [path::String="./"], [name::String=string(nameof(f))]; filename::String=name, cflags=``, kwargs...)
compile_shlib(funcs::Array, [path::String="./"]; filename="libfoo", demangle=true, cflags=``, kwargs...)
compile_shlib(f::Function, types::Tuple, [path::String="./"], [name::String=string(nameof(f))];
filename::String=name,
cflags=``,
method_table=StaticCompiler.method_table,
kwargs...)
compile_shlib(funcs::Array, [path::String="./"];
filename="libfoo",
demangle=true,
cflags=``,
method_table=StaticCompiler.method_table,
kwargs...)
```
As `compile_executable`, but compiling to a standalone `.dylib`/`.so` shared library.
Expand Down Expand Up @@ -184,38 +194,7 @@ function compile_shlib(funcs::Union{Array,Tuple}, path::String="./";

joinpath(abspath(path), filename * "." * Libdl.dlext)
end

"""
```julia
compile_wasm(f::Function, types::Tuple, [path::String="./"], [name::String=string(nameof(f))]; filename::String=name, flags=``, kwargs...)
compile_wasm(funcs::Union{Array,Tuple}, [path::String="./"]; filename="libfoo", demangle=true, flags=``, kwargs...)
```
As `compile_shlib`, but compiling to a WebAssembly library.
If `demangle` is set to `false`, compiled function names are prepended with "julia_".
```
"""
function compile_wasm(f::Function, types=();
path::String = "./",
filename = fix_name(f),
flags = ``,
kwargs...
)
tt = Base.to_tuple_type(types)
obj_path, name = generate_obj(f, tt, true, path, filename; target = (triple = "wasm32-unknown-wasi", cpu = "", features = ""), remove_julia_addrspaces = true, kwargs...)
run(`$(lld()) -flavor wasm --no-entry --export-all $flags $obj_path/obj.o -o $path/$name.wasm`)
joinpath(abspath(path), filename * ".wasm")
end
function compile_wasm(funcs::Union{Array,Tuple};
path::String="./",
filename="libfoo",
flags=``,
kwargs...
)
obj_path, name = generate_obj(funcs, true, path, filename; target = (triple = "wasm32-unknown-wasi", cpu = "", features = ""), remove_julia_addrspaces = true, kwargs...)
run(`$(lld()) -flavor wasm --no-entry --export-all $flags $obj_path/$filename.o -o $path/$filename.wasm`)
joinpath(abspath(path), filename * ".wasm")
end


"""
```julia
Expand Down Expand Up @@ -260,6 +239,7 @@ function generate_shlib_fptr(path::String, name, filename::String=name)
@assert fptr != C_NULL
fptr
end

# As above, but also compile (maybe remove this method in the future?)
function generate_shlib_fptr(f, tt, path::String=tempname(), name=fix_name(f), filename::String=name;
temp::Bool=true,
Expand Down Expand Up @@ -478,7 +458,6 @@ end
"""
```julia
generate_obj(f, tt, external::Bool, path::String = tempname(), filenamebase::String="obj";
mixtape = NoContext(),
target = (),
demangle = true,
strip_llvm = false,
Expand All @@ -490,9 +469,6 @@ Low level interface for compiling object code (`.o`) for for function `f` given
a tuple type `tt` characterizing the types of the arguments for which the
function will be compiled.
`mixtape` defines a context that can be used to transform IR prior to compilation using
[Mixtape](https://github.com/JuliaCompilerPlugins/Mixtape.jl) features.
`target` can be used to change the output target. This is useful for compiling to WebAssembly and embedded targets.
This is a named tuple with fields `triple`, `cpu`, and `features` (each of these are strings).
The defaults compile to the native target.
Expand Down Expand Up @@ -522,7 +498,6 @@ end
"""
```julia
generate_obj(funcs::Union{Array,Tuple}, external::Bool, path::String = tempname(), filenamebase::String="obj";
mixtape = NoContext(),
target = (),
demangle =false,
strip_llvm = false,
Expand All @@ -534,9 +509,6 @@ Low level interface for compiling object code (`.o`) for an array of Tuples
(f, tt) where each function `f` and tuple type `tt` determine the set of methods
which will be compiled.
`mixtape` defines a context that can be used to transform IR prior to compilation using
[Mixtape](https://github.com/JuliaCompilerPlugins/Mixtape.jl) features.
`target` can be used to change the output target. This is useful for compiling to WebAssembly and embedded targets.
This is a named tuple with fields `triple`, `cpu`, and `features` (each of these are strings).
The defaults compile to the native target.
Expand Down
44 changes: 16 additions & 28 deletions src/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using GPUCompiler:
using CodeInfoTools
using CodeInfoTools: resolve

struct StaticInterpreter{M} <: AbstractInterpreter
struct StaticInterpreter <: AbstractInterpreter
global_cache::CodeCache
method_table::Union{Nothing,Core.MethodTable}

Expand All @@ -20,13 +20,10 @@ struct StaticInterpreter{M} <: AbstractInterpreter
inf_params::InferenceParams
opt_params::OptimizationParams

# Mixtape context
mixtape::M

function StaticInterpreter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt, ip::InferenceParams, op::OptimizationParams, mixtape::CompilationContext)
function StaticInterpreter(cache::CodeCache, mt::Union{Nothing,Core.MethodTable}, world::UInt, ip::InferenceParams, op::OptimizationParams)
@assert world <= Base.get_world_counter()

return new{typeof(mixtape)}(
return new(
cache,
mt,

Expand All @@ -38,10 +35,7 @@ struct StaticInterpreter{M} <: AbstractInterpreter

# parameters for inference and optimization
ip,
op,

# Mixtape context
mixtape
op
)
end
end
Expand Down Expand Up @@ -79,9 +73,6 @@ function custom_pass!(interp::StaticInterpreter, result::InferenceResult, mi::Co
mi.specTypes isa UnionAll && return src
sig = Tuple(mi.specTypes.parameters)
as = map(resolve_generic, sig)
if allow(interp.mixtape, mi.def.module, as...)
src = transform(interp.mixtape, src, sig)
end
return src
end

Expand All @@ -102,22 +93,21 @@ end
Core.Compiler.may_optimize(interp::StaticInterpreter) = true
Core.Compiler.may_compress(interp::StaticInterpreter) = true
Core.Compiler.may_discard_trees(interp::StaticInterpreter) = true
if VERSION >= v"1.7.0-DEV.577"
Core.Compiler.verbose_stmt_info(interp::StaticInterpreter) = false
end


if isdefined(Base.Experimental, Symbol("@overlay"))
using Core.Compiler: OverlayMethodTable
if v"1.8-beta2" <= VERSION < v"1.9-" || VERSION >= v"1.9.0-DEV.120"
Core.Compiler.method_table(interp::StaticInterpreter) =
OverlayMethodTable(interp.world, interp.method_table)
else
Core.Compiler.method_table(interp::StaticInterpreter, sv::InferenceState) =
OverlayMethodTable(interp.world, interp.method_table)
end
using Core.Compiler: OverlayMethodTable
if v"1.8-beta2" <= VERSION < v"1.9-" || VERSION >= v"1.9.0-DEV.120"
Core.Compiler.method_table(interp::StaticInterpreter) =
OverlayMethodTable(interp.world, interp.method_table)
else
Core.Compiler.method_table(interp::StaticInterpreter, sv::InferenceState) =
OverlayMethodTable(interp.world, interp.method_table)
end
else
Core.Compiler.method_table(interp::StaticInterpreter, sv::InferenceState) =
WorldOverlayMethodTable(interp.world)
Core.Compiler.method_table(interp::StaticInterpreter, sv::InferenceState) =
WorldOverlayMethodTable(interp.world)
end

# semi-concrete interepretation is broken with overlays (JuliaLang/julia#47349)
Expand All @@ -134,13 +124,11 @@ end
struct StaticCompilerParams <: AbstractCompilerParams
opt::Bool
optlevel::Int
mixtape::CompilationContext
cache::CodeCache
end

function StaticCompilerParams(; opt = false,
optlevel = Base.JLOptions().opt_level,
mixtape = NoContext(),
cache = CodeCache())
return StaticCompilerParams(opt, optlevel, mixtape, cache)
return StaticCompilerParams(opt, optlevel, cache)
end
39 changes: 20 additions & 19 deletions src/target.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,18 @@ macro device_override(ex)
return esc(code)
end

Base.@kwdef struct NativeCompilerTarget <: GPUCompiler.AbstractCompilerTarget
Base.@kwdef struct NativeCompilerTarget{MT} <: GPUCompiler.AbstractCompilerTarget
triple::String=Sys.MACHINE
cpu::String=(LLVM.version() < v"8") ? "" : unsafe_string(LLVM.API.LLVMGetHostCPUName())
features::String=(LLVM.version() < v"8") ? "" : unsafe_string(LLVM.API.LLVMGetHostCPUFeatures())
method_table::MT = method_table
end

Base.@kwdef struct ExternalNativeCompilerTarget <: GPUCompiler.AbstractCompilerTarget
Base.@kwdef struct ExternalNativeCompilerTarget{MT} <: GPUCompiler.AbstractCompilerTarget
triple::String=Sys.MACHINE
cpu::String=(LLVM.version() < v"8") ? "" : unsafe_string(LLVM.API.LLVMGetHostCPUName())
features::String=(LLVM.version() < v"8") ? "" : unsafe_string(LLVM.API.LLVMGetHostCPUFeatures())
method_table::MT = method_table
end

module StaticRuntime
Expand Down Expand Up @@ -66,44 +68,43 @@ for target in (:NativeCompilerTarget, :ExternalNativeCompilerTarget)
return tm
end

GPUCompiler.runtime_slug(job::GPUCompiler.CompilerJob{$target}) = "native_$(job.config.target.cpu)-$(hash(job.config.target.features))"
GPUCompiler.runtime_slug(job::GPUCompiler.CompilerJob{<:$target}) = "native_$(job.config.target.cpu)-$(hash(job.config.target.features))"

GPUCompiler.runtime_module(::GPUCompiler.CompilerJob{$target}) = StaticRuntime
GPUCompiler.runtime_module(::GPUCompiler.CompilerJob{$target, StaticCompilerParams}) = StaticRuntime
GPUCompiler.runtime_module(::GPUCompiler.CompilerJob{<:$target}) = StaticRuntime
GPUCompiler.runtime_module(::GPUCompiler.CompilerJob{<:$target, StaticCompilerParams}) = StaticRuntime


GPUCompiler.can_throw(job::GPUCompiler.CompilerJob{$target, StaticCompilerParams}) = true
GPUCompiler.can_throw(job::GPUCompiler.CompilerJob{$target}) = true
GPUCompiler.can_throw(job::GPUCompiler.CompilerJob{<:$target, StaticCompilerParams}) = true
GPUCompiler.can_throw(job::GPUCompiler.CompilerJob{<:$target}) = true

GPUCompiler.get_interpreter(job::GPUCompiler.CompilerJob{$target, StaticCompilerParams}) =
GPUCompiler.get_interpreter(job::GPUCompiler.CompilerJob{<:$target, StaticCompilerParams}) =
StaticInterpreter(job.config.params.cache, GPUCompiler.method_table(job), job.world,
GPUCompiler.inference_params(job), GPUCompiler.optimization_params(job),
job.config.params.mixtape)
GPUCompiler.ci_cache(job::GPUCompiler.CompilerJob{$target, StaticCompilerParams}) = job.config.params.cache
GPUCompiler.inference_params(job), GPUCompiler.optimization_params(job))
GPUCompiler.ci_cache(job::GPUCompiler.CompilerJob{<:$target, StaticCompilerParams}) = job.config.params.cache
GPUCompiler.method_table(@nospecialize(job::GPUCompiler.CompilerJob{<:$target})) = job.config.target.method_table
end
end

GPUCompiler.method_table(@nospecialize(job::GPUCompiler.CompilerJob{ExternalNativeCompilerTarget})) = method_table
GPUCompiler.method_table(@nospecialize(job::GPUCompiler.CompilerJob{ExternalNativeCompilerTarget, StaticCompilerParams})) = method_table

function native_job(@nospecialize(func::Function), @nospecialize(types::Type), external::Bool;
mixtape = NoContext(),
name = fix_name(func),
kernel::Bool = false,
target = (),
target = (;),
method_table=method_table,
kwargs...
)
target = merge(target, (;method_table))
source = methodinstance(typeof(func), Base.to_tuple_type(types))
target = external ? ExternalNativeCompilerTarget(;target...) : NativeCompilerTarget(;target...)
params = StaticCompilerParams(mixtape = mixtape)
params = StaticCompilerParams()
config = GPUCompiler.CompilerConfig(target, params, name = name, kernel = kernel)
StaticCompiler.CompilerJob(source, config), kwargs
end

function native_job(@nospecialize(func), @nospecialize(types), external; mixtape = NoContext(), kernel::Bool=false, name=fix_name(repr(func)), target = (), kwargs...)
function native_job(@nospecialize(func), @nospecialize(types), external; kernel::Bool=false, name=fix_name(repr(func)), target = (;), method_table=method_table, kwargs...)
target = merge(target, (; method_table))
source = methodinstance(typeof(func), Base.to_tuple_type(types))
target = external ? ExternalNativeCompilerTarget(;target...) : NativeCompilerTarget(;target...)
params = StaticCompilerParams(mixtape = mixtape)
params = StaticCompilerParams()
config = GPUCompiler.CompilerConfig(target, params, name = name, kernel = kernel)
GPUCompiler.CompilerJob(source, config), kwargs
end
1 change: 0 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[deps]
CodeInfoTools = "bc773b8a-8374-437a-b9f2-0e9785855863"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Formatting = "59287772-0a20-5a39-b81b-1366585eb4c0"
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
Expand Down
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ using ManualMemory
using Distributed
using StaticTools
using StrideArraysCore
using CodeInfoTools
using MacroTools
using LLD_jll
using Bumper
Expand Down
Loading

0 comments on commit 1aed64e

Please sign in to comment.