From 65945e0cf906eff5ee9da5a655ebbbe8862f8107 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 31 Oct 2024 21:17:31 -0700 Subject: [PATCH] add lib arg Signed-off-by: youkaichao --- tests/compile/piecewise/test_simple.py | 1 + tests/compile/piecewise/test_toy_llama.py | 1 + vllm/__init__.py | 5 ----- vllm/utils.py | 11 ++++++++++- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index e7d5c55cd997d..85f2b856e049e 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -42,6 +42,7 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, op_func=silly_attention, mutates_args=["out"], fake_impl=silly_attention_fake, + lib=silly_lib, ) diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 5290aad2658ab..a75303c45ccc4 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -40,6 +40,7 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, op_func=silly_attention, mutates_args=["out"], fake_impl=silly_attention_fake, + lib=silly_lib, ) diff --git a/vllm/__init__.py b/vllm/__init__.py index 4b3026fc47fc2..8f477ea84756d 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -1,7 +1,5 @@ """vLLM: a high-throughput and memory-efficient inference engine for LLMs""" -from torch.library import Library - from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.llm_engine import LLMEngine @@ -16,9 +14,6 @@ from .version import __version__, __version_tuple__ -# create a library to hold the custom op -vllm_lib = Library("vllm", "FRAGMENT") # noqa - __all__ = [ "__version__", "__version_tuple__", diff --git a/vllm/utils.py b/vllm/utils.py index 1ff542fdf5934..68e8c0efa4514 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1515,12 +1515,17 @@ def weak_ref_tensors( raise ValueError("Invalid type for tensors") +# create a library to hold the custom op +vllm_lib = Library("vllm", "FRAGMENT") # noqa + + def direct_register_custom_op( library_name: str, op_name: str, op_func: Callable, mutates_args: List[str], fake_impl: Optional[Callable] = None, + lib: Optional[Library] = None, ): """ `torch.library.custom_op` can have significant overhead because it @@ -1535,7 +1540,11 @@ def direct_register_custom_op( to keep the library object alive. """ schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) - my_lib = Library(library_name, "FRAGMENT") + if library_name == "vllm": + my_lib = vllm_lib + else: + assert lib is not None + my_lib = lib my_lib.define(op_name + schema_str) my_lib.impl(op_name, op_func, "CUDA") if fake_impl is not None: