Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao committed Nov 1, 2024
1 parent 65945e0 commit 529b28a
Show file tree
Hide file tree
Showing 9 changed files with 10 additions and 23 deletions.
3 changes: 1 addition & 2 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,11 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,


direct_register_custom_op(
library_name="silly",
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
lib=silly_lib,
target_lib=silly_lib,
)


Expand Down
3 changes: 1 addition & 2 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,11 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,


direct_register_custom_op(
library_name="silly",
op_name="attention",
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
lib=silly_lib,
target_lib=silly_lib,
)


Expand Down
1 change: 0 additions & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,6 @@ def unified_flash_attention_fake(


direct_register_custom_op(
library_name="vllm",
op_name="unified_flash_attention",
op_func=unified_flash_attention,
mutates_args=["kv_cache"],
Expand Down
1 change: 0 additions & 1 deletion vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,6 @@ def unified_flash_infer_fake(


direct_register_custom_op(
library_name="vllm",
op_name="unified_flash_infer",
op_func=unified_flash_infer,
mutates_args=["kv_cache"],
Expand Down
2 changes: 0 additions & 2 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None:
return

direct_register_custom_op(
library_name="vllm",
op_name="inplace_all_reduce",
op_func=inplace_all_reduce,
mutates_args=["tensor"],
Expand All @@ -130,7 +129,6 @@ def outplace_all_reduce_fake(tensor: torch.Tensor,
return torch.empty_like(tensor)

direct_register_custom_op(
library_name="vllm",
op_name="outplace_all_reduce",
op_func=outplace_all_reduce,
mutates_args=[],
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ def single_marlin_moe_fake(


direct_register_custom_op(
library_name="vllm",
op_name="single_marlin_moe",
op_func=single_marlin_moe,
mutates_args=[],
Expand Down Expand Up @@ -353,7 +352,6 @@ def fused_marlin_moe_fake(


direct_register_custom_op(
library_name="vllm",
op_name="fused_marlin_moe",
op_func=fused_marlin_moe,
mutates_args=[],
Expand Down
2 changes: 0 additions & 2 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,6 @@ def inplace_fused_experts_fake(


direct_register_custom_op(
library_name="vllm",
op_name="inplace_fused_experts",
op_func=inplace_fused_experts,
mutates_args=["hidden_states"],
Expand Down Expand Up @@ -540,7 +539,6 @@ def outplace_fused_experts_fake(


direct_register_custom_op(
library_name="vllm",
op_name="outplace_fused_experts",
op_func=outplace_fused_experts,
mutates_args=[],
Expand Down
18 changes: 8 additions & 10 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1520,12 +1520,11 @@ def weak_ref_tensors(


def direct_register_custom_op(
library_name: str,
op_name: str,
op_func: Callable,
mutates_args: List[str],
fake_impl: Optional[Callable] = None,
lib: Optional[Library] = None,
target_lib: Optional[Library] = None,
):
"""
`torch.library.custom_op` can have significant overhead because it
Expand All @@ -1534,17 +1533,16 @@ def direct_register_custom_op(
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
for more details.
By default, the custom op is registered to the vLLM library. If you
want to register it to a different library, you can pass the library
object to the `target_lib` argument.
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
library object. It is important to have one line of code
`my_lib = Library(library_name, "FRAGMENT")` outside of the function
to keep the library object alive.
library object. If you want to bind the operator to a different library,
make sure the library object is alive when the operator is used.
"""
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
if library_name == "vllm":
my_lib = vllm_lib
else:
assert lib is not None
my_lib = lib
my_lib = target_lib or vllm_lib
my_lib.define(op_name + schema_str)
my_lib.impl(op_name, op_func, "CUDA")
if fake_impl is not None:
Expand Down
1 change: 0 additions & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ def unified_flash_attention_fake(


direct_register_custom_op(
library_name="vllm",
op_name="unified_flash_attention",
op_func=unified_flash_attention,
mutates_args=["kv_cache"],
Expand Down

0 comments on commit 529b28a

Please sign in to comment.