diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index dd91808f..ac9dc30e 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -493,6 +493,7 @@ def plan( pos_encoding_mode: str = "NONE", window_left: int = -1, logits_soft_cap: Optional[float] = None, + data_type: Optional[Union[str, torch.dtype]] = "float16", q_data_type: Optional[Union[str, torch.dtype]] = "float16", kv_data_type: Optional[Union[str, torch.dtype]] = None, sm_scale: Optional[float] = None, @@ -536,6 +537,9 @@ def plan( kv_data_type : Optional[Union[str, torch.dtype]] The data type of the key/value tensor. If None, will be set to ``q_data_type``. Defaults to ``None``. + data_type: Optional[Union[str, torch.dtype]] + The data type of both the query and key/value tensors. Defaults to torch.float16. + data_type is deprecated, please use q_data_type and kv_data_type instead. Note ---- @@ -580,6 +584,10 @@ def plan( qo_indptr = qo_indptr.to("cpu", non_blocking=True) indptr = indptr.to("cpu", non_blocking=True) + if data_type is not None: + q_data_type = data_type + kv_data_type = data_type + q_data_type = canonicalize_torch_dtype(q_data_type) if kv_data_type is None: kv_data_type = q_data_type diff --git a/python/flashinfer/jit/attention.py b/python/flashinfer/jit/attention.py index 42781064..ac96a28b 100644 --- a/python/flashinfer/jit/attention.py +++ b/python/flashinfer/jit/attention.py @@ -85,7 +85,7 @@ def gen_single_decode_cu(*args) -> Tuple[str, pathlib.Path]: path, get_single_decode_cu_str(*args), ) - return file_name, path + return uri, path def get_batch_decode_cu_str(