From 78e26e47b95bea994ad2a47e1b1f42810363429c Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 20 Oct 2024 01:27:09 -0700 Subject: [PATCH] bugfix: backward compatibility (#542) We recently changed the plan function signature and remove the `data_type` argument, which is not compatible with some old version. This PR keeps the `data_type` (but mark it as deprecated in documentation) for backward compatibility. Also fix a bug in `gen_single_decode_cu` function (return uri instead of filename). --- python/flashinfer/decode.py | 8 ++++++++ python/flashinfer/jit/attention.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index dd91808f..ac9dc30e 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -493,6 +493,7 @@ def plan( pos_encoding_mode: str = "NONE", window_left: int = -1, logits_soft_cap: Optional[float] = None, + data_type: Optional[Union[str, torch.dtype]] = "float16", q_data_type: Optional[Union[str, torch.dtype]] = "float16", kv_data_type: Optional[Union[str, torch.dtype]] = None, sm_scale: Optional[float] = None, @@ -536,6 +537,9 @@ def plan( kv_data_type : Optional[Union[str, torch.dtype]] The data type of the key/value tensor. If None, will be set to ``q_data_type``. Defaults to ``None``. + data_type: Optional[Union[str, torch.dtype]] + The data type of both the query and key/value tensors. Defaults to torch.float16. + data_type is deprecated, please use q_data_type and kv_data_type instead. Note ---- @@ -580,6 +584,10 @@ def plan( qo_indptr = qo_indptr.to("cpu", non_blocking=True) indptr = indptr.to("cpu", non_blocking=True) + if data_type is not None: + q_data_type = data_type + kv_data_type = data_type + q_data_type = canonicalize_torch_dtype(q_data_type) if kv_data_type is None: kv_data_type = q_data_type diff --git a/python/flashinfer/jit/attention.py b/python/flashinfer/jit/attention.py index 42781064..ac96a28b 100644 --- a/python/flashinfer/jit/attention.py +++ b/python/flashinfer/jit/attention.py @@ -85,7 +85,7 @@ def gen_single_decode_cu(*args) -> Tuple[str, pathlib.Path]: path, get_single_decode_cu_str(*args), ) - return file_name, path + return uri, path def get_batch_decode_cu_str(