Skip to content

Commit

Permalink
Refactor Pipeline and BasePipeline classes (#704)
Browse files Browse the repository at this point in the history
* Move classes to different files

* Add `_send_to_step` abstractmethod

* Move `_request_initial_batches` method to `BasePipeline`

* Move `_notify_step_to_stop` method to `BasePipeline`

* Move `_handle_batch_on_stop` method to `BasePipeline`

* Move `LAST_BATCH_FLAG_SENT` constant

* Move `_request_more_batches_if_needed` method to `BasePipeline`

* Move `_register_batch` method to `BasePipeline`

* Move `_get_successors` method to `BasePipeline`

* Move `_get_step_from_batch` method to `BasePipeline`

* Move `_manage_batch_flow` method to `BasePipeline`

* Add `_get_from_step` abstract method

* Add `_add_batches_back_to_batch_manager` method

* Add `_consume_output_queue` method

* Add `_create_step_input_queue` method

* Add `_run_step` abstract method

* Move `_handle_keyboard_interrupt` method

* Add `_load_queue`

* Add `_init_steps_load_status` method

* Move `_all_steps_loaded` method

* Move `_check_step_not_loaded_or_finished` method

* Move `_handle_stop` method

* Move `_run_output_queue_loop` method

* Remove unused variables

* Fix unit tests

* Remove shared dict info and update `CudaDevicePlacementMixin`

* Add `unload` method

* Add `portalocker` dependency

* Add missing unload

* Add `_OLD_IMPORT_MODULE_ATTR` dict

* Fix `override` import

* Remove log message

* Add missing call to `unload`
  • Loading branch information
gabrielmbmb authored Jun 12, 2024
1 parent d32d664 commit 2ea3f43
Show file tree
Hide file tree
Showing 24 changed files with 5,134 additions and 4,356 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"tblib >= 3.0.0",
"orjson >= 3.10.0",
"universal_pathlib >= 0.2.2",
"portalocker >= 2.8.2",
]
dynamic = ["version"]

Expand Down
4 changes: 4 additions & 0 deletions src/distilabel/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def load(self) -> None:
"""Method to be called to initialize the `LLM`, its logger and optionally the structured output generator."""
self._logger = logging.getLogger(f"distilabel.llm.{self.model_name}")

def unload(self) -> None:
"""Method to be called to unload the `LLM` and release any resources."""
pass

@property
@abstractmethod
def model_name(self) -> str:
Expand Down
5 changes: 5 additions & 0 deletions src/distilabel/llms/huggingface/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,11 @@ def load(self) -> None:

super().load()

def unload(self) -> None:
"""Unloads the `vLLM` model."""
CudaDevicePlacementMixin.unload(self)
super().unload()

@property
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
Expand Down
88 changes: 49 additions & 39 deletions src/distilabel/llms/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Generator, List, Literal, Union

import portalocker
from pydantic import BaseModel, Field, PrivateAttr

if TYPE_CHECKING:
from multiprocessing.managers import DictProxy
from multiprocessing.synchronize import Lock
_CUDA_DEVICE_PLACEMENT_MIXIN_FILE = (
Path(tempfile.gettempdir()) / "distilabel_cuda_device_placement_mixin.json"
)


class CudaDevicePlacementMixin(BaseModel):
Expand All @@ -44,11 +49,7 @@ class CudaDevicePlacementMixin(BaseModel):
cuda_devices: Union[List[int], Literal["auto"]] = Field(default="auto")

_llm_identifier: Union[str, None] = PrivateAttr(default=None)
_device_llm_placement_map: Union["DictProxy[str, Any]", None] = PrivateAttr(
default=None
)
_device_llm_placement_lock: Union["Lock", None] = PrivateAttr(default=None)
_available_cuda_devices: Union[List[int], None] = PrivateAttr(default=None)
_available_cuda_devices: List[int] = PrivateAttr(default_factory=list)
_can_check_cuda_devices: bool = PrivateAttr(default=False)

