Skip to content

Commit 7059043

Browse files
mzusmanAlvant
authored andcommitted
[Kernel/Model] Migrate mamba_ssm and causal_conv1d kernels to vLLM (vllm-project#7651)
Signed-off-by: Alvant <[email protected]>
1 parent 9b96b34 commit 7059043

20 files changed

+2815
-31
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
203203
FetchContent_MakeAvailable(cutlass)
204204

205205
list(APPEND VLLM_EXT_SRC
206+
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
207+
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
206208
"csrc/quantization/aqlm/gemm_kernels.cu"
207209
"csrc/quantization/awq/gemm_kernels.cu"
208210
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"

Dockerfile

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,6 @@ COPY requirements-cuda.txt requirements-cuda.txt
4242
RUN --mount=type=cache,target=/root/.cache/pip \
4343
python3 -m pip install -r requirements-cuda.txt
4444

45-
COPY requirements-mamba.txt requirements-mamba.txt
46-
RUN python3 -m pip install packaging
47-
RUN python3 -m pip install -r requirements-mamba.txt
4845

4946
# cuda arch list used by torch
5047
# can be useful for both `dev` and `test`
@@ -127,22 +124,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \
127124
python3 -m pip install -r requirements-dev.txt
128125

129126
#################### DEV IMAGE ####################
130-
#################### MAMBA Build IMAGE ####################
131-
FROM dev as mamba-builder
132-
# max jobs used for build
133-
ARG max_jobs=2
134-
ENV MAX_JOBS=${max_jobs}
135-
136-
WORKDIR /usr/src/mamba
137-
138-
COPY requirements-mamba.txt requirements-mamba.txt
139-
140-
# Download the wheel or build it if a pre-compiled release doesn't exist
141-
RUN pip --verbose wheel -r requirements-mamba.txt \
142-
--no-build-isolation --no-deps --no-cache-dir
143-
144-
#################### MAMBA Build IMAGE ####################
145-
146127
#################### vLLM installation IMAGE ####################
147128
# image with vLLM installed
148129
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04 AS vllm-base
@@ -179,10 +160,6 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
179160
--mount=type=cache,target=/root/.cache/pip \
180161
python3 -m pip install dist/*.whl --verbose
181162

182-
RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \
183-
--mount=type=cache,target=/root/.cache/pip \
184-
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
185-
186163
RUN --mount=type=cache,target=/root/.cache/pip \
187164
. /etc/environment && \
188165
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl

csrc/mamba/causal_conv1d/causal_conv1d.cu

Lines changed: 700 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
/******************************************************************************
2+
* Copyright (c) 2024, Tri Dao.
3+
******************************************************************************/
4+
// clang-format off
5+
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
6+
#pragma once
7+
8+
#include <cuda_bf16.h>
9+
#include <cuda_fp16.h>
10+
////////////////////////////////////////////////////////////////////////////////////////////////////
11+
12+
struct ConvParamsBase {
13+
using index_t = uint32_t;
14+
15+
int batch, dim, seqlen, width;
16+
bool silu_activation;
17+
18+
index_t x_batch_stride;
19+
index_t x_c_stride;
20+
index_t x_l_stride;
21+
index_t weight_c_stride;
22+
index_t weight_width_stride;
23+
index_t out_batch_stride;
24+
index_t out_c_stride;
25+
index_t out_l_stride;
26+
27+
index_t conv_state_batch_stride;
28+
index_t conv_state_c_stride;
29+
index_t conv_state_l_stride;
30+
31+
// Common data pointers.
32+
void *__restrict__ x_ptr;
33+
void *__restrict__ weight_ptr;
34+
void *__restrict__ bias_ptr;
35+
void *__restrict__ out_ptr;
36+
37+
void *__restrict__ conv_state_ptr;
38+
39+
void *__restrict__ seq_idx_ptr;
40+
41+
// No __restrict__ since initial_states could be the same as final_states.
42+
void * initial_states_ptr;
43+
index_t initial_states_batch_stride;
44+
index_t initial_states_l_stride;
45+
index_t initial_states_c_stride;
46+
47+
void * final_states_ptr;
48+
index_t final_states_batch_stride;
49+
index_t final_states_l_stride;
50+
index_t final_states_c_stride;
51+
};
52+
53+
54+
#ifndef USE_ROCM
55+
#include <cuda_bf16.h>
56+
57+
template<typename T>
58+
__device__ inline T shuffle_xor(T val, int offset) {
59+
return __shfl_xor_sync(uint32_t(-1), val, offset);
60+
}
61+
62+
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
63+
{
64+
return std::max(ilist);
65+
}
66+
67+
template<typename T>
68+
constexpr T constexpr_min(T a, T b) {
69+
return std::min(a, b);
70+
}
71+
72+
#else
73+
#include <hip/hip_bf16.h>
74+
75+
template<typename T>
76+
__device__ inline T shuffle_xor(T val, int offset) {
77+
return __shfl_xor(val, offset);
78+
}
79+
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
80+
{
81+
return *std::max_element(ilist.begin(), ilist.end());
82+
}
83+
84+
template<typename T>
85+
constexpr T constexpr_min(T a, T b) {
86+
return a < b ? a : b;
87+
}
88+
#endif
89+
90+
////////////////////////////////////////////////////////////////////////////////////////////////////
91+
92+
template<int BYTES> struct BytesToType {};
93+
94+
template<> struct BytesToType<16> {
95+
using Type = uint4;
96+
static_assert(sizeof(Type) == 16);
97+
};
98+
99+
template<> struct BytesToType<8> {
100+
using Type = uint64_t;
101+
static_assert(sizeof(Type) == 8);
102+
};
103+
104+
template<> struct BytesToType<4> {
105+
using Type = uint32_t;
106+
static_assert(sizeof(Type) == 4);
107+
};
108+
109+
template<> struct BytesToType<2> {
110+
using Type = uint16_t;
111+
static_assert(sizeof(Type) == 2);
112+
};
113+
114+
template<> struct BytesToType<1> {
115+
using Type = uint8_t;
116+
static_assert(sizeof(Type) == 1);
117+
};
118+
119+
////////////////////////////////////////////////////////////////////////////////////////////////////
120+
121+
template<typename T>
122+
struct SumOp {
123+
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
124+
};
125+
126+
template<int THREADS>
127+
struct Allreduce {
128+
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
129+
template<typename T, typename Operator>
130+
static __device__ inline T run(T x, Operator &op) {
131+
constexpr int OFFSET = THREADS / 2;
132+
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
133+
return Allreduce<OFFSET>::run(x, op);
134+
}
135+
};
136+
137+
template<>
138+
struct Allreduce<2> {
139+
template<typename T, typename Operator>
140+
static __device__ inline T run(T x, Operator &op) {
141+
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
142+
return x;
143+
}
144+
};
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Inspired by
2+
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
3+
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
4+
// clang-format off
5+
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
6+
7+
#pragma once
8+
9+
/// @param COND - a boolean expression to switch by
10+
/// @param CONST_NAME - a name given for the constexpr bool variable.
11+
/// @param ... - code to execute for true and false
12+
///
13+
/// Usage:
14+
/// ```
15+
/// BOOL_SWITCH(flag, BoolConst, [&] {
16+
/// some_function<BoolConst>(...);
17+
/// });
18+
/// ```
19+
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
20+
[&] { \
21+
if (COND) { \
22+
static constexpr bool CONST_NAME = true; \
23+
return __VA_ARGS__(); \
24+
} else { \
25+
static constexpr bool CONST_NAME = false; \
26+
return __VA_ARGS__(); \
27+
} \
28+
}()

0 commit comments

Comments
 (0)