From d3fc7c0ee163b015d0f3555b24de85214306903b Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 24 Oct 2024 16:18:27 +0800 Subject: [PATCH] [Kernel] add kernel for FATReLU (#9610) Signed-off-by: Jee Jee Li Signed-off-by: NickLucche --- csrc/activation_kernels.cu | 42 ++++++++++++++++++++++++ csrc/ops.h | 3 ++ csrc/torch_bindings.cpp | 4 +++ tests/kernels/test_activation.py | 23 +++++++++---- vllm/_custom_ops.py | 6 ++++ vllm/model_executor/layers/activation.py | 8 ++++- 6 files changed, 78 insertions(+), 8 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 5ed1dc3b8f792..839dc36ba4e29 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -89,6 +89,48 @@ void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] namespace vllm { +template +__device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) { + const float f = (float)x; + return (T)(f > threshold ? f : 0.0f); +} + +template +__global__ void act_and_mul_kernel_with_param( + scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d, + const float param) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); + const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); + out[token_idx * d + idx] = ACT_FN(x, param) * y; + } +} + +} // namespace vllm + +#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \ + vllm::act_and_mul_kernel_with_param> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d, \ + PARAM); \ + }); + +void fatrelu_and_mul(torch::Tensor& out, // [..., d], + torch::Tensor& input, // [..., 2 * d] + double threshold) { + LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold); +} +namespace vllm { + // Element-wise activation kernel template. template __global__ void activation_kernel( diff --git a/csrc/ops.h b/csrc/ops.h index c10c34e085750..11a2970695545 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -48,6 +48,9 @@ void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); +void fatrelu_and_mul(torch::Tensor& out, torch::Tensor& input, + double threshold); + void gelu_new(torch::Tensor& out, torch::Tensor& input); void gelu_fast(torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b999028fe06a9..826f918c82e78 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -60,6 +60,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("gelu_tanh_and_mul(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); + // FATReLU implementation. + ops.def("fatrelu_and_mul(Tensor! out, Tensor input, float threshold) -> ()"); + ops.impl("fatrelu_and_mul", torch::kCUDA, &fatrelu_and_mul); + // GELU implementation used in GPT-2. ops.def("gelu_new(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_new", torch::kCUDA, &gelu_new); diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index 9b476585fa19e..0e3d3c3a2e987 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -1,12 +1,13 @@ +import random from typing import Type import pytest import torch from tests.kernels.utils import opcheck -from vllm.model_executor.layers.activation import (FastGELU, GeluAndMul, - NewGELU, QuickGELU, - SiluAndMul) +from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul, + GeluAndMul, NewGELU, + QuickGELU, SiluAndMul) from vllm.utils import seed_everything from .allclose_default import get_default_atol, get_default_rtol @@ -20,7 +21,8 @@ ] -@pytest.mark.parametrize("activation", ["silu", "gelu", "gelu_tanh"]) +@pytest.mark.parametrize("activation", + ["silu", "gelu", "gelu_tanh", "fatrelu"]) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @@ -47,16 +49,23 @@ def test_act_and_mul( elif activation == "gelu_tanh": layer = GeluAndMul(approximate="tanh") fn = torch.ops._C.gelu_tanh_and_mul + elif activation == "fatrelu": + threshold = random.uniform(0, 1) + layer = FatreluAndMul(threshold) + fn = torch.ops._C.fatrelu_and_mul out = layer(x) ref_out = layer.forward_native(x) - # The SiLU and GELU implementations are equivalent to the native PyTorch - # implementations, so we can do exact comparison. + # The SiLU, GELU and FatReLU implementations are equivalent to the native + # PyTorch implementations, so we can do exact comparison. torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0) d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - opcheck(fn, (out, x)) + if activation == "fatrelu": + opcheck(fn, (out, x, threshold)) + else: + opcheck(fn, (out, x)) @pytest.mark.parametrize("activation", [(FastGELU, torch.ops._C.gelu_fast), diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index a25f7abca5498..60f458096c70c 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -79,6 +79,12 @@ def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: torch.ops._C.gelu_tanh_and_mul(out, x) +def fatrelu_and_mul(out: torch.Tensor, + x: torch.Tensor, + threshold: float = 0.0) -> None: + torch.ops._C.fatrelu_and_mul(out, x, threshold) + + def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: torch.ops._C.gelu_fast(out, x) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 8de3385a257f8..658a3700f33d6 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -39,7 +39,13 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor: return x1 * x2 def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - return self.forward_native(x) + from vllm import _custom_ops as ops + + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + ops.fatrelu_and_mul(out, x, self.threshold) + return out @CustomOp.register("silu_and_mul")