Skip to content

Commit

Permalink
[Kernel] Support MulAndSilu (vllm-project#11624)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
Signed-off-by: ice-tong <[email protected]>
  • Loading branch information
jeejeelee authored and ice-tong committed Jan 18, 2025
1 parent 70dbb2e commit 84c5769
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 36 deletions.
32 changes: 25 additions & 7 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,16 @@

namespace vllm {

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
bool act_first>
__device__ __forceinline__ scalar_t compute(const scalar_t& x,
const scalar_t& y) {
return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
}
// Activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>

template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
bool act_first>
__global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
Expand All @@ -19,7 +27,7 @@ __global__ void act_and_mul_kernel(
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = ACT_FN(x) * y;
out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
}
}

Expand Down Expand Up @@ -55,7 +63,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
} // namespace vllm

// Launch activation and gating kernel.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
// Use ACT_FIRST (bool) indicating whether to apply the activation function
// first.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
Expand All @@ -64,27 +74,35 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>, ACT_FIRST> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
});

void silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true);
}

void mul_and_silu(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
// The difference between mul_and_silu and silu_and_mul is that mul_and_silu
// applies the silu to the latter half of the input.
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false);
}

void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true);
}

void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true);
}

namespace vllm {
Expand Down
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,

void silu_and_mul(torch::Tensor& out, torch::Tensor& input);

void mul_and_silu(torch::Tensor& out, torch::Tensor& input);

void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);

void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);
Expand Down
3 changes: 3 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);

ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);

// Activation function used in GeGLU with `none` approximation.
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
Expand Down
20 changes: 13 additions & 7 deletions tests/kernels/test_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from tests.kernels.utils import opcheck
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
GeluAndMul, NewGELU,
QuickGELU, SiluAndMul)
GeluAndMul, MulAndSilu,
NewGELU, QuickGELU,
SiluAndMul)
from vllm.platforms import current_platform

from .allclose_default import get_default_atol, get_default_rtol
Expand All @@ -21,8 +22,9 @@
]


@pytest.mark.parametrize("activation",
["silu", "gelu", "gelu_tanh", "fatrelu"])
@pytest.mark.parametrize(
"activation",
["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"])
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
Expand All @@ -40,9 +42,12 @@ def test_act_and_mul(
current_platform.seed_everything(seed)
torch.set_default_device(device)
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
if activation == "silu":
if activation == "silu_and_mul":
layer = SiluAndMul()
fn = torch.ops._C.silu_and_mul
if activation == "mul_and_silu":
layer = MulAndSilu()
fn = torch.ops._C.mul_and_silu
elif activation == "gelu":
layer = GeluAndMul(approximate="none")
fn = torch.ops._C.gelu_and_mul
Expand All @@ -55,8 +60,9 @@ def test_act_and_mul(
fn = torch.ops._C.fatrelu_and_mul
out = layer(x)
ref_out = layer.forward_native(x)
# The SiLU, GELU and FatReLU implementations are equivalent to the native
# PyTorch implementations, so we can do exact comparison.
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
# equivalent to the native PyTorch implementations, so we can do exact
# comparison.
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)

d = x.shape[-1] // 2
Expand Down
35 changes: 35 additions & 0 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,41 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return out


@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
"""An activation function for SwiGLU.
The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""

def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_cpu():
self.op = torch.ops._C.mul_and_silu
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
self.op = ipex_ops.silu_and_mul

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return x[..., :d] * F.silu(x[..., d:])

def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
self.op(out, x)
return out

# TODO implement forward_xpu for MulAndSilu
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:


@CustomOp.register("gelu_and_mul")
class GeluAndMul(CustomOp):
"""An activation function for GeGLU.
Expand Down
14 changes: 3 additions & 11 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU,
SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
Expand Down Expand Up @@ -462,15 +463,6 @@ def forward(
return output


class SwiGLU(nn.Module):

def forward(self, x: torch.Tensor) -> torch.Tensor:
x, gate = x.chunk(2, dim=-1)
# Note that the order is reversed compared to
# SiluAndMul.
return x * F.silu(gate)


class LanuageModelMLP(nn.Module):
"""Molmo's LLM mlp."""

Expand All @@ -489,7 +481,7 @@ def __init__(self,
quant_config=quant_config,
)
# Activation function.
self.act_fn = SwiGLU()
self.act_fn = MulAndSilu()
# Feed-forward output projection.
self.down_proj = RowParallelLinear(
self.intermediate_size,
Expand Down
13 changes: 2 additions & 11 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from vllm import envs
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
Expand Down Expand Up @@ -248,15 +248,6 @@ def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
return audio_embeds


class FlippedSiluAndMul(SiluAndMul):
"""Ultravox is trained with SwiGLU with flipped halves."""

def forward(self, x: torch.Tensor):
a, b = x.chunk(2, dim=-1)
flipped = torch.cat((b, a), dim=-1)
return super().forward(flipped)


class UltravoxProjector(nn.Module):

def __init__(self, config: UltravoxConfig):
Expand All @@ -269,7 +260,7 @@ def __init__(self, config: UltravoxConfig):
dim = self.hidden_dim

if config.projector_act == "swiglu":
self.act = FlippedSiluAndMul()
self.act = MulAndSilu()
dim = dim // 2
else:
self.act = get_act_fn(config.projector_act)
Expand Down

0 comments on commit 84c5769

Please sign in to comment.