Skip to content

Commit bb13b8a

Browse files
jeejeeleeHwwwwwwwH
authored andcommitted
[Kernel] Support MulAndSilu (vllm-project#11624)
Signed-off-by: Jee Jee Li <[email protected]> Signed-off-by: hzh <[email protected]>
1 parent 58d45cd commit bb13b8a

File tree

7 files changed

+83
-36
lines changed

7 files changed

+83
-36
lines changed

csrc/activation_kernels.cu

+25-7
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,16 @@
99

1010
namespace vllm {
1111

12+
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
13+
bool act_first>
14+
__device__ __forceinline__ scalar_t compute(const scalar_t& x,
15+
const scalar_t& y) {
16+
return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
17+
}
1218
// Activation and gating kernel template.
13-
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
19+
20+
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
21+
bool act_first>
1422
__global__ void act_and_mul_kernel(
1523
scalar_t* __restrict__ out, // [..., d]
1624
const scalar_t* __restrict__ input, // [..., 2, d]
@@ -19,7 +27,7 @@ __global__ void act_and_mul_kernel(
1927
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
2028
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
2129
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
22-
out[token_idx * d + idx] = ACT_FN(x) * y;
30+
out[token_idx * d + idx] = compute<scalar_t, ACT_FN, act_first>(x, y);
2331
}
2432
}
2533

@@ -55,7 +63,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
5563
} // namespace vllm
5664

5765
// Launch activation and gating kernel.
58-
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \
66+
// Use ACT_FIRST (bool) indicating whether to apply the activation function
67+
// first.
68+
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \
5969
int d = input.size(-1) / 2; \
6070
int64_t num_tokens = input.numel() / input.size(-1); \
6171
dim3 grid(num_tokens); \
@@ -64,27 +74,35 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
6474
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
6575
VLLM_DISPATCH_FLOATING_TYPES( \
6676
input.scalar_type(), "act_and_mul_kernel", [&] { \
67-
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>> \
77+
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>, ACT_FIRST> \
6878
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
6979
input.data_ptr<scalar_t>(), d); \
7080
});
7181

7282
void silu_and_mul(torch::Tensor& out, // [..., d]
7383
torch::Tensor& input) // [..., 2 * d]
7484
{
75-
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel);
85+
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true);
86+
}
87+
88+
void mul_and_silu(torch::Tensor& out, // [..., d]
89+
torch::Tensor& input) // [..., 2 * d]
90+
{
91+
// The difference between mul_and_silu and silu_and_mul is that mul_and_silu
92+
// applies the silu to the latter half of the input.
93+
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false);
7694
}
7795

7896
void gelu_and_mul(torch::Tensor& out, // [..., d]
7997
torch::Tensor& input) // [..., 2 * d]
8098
{
81-
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel);
99+
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true);
82100
}
83101

84102
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
85103
torch::Tensor& input) // [..., 2 * d]
86104
{
87-
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel);
105+
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true);
88106
}
89107

90108
namespace vllm {

csrc/ops.h

+2
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
8686

8787
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
8888

89+
void mul_and_silu(torch::Tensor& out, torch::Tensor& input);
90+
8991
void gelu_and_mul(torch::Tensor& out, torch::Tensor& input);
9092

9193
void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input);

csrc/torch_bindings.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
5555
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
5656
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
5757

58+
ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()");
59+
ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu);
60+
5861
// Activation function used in GeGLU with `none` approximation.
5962
ops.def("gelu_and_mul(Tensor! out, Tensor input) -> ()");
6063
ops.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);

tests/kernels/test_activation.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66

77
from tests.kernels.utils import opcheck
88
from vllm.model_executor.layers.activation import (FastGELU, FatreluAndMul,
9-
GeluAndMul, NewGELU,
10-
QuickGELU, SiluAndMul)
9+
GeluAndMul, MulAndSilu,
10+
NewGELU, QuickGELU,
11+
SiluAndMul)
1112
from vllm.platforms import current_platform
1213

1314
from .allclose_default import get_default_atol, get_default_rtol
@@ -21,8 +22,9 @@
2122
]
2223

2324

