-
-
Notifications
You must be signed in to change notification settings - Fork 7.7k
[Kernel/Model] Migrate mamba_ssm and causal_conv1d kernels to vLLM #7651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
50 commits
Select commit
Hold shift + click to select a range
59e6abf
Migrate mamba_ssm and causal_conv1d kernels to vLLM
mzusman d2348ec
Casual conv1d compiles
mzusman 66ee5af
Add casual_conv1d to _custom_ops
mzusman 7a0d206
Add mamba ops and triton kernels
mzusman 145b6b7
Add casual_conv1d update
mzusman 2bdd7f5
setup selective scan fwd pass
mzusman e25dbfe
Format
mzusman 64b6160
Do not have a mamba layer for now, push in a future PR
mzusman 2ff36cb
Format
mzusman 5f9c383
Take off mamba from image and requirements
mzusman ac8354e
Add tests
mzusman ea80282
Some small fixes, tests still do not pass
mzusman 2f15495
Fix tests
mzusman b51fd28
Causal conv1d tests are passing
mzusman 0cc2252
Import
mzusman d65dfb6
Tests
mzusman e7b2b32
Format
mzusman 2c9fe00
Cleanup
mzusman c82cc30
Align with main
mzusman 6c83e5f
Format
mzusman cd78cf6
Merge remote-tracking branch 'github/main' into mamba_kernels_migrate
mzusman b6a00cb
Add init py files
mzusman ef69b6c
Move kernels to cuda only
mzusman 152f331
Revert "Move kernels to cuda only"
mzusman 39f0fa0
move kernels to if cuda
mzusman 42f94b7
Fix tests
mzusman f050781
Revert formating
mzusman c8ffba5
Format
mzusman 04f947b
Add comments on adapted from mamba/casual conv1d repos
mzusman 732db18
pare down number of w/i dtype combinations
mzusman fdca1ff
Clean up not used
mzusman fe70a39
Rename typo
mzusman 9a0e538
Add comment on einops
mzusman 619a40a
Remove requirement for einops
mzusman 5d0d2db
Fix tests after paring down kernels
mzusman c622375
format
mzusman cdc9205
Fix typo
mzusman 42d9c59
Merge remote-tracking branch 'github/main' into mamba_kernels_migrate
mzusman 308c922
register meta functions to the kernels
mzusman d921a48
Revert "register meta functions to the kernels"
mzusman a8078e7
move to ifndef ROCm
mzusman 2ca8db7
Format
mzusman abf02fa
Reduce combinations of bool switch to reduce wheel size
mzusman 633225c
Fix, use float as weight dtype
mzusman ec0112b
Merge remote-tracking branch 'github/main' into mamba_kernels_migrate
mzusman 1f35bbe
Take down seq_pos_idx, not used atm, will comeback in a following PR
mzusman bed44c4
Add comments and guard checks on disabled "features"
mzusman 950701a
Fix header file
mzusman 4e5d6b4
Merge remote-tracking branch 'github/main' into mamba_kernels_migrate
mzusman d23a429
Merge remote-tracking branch 'github/main' into mamba_kernels_migrate
mzusman File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
/****************************************************************************** | ||
* Copyright (c) 2024, Tri Dao. | ||
******************************************************************************/ | ||
// clang-format off | ||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h | ||
#pragma once | ||
|
||
#include <cuda_bf16.h> | ||
#include <cuda_fp16.h> | ||
//////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
struct ConvParamsBase { | ||
using index_t = uint32_t; | ||
|
||
int batch, dim, seqlen, width; | ||
bool silu_activation; | ||
|
||
index_t x_batch_stride; | ||
index_t x_c_stride; | ||
index_t x_l_stride; | ||
index_t weight_c_stride; | ||
index_t weight_width_stride; | ||
index_t out_batch_stride; | ||
index_t out_c_stride; | ||
index_t out_l_stride; | ||
|
||
index_t conv_state_batch_stride; | ||
index_t conv_state_c_stride; | ||
index_t conv_state_l_stride; | ||
|
||
// Common data pointers. | ||
void *__restrict__ x_ptr; | ||
void *__restrict__ weight_ptr; | ||
void *__restrict__ bias_ptr; | ||
void *__restrict__ out_ptr; | ||
|
||
void *__restrict__ conv_state_ptr; | ||
|
||
void *__restrict__ seq_idx_ptr; | ||
|
||
// No __restrict__ since initial_states could be the same as final_states. | ||
void * initial_states_ptr; | ||
index_t initial_states_batch_stride; | ||
index_t initial_states_l_stride; | ||
index_t initial_states_c_stride; | ||
|
||
void * final_states_ptr; | ||
index_t final_states_batch_stride; | ||
index_t final_states_l_stride; | ||
index_t final_states_c_stride; | ||
}; | ||
|
||
|
||
#ifndef USE_ROCM | ||
#include <cuda_bf16.h> | ||
|
||
template<typename T> | ||
__device__ inline T shuffle_xor(T val, int offset) { | ||
return __shfl_xor_sync(uint32_t(-1), val, offset); | ||
} | ||
|
||
constexpr size_t custom_max(std::initializer_list<size_t> ilist) | ||
{ | ||
return std::max(ilist); | ||
} | ||
|
||
template<typename T> | ||
constexpr T constexpr_min(T a, T b) { | ||
return std::min(a, b); | ||
} | ||
|
||
#else | ||
#include <hip/hip_bf16.h> | ||
|
||
template<typename T> | ||
__device__ inline T shuffle_xor(T val, int offset) { | ||
return __shfl_xor(val, offset); | ||
} | ||
constexpr size_t custom_max(std::initializer_list<size_t> ilist) | ||
{ | ||
return *std::max_element(ilist.begin(), ilist.end()); | ||
} | ||
|
||
template<typename T> | ||
constexpr T constexpr_min(T a, T b) { | ||
return a < b ? a : b; | ||
} | ||
#endif | ||
|
||
//////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
template<int BYTES> struct BytesToType {}; | ||
|
||
template<> struct BytesToType<16> { | ||
using Type = uint4; | ||
static_assert(sizeof(Type) == 16); | ||
}; | ||
|
||
template<> struct BytesToType<8> { | ||
using Type = uint64_t; | ||
static_assert(sizeof(Type) == 8); | ||
}; | ||
|
||
template<> struct BytesToType<4> { | ||
using Type = uint32_t; | ||
static_assert(sizeof(Type) == 4); | ||
}; | ||
|
||
template<> struct BytesToType<2> { | ||
using Type = uint16_t; | ||
static_assert(sizeof(Type) == 2); | ||
}; | ||
|
||
template<> struct BytesToType<1> { | ||
using Type = uint8_t; | ||
static_assert(sizeof(Type) == 1); | ||
}; | ||
|
||
//////////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
||
template<typename T> | ||
struct SumOp { | ||
__device__ inline T operator()(T const & x, T const & y) { return x + y; } | ||
}; | ||
|
||
template<int THREADS> | ||
struct Allreduce { | ||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); | ||
template<typename T, typename Operator> | ||
static __device__ inline T run(T x, Operator &op) { | ||
constexpr int OFFSET = THREADS / 2; | ||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); | ||
return Allreduce<OFFSET>::run(x, op); | ||
} | ||
}; | ||
|
||
template<> | ||
struct Allreduce<2> { | ||
template<typename T, typename Operator> | ||
static __device__ inline T run(T x, Operator &op) { | ||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); | ||
return x; | ||
} | ||
}; |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
// Inspired by | ||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h | ||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h | ||
// clang-format off | ||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h | ||
|
||
#pragma once | ||
|
||
/// @param COND - a boolean expression to switch by | ||
/// @param CONST_NAME - a name given for the constexpr bool variable. | ||
/// @param ... - code to execute for true and false | ||
/// | ||
/// Usage: | ||
/// ``` | ||
/// BOOL_SWITCH(flag, BoolConst, [&] { | ||
/// some_function<BoolConst>(...); | ||
/// }); | ||
/// ``` | ||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \ | ||
[&] { \ | ||
if (COND) { \ | ||
static constexpr bool CONST_NAME = true; \ | ||
return __VA_ARGS__(); \ | ||
} else { \ | ||
static constexpr bool CONST_NAME = false; \ | ||
return __VA_ARGS__(); \ | ||
} \ | ||
}() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.