diff --git a/python/csrc/flashinfer_ops.cu b/python/csrc/flashinfer_ops.cu index 42c93c1f..8ffa03ea 100644 --- a/python/csrc/flashinfer_ops.cu +++ b/python/csrc/flashinfer_ops.cu @@ -20,10 +20,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("single_decode_with_kv_cache", &single_decode_with_kv_cache, "Single-request decode with KV-Cache operator"); - m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache, - "Single-request prefill with KV-Cache operator"); - m.def("single_prefill_with_kv_cache_return_lse", &single_prefill_with_kv_cache_return_lse, - "Single-request prefill with KV-Cache operator, return logsumexp"); +// m.def("single_prefill_with_kv_cache", &single_prefill_with_kv_cache, +// "Single-request prefill with KV-Cache operator"); +// m.def("single_prefill_with_kv_cache_return_lse", &single_prefill_with_kv_cache_return_lse, +// "Single-request prefill with KV-Cache operator, return logsumexp"); m.def("merge_state", &merge_state, "Merge two self-attention states"); m.def("merge_states", &merge_states, "Merge multiple self-attention states"); m.def("batch_decode_with_padded_kv_cache", &batch_decode_with_padded_kv_cache, @@ -31,18 +31,18 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("batch_decode_with_padded_kv_cache_return_lse", &batch_decode_with_padded_kv_cache_return_lse, "Multi-request batch decode with padded KV-Cache operator, return logsumexp"); - m.def("batch_prefill_with_paged_kv_cache", &batch_prefill_with_paged_kv_cache, - "Multi-request batch prefill with paged KV-Cache operator"); +// m.def("batch_prefill_with_paged_kv_cache", &batch_prefill_with_paged_kv_cache, +// "Multi-request batch prefill with paged KV-Cache operator"); py::class_(m, "BatchDecodeWithPagedKVCachePyTorchWrapper") .def(py::init(&BatchDecodeWithPagedKVCachePyTorchWrapper::Create)) .def("begin_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward) .def("end_forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::EndForward) .def("forward", &BatchDecodeWithPagedKVCachePyTorchWrapper::Forward); - py::class_( - m, "BatchPrefillWithPagedKVCachePyTorchWrapper") - .def(py::init(&BatchPrefillWithPagedKVCachePyTorchWrapper::Create)) - .def("begin_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward) - .def("end_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward) - .def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward); +// py::class_( +// m, "BatchPrefillWithPagedKVCachePyTorchWrapper") +// .def(py::init(&BatchPrefillWithPagedKVCachePyTorchWrapper::Create)) +// .def("begin_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::BeginForward) +// .def("end_forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::EndForward) +// .def("forward", &BatchPrefillWithPagedKVCachePyTorchWrapper::Forward); } diff --git a/python/csrc/flashinfer_ops.h b/python/csrc/flashinfer_ops.h index c89e157b..f02f7446 100644 --- a/python/csrc/flashinfer_ops.h +++ b/python/csrc/flashinfer_ops.h @@ -23,15 +23,15 @@ torch::Tensor single_decode_with_kv_cache(torch::Tensor q, torch::Tensor k, torc unsigned int layout, float sm_scale, float rope_scale, float rope_theta); -torch::Tensor single_prefill_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, - torch::Tensor tmp, bool causal, unsigned int layout, - unsigned int rotary_mode, bool allow_fp16_qk_reduction, - float rope_scale, float rope_theta); +// torch::Tensor single_prefill_with_kv_cache(torch::Tensor q, torch::Tensor k, torch::Tensor v, +// torch::Tensor tmp, bool causal, unsigned int layout, +// unsigned int rotary_mode, bool allow_fp16_qk_reduction, +// float rope_scale, float rope_theta); -std::vector single_prefill_with_kv_cache_return_lse( - torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal, - unsigned int layout, unsigned int rotary_mode, bool allow_fp16_qk_reduction, float rope_scale, - float rope_theta); +// std::vector single_prefill_with_kv_cache_return_lse( +// torch::Tensor q, torch::Tensor k, torch::Tensor v, torch::Tensor tmp, bool causal, +// unsigned int layout, unsigned int rotary_mode, bool allow_fp16_qk_reduction, float rope_scale, +// float rope_theta); std::vector merge_state(torch::Tensor v_a, torch::Tensor s_a, torch::Tensor v_b, torch::Tensor s_b); @@ -47,10 +47,10 @@ std::vector batch_decode_with_padded_kv_cache_return_lse( torch::Tensor q, torch::Tensor k_padded, torch::Tensor v_padded, unsigned int layout, unsigned int rotary_mode, float sm_scale, float rope_scale, float rope_theta); -torch::Tensor batch_prefill_with_paged_kv_cache( - torch::Tensor q, torch::Tensor q_indptr, torch::Tensor kv_data, torch::Tensor kv_indptr, - torch::Tensor kv_indices, torch::Tensor kv_last_page_len, bool causal, unsigned int layout, - unsigned int rotary_mode, bool allow_fp16_qk_reduction, float rope_scale, float rope_theta); +// torch::Tensor batch_prefill_with_paged_kv_cache( +// torch::Tensor q, torch::Tensor q_indptr, torch::Tensor kv_data, torch::Tensor kv_indptr, +// torch::Tensor kv_indices, torch::Tensor kv_last_page_len, bool causal, unsigned int layout, +// unsigned int rotary_mode, bool allow_fp16_qk_reduction, float rope_scale, float rope_theta); class BatchDecodeWithPagedKVCachePyTorchWrapper { public: @@ -74,24 +74,24 @@ class BatchDecodeWithPagedKVCachePyTorchWrapper { flashinfer::QKVLayout kv_layout_; }; -class BatchPrefillWithPagedKVCachePyTorchWrapper { - public: - static BatchPrefillWithPagedKVCachePyTorchWrapper Create(unsigned int layout) { - return BatchPrefillWithPagedKVCachePyTorchWrapper(layout); - } - void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr, - unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads); - void EndForward(); - std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, - torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, - torch::Tensor paged_kv_indices, - torch::Tensor paged_kv_last_page_len, bool causal, - unsigned int rotary_mode, bool allow_fp16_qk_reduction, - float rope_scale, float rope_theta, bool return_lse); +// class BatchPrefillWithPagedKVCachePyTorchWrapper { +// public: +// static BatchPrefillWithPagedKVCachePyTorchWrapper Create(unsigned int layout) { +// return BatchPrefillWithPagedKVCachePyTorchWrapper(layout); +// } +// void BeginForward(torch::Tensor workspace_buffer, torch::Tensor qo_indptr, +// unsigned int batch_size, unsigned int num_qo_heads, unsigned int num_kv_heads); +// void EndForward(); +// std::vector Forward(torch::Tensor q, torch::Tensor qo_indptr, +// torch::Tensor paged_kv_data, torch::Tensor paged_kv_indptr, +// torch::Tensor paged_kv_indices, +// torch::Tensor paged_kv_last_page_len, bool causal, +// unsigned int rotary_mode, bool allow_fp16_qk_reduction, +// float rope_scale, float rope_theta, bool return_lse); - private: - BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout) - : kv_layout_(flashinfer::QKVLayout(layout)) {} - flashinfer::BatchPrefillHandler handler_; - flashinfer::QKVLayout kv_layout_; -}; +// private: +// BatchPrefillWithPagedKVCachePyTorchWrapper(unsigned int layout) +// : kv_layout_(flashinfer::QKVLayout(layout)) {} +// flashinfer::BatchPrefillHandler handler_; +// flashinfer::QKVLayout kv_layout_; +// }; diff --git a/python/setup.py b/python/setup.py index de01bc61..f224f69b 100644 --- a/python/setup.py +++ b/python/setup.py @@ -78,11 +78,11 @@ def remove_unwanted_pytorch_nvcc_flags(): name="flashinfer.ops._kernels", sources=[ "csrc/single_decode.cu", - "csrc/single_prefill.cu", - "csrc/cascade.cu", + # "csrc/single_prefill.cu", + # "csrc/cascade.cu", "csrc/batch_decode.cu", "csrc/flashinfer_ops.cu", - "csrc/batch_prefill.cu", + # "csrc/batch_prefill.cu", ], include_dirs=[ str(root.resolve() / "include"),