From 84c57696f3e92e9a4493bff2ce3dd5b31ed1c62e Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 15 Jan 2025 10:29:53 +0800 Subject: [PATCH] [Kernel] Support MulAndSilu (#11624) Signed-off-by: Jee Jee Li Signed-off-by: ice-tong --- csrc/activation_kernels.cu | 32 +++++++++++++++++----- csrc/ops.h | 2 ++ csrc/torch_bindings.cpp | 3 ++ tests/kernels/test_activation.py | 20 +++++++++----- vllm/model_executor/layers/activation.py | 35 ++++++++++++++++++++++++ vllm/model_executor/models/molmo.py | 14 ++-------- vllm/model_executor/models/ultravox.py | 13 ++------- 7 files changed, 83 insertions(+), 36 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 839dc36ba4e29..88275dbdd83a1 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -9,8 +9,16 @@ namespace vllm { +template +__device__ __forceinline__ scalar_t compute(const scalar_t& x, + const scalar_t& y) { + return act_first ? ACT_FN(x) * y : x * ACT_FN(y); +} // Activation and gating kernel template. -template + +template __global__ void act_and_mul_kernel( scalar_t* __restrict__ out, // [..., d] const scalar_t* __restrict__ input, // [..., 2, d] @@ -19,7 +27,7 @@ __global__ void act_and_mul_kernel( for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); - out[token_idx * d + idx] = ACT_FN(x) * y; + out[token_idx * d + idx] = compute(x, y); } } @@ -55,7 +63,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { } // namespace vllm // Launch activation and gating kernel. -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ +// Use ACT_FIRST (bool) indicating whether to apply the activation function +// first. +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \ int d = input.size(-1) / 2; \ int64_t num_tokens = input.numel() / input.size(-1); \ dim3 grid(num_tokens); \ @@ -64,7 +74,7 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ VLLM_DISPATCH_FLOATING_TYPES( \ input.scalar_type(), "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_kernel> \ + vllm::act_and_mul_kernel, ACT_FIRST> \ <<>>(out.data_ptr(), \ input.data_ptr(), d); \ }); @@ -72,19 +82,27 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { void silu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { - LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true); +} + +void mul_and_silu(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + // The difference between mul_and_silu and silu_and_mul is that mul_and_silu + // applies the silu to the latter half of the input. + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false); } void gelu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { - LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true); } void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { - LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true); } namespace vllm { diff --git a/csrc/ops.h b/csrc/ops.h index 9efd9b0c24700..5a194a0dd3654 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -86,6 +86,8 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void silu_and_mul(torch::Tensor& out, torch::Tensor& input); +void mul_and_silu(torch::Tensor& out, torch::Tensor& input); + void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 956258c1001d3..fb53d122487d3 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -55,6 +55,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); + ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu); + // Activation function used in GeGLU with `none` approximation. ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()"); ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index a84501f9c303f..dac26efe866b8 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -6,8 +6,9 @@ from tests.kernels.utils import opcheck from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul, - GeluAndMul, NewGELU, - QuickGELU, SiluAndMul) + GeluAndMul, MulAndSilu, + NewGELU, QuickGELU, + SiluAndMul) from vllm.platforms import current_platform from .allclose_default import get_default_atol, get_default_rtol @@ -21,8 +22,9 @@ ] -@pytest.mark.parametrize("activation", - ["silu", "gelu", "gelu_tanh", "fatrelu"]) +@pytest.mark.parametrize( + "activation", + ["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"]) @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @@ -40,9 +42,12 @@ def test_act_and_mul( current_platform.seed_everything(seed) torch.set_default_device(device) x = torch.randn(num_tokens, 2 * d, dtype=dtype) - if activation == "silu": + if activation == "silu_and_mul": layer = SiluAndMul() fn = torch.ops._C.silu_and_mul + if activation == "mul_and_silu": + layer = MulAndSilu() + fn = torch.ops._C.mul_and_silu elif activation == "gelu": layer = GeluAndMul(approximate="none") fn = torch.ops._C.gelu_and_mul @@ -55,8 +60,9 @@ def test_act_and_mul( fn = torch.ops._C.fatrelu_and_mul out = layer(x) ref_out = layer.forward_native(x) - # The SiLU, GELU and FatReLU implementations are equivalent to the native - # PyTorch implementations, so we can do exact comparison. + # The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are + # equivalent to the native PyTorch implementations, so we can do exact + # comparison. torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0) d = x.shape[-1] // 2 diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 2475190d197d3..af7894b42c560 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -87,6 +87,41 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: return out +@CustomOp.register("mul_and_silu") +class MulAndSilu(CustomOp): + """An activation function for SwiGLU. + + The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2. + + Shapes: + x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d) + return: (num_tokens, d) or (batch_size, seq_len, d) + """ + + def __init__(self): + super().__init__() + if current_platform.is_cuda_alike() or current_platform.is_cpu(): + self.op = torch.ops._C.mul_and_silu + elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops + self.op = ipex_ops.silu_and_mul + + def forward_native(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + d = x.shape[-1] // 2 + return x[..., :d] * F.silu(x[..., d:]) + + def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + self.op(out, x) + return out + + # TODO implement forward_xpu for MulAndSilu + # def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: + + @CustomOp.register("gelu_and_mul") class GeluAndMul(CustomOp): """An activation function for GeGLU. diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index a2fd1701316f2..5c7ae0deefcd8 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -23,7 +23,8 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.model_executor import SamplingMetadata -from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul +from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU, + SiluAndMul) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -462,15 +463,6 @@ def forward( return output -class SwiGLU(nn.Module): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x, gate = x.chunk(2, dim=-1) - # Note that the order is reversed compared to - # SiluAndMul. - return x * F.silu(gate) - - class LanuageModelMLP(nn.Module): """Molmo's LLM mlp.""" @@ -489,7 +481,7 @@ def __init__(self, quant_config=quant_config, ) # Activation function. - self.act_fn = SwiGLU() + self.act_fn = MulAndSilu() # Feed-forward output projection. self.down_proj = RowParallelLinear( self.intermediate_size, diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 3edfb5107683a..587f18ccaf98f 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -16,7 +16,7 @@ from vllm import envs from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn +from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.loader import DefaultModelLoader @@ -248,15 +248,6 @@ def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor: return audio_embeds -class FlippedSiluAndMul(SiluAndMul): - """Ultravox is trained with SwiGLU with flipped halves.""" - - def forward(self, x: torch.Tensor): - a, b = x.chunk(2, dim=-1) - flipped = torch.cat((b, a), dim=-1) - return super().forward(flipped) - - class UltravoxProjector(nn.Module): def __init__(self, config: UltravoxConfig): @@ -269,7 +260,7 @@ def __init__(self, config: UltravoxConfig): dim = self.hidden_dim if config.projector_act == "swiglu": - self.act = FlippedSiluAndMul() + self.act = MulAndSilu() dim = dim // 2 else: self.act = get_act_fn(config.projector_act)