Skip to content

Commit

Permalink
quick test
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 committed Jan 29, 2024
1 parent d7a0701 commit 82c8059
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 47 deletions.
24 changes: 12 additions & 12 deletions python/csrc/flashinfer_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,29 @@
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache,
"Single-request decode with KV-Cache operator");
m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache,
"Single-request prefill with KV-Cache operator");
m.def("single_prefill_with_kv_cache_return_lse", &single_prefill_with_kv_cache_return_lse,
"Single-request prefill with KV-Cache operator, return logsumexp");
// m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache,
// "Single-request prefill with KV-Cache operator");
// m.def("single_prefill_with_kv_cache_return_lse", &single_prefill_with_kv_cache_return_lse,
// "Single-request prefill with KV-Cache operator, return logsumexp");
m.def("merge_state", &merge_state, "Merge two self-attention states");
m.def("merge_states", &merge_states, "Merge multiple self-attention states");
m.def("batch_decode_with_padded_kv_cache", &batch_decode_with_padded_kv_cache,
"Multi-request batch decode with padded KV-Cache operator");
m.def("batch_decode_with_padded_kv_cache_return_lse",
&batch_decode_with_padded_kv_cache_return_lse,
"Multi-request batch decode with padded KV-Cache operator, return logsumexp");
m.def("batch_prefill_with_paged_kv_cache", &batch_prefill_with_paged_kv_cache,
"Multi-request batch prefill with paged KV-Cache operator");
// m.def("batch_prefill_with_paged_kv_cache", &batch_prefill_with_paged_kv_cache,
// "Multi-request batch prefill with paged KV-Cache operator");
py::class_<BatchDecodeWithPagedKVCachePyTorchWrapper>(m,
"BatchDecodeWithPagedKVCachePyTorchWrapper")
.def(py::init(&BatchDecodeWithPagedKVCachePyTorchWrapper::Create))
.def("begin_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward)
.def("forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::Forward);
py::class_<BatchPrefillWithPagedKVCachePyTorchWrapper>(
m, "BatchPrefillWithPagedKVCachePyTorchWrapper")
.def(py::init(&BatchPrefillWithPagedKVCachePyTorchWrapper::Create))
.def("begin_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward)
.def("end_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward)
.def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward);
// py::class_<BatchPrefillWithPagedKVCachePyTorchWrapper>(
// m, "BatchPrefillWithPagedKVCachePyTorchWrapper")
// .def(py::init(&BatchPrefillWithPagedKVCachePyTorchWrapper::Create))
// .def("begin_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward)
// .def("end_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward)
// .def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward);
}
64 changes: 32 additions & 32 deletions python/csrc/flashinfer_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc
unsigned int layout, float sm_scale, float rope_scale,
float rope_theta);

torch::Tensor single_prefill_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v,
torch::Tensor tmp, bool causal, unsigned int layout,
unsigned int rotary_mode, bool allow_fp16_qk_reduction,
float rope_scale, float rope_theta);
// torch::Tensor single_prefill_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v,
// torch::Tensor tmp, bool causal, unsigned int layout,
// unsigned int rotary_mode, bool allow_fp16_qk_reduction,
// float rope_scale, float rope_theta);

std::vector<torch::Tensor> single_prefill_with_kv_cache_return_lse(
torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal,
unsigned int layout, unsigned int rotary_mode, bool allow_fp16_qk_reduction, float rope_scale,
float rope_theta);
// std::vector<torch::Tensor> single_prefill_with_kv_cache_return_lse(
// torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal,
// unsigned int layout, unsigned int rotary_mode, bool allow_fp16_qk_reduction, float rope_scale,
// float rope_theta);

std::vector<torch::Tensor> merge_state(torch::Tensor v_a, torch::Tensor s_a, torch::Tensor v_b,
torch::Tensor s_b);
Expand All @@ -47,10 +47,10 @@ std::vector<torch::Tensor> batch_decode_with_padded_kv_cache_return_lse(
torch::Tensor q, torch::Tensor k_padded, torch::Tensor v_padded, unsigned int layout,
unsigned int rotary_mode, float sm_scale, float rope_scale, float rope_theta);

torch::Tensor batch_prefill_with_paged_kv_cache(
torch::Tensor q, torch::Tensor q_indptr, torch::Tensor kv_data, torch::Tensor kv_indptr,
torch::Tensor kv_indices, torch::Tensor kv_last_page_len, bool causal, unsigned int layout,
unsigned int rotary_mode, bool allow_fp16_qk_reduction, float rope_scale, float rope_theta);
// torch::Tensor batch_prefill_with_paged_kv_cache(
// torch::Tensor q, torch::Tensor q_indptr, torch::Tensor kv_data, torch::Tensor kv_indptr,
// torch::Tensor kv_indices, torch::Tensor kv_last_page_len, bool causal, unsigned int layout,
// unsigned int rotary_mode, bool allow_fp16_qk_reduction, float rope_scale, float rope_theta);

class BatchDecodeWithPagedKVCachePyTorchWrapper {
public:
Expand All @@ -74,24 +74,24 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper {
flashinfer::QKVLayout kv_layout_;
};

class BatchPrefillWithPagedKVCachePyTorchWrapper {
public:
static BatchPrefillWithPagedKVCachePyTorchWrapper Create(unsigned int layout) {
return BatchPrefillWithPagedKVCachePyTorchWrapper(layout);
}
void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads);
void EndForward();
std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr,
torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr,
torch::Tensor paged_kv_indices,
torch::Tensor paged_kv_last_page_len, bool causal,
unsigned int rotary_mode, bool allow_fp16_qk_reduction,
float rope_scale, float rope_theta, bool return_lse);
// class BatchPrefillWithPagedKVCachePyTorchWrapper {
// public:
// static BatchPrefillWithPagedKVCachePyTorchWrapper Create(unsigned int layout) {
// return BatchPrefillWithPagedKVCachePyTorchWrapper(layout);
// }
// void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr,
// unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads);
// void EndForward();
// std::vector<torch::Tensor> Forward(torch::Tensor q, torch::Tensor qo_indptr,
// torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr,
// torch::Tensor paged_kv_indices,
// torch::Tensor paged_kv_last_page_len, bool causal,
// unsigned int rotary_mode, bool allow_fp16_qk_reduction,
// float rope_scale, float rope_theta, bool return_lse);

private:
BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout)
: kv_layout_(flashinfer::QKVLayout(layout)) {}
flashinfer::BatchPrefillHandler handler_;
flashinfer::QKVLayout kv_layout_;
};
// private:
// BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout)
// : kv_layout_(flashinfer::QKVLayout(layout)) {}
// flashinfer::BatchPrefillHandler handler_;
// flashinfer::QKVLayout kv_layout_;
// };
6 changes: 3 additions & 3 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ def remove_unwanted_pytorch_nvcc_flags():
name="flashinfer.ops._kernels",
sources=[
"csrc/single_decode.cu",
"csrc/single_prefill.cu",
"csrc/cascade.cu",
# "csrc/single_prefill.cu",
# "csrc/cascade.cu",
"csrc/batch_decode.cu",
"csrc/flashinfer_ops.cu",
"csrc/batch_prefill.cu",
# "csrc/batch_prefill.cu",
],
include_dirs=[
str(root.resolve() / "include"),
Expand Down

0 comments on commit 82c8059

Please sign in to comment.