9
9
10
10
namespace vllm {
11
11
12
+ template <typename scalar_t , scalar_t (*ACT_FN)(const scalar_t &),
13
+ bool act_first>
14
+ __device__ __forceinline__ scalar_t compute (const scalar_t & x,
15
+ const scalar_t & y) {
16
+ return act_first ? ACT_FN (x) * y : x * ACT_FN (y);
17
+ }
12
18
// Activation and gating kernel template.
13
- template <typename scalar_t , scalar_t (*ACT_FN)(const scalar_t &)>
19
+
20
+ template <typename scalar_t , scalar_t (*ACT_FN)(const scalar_t &),
21
+ bool act_first>
14
22
__global__ void act_and_mul_kernel (
15
23
scalar_t * __restrict__ out, // [..., d]
16
24
const scalar_t * __restrict__ input, // [..., 2, d]
@@ -19,7 +27,7 @@ __global__ void act_and_mul_kernel(
19
27
for (int64_t idx = threadIdx .x ; idx < d; idx += blockDim .x ) {
20
28
const scalar_t x = VLLM_LDG (&input[token_idx * 2 * d + idx]);
21
29
const scalar_t y = VLLM_LDG (&input[token_idx * 2 * d + d + idx]);
22
- out[token_idx * d + idx] = ACT_FN (x) * y ;
30
+ out[token_idx * d + idx] = compute< scalar_t , ACT_FN, act_first>(x, y) ;
23
31
}
24
32
}
25
33
@@ -55,7 +63,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
55
63
} // namespace vllm
56
64
57
65
// Launch activation and gating kernel.
58
- #define LAUNCH_ACTIVATION_GATE_KERNEL (KERNEL ) \
66
+ // Use ACT_FIRST (bool) indicating whether to apply the activation function
67
+ // first.
68
+ #define LAUNCH_ACTIVATION_GATE_KERNEL (KERNEL, ACT_FIRST ) \
59
69
int d = input.size(-1 ) / 2 ; \
60
70
int64_t num_tokens = input.numel() / input.size(-1 ); \
61
71
dim3 grid (num_tokens); \
@@ -64,27 +74,35 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
64
74
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
65
75
VLLM_DISPATCH_FLOATING_TYPES ( \
66
76
input.scalar_type(), "act_and_mul_kernel", [&] { \
67
- vllm::act_and_mul_kernel<scalar_t , KERNEL<scalar_t >> \
77
+ vllm::act_and_mul_kernel<scalar_t , KERNEL<scalar_t >, ACT_FIRST> \
68
78
<<<grid, block, 0 , stream>>> (out.data_ptr <scalar_t >(), \
69
79
input.data_ptr <scalar_t >(), d); \
70
80
});
71
81
72
82
void silu_and_mul (torch::Tensor& out, // [..., d]
73
83
torch::Tensor& input) // [..., 2 * d]
74
84
{
75
- LAUNCH_ACTIVATION_GATE_KERNEL (vllm::silu_kernel);
85
+ LAUNCH_ACTIVATION_GATE_KERNEL (vllm::silu_kernel, true );
86
+ }
87
+
88
+ void mul_and_silu (torch::Tensor& out, // [..., d]
89
+ torch::Tensor& input) // [..., 2 * d]
90
+ {
91
+ // The difference between mul_and_silu and silu_and_mul is that mul_and_silu
92
+ // applies the silu to the latter half of the input.
93
+ LAUNCH_ACTIVATION_GATE_KERNEL (vllm::silu_kernel, false );
76
94
}
77
95
78
96
void gelu_and_mul (torch::Tensor& out, // [..., d]
79
97
torch::Tensor& input) // [..., 2 * d]
80
98
{
81
- LAUNCH_ACTIVATION_GATE_KERNEL (vllm::gelu_kernel);
99
+ LAUNCH_ACTIVATION_GATE_KERNEL (vllm::gelu_kernel, true );
82
100
}
83
101
84
102
void gelu_tanh_and_mul (torch::Tensor& out, // [..., d]
85
103
torch::Tensor& input) // [..., 2 * d]
86
104
{
87
- LAUNCH_ACTIVATION_GATE_KERNEL (vllm::gelu_tanh_kernel);
105
+ LAUNCH_ACTIVATION_GATE_KERNEL (vllm::gelu_tanh_kernel, true );
88
106
}
89
107
90
108
namespace vllm {
0 commit comments