From ed197477bb937d68a19edca9ba84dcdb062baeb8 Mon Sep 17 00:00:00 2001 From: Mayank Mishra <32954280+mayank31398@users.noreply.github.com> Date: Tue, 31 Dec 2024 02:07:07 -0500 Subject: [PATCH] cleanup + refactor (#125) Signed-off-by: Mayank Mishra --- cute_kernels/cpp_registry.yml | 10 +- .../cuda_implementation/__init__.py | 16 +- .../{kernels_forward.cu => forward.cu} | 0 .../add_scalar/cuda_implementation/forward.py | 15 ++ .../triton_implementation/__init__.py | 21 +- .../triton_implementation/forward.py | 34 +++ .../triton_implementation/kernels_forward.py | 15 -- .../cuda_implementation/__init__.py | 16 +- .../{kernels_forward.cu => forward.cu} | 0 .../add_tensor/cuda_implementation/forward.py | 15 ++ .../triton_implementation/__init__.py | 20 +- .../triton_implementation/forward.py | 35 +++ .../triton_implementation/kernels_forward.py | 17 -- .../cuda_implementation/__init__.py | 16 +- .../{kernels_forward.cu => forward.cu} | 0 .../cuda_implementation/forward.py | 15 ++ .../triton_implementation/__init__.py | 30 +-- .../{kernels_forward.py => forward.py} | 28 +++ .../triton_implementation/__init__.py | 68 +----- .../triton_implementation/backward.py | 75 +++++++ .../triton_implementation/forward.py | 64 ++++++ .../triton_implementation/kernels_backward.py | 37 ---- .../triton_implementation/kernels_forward.py | 31 --- .../rmsnorm/triton_implementation/__init__.py | 166 +------------- .../rmsnorm/triton_implementation/backward.py | 209 ++++++++++++++++++ ... backward.py->rmsnorm_backward_triton.yml} | 0 .../rmsnorm/triton_implementation/forward.py | 89 ++++++++ ...=> forward.py->rmsnorm_forward_triton.yml} | 0 .../triton_implementation/kernels_backward.py | 81 ------- .../triton_implementation/kernels_forward.py | 46 ---- .../swiglu/cuda_implementation/__init__.py | 28 +-- .../{kernels_backward.cu => backward.cu} | 0 .../swiglu/cuda_implementation/backward.py | 20 ++ .../{kernels_forward.cu => forward.cu} | 0 .../swiglu/cuda_implementation/forward.py | 13 ++ .../swiglu/triton_implementation/__init__.py | 50 +---- .../swiglu/triton_implementation/backward.py | 56 +++++ .../swiglu/triton_implementation/forward.py | 39 ++++ .../triton_implementation/kernels_backward.py | 25 --- .../triton_implementation/kernels_forward.py | 17 -- .../triton_implementation/__init__.py | 52 +---- .../{kernels_backward.py => backward.py} | 30 +++ .../{kernels_forward.py => forward.py} | 25 +++ 43 files changed, 783 insertions(+), 741 deletions(-) rename cute_kernels/kernels/add/add_scalar/cuda_implementation/{kernels_forward.cu => forward.cu} (100%) create mode 100644 cute_kernels/kernels/add/add_scalar/cuda_implementation/forward.py create mode 100644 cute_kernels/kernels/add/add_scalar/triton_implementation/forward.py delete mode 100644 cute_kernels/kernels/add/add_scalar/triton_implementation/kernels_forward.py rename cute_kernels/kernels/add/add_tensor/cuda_implementation/{kernels_forward.cu => forward.cu} (100%) create mode 100644 cute_kernels/kernels/add/add_tensor/cuda_implementation/forward.py create mode 100644 cute_kernels/kernels/add/add_tensor/triton_implementation/forward.py delete mode 100644 cute_kernels/kernels/add/add_tensor/triton_implementation/kernels_forward.py rename cute_kernels/kernels/contiguous_count/cuda_implementation/{kernels_forward.cu => forward.cu} (100%) create mode 100644 cute_kernels/kernels/contiguous_count/cuda_implementation/forward.py rename cute_kernels/kernels/contiguous_count/triton_implementation/{kernels_forward.py => forward.py} (60%) create mode 100644 cute_kernels/kernels/embedding/triton_implementation/backward.py create mode 100644 cute_kernels/kernels/embedding/triton_implementation/forward.py delete mode 100644 cute_kernels/kernels/embedding/triton_implementation/kernels_backward.py delete mode 100644 cute_kernels/kernels/embedding/triton_implementation/kernels_forward.py create mode 100644 cute_kernels/kernels/rmsnorm/triton_implementation/backward.py rename cute_kernels/kernels/rmsnorm/triton_implementation/{__init__.py->rmsnorm_backward_triton.yml => backward.py->rmsnorm_backward_triton.yml} (100%) create mode 100644 cute_kernels/kernels/rmsnorm/triton_implementation/forward.py rename cute_kernels/kernels/rmsnorm/triton_implementation/{__init__.py->rmsnorm_forward_triton.yml => forward.py->rmsnorm_forward_triton.yml} (100%) delete mode 100644 cute_kernels/kernels/rmsnorm/triton_implementation/kernels_backward.py delete mode 100644 cute_kernels/kernels/rmsnorm/triton_implementation/kernels_forward.py rename cute_kernels/kernels/swiglu/cuda_implementation/{kernels_backward.cu => backward.cu} (100%) create mode 100644 cute_kernels/kernels/swiglu/cuda_implementation/backward.py rename cute_kernels/kernels/swiglu/cuda_implementation/{kernels_forward.cu => forward.cu} (100%) create mode 100644 cute_kernels/kernels/swiglu/cuda_implementation/forward.py create mode 100644 cute_kernels/kernels/swiglu/triton_implementation/backward.py create mode 100644 cute_kernels/kernels/swiglu/triton_implementation/forward.py delete mode 100644 cute_kernels/kernels/swiglu/triton_implementation/kernels_backward.py delete mode 100644 cute_kernels/kernels/swiglu/triton_implementation/kernels_forward.py rename cute_kernels/kernels/swiglu_unchunked/triton_implementation/{kernels_backward.py => backward.py} (60%) rename cute_kernels/kernels/swiglu_unchunked/triton_implementation/{kernels_forward.py => forward.py} (53%) diff --git a/cute_kernels/cpp_registry.yml b/cute_kernels/cpp_registry.yml index 6b8abeaf..bb60f6a0 100644 --- a/cute_kernels/cpp_registry.yml +++ b/cute_kernels/cpp_registry.yml @@ -2,21 +2,21 @@ - add_tensor_forward_cuda sources: - kernels/add/add_tensor/cuda_implementation/ops.cpp - - kernels/add/add_tensor/cuda_implementation/kernels_forward.cu + - kernels/add/add_tensor/cuda_implementation/forward.cu build_path: add_tensor - functions: - add_scalar_forward_cuda sources: - kernels/add/add_scalar/cuda_implementation/ops.cpp - - kernels/add/add_scalar/cuda_implementation/kernels_forward.cu + - kernels/add/add_scalar/cuda_implementation/forward.cu build_path: add_scalar - functions: - contiguous_count_cuda sources: - kernels/contiguous_count/cuda_implementation/ops.cpp - - kernels/contiguous_count/cuda_implementation/kernels_forward.cu + - kernels/contiguous_count/cuda_implementation/forward.cu build_path: contiguous_count - functions: @@ -24,6 +24,6 @@ - swiglu_backward_cuda sources: - kernels/swiglu/cuda_implementation/ops.cpp - - kernels/swiglu/cuda_implementation/kernels_forward.cu - - kernels/swiglu/cuda_implementation/kernels_backward.cu + - kernels/swiglu/cuda_implementation/forward.cu + - kernels/swiglu/cuda_implementation/backward.cu build_path: swiglu diff --git a/cute_kernels/kernels/add/add_scalar/cuda_implementation/__init__.py b/cute_kernels/kernels/add/add_scalar/cuda_implementation/__init__.py index 63068af0..bff51d50 100644 --- a/cute_kernels/kernels/add/add_scalar/cuda_implementation/__init__.py +++ b/cute_kernels/kernels/add/add_scalar/cuda_implementation/__init__.py @@ -1,15 +1 @@ -import torch - -from .....constants import LIBRARY_NAME -from .....jit import cpp_jit -from .....utils import cute_op - - -_KERNEL_NAME = "add_scalar_forward_cuda" - - -@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) -@cpp_jit(_KERNEL_NAME) -def add_scalar_forward_cuda( - x: torch.Tensor, y: float, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int -) -> None: ... +from .forward import add_scalar_forward_cuda diff --git a/cute_kernels/kernels/add/add_scalar/cuda_implementation/kernels_forward.cu b/cute_kernels/kernels/add/add_scalar/cuda_implementation/forward.cu similarity index 100% rename from cute_kernels/kernels/add/add_scalar/cuda_implementation/kernels_forward.cu rename to cute_kernels/kernels/add/add_scalar/cuda_implementation/forward.cu diff --git a/cute_kernels/kernels/add/add_scalar/cuda_implementation/forward.py b/cute_kernels/kernels/add/add_scalar/cuda_implementation/forward.py new file mode 100644 index 00000000..63068af0 --- /dev/null +++ b/cute_kernels/kernels/add/add_scalar/cuda_implementation/forward.py @@ -0,0 +1,15 @@ +import torch + +from .....constants import LIBRARY_NAME +from .....jit import cpp_jit +from .....utils import cute_op + + +_KERNEL_NAME = "add_scalar_forward_cuda" + + +@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) +@cpp_jit(_KERNEL_NAME) +def add_scalar_forward_cuda( + x: torch.Tensor, y: float, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int +) -> None: ... diff --git a/cute_kernels/kernels/add/add_scalar/triton_implementation/__init__.py b/cute_kernels/kernels/add/add_scalar/triton_implementation/__init__.py index 76975e2c..ce3c4199 100644 --- a/cute_kernels/kernels/add/add_scalar/triton_implementation/__init__.py +++ b/cute_kernels/kernels/add/add_scalar/triton_implementation/__init__.py @@ -1,20 +1 @@ -import torch - -from .....constants import LIBRARY_NAME -from .....math import ceil_divide -from .....utils import cute_op -from .kernels_forward import _add_scalar_forward_triton_kernel - - -_KERNEL_NAME = "add_scalar_forward_triton" - - -@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) -def add_scalar_forward_triton(x: torch.Tensor, y: float, output: torch.Tensor, BLOCK_SIZE: int) -> None: - num_elements = x.numel() - num_programs = ceil_divide(num_elements, BLOCK_SIZE) - - with torch.device(x.device): - _add_scalar_forward_triton_kernel[(num_programs,)]( - x_ptr=x, y=y, output_ptr=output, num_elements=num_elements, BLOCK_SIZE=BLOCK_SIZE - ) +from .forward import add_scalar_forward_triton diff --git a/cute_kernels/kernels/add/add_scalar/triton_implementation/forward.py b/cute_kernels/kernels/add/add_scalar/triton_implementation/forward.py new file mode 100644 index 00000000..9f7e0a03 --- /dev/null +++ b/cute_kernels/kernels/add/add_scalar/triton_implementation/forward.py @@ -0,0 +1,34 @@ +import torch +import triton +import triton.language as tl + +from .....constants import LIBRARY_NAME +from .....math import ceil_divide +from .....utils import cute_op + + +_KERNEL_NAME = "add_scalar_forward_triton" + + +@triton.jit +def _add_scalar_forward_triton_kernel(x_ptr, y, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + + indices = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = indices < num_elements + + x = tl.load(x_ptr + indices, mask=mask) + output = x + y + + tl.store(output_ptr + indices, output, mask=mask) + + +@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) +def add_scalar_forward_triton(x: torch.Tensor, y: float, output: torch.Tensor, BLOCK_SIZE: int) -> None: + num_elements = x.numel() + num_programs = ceil_divide(num_elements, BLOCK_SIZE) + + with torch.device(x.device): + _add_scalar_forward_triton_kernel[(num_programs,)]( + x_ptr=x, y=y, output_ptr=output, num_elements=num_elements, BLOCK_SIZE=BLOCK_SIZE + ) diff --git a/cute_kernels/kernels/add/add_scalar/triton_implementation/kernels_forward.py b/cute_kernels/kernels/add/add_scalar/triton_implementation/kernels_forward.py deleted file mode 100644 index 1d50fa22..00000000 --- a/cute_kernels/kernels/add/add_scalar/triton_implementation/kernels_forward.py +++ /dev/null @@ -1,15 +0,0 @@ -import triton -import triton.language as tl - - -@triton.jit -def _add_scalar_forward_triton_kernel(x_ptr, y, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(axis=0) - - indices = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = indices < num_elements - - x = tl.load(x_ptr + indices, mask=mask) - output = x + y - - tl.store(output_ptr + indices, output, mask=mask) diff --git a/cute_kernels/kernels/add/add_tensor/cuda_implementation/__init__.py b/cute_kernels/kernels/add/add_tensor/cuda_implementation/__init__.py index fa3171cc..f6bcdd61 100644 --- a/cute_kernels/kernels/add/add_tensor/cuda_implementation/__init__.py +++ b/cute_kernels/kernels/add/add_tensor/cuda_implementation/__init__.py @@ -1,15 +1 @@ -import torch - -from .....constants import LIBRARY_NAME -from .....jit import cpp_jit -from .....utils import cute_op - - -_KERNEL_NAME = "add_tensor_forward_cuda" - - -@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) -@cpp_jit(_KERNEL_NAME) -def add_tensor_forward_cuda( - x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int -) -> None: ... +from .forward import add_tensor_forward_cuda diff --git a/cute_kernels/kernels/add/add_tensor/cuda_implementation/kernels_forward.cu b/cute_kernels/kernels/add/add_tensor/cuda_implementation/forward.cu similarity index 100% rename from cute_kernels/kernels/add/add_tensor/cuda_implementation/kernels_forward.cu rename to cute_kernels/kernels/add/add_tensor/cuda_implementation/forward.cu diff --git a/cute_kernels/kernels/add/add_tensor/cuda_implementation/forward.py b/cute_kernels/kernels/add/add_tensor/cuda_implementation/forward.py new file mode 100644 index 00000000..fa3171cc --- /dev/null +++ b/cute_kernels/kernels/add/add_tensor/cuda_implementation/forward.py @@ -0,0 +1,15 @@ +import torch + +from .....constants import LIBRARY_NAME +from .....jit import cpp_jit +from .....utils import cute_op + + +_KERNEL_NAME = "add_tensor_forward_cuda" + + +@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) +@cpp_jit(_KERNEL_NAME) +def add_tensor_forward_cuda( + x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int +) -> None: ... diff --git a/cute_kernels/kernels/add/add_tensor/triton_implementation/__init__.py b/cute_kernels/kernels/add/add_tensor/triton_implementation/__init__.py index 9bd45bba..f04a0e4c 100644 --- a/cute_kernels/kernels/add/add_tensor/triton_implementation/__init__.py +++ b/cute_kernels/kernels/add/add_tensor/triton_implementation/__init__.py @@ -1,19 +1 @@ -import torch - -from .....constants import LIBRARY_NAME -from .....math import ceil_divide -from .....utils import cute_op -from .kernels_forward import _add_tensor_forward_triton_kernel - - -_KERNEL_NAME = "add_tensor_forward_triton" - - -@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) -def add_tensor_forward_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, BLOCK_SIZE: int) -> None: - num_elements = x.numel() - num_programs = ceil_divide(num_elements, BLOCK_SIZE) - - _add_tensor_forward_triton_kernel[(num_programs,)]( - x_ptr=x, y_ptr=y, output_ptr=output, num_elements=num_elements, BLOCK_SIZE=BLOCK_SIZE - ) +from .forward import add_tensor_forward_triton diff --git a/cute_kernels/kernels/add/add_tensor/triton_implementation/forward.py b/cute_kernels/kernels/add/add_tensor/triton_implementation/forward.py new file mode 100644 index 00000000..39287800 --- /dev/null +++ b/cute_kernels/kernels/add/add_tensor/triton_implementation/forward.py @@ -0,0 +1,35 @@ +import torch +import triton +import triton.language as tl + +from .....constants import LIBRARY_NAME +from .....math import ceil_divide +from .....utils import cute_op + + +_KERNEL_NAME = "add_tensor_forward_triton" + + +@triton.jit +def _add_tensor_forward_triton_kernel(x_ptr, y_ptr, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + + indices = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = indices < num_elements + + x = tl.load(x_ptr + indices, mask=mask) + y = tl.load(y_ptr + indices, mask=mask) + + output = x + y + + tl.store(output_ptr + indices, output, mask=mask) + + +@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) +def add_tensor_forward_triton(x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, BLOCK_SIZE: int) -> None: + num_elements = x.numel() + num_programs = ceil_divide(num_elements, BLOCK_SIZE) + + _add_tensor_forward_triton_kernel[(num_programs,)]( + x_ptr=x, y_ptr=y, output_ptr=output, num_elements=num_elements, BLOCK_SIZE=BLOCK_SIZE + ) diff --git a/cute_kernels/kernels/add/add_tensor/triton_implementation/kernels_forward.py b/cute_kernels/kernels/add/add_tensor/triton_implementation/kernels_forward.py deleted file mode 100644 index f7b7f7ef..00000000 --- a/cute_kernels/kernels/add/add_tensor/triton_implementation/kernels_forward.py +++ /dev/null @@ -1,17 +0,0 @@ -import triton -import triton.language as tl - - -@triton.jit -def _add_tensor_forward_triton_kernel(x_ptr, y_ptr, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(axis=0) - - indices = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = indices < num_elements - - x = tl.load(x_ptr + indices, mask=mask) - y = tl.load(y_ptr + indices, mask=mask) - - output = x + y - - tl.store(output_ptr + indices, output, mask=mask) diff --git a/cute_kernels/kernels/contiguous_count/cuda_implementation/__init__.py b/cute_kernels/kernels/contiguous_count/cuda_implementation/__init__.py index ea7e28a8..4f4bc456 100644 --- a/cute_kernels/kernels/contiguous_count/cuda_implementation/__init__.py +++ b/cute_kernels/kernels/contiguous_count/cuda_implementation/__init__.py @@ -1,15 +1 @@ -import torch - -from ....constants import LIBRARY_NAME -from ....jit import cpp_jit -from ....utils import cute_op - - -_KERNEL_NAME = "contiguous_count_cuda" - - -@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) -@cpp_jit(_KERNEL_NAME) -def contiguous_count_cuda( - x: torch.Tensor, output: torch.Tensor, sm_count: int, thread_block_cluster_size: int, size: int, BLOCK_SIZE: int -) -> None: ... +from .forward import contiguous_count_cuda diff --git a/cute_kernels/kernels/contiguous_count/cuda_implementation/kernels_forward.cu b/cute_kernels/kernels/contiguous_count/cuda_implementation/forward.cu similarity index 100% rename from cute_kernels/kernels/contiguous_count/cuda_implementation/kernels_forward.cu rename to cute_kernels/kernels/contiguous_count/cuda_implementation/forward.cu diff --git a/cute_kernels/kernels/contiguous_count/cuda_implementation/forward.py b/cute_kernels/kernels/contiguous_count/cuda_implementation/forward.py new file mode 100644 index 00000000..ea7e28a8 --- /dev/null +++ b/cute_kernels/kernels/contiguous_count/cuda_implementation/forward.py @@ -0,0 +1,15 @@ +import torch + +from ....constants import LIBRARY_NAME +from ....jit import cpp_jit +from ....utils import cute_op + + +_KERNEL_NAME = "contiguous_count_cuda" + + +@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) +@cpp_jit(_KERNEL_NAME) +def contiguous_count_cuda( + x: torch.Tensor, output: torch.Tensor, sm_count: int, thread_block_cluster_size: int, size: int, BLOCK_SIZE: int +) -> None: ... diff --git a/cute_kernels/kernels/contiguous_count/triton_implementation/__init__.py b/cute_kernels/kernels/contiguous_count/triton_implementation/__init__.py index aaa163ff..14d1179b 100644 --- a/cute_kernels/kernels/contiguous_count/triton_implementation/__init__.py +++ b/cute_kernels/kernels/contiguous_count/triton_implementation/__init__.py @@ -1,29 +1 @@ -import torch - -from ....constants import LIBRARY_NAME -from ....math import ceil_divide -from ....utils import cute_op, get_sm_count -from .kernels_forward import _contiguous_count_triton_kernel - - -_KERNEL_NAME = "contiguous_count_triton" - - -@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) -def contiguous_count_triton( - x: torch.Tensor, output: torch.Tensor, size: int, BLOCK_SIZE: int, BLOCK_SIZE_C: int -) -> None: - B = x.numel() - - sm_count = get_sm_count(x.device) - num_programs = min(sm_count, ceil_divide(B, BLOCK_SIZE)) - - with torch.device(x.device): - _contiguous_count_triton_kernel[(num_programs,)]( - x_ptr=x, - output_ptr=output, - B=B, - C=size, - BLOCK_SIZE_B=BLOCK_SIZE, - BLOCK_SIZE_C=BLOCK_SIZE_C, - ) +from .forward import contiguous_count_triton diff --git a/cute_kernels/kernels/contiguous_count/triton_implementation/kernels_forward.py b/cute_kernels/kernels/contiguous_count/triton_implementation/forward.py similarity index 60% rename from cute_kernels/kernels/contiguous_count/triton_implementation/kernels_forward.py rename to cute_kernels/kernels/contiguous_count/triton_implementation/forward.py index 8182fe38..1cf0f059 100644 --- a/cute_kernels/kernels/contiguous_count/triton_implementation/kernels_forward.py +++ b/cute_kernels/kernels/contiguous_count/triton_implementation/forward.py @@ -1,6 +1,14 @@ +import torch import triton import triton.language as tl +from ....constants import LIBRARY_NAME +from ....math import ceil_divide +from ....utils import cute_op, get_sm_count + + +_KERNEL_NAME = "contiguous_count_triton" + @triton.jit def _contiguous_count_triton_kernel(x_ptr, output_ptr, B, C, BLOCK_SIZE_B: tl.constexpr, BLOCK_SIZE_C: tl.constexpr): @@ -30,3 +38,23 @@ def _contiguous_count_triton_kernel(x_ptr, output_ptr, B, C, BLOCK_SIZE_B: tl.co counts += tl.sum(equal, axis=0) tl.atomic_add(output_ptr + indices_c, counts, mask=mask_c) + + +@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) +def contiguous_count_triton( + x: torch.Tensor, output: torch.Tensor, size: int, BLOCK_SIZE: int, BLOCK_SIZE_C: int +) -> None: + B = x.numel() + + sm_count = get_sm_count(x.device) + num_programs = min(sm_count, ceil_divide(B, BLOCK_SIZE)) + + with torch.device(x.device): + _contiguous_count_triton_kernel[(num_programs,)]( + x_ptr=x, + output_ptr=output, + B=B, + C=size, + BLOCK_SIZE_B=BLOCK_SIZE, + BLOCK_SIZE_C=BLOCK_SIZE_C, + ) diff --git a/cute_kernels/kernels/embedding/triton_implementation/__init__.py b/cute_kernels/kernels/embedding/triton_implementation/__init__.py index aa8335d6..d59cfdfd 100644 --- a/cute_kernels/kernels/embedding/triton_implementation/__init__.py +++ b/cute_kernels/kernels/embedding/triton_implementation/__init__.py @@ -1,66 +1,2 @@ -import torch - -from ....constants import LIBRARY_NAME -from ....math import ceil_divide -from ....utils import cute_op -from .kernels_backward import _embedding_backward_triton_kernel -from .kernels_forward import _embedding_forward_triton_kernel - - -_FORWARD_KERNEL_NAME = "embedding_forward_triton" -_BACKWARD_KERNEL_NAME = "embedding_backward_triton" - - -@cute_op(f"{LIBRARY_NAME}::{_FORWARD_KERNEL_NAME}", mutates_args={"output"}) -def embedding_forward_triton( - input_ids: torch.Tensor, - weight: torch.Tensor, - output: torch.Tensor, - BLOCK_SIZE_B: int, - BLOCK_SIZE_H: int, -) -> None: - num_elements = input_ids.numel() - hidden_size = weight.size(-1) - - with torch.device(input_ids.device): - _embedding_forward_triton_kernel[ - (ceil_divide(num_elements, BLOCK_SIZE_B), ceil_divide(hidden_size, BLOCK_SIZE_H)) - ]( - x_ptr=input_ids, - weight_ptr=weight, - output_ptr=output, - B=num_elements, - H=hidden_size, - BLOCK_SIZE_B=BLOCK_SIZE_B, - BLOCK_SIZE_H=BLOCK_SIZE_H, - ) - - -@cute_op(f"{LIBRARY_NAME}::{_BACKWARD_KERNEL_NAME}", mutates_args={"weight_grad"}) -def embedding_backward_triton( - input_ids: torch.Tensor, - output_grad: torch.Tensor, - weight_grad: torch.Tensor, - BLOCK_SIZE_B: int, - BLOCK_SIZE_H: int, -) -> None: - num_elements = input_ids.numel() - hidden_size = weight_grad.size(-1) - - accumulate_in_fp32 = weight_grad.dtype == torch.bfloat16 - if accumulate_in_fp32: - weight_grad = weight_grad.float() - - with torch.device(input_ids.device): - _embedding_backward_triton_kernel[ - (ceil_divide(num_elements, BLOCK_SIZE_B), ceil_divide(hidden_size, BLOCK_SIZE_H)) - ]( - x_ptr=input_ids, - output_grad_ptr=output_grad, - weight_grad_ptr=weight_grad, - B=num_elements, - H=hidden_size, - accumulate_in_fp32=accumulate_in_fp32, - BLOCK_SIZE_B=BLOCK_SIZE_B, - BLOCK_SIZE_H=BLOCK_SIZE_H, - ) +from .backward import embedding_backward_triton +from .forward import embedding_forward_triton diff --git a/cute_kernels/kernels/embedding/triton_implementation/backward.py b/cute_kernels/kernels/embedding/triton_implementation/backward.py new file mode 100644 index 00000000..5eb0d964 --- /dev/null +++ b/cute_kernels/kernels/embedding/triton_implementation/backward.py @@ -0,0 +1,75 @@ +import torch +import triton +import triton.language as tl + +from ....constants import LIBRARY_NAME +from ....math import ceil_divide +from ....utils import cute_op + + +_KERNEL_NAME = "embedding_backward_triton" + + +@triton.jit +def _embedding_backward_triton_kernel( + x_ptr, + output_grad_ptr, + weight_grad_ptr, + B, + H, + accumulate_in_fp32: tl.constexpr, + BLOCK_SIZE_B: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, +): + pid_b = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + + indices_b = pid_b * BLOCK_SIZE_B + tl.arange(0, BLOCK_SIZE_B) + indices_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + + mask_b = indices_b < B + mask_h = indices_h < H + mask_bh = mask_b[:, None] & mask_h[None, :] + + x_ptrs = x_ptr + indices_b + x = tl.load(x_ptrs, mask=mask_b) + + output_grad_ptrs = output_grad_ptr + indices_b[:, None] * H + indices_h[None, :] + output_grad = tl.load(output_grad_ptrs, mask=mask_bh) + + weight_grad_ptrs = weight_grad_ptr + x[:, None] * H + indices_h[None, :] + + if accumulate_in_fp32: + output_grad = output_grad.to(tl.float32) + + tl.atomic_add(weight_grad_ptrs, output_grad, mask=mask_bh) + + +@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"weight_grad"}) +def embedding_backward_triton( + input_ids: torch.Tensor, + output_grad: torch.Tensor, + weight_grad: torch.Tensor, + BLOCK_SIZE_B: int, + BLOCK_SIZE_H: int, +) -> None: + num_elements = input_ids.numel() + hidden_size = weight_grad.size(-1) + + accumulate_in_fp32 = weight_grad.dtype == torch.bfloat16 + if accumulate_in_fp32: + weight_grad = weight_grad.float() + + with torch.device(input_ids.device): + _embedding_backward_triton_kernel[ + (ceil_divide(num_elements, BLOCK_SIZE_B), ceil_divide(hidden_size, BLOCK_SIZE_H)) + ]( + x_ptr=input_ids, + output_grad_ptr=output_grad, + weight_grad_ptr=weight_grad, + B=num_elements, + H=hidden_size, + accumulate_in_fp32=accumulate_in_fp32, + BLOCK_SIZE_B=BLOCK_SIZE_B, + BLOCK_SIZE_H=BLOCK_SIZE_H, + ) diff --git a/cute_kernels/kernels/embedding/triton_implementation/forward.py b/cute_kernels/kernels/embedding/triton_implementation/forward.py new file mode 100644 index 00000000..c85c2975 --- /dev/null +++ b/cute_kernels/kernels/embedding/triton_implementation/forward.py @@ -0,0 +1,64 @@ +import torch +import triton +import triton.language as tl + +from ....constants import LIBRARY_NAME +from ....math import ceil_divide +from ....utils import cute_op + + +_KERNEL_NAME = "embedding_forward_triton" + + +@triton.jit +def _embedding_forward_triton_kernel( + x_ptr, + weight_ptr, + output_ptr, + B, + H, + BLOCK_SIZE_B: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, +): + pid_b = tl.program_id(axis=0) + pid_h = tl.program_id(axis=1) + + indices_b = pid_b * BLOCK_SIZE_B + tl.arange(0, BLOCK_SIZE_B) + indices_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + + mask_b = indices_b < B + mask_h = indices_h < H + + x_ptrs = x_ptr + indices_b + x = tl.load(x_ptrs, mask=mask_b) + + weight_ptrs = weight_ptr + x[:, None] * H + indices_h[None, :] + word_embeddings = tl.load(weight_ptrs, mask=mask_h[None, :]) + + output_ptrs = output_ptr + indices_b[:, None] * H + indices_h[None, :] + tl.store(output_ptrs, word_embeddings, mask=mask_b[:, None] & mask_h[None, :]) + + +@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) +def embedding_forward_triton( + input_ids: torch.Tensor, + weight: torch.Tensor, + output: torch.Tensor, + BLOCK_SIZE_B: int, + BLOCK_SIZE_H: int, +) -> None: + num_elements = input_ids.numel() + hidden_size = weight.size(-1) + + with torch.device(input_ids.device): + _embedding_forward_triton_kernel[ + (ceil_divide(num_elements, BLOCK_SIZE_B), ceil_divide(hidden_size, BLOCK_SIZE_H)) + ]( + x_ptr=input_ids, + weight_ptr=weight, + output_ptr=output, + B=num_elements, + H=hidden_size, + BLOCK_SIZE_B=BLOCK_SIZE_B, + BLOCK_SIZE_H=BLOCK_SIZE_H, + ) diff --git a/cute_kernels/kernels/embedding/triton_implementation/kernels_backward.py b/cute_kernels/kernels/embedding/triton_implementation/kernels_backward.py deleted file mode 100644 index 1245fffb..00000000 --- a/cute_kernels/kernels/embedding/triton_implementation/kernels_backward.py +++ /dev/null @@ -1,37 +0,0 @@ -import triton -import triton.language as tl - - -@triton.jit -def _embedding_backward_triton_kernel( - x_ptr, - output_grad_ptr, - weight_grad_ptr, - B, - H, - accumulate_in_fp32: tl.constexpr, - BLOCK_SIZE_B: tl.constexpr, - BLOCK_SIZE_H: tl.constexpr, -): - pid_b = tl.program_id(axis=0) - pid_h = tl.program_id(axis=1) - - indices_b = pid_b * BLOCK_SIZE_B + tl.arange(0, BLOCK_SIZE_B) - indices_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) - - mask_b = indices_b < B - mask_h = indices_h < H - mask_bh = mask_b[:, None] & mask_h[None, :] - - x_ptrs = x_ptr + indices_b - x = tl.load(x_ptrs, mask=mask_b) - - output_grad_ptrs = output_grad_ptr + indices_b[:, None] * H + indices_h[None, :] - output_grad = tl.load(output_grad_ptrs, mask=mask_bh) - - weight_grad_ptrs = weight_grad_ptr + x[:, None] * H + indices_h[None, :] - - if accumulate_in_fp32: - output_grad = output_grad.to(tl.float32) - - tl.atomic_add(weight_grad_ptrs, output_grad, mask=mask_bh) diff --git a/cute_kernels/kernels/embedding/triton_implementation/kernels_forward.py b/cute_kernels/kernels/embedding/triton_implementation/kernels_forward.py deleted file mode 100644 index 48ac822d..00000000 --- a/cute_kernels/kernels/embedding/triton_implementation/kernels_forward.py +++ /dev/null @@ -1,31 +0,0 @@ -import triton -import triton.language as tl - - -@triton.jit -def _embedding_forward_triton_kernel( - x_ptr, - weight_ptr, - output_ptr, - B, - H, - BLOCK_SIZE_B: tl.constexpr, - BLOCK_SIZE_H: tl.constexpr, -): - pid_b = tl.program_id(axis=0) - pid_h = tl.program_id(axis=1) - - indices_b = pid_b * BLOCK_SIZE_B + tl.arange(0, BLOCK_SIZE_B) - indices_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) - - mask_b = indices_b < B - mask_h = indices_h < H - - x_ptrs = x_ptr + indices_b - x = tl.load(x_ptrs, mask=mask_b) - - weight_ptrs = weight_ptr + x[:, None] * H + indices_h[None, :] - word_embeddings = tl.load(weight_ptrs, mask=mask_h[None, :]) - - output_ptrs = output_ptr + indices_b[:, None] * H + indices_h[None, :] - tl.store(output_ptrs, word_embeddings, mask=mask_b[:, None] & mask_h[None, :]) diff --git a/cute_kernels/kernels/rmsnorm/triton_implementation/__init__.py b/cute_kernels/kernels/rmsnorm/triton_implementation/__init__.py index f0cf398d..9aa6ba47 100644 --- a/cute_kernels/kernels/rmsnorm/triton_implementation/__init__.py +++ b/cute_kernels/kernels/rmsnorm/triton_implementation/__init__.py @@ -1,164 +1,2 @@ -import torch - -from ....constants import LIBRARY_NAME, TORCH_TO_TRITON_DTYPE -from ....cutotune import cutotune -from ....math import ceil_divide -from ....utils import cute_op, get_num_elements_and_hidden_size, get_sm_count -from .kernels_backward import _rmsnorm_backward_triton_kernel -from .kernels_forward import _rmsnorm_forward_triton_kernel -from .parameters import get_cutotune_parameters - - -_FORWARD_KERNEL_NAME = "rmsnorm_forward_triton" -_BACKWARD_KERNEL_NO_WEIGHT_NAME = "rmsnorm_backward_no_weight_triton" -_BACKWARD_KERNEL_WEIGHTED_NAME = "rmsnorm_backward_triton" - - -@cutotune(**get_cutotune_parameters()) -@cute_op(f"{LIBRARY_NAME}::{_FORWARD_KERNEL_NAME}", mutates_args={"output", "rmsnorm_denominator"}) -def rmsnorm_forward_triton( - x: torch.Tensor, - weight: torch.Tensor | None, - output: torch.Tensor, - eps: float, - rmsnorm_denominator: torch.Tensor | None, - BLOCK_SIZE_B: int, - BLOCK_SIZE_H: int, -) -> None: - num_elements, hidden_size = get_num_elements_and_hidden_size(x) - - if BLOCK_SIZE_H < hidden_size: - raise ValueError(f"hidden_size should be more than the BLOCK_SIZE_H") - - with torch.device(x.device): - _rmsnorm_forward_triton_kernel[(ceil_divide(num_elements, BLOCK_SIZE_B),)]( - x_ptr=x, - x_dtype=TORCH_TO_TRITON_DTYPE[x.dtype], - has_weight=weight is not None, - weight_ptr=weight, - output_ptr=output, - eps=eps, - memory_efficient=rmsnorm_denominator is None, - rmsnorm_denominator_ptr=rmsnorm_denominator, - B=num_elements, - H=hidden_size, - BLOCK_SIZE_B=BLOCK_SIZE_B, - BLOCK_SIZE_H=BLOCK_SIZE_H, - ) - - -@cute_op(f"{LIBRARY_NAME}::{_BACKWARD_KERNEL_NO_WEIGHT_NAME}", mutates_args={"x_grad"}) -def _rmsnorm_backward_no_weight_triton( - x: torch.Tensor, - output_grad: torch.Tensor, - rmsnorm_denominator: torch.Tensor | None, - x_grad: torch.Tensor, - eps: float, - BLOCK_SIZE_B: int, - BLOCK_SIZE_H: int, -) -> None: - num_elements, hidden_size = get_num_elements_and_hidden_size(x) - - if BLOCK_SIZE_H < hidden_size: - raise ValueError(f"hidden_size should be more than the BLOCK_SIZE_H") - - sm_count = get_sm_count(x.device) - num_programs = min(sm_count, ceil_divide(num_elements, BLOCK_SIZE_B)) - - with torch.device(x.device): - _rmsnorm_backward_triton_kernel[(num_programs,)]( - x_ptr=x, - x_dtype=TORCH_TO_TRITON_DTYPE[x.dtype], - has_weight=False, - weight_ptr=None, - output_grad_ptr=output_grad, - x_grad_ptr=x_grad, - weight_grad_ptr=None, - eps=eps, - memory_efficient=rmsnorm_denominator is None, - rmsnorm_denominator_ptr=rmsnorm_denominator, - B=num_elements, - H=hidden_size, - BLOCK_SIZE_B=BLOCK_SIZE_B, - BLOCK_SIZE_H=BLOCK_SIZE_H, - ) - - -@cute_op(f"{LIBRARY_NAME}::{_BACKWARD_KERNEL_WEIGHTED_NAME}", mutates_args={"x_grad", "weight_grad"}) -def _rmsnorm_backward_triton( - x: torch.Tensor, - weight: torch.Tensor, - output_grad: torch.Tensor, - rmsnorm_denominator: torch.Tensor, - x_grad: torch.Tensor, - weight_grad: torch.Tensor, - eps: float, - BLOCK_SIZE_B: int, - BLOCK_SIZE_H: int, -) -> None: - num_elements, hidden_size = get_num_elements_and_hidden_size(x) - - if BLOCK_SIZE_H < hidden_size: - raise ValueError(f"hidden_size should be more than the BLOCK_SIZE_H") - - sm_count = get_sm_count(x.device) - num_programs = min(sm_count, ceil_divide(num_elements, BLOCK_SIZE_B)) - - with torch.device(x.device): - _rmsnorm_backward_triton_kernel[(num_programs,)]( - x_ptr=x, - x_dtype=TORCH_TO_TRITON_DTYPE[x.dtype], - has_weight=True, - weight_ptr=weight, - output_grad_ptr=output_grad, - x_grad_ptr=x_grad, - weight_grad_ptr=weight_grad, - eps=eps, - memory_efficient=rmsnorm_denominator is None, - rmsnorm_denominator_ptr=rmsnorm_denominator, - B=num_elements, - H=hidden_size, - BLOCK_SIZE_B=BLOCK_SIZE_B, - BLOCK_SIZE_H=BLOCK_SIZE_H, - ) - - -@cutotune(**get_cutotune_parameters()) -def rmsnorm_backward_triton( - x: torch.Tensor, - weight: torch.Tensor | None, - output_grad: torch.Tensor, - rmsnorm_denominator: torch.Tensor, - x_grad: torch.Tensor, - eps: float, - BLOCK_SIZE_B: int, - BLOCK_SIZE_H: int, -) -> torch.Tensor | None: - if weight is None: - weight_grad = None - _rmsnorm_backward_no_weight_triton( - x=x, - output_grad=output_grad, - rmsnorm_denominator=rmsnorm_denominator, - x_grad=x_grad, - eps=eps, - BLOCK_SIZE_B=BLOCK_SIZE_B, - BLOCK_SIZE_H=BLOCK_SIZE_H, - ) - else: - weight_grad = torch.zeros_like(weight, dtype=torch.float32) - _rmsnorm_backward_triton( - x=x, - weight=weight, - output_grad=output_grad, - rmsnorm_denominator=rmsnorm_denominator, - x_grad=x_grad, - weight_grad=weight_grad, - eps=eps, - BLOCK_SIZE_B=BLOCK_SIZE_B, - BLOCK_SIZE_H=BLOCK_SIZE_H, - ) - - weight_grad = weight_grad.type_as(weight) - - return weight_grad +from .backward import rmsnorm_backward_triton +from .forward import rmsnorm_forward_triton diff --git a/cute_kernels/kernels/rmsnorm/triton_implementation/backward.py b/cute_kernels/kernels/rmsnorm/triton_implementation/backward.py new file mode 100644 index 00000000..639be0c0 --- /dev/null +++ b/cute_kernels/kernels/rmsnorm/triton_implementation/backward.py @@ -0,0 +1,209 @@ +import torch +import triton +import triton.language as tl + +from ....constants import LIBRARY_NAME, TORCH_TO_TRITON_DTYPE +from ....cutotune import cutotune +from ....math import ceil_divide +from ....utils import cute_op, get_num_elements_and_hidden_size, get_sm_count +from .parameters import get_cutotune_parameters + + +_KERNEL_NO_WEIGHT_NAME = "rmsnorm_backward_no_weight_triton" +_KERNEL_WEIGHTED_NAME = "rmsnorm_backward_triton" + + +@triton.jit +def _rmsnorm_backward_triton_kernel( + x_ptr, + x_dtype: tl.constexpr, + has_weight: tl.constexpr, + weight_ptr, + output_grad_ptr, + x_grad_ptr, + weight_grad_ptr, + eps, + memory_efficient: tl.constexpr, + rmsnorm_denominator_ptr, + B, + H, + BLOCK_SIZE_B: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_programs = tl.num_programs(axis=0) + + num_elements_per_program = tl.cdiv(B, num_programs) + + indices_h = tl.arange(0, BLOCK_SIZE_H) + mask_h = indices_h < H + + program_start = pid * num_elements_per_program + program_end = min(program_start + num_elements_per_program, B) + num_elements_in_current_program = program_end - program_start + + num_loops = tl.cdiv(num_elements_in_current_program, BLOCK_SIZE_B) + + if has_weight: + weight = tl.load(weight_ptr + indices_h, mask=mask_h)[None, :] + weight_grad = tl.zeros((BLOCK_SIZE_H,), dtype=tl.float32) + else: + weight = 1 + weight_grad = 0 + + for i in range(num_loops): + indices_b = program_start + i * BLOCK_SIZE_B + tl.arange(0, BLOCK_SIZE_B) + mask_b = indices_b < program_end + + mask_bh = mask_b[:, None] & mask_h[None, :] + + x_ptrs = x_ptr + indices_b[:, None] * H + indices_h[None, :] + x = tl.load(x_ptrs, mask=mask_bh).to(tl.float32) + + if memory_efficient: + squared_sum = tl.sum(x * x, axis=1) + inverse_rms = tl.rsqrt(squared_sum / H + eps) + else: + inverse_rms = tl.load(rmsnorm_denominator_ptr + indices_b, mask=mask_b) + + output_grad_ptrs = output_grad_ptr + indices_b[:, None] * H + indices_h[None, :] + output_grad = tl.load(output_grad_ptrs, mask=mask_bh) + + output_grad_weight = (output_grad * weight).to(tl.float32) + + x_grad = inverse_rms[:, None] * output_grad_weight + x_grad -= ( + (1 / H) + * inverse_rms[:, None] + * inverse_rms[:, None] + * inverse_rms[:, None] + * x + * tl.sum(output_grad_weight * x, axis=1, keep_dims=True) + ) + x_grad = x_grad.to(x_dtype) + + x_grad_ptrs = x_grad_ptr + indices_b[:, None] * H + indices_h[None, :] + tl.store(x_grad_ptrs, x_grad, mask=mask_bh) + + if has_weight: + weight_grad += tl.sum(output_grad * (x * inverse_rms[:, None]).to(x_dtype), axis=0) + + if has_weight: + tl.atomic_add(weight_grad_ptr + indices_h, weight_grad, mask=mask_h) + + +@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NO_WEIGHT_NAME}", mutates_args={"x_grad"}) +def _rmsnorm_backward_no_weight_triton( + x: torch.Tensor, + output_grad: torch.Tensor, + rmsnorm_denominator: torch.Tensor | None, + x_grad: torch.Tensor, + eps: float, + BLOCK_SIZE_B: int, + BLOCK_SIZE_H: int, +) -> None: + num_elements, hidden_size = get_num_elements_and_hidden_size(x) + + if BLOCK_SIZE_H < hidden_size: + raise ValueError(f"hidden_size should be more than the BLOCK_SIZE_H") + + sm_count = get_sm_count(x.device) + num_programs = min(sm_count, ceil_divide(num_elements, BLOCK_SIZE_B)) + + with torch.device(x.device): + _rmsnorm_backward_triton_kernel[(num_programs,)]( + x_ptr=x, + x_dtype=TORCH_TO_TRITON_DTYPE[x.dtype], + has_weight=False, + weight_ptr=None, + output_grad_ptr=output_grad, + x_grad_ptr=x_grad, + weight_grad_ptr=None, + eps=eps, + memory_efficient=rmsnorm_denominator is None, + rmsnorm_denominator_ptr=rmsnorm_denominator, + B=num_elements, + H=hidden_size, + BLOCK_SIZE_B=BLOCK_SIZE_B, + BLOCK_SIZE_H=BLOCK_SIZE_H, + ) + + +@cute_op(f"{LIBRARY_NAME}::{_KERNEL_WEIGHTED_NAME}", mutates_args={"x_grad", "weight_grad"}) +def _rmsnorm_backward_triton( + x: torch.Tensor, + weight: torch.Tensor, + output_grad: torch.Tensor, + rmsnorm_denominator: torch.Tensor, + x_grad: torch.Tensor, + weight_grad: torch.Tensor, + eps: float, + BLOCK_SIZE_B: int, + BLOCK_SIZE_H: int, +) -> None: + num_elements, hidden_size = get_num_elements_and_hidden_size(x) + + if BLOCK_SIZE_H < hidden_size: + raise ValueError(f"hidden_size should be more than the BLOCK_SIZE_H") + + sm_count = get_sm_count(x.device) + num_programs = min(sm_count, ceil_divide(num_elements, BLOCK_SIZE_B)) + + with torch.device(x.device): + _rmsnorm_backward_triton_kernel[(num_programs,)]( + x_ptr=x, + x_dtype=TORCH_TO_TRITON_DTYPE[x.dtype], + has_weight=True, + weight_ptr=weight, + output_grad_ptr=output_grad, + x_grad_ptr=x_grad, + weight_grad_ptr=weight_grad, + eps=eps, + memory_efficient=rmsnorm_denominator is None, + rmsnorm_denominator_ptr=rmsnorm_denominator, + B=num_elements, + H=hidden_size, + BLOCK_SIZE_B=BLOCK_SIZE_B, + BLOCK_SIZE_H=BLOCK_SIZE_H, + ) + + +@cutotune(**get_cutotune_parameters()) +def rmsnorm_backward_triton( + x: torch.Tensor, + weight: torch.Tensor | None, + output_grad: torch.Tensor, + rmsnorm_denominator: torch.Tensor, + x_grad: torch.Tensor, + eps: float, + BLOCK_SIZE_B: int, + BLOCK_SIZE_H: int, +) -> torch.Tensor | None: + if weight is None: + weight_grad = None + _rmsnorm_backward_no_weight_triton( + x=x, + output_grad=output_grad, + rmsnorm_denominator=rmsnorm_denominator, + x_grad=x_grad, + eps=eps, + BLOCK_SIZE_B=BLOCK_SIZE_B, + BLOCK_SIZE_H=BLOCK_SIZE_H, + ) + else: + weight_grad = torch.zeros_like(weight, dtype=torch.float32) + _rmsnorm_backward_triton( + x=x, + weight=weight, + output_grad=output_grad, + rmsnorm_denominator=rmsnorm_denominator, + x_grad=x_grad, + weight_grad=weight_grad, + eps=eps, + BLOCK_SIZE_B=BLOCK_SIZE_B, + BLOCK_SIZE_H=BLOCK_SIZE_H, + ) + + weight_grad = weight_grad.type_as(weight) + + return weight_grad diff --git a/cute_kernels/kernels/rmsnorm/triton_implementation/__init__.py->rmsnorm_backward_triton.yml b/cute_kernels/kernels/rmsnorm/triton_implementation/backward.py->rmsnorm_backward_triton.yml similarity index 100% rename from cute_kernels/kernels/rmsnorm/triton_implementation/__init__.py->rmsnorm_backward_triton.yml rename to cute_kernels/kernels/rmsnorm/triton_implementation/backward.py->rmsnorm_backward_triton.yml diff --git a/cute_kernels/kernels/rmsnorm/triton_implementation/forward.py b/cute_kernels/kernels/rmsnorm/triton_implementation/forward.py new file mode 100644 index 00000000..11595b5d --- /dev/null +++ b/cute_kernels/kernels/rmsnorm/triton_implementation/forward.py @@ -0,0 +1,89 @@ +import torch +import triton +import triton.language as tl + +from ....constants import LIBRARY_NAME, TORCH_TO_TRITON_DTYPE +from ....cutotune import cutotune +from ....math import ceil_divide +from ....utils import cute_op, get_num_elements_and_hidden_size +from .parameters import get_cutotune_parameters + + +_KERNEL_NAME = "rmsnorm_forward_triton" + + +@triton.jit +def _rmsnorm_forward_triton_kernel( + x_ptr, + x_dtype: tl.constexpr, + has_weight: tl.constexpr, + weight_ptr, + output_ptr, + eps, + memory_efficient: tl.constexpr, + rmsnorm_denominator_ptr, + B, + H, + BLOCK_SIZE_B: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, +): + pid_b = tl.program_id(axis=0) + + indices_b = pid_b * BLOCK_SIZE_B + tl.arange(0, BLOCK_SIZE_B) + indices_h = tl.arange(0, BLOCK_SIZE_H) + + mask_b = indices_b < B + mask_h = indices_h < H + + mask_bh = mask_b[:, None] & mask_h[None, :] + + x_ptrs = x_ptr + indices_b[:, None] * H + indices_h[None, :] + x = tl.load(x_ptrs, mask=mask_bh).to(tl.float32) + + squared_sum = tl.sum(x * x, axis=1) + inverse_rms = tl.rsqrt((squared_sum / H) + eps) + + if not memory_efficient: + tl.store(rmsnorm_denominator_ptr + indices_b, inverse_rms, mask=mask_b) + + x *= inverse_rms[:, None] + + if has_weight: + weight = tl.load(weight_ptr + indices_h, mask=mask_h) + x = x.to(x_dtype) * weight[None, :] + + output_ptrs = output_ptr + indices_b[:, None] * H + indices_h[None, :] + tl.store(output_ptrs, x, mask=mask_bh) + + +@cutotune(**get_cutotune_parameters()) +@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output", "rmsnorm_denominator"}) +def rmsnorm_forward_triton( + x: torch.Tensor, + weight: torch.Tensor | None, + output: torch.Tensor, + eps: float, + rmsnorm_denominator: torch.Tensor | None, + BLOCK_SIZE_B: int, + BLOCK_SIZE_H: int, +) -> None: + num_elements, hidden_size = get_num_elements_and_hidden_size(x) + + if BLOCK_SIZE_H < hidden_size: + raise ValueError(f"hidden_size should be more than the BLOCK_SIZE_H") + + with torch.device(x.device): + _rmsnorm_forward_triton_kernel[(ceil_divide(num_elements, BLOCK_SIZE_B),)]( + x_ptr=x, + x_dtype=TORCH_TO_TRITON_DTYPE[x.dtype], + has_weight=weight is not None, + weight_ptr=weight, + output_ptr=output, + eps=eps, + memory_efficient=rmsnorm_denominator is None, + rmsnorm_denominator_ptr=rmsnorm_denominator, + B=num_elements, + H=hidden_size, + BLOCK_SIZE_B=BLOCK_SIZE_B, + BLOCK_SIZE_H=BLOCK_SIZE_H, + ) diff --git a/cute_kernels/kernels/rmsnorm/triton_implementation/__init__.py->rmsnorm_forward_triton.yml b/cute_kernels/kernels/rmsnorm/triton_implementation/forward.py->rmsnorm_forward_triton.yml similarity index 100% rename from cute_kernels/kernels/rmsnorm/triton_implementation/__init__.py->rmsnorm_forward_triton.yml rename to cute_kernels/kernels/rmsnorm/triton_implementation/forward.py->rmsnorm_forward_triton.yml diff --git a/cute_kernels/kernels/rmsnorm/triton_implementation/kernels_backward.py b/cute_kernels/kernels/rmsnorm/triton_implementation/kernels_backward.py deleted file mode 100644 index 337006d1..00000000 --- a/cute_kernels/kernels/rmsnorm/triton_implementation/kernels_backward.py +++ /dev/null @@ -1,81 +0,0 @@ -import triton -import triton.language as tl - - -@triton.jit -def _rmsnorm_backward_triton_kernel( - x_ptr, - x_dtype: tl.constexpr, - has_weight: tl.constexpr, - weight_ptr, - output_grad_ptr, - x_grad_ptr, - weight_grad_ptr, - eps, - memory_efficient: tl.constexpr, - rmsnorm_denominator_ptr, - B, - H, - BLOCK_SIZE_B: tl.constexpr, - BLOCK_SIZE_H: tl.constexpr, -): - pid = tl.program_id(axis=0) - num_programs = tl.num_programs(axis=0) - - num_elements_per_program = tl.cdiv(B, num_programs) - - indices_h = tl.arange(0, BLOCK_SIZE_H) - mask_h = indices_h < H - - program_start = pid * num_elements_per_program - program_end = min(program_start + num_elements_per_program, B) - num_elements_in_current_program = program_end - program_start - - num_loops = tl.cdiv(num_elements_in_current_program, BLOCK_SIZE_B) - - if has_weight: - weight = tl.load(weight_ptr + indices_h, mask=mask_h)[None, :] - weight_grad = tl.zeros((BLOCK_SIZE_H,), dtype=tl.float32) - else: - weight = 1 - weight_grad = 0 - - for i in range(num_loops): - indices_b = program_start + i * BLOCK_SIZE_B + tl.arange(0, BLOCK_SIZE_B) - mask_b = indices_b < program_end - - mask_bh = mask_b[:, None] & mask_h[None, :] - - x_ptrs = x_ptr + indices_b[:, None] * H + indices_h[None, :] - x = tl.load(x_ptrs, mask=mask_bh).to(tl.float32) - - if memory_efficient: - squared_sum = tl.sum(x * x, axis=1) - inverse_rms = tl.rsqrt(squared_sum / H + eps) - else: - inverse_rms = tl.load(rmsnorm_denominator_ptr + indices_b, mask=mask_b) - - output_grad_ptrs = output_grad_ptr + indices_b[:, None] * H + indices_h[None, :] - output_grad = tl.load(output_grad_ptrs, mask=mask_bh) - - output_grad_weight = (output_grad * weight).to(tl.float32) - - x_grad = inverse_rms[:, None] * output_grad_weight - x_grad -= ( - (1 / H) - * inverse_rms[:, None] - * inverse_rms[:, None] - * inverse_rms[:, None] - * x - * tl.sum(output_grad_weight * x, axis=1, keep_dims=True) - ) - x_grad = x_grad.to(x_dtype) - - x_grad_ptrs = x_grad_ptr + indices_b[:, None] * H + indices_h[None, :] - tl.store(x_grad_ptrs, x_grad, mask=mask_bh) - - if has_weight: - weight_grad += tl.sum(output_grad * (x * inverse_rms[:, None]).to(x_dtype), axis=0) - - if has_weight: - tl.atomic_add(weight_grad_ptr + indices_h, weight_grad, mask=mask_h) diff --git a/cute_kernels/kernels/rmsnorm/triton_implementation/kernels_forward.py b/cute_kernels/kernels/rmsnorm/triton_implementation/kernels_forward.py deleted file mode 100644 index 89d06cf9..00000000 --- a/cute_kernels/kernels/rmsnorm/triton_implementation/kernels_forward.py +++ /dev/null @@ -1,46 +0,0 @@ -import triton -import triton.language as tl - - -@triton.jit -def _rmsnorm_forward_triton_kernel( - x_ptr, - x_dtype: tl.constexpr, - has_weight: tl.constexpr, - weight_ptr, - output_ptr, - eps, - memory_efficient: tl.constexpr, - rmsnorm_denominator_ptr, - B, - H, - BLOCK_SIZE_B: tl.constexpr, - BLOCK_SIZE_H: tl.constexpr, -): - pid_b = tl.program_id(axis=0) - - indices_b = pid_b * BLOCK_SIZE_B + tl.arange(0, BLOCK_SIZE_B) - indices_h = tl.arange(0, BLOCK_SIZE_H) - - mask_b = indices_b < B - mask_h = indices_h < H - - mask_bh = mask_b[:, None] & mask_h[None, :] - - x_ptrs = x_ptr + indices_b[:, None] * H + indices_h[None, :] - x = tl.load(x_ptrs, mask=mask_bh).to(tl.float32) - - squared_sum = tl.sum(x * x, axis=1) - inverse_rms = tl.rsqrt((squared_sum / H) + eps) - - if not memory_efficient: - tl.store(rmsnorm_denominator_ptr + indices_b, inverse_rms, mask=mask_b) - - x *= inverse_rms[:, None] - - if has_weight: - weight = tl.load(weight_ptr + indices_h, mask=mask_h) - x = x.to(x_dtype) * weight[None, :] - - output_ptrs = output_ptr + indices_b[:, None] * H + indices_h[None, :] - tl.store(output_ptrs, x, mask=mask_bh) diff --git a/cute_kernels/kernels/swiglu/cuda_implementation/__init__.py b/cute_kernels/kernels/swiglu/cuda_implementation/__init__.py index d161ea2b..688b08d8 100644 --- a/cute_kernels/kernels/swiglu/cuda_implementation/__init__.py +++ b/cute_kernels/kernels/swiglu/cuda_implementation/__init__.py @@ -1,26 +1,2 @@ -import torch - -from ....constants import LIBRARY_NAME -from ....jit import cpp_jit -from ....utils import cute_op - - -_FORWARD_KERNEL_NAME = "swiglu_forward_cuda" -_BACKWARD_KERNEL_NAME = "swiglu_backward_cuda" - - -@cute_op(f"{LIBRARY_NAME}::{_FORWARD_KERNEL_NAME}", mutates_args={"output"}) -@cpp_jit(_FORWARD_KERNEL_NAME) -def swiglu_forward_cuda(gate: torch.Tensor, up: torch.Tensor, output: torch.Tensor, BLOCK_SIZE: int) -> None: ... - - -@cute_op(f"{LIBRARY_NAME}::{_BACKWARD_KERNEL_NAME}", mutates_args={"gate_grad", "up_grad"}) -@cpp_jit(_BACKWARD_KERNEL_NAME) -def swiglu_backward_cuda( - gate: torch.Tensor, - up: torch.Tensor, - output_grad: torch.Tensor, - gate_grad: torch.Tensor, - up_grad: torch.Tensor, - BLOCK_SIZE: int, -) -> None: ... +from .backward import swiglu_backward_cuda +from .forward import swiglu_forward_cuda diff --git a/cute_kernels/kernels/swiglu/cuda_implementation/kernels_backward.cu b/cute_kernels/kernels/swiglu/cuda_implementation/backward.cu similarity index 100% rename from cute_kernels/kernels/swiglu/cuda_implementation/kernels_backward.cu rename to cute_kernels/kernels/swiglu/cuda_implementation/backward.cu diff --git a/cute_kernels/kernels/swiglu/cuda_implementation/backward.py b/cute_kernels/kernels/swiglu/cuda_implementation/backward.py new file mode 100644 index 00000000..e9badce0 --- /dev/null +++ b/cute_kernels/kernels/swiglu/cuda_implementation/backward.py @@ -0,0 +1,20 @@ +import torch + +from ....constants import LIBRARY_NAME +from ....jit import cpp_jit +from ....utils import cute_op + + +_KERNEL_NAME = "swiglu_backward_cuda" + + +@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"gate_grad", "up_grad"}) +@cpp_jit(_KERNEL_NAME) +def swiglu_backward_cuda( + gate: torch.Tensor, + up: torch.Tensor, + output_grad: torch.Tensor, + gate_grad: torch.Tensor, + up_grad: torch.Tensor, + BLOCK_SIZE: int, +) -> None: ... diff --git a/cute_kernels/kernels/swiglu/cuda_implementation/kernels_forward.cu b/cute_kernels/kernels/swiglu/cuda_implementation/forward.cu similarity index 100% rename from cute_kernels/kernels/swiglu/cuda_implementation/kernels_forward.cu rename to cute_kernels/kernels/swiglu/cuda_implementation/forward.cu diff --git a/cute_kernels/kernels/swiglu/cuda_implementation/forward.py b/cute_kernels/kernels/swiglu/cuda_implementation/forward.py new file mode 100644 index 00000000..106a125e --- /dev/null +++ b/cute_kernels/kernels/swiglu/cuda_implementation/forward.py @@ -0,0 +1,13 @@ +import torch + +from ....constants import LIBRARY_NAME +from ....jit import cpp_jit +from ....utils import cute_op + + +_KERNEL_NAME = "swiglu_forward_cuda" + + +@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) +@cpp_jit(_KERNEL_NAME) +def swiglu_forward_cuda(gate: torch.Tensor, up: torch.Tensor, output: torch.Tensor, BLOCK_SIZE: int) -> None: ... diff --git a/cute_kernels/kernels/swiglu/triton_implementation/__init__.py b/cute_kernels/kernels/swiglu/triton_implementation/__init__.py index 7f512e6d..64bc0d26 100644 --- a/cute_kernels/kernels/swiglu/triton_implementation/__init__.py +++ b/cute_kernels/kernels/swiglu/triton_implementation/__init__.py @@ -1,48 +1,2 @@ -import torch - -from ....constants import LIBRARY_NAME -from ....math import ceil_divide -from ....utils import cute_op -from .kernels_backward import _swiglu_backward_triton_kernel -from .kernels_forward import _swiglu_forward_triton_kernel - - -_FORWARD_KERNEL_NAME = "swiglu_forward_triton" -_BACKWARD_KERNEL_NAME = "swiglu_backward_triton" - - -@cute_op(f"{LIBRARY_NAME}::{_FORWARD_KERNEL_NAME}", mutates_args={"output"}) -def swiglu_forward_triton(gate: torch.Tensor, up: torch.Tensor, output: torch.Tensor, BLOCK_SIZE: int) -> None: - num_elements = gate.numel() - - with torch.device(gate.device): - _swiglu_forward_triton_kernel[(ceil_divide(num_elements, BLOCK_SIZE),)]( - gate_ptr=gate, - up_ptr=up, - output_ptr=output, - num_elements=num_elements, - BLOCK_SIZE=BLOCK_SIZE, - ) - - -@cute_op(f"{LIBRARY_NAME}::{_BACKWARD_KERNEL_NAME}", mutates_args={"gate_grad", "up_grad"}) -def swiglu_backward_triton( - gate: torch.Tensor, - up: torch.Tensor, - output_grad: torch.Tensor, - gate_grad: torch.Tensor, - up_grad: torch.Tensor, - BLOCK_SIZE: int, -) -> None: - num_elements = gate.numel() - - with torch.device(gate.device): - _swiglu_backward_triton_kernel[(ceil_divide(num_elements, BLOCK_SIZE),)]( - gate_ptr=gate, - up_ptr=up, - output_grad_ptr=output_grad, - gate_grad_ptr=gate_grad, - up_grad_ptr=up_grad, - num_elements=num_elements, - BLOCK_SIZE=BLOCK_SIZE, - ) +from .backward import swiglu_backward_triton +from .forward import swiglu_forward_triton diff --git a/cute_kernels/kernels/swiglu/triton_implementation/backward.py b/cute_kernels/kernels/swiglu/triton_implementation/backward.py new file mode 100644 index 00000000..c000a1ba --- /dev/null +++ b/cute_kernels/kernels/swiglu/triton_implementation/backward.py @@ -0,0 +1,56 @@ +import torch +import triton +import triton.language as tl + +from ....constants import LIBRARY_NAME +from ....math import ceil_divide +from ....utils import cute_op + + +_KERNEL_NAME = "swiglu_backward_triton" + + +@triton.jit +def _swiglu_backward_triton_kernel( + gate_ptr, up_ptr, output_grad_ptr, gate_grad_ptr, up_grad_ptr, num_elements, BLOCK_SIZE: tl.constexpr +): + pid = tl.program_id(axis=0) + + indices = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = indices < num_elements + + gate = tl.load(gate_ptr + indices, mask=mask).to(tl.float32) + up = tl.load(up_ptr + indices, mask=mask) + output_grad = tl.load(output_grad_ptr + indices, mask=mask) + + gate_sigmoid = tl.sigmoid(gate) + gate_silu = gate * gate_sigmoid + + gate_grad = output_grad * up * (gate_sigmoid + gate_silu * (1 - gate_sigmoid)) + up_grad = output_grad * gate_silu + + tl.store(gate_grad_ptr + indices, gate_grad, mask=mask) + tl.store(up_grad_ptr + indices, up_grad, mask=mask) + + +@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"gate_grad", "up_grad"}) +def swiglu_backward_triton( + gate: torch.Tensor, + up: torch.Tensor, + output_grad: torch.Tensor, + gate_grad: torch.Tensor, + up_grad: torch.Tensor, + BLOCK_SIZE: int, +) -> None: + num_elements = gate.numel() + + with torch.device(gate.device): + _swiglu_backward_triton_kernel[(ceil_divide(num_elements, BLOCK_SIZE),)]( + gate_ptr=gate, + up_ptr=up, + output_grad_ptr=output_grad, + gate_grad_ptr=gate_grad, + up_grad_ptr=up_grad, + num_elements=num_elements, + BLOCK_SIZE=BLOCK_SIZE, + ) diff --git a/cute_kernels/kernels/swiglu/triton_implementation/forward.py b/cute_kernels/kernels/swiglu/triton_implementation/forward.py new file mode 100644 index 00000000..bf52bcfd --- /dev/null +++ b/cute_kernels/kernels/swiglu/triton_implementation/forward.py @@ -0,0 +1,39 @@ +import torch +import triton +import triton.language as tl + +from ....constants import LIBRARY_NAME +from ....math import ceil_divide +from ....utils import cute_op + + +_KERNEL_NAME = "swiglu_forward_triton" + + +@triton.jit +def _swiglu_forward_triton_kernel(gate_ptr, up_ptr, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + + indices = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = indices < num_elements + + gate = tl.load(gate_ptr + indices, mask=mask).to(tl.float32) + up = tl.load(up_ptr + indices, mask=mask) + + output = up * gate * tl.sigmoid(gate) + + tl.store(output_ptr + indices, output, mask=mask) + + +@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) +def swiglu_forward_triton(gate: torch.Tensor, up: torch.Tensor, output: torch.Tensor, BLOCK_SIZE: int) -> None: + num_elements = gate.numel() + + with torch.device(gate.device): + _swiglu_forward_triton_kernel[(ceil_divide(num_elements, BLOCK_SIZE),)]( + gate_ptr=gate, + up_ptr=up, + output_ptr=output, + num_elements=num_elements, + BLOCK_SIZE=BLOCK_SIZE, + ) diff --git a/cute_kernels/kernels/swiglu/triton_implementation/kernels_backward.py b/cute_kernels/kernels/swiglu/triton_implementation/kernels_backward.py deleted file mode 100644 index 72fabfa8..00000000 --- a/cute_kernels/kernels/swiglu/triton_implementation/kernels_backward.py +++ /dev/null @@ -1,25 +0,0 @@ -import triton -import triton.language as tl - - -@triton.jit -def _swiglu_backward_triton_kernel( - gate_ptr, up_ptr, output_grad_ptr, gate_grad_ptr, up_grad_ptr, num_elements, BLOCK_SIZE: tl.constexpr -): - pid = tl.program_id(axis=0) - - indices = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = indices < num_elements - - gate = tl.load(gate_ptr + indices, mask=mask).to(tl.float32) - up = tl.load(up_ptr + indices, mask=mask) - output_grad = tl.load(output_grad_ptr + indices, mask=mask) - - gate_sigmoid = tl.sigmoid(gate) - gate_silu = gate * gate_sigmoid - - gate_grad = output_grad * up * (gate_sigmoid + gate_silu * (1 - gate_sigmoid)) - up_grad = output_grad * gate_silu - - tl.store(gate_grad_ptr + indices, gate_grad, mask=mask) - tl.store(up_grad_ptr + indices, up_grad, mask=mask) diff --git a/cute_kernels/kernels/swiglu/triton_implementation/kernels_forward.py b/cute_kernels/kernels/swiglu/triton_implementation/kernels_forward.py deleted file mode 100644 index 361195c5..00000000 --- a/cute_kernels/kernels/swiglu/triton_implementation/kernels_forward.py +++ /dev/null @@ -1,17 +0,0 @@ -import triton -import triton.language as tl - - -@triton.jit -def _swiglu_forward_triton_kernel(gate_ptr, up_ptr, output_ptr, num_elements, BLOCK_SIZE: tl.constexpr): - pid = tl.program_id(axis=0) - - indices = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = indices < num_elements - - gate = tl.load(gate_ptr + indices, mask=mask).to(tl.float32) - up = tl.load(up_ptr + indices, mask=mask) - - output = up * gate * tl.sigmoid(gate) - - tl.store(output_ptr + indices, output, mask=mask) diff --git a/cute_kernels/kernels/swiglu_unchunked/triton_implementation/__init__.py b/cute_kernels/kernels/swiglu_unchunked/triton_implementation/__init__.py index 41c9563c..2d90a138 100644 --- a/cute_kernels/kernels/swiglu_unchunked/triton_implementation/__init__.py +++ b/cute_kernels/kernels/swiglu_unchunked/triton_implementation/__init__.py @@ -1,50 +1,2 @@ -import torch - -from ....constants import LIBRARY_NAME -from ....math import ceil_divide -from ....utils import cute_op, get_num_elements_and_hidden_size -from .kernels_backward import _swiglu_unchunked_backward_triton_kernel -from .kernels_forward import _swiglu_unchunked_forward_triton_kernel - - -_FORWARD_KERNEL_NAME = "swiglu_unchunked_forward_triton" -_BACKWARD_KERNEL_NAME = "swiglu_unchunked_backward_triton" - - -@cute_op(f"{LIBRARY_NAME}::{_FORWARD_KERNEL_NAME}", mutates_args={"output"}) -def swiglu_unchunked_forward_triton( - x: torch.Tensor, output: torch.Tensor, BLOCK_SIZE_B: int, BLOCK_SIZE_H: int -) -> None: - B, H = get_num_elements_and_hidden_size(x) - - with torch.device(x.device): - _swiglu_unchunked_forward_triton_kernel[(ceil_divide(B, BLOCK_SIZE_B), ceil_divide(H, BLOCK_SIZE_H))]( - x_ptr=x, - output_ptr=output, - B=B, - H=H, - BLOCK_SIZE_B=BLOCK_SIZE_B, - BLOCK_SIZE_H=BLOCK_SIZE_H, - ) - - -@cute_op(f"{LIBRARY_NAME}::{_BACKWARD_KERNEL_NAME}", mutates_args={"x_grad"}) -def swiglu_unchunked_backward_triton( - x: torch.Tensor, - output_grad: torch.Tensor, - x_grad: torch.Tensor, - BLOCK_SIZE_B: int, - BLOCK_SIZE_H: int, -) -> None: - B, H = get_num_elements_and_hidden_size(x) - - with torch.device(x.device): - _swiglu_unchunked_backward_triton_kernel[(ceil_divide(B, BLOCK_SIZE_B), ceil_divide(H, BLOCK_SIZE_H))]( - x_ptr=x, - output_grad_ptr=output_grad, - x_grad_ptr=x_grad, - B=B, - H=H, - BLOCK_SIZE_B=BLOCK_SIZE_B, - BLOCK_SIZE_H=BLOCK_SIZE_H, - ) +from .backward import swiglu_unchunked_backward_triton +from .forward import swiglu_unchunked_forward_triton diff --git a/cute_kernels/kernels/swiglu_unchunked/triton_implementation/kernels_backward.py b/cute_kernels/kernels/swiglu_unchunked/triton_implementation/backward.py similarity index 60% rename from cute_kernels/kernels/swiglu_unchunked/triton_implementation/kernels_backward.py rename to cute_kernels/kernels/swiglu_unchunked/triton_implementation/backward.py index 95163434..46ddf502 100644 --- a/cute_kernels/kernels/swiglu_unchunked/triton_implementation/kernels_backward.py +++ b/cute_kernels/kernels/swiglu_unchunked/triton_implementation/backward.py @@ -1,6 +1,14 @@ +import torch import triton import triton.language as tl +from ....constants import LIBRARY_NAME +from ....math import ceil_divide +from ....utils import cute_op, get_num_elements_and_hidden_size + + +_KERNEL_NAME = "swiglu_unchunked_backward_triton" + @triton.jit def _swiglu_unchunked_backward_triton_kernel( @@ -38,3 +46,25 @@ def _swiglu_unchunked_backward_triton_kernel( gate_grad_ptrs = up_grad_ptrs + (H >> 1) tl.store(gate_grad_ptrs, gate_grad, mask=mask_bh) + + +@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"x_grad"}) +def swiglu_unchunked_backward_triton( + x: torch.Tensor, + output_grad: torch.Tensor, + x_grad: torch.Tensor, + BLOCK_SIZE_B: int, + BLOCK_SIZE_H: int, +) -> None: + B, H = get_num_elements_and_hidden_size(x) + + with torch.device(x.device): + _swiglu_unchunked_backward_triton_kernel[(ceil_divide(B, BLOCK_SIZE_B), ceil_divide(H, BLOCK_SIZE_H))]( + x_ptr=x, + output_grad_ptr=output_grad, + x_grad_ptr=x_grad, + B=B, + H=H, + BLOCK_SIZE_B=BLOCK_SIZE_B, + BLOCK_SIZE_H=BLOCK_SIZE_H, + ) diff --git a/cute_kernels/kernels/swiglu_unchunked/triton_implementation/kernels_forward.py b/cute_kernels/kernels/swiglu_unchunked/triton_implementation/forward.py similarity index 53% rename from cute_kernels/kernels/swiglu_unchunked/triton_implementation/kernels_forward.py rename to cute_kernels/kernels/swiglu_unchunked/triton_implementation/forward.py index 217ba4bd..a34f6806 100644 --- a/cute_kernels/kernels/swiglu_unchunked/triton_implementation/kernels_forward.py +++ b/cute_kernels/kernels/swiglu_unchunked/triton_implementation/forward.py @@ -1,6 +1,14 @@ +import torch import triton import triton.language as tl +from ....constants import LIBRARY_NAME +from ....math import ceil_divide +from ....utils import cute_op, get_num_elements_and_hidden_size + + +_KERNEL_NAME = "swiglu_unchunked_forward_triton" + @triton.jit def _swiglu_unchunked_forward_triton_kernel( @@ -28,3 +36,20 @@ def _swiglu_unchunked_forward_triton_kernel( output_ptrs = output_ptr + indices_b[:, None] * half_H + indices_h[None, :] tl.store(output_ptrs, output, mask=mask_bh) + + +@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) +def swiglu_unchunked_forward_triton( + x: torch.Tensor, output: torch.Tensor, BLOCK_SIZE_B: int, BLOCK_SIZE_H: int +) -> None: + B, H = get_num_elements_and_hidden_size(x) + + with torch.device(x.device): + _swiglu_unchunked_forward_triton_kernel[(ceil_divide(B, BLOCK_SIZE_B), ceil_divide(H, BLOCK_SIZE_H))]( + x_ptr=x, + output_ptr=output, + B=B, + H=H, + BLOCK_SIZE_B=BLOCK_SIZE_B, + BLOCK_SIZE_H=BLOCK_SIZE_H, + )