Skip to content

Commit

Permalink
Use static compilation for kernels.
Browse files Browse the repository at this point in the history
Signed-off-by: Ilya Enkovich <[email protected]>
  • Loading branch information
ienkovich committed Jun 20, 2024
1 parent ede8a8e commit 974c6ed
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 307 deletions.
2 changes: 2 additions & 0 deletions python/triton/runtime/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def _build(name, src, srcdir, library_dirs, include_dirs, libraries):
# CPU backend uses C++ (driver.cpp). Some old version compilers need a specific C++17 flag.
if src.endswith(".cpp") or src.endswith(".cc"):
cc_cmd += ["-std=c++17", "-fopenmp"]
if src.endswith(".s"):
cc_cmd += ["-gdwarf-5"]
ret = subprocess.check_call(cc_cmd)
if ret == 0:
return so
Expand Down
16 changes: 4 additions & 12 deletions third_party/cpu/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def supports_target(target: GPUTarget):

def __init__(self, target: tuple) -> None:
super().__init__(target)
self.binary_ext = "bc"
self.binary_ext = "asm"

def parse_options(self, opts) -> Any:
args = {k: opts[k] for k in CPUOptions.__dataclass_fields__.keys() if k in opts}
Expand Down Expand Up @@ -138,22 +138,14 @@ def make_llir(src, metadata, options):
return ret

@staticmethod
def make_bc(src, metadata, options):
if os.environ.get("TRITON_CPU_ASM_DUMP", "0") == "1":
from triton.runtime.cache import get_cache_manager

asm = llvm.translate_to_host_asm(src, options.enable_fp_fusion)
fn_cache_manager = get_cache_manager(metadata['hash'])
fn_cache_manager.put(asm, f"{metadata['name']}.asm")

ret = llvm.translate_to_bc(src)
return ret
def make_asm(src, metadata, options):
return llvm.translate_to_host_asm(src, options.enable_fp_fusion)

def add_stages(self, stages, options):
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
stages["ttcir"] = lambda src, metadata: self.make_ttcir(src, metadata, options)
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
stages["bc"] = lambda src, metadata: self.make_bc(src, metadata, options)
stages["asm"] = lambda src, metadata: self.make_asm(src, metadata, options)

@functools.lru_cache()
def hash(self):
Expand Down
224 changes: 0 additions & 224 deletions third_party/cpu/backend/driver.cpp

This file was deleted.

94 changes: 23 additions & 71 deletions third_party/cpu/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,74 +8,9 @@
from triton.backends.compiler import GPUTarget

dirname = os.getenv("TRITON_SYS_PATH", default="/usr/local")
llvm_root = os.getenv("LLVM_PATH", default="~/.triton/llvm")
llvm_root = os.path.expanduser(llvm_root)
llvm_dirs = os.listdir(llvm_root)
if len(llvm_dirs) == 1:
llvm_root = os.path.join(llvm_root, llvm_dirs[0])
include_dir = [
os.path.join(dirname, "include"),
os.path.join(llvm_root, "include"),
]
library_dir = [os.path.join(dirname, "lib"), os.path.join(llvm_root, "lib")]
libraries = [
"LLVMOrcJIT",
"LLVMPasses",
"LLVMX86CodeGen",
"LLVMX86AsmParser",
"LLVMX86Desc",
"LLVMX86Info",
"LLVMGlobalISel",
"LLVMSelectionDAG",
"LLVMHipStdPar",
"LLVMCoroutines",
"LLVMipo",
"LLVMFrontendOpenMP",
"LLVMInstrumentation",
"LLVMAsmPrinter",
"LLVMCodeGen",
"LLVMObjCARCOpts",
"LLVMLinker",
"LLVMVectorize",
"LLVMScalarOpts",
"LLVMInstCombine",
"LLVMFrontendOffloading",
"LLVMExecutionEngine",
"LLVMAggressiveInstCombine",
"LLVMTransformUtils",
"LLVMTarget",
"LLVMRuntimeDyld",
"LLVMJITLink",
"LLVMIRPrinter",
"LLVMBitWriter",
"LLVMAnalysis",
"LLVMProfileData",
"LLVMSymbolize",
"LLVMDebugInfoDWARF",
"LLVMObject",
"LLVMTextAPI",
"LLVMMCParser",
"LLVMMCDisassembler",
"LLVMMC",
"LLVMIRReader",
"LLVMCFGuard",
"LLVMBitReader",
"LLVMAsmParser",
"LLVMCore",
"LLVMBinaryFormat",
"LLVMOrcTargetProcess",
"LLVMTargetParser",
"LLVMRemarks",
"LLVMOrcShared",
"LLVMOption",
"LLVMDebugInfoCodeView",
"LLVMCodeGenTypes",
"LLVMBitstreamReader",
"LLVMSupport",
"LLVMDemangle",
"stdc++",
"z",
]
include_dir = [os.path.join(dirname, "include")]
library_dir = [os.path.join(dirname, "lib")]
libraries = ["stdc++"]


def compile_module_from_src(src, name):
Expand Down Expand Up @@ -110,9 +45,26 @@ def __new__(cls):
return cls.instance

def __init__(self):
dirname = os.path.dirname(os.path.realpath(__file__))
mod = compile_module_from_src(Path(os.path.join(dirname, "driver.cpp")).read_text(), "cpu_utils")
self.load_binary = mod.load_binary
pass

def load_binary(self, name, src, shared_mem, device):
# src actually holds asm text, compile to a shared library.
key = hashlib.md5(src).hexdigest()
cache = get_cache_manager(key)
cache_path = cache.get_file(f"{name}.so")
if cache_path is None:
with tempfile.TemporaryDirectory() as tmpdir:
asm_path = os.path.join(tmpdir, "kernel.s")
Path(asm_path).write_bytes(src)
Path("kernel.s").write_bytes(src)
so = _build(name, asm_path, tmpdir, library_dir, include_dir, ["gcc", "m"])
with open(so, "rb") as f:
cache_path = cache.put(f.read(), f"{name}.so", binary=True)
import ctypes
lib = ctypes.cdll.LoadLibrary(cache_path)
fn_ptr = getattr(lib, name)
fn_ptr_as_void_p = ctypes.cast(fn_ptr, ctypes.c_void_p).value
return (fn_ptr, fn_ptr_as_void_p, 0, 0)

def get_device_properties(self, *args):
return {"max_shared_mem": 0}
Expand Down

0 comments on commit 974c6ed

Please sign in to comment.