diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 85f2b856e049e..d151d62516b07 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -37,12 +37,11 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, direct_register_custom_op( - library_name="silly", op_name="attention", op_func=silly_attention, mutates_args=["out"], fake_impl=silly_attention_fake, - lib=silly_lib, + target_lib=silly_lib, ) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index a75303c45ccc4..e3e5a7d0fc5a5 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -35,12 +35,11 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, direct_register_custom_op( - library_name="silly", op_name="attention", op_func=silly_attention, mutates_args=["out"], fake_impl=silly_attention_fake, - lib=silly_lib, + target_lib=silly_lib, ) diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 6244faadb5d37..c294fcf7f08fe 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -774,7 +774,6 @@ def unified_flash_attention_fake( direct_register_custom_op( - library_name="vllm", op_name="unified_flash_attention", op_func=unified_flash_attention, mutates_args=["kv_cache"], diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index ad0b66cc59547..234c87d5c4edb 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -924,7 +924,6 @@ def unified_flash_infer_fake( direct_register_custom_op( - library_name="vllm", op_name="unified_flash_infer", op_func=unified_flash_infer, mutates_args=["kv_cache"], diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index fc984a510b76a..94ba41a016f6d 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -110,7 +110,6 @@ def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None: return direct_register_custom_op( - library_name="vllm", op_name="inplace_all_reduce", op_func=inplace_all_reduce, mutates_args=["tensor"], @@ -130,7 +129,6 @@ def outplace_all_reduce_fake(tensor: torch.Tensor, return torch.empty_like(tensor) direct_register_custom_op( - library_name="vllm", op_name="outplace_all_reduce", op_func=outplace_all_reduce, mutates_args=[], diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index b0828f3e57d74..4741d69de11ac 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -136,7 +136,6 @@ def single_marlin_moe_fake( direct_register_custom_op( - library_name="vllm", op_name="single_marlin_moe", op_func=single_marlin_moe, mutates_args=[], @@ -353,7 +352,6 @@ def fused_marlin_moe_fake( direct_register_custom_op( - library_name="vllm", op_name="fused_marlin_moe", op_func=fused_marlin_moe, mutates_args=[], diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 710778a78b10c..340da32263c1c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -499,7 +499,6 @@ def inplace_fused_experts_fake( direct_register_custom_op( - library_name="vllm", op_name="inplace_fused_experts", op_func=inplace_fused_experts, mutates_args=["hidden_states"], @@ -540,7 +539,6 @@ def outplace_fused_experts_fake( direct_register_custom_op( - library_name="vllm", op_name="outplace_fused_experts", op_func=outplace_fused_experts, mutates_args=[], diff --git a/vllm/utils.py b/vllm/utils.py index 68e8c0efa4514..e989257302858 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1520,12 +1520,11 @@ def weak_ref_tensors( def direct_register_custom_op( - library_name: str, op_name: str, op_func: Callable, mutates_args: List[str], fake_impl: Optional[Callable] = None, - lib: Optional[Library] = None, + target_lib: Optional[Library] = None, ): """ `torch.library.custom_op` can have significant overhead because it @@ -1534,17 +1533,16 @@ def direct_register_custom_op( See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 for more details. + By default, the custom op is registered to the vLLM library. If you + want to register it to a different library, you can pass the library + object to the `target_lib` argument. + IMPORTANT: the lifetime of the operator is tied to the lifetime of the - library object. It is important to have one line of code - `my_lib = Library(library_name, "FRAGMENT")` outside of the function - to keep the library object alive. + library object. If you want to bind the operator to a different library, + make sure the library object is alive when the operator is used. """ schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) - if library_name == "vllm": - my_lib = vllm_lib - else: - assert lib is not None - my_lib = lib + my_lib = target_lib or vllm_lib my_lib.define(op_name + schema_str) my_lib.impl(op_name, op_func, "CUDA") if fake_impl is not None: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 86f4665b27a91..b2af89ebf854a 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -236,7 +236,6 @@ def unified_flash_attention_fake( direct_register_custom_op( - library_name="vllm", op_name="unified_flash_attention", op_func=unified_flash_attention, mutates_args=["kv_cache"],