-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add multiprocessing HPU executor (#559)
Signed-off-by: Konrad Zawora <[email protected]>
- Loading branch information
1 parent
d1c2e15
commit 0933513
Showing
4 changed files
with
58 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from typing import Callable, Optional, Tuple, Type | ||
|
||
import habana_frameworks.torch # noqa: F401 | ||
import torch | ||
|
||
from vllm.executor.multiproc_gpu_executor import ( | ||
MultiprocessingGPUExecutor, MultiprocessingGPUExecutorAsync) | ||
from vllm.logger import init_logger | ||
from vllm.utils import make_async | ||
from vllm.worker.worker_base import WorkerBase | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class MultiprocessingHPUExecutor(MultiprocessingGPUExecutor): | ||
"""Python multiprocessing-based multi-HPU executor""" | ||
|
||
def _get_worker_module_and_class( | ||
self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: | ||
worker_class_fn = None | ||
if self.speculative_config is not None: | ||
module_name = "vllm.spec_decode.spec_decode_worker" | ||
class_name = "create_spec_worker" | ||
else: | ||
module_name = "vllm.worker.hpu_worker" | ||
class_name = "HPUWorker" | ||
return (module_name, class_name, worker_class_fn) | ||
|
||
def _check_executor_parameters(self): | ||
world_size = self.parallel_config.world_size | ||
tensor_parallel_size = self.parallel_config.tensor_parallel_size | ||
|
||
hpu_device_count = torch.hpu.device_count() | ||
assert tensor_parallel_size <= hpu_device_count, ( | ||
f"please set tensor_parallel_size ({tensor_parallel_size}) " | ||
f"to less than max local hpu count ({hpu_device_count})") | ||
|
||
assert world_size <= hpu_device_count, ( | ||
f"please ensure that world_size ({world_size}) " | ||
f"is less than than max local hpu count ({hpu_device_count})") | ||
|
||
|
||
class MultiprocessingHPUExecutorAsync(MultiprocessingHPUExecutor, | ||
MultiprocessingGPUExecutorAsync): | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.driver_exec_model = make_async(self.driver_worker.execute_model) |