Skip to content

Commit 0f3e0eb

Browse files
committed
[platform] support pytorch custom op pluggable
Signed-off-by: wangxiyuan <[email protected]>
1 parent 8936316 commit 0f3e0eb

File tree

1 file changed

+19
-12
lines changed

1 file changed

+19
-12
lines changed

vllm/model_executor/custom_op.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,16 @@ def __init__(self):
2020
super().__init__()
2121
self._forward_method = self.dispatch_forward()
2222

23+
@classmethod
24+
def set_foward_method(cls, method):
25+
"""Provide a way to register a custom forward method for a specific
26+
backend."""
27+
if getattr(cls, f"forward_{current_platform.device_name}", None):
28+
raise ValueError(
29+
f"Custom op {cls.__class__.__name__} already has a "
30+
f"forward_{current_platform.device_name} method")
31+
setattr(cls, f"forward_{current_platform.device_name}", method)
32+
2333
def forward(self, *args, **kwargs):
2434
return self._forward_method(*args, **kwargs)
2535

@@ -72,18 +82,15 @@ def dispatch_forward(self):
7282
if not enabled:
7383
return self.forward_native
7484

75-
if current_platform.is_rocm():
76-
return self.forward_hip
77-
elif current_platform.is_cpu():
78-
return self.forward_cpu
79-
elif current_platform.is_hpu():
80-
return self.forward_hpu
81-
elif current_platform.is_tpu():
82-
return self.forward_tpu
83-
elif current_platform.is_xpu():
84-
return self.forward_xpu
85-
else:
86-
return self.forward_cuda
85+
custom_forward_func = \
86+
getattr(self, f"forward_{current_platform.device_name}", None)
87+
if not custom_forward_func:
88+
logger.warning(
89+
"Custom op %s is not supported on %s, falling back "
90+
"to native.", self.__class__.__name__,
91+
current_platform.device_name)
92+
return self.forward_native
93+
return custom_forward_func
8794

8895
@classmethod
8996
def enabled(cls) -> bool:

0 commit comments

Comments
 (0)