Skip to content

Commit

Permalink
Compatibility with vLLM with tensor_parallel_size argument (#805)
Browse files Browse the repository at this point in the history
* Add `_NoDaemonPool` class

* Use `Union`

* Update src/distilabel/pipeline/local.py

Co-authored-by: Agus <[email protected]>

* Update dependency version to `vllm>=0.5.3` and add `setuptools`

* Remove pinned `outlines==0.34.0`

* Fix docstring

* Add docs about `vLLM` with `ray`

---------

Co-authored-by: Agus <[email protected]>
  • Loading branch information
gabrielmbmb and plaguss authored Jul 23, 2024
1 parent ea1c44b commit b7f124f
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 8 deletions.
33 changes: 33 additions & 0 deletions docs/sections/how_to_guides/advanced/scaling_with_ray.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,36 @@ ray job submit --address http://localhost:8265 --working-dir ray-pipeline -- pyt
1. In this case, we just want two nodes: one to run the Ray head node and one to run a worker.
2. We just want to run a task per node i.e. the Ray command that starts the head/worker node.
3. We have selected 1 GPU per node, but we could have selected more depending on the pipeline.

## `vLLM` and `tensor_parallel_size`

In order to use `vLLM` multi-GPU and multi-node capabilities with `ray`, we need to do a few changes in the example pipeline from above. The first change needed is to specify a value for `tensor_parallel_size` aka "In how many GPUs do I want you to load the model", and the second one is to define `ray` as the `distributed_executor_backend` as the default one in `vLLM` is to use `multiprocessing`:


```python
with Pipeline(name="text-generation-ray-pipeline") as pipeline:
load_data_from_hub = LoadDataFromHub(output_mappings={"prompt": "instruction"})
text_generation = TextGeneration(
llm=vLLM(
model="meta-llama/Meta-Llama-3.1-70B-Instruct",
tokenizer="meta-llama/Meta-Llama-3.1-70B-Instruct",
extra_kwargs={
"tensor_parallel_size": 8,
"distributed_executor_backend": "ray",
}
)
)
load_data_from_hub >> text_generation
```

Finally, we need to define two environment variables in our `runtime_env.yaml` file:

```yaml
env_vars:
VLLM_USE_RAY_COMPILED_DAG: "1"
VLLM_USE_RAY_SPMD_WORKER: "1"
```

More information about distributed inference with `vLLM` can be found here: [vLLM - Distributed Serving](https://docs.vllm.ai/en/latest/serving/distributed_serving.html)
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,12 @@ openai = ["openai >= 1.0.0"]
outlines = ["outlines >= 0.0.40"]
ray = ["ray[default] >= 2.31.0"]
vertexai = ["google-cloud-aiplatform >= 1.38.0"]
vllm = ["vllm >= 0.4.0", "outlines == 0.0.34", "filelock >= 3.13.4"]
vllm = [
"vllm >= 0.5.3",
"filelock >= 3.13.4",
# `setuptools` is needed to be installed if installed with `uv pip install distilabel[vllm]`
"setuptools",
]

[project.urls]
Documentation = "https://distilabel.argilla.io/"
Expand Down
6 changes: 3 additions & 3 deletions src/distilabel/llms/huggingface/inference_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ async def agenerate( # type: ignore
input: a single input in chat format to generate responses for.
max_new_tokens: the maximum number of new tokens that the model will generate.
Defaults to `128`.
frequence_penalty: a value between `-2.0` and `2.0`. Positive values penalize
frequency_penalty: a value between `-2.0` and `2.0`. Positive values penalize
new tokens based on their existing frequency in the text so far, decreasing
model's likelihood to repeat the same line verbatim. Defauls to `None`.
logit_bias: modify the likelihood of specified tokens appearing in the completion.
Expand Down Expand Up @@ -545,8 +545,8 @@ async def agenerate( # type: ignore
only if `tokenizer_id` is `None`. Defaults to `None`.
top_p: the top-p value to use for the generation. Defaults to `1.0`.
do_sample: whether to use sampling for the generation. This argument is exclusive
of the `text_generation` method and will be only used if `tokenizer_id` is not
`None`. Defaults to `False`.
of the `text_generation` method and will be only used if `tokenizer_id` is not
`None`. Defaults to `False`.
repetition_penalty: the repetition penalty to use for the generation. This argument
is exclusive of the `text_generation` method and will be only used if `tokenizer_id`
is not `None`. Defaults to `None`.
Expand Down
42 changes: 38 additions & 4 deletions src/distilabel/pipeline/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import multiprocessing as mp
import signal
import sys
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union, cast
from multiprocessing.pool import Pool
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, Union, cast

import tblib

Expand Down Expand Up @@ -48,6 +49,40 @@ def _init_worker(log_queue: "Queue[Any]") -> None:
setup_logging(log_queue)


# We create a custom `Pool` class so the created processes are not daemons, allowing
# them to create child processes if necessary (for example when using `vLLM` with `tensor_parallel_size`)
# https://stackoverflow.com/questions/6974695/python-process-pool-non-daemonic
class _NoDaemonProcess(mp.Process):
@property
def daemon(self) -> bool:
return False

@daemon.setter
def daemon(self, value: bool) -> None: # type: ignore
pass


class _NoDaemonContext(type(mp.get_context())):
Process = _NoDaemonProcess


class _NoDaemonPool(Pool):
def __init__(
self,
processes: Union[int, None] = None,
initializer: Union[Callable[..., object], None] = None,
initargs: Iterable[Any] = ..., # type: ignore
maxtasksperchild: Union[int, None] = None,
) -> None:
super().__init__(
processes=processes,
initializer=initializer,
initargs=initargs,
maxtasksperchild=maxtasksperchild,
context=_NoDaemonContext(), # type: ignore
)


class Pipeline(BasePipeline):
"""Local pipeline implementation using `multiprocessing`."""

Expand Down Expand Up @@ -133,10 +168,9 @@ def run(
return distiset

num_processes = self.dag.get_total_replica_count()
ctx = mp.get_context() # type: ignore
with (
ctx.Manager() as manager,
ctx.Pool(
mp.Manager() as manager,
_NoDaemonPool(
num_processes,
initializer=_init_worker,
initargs=(self._log_queue,),
Expand Down

0 comments on commit b7f124f

Please sign in to comment.