def load(self) -> None:
Expand Down Expand Up @@ -77,29 +78,40 @@ def load(self) -> None:

self._assign_cuda_devices()

def set_device_placement_info(
self,
llm_identifier: str,
device_llm_placement_map: "DictProxy[str, Any]",
device_llm_placement_lock: "Lock",
) -> None:
"""Sets the value of `_device_llm_placement_map` to be used to assign CUDA devices
to the LLM.
def unload(self) -> None:
"""Unloads the LLM and removes the CUDA devices assigned to it from the device
placement information provided in `_device_llm_placement_map`."""
with self._device_llm_placement_map() as device_map:
if self._llm_identifier in device_map:
self._logger.debug(
f"Removing '{self._llm_identifier}' from the CUDA device map file"
f" '{_CUDA_DEVICE_PLACEMENT_MIXIN_FILE}'."
)
del device_map[self._llm_identifier]

Args:
llm_identifier: the identifier of the LLM to be used as key in the device
placement information.
device_llm_placement_map: a dictionary with the device placement information for
each LLM. It should have two keys. The first key is "lock" and its value is
a lock object to be used to synchronize the access to the device placement
information. The second key is "value" and its value is a dictionary with the
device placement information for each LLM.
device_llm_placement_lock: a lock object to be used to synchronize the access to
`_device_llm_placement_map`.
@contextmanager
def _device_llm_placement_map(self) -> Generator[Dict[str, List[int]], None, None]:
"""Reads the content of the device placement file of the node with a lock, yields
the content, and writes the content back to the file after the context manager is
closed. If the file doesn't exist, an empty dictionary will be yielded.
Yields:
The content of the device placement file.
"""
self._llm_identifier = llm_identifier
self._device_llm_placement_map = device_llm_placement_map
self._device_llm_placement_lock = device_llm_placement_lock
_CUDA_DEVICE_PLACEMENT_MIXIN_FILE.touch()
with portalocker.Lock(
_CUDA_DEVICE_PLACEMENT_MIXIN_FILE,
"r+",
flags=portalocker.LockFlags.EXCLUSIVE,
) as f:
try:
content = json.load(f)
except json.JSONDecodeError:
content = {}
yield content
f.seek(0)
f.truncate()
f.write(json.dumps(content))

def _assign_cuda_devices(self) -> None:
"""Assigns CUDA devices to the LLM based on the device placement information provided
Expand All @@ -109,16 +121,14 @@ def _assign_cuda_devices(self) -> None:
checked if the devices are available to be used by the LLM. If not, a warning will be
logged."""

if self._device_llm_placement_map is not None:
with self._device_llm_placement_lock: # type: ignore
if self.cuda_devices == "auto":
self.cuda_devices = [
self._get_cuda_device(self._device_llm_placement_map)
]
else:
self._check_cuda_devices(self._device_llm_placement_map)
# Take the lock and read the device placement information for each LLM.
with self._device_llm_placement_map() as device_map:
if self.cuda_devices == "auto":
self.cuda_devices = [self._get_cuda_device(device_map)]
else:
self._check_cuda_devices(device_map)

self._device_llm_placement_map[self._llm_identifier] = self.cuda_devices # type: ignore
device_map[self._llm_identifier] = self.cuda_devices # type: ignore

# `_device_llm_placement_map` was not provided and user didn't set the `cuda_devices`
# attribute. In this case, the `cuda_devices` attribute will be set to an empty list.
Expand Down
5 changes: 5 additions & 0 deletions src/distilabel/llms/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,11 @@ def load(self) -> None:
self.structured_output
)

def unload(self) -> None:
"""Unloads the `vLLM` model."""
CudaDevicePlacementMixin.unload(self)
super().unload()

@property
def model_name(self) -> str:
"""Returns the model name used for the LLM."""
Expand Down
Loading

0 comments on commit 2ea3f43

Please sign in to comment.