Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA jit #108

Merged
merged 22 commits into from
Dec 21, 2024
1 change: 0 additions & 1 deletion cute_kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
)
from .enums import KernelBackend
from .inductor import init_inductor
from .kernel_registry import KernelRegistry
from .kernels import (
MoE_Torch,
MoE_Triton,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
sources:
- kernels/add/add_tensor/cuda_implementation/ops.cpp
- kernels/add/add_tensor/cuda_implementation/kernels_forward.cu
build_path: add_tensor

- functions:
- add_scalar_forward_cuda
sources:
- kernels/add/add_scalar/cuda_implementation/ops.cpp
- kernels/add/add_scalar/cuda_implementation/kernels_forward.cu
build_path: add_scalar

- functions:
- swiglu_forward_cuda
Expand All @@ -17,3 +19,4 @@
- kernels/swiglu/cuda_implementation/ops.cpp
- kernels/swiglu/cuda_implementation/kernels_forward.cu
- kernels/swiglu/cuda_implementation/kernels_backward.cu
build_path: swiglu
89 changes: 89 additions & 0 deletions cute_kernels/jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import inspect
import os
from typing import Callable

import torch
import yaml
from torch.utils.cpp_extension import load as load_cpp_extension


CPP_MODULE_PREFIX = "cute_cuda_kernels"
CPP_BUILD_DIRECTORY = "build"
CPP_FUNCTIONS = {}
CPP_REGISTRY_YAML = yaml.safe_load(open(os.path.join(os.path.dirname(__file__), "cpp_registry.yml"), "r"))


@torch._dynamo.disable
def compile_cpp(name: str) -> None:
function_map = []
all_functions = []
source_map = []
build_directories = []
for module in CPP_REGISTRY_YAML:
function_map.append(module["functions"])
all_functions.extend(module["functions"])
source_map.append([os.path.join(os.path.dirname(__file__), source) for source in module["sources"]])
build_directories.append(module["build_path"])

assert len(all_functions) == len(set(all_functions)), "function names are not unique"

# find which files the function belongs to
for index, (functions, build_directory) in enumerate(zip(function_map, build_directories)):
if name in functions:
break

full_build_path = os.path.join(CPP_BUILD_DIRECTORY, build_directory)
os.makedirs(full_build_path, exist_ok=True)

module = load_cpp_extension(
f"{CPP_MODULE_PREFIX}_{build_directory}",
sources=source_map[index],
with_cuda=True,
extra_cflags=["-O3", "-Wall", "-shared", "-fPIC", "-fdiagnostics-color"],
build_directory=full_build_path,
verbose=True,
)

# populate all functions from the file
for function in function_map[index]:
CPP_FUNCTIONS[function] = getattr(module, function)


def get_cpp_function(name: str) -> Callable:
function = CPP_FUNCTIONS.get(name, None)

# if kernel is compiled, we return the torch op since its compatible with torch compile
if function is None:
compile_cpp(name)
function = get_cpp_function(name)

return function


def cpp_jit(function_name: str) -> Callable:
cpp_function = None
args_spec = None

def _run(*args, **kwargs):
nonlocal cpp_function

if cpp_function is None:
cpp_function = get_cpp_function(function_name)

full_args = []
full_args.extend(args)
for variable_name in args_spec.args[len(args) :]:
full_args.append(kwargs[variable_name])

return cpp_function(*full_args)

def inner(function: Callable) -> Callable:
_run.__signature__ = inspect.signature(function)
_run.__name__ = function.__name__

nonlocal args_spec
args_spec = inspect.getfullargspec(function)

return _run

return inner
61 changes: 0 additions & 61 deletions cute_kernels/kernel_registry.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import torch

from .....constants import LIBRARY_NAME
from .....kernel_registry import KernelRegistry
from .....jit import cpp_jit
from .....utils import cute_op


_KERNEL_NAME = "add_scalar_forward_cuda"


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"})
@cpp_jit(_KERNEL_NAME)
def add_scalar_forward_cuda(
x: torch.Tensor, y: float, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int
) -> None:
KernelRegistry.get_kernel(_KERNEL_NAME)(x, y, output, vector_instruction_width, BLOCK_SIZE)
) -> None: ...
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import torch

from .....constants import LIBRARY_NAME
from .....kernel_registry import KernelRegistry
from .....jit import cpp_jit
from .....utils import cute_op


_KERNEL_NAME = "add_tensor_forward_cuda"


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"})
@cpp_jit(_KERNEL_NAME)
def add_tensor_forward_cuda(
x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int
) -> None:
KernelRegistry.get_kernel(_KERNEL_NAME)(x, y, output, vector_instruction_width, BLOCK_SIZE)
) -> None: ...
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import torch

from ....constants import LIBRARY_NAME
from ....kernel_registry import KernelRegistry
from ....jit import cpp_jit
from ....utils import cute_op


_KERNEL_NAME = "swiglu_backward_cuda"


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"gate_grad", "up_grad"})
@cpp_jit(_KERNEL_NAME)
def swiglu_backward_cuda(
gate: torch.Tensor,
up: torch.Tensor,
Expand All @@ -17,7 +18,4 @@ def swiglu_backward_cuda(
up_grad: torch.Tensor,
vector_instruction_width: int,
BLOCK_SIZE: int,
) -> None:
KernelRegistry.get_kernel(_KERNEL_NAME)(
gate, up, output_grad, gate_grad, up_grad, vector_instruction_width, BLOCK_SIZE
)
) -> None: ...
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import torch

from ....constants import LIBRARY_NAME
from ....kernel_registry import KernelRegistry
from ....jit import cpp_jit
from ....utils import cute_op


_KERNEL_NAME = "swiglu_forward_cuda"


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"})
@cpp_jit(_KERNEL_NAME)
def swiglu_forward_cuda(
gate: torch.Tensor, up: torch.Tensor, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int
) -> None:
KernelRegistry.get_kernel(_KERNEL_NAME)(gate, up, output, vector_instruction_width, BLOCK_SIZE)
) -> None: ...
Loading