Skip to content

Commit

Permalink
cleanup + refactor (#125)
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 authored Dec 31, 2024
1 parent ee79499 commit ed19747
Show file tree
Hide file tree
Showing 43 changed files with 783 additions and 741 deletions.
10 changes: 5 additions & 5 deletions cute_kernels/cpp_registry.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,28 @@
- add_tensor_forward_cuda
sources:
- kernels/add/add_tensor/cuda_implementation/ops.cpp
- kernels/add/add_tensor/cuda_implementation/kernels_forward.cu
- kernels/add/add_tensor/cuda_implementation/forward.cu
build_path: add_tensor

- functions:
- add_scalar_forward_cuda
sources:
- kernels/add/add_scalar/cuda_implementation/ops.cpp
- kernels/add/add_scalar/cuda_implementation/kernels_forward.cu
- kernels/add/add_scalar/cuda_implementation/forward.cu
build_path: add_scalar

- functions:
- contiguous_count_cuda
sources:
- kernels/contiguous_count/cuda_implementation/ops.cpp
- kernels/contiguous_count/cuda_implementation/kernels_forward.cu
- kernels/contiguous_count/cuda_implementation/forward.cu
build_path: contiguous_count

- functions:
- swiglu_forward_cuda
- swiglu_backward_cuda
sources:
- kernels/swiglu/cuda_implementation/ops.cpp
- kernels/swiglu/cuda_implementation/kernels_forward.cu
- kernels/swiglu/cuda_implementation/kernels_backward.cu
- kernels/swiglu/cuda_implementation/forward.cu
- kernels/swiglu/cuda_implementation/backward.cu
build_path: swiglu
Original file line number Diff line number Diff line change
@@ -1,15 +1 @@
import torch

from .....constants import LIBRARY_NAME
from .....jit import cpp_jit
from .....utils import cute_op


_KERNEL_NAME = "add_scalar_forward_cuda"


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"})
@cpp_jit(_KERNEL_NAME)
def add_scalar_forward_cuda(
x: torch.Tensor, y: float, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int
) -> None: ...
from .forward import add_scalar_forward_cuda
15 changes: 15 additions & 0 deletions cute_kernels/kernels/add/add_scalar/cuda_implementation/forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

from .....constants import LIBRARY_NAME
from .....jit import cpp_jit
from .....utils import cute_op


_KERNEL_NAME = "add_scalar_forward_cuda"


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"})
@cpp_jit(_KERNEL_NAME)
def add_scalar_forward_cuda(
x: torch.Tensor, y: float, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int
) -> None: ...
Original file line number Diff line number Diff line change
@@ -1,20 +1 @@
import torch

from .....constants import LIBRARY_NAME
from .....math import ceil_divide
from .....utils import cute_op
from .kernels_forward import _add_scalar_forward_triton_kernel


_KERNEL_NAME = "add_scalar_forward_triton"


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"})
def add_scalar_forward_triton(x: torch.Tensor, y: float, output: torch.Tensor, BLOCK_SIZE: int) -> None:
num_elements = x.numel()
num_programs = ceil_divide(num_elements, BLOCK_SIZE)

with torch.device(x.device):
_add_scalar_forward_triton_kernel[(num_programs,)](
x_ptr=x, y=y, output_ptr=output, num_elements=num_elements, BLOCK_SIZE=BLOCK_SIZE
)
from .forward import add_scalar_forward_triton
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
import triton
import triton.language as tl

from .....constants import LIBRARY_NAME
from .....math import ceil_divide
from .....utils import cute_op


_KERNEL_NAME = "add_scalar_forward_triton"


@triton.jit
def _add_scalar_forward_triton_kernel(x_ptr, y, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)

indices = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = indices < num_elements

x = tl.load(x_ptr + indices, mask=mask)
output = x + y

tl.store(output_ptr + indices, output, mask=mask)


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"})
def add_scalar_forward_triton(x: torch.Tensor, y: float, output: torch.Tensor, BLOCK_SIZE: int) -> None:
num_elements = x.numel()
num_programs = ceil_divide(num_elements, BLOCK_SIZE)

with torch.device(x.device):
_add_scalar_forward_triton_kernel[(num_programs,)](
x_ptr=x, y=y, output_ptr=output, num_elements=num_elements, BLOCK_SIZE=BLOCK_SIZE
)

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,15 +1 @@
import torch

from .....constants import LIBRARY_NAME
from .....jit import cpp_jit
from .....utils import cute_op


_KERNEL_NAME = "add_tensor_forward_cuda"


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"})
@cpp_jit(_KERNEL_NAME)
def add_tensor_forward_cuda(
x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int
) -> None: ...
from .forward import add_tensor_forward_cuda
15 changes: 15 additions & 0 deletions cute_kernels/kernels/add/add_tensor/cuda_implementation/forward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

from .....constants import LIBRARY_NAME
from .....jit import cpp_jit
from .....utils import cute_op


_KERNEL_NAME = "add_tensor_forward_cuda"


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"})
@cpp_jit(_KERNEL_NAME)
def add_tensor_forward_cuda(
x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int
) -> None: ...
Original file line number Diff line number Diff line change
@@ -1,19 +1 @@
import torch

from .....constants import LIBRARY_NAME
from .....math import ceil_divide
from .....utils import cute_op
from .kernels_forward import _add_tensor_forward_triton_kernel


_KERNEL_NAME = "add_tensor_forward_triton"


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"})
def add_tensor_forward_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, BLOCK_SIZE: int) -> None:
num_elements = x.numel()
num_programs = ceil_divide(num_elements, BLOCK_SIZE)

