|
1 | 1 | from vllm.logger import init_logger
|
2 | 2 | from vllm.platforms import current_platform
|
| 3 | +from vllm.utils import resolve_obj_by_qualname |
3 | 4 |
|
4 | 5 | from .punica_base import PunicaWrapperBase
|
5 | 6 |
|
6 | 7 | logger = init_logger(__name__)
|
7 | 8 |
|
8 | 9 |
|
9 | 10 | def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
|
10 |
| - if current_platform.is_cuda_alike(): |
11 |
| - # Lazy import to avoid ImportError |
12 |
| - from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU |
13 |
| - logger.info_once("Using PunicaWrapperGPU.") |
14 |
| - return PunicaWrapperGPU(*args, **kwargs) |
15 |
| - elif current_platform.is_cpu(): |
16 |
| - # Lazy import to avoid ImportError |
17 |
| - from vllm.lora.punica_wrapper.punica_cpu import PunicaWrapperCPU |
18 |
| - logger.info_once("Using PunicaWrapperCPU.") |
19 |
| - return PunicaWrapperCPU(*args, **kwargs) |
20 |
| - elif current_platform.is_hpu(): |
21 |
| - # Lazy import to avoid ImportError |
22 |
| - from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU |
23 |
| - logger.info_once("Using PunicaWrapperHPU.") |
24 |
| - return PunicaWrapperHPU(*args, **kwargs) |
25 |
| - else: |
26 |
| - raise NotImplementedError |
| 11 | + punica_wrapper_qualname = current_platform.get_punica_wrapper() |
| 12 | + punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname) |
| 13 | + punica_wrapper = punica_wrapper_cls(*args, **kwargs) |
| 14 | + assert punica_wrapper is not None, \ |
| 15 | + "the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong." |
| 16 | + logger.info_once("Using " + punica_wrapper_qualname.rsplit(".", 1)[1] + |
| 17 | + ".") |
| 18 | + return punica_wrapper |
0 commit comments