Skip to content

Commit

Permalink
add lib arg
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao committed Nov 1, 2024
1 parent 2ef5e40 commit 65945e0
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 6 deletions.
1 change: 1 addition & 0 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
lib=silly_lib,
)


Expand Down
1 change: 1 addition & 0 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
op_func=silly_attention,
mutates_args=["out"],
fake_impl=silly_attention_fake,
lib=silly_lib,
)


Expand Down
5 changes: 0 additions & 5 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""

from torch.library import Library

from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
Expand All @@ -16,9 +14,6 @@

from .version import __version__, __version_tuple__

# create a library to hold the custom op
vllm_lib = Library("vllm", "FRAGMENT") # noqa

__all__ = [
"__version__",
"__version_tuple__",
Expand Down
11 changes: 10 additions & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1515,12 +1515,17 @@ def weak_ref_tensors(
raise ValueError("Invalid type for tensors")


# create a library to hold the custom op
vllm_lib = Library("vllm", "FRAGMENT") # noqa


def direct_register_custom_op(
library_name: str,
op_name: str,
op_func: Callable,
mutates_args: List[str],
fake_impl: Optional[Callable] = None,
lib: Optional[Library] = None,
):
"""
`torch.library.custom_op` can have significant overhead because it
Expand All @@ -1535,7 +1540,11 @@ def direct_register_custom_op(
to keep the library object alive.
"""
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
my_lib = Library(library_name, "FRAGMENT")
if library_name == "vllm":
my_lib = vllm_lib
else:
assert lib is not None
my_lib = lib
my_lib.define(op_name + schema_str)
my_lib.impl(op_name, op_func, "CUDA")
if fake_impl is not None:
Expand Down

0 comments on commit 65945e0

Please sign in to comment.