Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with vLLM with tensor_parallel_size argument #805

Merged
merged 10 commits into from
Jul 23, 2024
42 changes: 38 additions & 4 deletions src/distilabel/pipeline/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import multiprocessing as mp
import signal
import sys
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, cast
from multiprocessing.pool import Pool
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Union, cast

import tblib

Expand Down Expand Up @@ -48,6 +49,40 @@ def _init_worker(log_queue: "Queue[Any]") -> None:
setup_logging(log_queue)


# We create a custom `Pool` class so the created processes are not daemons, allowing
# them to created child processes if necessary (for example when using `vLLM` with `tensor_parallel_size`)
gabrielmbmb marked this conversation as resolved.
Show resolved Hide resolved
# https://stackoverflow.com/questions/6974695/python-process-pool-non-daemonic
class _NoDaemonProcess(mp.Process):
@property
def daemon(self) -> bool:
return False

@daemon.setter
def daemon(self, value: bool) -> None: # type: ignore
pass


class _NoDaemonContext(type(mp.get_context())):
Process = _NoDaemonProcess


class _NoDaemonPool(Pool):
def __init__(
self,
processes: Union[int, None] = None,
initializer: Union[Callable[..., object], None] = None,
initargs: Iterable[Any] = ..., # type: ignore
maxtasksperchild: Union[int, None] = None,
) -> None:
super().__init__(
processes=processes,
initializer=initializer,
initargs=initargs,
maxtasksperchild=maxtasksperchild,
context=_NoDaemonContext(), # type: ignore
)


class Pipeline(BasePipeline):
"""Local pipeline implementation using `multiprocessing`."""

Expand Down Expand Up @@ -133,10 +168,9 @@ def run(
return distiset

num_processes = self.dag.get_total_replica_count()
ctx = mp.get_context() # type: ignore
with (
ctx.Manager() as manager,
ctx.Pool(
mp.Manager() as manager,
_NoDaemonPool(
num_processes,
initializer=_init_worker,
initargs=(self._log_queue,),
Expand Down
Loading