Skip to content

Commit 27a2924

Browse files
shen-shanshanUbuntu
authored and
Ubuntu
committed
[Platform] Move get_punica_wrapper() function to Platform (vllm-project#11516)
Signed-off-by: Shanshan Shen <[email protected]>
1 parent 7033d82 commit 27a2924

File tree

6 files changed

+32
-17
lines changed

6 files changed

+32
-17
lines changed
+9-17
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,18 @@
11
from vllm.logger import init_logger
22
from vllm.platforms import current_platform
3+
from vllm.utils import resolve_obj_by_qualname
34

45
from .punica_base import PunicaWrapperBase
56

67
logger = init_logger(__name__)
78

89

910
def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
10-
if current_platform.is_cuda_alike():
11-
# Lazy import to avoid ImportError
12-
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
13-
logger.info_once("Using PunicaWrapperGPU.")
14-
return PunicaWrapperGPU(*args, **kwargs)
15-
elif current_platform.is_cpu():
16-
# Lazy import to avoid ImportError
17-
from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU
18-
logger.info_once("Using PunicaWrapperCPU.")
19-
return PunicaWrapperCPU(*args, **kwargs)
20-
elif current_platform.is_hpu():
21-
# Lazy import to avoid ImportError
22-
from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU
23-
logger.info_once("Using PunicaWrapperHPU.")
24-
return PunicaWrapperHPU(*args, **kwargs)
25-
else:
26-
raise NotImplementedError
11+
punica_wrapper_qualname = current_platform.get_punica_wrapper()
12+
punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname)
13+
punica_wrapper = punica_wrapper_cls(*args, **kwargs)
14+
assert punica_wrapper is not None, \
15+
"the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong."
16+
logger.info_once("Using " + punica_wrapper_qualname.rsplit(".", 1)[1] +
17+
".")
18+
return punica_wrapper

vllm/platforms/cpu.py

+4
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
109109
def is_pin_memory_available(cls) -> bool:
110110
logger.warning("Pin memory is not supported on CPU.")
111111
return False
112+
113+
@classmethod
114+
def get_punica_wrapper(cls) -> str:
115+
return "vllm.lora.punica_wrapper.punica_cpu.PunicaWrapperCPU"

vllm/platforms/cuda.py

+4
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,10 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
218218
logger.info("Using Flash Attention backend.")
219219
return "vllm.attention.backends.flash_attn.FlashAttentionBackend"
220220

221+
@classmethod
222+
def get_punica_wrapper(cls) -> str:
223+
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"
224+
221225

222226
# NVML utils
223227
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,

vllm/platforms/hpu.py

+4
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
6363
def is_pin_memory_available(cls):
6464
logger.warning("Pin memory is not supported on HPU.")
6565
return False
66+
67+
@classmethod
68+
def get_punica_wrapper(cls) -> str:
69+
return "vllm.lora.punica_wrapper.punica_hpu.PunicaWrapperHPU"

vllm/platforms/interface.py

+7
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,13 @@ def is_pin_memory_available(cls) -> bool:
276276
return False
277277
return True
278278

279+
@classmethod
280+
def get_punica_wrapper(cls) -> str:
281+
"""
282+
Return the punica wrapper for current platform.
283+
"""
284+
raise NotImplementedError
285+
279286

280287
class UnspecifiedPlatform(Platform):
281288
_enum = PlatformEnum.UNSPECIFIED

vllm/platforms/rocm.py

+4
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,7 @@ def verify_quantization(cls, quant: str) -> None:
153153
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
154154
" is not set, enabling VLLM_USE_TRITON_AWQ.")
155155
envs.VLLM_USE_TRITON_AWQ = True
156+
157+
@classmethod
158+
def get_punica_wrapper(cls) -> str:
159+
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

0 commit comments

Comments
 (0)