Skip to content

Commit

Permalink
C++ jit
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 committed Dec 21, 2024
1 parent 9323ac3 commit 1bef0c2
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 13 deletions.
8 changes: 8 additions & 0 deletions cute_kernels/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os

import torch
import triton.language as tl
import yaml

from .math import get_powers_of_2

Expand All @@ -21,3 +24,8 @@
torch.float16: tl.float16,
torch.bfloat16: tl.bfloat16,
}

CPP_MODULE_PREFIX = "cute_cuda_kernels"
CPP_BUILD_DIRECTORY = "build"
CPP_FUNCTIONS = {}
CPP_REGISTRY_YAML = yaml.safe_load(open(os.path.join(os.path.dirname(__file__), "cpp_registry.yml"), "r"))
File renamed without changes.
20 changes: 7 additions & 13 deletions cute_kernels/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,15 @@
import yaml
from torch.utils.cpp_extension import load as load_cpp_extension


class _CUDA_JIT:
module_name = "cute_cuda_kernels"
build_directory = "build"
cuda_kernel_registry = {}
kernel_registry_yaml = yaml.safe_load(open(os.path.join(os.path.dirname(__file__), "kernel_registry.yml"), "r"))
from .constants import CPP_BUILD_DIRECTORY, CPP_FUNCTIONS, CPP_MODULE_PREFIX, CPP_REGISTRY_YAML


@torch._dynamo.disable
def compile_cpp(name: str) -> None:
function_map = []
all_functions = []
source_map = []
for module in _CUDA_JIT.kernel_registry_yaml:
for module in CPP_REGISTRY_YAML:
function_map.append(module["functions"])
all_functions.extend(module["functions"])

Expand All @@ -29,30 +24,29 @@ def compile_cpp(name: str) -> None:

assert len(all_functions) == len(set(all_functions)), "function names are not unique"

build_directory = _CUDA_JIT.build_directory
os.makedirs(build_directory, exist_ok=True)
os.makedirs(CPP_BUILD_DIRECTORY, exist_ok=True)

# find which files the function belongs to
for index, functions in enumerate(function_map):
if name in functions:
break

module = load_cpp_extension(
f"{_CUDA_JIT.module_name}_{index}",
f"{CPP_MODULE_PREFIX}_{index}",
sources=source_map[index],
with_cuda=True,
extra_cflags=["-O3", "-Wall", "-shared", "-fPIC", "-fdiagnostics-color"],
build_directory=build_directory,
build_directory=CPP_BUILD_DIRECTORY,
verbose=True,
)

# populate all functions from the file
for function in function_map[index]:
_CUDA_JIT.cuda_kernel_registry[function] = getattr(module, function)
CPP_FUNCTIONS[function] = getattr(module, function)


def get_cpp_function(name: str) -> Callable:
function = _CUDA_JIT.cuda_kernel_registry.get(name, None)
function = CPP_FUNCTIONS.get(name, None)

# if kernel is compiled, we return the torch op since its compatible with torch compile
if function is None:
Expand Down

0 comments on commit 1bef0c2

Please sign in to comment.