_add_tensor_forward_triton_kernel[(num_programs,)](
x_ptr=x, y_ptr=y, output_ptr=output, num_elements=num_elements, BLOCK_SIZE=BLOCK_SIZE
)
from .forward import add_tensor_forward_triton
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
import triton
import triton.language as tl

from .....constants import LIBRARY_NAME
from .....math import ceil_divide
from .....utils import cute_op


_KERNEL_NAME = "add_tensor_forward_triton"


@triton.jit
def _add_tensor_forward_triton_kernel(x_ptr, y_ptr, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)

indices = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = indices < num_elements

x = tl.load(x_ptr + indices, mask=mask)
y = tl.load(y_ptr + indices, mask=mask)

output = x + y

tl.store(output_ptr + indices, output, mask=mask)


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"})
def add_tensor_forward_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, BLOCK_SIZE: int) -> None:
num_elements = x.numel()
num_programs = ceil_divide(num_elements, BLOCK_SIZE)

_add_tensor_forward_triton_kernel[(num_programs,)](
x_ptr=x, y_ptr=y, output_ptr=output, num_elements=num_elements, BLOCK_SIZE=BLOCK_SIZE
)

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,15 +1 @@
import torch

from ....constants import LIBRARY_NAME
from ....jit import cpp_jit
from ....utils import cute_op


_KERNEL_NAME = "contiguous_count_cuda"


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"})
@cpp_jit(_KERNEL_NAME)
def contiguous_count_cuda(
x: torch.Tensor, output: torch.Tensor, sm_count: int, thread_block_cluster_size: int, size: int, BLOCK_SIZE: int
) -> None: ...
from .forward import contiguous_count_cuda
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

from ....constants import LIBRARY_NAME
from ....jit import cpp_jit
from ....utils import cute_op


_KERNEL_NAME = "contiguous_count_cuda"


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"})
@cpp_jit(_KERNEL_NAME)
def contiguous_count_cuda(
x: torch.Tensor, output: torch.Tensor, sm_count: int, thread_block_cluster_size: int, size: int, BLOCK_SIZE: int
) -> None: ...
Original file line number Diff line number Diff line change
@@ -1,29 +1 @@
import torch

from ....constants import LIBRARY_NAME
from ....math import ceil_divide
from ....utils import cute_op, get_sm_count
from .kernels_forward import _contiguous_count_triton_kernel


_KERNEL_NAME = "contiguous_count_triton"


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"})
def contiguous_count_triton(
x: torch.Tensor, output: torch.Tensor, size: int, BLOCK_SIZE: int, BLOCK_SIZE_C: int
) -> None:
B = x.numel()

sm_count = get_sm_count(x.device)
num_programs = min(sm_count, ceil_divide(B, BLOCK_SIZE))

with torch.device(x.device):
_contiguous_count_triton_kernel[(num_programs,)](
x_ptr=x,
output_ptr=output,
B=B,
C=size,
BLOCK_SIZE_B=BLOCK_SIZE,
BLOCK_SIZE_C=BLOCK_SIZE_C,
)
from .forward import contiguous_count_triton
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
import torch
import triton
import triton.language as tl

from ....constants import LIBRARY_NAME
from ....math import ceil_divide
from ....utils import cute_op, get_sm_count


_KERNEL_NAME = "contiguous_count_triton"


@triton.jit
def _contiguous_count_triton_kernel(x_ptr, output_ptr, B, C, BLOCK_SIZE_B: tl.constexpr, BLOCK_SIZE_C: tl.constexpr):
Expand Down Expand Up @@ -30,3 +38,23 @@ def _contiguous_count_triton_kernel(x_ptr, output_ptr, B, C, BLOCK_SIZE_B: tl.co
counts += tl.sum(equal, axis=0)

tl.atomic_add(output_ptr + indices_c, counts, mask=mask_c)


@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"})
def contiguous_count_triton(
x: torch.Tensor, output: torch.Tensor, size: int, BLOCK_SIZE: int, BLOCK_SIZE_C: int
) -> None:
B = x.numel()

sm_count = get_sm_count(x.device)
num_programs = min(sm_count, ceil_divide(B, BLOCK_SIZE))

with torch.device(x.device):
_contiguous_count_triton_kernel[(num_programs,)](
x_ptr=x,
output_ptr=output,
B=B,
C=size,
BLOCK_SIZE_B=BLOCK_SIZE,
BLOCK_SIZE_C=BLOCK_SIZE_C,
)
Loading

0 comments on commit ed19747

Please sign in to comment.