24-
@pytest.mark.parametrize("activation",
25-
["silu", "gelu", "gelu_tanh", "fatrelu"])
25+
@pytest.mark.parametrize(
26+
"activation",
27+
["silu_and_mul", "mul_and_silu", "gelu", "gelu_tanh", "fatrelu"])
2628
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
2729
@pytest.mark.parametrize("d", D)
2830
@pytest.mark.parametrize("dtype", DTYPES)
@@ -40,9 +42,12 @@ def test_act_and_mul(
4042
current_platform.seed_everything(seed)
4143
torch.set_default_device(device)
4244
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
43-
if activation == "silu":
45+
if activation == "silu_and_mul":
4446
layer = SiluAndMul()
4547
fn = torch.ops._C.silu_and_mul
48+
if activation == "mul_and_silu":
49+
layer = MulAndSilu()
50+
fn = torch.ops._C.mul_and_silu
4651
elif activation == "gelu":
4752
layer = GeluAndMul(approximate="none")
4853
fn = torch.ops._C.gelu_and_mul
@@ -55,8 +60,9 @@ def test_act_and_mul(
5560
fn = torch.ops._C.fatrelu_and_mul
5661
out = layer(x)
5762
ref_out = layer.forward_native(x)
58-
# The SiLU, GELU and FatReLU implementations are equivalent to the native
59-
# PyTorch implementations, so we can do exact comparison.
63+
# The SiluAndMul, MulAndSilu, GELU and FatReLU implementations are
64+
# equivalent to the native PyTorch implementations, so we can do exact
65+
# comparison.
6066
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
6167

6268
d = x.shape[-1] // 2

vllm/model_executor/layers/activation.py

+35
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,41 @@ def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
8787
return out
8888

8989

90+
@CustomOp.register("mul_and_silu")
91+
class MulAndSilu(CustomOp):
92+
"""An activation function for SwiGLU.
93+
94+
The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.
95+
96+
Shapes:
97+
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
98+
return: (num_tokens, d) or (batch_size, seq_len, d)
99+
"""
100+
101+
def __init__(self):
102+
super().__init__()
103+
if current_platform.is_cuda_alike() or current_platform.is_cpu():
104+
self.op = torch.ops._C.mul_and_silu
105+
elif current_platform.is_xpu():
106+
from vllm._ipex_ops import ipex_ops
107+
self.op = ipex_ops.silu_and_mul
108+
109+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
110+
"""PyTorch-native implementation equivalent to forward()."""
111+
d = x.shape[-1] // 2
112+
return x[..., :d] * F.silu(x[..., d:])
113+
114+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
115+
d = x.shape[-1] // 2
116+
output_shape = (x.shape[:-1] + (d, ))
117+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
118+
self.op(out, x)
119+
return out
120+
121+
# TODO implement forward_xpu for MulAndSilu
122+
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
123+
124+
90125
@CustomOp.register("gelu_and_mul")
91126
class GeluAndMul(CustomOp):
92127
"""An activation function for GeGLU.

vllm/model_executor/models/molmo.py

+3-11
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
2424
InputContext, token_inputs)
2525
from vllm.model_executor import SamplingMetadata
26-
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
26+
from vllm.model_executor.layers.activation import (MulAndSilu, QuickGELU,
27+
SiluAndMul)
2728
from vllm.model_executor.layers.layernorm import RMSNorm
2829
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
2930
MergedColumnParallelLinear,
@@ -462,15 +463,6 @@ def forward(
462463
return output
463464

464465

465-
class SwiGLU(nn.Module):
466-
467-
def forward(self, x: torch.Tensor) -> torch.Tensor:
468-
x, gate = x.chunk(2, dim=-1)
469-
# Note that the order is reversed compared to
470-
# SiluAndMul.
471-
return x * F.silu(gate)
472-
473-
474466
class LanuageModelMLP(nn.Module):
475467
"""Molmo's LLM mlp."""
476468

@@ -489,7 +481,7 @@ def __init__(self,
489481
quant_config=quant_config,
490482
)
491483
# Activation function.
492-
self.act_fn = SwiGLU()
484+
self.act_fn = MulAndSilu()
493485
# Feed-forward output projection.
494486
self.down_proj = RowParallelLinear(
495487
self.intermediate_size,

vllm/model_executor/models/ultravox.py

+2-11
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from vllm import envs
1717
from vllm.attention import AttentionMetadata
1818
from vllm.config import VllmConfig
19-
from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn
19+
from vllm.model_executor.layers.activation import MulAndSilu, get_act_fn
2020
from vllm.model_executor.layers.layernorm import RMSNorm
2121
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
2222
from vllm.model_executor.model_loader.loader import DefaultModelLoader
@@ -248,15 +248,6 @@ def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
248248
return audio_embeds
249249

250250

251-
class FlippedSiluAndMul(SiluAndMul):
252-
"""Ultravox is trained with SwiGLU with flipped halves."""
253-
254-
def forward(self, x: torch.Tensor):
255-
a, b = x.chunk(2, dim=-1)
256-
flipped = torch.cat((b, a), dim=-1)
257-
return super().forward(flipped)
258-
259-
260251
class UltravoxProjector(nn.Module):
261252

262253
def __init__(self, config: UltravoxConfig):
@@ -269,7 +260,7 @@ def __init__(self, config: UltravoxConfig):
269260
dim = self.hidden_dim
270261

271262
if config.projector_act == "swiglu":
272-
self.act = FlippedSiluAndMul()
263+
self.act = MulAndSilu()
273264
dim = dim // 2
274265
else:
275266
self.act = get_act_fn(config.projector_act)

0 commit comments

Comments
 (0)