From 13da3e45a7f8904689f9ba7edd20301bdd48f08f Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 19 Dec 2024 11:08:42 -0500 Subject: [PATCH 1/3] cleanup swiglu Signed-off-by: Mayank Mishra --- .../swiglu/cuda_implementation/__init__.py | 33 +++++++++++++++++-- .../cuda_implementation/kernels_backward.py | 23 ------------- .../cuda_implementation/kernels_forward.py | 15 --------- 3 files changed, 31 insertions(+), 40 deletions(-) delete mode 100644 cute_kernels/kernels/swiglu/cuda_implementation/kernels_backward.py delete mode 100644 cute_kernels/kernels/swiglu/cuda_implementation/kernels_forward.py diff --git a/cute_kernels/kernels/swiglu/cuda_implementation/__init__.py b/cute_kernels/kernels/swiglu/cuda_implementation/__init__.py index 2c61998a..ed807415 100644 --- a/cute_kernels/kernels/swiglu/cuda_implementation/__init__.py +++ b/cute_kernels/kernels/swiglu/cuda_implementation/__init__.py @@ -1,2 +1,31 @@ -from .kernels_backward import swiglu_backward_cuda -from .kernels_forward import swiglu_forward_cuda +import torch + +from ....constants import LIBRARY_NAME +from ....kernel_registry import KernelRegistry +from ....utils import cute_op + + +_FORWARD_KERNEL_NAME = "swiglu_forward_cuda" +_BACKWARD_KERNEL_NAME = "swiglu_backward_cuda" + + +@cute_op(f"{LIBRARY_NAME}::{_FORWARD_KERNEL_NAME}", mutates_args={"output"}) +def swiglu_forward_cuda( + gate: torch.Tensor, up: torch.Tensor, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int +) -> None: + KernelRegistry.get_kernel(_FORWARD_KERNEL_NAME)(gate, up, output, vector_instruction_width, BLOCK_SIZE) + + +@cute_op(f"{LIBRARY_NAME}::{_BACKWARD_KERNEL_NAME}", mutates_args={"gate_grad", "up_grad"}) +def swiglu_backward_cuda( + gate: torch.Tensor, + up: torch.Tensor, + output_grad: torch.Tensor, + gate_grad: torch.Tensor, + up_grad: torch.Tensor, + vector_instruction_width: int, + BLOCK_SIZE: int, +) -> None: + KernelRegistry.get_kernel(_BACKWARD_KERNEL_NAME)( + gate, up, output_grad, gate_grad, up_grad, vector_instruction_width, BLOCK_SIZE + ) diff --git a/cute_kernels/kernels/swiglu/cuda_implementation/kernels_backward.py b/cute_kernels/kernels/swiglu/cuda_implementation/kernels_backward.py deleted file mode 100644 index d3664a96..00000000 --- a/cute_kernels/kernels/swiglu/cuda_implementation/kernels_backward.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch - -from ....constants import LIBRARY_NAME -from ....kernel_registry import KernelRegistry -from ....utils import cute_op - - -_KERNEL_NAME = "swiglu_backward_cuda" - - -@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"gate_grad", "up_grad"}) -def swiglu_backward_cuda( - gate: torch.Tensor, - up: torch.Tensor, - output_grad: torch.Tensor, - gate_grad: torch.Tensor, - up_grad: torch.Tensor, - vector_instruction_width: int, - BLOCK_SIZE: int, -) -> None: - KernelRegistry.get_kernel(_KERNEL_NAME)( - gate, up, output_grad, gate_grad, up_grad, vector_instruction_width, BLOCK_SIZE - ) diff --git a/cute_kernels/kernels/swiglu/cuda_implementation/kernels_forward.py b/cute_kernels/kernels/swiglu/cuda_implementation/kernels_forward.py deleted file mode 100644 index 472275a6..00000000 --- a/cute_kernels/kernels/swiglu/cuda_implementation/kernels_forward.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch - -from ....constants import LIBRARY_NAME -from ....kernel_registry import KernelRegistry -from ....utils import cute_op - - -_KERNEL_NAME = "swiglu_forward_cuda" - - -@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) -def swiglu_forward_cuda( - gate: torch.Tensor, up: torch.Tensor, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int -) -> None: - KernelRegistry.get_kernel(_KERNEL_NAME)(gate, up, output, vector_instruction_width, BLOCK_SIZE) From 1c31acd9e67954852cc7ffd66b6a8ebf34cbbc2c Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 19 Dec 2024 11:09:48 -0500 Subject: [PATCH 2/3] cleanup add Signed-off-by: Mayank Mishra --- .../add_scalar/cuda_implementation/__init__.py | 16 +++++++++++++++- .../cuda_implementation/kernels_forward.py | 15 --------------- .../add_tensor/cuda_implementation/__init__.py | 16 +++++++++++++++- .../cuda_implementation/kernels_forward.py | 15 --------------- 4 files changed, 30 insertions(+), 32 deletions(-) delete mode 100644 cute_kernels/kernels/add/add_scalar/cuda_implementation/kernels_forward.py delete mode 100644 cute_kernels/kernels/add/add_tensor/cuda_implementation/kernels_forward.py diff --git a/cute_kernels/kernels/add/add_scalar/cuda_implementation/__init__.py b/cute_kernels/kernels/add/add_scalar/cuda_implementation/__init__.py index 31c0502f..9f0a02c9 100644 --- a/cute_kernels/kernels/add/add_scalar/cuda_implementation/__init__.py +++ b/cute_kernels/kernels/add/add_scalar/cuda_implementation/__init__.py @@ -1 +1,15 @@ -from .kernels_forward import add_scalar_forward_cuda +import torch + +from .....constants import LIBRARY_NAME +from .....kernel_registry import KernelRegistry +from .....utils import cute_op + + +_KERNEL_NAME = "add_scalar_forward_cuda" + + +@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) +def add_scalar_forward_cuda( + x: torch.Tensor, y: float, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int +) -> None: + KernelRegistry.get_kernel(_KERNEL_NAME)(x, y, output, vector_instruction_width, BLOCK_SIZE) diff --git a/cute_kernels/kernels/add/add_scalar/cuda_implementation/kernels_forward.py b/cute_kernels/kernels/add/add_scalar/cuda_implementation/kernels_forward.py deleted file mode 100644 index 9f0a02c9..00000000 --- a/cute_kernels/kernels/add/add_scalar/cuda_implementation/kernels_forward.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch - -from .....constants import LIBRARY_NAME -from .....kernel_registry import KernelRegistry -from .....utils import cute_op - - -_KERNEL_NAME = "add_scalar_forward_cuda" - - -@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) -def add_scalar_forward_cuda( - x: torch.Tensor, y: float, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int -) -> None: - KernelRegistry.get_kernel(_KERNEL_NAME)(x, y, output, vector_instruction_width, BLOCK_SIZE) diff --git a/cute_kernels/kernels/add/add_tensor/cuda_implementation/__init__.py b/cute_kernels/kernels/add/add_tensor/cuda_implementation/__init__.py index 3c69e774..e87c9b65 100644 --- a/cute_kernels/kernels/add/add_tensor/cuda_implementation/__init__.py +++ b/cute_kernels/kernels/add/add_tensor/cuda_implementation/__init__.py @@ -1 +1,15 @@ -from .kernels_forward import add_tensor_forward_cuda +import torch + +from .....constants import LIBRARY_NAME +from .....kernel_registry import KernelRegistry +from .....utils import cute_op + + +_KERNEL_NAME = "add_tensor_forward_cuda" + + +@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) +def add_tensor_forward_cuda( + x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int +) -> None: + KernelRegistry.get_kernel(_KERNEL_NAME)(x, y, output, vector_instruction_width, BLOCK_SIZE) diff --git a/cute_kernels/kernels/add/add_tensor/cuda_implementation/kernels_forward.py b/cute_kernels/kernels/add/add_tensor/cuda_implementation/kernels_forward.py deleted file mode 100644 index e87c9b65..00000000 --- a/cute_kernels/kernels/add/add_tensor/cuda_implementation/kernels_forward.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch - -from .....constants import LIBRARY_NAME -from .....kernel_registry import KernelRegistry -from .....utils import cute_op - - -_KERNEL_NAME = "add_tensor_forward_cuda" - - -@cute_op(f"{LIBRARY_NAME}::{_KERNEL_NAME}", mutates_args={"output"}) -def add_tensor_forward_cuda( - x: torch.Tensor, y: torch.Tensor, output: torch.Tensor, vector_instruction_width: int, BLOCK_SIZE: int -) -> None: - KernelRegistry.get_kernel(_KERNEL_NAME)(x, y, output, vector_instruction_width, BLOCK_SIZE) From 89b007cc5090384d251b9d5ee7ac34c9244ca7f8 Mon Sep 17 00:00:00 2001 From: Mayank Mishra Date: Thu, 19 Dec 2024 11:11:45 -0500 Subject: [PATCH 3/3] cleanup add Signed-off-by: Mayank Mishra --- .../kernels/contiguous_count/__init__.py | 5 +- .../cuda_implementation/__init__.py | 0 .../cuda_implementation/kernels_forward.cu | 170 ++++++++++++++++++ .../cuda_implementation/ops.cpp | 11 ++ 4 files changed, 185 insertions(+), 1 deletion(-) create mode 100644 cute_kernels/kernels/contiguous_count/cuda_implementation/__init__.py create mode 100644 cute_kernels/kernels/contiguous_count/cuda_implementation/kernels_forward.cu create mode 100644 cute_kernels/kernels/contiguous_count/cuda_implementation/ops.cpp diff --git a/cute_kernels/kernels/contiguous_count/__init__.py b/cute_kernels/kernels/contiguous_count/__init__.py index 5833f941..2836ca07 100644 --- a/cute_kernels/kernels/contiguous_count/__init__.py +++ b/cute_kernels/kernels/contiguous_count/__init__.py @@ -2,6 +2,7 @@ from ...enums import KernelBackend from ...utils import ensure_contiguous +from .cuda_implementation import contiguous_count_cuda from .triton_implementation import contiguous_count_triton @@ -16,7 +17,9 @@ def contiguous_count_cute( assert x.dim() == 1, "x should be 1-dimensional" assert x.dtype in [torch.int32, torch.long] - if kernel_backend == KernelBackend.triton: + if kernel_backend == KernelBackend.cuda: + output = contiguous_count_cuda(x=x, size=size, BLOCK_SIZE_B=BLOCK_SIZE_B) + elif kernel_backend == KernelBackend.triton: output = contiguous_count_triton(x=x, size=size, BLOCK_SIZE_B=BLOCK_SIZE_B) else: raise ValueError(f"unexpected kernel_backend ({kernel_backend})") diff --git a/cute_kernels/kernels/contiguous_count/cuda_implementation/__init__.py b/cute_kernels/kernels/contiguous_count/cuda_implementation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cute_kernels/kernels/contiguous_count/cuda_implementation/kernels_forward.cu b/cute_kernels/kernels/contiguous_count/cuda_implementation/kernels_forward.cu new file mode 100644 index 00000000..5e9f0598 --- /dev/null +++ b/cute_kernels/kernels/contiguous_count/cuda_implementation/kernels_forward.cu @@ -0,0 +1,170 @@ +#include +#include +#include +#include + +#include "../../../include/activations.h" +#include "../../../include/dtypes/all.h" +#include "../../../include/threads.h" + +template +__global__ void _swiglu_forward_cuda_kernel(const scalar_t *gate, + const scalar_t *up, + scalar_t *output, + const int64_t num_elements) { + constexpr int vector_instruction_width = sizeof(vector_t) / sizeof(scalar_t); + static_assert(vector_instruction_width == 1 || vector_instruction_width == 2 || vector_instruction_width == 4 || + vector_instruction_width == 8); + + const int64_t thread_id = get_global_thread_id(); + using dtype = DType; + + if constexpr (vector_instruction_width == 1) { + if (thread_id < num_elements) { + fp32 _gate_upcast = dtype::upcast(gate[thread_id]); + + // up is upcasted automatically + _gate_upcast = up[thread_id] * _gate_upcast * sigmoid(_gate_upcast); + output[thread_id] = dtype::downcast(_gate_upcast); + } + } else { + int64_t end = (thread_id + 1) * vector_instruction_width - 1; // inclusive of last element + + if (end < num_elements) { + vector_t *output_vec = (vector_t *)output; + + if constexpr (std::is_same_v) { + const fp32 *gate_vec = (fp32 *)&((vector_t *)gate)[thread_id]; + const fp32 *up_vec = (fp32 *)&((vector_t *)up)[thread_id]; + fp32 output_buffer[vector_instruction_width]; + + // clang-format off + #pragma unroll + // clang-format on + for (int i = 0; i < vector_instruction_width; i++) { + output_buffer[i] = up_vec[i] * gate_vec[i] * sigmoid(gate_vec[i]); + } + + if constexpr (vector_instruction_width == 2) { + output_vec[thread_id] = dtype::make2(output_buffer); + } else if constexpr (vector_instruction_width == 4) { + output_vec[thread_id] = dtype::make4(output_buffer); + } else { + static_assert("vector_instruction_width is invalid for fp32"); + } + } else { + using T2 = typename dtype::nv_dtype2; + + if constexpr (vector_instruction_width == 2) { + T2 _gate = ((vector_t *)gate)[thread_id]; + T2 _up = ((vector_t *)up)[thread_id]; + + fp32_2 _gate_upcast = dtype::upcast(_gate); + fp32_2 _up_upcast = dtype::upcast(_up); + + _gate_upcast = + DType::make2(_up_upcast.x * _gate_upcast.x * sigmoid(_gate_upcast.x), + _up_upcast.y * _gate_upcast.y * sigmoid(_gate_upcast.y)); + + output_vec[thread_id] = dtype::downcast(_gate_upcast); + } else { + const fp32 *gate_vec = (fp32 *)&((vector_t *)gate)[thread_id]; + const fp32 *up_vec = (fp32 *)&((vector_t *)up)[thread_id]; + + const int n = vector_instruction_width >> 1; + fp32 output_buffer[n]; + + // clang-format off + #pragma unroll + // clang-format on + for (int i = 0; i < n; i++) { + fp32_2 _gate_upcast = dtype::upcast(dtype::reinterpret_32_bits_as_2x16(gate_vec[i])); + fp32_2 _up_upcast = dtype::upcast(dtype::reinterpret_32_bits_as_2x16(up_vec[i])); + + _gate_upcast = + DType::make2(_up_upcast.x * _gate_upcast.x * sigmoid(_gate_upcast.x), + _up_upcast.y * _gate_upcast.y * sigmoid(_gate_upcast.y)); + + output_buffer[i] = dtype::reinterpret_2x16_as_32_bits(dtype::downcast(_gate_upcast)); + } + + if constexpr (vector_instruction_width == 4) { + output_vec[thread_id] = DType::make2(output_buffer); + } else if constexpr (vector_instruction_width == 8) { + output_vec[thread_id] = DType::make4(output_buffer); + } else { + static_assert("vector_instruction_width is invalid for fp16 & bf16"); + } + } + } + } + + // use first warp for computing the last elements + if (thread_id < WARP_SIZE) { + // NOTE end is same as start since we don't use vector load stores here + end = (num_elements / vector_instruction_width) * vector_instruction_width + thread_id; + if (end < num_elements) { + fp32 _gate_upcast = dtype::upcast(gate[end]); + + // up is upcasted automatically + _gate_upcast = up[end] * _gate_upcast * sigmoid(_gate_upcast); + output[end] = dtype::downcast(_gate_upcast); + } + } + } +} + +void swiglu_forward_cuda(const torch::Tensor &gate, + const torch::Tensor &up, + torch::Tensor &output, + const int &vector_instruction_width, + const int &BLOCK_SIZE) { + const int64_t num_elements = gate.numel(); + + AT_DISPATCH_CUSTOM_FLOAT_TYPES( + gate.scalar_type(), "swiglu_forward_cuda_kernel", ([&] { + const int num_elements_per_block = BLOCK_SIZE * vector_instruction_width; + const int NUM_BLOCKS = (num_elements + num_elements_per_block - 1) / num_elements_per_block; + + switch (vector_instruction_width) { + case 1: + _swiglu_forward_cuda_kernel<<>>( + gate.data_ptr(), up.data_ptr(), output.data_ptr(), num_elements); + break; + case 2: + using vector_t = typename DType::nv_dtype2; + _swiglu_forward_cuda_kernel<<>>( + gate.data_ptr(), up.data_ptr(), output.data_ptr(), num_elements); + break; + case 4: + if constexpr (std::is_same_v) { + _swiglu_forward_cuda_kernel + <<>>(gate.data_ptr(), + up.data_ptr(), + output.data_ptr(), + num_elements); + } else { + _swiglu_forward_cuda_kernel + <<>>(gate.data_ptr(), + up.data_ptr(), + output.data_ptr(), + num_elements); + } + break; + case 8: + if constexpr (std::is_same_v) { + throw std::runtime_error("fp32 doesn't support vector_instruction_width = 8"); + } else { + _swiglu_forward_cuda_kernel + <<>>(gate.data_ptr(), + up.data_ptr(), + output.data_ptr(), + num_elements); + } + break; + default: + throw std::runtime_error("invalid vector_instruction_width"); + break; + } + })); +} diff --git a/cute_kernels/kernels/contiguous_count/cuda_implementation/ops.cpp b/cute_kernels/kernels/contiguous_count/cuda_implementation/ops.cpp new file mode 100644 index 00000000..b2ba4e18 --- /dev/null +++ b/cute_kernels/kernels/contiguous_count/cuda_implementation/ops.cpp @@ -0,0 +1,11 @@ +#include + +void swiglu_forward_cuda(const torch::Tensor &gate, + const torch::Tensor &up, + torch::Tensor &output, + const int &vector_instruction_width, + const int &BLOCK_SIZE); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("swiglu_forward_cuda", &swiglu_forward_cuda, "SwiGLU forward (CUDA)"); +}