From 2ea3f43cd4f1c73e249d6fe7db626a70e2f7d238 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Wed, 12 Jun 2024 16:39:03 +0200 Subject: [PATCH] Refactor `Pipeline` and `BasePipeline` classes (#704) * Move classes to different files * Add `_send_to_step` abstractmethod * Move `_request_initial_batches` method to `BasePipeline` * Move `_notify_step_to_stop` method to `BasePipeline` * Move `_handle_batch_on_stop` method to `BasePipeline` * Move `LAST_BATCH_FLAG_SENT` constant * Move `_request_more_batches_if_needed` method to `BasePipeline` * Move `_register_batch` method to `BasePipeline` * Move `_get_successors` method to `BasePipeline` * Move `_get_step_from_batch` method to `BasePipeline` * Move `_manage_batch_flow` method to `BasePipeline` * Add `_get_from_step` abstract method * Add `_add_batches_back_to_batch_manager` method * Add `_consume_output_queue` method * Add `_create_step_input_queue` method * Add `_run_step` abstract method * Move `_handle_keyboard_interrupt` method * Add `_load_queue` * Add `_init_steps_load_status` method * Move `_all_steps_loaded` method * Move `_check_step_not_loaded_or_finished` method * Move `_handle_stop` method * Move `_run_output_queue_loop` method * Remove unused variables * Fix unit tests * Remove shared dict info and update `CudaDevicePlacementMixin` * Add `unload` method * Add `portalocker` dependency * Add missing unload * Add `_OLD_IMPORT_MODULE_ATTR` dict * Fix `override` import * Remove log message * Add missing call to `unload` --- pyproject.toml | 1 + src/distilabel/llms/base.py | 4 + .../llms/huggingface/transformers.py | 5 + src/distilabel/llms/mixins.py | 88 +- src/distilabel/llms/vllm.py | 5 + src/distilabel/pipeline/base.py | 1564 +++----- src/distilabel/pipeline/batch.py | 233 ++ src/distilabel/pipeline/batch_manager.py | 896 +++++ src/distilabel/pipeline/constants.py | 1 + src/distilabel/pipeline/local.py | 621 +--- .../pipeline/routing_batch_function.py | 2 +- src/distilabel/pipeline/typing.py | 10 +- src/distilabel/pipeline/write_buffer.py | 168 + src/distilabel/steps/base.py | 6 + src/distilabel/steps/tasks/base.py | 9 +- src/distilabel/utils/serialization.py | 31 +- tests/unit/llms/test_mixins.py | 131 +- tests/unit/pipeline/test_base.py | 3142 +++-------------- tests/unit/pipeline/test_batch.py | 172 + tests/unit/pipeline/test_batch_manager.py | 2214 ++++++++++++ tests/unit/pipeline/test_local.py | 33 +- .../pipeline/test_routing_batch_function.py | 2 +- tests/unit/pipeline/test_write_buffer.py | 150 + tests/unit/pipeline/utils.py | 2 +- 24 files changed, 5134 insertions(+), 4356 deletions(-) create mode 100644 src/distilabel/pipeline/batch.py create mode 100644 src/distilabel/pipeline/batch_manager.py create mode 100644 src/distilabel/pipeline/write_buffer.py create mode 100644 tests/unit/pipeline/test_batch.py create mode 100644 tests/unit/pipeline/test_batch_manager.py create mode 100644 tests/unit/pipeline/test_write_buffer.py diff --git a/pyproject.toml b/pyproject.toml index 1e29af9d94..df771ae20c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "tblib >= 3.0.0", "orjson >= 3.10.0", "universal_pathlib >= 0.2.2", + "portalocker >= 2.8.2", ] dynamic = ["version"] diff --git a/src/distilabel/llms/base.py b/src/distilabel/llms/base.py index e94e06e8ec..55250da791 100644 --- a/src/distilabel/llms/base.py +++ b/src/distilabel/llms/base.py @@ -85,6 +85,10 @@ def load(self) -> None: """Method to be called to initialize the `LLM`, its logger and optionally the structured output generator.""" self._logger = logging.getLogger(f"distilabel.llm.{self.model_name}") + def unload(self) -> None: + """Method to be called to unload the `LLM` and release any resources.""" + pass + @property @abstractmethod def model_name(self) -> str: diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index 1b73eaa469..dc428f4283 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -140,6 +140,11 @@ def load(self) -> None: super().load() + def unload(self) -> None: + """Unloads the `vLLM` model.""" + CudaDevicePlacementMixin.unload(self) + super().unload() + @property def model_name(self) -> str: """Returns the model name used for the LLM.""" diff --git a/src/distilabel/llms/mixins.py b/src/distilabel/llms/mixins.py index 9146a0ef4f..1d2e8b35a0 100644 --- a/src/distilabel/llms/mixins.py +++ b/src/distilabel/llms/mixins.py @@ -12,14 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Union +import tempfile +from contextlib import contextmanager +from pathlib import Path +from typing import Dict, Generator, List, Literal, Union +import portalocker from pydantic import BaseModel, Field, PrivateAttr -if TYPE_CHECKING: - from multiprocessing.managers import DictProxy - from multiprocessing.synchronize import Lock +_CUDA_DEVICE_PLACEMENT_MIXIN_FILE = ( + Path(tempfile.gettempdir()) / "distilabel_cuda_device_placement_mixin.json" +) class CudaDevicePlacementMixin(BaseModel): @@ -44,11 +49,7 @@ class CudaDevicePlacementMixin(BaseModel): cuda_devices: Union[List[int], Literal["auto"]] = Field(default="auto") _llm_identifier: Union[str, None] = PrivateAttr(default=None) - _device_llm_placement_map: Union["DictProxy[str, Any]", None] = PrivateAttr( - default=None - ) - _device_llm_placement_lock: Union["Lock", None] = PrivateAttr(default=None) - _available_cuda_devices: Union[List[int], None] = PrivateAttr(default=None) + _available_cuda_devices: List[int] = PrivateAttr(default_factory=list) _can_check_cuda_devices: bool = PrivateAttr(default=False) def load(self) -> None: @@ -77,29 +78,40 @@ def load(self) -> None: self._assign_cuda_devices() - def set_device_placement_info( - self, - llm_identifier: str, - device_llm_placement_map: "DictProxy[str, Any]", - device_llm_placement_lock: "Lock", - ) -> None: - """Sets the value of `_device_llm_placement_map` to be used to assign CUDA devices - to the LLM. + def unload(self) -> None: + """Unloads the LLM and removes the CUDA devices assigned to it from the device + placement information provided in `_device_llm_placement_map`.""" + with self._device_llm_placement_map() as device_map: + if self._llm_identifier in device_map: + self._logger.debug( + f"Removing '{self._llm_identifier}' from the CUDA device map file" + f" '{_CUDA_DEVICE_PLACEMENT_MIXIN_FILE}'." + ) + del device_map[self._llm_identifier] - Args: - llm_identifier: the identifier of the LLM to be used as key in the device - placement information. - device_llm_placement_map: a dictionary with the device placement information for - each LLM. It should have two keys. The first key is "lock" and its value is - a lock object to be used to synchronize the access to the device placement - information. The second key is "value" and its value is a dictionary with the - device placement information for each LLM. - device_llm_placement_lock: a lock object to be used to synchronize the access to - `_device_llm_placement_map`. + @contextmanager + def _device_llm_placement_map(self) -> Generator[Dict[str, List[int]], None, None]: + """Reads the content of the device placement file of the node with a lock, yields + the content, and writes the content back to the file after the context manager is + closed. If the file doesn't exist, an empty dictionary will be yielded. + + Yields: + The content of the device placement file. """ - self._llm_identifier = llm_identifier - self._device_llm_placement_map = device_llm_placement_map - self._device_llm_placement_lock = device_llm_placement_lock + _CUDA_DEVICE_PLACEMENT_MIXIN_FILE.touch() + with portalocker.Lock( + _CUDA_DEVICE_PLACEMENT_MIXIN_FILE, + "r+", + flags=portalocker.LockFlags.EXCLUSIVE, + ) as f: + try: + content = json.load(f) + except json.JSONDecodeError: + content = {} + yield content + f.seek(0) + f.truncate() + f.write(json.dumps(content)) def _assign_cuda_devices(self) -> None: """Assigns CUDA devices to the LLM based on the device placement information provided @@ -109,16 +121,14 @@ def _assign_cuda_devices(self) -> None: checked if the devices are available to be used by the LLM. If not, a warning will be logged.""" - if self._device_llm_placement_map is not None: - with self._device_llm_placement_lock: # type: ignore - if self.cuda_devices == "auto": - self.cuda_devices = [ - self._get_cuda_device(self._device_llm_placement_map) - ] - else: - self._check_cuda_devices(self._device_llm_placement_map) + # Take the lock and read the device placement information for each LLM. + with self._device_llm_placement_map() as device_map: + if self.cuda_devices == "auto": + self.cuda_devices = [self._get_cuda_device(device_map)] + else: + self._check_cuda_devices(device_map) - self._device_llm_placement_map[self._llm_identifier] = self.cuda_devices # type: ignore + device_map[self._llm_identifier] = self.cuda_devices # type: ignore # `_device_llm_placement_map` was not provided and user didn't set the `cuda_devices` # attribute. In this case, the `cuda_devices` attribute will be set to an empty list. diff --git a/src/distilabel/llms/vllm.py b/src/distilabel/llms/vllm.py index de8ea44fd7..4a9ce88f75 100644 --- a/src/distilabel/llms/vllm.py +++ b/src/distilabel/llms/vllm.py @@ -189,6 +189,11 @@ def load(self) -> None: self.structured_output ) + def unload(self) -> None: + """Unloads the `vLLM` model.""" + CudaDevicePlacementMixin.unload(self) + super().unload() + @property def model_name(self) -> str: """Returns the model name used for the LLM.""" diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py index 88eb127ebb..cb3b0625a7 100644 --- a/src/distilabel/pipeline/base.py +++ b/src/distilabel/pipeline/base.py @@ -12,57 +12,58 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import hashlib import logging import os -from collections import defaultdict -from dataclasses import dataclass, field +import signal +import threading +import time +from abc import ABC, abstractmethod from pathlib import Path from typing import ( TYPE_CHECKING, Any, + Callable, Dict, - Iterable, List, Optional, - Set, Tuple, TypedDict, Union, ) import fsspec -import pyarrow as pa -import pyarrow.parquet as pq from typing_extensions import Self from upath import UPath from distilabel import __version__ from distilabel.distiset import create_distiset from distilabel.pipeline._dag import DAG +from distilabel.pipeline.batch import _Batch +from distilabel.pipeline.batch_manager import _BatchManager from distilabel.pipeline.constants import ( + CONVERGENCE_STEP_ATTR_NAME, + INPUT_QUEUE_ATTR_NAME, + LAST_BATCH_SENT_FLAG, RECEIVES_ROUTED_BATCHES_ATTR_NAME, ROUTING_BATCH_FUNCTION_ATTR_NAME, STEP_ATTR_NAME, ) -from distilabel.utils.dicts import flatten_dict -from distilabel.utils.files import list_files_in_dir +from distilabel.pipeline.write_buffer import _WriteBuffer from distilabel.utils.logging import setup_logging, stop_logging from distilabel.utils.serialization import ( TYPE_INFO_KEY, - _check_is_dir, _Serializable, - read_json, ) if TYPE_CHECKING: from os import PathLike + from queue import Queue from distilabel.distiset import Distiset from distilabel.pipeline.routing_batch_function import RoutingBatchFunction - from distilabel.steps.base import _Step - from distilabel.utils.serialization import StrOrPath + from distilabel.pipeline.typing import StepLoadStatus + from distilabel.steps.base import Step, _Step BASE_CACHE_DIR = Path.home() / ".cache" / "distilabel" / "pipelines" @@ -114,7 +115,11 @@ def get_pipeline(cls) -> Union["BasePipeline", None]: return cls._context_global_pipeline -class BasePipeline(_Serializable): +_STEP_LOAD_FAILED_CODE = -666 +_STEP_NOT_LOADED_CODE = -999 + + +class BasePipeline(ABC, _Serializable): """Base class for a `distilabel` pipeline. Attributes: @@ -142,8 +147,14 @@ class BasePipeline(_Serializable): Defaults to `False`. _dry_run: A flag to indicate if the pipeline is running in dry run mode. Defaults to `False`. + output_queue: A queue to store the output of the steps while running the pipeline. + load_queue: A queue used by each `Step` to notify the main process it has finished + loading or it the step has been unloaded. """ + _output_queue: "Queue[Any]" + _load_queue: "Queue[Union[StepLoadStatus, None]]" + def __init__( self, name: str, @@ -182,10 +193,18 @@ def __init__( "filename": self._cache_location["log_file"] } + self._steps_load_status: Dict[str, int] = {} + self._steps_load_status_lock = threading.Lock() + + self._stop_called = False + self._stop_called_lock = threading.Lock() + self._stop_calls = 0 + self._fs: Optional[fsspec.AbstractFileSystem] = None self._storage_base_path: Optional[str] = None self._use_fs_to_pass_data: bool = False - self._dry_run: bool = False + + self._dry_run = False def __enter__(self) -> Self: """Set the global pipeline instance when entering a pipeline context.""" @@ -310,6 +329,8 @@ def run( } ) + self._init_steps_load_status() + # Validate the pipeline DAG to check that all the steps are chainable, there are # no missing runtime parameters, batch sizes are correct, etc. self.dag.validate() @@ -390,6 +411,12 @@ def get_runtime_parameters_info(self) -> Dict[str, List[Dict[str, Any]]]: runtime_parameters[step_name] = step.get_runtime_parameters_info() return runtime_parameters + def _init_steps_load_status(self) -> None: + """Initialize the `_steps_load_status` dictionary assigning 0 to every step of + the pipeline.""" + for step_name in self.dag: + self._steps_load_status[step_name] = _STEP_NOT_LOADED_CODE + def _setup_fsspec( self, storage_parameters: Optional[Dict[str, Any]] = None ) -> None: @@ -451,6 +478,14 @@ def _add_edge(self, from_step: str, to_step: str) -> None: value=routing_batch_function is not None, ) + def _is_convergence_step(self, step_name: str) -> None: + """Checks if a step is a convergence step. + + Args: + step_name: The name of the step. + """ + return self.dag.get_step(step_name).get(CONVERGENCE_STEP_ATTR_NAME) + def _add_routing_batch_function( self, step_name: str, routing_batch_function: "RoutingBatchFunction" ) -> None: @@ -576,1244 +611,521 @@ def _setup_write_buffer(self) -> None: self._logger.info(f"📝 Pipeline data will be written to '{buffer_data_path}'") self._write_buffer = _WriteBuffer(buffer_data_path, self.dag.leaf_steps) - def _send_batch_to_step(self, batch: "_Batch") -> None: - """Sends a batch to the input queue of a step, writing the data of the batch - to the filesystem and setting `batch.data_path` with the path where the data - was written (if requiered i.e. the step is a global step or `use_fs_to_pass_data`) + def _run_output_queue_loop_in_thread(self) -> threading.Thread: + """Runs the output queue loop in a separate thread to receive the output batches + from the steps. This is done to avoid the signal handler to block the loop, which + would prevent the pipeline from stopping correctly.""" + thread = threading.Thread(target=self._output_queue_loop) + thread.start() + return thread + + def _output_queue_loop(self) -> None: + """Loop to receive the output batches from the steps and manage the flow of the + batches through the pipeline.""" + while self._batch_manager.can_generate() and not self._stop_called: # type: ignore + self._logger.debug("Waiting for output batch from step...") + if (batch := self._output_queue.get()) is None: + self._logger.debug("Received `None` from output queue. Breaking loop.") + break - This method should be extended by the specific pipeline implementation, adding - the logic to send the batch to the step. - - Args: - batch: The batch to send. - """ - self._logger.debug( - f"Setting batch {batch.seq_no} as last batch sent to '{batch.step_name}': {batch}" - ) - self._batch_manager.set_last_batch_sent(batch) # type: ignore - - step: "_Step" = self.dag.get_step(batch.step_name)[STEP_ATTR_NAME] - if not step.is_generator and (step.is_global or self._use_fs_to_pass_data): - base_path = UPath(self._storage_base_path) / step.name # type: ignore self._logger.debug( - f"Writing {batch.seq_no} batch for '{batch.step_name}' step to filesystem: {base_path}" + f"Received batch with seq_no {batch.seq_no} from step '{batch.step_name}'" + f" from output queue: {batch}" ) - batch.write_batch_data_to_fs(self._fs, base_path) # type: ignore - - self._logger.debug( - f"Sending batch {batch.seq_no} to step '{batch.step_name}': {batch}" - ) - - -@dataclass -class _Batch(_Serializable): - """Dataclass to represent a batch of data to be processed by a `_Step`. - - Attributes: - seq_no: The sequence number of the batch. - step_name: The name of the step that will process the batch. - last_batch: A flag to indicate if the batch is the last one. - data: The data to be processed. - data_hash: The hash of the data. Defaults to `None`. - data_path: The path where the data of the batch is stored. Defaults to `None`. - accumulated: A flag to indicate if the batch is accumulated. - created_from: A dictionary containing the `seq_no` of the batches of the steps that - were used to create this batch. - size: The size of the batch. - """ - - seq_no: int - step_name: str - last_batch: bool - data: List[List[Dict[str, Any]]] = field(default_factory=list, repr=False) - data_hash: Optional[str] = None - data_path: Optional[str] = None - accumulated: bool = False - created_from: Dict[str, List[Tuple[int, int]]] = field(default_factory=dict) - batch_routed_to: List[str] = field(default_factory=list) - size: int = 0 - _fs: Optional[fsspec.AbstractFileSystem] = None - - def next_batch(self) -> "_Batch": - """Create a new `_Batch` instance with the next batch of data. - Args: - data: The data to be processed. - - Returns: - A `_Batch` instance. - """ - return _Batch( - seq_no=self.seq_no + 1, step_name=self.step_name, last_batch=self.last_batch - ) - - def set_data(self, data: List[List[Dict[str, Any]]]) -> None: - """Sets the data of the batch and updates the size of the batch. - - Args: - data: The data of the batch. - """ - self.data = data - self.size = len(data[0]) - self._update_data_hash() - - def get_data(self, num_rows: Union[int, None] = None) -> List[Dict[str, Any]]: - """Takes `num_rows` from the data of the batch and returns it. This method will - also remove the data from the batch and update the hash of the data. - - Args: - num_rows: The number of rows to take from the data. If `None`, then all the - data will be taken. Defaults to `None`. - - Returns: - A list with the data taken from the batch. - """ - - if self.data == [] and self.data_path is not None: - pass - - if num_rows is None: - data = self.data[0] - self.data = [] - else: - data = self.data[0][:num_rows] - self.data[0] = self.data[0][num_rows:] - - self._update_data_hash() - return data - - def _update_data_hash(self) -> None: - """Updates the hash of the data of the batch.""" - self.data_hash = hashlib.sha1(str(self.data).encode()).hexdigest() + if batch.data_path: + self._logger.debug( + f"Reading {batch.seq_no} batch data from '{batch.step_name}': '{batch.data_path}'" + ) + batch.read_batch_data_from_fs() - @classmethod - def accumulate(cls, step_name: str, batches: List[List["_Batch"]]) -> "_Batch": - """Creates a `_Batch` instance using the data from the list of batches that - were received from another steps. The batches will be accumulated in a single - list of data. + if batch.step_name in self.dag.leaf_steps: + self._write_buffer.add_batch(batch) # type: ignore - Args: - step_name: The name of the step that will process the batch. - batches: a list containing the list of batches received from the predecessors. + # If `_stop_called` was set to `True` while waiting for the output queue, then + # we need to handle the stop of the pipeline and break the loop to avoid + # propagating the batches through the pipeline and making the stop process + # slower. + if self._stop_called: + self._handle_batch_on_stop(batch) + break - Returns: - A `_Batch` instance. - """ + self._manage_batch_flow(batch) - data = [] - for step_batches in batches: - accumulated_data = [row for batch in step_batches for row in batch.data[0]] - data.append(accumulated_data) - return cls( - seq_no=0, step_name=step_name, last_batch=True, data=data, accumulated=True - ) + if self._stop_called: + self._handle_stop() - def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]: - """Dumps the content of the `_Batch` to a dictionary, using the `dataclass` helper function. + self._cache() - Args: - obj: Unused, just kept to match the signature of the parent method. - kwargs: Additional arguments that are kept to match the signature of the parent method. + def _run_load_queue_loop_in_thread(self) -> threading.Thread: + """Runs a background thread that reads from the `load_queue` to update the status + of the number of workers loaded for each step. Returns: - A `dict` containing the internal representation of the `_Batch`. - """ - - include_batch_data = kwargs.get("include_batch_data", True) - - dump = { - "seq_no": self.seq_no, - "step_name": self.step_name, - "last_batch": self.last_batch, - "data_hash": self.data_hash, - "accumulated": self.accumulated, - "created_from": self.created_from, - "batch_routed_to": self.batch_routed_to, - "size": self.size, - } - - if include_batch_data: - dump["data"] = self.data + The thread that was started. + """ + thread = threading.Thread(target=self._run_load_queue_loop) + thread.start() + return thread + + def _run_load_queue_loop(self) -> None: + """Runs a loop that reads from the `load_queue` to update the status of the number + of workers loaded for each step.""" + while True: + if (load_info := self._load_queue.get()) is None: + self._logger.debug("Received `None` from load queue. Breaking loop.") + break + + with self._steps_load_status_lock: + step_name, status = load_info["name"], load_info["status"] + if status == "loaded": + if self._steps_load_status[step_name] == _STEP_NOT_LOADED_CODE: + self._steps_load_status[step_name] = 1 + else: + self._steps_load_status[step_name] += 1 + elif status == "unloaded": + self._steps_load_status[step_name] -= 1 + else: + # load failed + self._steps_load_status[step_name] = _STEP_LOAD_FAILED_CODE - return dump + self._logger.debug( + f"Step '{step_name}' loaded workers: {self._steps_load_status[step_name]}" + ) - def copy(self) -> "_Batch": - """Creates a copy of the `_Batch` instance. + def _all_steps_loaded(self) -> bool: + """Waits for all the steps to load. Returns: - A copy of the `_Batch` instance. - """ - return copy.deepcopy(self) - - def write_batch_data_to_fs( - self, - fs: Optional[fsspec.AbstractFileSystem] = None, - base_path: Optional[UPath] = None, - ) -> None: - """Writes the content of the batch to the filesystem. - - Args - fs: The `fsspec` filesystem to be used to write the data. If not provided, the - one set in the `_fs` attribute will be used. Defaults to `None`. - base_path: The base path where the data of the batch will be stored. If not - provided, the one set in the `data_path` attribute will be used. Defaults - to `None`. - - Raises: - ValueError: If `fs` is not provided and the `_fs` attribute is not set. - """ - - if not fs and not self._fs: - raise ValueError( - "The `fs` parameter must be provided if the `_fs` attribute is not set." - ) - - if fs: - self._fs = fs - - if not base_path and not self.data_path: - raise ValueError( - "The `base_path` parameter must be provided if the `data_path` attribute" - " is not set." - ) - - seq_no_dir = ( - base_path / f"seq_no_{self.seq_no}" if base_path else UPath(self.data_path) - ) - seq_no_dir._fs_cached = self._fs # type: ignore - seq_no_dir.mkdir(parents=True, exist_ok=True) - - for i, data in enumerate(self.data): - table = pa.Table.from_pylist(data) - with self._fs.open(seq_no_dir / f"data_index_{i}.parquet", "wb") as f: # type: ignore - pq.write_table(table, f) + `True` if all the steps have been loaded correctly, `False` otherwise. + """ + + self._logger.info("⏳ Waiting for all the steps to load...") + previous_message = None + while not self._stop_called: + with self._steps_load_status_lock: + self._logger.debug(f"Steps loaded: {self._steps_load_status}") + + if any( + num_workers_loaded == _STEP_LOAD_FAILED_CODE + for num_workers_loaded in self._steps_load_status.values() + ): + self._logger.error("❌ Failed to load all the steps") + return False + + num_steps_loaded = 0 + workers_message = "" + for step_name, num_workers_loaded in self._steps_load_status.items(): + # TODO: update condition once we allow more than one worker per step + if num_workers_loaded == 1: + num_steps_loaded += 1 + workers_message += ( + f"\n * '{step_name}' workers: {max(0, num_workers_loaded)}" + ) - self.data = [] - self.data_path = str(seq_no_dir) + message = f"⏳ Steps loaded: {num_steps_loaded}/{len(self.dag)}{workers_message}" + if num_steps_loaded > 0 and message != previous_message: + self._logger.info(message) + previous_message = message - def read_batch_data_from_fs(self) -> None: - """Reads the content of the batch from the filesystem.""" - if not self.data_path: - raise ValueError( - "`data_path` attribute must be set to read the data from the filesystem." - " Use `write_batch_data_to_fs` method to set the `data_path` attribute." - ) + if num_steps_loaded == len(self.dag): + self._logger.info("✅ All the steps have been loaded!") + return True - if not self._fs: - raise ValueError( - "`_fs` attribute must be set to read the data from the filesystem." - " Use `write_batch_data_to_fs` method to set the `_fs` attribute." - ) + time.sleep(2.5) - for file in self._fs.ls(self.data_path): - with self._fs.open(file, "rb") as f: - table = pq.read_table(f) - self.data.append(table.to_pylist()) + return not self._stop_called - self._fs.rm(self.data_path, recursive=True) + def _handle_stop(self) -> None: + """Handles the stop of the pipeline execution, which will stop the steps from + processing more batches and wait for the output queue to be empty, to not lose + any data that was already processed by the steps before the stop was called.""" + self._logger.debug("Handling stop of the pipeline execution...") + self._add_batches_back_to_batch_manager() -@dataclass -class _BatchManagerStep(_Serializable): - """A class that will accumulate data for a step from the predecessors and create - batches for the step to process when there is enough data. + # Wait for the input queue to be empty, which means that all the steps finished + # processing the batches that were sent before the stop flag. + for step_name in self.dag: + self._wait_step_input_queue_empty(step_name) - Attributes: - step_name: The name of the step that will process the data. - accumulate: A flag to indicate if the data should be accumulated and create a - batch with all the data received from the predecessors instead of creating - batches with the `input_batch_size`. - input_batch_size: The size of the batch to be created for the step to process. - If `None`, then `accumulate` must be `True`. Defaults to `None`. - data: A dictionary with the predecessor step name as the key and a list of - dictionaries (rows) as the value. - built_batches: A list with the batches that were built and sent to the step queue, - but the step was stopped before processing the batch, so the batch doesn't get - lost. Defaults to an empty list. - seq_no: The sequence number of the next batch to be created. It will be - incremented for each batch created. - last_batch_received: A list with the names of the steps that sent the last - batch of data. - convergence_step: A flag to indicate if the step is a convergence step. An - `Step` is a convergence step if all its predecessors are receiving routed - batches. Defaults to `False`. - convergence_step_batches_consumed: A dictionary in which the key is the `seq_no` - of the batch created by step A, that was used by step B and C and obtained from - the `created_from` of the batches created by them. It's used to know if all - the batches from B and C steps created from batches of A have been consumed - by D, in order to not mess up the order of the batches. Only used if `convergence_step=True`. - Defaults to an empty dictionary. - next_expected_created_from_batch_seq_no: The next expected sequence number of the - batch from step A used by steps B and C and obtained from the `created_from` - of the batches created by them. It's used to avoid messing up the order of the - batches. Only used if `convergence_step=True`. Defaults to `0`. - """ + self._consume_output_queue() - step_name: str - accumulate: bool - input_batch_size: Union[int, None] = None - data: Dict[str, List[_Batch]] = field(default_factory=dict) - built_batches: List[_Batch] = field(default_factory=list) - seq_no: int = 0 - last_batch_received: List[str] = field(default_factory=list) - convergence_step: bool = False - convergence_step_batches_consumed: Dict[str, Dict[str, int]] = field( - default_factory=dict - ) - next_expected_created_from_batch_seq_no: int = 0 - - def add_batch(self, batch: _Batch, prepend: bool = False) -> None: - """Add a batch of data from `batch.step_name` to the step. It will accumulate the - data and keep track of the last batch received from the predecessors. + def _wait_step_input_queue_empty(self, step_name: str) -> Union["Queue[Any]", None]: + """Waits for the input queue of a step to be empty. Args: - batch: The output batch of an step to be processed by the step. - prepend: If `True`, the content of the batch will be added to the `built_batches` - list. This is done so if a `_Batch` was already built and send to the step - queue, and the step is stopped before processing the batch, the batch doesn't - get lost. Defaults to `False`. - """ - from_step = batch.step_name - - if prepend: - self.built_batches.append(batch) - else: - self.data[from_step].append(batch) - - if batch.last_batch: - self.last_batch_received.append(from_step) - - def get_batch(self) -> Union[_Batch, None]: - """Create a new batch of data for the step to process. It will return `None` if - there is not enough data to create a batch. + step_name: The name of the step. Returns: - A `_Batch` instance if there is enough data to create a batch. Otherwise, - `None`. + The input queue of the step if it's not loaded or finished, `None` otherwise. """ - if not self._ready_to_create_batch(): + if self._check_step_not_loaded_or_finished(step_name): return None - # If there are batches in the `built_batches` list, then return the first one - # and remove it from the list. - if self.built_batches: - return self.built_batches.pop(0) - - # `_last_batch` must be called before `_get_data`, as `_get_data` will update the - # list of data which is used to determine if the batch to be created is the last one. - # TODO: remove `_last_batch` method and integrate logic in `_get_data` - last_batch = self._last_batch() - data, created_from, batch_routed_to = self._get_data() - - return _Batch( - seq_no=self._get_seq_no(), - step_name=self.step_name, - last_batch=last_batch, - data=data, - accumulated=self.accumulate, - created_from=created_from, - batch_routed_to=batch_routed_to, - ) - - def empty_buffers(self) -> List[str]: - """Checks if the input buffer for the step is empty. + if input_queue := self.dag.get_step(step_name).get(INPUT_QUEUE_ATTR_NAME): + while input_queue.qsize() != 0: + pass + return input_queue - Returns: - The name of the previous steps for which the input buffer for this step is - empty. - """ - if self.accumulate: - return [ - previous_step - for previous_step in self.data.keys() - if previous_step not in self.last_batch_received - ] - - return [ - previous_step - for previous_step, batches in self.data.items() - if previous_step not in self.last_batch_received - and sum(len(batch.data[0]) for batch in batches) < self.input_batch_size # type: ignore - ] - - @classmethod - def from_step( - cls, step: "_Step", predecessors: Iterable[str], convergence_step: bool = False - ) -> "_BatchManagerStep": - """Creates a `_BatchManagerStep` instance from a `_Step` instance and its - predecessors. + def _check_step_not_loaded_or_finished(self, step_name: str) -> bool: + """Checks if a step is not loaded or already finished. Args: - step: The `_Step` instance. - predecessors: The names of the predecessors of the step. - convergence_step: A flag to indicate if the step is a convergence step. An - `Step` is a convergence step if all its predecessors are receiving routed - batches. Defaults to `False`. - - Returns: - A `_BatchManagerStep` instance. - """ - return cls( - step_name=step.name, # type: ignore - accumulate=step.is_global, - input_batch_size=getattr(step, "input_batch_size", None), - data={predecessor: [] for predecessor in predecessors}, - convergence_step=convergence_step, - ) - - def _get_seq_no(self) -> int: - """Gets the sequence number for the next batch to be created and increments it. - - Returns: - The sequence number for the next batch to be created. - """ - seq_no = self.seq_no - self.seq_no += 1 - return seq_no - - def _get_data( - self, - ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]], List[str]]: - """Gets the data needed to create a batch for the step to process. If the step is - accumulating data, then it will return a list with all the data received from the - predecessors. Otherwise, it will return a list of data with the `input_batch_size` - for each predecessor. In addition, it will remove the data used to create the - batch from the step's data. - - Returns: - A tuple containing the list of data needed to create a batch for the step to - process, a dictionary with the sequence numbers of the batches that were used - to create the batch and the list of steps to which the batch was routed to if - the step is a normal step. - """ - if self.accumulate: - # Steps accumulating cannot receive routed batches - return self._get_data_for_accumulate() + ([],) - - if self.convergence_step: - # Convergence steps will receive routed batches, but we need to clean the - # `batch_routed_to` list - return self._get_data_for_convergence_step() + ([],) - - return self._get_data_normal() - - def _get_data_for_accumulate( - self, - ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]]]: - """Gets the data needed to create a batch for the step to process when the step - is accumulating data. It will return a list with all the data received from the - predecessors. In addition, it will remove the data used to create the batch from - the step's data. - - Returns: - A tuple containing the list of data needed to create a batch for the step to - process and a dictionary with the sequence numbers of the batches that were - used to create the batch. - """ - data = [] - batches_used = {} - for step_name, batches in self.data.items(): - batches_used[step_name] = [] - for batch in batches: - batches_used[step_name].append((batch.seq_no, batch.size)) - data.append([row for batch in batches for row in batch.get_data()]) - # Reset the data buffer - self.data = {step_name: [] for step_name in self.data} - return data, batches_used - - def _get_data_for_convergence_step( - self, - ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]]]: - """Gets the data needed to create a batch for the step to process when the step is - a convergence step. - - Returns: - A tuple containing the list of data needed to create a batch for the step to - process and a dictionary with the sequence numbers of the batches that were - used to create the batch. - """ - grouped_batches = self._group_batches_by_created_from() - seq_no, batches = grouped_batches[0] - str_seq_no = str(seq_no) - - remaining_rows_per_step: Dict[str, int] = { - step_name: self.input_batch_size - for step_name in self.data # type: ignore - } - batches_used = defaultdict(list) - data = defaultdict(list) - for batch, batch_size in batches: - remaining_rows = remaining_rows_per_step[batch.step_name] - selected_data = batch.get_data(remaining_rows) - data[batch.step_name].extend(selected_data) - - # If A -> [B, C] -> D, then in D (this step) we keep track of the remaining - # rows from the batches of A that B and C used to create the `batches`. - batch_size = self.convergence_step_batches_consumed.setdefault( - str_seq_no, {} - ).get(batch.step_name, batch_size) - remaining_rows_in_batch = batch_size - len(selected_data) - self.convergence_step_batches_consumed[str_seq_no].update( - {batch.step_name: remaining_rows_in_batch} - ) - - # Update the remaining rows - num_rows = len(selected_data) - remaining_rows_per_step[batch.step_name] -= num_rows # type: ignore - - # Keep track of the batches used to create the batch - batches_used[batch.step_name].append((batch.seq_no, batch.size)) - - # If the batch was entirely consumed, then remove it from the buffer - if len(batch.data[0]) == 0: - self.data[batch.step_name].remove(batch) - continue - - # If all the batches grouped by the `seq_no` in `created_from` were consumed, then - # we can update the `next_expected_created_from_batch_seq_no` to the next one - # to avoid skipping batches. - no_remaining_rows = all( - count == 0 - for count in self.convergence_step_batches_consumed[str_seq_no].values() - ) - if no_remaining_rows: - self.next_expected_created_from_batch_seq_no += 1 - - return list(data.values()), dict(batches_used) - - def _get_data_normal( - self, - ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]], List[str]]: - """Gets the data needed to create a batch for the step to process when the step is - not accumulating data. It will return a list of data with the `input_batch_size` - for each predecessor. In addition, it will remove the data used to create the batch - from the step's data. + step_name: The name of the step. Returns: - A tuple containing the list of data needed to create a batch for the step to - process, a dictionary with the sequence numbers of the batches that were used - to create the batch and the list of steps to which the batch was routed to if - the step is a convergence step. + `True` if the step is not loaded or already finished, `False` otherwise. """ - data = [] - batches_used = defaultdict(list) - batch_routed_to = [] - for step_name in self.data: - # For each step batches buffer, we will create a batch with the `input_batch_size` - # using the data from the buffer. We will remove the consumed batches (no data - # left) and update the batch data with the remaining data. - step_data = [] - idx_drop_batches = [] - remaining_rows: int = self.input_batch_size # type: ignore - for idx, batch in enumerate(self.data[step_name]): - if remaining_rows == 0: - break - - # Get `remaining_rows` or the remaining rows in the batch and add it to - # the step data that will be used to create the batch - selected_data = batch.get_data(remaining_rows) - step_data.extend(selected_data) - batch_routed_to = batch.batch_routed_to - - # Update the remaining rows - num_rows = len(selected_data) - remaining_rows -= num_rows - - # Keep track of the batches used to create the batch - batches_used[step_name].append((batch.seq_no, batch.size)) - - # If the batch was entirely consumed, then remove it from the buffer - if len(batch.data[0]) == 0: - idx_drop_batches.append(idx) - continue - - # Remove the batches that were entirely consumed - idx_drop_batches.reverse() - for idx in idx_drop_batches: - self.data[step_name].pop(idx) - - data.append(step_data) - - return data, dict(batches_used), batch_routed_to + with self._steps_load_status_lock: + num_workers = self._steps_load_status[step_name] - def _ready_to_create_batch(self) -> bool: - """Checks if there is enough data to create a batch for the step. + # The step has finished (workers = 0) or it has failed to load + if num_workers in [0, _STEP_LOAD_FAILED_CODE]: + return True - Returns: - `True` if there is enough data to create a batch for the step. Otherwise, - `False`. - """ - if self.accumulate: - return self._ready_to_create_batch_accumulate() + return False - if self.convergence_step: - return self._ready_to_create_batch_convergence_step() + @property + @abstractmethod + def QueueClass(self) -> Callable: + """The class of the queue to use in the pipeline.""" + pass - return self._ready_to_create_batch_normal() + def _create_step_input_queue(self, step_name: str) -> "Queue[Any]": + """Creates an input queue for a step. - def _ready_to_create_batch_accumulate(self) -> bool: - """Checks if there is enough data for an step accumulating data. It will return - `True` if the last batch was received from all the predecessors. + Args: + step_name: The name of the step. Returns: - `True` if ready to create a batch, `False` otherwise. + The input queue created. """ - return all( - step in self.last_batch_received - and sum(len(batch.data[0]) for batch in batches) >= 0 - for step, batches in self.data.items() - ) + input_queue = self.QueueClass() + self.dag.set_step_attr(step_name, INPUT_QUEUE_ATTR_NAME, input_queue) + return input_queue - def _ready_to_create_batch_convergence_step(self) -> bool: - """Checks if there is enough data for creating a batch for an step in which output - batches that were generated by steps that received routed batches are received. - It will return `True`, if all the output batches that were generated from a routed - batch have been received. + @abstractmethod + def _run_step(self, step: "_Step", input_queue: "Queue[Any]") -> None: + """Runs the `Step` instance. - Returns: - `True` if ready to create a batch, `False` otherwise. + Args: + step: The `Step` instance to run. + input_queue: The input queue where the step will receive the batches. """ - grouped_batches = self._group_batches_by_created_from() - if not grouped_batches: - return False - seq_no, batches = grouped_batches[0] - - # If the `seq_no` from the `created_from` field is not the expected one, then - # we cannot create a batch yet or the order will be messed up - if seq_no != self.next_expected_created_from_batch_seq_no: - return False - - # Not all output batches to which the input batch was routed to haven't been - # received - batch_routed_to = batches[0][0].batch_routed_to - batches_received_from = {batch.step_name for batch, _ in batches} - if any(step_name not in batches_received_from for step_name in batch_routed_to): - return False - - # There are output batches to which the input batch was routed to from all - # the steps. Check if there is enough data for creating a batch with `input_batch_size` - rows_per_step = defaultdict(lambda: 0) - for batch, _ in batches: - num_rows = len(batch.data[0]) - rows_per_step[batch.step_name] += num_rows - - # If there aren't at least `input_batch_size` rows from each step, then there - # isn't enough data to create a batch - if not all( - num_rows >= self.input_batch_size or step_name in self.last_batch_received # type: ignore - for step_name, num_rows in rows_per_step.items() - ): - return False - - return True - - def _ready_to_create_batch_normal(self) -> bool: - """Checks if there is enough data for creating a batch for a normal step. It will - be `True` it there are at least `input_batch_size` rows from each predecessor step. + pass - Returns: - `True` if ready to create a batch, `False` otherwise. + def _run_steps(self) -> None: + """Runs the `Step`s of the pipeline, creating first an input queue for each step + that will be used to send the batches. """ - for step_name, batches in self.data.items(): - num_rows = sum(len(batch.data[0]) for batch in batches) - - # If there are now rows but the last batch was already received, then there - # are no more batch to be created - if num_rows == 0 and step_name in self.last_batch_received: - return False - - # If there are not enough rows and the last batch was not received yet, then - # there is not enough data yet to creata a batch - if ( - self.input_batch_size - and num_rows < self.input_batch_size - and step_name not in self.last_batch_received - ): - return False + for step_name in self.dag: + step: "Step" = self.dag.get_step(step_name)[STEP_ATTR_NAME] + input_queue = self._create_step_input_queue(step_name=step_name) - return True + # Set `pipeline` to `None` as in some Python environments the pipeline is not + # picklable and it will raise an error when trying to send the step to the process. + # `TypeError: cannot pickle 'code' object` + step.pipeline = None - def _last_batch(self) -> bool: - """Checks if the batch to be created is the last one i.e. if the last batch was - received from all the predecessors. + self._logger.debug(f"Running 1 instance of step '{step.name}'...") + self._run_step(step=step, input_queue=input_queue) - Returns: - `True` if the batch to be created is the last one. Otherwise, `False`. - """ - if self.accumulate: - return self._last_batch_accumulate() + def _add_batches_back_to_batch_manager(self) -> None: + """Add the `Batch`es that were sent to a `Step` back to the `_BatchManager`. This + method should be used when the pipeline has been stopped prematurely.""" + for step_name in self.dag: + node = self.dag.get_step(step_name) + step: "_Step" = node[STEP_ATTR_NAME] + if step.is_generator: + continue + if input_queue := node.get(INPUT_QUEUE_ATTR_NAME): + while not input_queue.empty(): + batch = input_queue.get() + if batch is None: + continue + self._batch_manager.add_batch( # type: ignore + to_step=step_name, batch=batch, prepend=True + ) + self._logger.debug( + f"Adding batch back to the batch manager: {batch}" + ) + input_queue.put(None) + + def _consume_output_queue(self) -> None: + """Consumes the `Batch`es from the output queue until it's empty. This method should + be used when the pipeline has been stopped prematurely to consume and to not lose + the `Batch`es that were processed by the leaf `Step`s before stopping the pipeline.""" + while not self._output_queue.empty(): + batch = self._output_queue.get() + if batch is None: + continue - if self.convergence_step: - return self._last_batch_convergence_step() + if batch.step_name in self.dag.leaf_steps: + self._write_buffer.add_batch(batch) # type: ignore - return self._last_batch_normal() + self._handle_batch_on_stop(batch) - def _last_batch_accumulate(self) -> bool: - """Checks if the batch to be created is the last one for an step accumulating data. - `True` if the last batch was received from all the predecessors. + def _manage_batch_flow(self, batch: "_Batch") -> None: + """Checks if the step that generated the batch has more data in its buffer to + generate a new batch. If there's data, then a new batch is sent to the step. If + the step has no data in its buffer, then the predecessors generator steps are + requested to send a new batch. - Returns: - `True` if the batch to be created is the last one. Otherwise, `False`. + Args: + batch: The batch that was processed. """ - return all(step in self.last_batch_received for step in self.data.keys()) + assert self._batch_manager, "Batch manager is not set" - def _last_batch_convergence_step(self) -> bool: - """Checks if the batch to be created is the last one for a convergence step. `True` - if the last batch of all the steps (`batch_routed_to`) in the last routed batch - have been received. + # Make sure to send the `LAST_BATCH_SENT_FLAG` to the predecessors of the convergence + # step if the batch is the last one, so they stop their processing loop even if + # they haven't received the last batch because of the routing function. + if self._is_convergence_step(batch.step_name) and batch.last_batch: + for step_name in self.dag.get_step_predecessors(batch.step_name): + self._send_last_batch_flag_to_step(step_name) - Returns: - `True` if the batch to be created is the last one. Otherwise, `False`. - """ - grouped_batches = self._group_batches_by_created_from() - if not grouped_batches: - return False - _, batches = grouped_batches[0] + route_to, routed = self._get_successors(batch) - for batch, _ in batches: - if not batch.last_batch: - return False + # Keep track of the steps that the batch was routed to + if routed: + batch.batch_routed_to = route_to - if len(batch.data[0]) > self.input_batch_size: # type: ignore - return False + self._register_batch(batch) - return True + step = self._get_step_from_batch(batch) - def _last_batch_normal(self) -> bool: - """Checks if the batch to be created is the last one for a normal step. `True` if - there is no more data to be received from the predecessors. + # Add the batch to the successors input buffers + for successor in route_to: + # Copy batch to avoid modifying the same reference in the batch manager + batch_to_add = batch.copy() if len(route_to) > 1 else batch - Returns: - `True` if the batch to be created is the last one. Otherwise, `False`. - """ - for step_name, batches in self.data.items(): - if step_name not in self.last_batch_received: - return False - - num_rows = sum(len(batch.data[0]) for batch in batches) + self._batch_manager.add_batch(successor, batch_to_add) + # Check if the step is a generator and if there are successors that need data + # from this step. This usually happens when the generator `batch_size` is smaller + # than the `input_batch_size` of the successor steps. if ( - self.input_batch_size - and num_rows > self.input_batch_size - and step_name in self.last_batch_received + step.is_generator + and step.name in self._batch_manager.step_empty_buffers(successor) ): - return False - - return True - - def _group_batches_by_created_from( - self, - ) -> List[Tuple[int, List[Tuple["_Batch", int]]]]: - """Group the batches by the first key of `created_from` field. This method is - meant to be used only with a `convergence_step`. - - Returns: - A list of the batches grouped by the `seq_no` of the first step name in `created_from`. - The list is sorted by the `seq_no`. - """ - grouped_batches: Dict[int, List[Tuple["_Batch", int]]] = defaultdict(list) - for batches in self.data.values(): - for batch in batches: - first_key = next(iter(batch.created_from)) - batch_seq_no, batch_size = batch.created_from[first_key][0] - grouped_batches[batch_seq_no].append((batch, batch_size)) - return sorted((seq_no, batches) for seq_no, batches in grouped_batches.items()) - - def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]: - """Dumps the content of the `_BatchManagerStep` to a dictionary, using the `dataclass` helper function. - - Args: - obj: Unused, just kept to match the signature of the parent method. - kwargs: Additional arguments that are kept to match the signature of the parent method. - - Returns: - Internal representation of the `_BatchManagerStep`. - """ - return { - "step_name": self.step_name, - "accumulate": self.accumulate, - "input_batch_size": self.input_batch_size, - "data": { - step_name: [batch.dump(**kwargs) for batch in batches] - for step_name, batches in self.data.items() - }, - "built_batches": [batch.dump(**kwargs) for batch in self.built_batches], - "seq_no": self.seq_no, - "last_batch_received": self.last_batch_received, - "convergence_step": self.convergence_step, - "convergence_step_batches_consumed": self.convergence_step_batches_consumed, - "next_expected_created_from_batch_seq_no": self.next_expected_created_from_batch_seq_no, - } - - -LAST_BATCH_SENT_FLAG = "last_batch_sent" - - -class _BatchManager(_Serializable): - """Class to manage the batches received from the steps. It keeps track of the - received batches and returns new batches for the steps to process based on their - input batch size and the batches received from the predecessors. - - Attributes: - steps: A dictionary with the step name as the key and a `_BatchManagerStep` - instance as the value. - last_batch_received: A dictionary with the step name as the key and a flag to - indicate whether we received the last batch from the step. - """ - - def __init__( - self, - steps: Dict[str, _BatchManagerStep], - last_batch_received: Dict[str, Union[_Batch, None]], - last_batch_sent: Dict[str, Union[_Batch, None]], - last_batch_flag_sent_to: List[str], - ) -> None: - """Initialize the `_BatchManager` instance. - - Args: - steps: A dictionary with the step name as the key and a dictionary with the - predecessor step name as the key and a list of batches as the value. - last_batch_received: A dictionary with the step name as the key and a the last - `_Batch` received from the step. - last_batch_sent: A dictionary with the step name as the key and a the last - `_Batch` sent to the step. - last_batch_flag_sent_to: A list with the names of the steps to which `LAST_BATCH_SENT_FLAG` - was sent. - """ - - self._steps = steps - self._last_batch_received = last_batch_received - self._last_batch_sent = last_batch_sent - self._last_batch_flag_sent_to = last_batch_flag_sent_to - - def can_generate(self) -> bool: - """Checks if there are still batches to be processed by the steps. - - Returns: - `True` if there are still batches to be processed by the steps. Otherwise, - `False`. - """ - - for step_name, batch in self._last_batch_received.items(): - if step_name not in self._last_batch_flag_sent_to: - if not batch: - return True - - if not batch.last_batch: - return True - - if not self.get_last_batch_sent(step_name): - return True - - return False - - def register_batch(self, batch: _Batch) -> None: - """Method to register a batch received from a step. It will keep track of the - sequence number and the last batch received from the step in the internal maps. + last_batch_sent = self._batch_manager.get_last_batch_sent(step.name) + self._send_batch_to_step(last_batch_sent.next_batch()) # type: ignore + + # If successor step has enough data in its buffer to create a new batch, then + # send the batch to the step. + if new_batch := self._batch_manager.get_batch(successor): + self._send_batch_to_step(new_batch) + + if not step.is_generator: + # Step ("this", the one from which the batch was received) has enough data on its + # buffers to create a new batch + if new_batch := self._batch_manager.get_batch(step.name): # type: ignore + self._send_batch_to_step(new_batch) + else: + self._request_more_batches_if_needed(step) - Args: - batch: _Batch from which we will register the sequence number and the last batch received. - """ - self._last_batch_received[batch.step_name] = batch + self._cache() - def get_last_batch(self, step_name: str) -> Union[_Batch, None]: - """Gets the last batch received from a step. + def _send_to_step(self, step_name: str, to_send: Any) -> None: + """Sends something to the input queue of a step. Args: step_name: The name of the step. - - Returns: - The last batch received from the step or `None` if no batch was received. - """ - return self._last_batch_received.get(step_name) - - def add_batch(self, to_step: str, batch: _Batch, prepend: bool = False) -> None: - """Add an output batch from `batch.step_name` to `to_step`. - - Args: - to_step: The name of the step that will process the batch. - batch: The output batch of an step to be processed by `to_step`. - prepend: If `True`, the content of the batch will be added at the start of - the buffer. - - Raises: - ValueError: If `to_step` is not found in the batch manager. + to_send: The object to send. """ - if to_step not in self._steps: - raise ValueError(f"Step '{to_step}' not found in the batch manager.") + input_queue = self.dag.get_step(step_name)[INPUT_QUEUE_ATTR_NAME] + input_queue.put(to_send) - step = self._steps[to_step] - step.add_batch(batch, prepend) + def _send_batch_to_step(self, batch: "_Batch") -> None: + """Sends a batch to the input queue of a step, writing the data of the batch + to the filesystem and setting `batch.data_path` with the path where the data + was written (if requiered i.e. the step is a global step or `use_fs_to_pass_data`) - def get_batch(self, step_name: str) -> Union[_Batch, None]: - """Get the next batch to be processed by the step. + This method should be extended by the specific pipeline implementation, adding + the logic to send the batch to the step. Args: - step_name: The name of the step that will process the batch. - - Returns: - A `_Batch` instance if there is a batch to be processed by the step. Otherwise, - `None`. + batch: The batch to send. """ - if step_name not in self._steps: - raise ValueError(f"Step '{step_name}' not found in the batch manager.") - - return self._steps[step_name].get_batch() + self._logger.debug( + f"Setting batch {batch.seq_no} as last batch sent to '{batch.step_name}': {batch}" + ) + self._batch_manager.set_last_batch_sent(batch) # type: ignore - def step_empty_buffers(self, step_name: str) -> List[str]: - """Checks if the input buffer for a step is empty. + step: "_Step" = self.dag.get_step(batch.step_name)[STEP_ATTR_NAME] + if not step.is_generator and (step.is_global or self._use_fs_to_pass_data): + base_path = UPath(self._storage_base_path) / step.name # type: ignore + self._logger.debug( + f"Writing {batch.seq_no} batch for '{batch.step_name}' step to filesystem: {base_path}" + ) + batch.write_batch_data_to_fs(self._fs, base_path) # type: ignore - Args: - step_name: The name of the step. + self._logger.debug( + f"Sending batch {batch.seq_no} to step '{batch.step_name}': {batch}" + ) - Returns: - The name of the previous steps for which the input buffer for this step is - empty. - """ - return self._steps[step_name].empty_buffers() + self._send_to_step(batch.step_name, batch) - def set_last_batch_sent(self, batch: "_Batch") -> None: - """Set the last batch sent to a step. + def _register_batch(self, batch: "_Batch") -> None: + """Registers a batch in the batch manager. Args: - batch: The last batch sent to a step. + batch: The batch to register. """ - self._last_batch_sent[batch.step_name] = batch + self._batch_manager.register_batch(batch) # type: ignore + self._logger.debug( + f"Batch {batch.seq_no} from step '{batch.step_name}' registered in batch" + " manager" + ) - def get_last_batch_sent(self, step_name: str) -> Union["_Batch", None]: - """Get the last batch sent to a step. + def _send_last_batch_flag_to_step(self, step_name: str) -> None: + """Sends the `LAST_BATCH_SENT_FLAG` to a step to stop processing batches. Args: step_name: The name of the step. - - Returns: - The last batch sent to a step or `None` if no batch was sent. """ - return self._last_batch_sent.get(step_name, None) + batch = self._batch_manager.get_last_batch_sent(step_name) # type: ignore + if batch and batch.last_batch: + return - def set_last_batch_flag_sent_to(self, step_name: str) -> None: - """Set the flag to indicate that the last batch was sent to a step. + self._logger.debug( + f"Sending `LAST_BATCH_SENT_FLAG` to '{step_name}' step to stop processing" + " batches..." + ) - Args: - step_name: The name of the step. - """ - self._last_batch_flag_sent_to.append(step_name) + self._send_to_step(step_name, LAST_BATCH_SENT_FLAG) + self._batch_manager.set_last_batch_flag_sent_to(step_name) # type: ignore - @classmethod - def from_dag(cls, dag: "DAG") -> "_BatchManager": - """Create a `_BatchManager` instance from a `DAG` instance. + def _request_initial_batches(self) -> None: + """Requests the initial batches to the generator steps.""" + assert self._batch_manager, "Batch manager is not set" - Args: - dag: The `DAG` instance. + for step in self._batch_manager._steps.values(): + if batch := step.get_batch(): + self._logger.debug( + f"Sending initial batch to '{step.step_name}' step: {batch}" + ) + self._send_batch_to_step(batch) - Returns: - A `_BatchManager` instance. - """ - steps = {} - last_batch_received = {} - last_batch_sent = {} - for step_name in dag: - step: "_Step" = dag.get_step(step_name)[STEP_ATTR_NAME] - last_batch_received[step.name] = None - last_batch_sent[step.name] = None - if step.is_generator: - continue - predecessors = list(dag.get_step_predecessors(step_name)) - convergence_step = all( - dag.get_step(predecessor).get(RECEIVES_ROUTED_BATCHES_ATTR_NAME, False) - for predecessor in predecessors - ) - batch_manager_step = _BatchManagerStep.from_step( - step=step, - predecessors=predecessors, - convergence_step=convergence_step, + for step_name in self.dag.root_steps: + seq_no = 0 + if last_batch := self._batch_manager.get_last_batch(step_name): + seq_no = last_batch.seq_no + 1 + batch = _Batch(seq_no=seq_no, step_name=step_name, last_batch=self._dry_run) + self._logger.debug( + f"Requesting initial batch to '{step_name}' generator step: {batch}" ) - steps[step_name] = batch_manager_step - return cls(steps, last_batch_received, last_batch_sent, []) - - def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]: - """Dumps the content of the `_BatchManager` to a dictionary. - - Args: - obj (Any): Unused, just kept to match the signature of the parent method. - kwargs (Any): Additional arguments that are kept to match the signature of the parent method. - - Returns: - Dict[str, Any]: Internal representation of the `_BatchManager`. - """ - return { - "steps": {name: step.dump(**kwargs) for name, step in self._steps.items()}, - "last_batch_received": { - step_name: batch.dump(**kwargs) if batch is not None else None - for step_name, batch in self._last_batch_received.items() - }, - "last_batch_sent": { - step_name: batch.dump(**kwargs) if batch is not None else None - for step_name, batch in self._last_batch_sent.items() - }, - "last_batch_flag_sent_to": self._last_batch_flag_sent_to, - } + self._send_batch_to_step(batch) - def cache(self, path: "StrOrPath") -> None: - """Cache the `_BatchManager` to a file. + def _request_more_batches_if_needed(self, step: "Step") -> None: + """Request more batches to the predecessors steps of `step` if needed. Args: - path: The path to the file where the `_BatchManager` will be cached. If `None`, - then the `_BatchManager` will be cached in the default cache folder. - """ - - def save_batch( - batches_dir: Path, batch_dump: Dict[str, Any], batch_list: List[_Batch] - ) -> Path: - seq_no = batch_dump["seq_no"] - data_hash = batch_dump["data_hash"] - batch_file = batches_dir / f"batch_{seq_no}_{data_hash}.json" - - # Save the batch if it doesn't exist - if not batch_file.exists(): - # Get the data of the batch before saving it - batch = next(batch for batch in batch_list if batch.seq_no == seq_no) - batch_dump["data"] = batch.data - self.save(path=batch_file, format="json", dump=batch_dump) - - return batch_file - - def remove_files(keep_files: List[str], dir: Path) -> None: - files = list_files_in_dir(dir, key=None) - remove = set(files) - {Path(file) for file in keep_files} - for file in remove: - file.unlink() - - path = Path(path) - - # Do not include `_Batch` data so `dump` is fast - dump = self.dump(include_batch_data=False) - batch_manager_step_files = {} - - # Do this to avoid modifying the dictionary while iterating over it - batch_manager_steps = set(dump["steps"].keys()) - for step_name in batch_manager_steps: - step_dump = dump["steps"].pop(step_name) - - # Create a directory for each batch manager step to store their batches - batch_manager_step_dir = path.parent / "batch_manager_steps" / step_name - batch_manager_step_dir.mkdir(parents=True, exist_ok=True) - - # Store each built `_Batch` in a separate file - built_batches_dir = batch_manager_step_dir / "built_batches" - built_batches_dir.mkdir(parents=True, exist_ok=True) - step_dump["built_batches"] = [ - str( - save_batch( - batches_dir=built_batches_dir, - batch_dump=batch_dump, - batch_list=self._steps[step_name].built_batches, - ) - ) - for batch_dump in step_dump["built_batches"] - ] - # Remove built `_Batch`es that were consumed from cache - remove_files(step_dump["built_batches"], built_batches_dir) - - # Store each `_BatchManagerStep` `_Batch`es in a separate file - for buffered_step_name in step_dump["data"]: - step_batches_dir = batch_manager_step_dir / buffered_step_name - step_batches_dir.mkdir(parents=True, exist_ok=True) - - # Store each `_Batch` in a separate file - step_dump["data"][buffered_step_name] = [ - str( - save_batch( - batches_dir=step_batches_dir, - batch_dump=batch_dump, - batch_list=self._steps[step_name].data[buffered_step_name], - ) - ) - for batch_dump in step_dump["data"][buffered_step_name] - ] + step: The step of which it has to be checked if more batches are needed from + its predecessors. + """ + empty_buffers = self._batch_manager.step_empty_buffers(step.name) # type: ignore + for previous_step_name in empty_buffers: + # Only more batches can be requested to the `GeneratorStep`s as they are the + # only kind of steps that lazily generate batches. + if previous_step_name not in self.dag.root_steps: + continue - # Remove `_Batch`es that were consumed from cache - remove_files(step_dump["data"][buffered_step_name], step_batches_dir) + # Get the last batch that the previous step sent to generate the next batch + # (next `seq_no`). + last_batch = self._batch_manager.get_last_batch_sent(previous_step_name) # type: ignore + if last_batch is None: + continue - # Store the `_BatchManagerStep` info - batch_manager_step_file = str( - path.parent / f"batch_manager_steps/{step_name}/batch_manager_step.json" + self._logger.debug( + f"Step '{step.name}' input buffer for step '{previous_step_name}' is" + " empty. Requesting new batch..." ) - self.save(path=batch_manager_step_file, format="json", dump=step_dump) + self._send_batch_to_step(last_batch.next_batch()) - # Store the path to the `_BatchManagerStep` file - batch_manager_step_files[step_name] = batch_manager_step_file - - dump["steps"] = batch_manager_step_files - self.save(path=path, format="json", dump=dump) - - @classmethod - def load_from_cache(cls, path: "StrOrPath") -> "_BatchManager": - """Loads the `_BatchManager` from a cache file. + def _handle_batch_on_stop(self, batch: "_Batch") -> None: + """Handles a batch that was received from the output queue when the pipeline was + stopped. It will add and register the batch in the batch manager. Args: - path: The path to the cache file. + batch: The batch to handle. """ - _check_is_dir(path) - content = read_json(path) - - # Read each `_BatchManagerStep` from file - steps = {} - for step_name, step_file in content["steps"].items(): - steps[step_name] = read_json(step_file) - - # Read each `_Batch` from file - steps[step_name]["built_batches"] = [ - read_json(batch) for batch in steps[step_name]["built_batches"] - ] - - for buffered_step_name, batch_files in steps[step_name]["data"].items(): - steps[step_name]["data"][buffered_step_name] = [ - read_json(batch_file) for batch_file in batch_files - ] - - content["steps"] = steps - return cls.from_dict(content) + assert self._batch_manager, "Batch manager is not set" + self._batch_manager.register_batch(batch) + step: "Step" = self.dag.get_step(batch.step_name)[STEP_ATTR_NAME] + for successor in self.dag.get_step_successors(step.name): # type: ignore + self._batch_manager.add_batch(successor, batch) -class _WriteBuffer: - """Class in charge of sending the batched contents to a buffer and writing - those to files under a given folder. + def _get_step_from_batch(self, batch: "_Batch") -> "Step": + """Gets the `Step` instance from a batch. - As batches are received, they are added to the buffer and once each buffer - is full, the content is written to a parquet file. - """ - - def __init__(self, path: "PathLike", leaf_steps: Set[str]) -> None: - """ Args: - path: Folder where the files will be written, the idea - is for this path to be in the cache folder under /data. - leaf_steps: Leaf steps from either the DAG of the Pipeline. + batch: The batch to get the step from. - Raises: - ValueError: If the path is not a directory. + Returns: + The `Step` instance. """ - self._path = Path(path) - if not self._path.exists(): - self._path.mkdir(parents=True, exist_ok=True) - for step in leaf_steps: - (self._path / step).mkdir(parents=True, exist_ok=True) - - if not self._path.is_dir(): - raise ValueError(f"The path should be a directory, not a file: {path}") + return self.dag.get_step(batch.step_name)[STEP_ATTR_NAME] - self._buffers: Dict[str, List[Dict[str, Any]]] = { - step: [] for step in leaf_steps - } - # TODO: make this configurable - self._buffers_dump_batch_size: Dict[str, int] = { - step: 50 for step in leaf_steps - } - self._buffer_last_schema = {} - self._buffers_last_file: Dict[str, int] = {step: 1 for step in leaf_steps} - self._logger = logging.getLogger("distilabel.write_buffer") + def _notify_steps_to_stop(self) -> None: + """Notifies the steps to stop their infinite running loop by sending `None` to + their input queues.""" + for step_name in self.dag: + self._send_to_step(step_name, None) - def _get_filename(self, step_name: str) -> Path: - """Creates the filename for the step. + def _get_successors(self, batch: "_Batch") -> Tuple[List[str], bool]: + """Gets the successors and the successors to which the batch has to be routed. Args: - step_name: Name of the step to which the data belongs to. - - Returns: - Filename for the step. - """ - return self._path / f"{step_name}.parquet" - - def is_full(self, step_name: str) -> bool: - """Checks the buffers that are full so that those can be written to the file. + batch: The batch to which the successors will be determined. Returns: - Whether the buffer is full. - """ - return len(self._buffers[step_name]) >= self._buffers_dump_batch_size[step_name] - - def add_batch(self, batch: "_Batch") -> None: - """Adds a batch to the buffer and writes the buffer to the file if it's full. - - Args: - batch: batch to add to the buffer. - """ - step_name = batch.step_name - data = batch.data[0] - self._buffers[step_name].extend(data) - self._logger.debug( - f"Added batch to write buffer for step '{step_name}' with {len(data)} rows." - ) - if self.is_full(step_name): - self._logger.debug( - f"Buffer for step '{step_name}' is full (rows: {len(self._buffers[step_name])}," - f" full: {self._buffers_dump_batch_size[step_name]}), writing to file..." + The successors to route the batch to and whether the batch was routed using + a routing function. + """ + node = self.dag.get_step(batch.step_name) + step: "Step" = node[STEP_ATTR_NAME] + successors = list(self.dag.get_step_successors(step.name)) # type: ignore + route_to = successors + + # Check if the step has a routing function to send the batch to specific steps + if routing_batch_function := node.get(ROUTING_BATCH_FUNCTION_ATTR_NAME): + route_to = routing_batch_function(batch, successors) + successors_str = ", ".join(f"'{successor}'" for successor in route_to) + self._logger.info( + f"🚏 Using '{step.name}' routing function to send batch {batch.seq_no} to steps: {successors_str}" ) - self._write(step_name) - def _write(self, step_name: str) -> None: - """Writes the content to the file and cleans the buffer. + return route_to, route_to != successors - Args: - step_name (str): Name of the step to which the data pertains. - """ - step_parquet_dir = Path(self._path, step_name) - if not step_parquet_dir.exists(): - self._logger.debug( - f"Creating directory for step '{step_name}' parquet files..." - ) - step_parquet_dir.mkdir() + @abstractmethod + def _stop(self) -> None: + """Stops the pipeline in a controlled way.""" + pass - try: - table = pa.Table.from_pylist(self._buffers[step_name]) - except pa.lib.ArrowInvalid as pae: - if ( - repr(pae) - != "ArrowInvalid('cannot mix struct and non-struct, non-null values')" - ): - raise pae - flattened_buffers = [flatten_dict(buf) for buf in self._buffers[step_name]] - table = pa.Table.from_pylist(flattened_buffers) + def _stop_load_queue_loop(self) -> None: + """Stops the `_load_queue` loop sending a `None`.""" + self._logger.debug("Sending `None` to the load queue to notify stop...") + self._load_queue.put(None) - last_schema = self._buffer_last_schema.get(step_name) - if last_schema is None: - self._buffer_last_schema[step_name] = table.schema - else: - if not last_schema.equals(table.schema): - new_schema = pa.unify_schemas([last_schema, table.schema]) - self._buffer_last_schema[step_name] = new_schema - table = table.cast(new_schema) + def _stop_output_queue_loop(self) -> None: + """Stops the `_output_queue` loop sending a `None`.""" + self._logger.debug("Sending `None` to the output queue to notify stop...") + self._output_queue.put(None) - next_file_number = self._buffers_last_file[step_name] - self._buffers_last_file[step_name] = next_file_number + 1 + def _handle_keyboard_interrupt(self) -> Any: + """Handles KeyboardInterrupt signal sent during the Pipeline.run method. - parquet_file = step_parquet_dir / f"{str(next_file_number).zfill(5)}.parquet" - pq.write_table(table, parquet_file) - self._logger.debug(f"Written to file '{parquet_file}'") + It will try to call self._stop (if the pipeline didn't started yet, it won't + have any effect), and if the pool is already started, will close it before exiting + the program. - self._clean_buffer(step_name) + Returns: + The original `signal.SIGINT` handler. + """ - def _clean_buffer(self, step_name: str) -> None: - """Cleans the buffer by setting it's content to `None`. + def signal_handler(signumber: int, frame: Any) -> None: + self._stop() - Args: - step_name: The name of the buffer to clean. - """ - self._buffers[step_name] = [] - - def close(self) -> None: - """Closes the buffer by writing the remaining content to the file.""" - for step_name in self._buffers: - if self._buffers[step_name]: - self._write(step_name) - - # We need to read the parquet files and write them again to ensure the schema - # is correct. Otherwise, the first parquets won't have the last schema and - # then we will have issues when reading them. - for file in list_files_in_dir(self._path / step_name): - if step_name in self._buffer_last_schema: - table = pq.read_table( - file, schema=self._buffer_last_schema[step_name] - ) - pq.write_table(table, file) + return signal.signal(signal.SIGINT, signal_handler) diff --git a/src/distilabel/pipeline/batch.py b/src/distilabel/pipeline/batch.py new file mode 100644 index 0000000000..d8ad4312ae --- /dev/null +++ b/src/distilabel/pipeline/batch.py @@ -0,0 +1,233 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import hashlib +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple, Union + +import fsspec +import pyarrow as pa +import pyarrow.parquet as pq +from upath import UPath + +from distilabel.utils.serialization import _Serializable + + +@dataclass +class _Batch(_Serializable): + """Dataclass to represent a batch of data to be processed by a `_Step`. + + Attributes: + seq_no: The sequence number of the batch. + step_name: The name of the step that will process the batch. + last_batch: A flag to indicate if the batch is the last one. + data: The data to be processed. + data_hash: The hash of the data. Defaults to `None`. + data_path: The path where the data of the batch is stored. Defaults to `None`. + accumulated: A flag to indicate if the batch is accumulated. + created_from: A dictionary containing the `seq_no` of the batches of the steps that + were used to create this batch. + size: The size of the batch. + """ + + seq_no: int + step_name: str + last_batch: bool + data: List[List[Dict[str, Any]]] = field(default_factory=list, repr=False) + data_hash: Optional[str] = None + data_path: Optional[str] = None + accumulated: bool = False + created_from: Dict[str, List[Tuple[int, int]]] = field(default_factory=dict) + batch_routed_to: List[str] = field(default_factory=list) + size: int = 0 + _fs: Optional[fsspec.AbstractFileSystem] = None + + def next_batch(self) -> "_Batch": + """Create a new `_Batch` instance with the next batch of data. + + Args: + data: The data to be processed. + + Returns: + A `_Batch` instance. + """ + return _Batch( + seq_no=self.seq_no + 1, step_name=self.step_name, last_batch=self.last_batch + ) + + def set_data(self, data: List[List[Dict[str, Any]]]) -> None: + """Sets the data of the batch and updates the size of the batch. + + Args: + data: The data of the batch. + """ + self.data = data + self.size = len(data[0]) + self._update_data_hash() + + def get_data(self, num_rows: Union[int, None] = None) -> List[Dict[str, Any]]: + """Takes `num_rows` from the data of the batch and returns it. This method will + also remove the data from the batch and update the hash of the data. + + Args: + num_rows: The number of rows to take from the data. If `None`, then all the + data will be taken. Defaults to `None`. + + Returns: + A list with the data taken from the batch. + """ + + if self.data == [] and self.data_path is not None: + pass + + if num_rows is None: + data = self.data[0] + self.data = [] + else: + data = self.data[0][:num_rows] + self.data[0] = self.data[0][num_rows:] + + self._update_data_hash() + return data + + def _update_data_hash(self) -> None: + """Updates the hash of the data of the batch.""" + self.data_hash = hashlib.sha1(str(self.data).encode()).hexdigest() + + @classmethod + def accumulate(cls, step_name: str, batches: List[List["_Batch"]]) -> "_Batch": + """Creates a `_Batch` instance using the data from the list of batches that + were received from another steps. The batches will be accumulated in a single + list of data. + + Args: + step_name: The name of the step that will process the batch. + batches: a list containing the list of batches received from the predecessors. + + Returns: + A `_Batch` instance. + """ + + data = [] + for step_batches in batches: + accumulated_data = [row for batch in step_batches for row in batch.data[0]] + data.append(accumulated_data) + return cls( + seq_no=0, step_name=step_name, last_batch=True, data=data, accumulated=True + ) + + def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]: + """Dumps the content of the `_Batch` to a dictionary, using the `dataclass` helper function. + + Args: + obj: Unused, just kept to match the signature of the parent method. + kwargs: Additional arguments that are kept to match the signature of the parent method. + + Returns: + A `dict` containing the internal representation of the `_Batch`. + """ + + include_batch_data = kwargs.get("include_batch_data", True) + + dump = { + "seq_no": self.seq_no, + "step_name": self.step_name, + "last_batch": self.last_batch, + "data_hash": self.data_hash, + "accumulated": self.accumulated, + "created_from": self.created_from, + "batch_routed_to": self.batch_routed_to, + "size": self.size, + } + + if include_batch_data: + dump["data"] = self.data + + return dump + + def copy(self) -> "_Batch": + """Creates a copy of the `_Batch` instance. + + Returns: + A copy of the `_Batch` instance. + """ + return copy.deepcopy(self) + + def write_batch_data_to_fs( + self, + fs: Optional[fsspec.AbstractFileSystem] = None, + base_path: Optional[UPath] = None, + ) -> None: + """Writes the content of the batch to the filesystem. + + Args + fs: The `fsspec` filesystem to be used to write the data. If not provided, the + one set in the `_fs` attribute will be used. Defaults to `None`. + base_path: The base path where the data of the batch will be stored. If not + provided, the one set in the `data_path` attribute will be used. Defaults + to `None`. + + Raises: + ValueError: If `fs` is not provided and the `_fs` attribute is not set. + """ + + if not fs and not self._fs: + raise ValueError( + "The `fs` parameter must be provided if the `_fs` attribute is not set." + ) + + if fs: + self._fs = fs + + if not base_path and not self.data_path: + raise ValueError( + "The `base_path` parameter must be provided if the `data_path` attribute" + " is not set." + ) + + seq_no_dir = ( + base_path / f"seq_no_{self.seq_no}" if base_path else UPath(self.data_path) + ) + seq_no_dir._fs_cached = self._fs # type: ignore + seq_no_dir.mkdir(parents=True, exist_ok=True) + + for i, data in enumerate(self.data): + table = pa.Table.from_pylist(data) + with self._fs.open(seq_no_dir / f"data_index_{i}.parquet", "wb") as f: # type: ignore + pq.write_table(table, f) + + self.data = [] + self.data_path = str(seq_no_dir) + + def read_batch_data_from_fs(self) -> None: + """Reads the content of the batch from the filesystem.""" + if not self.data_path: + raise ValueError( + "`data_path` attribute must be set to read the data from the filesystem." + " Use `write_batch_data_to_fs` method to set the `data_path` attribute." + ) + + if not self._fs: + raise ValueError( + "`_fs` attribute must be set to read the data from the filesystem." + " Use `write_batch_data_to_fs` method to set the `_fs` attribute." + ) + + for file in self._fs.ls(self.data_path): + with self._fs.open(file, "rb") as f: + table = pq.read_table(f) + self.data.append(table.to_pylist()) + + self._fs.rm(self.data_path, recursive=True) diff --git a/src/distilabel/pipeline/batch_manager.py b/src/distilabel/pipeline/batch_manager.py new file mode 100644 index 0000000000..cc14f0dd21 --- /dev/null +++ b/src/distilabel/pipeline/batch_manager.py @@ -0,0 +1,896 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, Union + +from distilabel.pipeline._dag import DAG +from distilabel.pipeline.batch import _Batch +from distilabel.pipeline.constants import ( + RECEIVES_ROUTED_BATCHES_ATTR_NAME, + STEP_ATTR_NAME, +) +from distilabel.steps.base import _Step +from distilabel.utils.files import list_files_in_dir +from distilabel.utils.serialization import ( + StrOrPath, + _check_is_dir, + _Serializable, + read_json, +) + +if TYPE_CHECKING: + from distilabel.utils.serialization import StrOrPath + + +@dataclass +class _BatchManagerStep(_Serializable): + """A class that will accumulate data for a step from the predecessors and create + batches for the step to process when there is enough data. + + Attributes: + step_name: The name of the step that will process the data. + accumulate: A flag to indicate if the data should be accumulated and create a + batch with all the data received from the predecessors instead of creating + batches with the `input_batch_size`. + input_batch_size: The size of the batch to be created for the step to process. + If `None`, then `accumulate` must be `True`. Defaults to `None`. + data: A dictionary with the predecessor step name as the key and a list of + dictionaries (rows) as the value. + built_batches: A list with the batches that were built and sent to the step queue, + but the step was stopped before processing the batch, so the batch doesn't get + lost. Defaults to an empty list. + seq_no: The sequence number of the next batch to be created. It will be + incremented for each batch created. + last_batch_received: A list with the names of the steps that sent the last + batch of data. + convergence_step: A flag to indicate if the step is a convergence step. An + `Step` is a convergence step if all its predecessors are receiving routed + batches. Defaults to `False`. + convergence_step_batches_consumed: A dictionary in which the key is the `seq_no` + of the batch created by step A, that was used by step B and C and obtained from + the `created_from` of the batches created by them. It's used to know if all + the batches from B and C steps created from batches of A have been consumed + by D, in order to not mess up the order of the batches. Only used if `convergence_step=True`. + Defaults to an empty dictionary. + next_expected_created_from_batch_seq_no: The next expected sequence number of the + batch from step A used by steps B and C and obtained from the `created_from` + of the batches created by them. It's used to avoid messing up the order of the + batches. Only used if `convergence_step=True`. Defaults to `0`. + """ + + step_name: str + accumulate: bool + input_batch_size: Union[int, None] = None + data: Dict[str, List[_Batch]] = field(default_factory=dict) + built_batches: List[_Batch] = field(default_factory=list) + seq_no: int = 0 + last_batch_received: List[str] = field(default_factory=list) + convergence_step: bool = False + convergence_step_batches_consumed: Dict[str, Dict[str, int]] = field( + default_factory=dict + ) + next_expected_created_from_batch_seq_no: int = 0 + + def add_batch(self, batch: _Batch, prepend: bool = False) -> None: + """Add a batch of data from `batch.step_name` to the step. It will accumulate the + data and keep track of the last batch received from the predecessors. + + Args: + batch: The output batch of an step to be processed by the step. + prepend: If `True`, the content of the batch will be added to the `built_batches` + list. This is done so if a `_Batch` was already built and send to the step + queue, and the step is stopped before processing the batch, the batch doesn't + get lost. Defaults to `False`. + """ + from_step = batch.step_name + + if prepend: + self.built_batches.append(batch) + else: + self.data[from_step].append(batch) + + if batch.last_batch: + self.last_batch_received.append(from_step) + + def get_batch(self) -> Union[_Batch, None]: + """Create a new batch of data for the step to process. It will return `None` if + there is not enough data to create a batch. + + Returns: + A `_Batch` instance if there is enough data to create a batch. Otherwise, + `None`. + """ + if not self._ready_to_create_batch(): + return None + + # If there are batches in the `built_batches` list, then return the first one + # and remove it from the list. + if self.built_batches: + return self.built_batches.pop(0) + + # `_last_batch` must be called before `_get_data`, as `_get_data` will update the + # list of data which is used to determine if the batch to be created is the last one. + # TODO: remove `_last_batch` method and integrate logic in `_get_data` + last_batch = self._last_batch() + data, created_from, batch_routed_to = self._get_data() + + return _Batch( + seq_no=self._get_seq_no(), + step_name=self.step_name, + last_batch=last_batch, + data=data, + accumulated=self.accumulate, + created_from=created_from, + batch_routed_to=batch_routed_to, + ) + + def empty_buffers(self) -> List[str]: + """Checks if the input buffer for the step is empty. + + Returns: + The name of the previous steps for which the input buffer for this step is + empty. + """ + if self.accumulate: + return [ + previous_step + for previous_step in self.data.keys() + if previous_step not in self.last_batch_received + ] + + return [ + previous_step + for previous_step, batches in self.data.items() + if previous_step not in self.last_batch_received + and sum(len(batch.data[0]) for batch in batches) < self.input_batch_size # type: ignore + ] + + @classmethod + def from_step( + cls, step: "_Step", predecessors: Iterable[str], convergence_step: bool = False + ) -> "_BatchManagerStep": + """Creates a `_BatchManagerStep` instance from a `_Step` instance and its + predecessors. + + Args: + step: The `_Step` instance. + predecessors: The names of the predecessors of the step. + convergence_step: A flag to indicate if the step is a convergence step. An + `Step` is a convergence step if all its predecessors are receiving routed + batches. Defaults to `False`. + + Returns: + A `_BatchManagerStep` instance. + """ + return cls( + step_name=step.name, # type: ignore + accumulate=step.is_global, + input_batch_size=getattr(step, "input_batch_size", None), + data={predecessor: [] for predecessor in predecessors}, + convergence_step=convergence_step, + ) + + def _get_seq_no(self) -> int: + """Gets the sequence number for the next batch to be created and increments it. + + Returns: + The sequence number for the next batch to be created. + """ + seq_no = self.seq_no + self.seq_no += 1 + return seq_no + + def _get_data( + self, + ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]], List[str]]: + """Gets the data needed to create a batch for the step to process. If the step is + accumulating data, then it will return a list with all the data received from the + predecessors. Otherwise, it will return a list of data with the `input_batch_size` + for each predecessor. In addition, it will remove the data used to create the + batch from the step's data. + + Returns: + A tuple containing the list of data needed to create a batch for the step to + process, a dictionary with the sequence numbers of the batches that were used + to create the batch and the list of steps to which the batch was routed to if + the step is a normal step. + """ + if self.accumulate: + # Steps accumulating cannot receive routed batches + return self._get_data_for_accumulate() + ([],) + + if self.convergence_step: + # Convergence steps will receive routed batches, but we need to clean the + # `batch_routed_to` list + return self._get_data_for_convergence_step() + ([],) + + return self._get_data_normal() + + def _get_data_for_accumulate( + self, + ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]]]: + """Gets the data needed to create a batch for the step to process when the step + is accumulating data. It will return a list with all the data received from the + predecessors. In addition, it will remove the data used to create the batch from + the step's data. + + Returns: + A tuple containing the list of data needed to create a batch for the step to + process and a dictionary with the sequence numbers of the batches that were + used to create the batch. + """ + data = [] + batches_used = {} + for step_name, batches in self.data.items(): + batches_used[step_name] = [] + for batch in batches: + batches_used[step_name].append((batch.seq_no, batch.size)) + data.append([row for batch in batches for row in batch.get_data()]) + # Reset the data buffer + self.data = {step_name: [] for step_name in self.data} + return data, batches_used + + def _get_data_for_convergence_step( + self, + ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]]]: + """Gets the data needed to create a batch for the step to process when the step is + a convergence step. + + Returns: + A tuple containing the list of data needed to create a batch for the step to + process and a dictionary with the sequence numbers of the batches that were + used to create the batch. + """ + grouped_batches = self._group_batches_by_created_from() + seq_no, batches = grouped_batches[0] + str_seq_no = str(seq_no) + + remaining_rows_per_step: Dict[str, int] = { + step_name: self.input_batch_size + for step_name in self.data # type: ignore + } + batches_used = defaultdict(list) + data = defaultdict(list) + for batch, batch_size in batches: + remaining_rows = remaining_rows_per_step[batch.step_name] + selected_data = batch.get_data(remaining_rows) + data[batch.step_name].extend(selected_data) + + # If A -> [B, C] -> D, then in D (this step) we keep track of the remaining + # rows from the batches of A that B and C used to create the `batches`. + batch_size = self.convergence_step_batches_consumed.setdefault( + str_seq_no, {} + ).get(batch.step_name, batch_size) + remaining_rows_in_batch = batch_size - len(selected_data) + self.convergence_step_batches_consumed[str_seq_no].update( + {batch.step_name: remaining_rows_in_batch} + ) + + # Update the remaining rows + num_rows = len(selected_data) + remaining_rows_per_step[batch.step_name] -= num_rows # type: ignore + + # Keep track of the batches used to create the batch + batches_used[batch.step_name].append((batch.seq_no, batch.size)) + + # If the batch was entirely consumed, then remove it from the buffer + if len(batch.data[0]) == 0: + self.data[batch.step_name].remove(batch) + continue + + # If all the batches grouped by the `seq_no` in `created_from` were consumed, then + # we can update the `next_expected_created_from_batch_seq_no` to the next one + # to avoid skipping batches. + no_remaining_rows = all( + count == 0 + for count in self.convergence_step_batches_consumed[str_seq_no].values() + ) + if no_remaining_rows: + self.next_expected_created_from_batch_seq_no += 1 + + return list(data.values()), dict(batches_used) + + def _get_data_normal( + self, + ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]], List[str]]: + """Gets the data needed to create a batch for the step to process when the step is + not accumulating data. It will return a list of data with the `input_batch_size` + for each predecessor. In addition, it will remove the data used to create the batch + from the step's data. + + Returns: + A tuple containing the list of data needed to create a batch for the step to + process, a dictionary with the sequence numbers of the batches that were used + to create the batch and the list of steps to which the batch was routed to if + the step is a convergence step. + """ + data = [] + batches_used = defaultdict(list) + batch_routed_to = [] + for step_name in self.data: + # For each step batches buffer, we will create a batch with the `input_batch_size` + # using the data from the buffer. We will remove the consumed batches (no data + # left) and update the batch data with the remaining data. + step_data = [] + idx_drop_batches = [] + remaining_rows: int = self.input_batch_size # type: ignore + for idx, batch in enumerate(self.data[step_name]): + if remaining_rows == 0: + break + + # Get `remaining_rows` or the remaining rows in the batch and add it to + # the step data that will be used to create the batch + selected_data = batch.get_data(remaining_rows) + step_data.extend(selected_data) + batch_routed_to = batch.batch_routed_to + + # Update the remaining rows + num_rows = len(selected_data) + remaining_rows -= num_rows + + # Keep track of the batches used to create the batch + batches_used[step_name].append((batch.seq_no, batch.size)) + + # If the batch was entirely consumed, then remove it from the buffer + if len(batch.data[0]) == 0: + idx_drop_batches.append(idx) + continue + + # Remove the batches that were entirely consumed + idx_drop_batches.reverse() + for idx in idx_drop_batches: + self.data[step_name].pop(idx) + + data.append(step_data) + + return data, dict(batches_used), batch_routed_to + + def _ready_to_create_batch(self) -> bool: + """Checks if there is enough data to create a batch for the step. + + Returns: + `True` if there is enough data to create a batch for the step. Otherwise, + `False`. + """ + if self.accumulate: + return self._ready_to_create_batch_accumulate() + + if self.convergence_step: + return self._ready_to_create_batch_convergence_step() + + return self._ready_to_create_batch_normal() + + def _ready_to_create_batch_accumulate(self) -> bool: + """Checks if there is enough data for an step accumulating data. It will return + `True` if the last batch was received from all the predecessors. + + Returns: + `True` if ready to create a batch, `False` otherwise. + """ + return all( + step in self.last_batch_received + and sum(len(batch.data[0]) for batch in batches) >= 0 + for step, batches in self.data.items() + ) + + def _ready_to_create_batch_convergence_step(self) -> bool: + """Checks if there is enough data for creating a batch for an step in which output + batches that were generated by steps that received routed batches are received. + It will return `True`, if all the output batches that were generated from a routed + batch have been received. + + Returns: + `True` if ready to create a batch, `False` otherwise. + """ + grouped_batches = self._group_batches_by_created_from() + if not grouped_batches: + return False + seq_no, batches = grouped_batches[0] + + # If the `seq_no` from the `created_from` field is not the expected one, then + # we cannot create a batch yet or the order will be messed up + if seq_no != self.next_expected_created_from_batch_seq_no: + return False + + # Not all output batches to which the input batch was routed to haven't been + # received + batch_routed_to = batches[0][0].batch_routed_to + batches_received_from = {batch.step_name for batch, _ in batches} + if any(step_name not in batches_received_from for step_name in batch_routed_to): + return False + + # There are output batches to which the input batch was routed to from all + # the steps. Check if there is enough data for creating a batch with `input_batch_size` + rows_per_step = defaultdict(lambda: 0) + for batch, _ in batches: + num_rows = len(batch.data[0]) + rows_per_step[batch.step_name] += num_rows + + # If there aren't at least `input_batch_size` rows from each step, then there + # isn't enough data to create a batch + if not all( + num_rows >= self.input_batch_size or step_name in self.last_batch_received # type: ignore + for step_name, num_rows in rows_per_step.items() + ): + return False + + return True + + def _ready_to_create_batch_normal(self) -> bool: + """Checks if there is enough data for creating a batch for a normal step. It will + be `True` it there are at least `input_batch_size` rows from each predecessor step. + + Returns: + `True` if ready to create a batch, `False` otherwise. + """ + for step_name, batches in self.data.items(): + num_rows = sum(len(batch.data[0]) for batch in batches) + + # If there are now rows but the last batch was already received, then there + # are no more batch to be created + if num_rows == 0 and step_name in self.last_batch_received: + return False + + # If there are not enough rows and the last batch was not received yet, then + # there is not enough data yet to creata a batch + if ( + self.input_batch_size + and num_rows < self.input_batch_size + and step_name not in self.last_batch_received + ): + return False + + return True + + def _last_batch(self) -> bool: + """Checks if the batch to be created is the last one i.e. if the last batch was + received from all the predecessors. + + Returns: + `True` if the batch to be created is the last one. Otherwise, `False`. + """ + if self.accumulate: + return self._last_batch_accumulate() + + if self.convergence_step: + return self._last_batch_convergence_step() + + return self._last_batch_normal() + + def _last_batch_accumulate(self) -> bool: + """Checks if the batch to be created is the last one for an step accumulating data. + `True` if the last batch was received from all the predecessors. + + Returns: + `True` if the batch to be created is the last one. Otherwise, `False`. + """ + return all(step in self.last_batch_received for step in self.data.keys()) + + def _last_batch_convergence_step(self) -> bool: + """Checks if the batch to be created is the last one for a convergence step. `True` + if the last batch of all the steps (`batch_routed_to`) in the last routed batch + have been received. + + Returns: + `True` if the batch to be created is the last one. Otherwise, `False`. + """ + grouped_batches = self._group_batches_by_created_from() + if not grouped_batches: + return False + _, batches = grouped_batches[0] + + for batch, _ in batches: + if not batch.last_batch: + return False + + if len(batch.data[0]) > self.input_batch_size: # type: ignore + return False + + return True + + def _last_batch_normal(self) -> bool: + """Checks if the batch to be created is the last one for a normal step. `True` if + there is no more data to be received from the predecessors. + + Returns: + `True` if the batch to be created is the last one. Otherwise, `False`. + """ + for step_name, batches in self.data.items(): + if step_name not in self.last_batch_received: + return False + + num_rows = sum(len(batch.data[0]) for batch in batches) + + if ( + self.input_batch_size + and num_rows > self.input_batch_size + and step_name in self.last_batch_received + ): + return False + + return True + + def _group_batches_by_created_from( + self, + ) -> List[Tuple[int, List[Tuple["_Batch", int]]]]: + """Group the batches by the first key of `created_from` field. This method is + meant to be used only with a `convergence_step`. + + Returns: + A list of the batches grouped by the `seq_no` of the first step name in `created_from`. + The list is sorted by the `seq_no`. + """ + grouped_batches: Dict[int, List[Tuple["_Batch", int]]] = defaultdict(list) + for batches in self.data.values(): + for batch in batches: + first_key = next(iter(batch.created_from)) + batch_seq_no, batch_size = batch.created_from[first_key][0] + grouped_batches[batch_seq_no].append((batch, batch_size)) + return sorted((seq_no, batches) for seq_no, batches in grouped_batches.items()) + + def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]: + """Dumps the content of the `_BatchManagerStep` to a dictionary, using the `dataclass` helper function. + + Args: + obj: Unused, just kept to match the signature of the parent method. + kwargs: Additional arguments that are kept to match the signature of the parent method. + + Returns: + Internal representation of the `_BatchManagerStep`. + """ + return { + "step_name": self.step_name, + "accumulate": self.accumulate, + "input_batch_size": self.input_batch_size, + "data": { + step_name: [batch.dump(**kwargs) for batch in batches] + for step_name, batches in self.data.items() + }, + "built_batches": [batch.dump(**kwargs) for batch in self.built_batches], + "seq_no": self.seq_no, + "last_batch_received": self.last_batch_received, + "convergence_step": self.convergence_step, + "convergence_step_batches_consumed": self.convergence_step_batches_consumed, + "next_expected_created_from_batch_seq_no": self.next_expected_created_from_batch_seq_no, + } + + +class _BatchManager(_Serializable): + """Class to manage the batches received from the steps. It keeps track of the + received batches and returns new batches for the steps to process based on their + input batch size and the batches received from the predecessors. + + Attributes: + steps: A dictionary with the step name as the key and a `_BatchManagerStep` + instance as the value. + last_batch_received: A dictionary with the step name as the key and a flag to + indicate whether we received the last batch from the step. + """ + + def __init__( + self, + steps: Dict[str, _BatchManagerStep], + last_batch_received: Dict[str, Union[_Batch, None]], + last_batch_sent: Dict[str, Union[_Batch, None]], + last_batch_flag_sent_to: List[str], + ) -> None: + """Initialize the `_BatchManager` instance. + + Args: + steps: A dictionary with the step name as the key and a dictionary with the + predecessor step name as the key and a list of batches as the value. + last_batch_received: A dictionary with the step name as the key and a the last + `_Batch` received from the step. + last_batch_sent: A dictionary with the step name as the key and a the last + `_Batch` sent to the step. + last_batch_flag_sent_to: A list with the names of the steps to which `LAST_BATCH_SENT_FLAG` + was sent. + """ + + self._steps = steps + self._last_batch_received = last_batch_received + self._last_batch_sent = last_batch_sent + self._last_batch_flag_sent_to = last_batch_flag_sent_to + + def can_generate(self) -> bool: + """Checks if there are still batches to be processed by the steps. + + Returns: + `True` if there are still batches to be processed by the steps. Otherwise, + `False`. + """ + + for step_name, batch in self._last_batch_received.items(): + if step_name not in self._last_batch_flag_sent_to: + if not batch: + return True + + if not batch.last_batch: + return True + + if not self.get_last_batch_sent(step_name): + return True + + return False + + def register_batch(self, batch: _Batch) -> None: + """Method to register a batch received from a step. It will keep track of the + sequence number and the last batch received from the step in the internal maps. + + Args: + batch: _Batch from which we will register the sequence number and the last batch received. + """ + self._last_batch_received[batch.step_name] = batch + + def get_last_batch(self, step_name: str) -> Union[_Batch, None]: + """Gets the last batch received from a step. + + Args: + step_name: The name of the step. + + Returns: + The last batch received from the step or `None` if no batch was received. + """ + return self._last_batch_received.get(step_name) + + def add_batch(self, to_step: str, batch: _Batch, prepend: bool = False) -> None: + """Add an output batch from `batch.step_name` to `to_step`. + + Args: + to_step: The name of the step that will process the batch. + batch: The output batch of an step to be processed by `to_step`. + prepend: If `True`, the content of the batch will be added at the start of + the buffer. + + Raises: + ValueError: If `to_step` is not found in the batch manager. + """ + if to_step not in self._steps: + raise ValueError(f"Step '{to_step}' not found in the batch manager.") + + step = self._steps[to_step] + step.add_batch(batch, prepend) + + def get_batch(self, step_name: str) -> Union[_Batch, None]: + """Get the next batch to be processed by the step. + + Args: + step_name: The name of the step that will process the batch. + + Returns: + A `_Batch` instance if there is a batch to be processed by the step. Otherwise, + `None`. + """ + if step_name not in self._steps: + raise ValueError(f"Step '{step_name}' not found in the batch manager.") + + return self._steps[step_name].get_batch() + + def step_empty_buffers(self, step_name: str) -> List[str]: + """Checks if the input buffer for a step is empty. + + Args: + step_name: The name of the step. + + Returns: + The name of the previous steps for which the input buffer for this step is + empty. + """ + return self._steps[step_name].empty_buffers() + + def set_last_batch_sent(self, batch: "_Batch") -> None: + """Set the last batch sent to a step. + + Args: + batch: The last batch sent to a step. + """ + self._last_batch_sent[batch.step_name] = batch + + def get_last_batch_sent(self, step_name: str) -> Union["_Batch", None]: + """Get the last batch sent to a step. + + Args: + step_name: The name of the step. + + Returns: + The last batch sent to a step or `None` if no batch was sent. + """ + return self._last_batch_sent.get(step_name, None) + + def set_last_batch_flag_sent_to(self, step_name: str) -> None: + """Set the flag to indicate that the last batch was sent to a step. + + Args: + step_name: The name of the step. + """ + self._last_batch_flag_sent_to.append(step_name) + + @classmethod + def from_dag(cls, dag: "DAG") -> "_BatchManager": + """Create a `_BatchManager` instance from a `DAG` instance. + + Args: + dag: The `DAG` instance. + + Returns: + A `_BatchManager` instance. + """ + steps = {} + last_batch_received = {} + last_batch_sent = {} + for step_name in dag: + step: "_Step" = dag.get_step(step_name)[STEP_ATTR_NAME] + last_batch_received[step.name] = None + last_batch_sent[step.name] = None + if step.is_generator: + continue + predecessors = list(dag.get_step_predecessors(step_name)) + convergence_step = all( + dag.get_step(predecessor).get(RECEIVES_ROUTED_BATCHES_ATTR_NAME, False) + for predecessor in predecessors + ) + batch_manager_step = _BatchManagerStep.from_step( + step=step, + predecessors=predecessors, + convergence_step=convergence_step, + ) + steps[step_name] = batch_manager_step + return cls(steps, last_batch_received, last_batch_sent, []) + + def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]: + """Dumps the content of the `_BatchManager` to a dictionary. + + Args: + obj (Any): Unused, just kept to match the signature of the parent method. + kwargs (Any): Additional arguments that are kept to match the signature of the parent method. + + Returns: + Dict[str, Any]: Internal representation of the `_BatchManager`. + """ + return { + "steps": {name: step.dump(**kwargs) for name, step in self._steps.items()}, + "last_batch_received": { + step_name: batch.dump(**kwargs) if batch is not None else None + for step_name, batch in self._last_batch_received.items() + }, + "last_batch_sent": { + step_name: batch.dump(**kwargs) if batch is not None else None + for step_name, batch in self._last_batch_sent.items() + }, + "last_batch_flag_sent_to": self._last_batch_flag_sent_to, + } + + def cache(self, path: "StrOrPath") -> None: + """Cache the `_BatchManager` to a file. + + Args: + path: The path to the file where the `_BatchManager` will be cached. If `None`, + then the `_BatchManager` will be cached in the default cache folder. + """ + + def save_batch( + batches_dir: Path, batch_dump: Dict[str, Any], batch_list: List[_Batch] + ) -> Path: + seq_no = batch_dump["seq_no"] + data_hash = batch_dump["data_hash"] + batch_file = batches_dir / f"batch_{seq_no}_{data_hash}.json" + + # Save the batch if it doesn't exist + if not batch_file.exists(): + # Get the data of the batch before saving it + batch = next(batch for batch in batch_list if batch.seq_no == seq_no) + batch_dump["data"] = batch.data + self.save(path=batch_file, format="json", dump=batch_dump) + + return batch_file + + def remove_files(keep_files: List[str], dir: Path) -> None: + files = list_files_in_dir(dir, key=None) + remove = set(files) - {Path(file) for file in keep_files} + for file in remove: + file.unlink() + + path = Path(path) + + # Do not include `_Batch` data so `dump` is fast + dump = self.dump(include_batch_data=False) + batch_manager_step_files = {} + + # Do this to avoid modifying the dictionary while iterating over it + batch_manager_steps = set(dump["steps"].keys()) + for step_name in batch_manager_steps: + step_dump = dump["steps"].pop(step_name) + + # Create a directory for each batch manager step to store their batches + batch_manager_step_dir = path.parent / "batch_manager_steps" / step_name + batch_manager_step_dir.mkdir(parents=True, exist_ok=True) + + # Store each built `_Batch` in a separate file + built_batches_dir = batch_manager_step_dir / "built_batches" + built_batches_dir.mkdir(parents=True, exist_ok=True) + step_dump["built_batches"] = [ + str( + save_batch( + batches_dir=built_batches_dir, + batch_dump=batch_dump, + batch_list=self._steps[step_name].built_batches, + ) + ) + for batch_dump in step_dump["built_batches"] + ] + # Remove built `_Batch`es that were consumed from cache + remove_files(step_dump["built_batches"], built_batches_dir) + + # Store each `_BatchManagerStep` `_Batch`es in a separate file + for buffered_step_name in step_dump["data"]: + step_batches_dir = batch_manager_step_dir / buffered_step_name + step_batches_dir.mkdir(parents=True, exist_ok=True) + + # Store each `_Batch` in a separate file + step_dump["data"][buffered_step_name] = [ + str( + save_batch( + batches_dir=step_batches_dir, + batch_dump=batch_dump, + batch_list=self._steps[step_name].data[buffered_step_name], + ) + ) + for batch_dump in step_dump["data"][buffered_step_name] + ] + + # Remove `_Batch`es that were consumed from cache + remove_files(step_dump["data"][buffered_step_name], step_batches_dir) + + # Store the `_BatchManagerStep` info + batch_manager_step_file = str( + path.parent / f"batch_manager_steps/{step_name}/batch_manager_step.json" + ) + self.save(path=batch_manager_step_file, format="json", dump=step_dump) + + # Store the path to the `_BatchManagerStep` file + batch_manager_step_files[step_name] = batch_manager_step_file + + dump["steps"] = batch_manager_step_files + self.save(path=path, format="json", dump=dump) + + @classmethod + def load_from_cache(cls, path: "StrOrPath") -> "_BatchManager": + """Loads the `_BatchManager` from a cache file. + + Args: + path: The path to the cache file. + """ + _check_is_dir(path) + content = read_json(path) + + # Read each `_BatchManagerStep` from file + steps = {} + for step_name, step_file in content["steps"].items(): + steps[step_name] = read_json(step_file) + + # Read each `_Batch` from file + steps[step_name]["built_batches"] = [ + read_json(batch) for batch in steps[step_name]["built_batches"] + ] + + for buffered_step_name, batch_files in steps[step_name]["data"].items(): + steps[step_name]["data"][buffered_step_name] = [ + read_json(batch_file) for batch_file in batch_files + ] + + content["steps"] = steps + return cls.from_dict(content) diff --git a/src/distilabel/pipeline/constants.py b/src/distilabel/pipeline/constants.py index 450ef0ed6d..3d400e4a1b 100644 --- a/src/distilabel/pipeline/constants.py +++ b/src/distilabel/pipeline/constants.py @@ -19,3 +19,4 @@ RECEIVES_ROUTED_BATCHES_ATTR_NAME: Final[str] = "receives_routed_batches" ROUTING_BATCH_FUNCTION_ATTR_NAME: Final[str] = "routing_batch_function" CONVERGENCE_STEP_ATTR_NAME: Final[str] = "convergence_step" +LAST_BATCH_SENT_FLAG: Final[str] = "last_batch_sent" diff --git a/src/distilabel/pipeline/local.py b/src/distilabel/pipeline/local.py index 36f41e16b9..51986c031d 100644 --- a/src/distilabel/pipeline/local.py +++ b/src/distilabel/pipeline/local.py @@ -15,53 +15,30 @@ import multiprocessing as mp import signal import sys -import threading -import time import traceback -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast import tblib from distilabel.distiset import create_distiset from distilabel.llms.mixins import CudaDevicePlacementMixin from distilabel.pipeline.base import ( - LAST_BATCH_SENT_FLAG, BasePipeline, - _Batch, ) +from distilabel.pipeline.batch import _Batch from distilabel.pipeline.constants import ( - CONVERGENCE_STEP_ATTR_NAME, - INPUT_QUEUE_ATTR_NAME, - ROUTING_BATCH_FUNCTION_ATTR_NAME, - STEP_ATTR_NAME, + LAST_BATCH_SENT_FLAG, ) -from distilabel.steps.base import Step +from distilabel.steps.tasks.base import Task from distilabel.utils.logging import setup_logging, stop_logging if TYPE_CHECKING: - from multiprocessing.managers import DictProxy, SyncManager - from multiprocessing.pool import Pool from queue import Queue from distilabel.distiset import Distiset - from distilabel.steps.base import GeneratorStep, _Step - - -_STEPS_LOADED_KEY = "steps_loaded" -_STEPS_LOADED_LOCK_KEY = "steps_loaded_lock" -_STEPS_LOADED_ERROR_CODE = -1 -_CUDA_LLM_DEVICE_PLACEMENT_KEY = "cuda_llm_device_placement" -_CUDA_LLM_DEVICE_PLACEMENT_LOCK_KEY = "cuda_llm_device_placement_lock" - -_STOP_CALLED = False -_STOP_CALLED_LOCK = threading.Lock() -_STOP_CALLS = 0 + from distilabel.pipeline.typing import StepLoadStatus + from distilabel.steps.base import GeneratorStep, Step, _Step -_STEPS_LOADED = set() -_STEPS_LOADED_LOCK = threading.Lock() - -_STEPS_FINISHED = set() -_STEPS_FINISHED_LOCK = threading.Lock() _SUBPROCESS_EXCEPTION: Union[Exception, None] = None @@ -129,17 +106,24 @@ def run( initializer=_init_worker, initargs=(log_queue,), ) as pool: - self.output_queue: "Queue[Any]" = manager.Queue() - self.shared_info = self._create_shared_info_dict(manager) - self._handle_keyboard_interrupt(manager=manager, pool=pool) + self._manager = manager + self._pool = pool + self._output_queue = self.QueueClass() + self._load_queue = self.QueueClass() + self._handle_keyboard_interrupt() # Run the steps using the pool of processes - self._run_steps_in_loop(pool, manager, self.output_queue, self.shared_info) + self._run_steps() + + # Run the loop for receiving the load status of each step + self._load_steps_thread = self._run_load_queue_loop_in_thread() # Wait for all the steps to be loaded correctly if not self._all_steps_loaded(): self._write_buffer.close() # type: ignore self._batch_manager = None + self._stop_load_queue_loop() + self._load_steps_thread.join() stop_logging() raise RuntimeError( "Failed to load all the steps. Could not run pipeline." @@ -150,15 +134,20 @@ def run( self._request_initial_batches() # Start a loop to receive the output batches from the steps - self._run_output_queue_loop_in_thread() + self._output_queue_thread = self._run_output_queue_loop_in_thread() + self._output_queue_thread.join() # Send `None` to steps `input_queue`s just in case some step is still waiting self._notify_steps_to_stop() + # Stop the load queue loop + self._stop_load_queue_loop() + # `Pool.__exit__` has already called `terminate`, `join` the pool to make sure # all the processes have finished - pool.join() - manager.join() + self._load_steps_thread.join() + self._pool.join() + self._manager.join() self._write_buffer.close() # type: ignore distiset = create_distiset( @@ -170,421 +159,35 @@ def run( stop_logging() return distiset - def _run_output_queue_loop_in_thread(self) -> None: - """Runs the output queue loop in a separate thread to receive the output batches - from the steps. This is done to avoid the signal handler to block the loop, which - would prevent the pipeline from stopping correctly.""" - thread = threading.Thread(target=self._output_queue_loop) - thread.start() - thread.join() - - def _notify_steps_to_stop(self) -> None: - """Notifies the steps to stop their infinite running loop by sending `None` to - their input queues.""" - for step_name in self.dag: - if input_queue := self.dag.get_step(step_name).get(INPUT_QUEUE_ATTR_NAME): - input_queue.put(None) - - def _output_queue_loop(self) -> None: - """Loop to receive the output batches from the steps and manage the flow of the - batches through the pipeline.""" - while self._batch_manager.can_generate() and not _STOP_CALLED: # type: ignore - self._logger.debug("Waiting for output batch from step...") - if (batch := self.output_queue.get()) is None: - self._logger.debug("Received `None` from output queue. Breaking loop.") - break - - self._logger.debug( - f"Received batch with seq_no {batch.seq_no} from step '{batch.step_name}'" - f" from output queue: {batch}" - ) - - if batch.data_path: - self._logger.debug( - f"Reading {batch.seq_no} batch data from '{batch.step_name}': '{batch.data_path}'" - ) - batch.read_batch_data_from_fs() - - if batch.step_name in self.dag.leaf_steps: - self._write_buffer.add_batch(batch) # type: ignore - - # If `_STOP_CALLED` was set to `True` while waiting for the output queue, then - # we need to handle the stop of the pipeline and break the loop to avoid - # propagating the batches through the pipeline and making the stop process - # slower. - if _STOP_CALLED: - self._handle_batch_on_stop(batch) - break - - self._manage_batch_flow(batch) - - if _STOP_CALLED: - self._handle_stop() - - def _manage_batch_flow(self, batch: "_Batch") -> None: - """Checks if the step that generated the batch has more data in its buffer to - generate a new batch. If there's data, then a new batch is sent to the step. If - the step has no data in its buffer, then the predecessors generator steps are - requested to send a new batch. - - Args: - batch: The batch that was processed. - """ - assert self._batch_manager, "Batch manager is not set" - - # Make sure to send the `LAST_BATCH_SENT_FLAG` to the predecessors of the convergence - # step if the batch is the last one, so they stop their processing loop even if - # they haven't received the last batch because of the routing function. - if self._is_convergence_step(batch.step_name) and batch.last_batch: - for step_name in self.dag.get_step_predecessors(batch.step_name): - self._send_last_batch_flag_to_step(step_name) - - route_to, routed = self._get_successors(batch) - - # Keep track of the steps that the batch was routed to - if routed: - batch.batch_routed_to = route_to - - self._register_batch(batch) - - step = self._get_step_from_batch(batch) - - # Add the batch to the successors input buffers - for successor in route_to: - # Copy batch to avoid modifying the same reference in the batch manager - batch_to_add = batch.copy() if len(route_to) > 1 else batch - - self._batch_manager.add_batch(successor, batch_to_add) - - # Check if the step is a generator and if there are successors that need data - # from this step. This usually happens when the generator `batch_size` is smaller - # than the `input_batch_size` of the successor steps. - if ( - step.is_generator - and step.name in self._batch_manager.step_empty_buffers(successor) - ): - last_batch_sent = self._batch_manager.get_last_batch_sent(step.name) - self._send_batch_to_step(last_batch_sent.next_batch()) # type: ignore - - # If successor step has enough data in its buffer to create a new batch, then - # send the batch to the step. - if new_batch := self._batch_manager.get_batch(successor): - self._send_batch_to_step(new_batch) - - if not step.is_generator: - # Step ("this", the one from which the batch was received) has enough data on its - # buffers to create a new batch - if new_batch := self._batch_manager.get_batch(step.name): # type: ignore - self._send_batch_to_step(new_batch) - else: - self._request_more_batches_if_needed(step) - - self._cache() - - def _register_batch(self, batch: "_Batch") -> None: - """Registers a batch in the batch manager. - - Args: - batch: The batch to register. - """ - self._batch_manager.register_batch(batch) # type: ignore - self._logger.debug( - f"Batch {batch.seq_no} from step '{batch.step_name}' registered in batch" - " manager" - ) - - def _get_successors(self, batch: "_Batch") -> Tuple[List[str], bool]: - """Gets the successors and the successors to which the batch has to be routed. - - Args: - batch: The batch to which the successors will be determined. - - Returns: - The successors to route the batch to and whether the batch was routed using - a routing function. - """ - node = self.dag.get_step(batch.step_name) - step: "Step" = node[STEP_ATTR_NAME] - successors = list(self.dag.get_step_successors(step.name)) # type: ignore - route_to = successors - - # Check if the step has a routing function to send the batch to specific steps - if routing_batch_function := node.get(ROUTING_BATCH_FUNCTION_ATTR_NAME): - route_to = routing_batch_function(batch, successors) - successors_str = ", ".join(f"'{successor}'" for successor in route_to) - self._logger.info( - f"🚏 Using '{step.name}' routing function to send batch {batch.seq_no} to steps: {successors_str}" - ) - - return route_to, route_to != successors - - def _get_step_from_batch(self, batch: "_Batch") -> "Step": - """Gets the `Step` instance from a batch. - - Args: - batch: The batch to get the step from. - - Returns: - The `Step` instance. - """ - return self.dag.get_step(batch.step_name)[STEP_ATTR_NAME] - - def _request_more_batches_if_needed(self, step: "Step") -> None: - """Request more batches to the predecessors steps of `step` if needed. - - Args: - step: The step of which it has to be checked if more batches are needed from - its predecessors. - """ - empty_buffers = self._batch_manager.step_empty_buffers(step.name) # type: ignore - for previous_step_name in empty_buffers: - if previous_step_name not in self.dag.root_steps: - continue - - last_batch = self._batch_manager.get_last_batch_sent(previous_step_name) # type: ignore - if last_batch is None: - continue - - self._logger.debug( - f"Step '{step.name}' input buffer for step '{previous_step_name}' is" - " empty. Requesting new batch..." - ) - self._send_batch_to_step(last_batch.next_batch()) - - def _handle_stop(self) -> None: - """Handles the stop of the pipeline execution, which will stop the steps from - processing more batches and wait for the output queue to be empty, to not lose - any data that was already processed by the steps before the stop was called.""" - self._logger.debug("Handling stop of the pipeline execution...") - - # Add the remaining batches in the input queues back to the batch manager - for step_name in self.dag: - node = self.dag.get_step(step_name) - step: "_Step" = node[STEP_ATTR_NAME] - if step.is_generator: - continue - if input_queue := node.get(INPUT_QUEUE_ATTR_NAME): - while not input_queue.empty(): - batch = input_queue.get() - if batch is None: - continue - self._batch_manager.add_batch( # type: ignore - to_step=step_name, batch=batch, prepend=True - ) - self._logger.debug( - f"Adding batch back to the batch manager: {batch}" - ) - input_queue.put(None) - - # Wait for the input queue to be empty, which means that all the steps finished - # processing the batches that were sent before the stop flag. - for step_name in self.dag: - self._wait_step_input_queue_empty(step_name) - - # Consume the output queue until it's empty to not lose any data that was already - # processed by the steps before stop was called. - while not self.output_queue.empty(): - batch = self.output_queue.get() - if batch is None: - continue - - if batch.step_name in self.dag.leaf_steps: - self._write_buffer.add_batch(batch) # type: ignore - - self._handle_batch_on_stop(batch) - - self._cache() - - def _handle_batch_on_stop(self, batch: "_Batch") -> None: - """Handles a batch that was received from the output queue when the pipeline was - stopped. It will add and register the batch in the batch manager. - - Args: - batch: The batch to handle. - """ - self._batch_manager.register_batch(batch) # type: ignore - step: "Step" = self.dag.get_step(batch.step_name)[STEP_ATTR_NAME] - for successor in self.dag.get_step_successors(step.name): # type: ignore - self._batch_manager.add_batch(successor, batch) # type: ignore - - def _wait_step_input_queue_empty(self, step_name: str) -> Union["Queue[Any]", None]: - """Waits for the input queue of a step to be empty. - - Args: - step_name: The name of the step. - - Returns: - The input queue of the step if it's not loaded or finished, `None` otherwise. - """ - if self._check_step_not_loaded_or_finished(step_name): - return None - - if input_queue := self.dag.get_step(step_name).get(INPUT_QUEUE_ATTR_NAME): - while input_queue.qsize() != 0: - pass - return input_queue - - def _create_shared_info_dict(self, manager: "SyncManager") -> "DictProxy[str, Any]": - """Creates the shared information dictionary to be used by the processes. - - Args: - manager: The manager to create the shared information. - - Returns: - The shared information dictionary. - """ - # TODO: not very important, but we could use a different lock for each matter - return manager.dict( - **{ - _STEPS_LOADED_KEY: manager.list(), - _STEPS_LOADED_LOCK_KEY: manager.Lock(), - _CUDA_LLM_DEVICE_PLACEMENT_KEY: manager.dict(**{}), - _CUDA_LLM_DEVICE_PLACEMENT_LOCK_KEY: manager.Lock(), - } - ) - - def _all_steps_loaded(self) -> bool: - """Waits for all the steps to load. + @property + def QueueClass(self) -> Callable: + """The callable used to create the input and output queues. Returns: - `True` if all the steps have been loaded correctly, `False` otherwise. + The callable to create a `Queue`. """ + assert self._manager, "Manager is not initialized" + return self._manager.Queue - def _update_all_steps_loaded(steps_loaded: List[str]) -> None: - with _STEPS_LOADED_LOCK: - _STEPS_LOADED.update(steps_loaded) - - self._logger.info("⏳ Waiting for all the steps to load...") - previous_message = None - while not _STOP_CALLED: - with self.shared_info[_STEPS_LOADED_LOCK_KEY]: - steps_loaded = self.shared_info[_STEPS_LOADED_KEY] - num_steps_loaded = ( - len(steps_loaded) - if steps_loaded != [_STEPS_LOADED_ERROR_CODE] - else 0 - ) - self._logger.debug(f"Steps loaded: {steps_loaded}") - - message = f"⏳ Steps loaded: {num_steps_loaded}/{len(self.dag)}" - if num_steps_loaded > 0 and message != previous_message: - self._logger.info(message) - previous_message = message - - if num_steps_loaded == len(self.dag): - self._logger.info("✅ All the steps have been loaded!") - _update_all_steps_loaded(steps_loaded) - return True - - if steps_loaded == [_STEPS_LOADED_ERROR_CODE]: - self._logger.error("❌ Failed to load all the steps") - _update_all_steps_loaded(steps_loaded) - return False - - time.sleep(2.5) - - return not _STOP_CALLED - - def _request_initial_batches(self) -> None: - """Requests the initial batches to the generator steps.""" - assert self._batch_manager, "Batch manager is not set" - - for step in self._batch_manager._steps.values(): - if batch := step.get_batch(): - self._logger.debug( - f"Sending initial batch to '{step.step_name}' step: {batch}" - ) - self._send_batch_to_step(batch) - - for step_name in self.dag.root_steps: - seq_no = 0 - if last_batch := self._batch_manager.get_last_batch(step_name): - seq_no = last_batch.seq_no + 1 - batch = _Batch(seq_no=seq_no, step_name=step_name, last_batch=self._dry_run) - self._logger.debug( - f"Requesting initial batch to '{step_name}' generator step: {batch}" - ) - self._send_batch_to_step(batch) - - def _send_batch_to_step(self, batch: "_Batch") -> None: - """Sends a batch to the input queue of a step. - - Args: - batch: The batch to send. - """ - super()._send_batch_to_step(batch) - input_queue = self.dag.get_step(batch.step_name)[INPUT_QUEUE_ATTR_NAME] - input_queue.put(batch) - - def _is_convergence_step(self, step_name: str) -> None: - """Checks if a step is a convergence step. - - Args: - step_name: The name of the step. - """ - return self.dag.get_step(step_name).get(CONVERGENCE_STEP_ATTR_NAME) - - def _send_last_batch_flag_to_step(self, step_name: str) -> None: - """Sends the `LAST_BATCH_SENT_FLAG` to a step to stop processing batches. + def _run_step(self, step: "_Step", input_queue: "Queue[Any]") -> None: + """Runs the `Step` wrapped in a `_ProcessWrapper` in a separate process of the + `Pool`. Args: - step_name: The name of the step. + step: The step to run. + input_queue: The input queue to send the data to the step. """ - batch = self._batch_manager.get_last_batch_sent(step_name) # type: ignore - if batch and batch.last_batch: - return - - self._logger.debug( - f"Sending `LAST_BATCH_SENT_FLAG` to '{step_name}' step to stop processing" - " batches..." + assert self._pool, "Pool is not initialized" + + process_wrapper = _ProcessWrapper( + step=step, + input_queue=input_queue, + output_queue=self._output_queue, + load_queue=self._load_queue, + dry_run=self._dry_run, ) - input_queue = self.dag.get_step(step_name)[INPUT_QUEUE_ATTR_NAME] - input_queue.put(LAST_BATCH_SENT_FLAG) - self._batch_manager.set_last_batch_flag_sent_to(step_name) # type: ignore - def _run_steps_in_loop( - self, - pool: "Pool", - manager: "SyncManager", - output_queue: "Queue[_Batch]", - shared_info: "DictProxy[str, Any]", - ) -> None: - """Using the `pool`, runs the steps in the DAG in an infinite loop waiting for - input batches and sending the output batches to the `output_queue`. - - Each `Step` is wrapped in a `_ProcessWrapper`, which will handle the lifecycle of - the `Step` and the communication with the `input_queue` and `output_queue`. The - `_ProcessWrapper.run` method is the target function of the process. - - Args: - pool: The pool of processes. - manager: The manager to create the queues. - output_queue: The queue to send the output batches. - shared_info: The shared information between the processes. - """ - for step_name in self.dag: - step: "Step" = self.dag.get_step(step_name)[STEP_ATTR_NAME] - input_queue = manager.Queue() - self.dag.set_step_attr(step.name, INPUT_QUEUE_ATTR_NAME, input_queue) # type: ignore - - # Set `pipeline` to `None` as in some Python environments the pipeline is not - # picklable and it will raise an error when trying to send the step to the process. - # `TypeError: cannot pickle 'code' object` - step.pipeline = None - - process_wrapper = _ProcessWrapper( - step=step, - input_queue=input_queue, - output_queue=output_queue, - shared_info=shared_info, - dry_run=self._dry_run, - ) - - pool.apply_async( - process_wrapper.run, - callback=self._finished_callback, - error_callback=self._error_callback, - ) # type: ignore + self._pool.apply_async(process_wrapper.run, error_callback=self._error_callback) def _error_callback(self, e: BaseException) -> None: """Error callback that will be called when an error occurs in a `Step` process. @@ -603,8 +206,6 @@ def _error_callback(self, e: BaseException) -> None: if e.is_load_error: self._logger.error(f"❌ Failed to load step '{e.step.name}': {e.message}") - with self.shared_info[_STEPS_LOADED_LOCK_KEY]: - self.shared_info[_STEPS_LOADED_KEY] = [_STEPS_LOADED_ERROR_CODE] _SUBPROCESS_EXCEPTION = e.subprocess_exception _SUBPROCESS_EXCEPTION.__traceback__ = tblib.Traceback.from_string( # type: ignore e.formatted_traceback @@ -630,94 +231,51 @@ def _error_callback(self, e: BaseException) -> None: # Global step with successors failed self._logger.error(f"An error occurred in global step '{step_name}'") self._logger.error(f"Subprocess traceback:\n\n{e.formatted_traceback}") - self._cache() - self._stop() - - def _finished_callback(self, step_name: str) -> None: - """Callback that will be called when a `Step` process finishes. - - Args: - step_name: The name of the step that finished. - """ - with _STEPS_FINISHED_LOCK: - _STEPS_FINISHED.add(step_name) - - def _check_step_not_loaded_or_finished(self, step_name: str) -> bool: - """Checks if a step is not loaded or already finished. - - Args: - step_name: The name of the step. - - Returns: - `True` if the step is not loaded or already finished, `False` otherwise. - """ - with _STEPS_LOADED_LOCK: - if step_name not in _STEPS_LOADED: - return True - with _STEPS_FINISHED_LOCK: - if step_name in _STEPS_FINISHED: - return True - - return False + self._stop() - def _stop( - self, manager: Optional["SyncManager"] = None, pool: Optional["Pool"] = None - ) -> None: + def _stop(self) -> None: """Stops the pipeline execution. It will first send `None` to the input queues of all the steps and then wait until the output queue is empty i.e. all the steps finished processing the batches that were sent before the stop flag. Then it will send `None` to the output queue to notify the pipeline to stop.""" - global _STOP_CALLED - - with _STOP_CALLED_LOCK: - if _STOP_CALLED: - global _STOP_CALLS - _STOP_CALLS += 1 - if _STOP_CALLS == 1: + with self._stop_called_lock: + if self._stop_called: + self._stop_calls += 1 + if self._stop_calls == 1: self._logger.warning( "🛑 Press again to force the pipeline to stop." ) - elif _STOP_CALLS > 1: + elif self._stop_calls > 1: self._logger.warning("🛑 Forcing pipeline interruption.") - if pool: - pool.terminate() - pool.join() + if self._pool: + self._pool.terminate() + self._pool.join() + self._pool = None - if manager: - manager.shutdown() - manager.join() + if self._manager: + self._manager.shutdown() + self._manager.join() + self._manager = None stop_logging() sys.exit(1) return - _STOP_CALLED = True + self._stop_called = True - self._logger.debug(f"Steps loaded before calling `stop`: {_STEPS_LOADED}") + self._logger.debug( + f"Steps loaded before calling `stop`: {self._steps_load_status}" + ) self._logger.info( "🛑 Stopping pipeline. Waiting for steps to finish processing batches..." ) - self._logger.debug("Sending `None` to the output queue to notify stop...") - self.output_queue.put(None) - - def _handle_keyboard_interrupt( - self, manager: Optional["SyncManager"] = None, pool: Optional["Pool"] = None - ) -> None: - """Handles KeyboardInterrupt signal sent during the Pipeline.run method. - - It will try to call self._stop (if the pipeline didn't started yet, it won't - have any effect), and if the pool is already started, will close it before exiting - the program. - """ - - def signal_handler(signumber: int, frame: Any) -> None: - self._stop(manager=manager, pool=pool) - signal.signal(signal.SIGINT, signal_handler) + self._stop_load_queue_loop() + self._stop_output_queue_loop() class _ProcessWrapperException(Exception): @@ -781,15 +339,16 @@ class _ProcessWrapper: step: The step to run. input_queue: The queue to receive the input data. output_queue: The queue to send the output data. - shared_info: The shared information between the processes. + load_queue: The queue used to notify the main process that the step has been loaded, + has been unloaded or has failed to load. """ def __init__( self, - step: "Step", + step: "_Step", input_queue: "Queue[_Batch]", output_queue: "Queue[_Batch]", - shared_info: "DictProxy[str, Any]", + load_queue: "Queue[Union[StepLoadStatus, None]]", dry_run: bool = False, ) -> None: """Initializes the `_ProcessWrapper`. @@ -798,29 +357,22 @@ def __init__( step: The step to run. input_queue: The queue to receive the input data. output_queue: The queue to send the output data. - shared_info: The shared information between the processes. + load_queue: The queue used to notify the main process that the step has been + loaded, has been unloaded or has failed to load. dry_run: Flag to ensure we are forcing to run the last batch. """ self.step = step self.input_queue = input_queue self.output_queue = output_queue - self.shared_info = shared_info + self.load_queue = load_queue self._dry_run = dry_run - # If step is a task, and it's using a `CUDALLM`, then set the CUDA device map - # and the lock for that map. - if hasattr(self.step, "llm") and isinstance( - self.step.llm, CudaDevicePlacementMixin + if ( + isinstance(self.step, Task) + and hasattr(self.step, "llm") + and isinstance(self.step.llm, CudaDevicePlacementMixin) ): - self.step.llm.set_device_placement_info( - llm_identifier=self.step.name, - device_llm_placement_map=self.shared_info[ - _CUDA_LLM_DEVICE_PLACEMENT_KEY - ], - device_llm_placement_lock=self.shared_info[ - _CUDA_LLM_DEVICE_PLACEMENT_LOCK_KEY - ], - ) + self.step.llm._llm_identifier = self.step.name def run(self) -> str: """The target function executed by the process. This function will also handle @@ -836,6 +388,8 @@ def run(self) -> str: self.step.load() self.step._logger.debug(f"Step '{self.step.name}' loaded!") except Exception as e: + self.step.unload() + self._notify_load_failed() raise _ProcessWrapperException.create_load_error( str(e), self.step, e ) from e @@ -853,14 +407,25 @@ def run(self) -> str: except Exception: pass + self.step.unload() + + self._notify_unload() + self.step._logger.info(f"🏁 Finished running step '{self.step.name}'") return self.step.name # type: ignore def _notify_load(self) -> None: """Notifies that the step has finished executing its `load` function successfully.""" - with self.shared_info[_STEPS_LOADED_LOCK_KEY]: - self.shared_info[_STEPS_LOADED_KEY].append(self.step.name) + self.load_queue.put({"name": self.step.name, "status": "loaded"}) # type: ignore + + def _notify_unload(self) -> None: + """Notifies that the step has been unloaded.""" + self.load_queue.put({"name": self.step.name, "status": "unloaded"}) # type: ignore + + def _notify_load_failed(self) -> None: + """Notifies that the step failed to load.""" + self.load_queue.put({"name": self.step.name, "status": "load_failed"}) # type: ignore def _generator_step_process_loop(self) -> None: """Runs the process loop for a generator step. It will call the `process` method diff --git a/src/distilabel/pipeline/routing_batch_function.py b/src/distilabel/pipeline/routing_batch_function.py index c66a2d82b0..9d074d9fb6 100644 --- a/src/distilabel/pipeline/routing_batch_function.py +++ b/src/distilabel/pipeline/routing_batch_function.py @@ -26,7 +26,7 @@ ) if TYPE_CHECKING: - from distilabel.pipeline.base import _Batch + from distilabel.pipeline.batch import _Batch from distilabel.pipeline.typing import DownstreamConnectableSteps from distilabel.steps.base import _Step diff --git a/src/distilabel/pipeline/typing.py b/src/distilabel/pipeline/typing.py index 2ebb9b4643..ebb4c68155 100644 --- a/src/distilabel/pipeline/typing.py +++ b/src/distilabel/pipeline/typing.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, TypeVar, Union +from typing import TYPE_CHECKING, Literal, TypedDict, TypeVar, Union if TYPE_CHECKING: from distilabel.steps.base import GeneratorStep, GlobalStep, Step @@ -32,3 +32,11 @@ covariant=True, ) """Type for the `Step` types that can be connected as downstream steps.""" + + +class StepLoadStatus(TypedDict): + """Dict containing information about if one step was loaded/unloaded or if it's load + failed""" + + name: str + status: Literal["loaded", "unloaded", "load_failed"] diff --git a/src/distilabel/pipeline/write_buffer.py b/src/distilabel/pipeline/write_buffer.py new file mode 100644 index 0000000000..a71ffdd9b2 --- /dev/null +++ b/src/distilabel/pipeline/write_buffer.py @@ -0,0 +1,168 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from os import PathLike +from pathlib import Path +from typing import Any, Dict, List, Set + +import pyarrow as pa +import pyarrow.parquet as pq + +from distilabel.pipeline.batch import _Batch +from distilabel.utils.dicts import flatten_dict +from distilabel.utils.files import list_files_in_dir + + +class _WriteBuffer: + """Class in charge of sending the batched contents to a buffer and writing + those to files under a given folder. + + As batches are received, they are added to the buffer and once each buffer + is full, the content is written to a parquet file. + """ + + def __init__(self, path: "PathLike", leaf_steps: Set[str]) -> None: + """ + Args: + path: Folder where the files will be written, the idea + is for this path to be in the cache folder under /data. + leaf_steps: Leaf steps from either the DAG of the Pipeline. + + Raises: + ValueError: If the path is not a directory. + """ + self._path = Path(path) + if not self._path.exists(): + self._path.mkdir(parents=True, exist_ok=True) + for step in leaf_steps: + (self._path / step).mkdir(parents=True, exist_ok=True) + + if not self._path.is_dir(): + raise ValueError(f"The path should be a directory, not a file: {path}") + + self._buffers: Dict[str, List[Dict[str, Any]]] = { + step: [] for step in leaf_steps + } + # TODO: make this configurable + self._buffers_dump_batch_size: Dict[str, int] = { + step: 50 for step in leaf_steps + } + self._buffer_last_schema = {} + self._buffers_last_file: Dict[str, int] = {step: 1 for step in leaf_steps} + self._logger = logging.getLogger("distilabel.write_buffer") + + def _get_filename(self, step_name: str) -> Path: + """Creates the filename for the step. + + Args: + step_name: Name of the step to which the data belongs to. + + Returns: + Filename for the step. + """ + return self._path / f"{step_name}.parquet" + + def is_full(self, step_name: str) -> bool: + """Checks the buffers that are full so that those can be written to the file. + + Returns: + Whether the buffer is full. + """ + return len(self._buffers[step_name]) >= self._buffers_dump_batch_size[step_name] + + def add_batch(self, batch: "_Batch") -> None: + """Adds a batch to the buffer and writes the buffer to the file if it's full. + + Args: + batch: batch to add to the buffer. + """ + step_name = batch.step_name + data = batch.data[0] + self._buffers[step_name].extend(data) + self._logger.debug( + f"Added batch to write buffer for step '{step_name}' with {len(data)} rows." + ) + if self.is_full(step_name): + self._logger.debug( + f"Buffer for step '{step_name}' is full (rows: {len(self._buffers[step_name])}," + f" full: {self._buffers_dump_batch_size[step_name]}), writing to file..." + ) + self._write(step_name) + + def _write(self, step_name: str) -> None: + """Writes the content to the file and cleans the buffer. + + Args: + step_name (str): Name of the step to which the data pertains. + """ + step_parquet_dir = Path(self._path, step_name) + if not step_parquet_dir.exists(): + self._logger.debug( + f"Creating directory for step '{step_name}' parquet files..." + ) + step_parquet_dir.mkdir() + + try: + table = pa.Table.from_pylist(self._buffers[step_name]) + except pa.lib.ArrowInvalid as pae: + if ( + repr(pae) + != "ArrowInvalid('cannot mix struct and non-struct, non-null values')" + ): + raise pae + flattened_buffers = [flatten_dict(buf) for buf in self._buffers[step_name]] + table = pa.Table.from_pylist(flattened_buffers) + + last_schema = self._buffer_last_schema.get(step_name) + if last_schema is None: + self._buffer_last_schema[step_name] = table.schema + else: + if not last_schema.equals(table.schema): + new_schema = pa.unify_schemas([last_schema, table.schema]) + self._buffer_last_schema[step_name] = new_schema + table = table.cast(new_schema) + + next_file_number = self._buffers_last_file[step_name] + self._buffers_last_file[step_name] = next_file_number + 1 + + parquet_file = step_parquet_dir / f"{str(next_file_number).zfill(5)}.parquet" + pq.write_table(table, parquet_file) + self._logger.debug(f"Written to file '{parquet_file}'") + + self._clean_buffer(step_name) + + def _clean_buffer(self, step_name: str) -> None: + """Cleans the buffer by setting it's content to `None`. + + Args: + step_name: The name of the buffer to clean. + """ + self._buffers[step_name] = [] + + def close(self) -> None: + """Closes the buffer by writing the remaining content to the file.""" + for step_name in self._buffers: + if self._buffers[step_name]: + self._write(step_name) + + # We need to read the parquet files and write them again to ensure the schema + # is correct. Otherwise, the first parquets won't have the last schema and + # then we will have issues when reading them. + for file in list_files_in_dir(self._path / step_name): + if step_name in self._buffer_last_schema: + table = pq.read_table( + file, schema=self._buffer_last_schema[step_name] + ) + pq.write_table(table, file) diff --git a/src/distilabel/steps/base.py b/src/distilabel/steps/base.py index 9aab121815..d35cdb8b5d 100644 --- a/src/distilabel/steps/base.py +++ b/src/distilabel/steps/base.py @@ -303,6 +303,12 @@ def load(self) -> None: """ self._logger = logging.getLogger(f"distilabel.step.{self.name}") + def unload(self) -> None: + """Method to perform any cleanup logic after the `process` method is called. For + example, to close a connection to a database, etc. + """ + self._logger.debug("Executing step unload logic.") + @property def is_generator(self) -> bool: """Whether the step is a generator step or not. diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index fda1e1e248..a73ccbfbdb 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union from pydantic import Field +from typing_extensions import override from distilabel.llms.base import LLM from distilabel.mixins.runtime_parameters import RuntimeParameter @@ -64,10 +65,16 @@ class _Task(_Step, ABC): ) def load(self) -> None: - """Loads the LLM via the `LLM.load()` method (done for safer serialization).""" + """Loads the LLM via the `LLM.load()` method.""" super().load() self.llm.load() + @override + def unload(self) -> None: + """Unloads the LLM.""" + self._logger.debug("Executing task unload logic.") + self.llm.unload() + @abstractmethod def format_output( self, diff --git a/src/distilabel/utils/serialization.py b/src/distilabel/utils/serialization.py index 7862a8b9c5..714f52fbb0 100644 --- a/src/distilabel/utils/serialization.py +++ b/src/distilabel/utils/serialization.py @@ -25,7 +25,18 @@ from enum import EnumType from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Type, TypeVar, Union, get_args +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Tuple, + Type, + TypeVar, + Union, + get_args, +) import yaml from pydantic import BaseModel @@ -41,12 +52,30 @@ SaveFormats = Literal["json", "yaml"] +# Mapping to handle import paths that could have been serialized from previous versions +_OLD_IMPORT_MODULE_ATTR: Dict[Tuple[str, str], Tuple[str, str]] = { + ("distilabel.pipeline.base", "_Batch"): ("distilabel.pipeline.batch", "_Batch"), + ("distilabel.pipeline.base", "_BatchManager"): ( + "distilabel.pipeline.batch_manager", + "_BatchManager", + ), + ("distilabel.pipeline.base", "_BatchManagerStep"): ( + "distilabel.pipeline.batch_manager", + "_BatchManagerStep", + ), +} + + def _get_module_attr(module: str, name: str) -> Type: """Gets a class given the module and the name of the class. Returns: The type of the class. """ + + if (module, name) in _OLD_IMPORT_MODULE_ATTR: + module, name = _OLD_IMPORT_MODULE_ATTR[(module, name)] + mod = importlib.import_module(module) return getattr(mod, name) diff --git a/tests/unit/llms/test_mixins.py b/tests/unit/llms/test_mixins.py index feb8b00e01..c0c7b10671 100644 --- a/tests/unit/llms/test_mixins.py +++ b/tests/unit/llms/test_mixins.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import multiprocessing as mp import os import sys from typing import TYPE_CHECKING, Any, Generator, List, Union @@ -43,6 +42,10 @@ def load(self) -> None: super().load() CudaDevicePlacementMixin.load(self) + def unload(self) -> None: + super().unload() + CudaDevicePlacementMixin.unload(self) + @property def model_name(self) -> str: return "test" @@ -63,13 +66,7 @@ def test_set_cuda_visible_devices(self) -> None: assert os.environ["CUDA_VISIBLE_DEVICES"] == "0,1" - def test_cuda_visible_devices_not_cuda_devices(self) -> None: - llm = DummyCudaLLM() - llm._llm_identifier = "unit-test" - - llm.load() - - assert os.getenv("CUDA_VISIBLE_DEVICES") is None + llm.unload() def test_set_cuda_visible_devices_unvalid_devices(self) -> None: llm = DummyCudaLLM(cuda_devices=[5, 6]) @@ -80,84 +77,54 @@ def test_set_cuda_visible_devices_unvalid_devices(self) -> None: ): llm.load() - def test_set_device_placement_info(self) -> None: - llm = DummyCudaLLM(cuda_devices="auto") + llm.unload() + + def test_set_cuda_visible_devices_auto(self) -> None: + llm1 = DummyCudaLLM() + llm1._llm_identifier = "unit-test-1" + llm1.load() - with mp.Manager() as manager: - llm.set_device_placement_info( - llm_identifier="unit-test", - device_llm_placement_map=manager.dict(), - device_llm_placement_lock=manager.Lock(), # type: ignore - ) + assert os.environ["CUDA_VISIBLE_DEVICES"] == "0" - assert llm._llm_identifier == "unit-test" - assert llm._device_llm_placement_map is not None + llm2 = DummyCudaLLM() + llm2._llm_identifier = "unit-test-2" + llm2.load() - def test_set_cuda_visible_devices_auto(self) -> None: - with mp.Manager() as manager: - device_llm_placement_map = manager.dict() - lock = manager.Lock() - - llm1 = DummyCudaLLM() - llm1.set_device_placement_info( - llm_identifier="unit-test-1", - device_llm_placement_map=device_llm_placement_map, - device_llm_placement_lock=lock, # type: ignore - ) - llm1.load() - - assert os.environ["CUDA_VISIBLE_DEVICES"] == "0" - - llm2 = DummyCudaLLM() - llm2.set_device_placement_info( - llm_identifier="unit-test-2", - device_llm_placement_map=device_llm_placement_map, - device_llm_placement_lock=lock, # type: ignore - ) - llm2.load() - - assert os.environ["CUDA_VISIBLE_DEVICES"] == "1" + assert os.environ["CUDA_VISIBLE_DEVICES"] == "1" + + llm1.unload() + llm2.unload() def test_set_cuda_visible_devices_auto_not_enough_devices(self) -> None: - with mp.Manager() as manager: - device_llm_placement_map = manager.dict() - lock = manager.Lock() - - with pytest.raises( - RuntimeError, match="Couldn't find an available CUDA device" - ): - # 4 devices are available, but 5 LLMs are going to be loaded - for i in range(5): - llm = DummyCudaLLM() - llm.set_device_placement_info( - llm_identifier=f"unit-test-{i}", - device_llm_placement_map=device_llm_placement_map, - device_llm_placement_lock=lock, # type: ignore - ) - llm.load() + llms = [] + for i in range(5): + llm = DummyCudaLLM() + llm._llm_identifier = f"unit-test-{i}" + llms.append(llm) + + with pytest.raises( + RuntimeError, match="Couldn't find an available CUDA device" + ): + # 4 devices are available, but 5 LLMs are going to be loaded + for llm in llms: + llm.load() + + for llm in llms: + llm.unload() def test_check_cuda_devices(self, caplog) -> None: - with mp.Manager() as manager: - device_llm_placement_map = manager.dict() - lock = manager.Lock() - - llm1 = DummyCudaLLM(cuda_devices=[1]) - llm1.set_device_placement_info( - llm_identifier="unit-test-1", - device_llm_placement_map=device_llm_placement_map, - device_llm_placement_lock=lock, # type: ignore - ) - llm1.load() - - llm2 = DummyCudaLLM(cuda_devices=[1]) - llm2.set_device_placement_info( - llm_identifier="unit-test-2", - device_llm_placement_map=device_llm_placement_map, - device_llm_placement_lock=lock, # type: ignore - ) - llm2.load() - - assert ( - "LLM with identifier 'unit-test-1' is also going to use CUDA device '1'" - in caplog.text - ) + llm1 = DummyCudaLLM(cuda_devices=[1]) + llm1._llm_identifier = "unit-test-1" + llm1.load() + + llm2 = DummyCudaLLM(cuda_devices=[1]) + llm2._llm_identifier = "unit-test-2" + llm2.load() + + assert ( + "LLM with identifier 'unit-test-1' is also going to use CUDA device '1'" + in caplog.text + ) + + llm1.unload() + llm2.unload() diff --git a/tests/unit/pipeline/test_base.py b/tests/unit/pipeline/test_base.py index 4c26db132d..c18a30e143 100644 --- a/tests/unit/pipeline/test_base.py +++ b/tests/unit/pipeline/test_base.py @@ -12,26 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import tempfile from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +from queue import Queue +from typing import Any, Callable, Dict, List, Optional from unittest import mock import pytest -from distilabel.distiset import Distiset, create_distiset from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.pipeline._dag import DAG from distilabel.pipeline.base import ( + _STEP_LOAD_FAILED_CODE, + _STEP_NOT_LOADED_CODE, BasePipeline, - _Batch, - _BatchManager, - _BatchManagerStep, _GlobalPipelineManager, - _WriteBuffer, ) -from distilabel.pipeline.local import Pipeline -from distilabel.steps.base import GlobalStep, Step, StepInput +from distilabel.pipeline.batch import _Batch +from distilabel.pipeline.batch_manager import _BatchManager +from distilabel.pipeline.constants import INPUT_QUEUE_ATTR_NAME, LAST_BATCH_SENT_FLAG +from distilabel.pipeline.routing_batch_function import ( + routing_batch_function, + sample_n_steps, +) +from distilabel.pipeline.write_buffer import _WriteBuffer +from distilabel.steps.base import Step, StepInput, _Step from distilabel.steps.typing import StepOutput from distilabel.utils.serialization import TYPE_INFO_KEY from fsspec.implementations.local import LocalFileSystem @@ -43,11 +48,19 @@ DummyGlobalStep, DummyStep1, DummyStep2, - batch_gen, ) -if TYPE_CHECKING: - from distilabel.steps.base import GeneratorStep + +class DummyPipeline(BasePipeline): + @property + def QueueClass(self) -> Callable: + return Queue + + def _run_step(self, step: "_Step", input_queue: "Queue[Any]") -> None: + pass + + def _stop(self) -> None: + pass class TestGlobalPipelineManager: @@ -55,7 +68,7 @@ def teardown_method(self) -> None: _GlobalPipelineManager.set_pipeline(None) def test_set_pipeline(self) -> None: - pipeline = BasePipeline(name="unit-test-pipeline") + pipeline = DummyPipeline(name="unit-test-pipeline") _GlobalPipelineManager.set_pipeline(pipeline) assert _GlobalPipelineManager.get_pipeline() == pipeline @@ -64,7 +77,7 @@ def test_set_pipeline_none(self) -> None: assert _GlobalPipelineManager.get_pipeline() is None def test_get_pipeline(self) -> None: - pipeline = BasePipeline(name="unit-test-pipeline") + pipeline = DummyPipeline(name="unit-test-pipeline") _GlobalPipelineManager.set_pipeline(pipeline) assert _GlobalPipelineManager.get_pipeline() == pipeline @@ -73,7 +86,7 @@ class TestBasePipeline: def test_context_manager(self) -> None: assert _GlobalPipelineManager.get_pipeline() is None - with BasePipeline(name="unit-test-pipeline") as pipeline: + with DummyPipeline(name="unit-test-pipeline") as pipeline: assert pipeline is not None assert _GlobalPipelineManager.get_pipeline() == pipeline @@ -81,7 +94,7 @@ def test_context_manager(self) -> None: @pytest.mark.parametrize("use_cache", [False, True]) def test_load_batch_manager(self, use_cache: bool) -> None: - pipeline = BasePipeline(name="unit-test-pipeline") + pipeline = DummyPipeline(name="unit-test-pipeline") pipeline._load_batch_manager(use_cache=True) pipeline._cache() @@ -102,19 +115,19 @@ def test_load_batch_manager(self, use_cache: bool) -> None: mock_from_dag.assert_called_once_with(pipeline.dag) def test_setup_write_buffer(self) -> None: - pipeline = BasePipeline(name="unit-test-pipeline") + pipeline = DummyPipeline(name="unit-test-pipeline") pipeline._setup_write_buffer() assert isinstance(pipeline._write_buffer, _WriteBuffer) def test_set_logging_parameters(self) -> None: - pipeline = BasePipeline(name="unit-test-pipeline") + pipeline = DummyPipeline(name="unit-test-pipeline") pipeline._set_logging_parameters({"unit-test": "yes"}) assert pipeline._logging_parameters == {"unit-test": "yes"} def test_setup_fsspec(self) -> None: - pipeline = BasePipeline(name="unit-test-pipeline") + pipeline = DummyPipeline(name="unit-test-pipeline") with mock.patch("fsspec.filesystem") as mock_filesystem: pipeline._setup_fsspec({"path": "gcs://my-bucket", "extra": "stuff"}) @@ -122,7 +135,7 @@ def test_setup_fsspec(self) -> None: mock_filesystem.assert_called_once_with("gcs", **{"extra": "stuff"}) def test_setup_fsspec_default(self) -> None: - pipeline = BasePipeline(name="unit-test-pipeline") + pipeline = DummyPipeline(name="unit-test-pipeline") pipeline._setup_fsspec() assert isinstance(pipeline._fs, LocalFileSystem) @@ -132,13 +145,267 @@ def test_setup_fsspec_default(self) -> None: ) def test_setup_fsspec_raises_value_error(self) -> None: - pipeline = BasePipeline(name="unit-test-pipeline") + pipeline = DummyPipeline(name="unit-test-pipeline") with pytest.raises(ValueError, match="The 'path' key must be present"): pipeline._setup_fsspec({"key": "random"}) + def test_init_steps_load_status(self) -> None: + with DummyPipeline(name="dummy") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + step2 = DummyStep1() + step3 = DummyStep2() + + generator >> [step, step2] >> step3 + + pipeline._init_steps_load_status() + assert pipeline._steps_load_status == { + generator.name: _STEP_NOT_LOADED_CODE, + step.name: _STEP_NOT_LOADED_CODE, + step2.name: _STEP_NOT_LOADED_CODE, + step3.name: _STEP_NOT_LOADED_CODE, + } + + def test_run_load_queue_loop(self) -> None: + pipeline = DummyPipeline(name="unit-test-pipeline") + + pipeline._load_queue = Queue() + pipeline._steps_load_status = {"dummy": 0} + pipeline._load_queue.put({"name": "dummy", "status": "loaded"}) + + thread = pipeline._run_load_queue_loop_in_thread() + pipeline._load_queue.put(None) + thread.join() + + assert pipeline._steps_load_status["dummy"] == 1 + + def test_run_load_queue_loop_receiving_none(self) -> None: + pipeline = DummyPipeline(name="unit-test-pipeline") + + pipeline._load_queue = Queue() + pipeline._load_queue.put(None) + + thread = pipeline._run_load_queue_loop_in_thread() + thread.join() + + assert not thread.is_alive() + + def test_all_steps_loaded(self, caplog) -> None: + with DummyPipeline(name="dummy") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + step2 = DummyStep1() + step3 = DummyStep2() + + generator >> [step, step2] >> step3 + + pipeline._steps_load_status = { # type: ignore + generator.name: 1, + step.name: 1, + step2.name: 1, + step3.name: 1, + } + caplog.set_level(logging.INFO) + + assert pipeline._all_steps_loaded() is True + assert "All the steps have been loaded!" in caplog.text + + def test_all_steps_loaded_with_failing_step(self, caplog) -> None: + with DummyPipeline(name="dummy") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + step2 = DummyStep1() + step3 = DummyStep2() + + generator >> [step, step2] >> step3 + + pipeline._init_steps_load_status() + pipeline._steps_load_status[generator.name] = _STEP_LOAD_FAILED_CODE # type: ignore + caplog.set_level(logging.INFO) + + assert pipeline._all_steps_loaded() is False + assert "Failed to load all the steps" in caplog.text + + def test_all_steps_loaded_stop_aclled(self) -> None: + with DummyPipeline(name="dummy") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + step2 = DummyStep1() + step3 = DummyStep2() + + generator >> [step, step2] >> step3 + + pipeline._init_steps_load_status() + pipeline._stop_called = True + + assert pipeline._all_steps_loaded() is False + + def test_handle_stop(self) -> None: + with DummyPipeline(name="dummy") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + step2 = DummyStep1() + step3 = DummyStep2() + + generator >> [step, step2] >> step3 + + pipeline._add_batches_back_to_batch_manager = mock.MagicMock() + pipeline._wait_step_input_queue_empty = mock.MagicMock() + pipeline._consume_output_queue = mock.MagicMock() + + pipeline._handle_stop() + + pipeline._add_batches_back_to_batch_manager.assert_called_once() + pipeline._wait_step_input_queue_empty.assert_has_calls( + [ + mock.call(generator.name), + mock.call(step.name), + mock.call(step2.name), + mock.call(step3.name), + ], + any_order=True, + ) + pipeline._consume_output_queue.assert_called_once() + + @pytest.mark.parametrize( + "num_workers,expected", [(0, True), (_STEP_LOAD_FAILED_CODE, True), (1, False)] + ) + def test_check_step_not_loaded_or_finished( + self, num_workers: int, expected: bool + ) -> None: + pipeline = DummyPipeline(name="unit-test-pipeline") + pipeline._steps_load_status = {"dummy": num_workers} + + assert pipeline._check_step_not_loaded_or_finished("dummy") is expected + + def test_is_convergence_step(self) -> None: + sample_two_steps = sample_n_steps(2) + + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + step2 = DummyStep1() + step3 = DummyStep2() + + generator >> sample_two_steps >> [step, step2] >> step3 + + pipeline.dag.validate() + + assert not pipeline._is_convergence_step(generator.name) # type: ignore + assert not pipeline._is_convergence_step(step.name) # type: ignore + assert not pipeline._is_convergence_step(step2.name) # type: ignore + assert pipeline._is_convergence_step(step3.name) # type: ignore + + def test_create_step_input_queue(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + + generator >> step + + generator_name: str = generator.name # type: ignore + input_queue = pipeline._create_step_input_queue(generator_name) + assert isinstance(input_queue, Queue) + assert isinstance( + pipeline.dag.get_step(generator_name)[INPUT_QUEUE_ATTR_NAME], Queue + ) + + def test_run_steps(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + + generator >> step + + pipeline._create_step_input_queue = mock.MagicMock() + pipeline._run_step = mock.MagicMock() + pipeline._run_steps() + + pipeline._create_step_input_queue.assert_has_calls( + [ + mock.call(step_name=step.name), + mock.call(step_name=generator.name), + ], + any_order=True, + ) + + pipeline._run_step.assert_has_calls( + [ + mock.call(step=mock.ANY, input_queue=mock.ANY), + mock.call(step=mock.ANY, input_queue=mock.ANY), + ] + ) + + def test_add_batches_back_to_batch_manager(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + + generator >> step + + generator_name: str = generator.name # type: ignore + step_name: str = step.name # type: ignore + + pipeline._batch_manager = _BatchManager.from_dag(pipeline.dag) + generator_queue = Queue() + pipeline.dag.set_step_attr( + generator_name, INPUT_QUEUE_ATTR_NAME, generator_queue + ) + step_queue = Queue() + pipeline.dag.set_step_attr(step_name, INPUT_QUEUE_ATTR_NAME, step_queue) + + generator_queue.put( + _Batch(seq_no=0, step_name=generator_name, last_batch=False) + ) + generator_queue.put( + _Batch(seq_no=1, step_name=generator_name, last_batch=False) + ) + + step_batch_0 = _Batch(seq_no=0, step_name=step_name, last_batch=False) + step_batch_1 = _Batch(seq_no=0, step_name=step_name, last_batch=False) + step_queue.put(step_batch_0) + step_queue.put(step_batch_1) + + pipeline._add_batches_back_to_batch_manager() + + assert pipeline._batch_manager._steps[step_name].built_batches == [ + step_batch_0, + step_batch_1, + ] + + def test_consume_output_queue(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + + generator >> step + + pipeline._output_queue = Queue() + pipeline._write_buffer = mock.MagicMock() + pipeline._handle_batch_on_stop = mock.MagicMock() + + generator_name: str = generator.name # type: ignore + step_name: str = step.name # type: ignore + + generator_batch = _Batch(seq_no=0, step_name=generator_name, last_batch=False) + step_batch = _Batch(seq_no=0, step_name=step_name, last_batch=False) + + pipeline._output_queue.put(generator_batch) + pipeline._output_queue.put(step_batch) + + pipeline._consume_output_queue() + + pipeline._write_buffer.add_batch.assert_called_once_with(step_batch) + pipeline._handle_batch_on_stop.assert_has_calls( + [ + mock.call(generator_batch), + mock.call(step_batch), + ] + ) + def test_send_batch_to_step(self) -> None: - with BasePipeline(name="unit-test-pipeline") as pipeline: + with DummyPipeline(name="unit-test-pipeline") as pipeline: generator = DummyGeneratorStep() step = DummyStep1() global_step = DummyGlobalStep() @@ -146,6 +413,7 @@ def test_send_batch_to_step(self) -> None: generator >> [step, global_step] pipeline._batch_manager = mock.MagicMock() + pipeline._send_to_step = mock.MagicMock() pipeline._setup_fsspec() with mock.patch( @@ -159,6 +427,8 @@ def test_send_batch_to_step(self) -> None: _Batch(seq_no=0, step_name=step.name, last_batch=False) # type: ignore ) + # `write_batch_data_to_fs` shouldn't have been called because last batch sent with + # `_send_batch_to_step` is from a non-global step. mock_write.assert_not_called() with mock.patch( @@ -168,6 +438,8 @@ def test_send_batch_to_step(self) -> None: _Batch(seq_no=0, step_name=global_step.name, last_batch=False) # type: ignore ) + # `write_batch_data_to_fs` should have been called because last batch sent with + # `_send_batch_to_step` is from a global step. mock_write.assert_called_once_with( pipeline._fs, UPath(pipeline._storage_base_path) / global_step.name, @@ -182,6 +454,8 @@ def test_send_batch_to_step(self) -> None: _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore ) + # `write_batch_data_to_fs` shouldn't have been called because generator receives + # empty batches, so there's no data to write. mock_write.assert_not_called() with mock.patch( @@ -207,6 +481,229 @@ def test_send_batch_to_step(self) -> None: ] ) + def test_register_batch(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + + generator >> step + + pipeline._batch_manager = mock.MagicMock() + batch = _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore + pipeline._register_batch(batch) + + pipeline._batch_manager.register_batch.assert_called_once_with(batch) + + def test_send_last_batch_flag_to_step(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + + generator >> step + + step_name: str = step.name # type: ignore + + pipeline._batch_manager = _BatchManager( + steps={}, + last_batch_received={step_name: None}, + last_batch_sent={step_name: None}, + last_batch_flag_sent_to=[], + ) + + with mock.patch.object(pipeline, "_send_to_step") as mock_sent_to_step: + pipeline._send_last_batch_flag_to_step(step_name) + + mock_sent_to_step.assert_called_once_with(step_name, LAST_BATCH_SENT_FLAG) + + pipeline._batch_manager._last_batch_sent[step_name] = _Batch( + seq_no=0, + step_name=step_name, + last_batch=True, + ) + with mock.patch.object(pipeline, "_send_to_step") as mock_sent_to_step: + pipeline._send_last_batch_flag_to_step(step_name) + + mock_sent_to_step.assert_not_called() + + def test_request_initial_batches(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1(input_batch_size=5) + + generator >> step + + generator2 = DummyGeneratorStep() + step2 = DummyStep1(input_batch_size=5) + + generator2 >> step2 + + pipeline._batch_manager = _BatchManager.from_dag(pipeline.dag) + + # Simulate there were batches from the cache for the steps + batch_0 = _Batch( + seq_no=0, + step_name=generator.name, # type: ignore + last_batch=False, + data=[[{"a": i} for i in range(5)]], + ) + pipeline._batch_manager._steps[step.name].data[generator.name] = [ # type: ignore + batch_0 + ] + + batch_1 = _Batch( + seq_no=0, + step_name=generator2.name, # type: ignore + last_batch=False, + data=[[{"b": i} for i in range(5)]], + ) # type: ignore + pipeline._batch_manager._steps[step2.name].data[generator2.name] = [ # type: ignore + batch_1 + ] + + with mock.patch.object( + pipeline, "_send_batch_to_step" + ) as mock_send_batch_to_step: + pipeline._request_initial_batches() + + mock_send_batch_to_step.assert_has_calls( + [ + mock.call(mock.ANY), + mock.call(mock.ANY), + mock.call(_Batch(seq_no=0, step_name=generator.name, last_batch=False)), # type: ignore + mock.call( + _Batch(seq_no=0, step_name=generator2.name, last_batch=False) # type: ignore + ), + ], + any_order=True, + ) + + def test_request_more_batches_if_needed(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + + generator >> step + + generator_name: str = generator.name # type: ignore + + pipeline._batch_manager = _BatchManager.from_dag(pipeline.dag) + + batch = _Batch(seq_no=0, step_name=generator_name, last_batch=False) + pipeline._batch_manager._last_batch_sent[generator_name] = batch + + with mock.patch.object( + pipeline, "_send_batch_to_step" + ) as mock_send_batch_to_step: + pipeline._request_more_batches_if_needed(step) + + mock_send_batch_to_step.assert_called_once_with(batch.next_batch()) + + def test_handle_batch_on_stop(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1(input_batch_size=5) + step2 = DummyStep1(input_batch_size=5) + step3 = DummyStep1(input_batch_size=5) + + generator >> [step, step2, step3] + + batch_manager_mock = mock.MagicMock() + pipeline._batch_manager = batch_manager_mock + + batch = _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore + pipeline._handle_batch_on_stop(batch) + + batch_manager_mock.register_batch.assert_called_once_with(batch) + batch_manager_mock.add_batch.assert_has_calls( + [ + mock.call(step.name, batch), + mock.call(step2.name, batch), + mock.call(step3.name, batch), + ] + ) + + def test_get_step_from_batch(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + + generator >> step + + batch = _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore + assert pipeline._get_step_from_batch(batch) == generator + + batch = _Batch(seq_no=0, step_name=step.name, last_batch=False) # type: ignore + assert pipeline._get_step_from_batch(batch) == step + + def test_notify_steps_to_stop(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1(input_batch_size=5) + + generator >> step + + with mock.patch.object(pipeline, "_send_to_step") as mock_send_to_step: + pipeline._notify_steps_to_stop() + + mock_send_to_step.assert_has_calls( + [ + mock.call(generator.name, None), + mock.call(step.name, None), + ] + ) + + def test_get_successors(self) -> None: + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1() + step2 = DummyStep1() + step3 = DummyStep2() + + generator >> [step, step2] >> step3 + + assert pipeline._get_successors( + _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore + ) == ([step.name, step2.name], False) + assert pipeline._get_successors( + _Batch(seq_no=0, step_name=step.name, last_batch=False) # type: ignore + ) == ([step3.name], False) + assert pipeline._get_successors( + _Batch(seq_no=0, step_name=step2.name, last_batch=False) # type: ignore + ) == ([step3.name], False) + assert pipeline._get_successors( + _Batch(seq_no=0, step_name=step3.name, last_batch=False) # type: ignore + ) == ([], False) + + def test_get_successors_with_routing_batch_function(self) -> None: + @routing_batch_function() + def fixed_routing_batch_function(steps: List[str]) -> List[str]: + return ["step_2", "step_3"] + + with DummyPipeline(name="unit-test-pipeline") as pipeline: + generator = DummyGeneratorStep() + step = DummyStep1(name="step_1") + step2 = DummyStep1(name="step_2") + step3 = DummyStep1(name="step_3") + step4 = DummyStep2(name="step_4") + + generator >> fixed_routing_batch_function >> [step, step2, step3] >> step4 + + assert pipeline._get_successors( + _Batch(seq_no=0, step_name=generator.name, last_batch=False) # type: ignore + ) == (["step_2", "step_3"], True) + assert pipeline._get_successors( + _Batch(seq_no=0, step_name=step.name, last_batch=False) # type: ignore + ) == ([step4.name], False) + assert pipeline._get_successors( + _Batch(seq_no=0, step_name=step2.name, last_batch=False) # type: ignore + ) == ([step4.name], False) + assert pipeline._get_successors( + _Batch(seq_no=0, step_name=step3.name, last_batch=False) # type: ignore + ) == ([step4.name], False) + assert pipeline._get_successors( + _Batch(seq_no=0, step_name=step4.name, last_batch=False) # type: ignore + ) == ([], False) + def test_get_runtime_parameters_info(self) -> None: class DummyStep1(Step): runtime_param1: RuntimeParameter[str] = Field( @@ -230,7 +727,7 @@ class DummyStep2(Step): def process(self, inputs: StepInput) -> None: pass - with BasePipeline(name="unit-test-pipeline") as pipeline: + with DummyPipeline(name="unit-test-pipeline") as pipeline: DummyStep1(name="dummy_step_1") DummyStep2(name="dummy_step_2") @@ -331,7 +828,7 @@ class DummyStep2(Step): def process(self, inputs: StepInput) -> StepOutput: # type: ignore yield [{}] - with BasePipeline(name="unit-test-pipeline") as pipeline: + with DummyPipeline(name="unit-test-pipeline") as pipeline: gen_step = DummyGeneratorStep(name="dummy_generator_step") step1 = DummyStep1(name="dummy_step_1") step2 = DummyStep2(name="dummy_step_2") @@ -348,7 +845,7 @@ def process(self, inputs: StepInput) -> StepOutput: # type: ignore def test_cache_dir_env_variable(self) -> None: with mock.patch.dict(os.environ, clear=True): os.environ["DISTILABEL_CACHE_DIR"] = "/tmp/unit-test" - pipeline = BasePipeline(name="unit-test-pipeline") + pipeline = DummyPipeline(name="unit-test-pipeline") assert pipeline._cache_dir == Path("/tmp/unit-test") @pytest.mark.parametrize( @@ -371,7 +868,7 @@ def test_cache_dir_env_variable(self) -> None: ) def test_step_names_inferred(self, in_pipeline: bool, names: List[str]) -> None: if in_pipeline: - with BasePipeline(name="unit-test-pipeline"): + with DummyPipeline(name="unit-test-pipeline"): gen_step = DummyGeneratorStep() step1_0 = DummyStep1() step2 = DummyStep2() @@ -391,2430 +888,84 @@ def test_step_names_inferred(self, in_pipeline: bool, names: List[str]) -> None: def test_infer_step_names_big_pipeline(self) -> None: # Tests that the name of the steps are inferred correctly when the pipeline is big (say 50 steps). - with BasePipeline(name="unit-test-pipeline") as pipe: + with DummyPipeline(name="unit-test-pipeline") as pipe: gen_step = DummyGeneratorStep() for _ in range(50): gen_step.connect(DummyStep1()) assert list(pipe.dag.G)[-1] == "dummy_step1_49" -class TestBatch: - def test_get_data(self) -> None: - batch = _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[ - [ - {"a": 0}, - {"a": 1}, - {"a": 2}, - {"a": 3}, - {"a": 4}, - {"a": 5}, - {"a": 6}, - ] - ], +class TestPipelineSerialization: + def test_base_pipeline_dump(self): + pipeline = DummyPipeline(name="unit-test-pipeline") + dump = pipeline.dump() + assert len(dump.keys()) == 2 + assert "pipeline" in dump + assert "distilabel" in dump + assert TYPE_INFO_KEY in dump["pipeline"] + assert ( + dump["pipeline"][TYPE_INFO_KEY]["module"] == "tests.unit.pipeline.test_base" ) + assert dump["pipeline"][TYPE_INFO_KEY]["name"] == "DummyPipeline" - batch.set_data( - [ - [ - {"a": 0}, - {"a": 1}, - {"a": 2}, - {"a": 3}, - {"a": 4}, - {"a": 5}, - {"a": 6}, - ] - ] - ) + def test_base_pipeline_from_dict(self): + pipeline = DummyPipeline(name="unit-test-pipeline") + pipe = DummyPipeline.from_dict(pipeline.dump()) + assert isinstance(pipe, DummyPipeline) - old_hash = batch.data_hash + def test_pipeline_dump(self): + from distilabel.pipeline.local import Pipeline - data = batch.get_data(5) - assert data == [{"a": 0}, {"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}] - assert batch.data == [[{"a": 5}, {"a": 6}]] - assert batch.data_hash != old_hash + pipeline = Pipeline(name="unit-test-pipeline") + dump = pipeline.dump() + assert len(dump.keys()) == 2 + assert "pipeline" in dump + assert "distilabel" in dump + assert TYPE_INFO_KEY in dump["pipeline"] + assert dump["pipeline"][TYPE_INFO_KEY]["module"] == "distilabel.pipeline.local" + assert dump["pipeline"][TYPE_INFO_KEY]["name"] == "Pipeline" - def test_set_data(self) -> None: - batch = _Batch(seq_no=0, step_name="step1", last_batch=False) - data = [[{"i": i} for i in range(5000)]] - batch.set_data(data) + @pytest.mark.parametrize( + "format, name, loader", + [ + ("yaml", "pipe.yaml", DummyPipeline.from_yaml), + ("json", "pipe.json", DummyPipeline.from_json), + ("invalid", "pipe.invalid", None), + ], + ) + def test_pipeline_to_from_file_format( + self, + format: str, + name: str, + loader: Callable, + ) -> None: + pipeline = DummyPipeline(name="unit-test-pipeline") - assert batch.data == data - assert batch.size == 5000 + with tempfile.TemporaryDirectory() as tmpdirname: + filename = Path(tmpdirname) / name + if format == "invalid": + with pytest.raises(ValueError): + pipeline.save(filename, format=format) + else: + pipeline.save(filename, format=format) + assert filename.exists() + pipe_from_file = loader(filename) + assert isinstance(pipe_from_file, DummyPipeline) - def test_next_batch(self) -> None: - batch = _Batch(seq_no=0, step_name="step1", last_batch=False) - next_batch = batch.next_batch() + def test_base_pipeline_signature(self): + pipeline = DummyPipeline(name="unit-test-pipeline") + # Doesn't matter if it's exactly this or not, the test should fail if we change the + # way this is created. + signature = pipeline._create_signature() + assert signature == "da39a3ee5e6b4b0d3255bfef95601890afd80709" - assert next_batch == _Batch(seq_no=1, step_name="step1", last_batch=False) + # Maybe not the best place for this test, but does the work for now + from distilabel.pipeline.local import Pipeline + from distilabel.pipeline.routing_batch_function import sample_n_steps - def test_accumulate(self) -> None: - batches = [ - [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}]], - ), - _Batch( - seq_no=1, - step_name="step1", - last_batch=True, - data=[[{"a": 4}, {"a": 5}, {"a": 6}]], - ), - ], - [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}]], - ), - _Batch( - seq_no=1, - step_name="step2", - last_batch=True, - data=[[{"b": 4}, {"b": 5}, {"b": 6}]], - ), - ], - ] + from tests.unit.pipeline.utils import DummyGeneratorStep, DummyStep1, DummyStep2 - batch = _Batch.accumulate("step3", batches) - - assert batch.seq_no == 0 - assert batch.step_name == "step3" - assert batch.last_batch is True - assert batch.data == [ - [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}], - [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}], - ] - - def test_dump(self) -> None: - batch = _Batch(seq_no=0, step_name="step1", last_batch=False) - assert batch.dump() == { - "seq_no": 0, - "size": 0, - "step_name": "step1", - "last_batch": False, - "data": [], - "data_hash": None, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": {"module": "distilabel.pipeline.base", "name": "_Batch"}, - } - - batch = _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}]], - data_hash="hash", - accumulated=False, - created_from={"step0": [(0, 5), (1, 5)]}, - batch_routed_to=["step2", "step3"], - ) - assert batch.dump() == { - "seq_no": 0, - "size": 0, - "step_name": "step1", - "last_batch": False, - "data": [[{"a": 1}, {"a": 2}, {"a": 3}]], - "data_hash": "hash", - "accumulated": False, - "created_from": {"step0": [(0, 5), (1, 5)]}, - "batch_routed_to": ["step2", "step3"], - "type_info": {"module": "distilabel.pipeline.base", "name": "_Batch"}, - } - - def test_from_dict(self) -> None: - batch = _Batch.from_dict( - { - "seq_no": 0, - "step_name": "step1", - "last_batch": False, - "data": [[{"a": 1}, {"a": 2}, {"a": 3}]], - "accumulated": False, - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - ) - - assert isinstance(batch, _Batch) - assert batch.seq_no == 0 - assert batch.step_name == "step1" - assert batch.last_batch is False - assert batch.data == [[{"a": 1}, {"a": 2}, {"a": 3}]] - assert batch.accumulated is False - - -class TestBatchManagerStep: - def test_add_batch(self) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step2", accumulate=False, input_batch_size=10, data={"step1": []} - ) - - batch = _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}]], - ) - - batch_manager_step.add_batch(batch) - - assert batch_manager_step.data["step1"] == [batch] - assert batch_manager_step.last_batch_received == [] - - def test_add_batch_with_prepend(self) -> None: - batch_1 = _Batch( - seq_no=1, - step_name="step1", - last_batch=False, - data=[[{"a": 6}, {"a": 7}, {"a": 8}, {"a": 9}, {"a": 10}]], - ) - batch_manager_step = _BatchManagerStep( - step_name="step2", - accumulate=False, - input_batch_size=10, - data={"step1": [batch_1]}, - ) - - batch_0 = _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - batch_manager_step.add_batch(batch_0, prepend=True) - - assert batch_manager_step.built_batches == [batch_0] - assert batch_manager_step.data["step1"] == [batch_1] - assert batch_manager_step.last_batch_received == [] - - def test_add_batch_last_batch(self) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step2", accumulate=False, input_batch_size=10, data={"step1": []} - ) - - batch = _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}]], - ) - - batch_manager_step.add_batch(batch) - - assert batch_manager_step.data["step1"] == [batch] - assert batch_manager_step.last_batch_received == ["step1"] - - def test_get_batch(self) -> None: - previously_built_batch = _Batch( - seq_no=0, - step_name="step3", - last_batch=False, - data=[ - [ - {"a": -1}, - {"a": 0}, - ], - [ - {"b": -1}, - {"b": 0}, - ], - ], - ) - - batch_manager_step = _BatchManagerStep( - step_name="step3", - accumulate=False, - input_batch_size=2, - seq_no=1, - data={ - "step1": [ - _Batch( - seq_no=1, - step_name="step1", - last_batch=False, - data=[ - [ - {"a": 1}, - {"a": 2}, - {"a": 3}, - {"a": 4}, - {"a": 5}, - ] - ], - size=5, - ) - ], - "step2": [ - _Batch( - seq_no=1, - step_name="step2", - last_batch=False, - data=[ - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - ] - ], - size=5, - ) - ], - }, - built_batches=[previously_built_batch], - ) - - batch = batch_manager_step.get_batch() - - assert batch == previously_built_batch - - batch = batch_manager_step.get_batch() - - assert batch == _Batch( - step_name="step3", - seq_no=1, - last_batch=False, - data=[ - [ - {"a": 1}, - {"a": 2}, - ], - [ - {"b": 1}, - {"b": 2}, - ], - ], - created_from={"step1": [(1, 5)], "step2": [(1, 5)]}, - ) - - batch = batch_manager_step.get_batch() - - assert batch == _Batch( - step_name="step3", - seq_no=2, - last_batch=False, - data=[ - [ - {"a": 3}, - {"a": 4}, - ], - [ - {"b": 3}, - {"b": 4}, - ], - ], - created_from={"step1": [(1, 5)], "step2": [(1, 5)]}, - ) - - def test_get_batches_accumulate(self) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step3", - accumulate=True, - data={ - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[ - [ - {"a": 1}, - {"a": 2}, - {"a": 3}, - {"a": 4}, - {"a": 5}, - ] - ], - size=5, - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=True, - data=[ - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - ] - ], - size=6, - ) - ], - }, - last_batch_received=["step1", "step2"], - ) - - batch = batch_manager_step.get_batch() - - assert batch == _Batch( - step_name="step3", - seq_no=0, - last_batch=True, - accumulated=True, - data=[ - [ - {"a": 1}, - {"a": 2}, - {"a": 3}, - {"a": 4}, - {"a": 5}, - ], - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - ], - ], - created_from={"step1": [(0, 5)], "step2": [(0, 6)]}, - ) - - def test_get_batches_not_enough_data(self) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step3", - accumulate=False, - input_batch_size=2, - data={ - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[ - [ - {"a": 1}, - ] - ], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[ - [ - {"b": 1}, - {"b": 2}, - ] - ], - ) - ], - }, - ) - - assert batch_manager_step.get_batch() is None - - def test_from_step(self, dummy_step_1: "Step") -> None: - batch_manager_step = _BatchManagerStep.from_step( - step=dummy_step_1, predecessors=["step1", "step2"] - ) - - assert batch_manager_step.step_name == "dummy_step_1" - assert batch_manager_step.accumulate is False - assert batch_manager_step.input_batch_size == 50 - assert batch_manager_step.data == {"step1": [], "step2": []} - assert batch_manager_step.seq_no == 0 - assert batch_manager_step.last_batch_received == [] - - def test_from_step_with_global_step(self, dummy_global_step: "GlobalStep") -> None: - batch_manager_step = _BatchManagerStep.from_step( - step=dummy_global_step, predecessors=["step1", "step2"] - ) - - assert batch_manager_step.step_name == "dummy_global_step" - assert batch_manager_step.accumulate is True - assert batch_manager_step.input_batch_size == 50 - assert batch_manager_step.data == {"step1": [], "step2": []} - assert batch_manager_step.seq_no == 0 - assert batch_manager_step.last_batch_received == [] - - def test_get_seq_no(self) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step2", accumulate=False, input_batch_size=5, data={"step1": []} - ) - - seq_no = batch_manager_step._get_seq_no() - - assert seq_no == 0 - assert batch_manager_step.seq_no == 1 - - def test_get_data(self) -> None: - batch_step_1 = _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}]], - size=6, - batch_routed_to=["step1", "step2"], - ) - batch_step_2 = _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[ - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - {"b": 7}, - ] - ], - size=7, - batch_routed_to=["step1", "step2"], - ) - batch_manager_step = _BatchManagerStep( - step_name="step3", - accumulate=False, - input_batch_size=5, - data={ - "step1": [batch_step_1], - "step2": [batch_step_2], - }, - ) - - data, created_from, routed_to = batch_manager_step._get_data() - assert data == [ - [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}], - [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}], - ] - assert created_from == {"step1": [(0, 6)], "step2": [(0, 7)]} - assert routed_to == ["step1", "step2"] - - assert batch_manager_step.data == { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 6}]], - data_hash=batch_step_1.data_hash, - size=6, - batch_routed_to=["step1", "step2"], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[[{"b": 6}, {"b": 7}]], - data_hash=batch_step_2.data_hash, - size=7, - batch_routed_to=["step1", "step2"], - ) - ], - } - - def test_get_data_accumulate(self) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step3", - accumulate=True, - data={ - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[ - [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}] - ], - size=6, - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[ - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - {"b": 7}, - ] - ], - size=7, - ) - ], - }, - ) - - data, created_from, routed_to = batch_manager_step._get_data() - - assert data == [ - [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}], - [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}, {"b": 7}], - ] - assert created_from == {"step1": [(0, 6)], "step2": [(0, 7)]} - assert routed_to == [] - - assert batch_manager_step.data == {"step1": [], "step2": []} - - def test_get_data_convergence_step(self) -> None: - batch_a_0 = _Batch( - seq_no=0, - step_name="A", - last_batch=False, - data=[ - [ - {"generation": "Hello, I'm A 0"}, - {"generation": "Hello, I'm A 0"}, - {"generation": "Hello, I'm A 0"}, - ] - ], - size=3, - created_from={"Z": [(0, 3)]}, - ) - - batch_a_1 = _Batch( - seq_no=1, - step_name="A", - last_batch=False, - data=[ - [ - {"generation": "Hello, I'm A 1"}, - {"generation": "Hello, I'm A 1"}, - {"generation": "Hello, I'm A 1"}, - ] - ], - size=3, - created_from={"Z": [(1, 3)]}, - ) - - batch_b_0 = _Batch( - seq_no=0, - step_name="B", - last_batch=False, - data=[ - [ - {"generation": "Hello, I'm B 0"}, - {"generation": "Hello, I'm B 0"}, - {"generation": "Hello, I'm B 0"}, - ] - ], - size=3, - created_from={"Z": [(0, 3)]}, - ) - - batch_c_0 = _Batch( - seq_no=0, - step_name="C", - last_batch=False, - data=[ - [ - {"generation": "Hello, I'm C 0"}, - {"generation": "Hello, I'm C 0"}, - {"generation": "Hello, I'm C 0"}, - ] - ], - size=3, - created_from={"Z": [(1, 3)]}, - ) - - batch_manager_step = _BatchManagerStep( - step_name="D", - input_batch_size=3, - convergence_step=True, - accumulate=False, - data={"A": [batch_a_0, batch_a_1], "B": [batch_b_0], "C": [batch_c_0]}, - ) - - data, created_from, routed_to = batch_manager_step._get_data() - - assert data == [ - [ - {"generation": "Hello, I'm A 0"}, - {"generation": "Hello, I'm A 0"}, - {"generation": "Hello, I'm A 0"}, - ], - [ - {"generation": "Hello, I'm B 0"}, - {"generation": "Hello, I'm B 0"}, - {"generation": "Hello, I'm B 0"}, - ], - ] - assert created_from == {"A": [(0, 3)], "B": [(0, 3)]} - assert routed_to == [] - assert batch_manager_step.next_expected_created_from_batch_seq_no == 1 - - data, created_from, routed_to = batch_manager_step._get_data() - - assert data == [ - [ - {"generation": "Hello, I'm A 1"}, - {"generation": "Hello, I'm A 1"}, - {"generation": "Hello, I'm A 1"}, - ], - [ - {"generation": "Hello, I'm C 0"}, - {"generation": "Hello, I'm C 0"}, - {"generation": "Hello, I'm C 0"}, - ], - ] - assert created_from == {"A": [(1, 3)], "C": [(0, 3)]} - assert routed_to == [] - assert batch_manager_step.next_expected_created_from_batch_seq_no == 2 - - @pytest.mark.parametrize( - "data, last_batch_received, expected", - [ - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - ] - }, - [], - False, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]], - ) - ], - }, - [], - False, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[ - [ - {"a": 1}, - {"a": 2}, - {"a": 3}, - {"a": 4}, - {"a": 5}, - {"a": 6}, - ] - ], - ) - ] - }, - ["step1"], - False, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - ] - }, - ["step1"], - True, - ), - ], - ) - def test_last_batch( - self, - data: Dict[str, List[_Batch]], - last_batch_received: List[str], - expected: bool, - ) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step2", - accumulate=False, - input_batch_size=5, - data=data, - last_batch_received=last_batch_received, - ) - - assert batch_manager_step._last_batch() is expected - - @pytest.mark.parametrize( - "data, last_batch_received, expected", - [ - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - ) - ], - }, - [], - False, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - ) - ], - }, - ["step1"], - False, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - ) - ], - }, - ["step1", "step2"], - True, - ), - ], - ) - def test_last_batch_accumulate( - self, - data: Dict[str, List[_Batch]], - last_batch_received: List[str], - expected: bool, - ) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step3", - accumulate=True, - data=data, - last_batch_received=last_batch_received, - ) - - assert batch_manager_step._last_batch() is expected - - @pytest.mark.parametrize( - "data, last_batch_received, expected", - [ - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - created_from={"step0": [(0, 5)]}, - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - created_from={"step0": [(0, 5)]}, - ) - ], - }, - [], - False, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - created_from={"step0": [(0, 5)]}, - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - created_from={"step0": [(0, 5)]}, - ) - ], - }, - [], - False, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}]], - created_from={"step0": [(0, 3)]}, - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"b": 1}, {"b": 2}, {"b": 3}]], - created_from={"step0": [(0, 3)]}, - ) - ], - }, - [], - True, - ), - ], - ) - def test_last_batch_convergence_step( - self, - data: Dict[str, List[_Batch]], - last_batch_received: List[str], - expected: bool, - ) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step3", - accumulate=False, - data=data, - last_batch_received=last_batch_received, - input_batch_size=3, - convergence_step=True, - ) - - assert batch_manager_step._last_batch() is expected - - @pytest.mark.parametrize( - "data, last_batch_received, expected", - [ - ( - { - "step1": [], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - ) - ], - }, - [], - False, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - ) - ], - }, - [], - False, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=True, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - ) - ], - }, - ["step1", "step2"], - True, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=True, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]], - ) - ], - }, - ["step1", "step2"], - True, - ), - ], - ) - def test_ready_to_create_batch( - self, - data: Dict[str, List[_Batch]], - last_batch_received: List[str], - expected: bool, - ) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step2", - accumulate=False, - input_batch_size=5, - data=data, - last_batch_received=last_batch_received, - ) - - assert batch_manager_step._ready_to_create_batch() is expected - - @pytest.mark.parametrize( - "data, last_batch_received, expected", - [ - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=True, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - ) - ], - }, - ["step1", "step2"], - True, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - ) - ], - }, - ["step1"], - False, - ), - ], - ) - def test_ready_to_create_batch_accumulate( - self, - data: Dict[str, List[_Batch]], - last_batch_received: List[str], - expected: bool, - ) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step3", - accumulate=True, - data=data, - last_batch_received=last_batch_received, - ) - - assert batch_manager_step._ready_to_create_batch() is expected - - def test_dump(self) -> None: - batch_step_1 = _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}]], - data_hash="hash0", - size=6, - ) - batch_step_2 = _Batch( - seq_no=0, - step_name="step2", - last_batch=True, - data=[ - [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}, {"b": 7}] - ], - data_hash="hash1", - size=7, - ) - batch_step_3 = _Batch( - seq_no=0, - step_name="step3", - last_batch=True, - data=[[{"c": 1}, {"c": 2}, {"c": 3}, {"c": 4}, {"c": 5}]], - data_hash="hash2", - size=5, - ) - batch_manager_step = _BatchManagerStep( - step_name="step3", - accumulate=True, - data={ - "step1": [batch_step_1], - "step2": [batch_step_2], - }, - built_batches=[batch_step_3], - ) - assert batch_manager_step.dump() == { - "step_name": "step3", - "accumulate": True, - "convergence_step": False, - "convergence_step_batches_consumed": {}, - "input_batch_size": None, - "data": { - "step1": [ - { - "seq_no": 0, - "step_name": "step1", - "last_batch": True, - "data": [ - [ - {"a": 1}, - {"a": 2}, - {"a": 3}, - {"a": 4}, - {"a": 5}, - {"a": 6}, - ] - ], - "data_hash": "hash0", - "size": 6, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - ], - "step2": [ - { - "seq_no": 0, - "step_name": "step2", - "last_batch": True, - "data": [ - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - {"b": 7}, - ] - ], - "data_hash": "hash1", - "size": 7, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - ], - }, - "built_batches": [ - { - "seq_no": 0, - "step_name": "step3", - "last_batch": True, - "data": [[{"c": 1}, {"c": 2}, {"c": 3}, {"c": 4}, {"c": 5}]], - "data_hash": "hash2", - "size": 5, - "accumulated": False, - "batch_routed_to": [], - "created_from": {}, - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - ], - "seq_no": 0, - "last_batch_received": [], - "next_expected_created_from_batch_seq_no": 0, - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_BatchManagerStep", - }, - } - - @pytest.mark.parametrize( - "data, last_batch_received, expected", - [ - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - batch_routed_to=["step1", "step2"], - created_from={"step0": [(0, 5)]}, - ) - ], - "step2": [], - }, - [], - False, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - batch_routed_to=["step1", "step2"], - created_from={"step0": [(0, 5)]}, - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - batch_routed_to=["step1", "step2"], - created_from={"step0": [(0, 5)]}, - ) - ], - }, - [], - True, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], - batch_routed_to=["step1", "step2"], - created_from={"step0": [(0, 4)]}, - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - batch_routed_to=["step1", "step2"], - created_from={"step0": [(0, 5)]}, - ) - ], - }, - [], - False, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=True, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], - batch_routed_to=["step1", "step2"], - created_from={"step0": [(0, 4)]}, - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=True, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]], - batch_routed_to=["step1", "step2"], - created_from={"step0": [(0, 4)]}, - ) - ], - }, - ["step1", "step2"], - True, - ), - ( - { - "step1": [ - _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], - batch_routed_to=["step1", "step2"], - created_from={"step0": [(0, 4)]}, - ) - ], - "step2": [ - _Batch( - seq_no=0, - step_name="step2", - last_batch=False, - data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], - batch_routed_to=["step1", "step2"], - created_from={"step0": [(0, 5)]}, - ) - ], - }, - [], - False, - ), - ], - ) - def test_ready_to_create_batch_convergence_step( - self, - data: Dict[str, List[_Batch]], - last_batch_received: List[str], - expected: bool, - ) -> None: - batch_manager_step = _BatchManagerStep( - step_name="step3", - accumulate=False, - input_batch_size=5, - data=data, - last_batch_received=last_batch_received, - convergence_step=True, - ) - - assert batch_manager_step._ready_to_create_batch() is expected - - def test_from_dict(self) -> None: - batch_manager_step = _BatchManagerStep.from_dict( - { - "step_name": "step3", - "accumulate": True, - "convergence_step": False, - "convergence_step_batches_consumed": {0: {"Z": 1234}}, - "input_batch_size": None, - "data": { - "step1": [ - { - "seq_no": 0, - "step_name": "step1", - "last_batch": True, - "data": [ - [ - {"a": 1}, - {"a": 2}, - {"a": 3}, - {"a": 4}, - {"a": 5}, - {"a": 6}, - ] - ], - "size": 6, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - } - ], - "step2": [ - { - "seq_no": 0, - "step_name": "step2", - "last_batch": True, - "data": [ - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - {"b": 7}, - ] - ], - "size": 7, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - } - ], - }, - "seq_no": 0, - "last_batch_received": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_BatchManagerStep", - }, - } - ) - - assert isinstance(batch_manager_step, _BatchManagerStep) - assert batch_manager_step.step_name == "step3" - assert batch_manager_step.accumulate is True - assert batch_manager_step.convergence_step is False - assert batch_manager_step.convergence_step_batches_consumed == {0: {"Z": 1234}} - assert batch_manager_step.input_batch_size is None - assert batch_manager_step.seq_no == 0 - assert batch_manager_step.last_batch_received == [] - - -class TestBatchManager: - def test_add_batch(self) -> None: - batch_manager = _BatchManager( - steps={ - "step3": _BatchManagerStep( - step_name="step3", - accumulate=False, - input_batch_size=5, - data={"step1": [], "step2": []}, - ) - }, - last_batch_received={"step3": None}, - last_batch_sent={"step3": None}, - last_batch_flag_sent_to=[], - ) - - batch_from_step_1 = _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - batch_manager.add_batch(to_step="step3", batch=batch_from_step_1) - - assert batch_manager._steps["step3"].data == { - "step1": [batch_from_step_1], - "step2": [], - } - - def test_add_batch_with_prepend(self) -> None: - batch_1 = _Batch( - seq_no=1, - step_name="step1", - last_batch=False, - data=[[{"a": 6}, {"a": 7}, {"a": 8}, {"a": 9}, {"a": 10}]], - ) - batch_manager = _BatchManager( - steps={ - "step3": _BatchManagerStep( - step_name="step3", - accumulate=False, - input_batch_size=5, - data={ - "step1": [batch_1], - "step2": [], - }, - ) - }, - last_batch_received={"step3": None}, - last_batch_sent={"step3": None}, - last_batch_flag_sent_to=[], - ) - batch_0 = _Batch( - seq_no=0, - step_name="step1", - last_batch=False, - data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], - ) - batch_manager.add_batch(to_step="step3", batch=batch_0, prepend=True) - assert batch_manager._steps["step3"].built_batches == [batch_0] - assert batch_manager._steps["step3"].data == { - "step1": [batch_1], - "step2": [], - } - - def test_from_dag( - self, - dummy_generator_step: "GeneratorStep", - dummy_step_1: "Step", - dummy_step_2: "Step", - dummy_global_step: "GlobalStep", - ) -> None: - dag = DAG() - dag.add_step(dummy_generator_step) - dag.add_step(dummy_step_1) - dag.add_step(dummy_step_2) - dag.add_step(dummy_global_step) - dag.add_edge("dummy_generator_step", "dummy_step_1") - dag.add_edge("dummy_generator_step", "dummy_global_step") - dag.add_edge("dummy_step_1", "dummy_step_2") - - batch_manager = _BatchManager.from_dag(dag) - - assert batch_manager._steps == { - "dummy_step_1": _BatchManagerStep( - step_name="dummy_step_1", - accumulate=False, - input_batch_size=50, - data={"dummy_generator_step": []}, - ), - "dummy_global_step": _BatchManagerStep( - step_name="dummy_global_step", - accumulate=True, - input_batch_size=50, - data={"dummy_generator_step": []}, - ), - "dummy_step_2": _BatchManagerStep( - step_name="dummy_step_2", - accumulate=False, - input_batch_size=50, - data={"dummy_step_1": []}, - ), - } - - def test_can_generate(self) -> None: - batch_manager = _BatchManager( - steps={}, - last_batch_received={ - "step_1": _Batch(seq_no=0, step_name="step_1", last_batch=False), - "step_2": _Batch(seq_no=0, step_name="step_2", last_batch=False), - "step_3": _Batch(seq_no=0, step_name="step_3", last_batch=False), - }, - last_batch_sent={"step_1": None, "step_2": None, "step_3": None}, - last_batch_flag_sent_to=[], - ) - - assert batch_manager.can_generate() - - batch_1 = _Batch(seq_no=0, step_name="step_1", last_batch=True) - batch_2 = _Batch(seq_no=0, step_name="step_2", last_batch=True) - batch_3 = _Batch(seq_no=0, step_name="step_3", last_batch=True) - - batch_manager = _BatchManager( - steps={}, - last_batch_received={ - "step_1": batch_1, - "step_2": batch_2, - "step_3": batch_3, - }, - last_batch_sent={"step_1": batch_1, "step_2": batch_2, "step_3": batch_3}, - last_batch_flag_sent_to=[], - ) - - assert not batch_manager.can_generate() - - def test_dump(self) -> None: - built_batch = _Batch( - seq_no=0, - last_batch=False, - step_name="step3", - data=[[]], - data_hash="hash", - ) - - batch_manager = _BatchManager( - steps={ - "step3": _BatchManagerStep( - step_name="step3", - accumulate=False, - input_batch_size=5, - data={"step1": [], "step2": []}, - built_batches=[built_batch], - seq_no=1, - ) - }, - last_batch_received={ - "step3": _Batch( - seq_no=0, - step_name="step3", - last_batch=False, - ) - }, - last_batch_sent={ - "step3": _Batch( - seq_no=1, - step_name="step3", - last_batch=False, - ) - }, - last_batch_flag_sent_to=["step99"], - ) - assert batch_manager.dump() == { - "steps": { - "step3": { - "step_name": "step3", - "accumulate": False, - "convergence_step": False, - "convergence_step_batches_consumed": {}, - "input_batch_size": 5, - "data": {"step1": [], "step2": []}, - "built_batches": [ - { - "seq_no": 0, - "step_name": "step3", - "last_batch": False, - "data": [[]], - "data_hash": "hash", - "size": 0, - "accumulated": False, - "batch_routed_to": [], - "created_from": {}, - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - ], - "seq_no": 1, - "last_batch_received": [], - "next_expected_created_from_batch_seq_no": 0, - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_BatchManagerStep", - }, - }, - }, - "last_batch_received": { - "step3": { - "seq_no": 0, - "step_name": "step3", - "batch_routed_to": [], - "created_from": {}, - "last_batch": False, - "data": [], - "data_hash": None, - "size": 0, - "accumulated": False, - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - }, - "last_batch_sent": { - "step3": { - "seq_no": 1, - "step_name": "step3", - "batch_routed_to": [], - "created_from": {}, - "last_batch": False, - "data": [], - "data_hash": None, - "size": 0, - "accumulated": False, - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - }, - "last_batch_flag_sent_to": ["step99"], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_BatchManager", - }, - } - - def test_from_dict(self) -> None: - batch_manager = _BatchManager.from_dict( - { - "steps": { - "step1": { - "step_name": "step1", - "accumulate": True, - "convergence_step": False, - "convergence_step_batches_consumed": {0: {"Z": 1234}}, - "input_batch_size": None, - "data": { - "step2": [ - { - "seq_no": 0, - "step_name": "step2", - "last_batch": True, - "data": [ - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - {"b": 7}, - ] - ], - "size": 7, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - ], - }, - "seq_no": 0, - "last_batch_received": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_BatchManagerStep", - }, - }, - "step2": { - "step_name": "step2", - "accumulate": False, - "convergence_step": False, - "convergence_step_batches_consumed": {0: {"Z": 1234}}, - "input_batch_size": 50, - "data": { - "step2": [ - { - "seq_no": 0, - "step_name": "step2", - "last_batch": True, - "data": [ - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - {"b": 7}, - ] - ], - "size": 7, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - ], - }, - "seq_no": 0, - "last_batch_received": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_BatchManagerStep", - }, - }, - }, - "last_batch_received": { - "step1": { - "seq_no": 0, - "step_name": "step1", - "last_batch": False, - "data": [], - "size": 0, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - }, - "step2": { - "seq_no": 0, - "step_name": "step2", - "last_batch": False, - "data": [], - "size": 0, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - }, - }, - "last_batch_sent": { - "step1": { - "seq_no": 0, - "step_name": "step1", - "last_batch": False, - "data": [], - "size": 0, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - }, - "step2": { - "seq_no": 0, - "step_name": "step2", - "last_batch": False, - "data": [], - "size": 0, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - }, - }, - "last_batch_flag_sent_to": ["step3"], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_BatchManager", - }, - } - ) - - assert isinstance(batch_manager, _BatchManager) - - assert len(batch_manager._steps) == 2 - for step in batch_manager._steps.values(): - assert isinstance(step, _BatchManagerStep) - - assert len(batch_manager._last_batch_received) == 2 - for step in batch_manager._last_batch_received.values(): - assert isinstance(step, _Batch) - - assert len(batch_manager._last_batch_sent) == 2 - for step in batch_manager._last_batch_sent.values(): - assert isinstance(step, _Batch) - - assert batch_manager._last_batch_flag_sent_to == ["step3"] - - def test_cache(self) -> None: - batch_manager = _BatchManager.from_dict( - { - "steps": { - "step1": { - "step_name": "step1", - "accumulate": True, - "convergence_step": False, - "convergence_step_batches_consumed": {"0": {"Z": 1234}}, - "input_batch_size": None, - "data": { - "step2": [ - { - "seq_no": 0, - "step_name": "step2", - "last_batch": True, - "data": [ - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - {"b": 7}, - ] - ], - "data_hash": "1234", - "size": 7, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - ], - }, - "built_batches": [ - { - "seq_no": 0, - "step_name": "step1", - "last_batch": False, - "data": [ - [ - {"a": 1}, - {"a": 2}, - {"a": 3}, - {"a": 4}, - {"a": 5}, - ] - ], - "data_hash": "1234", - "size": 5, - "accumulated": False, - "batch_routed_to": [], - "created_from": {}, - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - ], - "seq_no": 0, - "last_batch_received": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_BatchManagerStep", - }, - }, - "step2": { - "step_name": "step2", - "accumulate": False, - "convergence_step": False, - "convergence_step_batches_consumed": {"0": {"Z": 1234}}, - "input_batch_size": 50, - "data": { - "step2": [ - { - "seq_no": 0, - "step_name": "step2", - "last_batch": True, - "data": [ - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - {"b": 7}, - ] - ], - "data_hash": "1234", - "size": 7, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - ], - }, - "built_batches": [ - { - "seq_no": 0, - "step_name": "step1", - "last_batch": False, - "data": [ - [ - {"a": 1}, - {"a": 2}, - {"a": 3}, - {"a": 4}, - {"a": 5}, - ] - ], - "data_hash": "1234", - "size": 5, - "accumulated": False, - "batch_routed_to": [], - "created_from": {}, - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - ], - "seq_no": 0, - "last_batch_received": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_BatchManagerStep", - }, - }, - }, - "last_batch_received": { - "step1": { - "seq_no": 0, - "step_name": "step1", - "last_batch": False, - "data": [], - "size": 0, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - }, - "step2": { - "seq_no": 0, - "step_name": "step2", - "last_batch": False, - "data": [], - "size": 0, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - }, - }, - "last_batch_sent": { - "step1": { - "seq_no": 0, - "step_name": "step1", - "last_batch": False, - "data": [], - "size": 0, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - }, - "step2": { - "seq_no": 0, - "step_name": "step2", - "last_batch": False, - "data": [], - "size": 0, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - }, - }, - "last_batch_flag_sent_to": ["step3"], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_BatchManager", - }, - } - ) - - with tempfile.TemporaryDirectory() as tmp_dir: - batch_manager_path = Path(tmp_dir) / "batch_manager.json" - batch_manager.cache(batch_manager_path) - - assert batch_manager_path.exists() and batch_manager_path.is_file() - - for step_name, step in batch_manager._steps.items(): - batch_manager_step_dir = ( - Path(tmp_dir) / "batch_manager_steps" / step_name - ) - assert ( - batch_manager_step_dir.exists() and batch_manager_step_dir.is_dir() - ) - - batch_manager_step_path = ( - batch_manager_step_dir / "batch_manager_step.json" - ) - assert ( - batch_manager_step_path.exists() - and batch_manager_step_path.is_file() - ) - - built_batches_dir = batch_manager_step_dir / "built_batches" - assert built_batches_dir.exists() - - for batch in step.built_batches: - batch_path = ( - built_batches_dir - / f"batch_{batch.seq_no}_{batch.data_hash}.json" - ) - assert batch_path.exists() and batch_path.is_file() - - for buffered_step_name in step.data: - buffered_step_dir = batch_manager_step_dir / buffered_step_name - assert buffered_step_dir.exists() and buffered_step_dir.is_dir() - - for batch in step.data[buffered_step_name]: - batch_path = ( - buffered_step_dir - / f"batch_{batch.seq_no}_{batch.data_hash}.json" - ) - assert batch_path.exists() and batch_path.is_file() - - def test_load_from_cache(self) -> None: - batch_manager = _BatchManager.from_dict( - { - "steps": { - "step1": { - "step_name": "step1", - "accumulate": True, - "convergence_step": False, - "convergence_step_batches_consumed": {"0": {"Z": 1234}}, - "input_batch_size": None, - "data": { - "step2": [ - { - "seq_no": 0, - "step_name": "step2", - "last_batch": True, - "data": [ - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - {"b": 7}, - ] - ], - "data_hash": "1234", - "size": 7, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - ], - }, - "built_batches": [ - { - "seq_no": 0, - "step_name": "step1", - "last_batch": False, - "data": [ - [ - {"a": 1}, - {"a": 2}, - {"a": 3}, - {"a": 4}, - {"a": 5}, - ] - ], - "data_hash": "1234", - "size": 5, - "accumulated": False, - "batch_routed_to": [], - "created_from": {}, - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - ], - "seq_no": 0, - "last_batch_received": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_BatchManagerStep", - }, - }, - "step2": { - "step_name": "step2", - "accumulate": False, - "convergence_step": False, - "convergence_step_batches_consumed": {"0": {"Z": 1234}}, - "input_batch_size": 50, - "data": { - "step2": [ - { - "seq_no": 0, - "step_name": "step2", - "last_batch": True, - "data": [ - [ - {"b": 1}, - {"b": 2}, - {"b": 3}, - {"b": 4}, - {"b": 5}, - {"b": 6}, - {"b": 7}, - ] - ], - "data_hash": "1234", - "size": 7, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - ], - }, - "built_batches": [ - { - "seq_no": 0, - "step_name": "step1", - "last_batch": False, - "data": [ - [ - {"a": 1}, - {"a": 2}, - {"a": 3}, - {"a": 4}, - {"a": 5}, - ] - ], - "data_hash": "1234", - "size": 5, - "accumulated": False, - "batch_routed_to": [], - "created_from": {}, - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - } - ], - "seq_no": 0, - "last_batch_received": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_BatchManagerStep", - }, - }, - }, - "last_batch_received": { - "step1": { - "seq_no": 0, - "step_name": "step1", - "last_batch": False, - "data": [], - "size": 0, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - }, - "step2": { - "seq_no": 0, - "step_name": "step2", - "last_batch": False, - "data": [], - "size": 0, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - }, - }, - "last_batch_sent": { - "step1": { - "seq_no": 0, - "step_name": "step1", - "last_batch": False, - "data": [], - "size": 0, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - }, - "step2": { - "seq_no": 0, - "step_name": "step2", - "last_batch": False, - "data": [], - "size": 0, - "accumulated": False, - "created_from": {}, - "batch_routed_to": [], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_Batch", - }, - }, - }, - "last_batch_flag_sent_to": ["step3"], - "type_info": { - "module": "distilabel.pipeline.base", - "name": "_BatchManager", - }, - } - ) - - with tempfile.TemporaryDirectory() as tmp_dir: - batch_manager_path = Path(tmp_dir) / "batch_manager.json" - batch_manager.cache(batch_manager_path) - loaded_batch_manager = _BatchManager.load_from_cache(batch_manager_path) - - assert batch_manager.dump() == loaded_batch_manager.dump() - - -class TestPipelineSerialization: - def test_base_pipeline_dump(self): - pipeline = BasePipeline(name="unit-test-pipeline") - dump = pipeline.dump() - assert len(dump.keys()) == 2 - assert "pipeline" in dump - assert "distilabel" in dump - assert TYPE_INFO_KEY in dump["pipeline"] - assert dump["pipeline"][TYPE_INFO_KEY]["module"] == "distilabel.pipeline.base" - assert dump["pipeline"][TYPE_INFO_KEY]["name"] == "BasePipeline" - - def test_base_pipeline_from_dict(self): - pipeline = BasePipeline(name="unit-test-pipeline") - pipe = BasePipeline.from_dict(pipeline.dump()) - assert isinstance(pipe, BasePipeline) - - def test_pipeline_dump(self): - from distilabel.pipeline.local import Pipeline - - pipeline = Pipeline(name="unit-test-pipeline") - dump = pipeline.dump() - assert len(dump.keys()) == 2 - assert "pipeline" in dump - assert "distilabel" in dump - assert TYPE_INFO_KEY in dump["pipeline"] - assert dump["pipeline"][TYPE_INFO_KEY]["module"] == "distilabel.pipeline.local" - assert dump["pipeline"][TYPE_INFO_KEY]["name"] == "Pipeline" - - @pytest.mark.parametrize( - "format, name, loader", - [ - ("yaml", "pipe.yaml", BasePipeline.from_yaml), - ("json", "pipe.json", BasePipeline.from_json), - ("invalid", "pipe.invalid", None), - ], - ) - def test_pipeline_to_from_file_format( - self, - format: str, - name: str, - loader: Callable, - ) -> None: - pipeline = BasePipeline(name="unit-test-pipeline") - - with tempfile.TemporaryDirectory() as tmpdirname: - filename = Path(tmpdirname) / name - if format == "invalid": - with pytest.raises(ValueError): - pipeline.save(filename, format=format) - else: - pipeline.save(filename, format=format) - assert filename.exists() - pipe_from_file = loader(filename) - assert isinstance(pipe_from_file, BasePipeline) - - def test_base_pipeline_signature(self): - pipeline = BasePipeline(name="unit-test-pipeline") - # Doesn't matter if it's exactly this or not, the test should fail if we change the - # way this is created. - signature = pipeline._create_signature() - assert signature == "da39a3ee5e6b4b0d3255bfef95601890afd80709" - - # Maybe not the best place for this test, but does the work for now - from distilabel.pipeline.local import Pipeline - from distilabel.pipeline.routing_batch_function import sample_n_steps - - from tests.unit.pipeline.utils import DummyGeneratorStep, DummyStep1, DummyStep2 - - sample_two_steps = sample_n_steps(2) + sample_two_steps = sample_n_steps(2) with Pipeline(name="unit-test-pipeline") as pipeline: dummy_generator = DummyGeneratorStep() @@ -2951,126 +1102,3 @@ def test_binary_operators(self) -> None: signature_2 = pipeline_2._create_signature() assert signature_1 == signature_2 - - -class TestWriteBuffer: - def test_create(self) -> None: - with tempfile.TemporaryDirectory() as tmpdirname: - folder = Path(tmpdirname) / "data" - with Pipeline(name="unit-test-pipeline") as pipeline: - dummy_generator_1 = DummyGeneratorStep(name="dummy_generator_step_1") - dummy_generator_2 = DummyGeneratorStep(name="dummy_generator_step_2") - dummy_step_1 = DummyStep1(name="dummy_step_1") - dummy_step_2 = DummyStep2(name="dummy_step_2") - dummy_step_3 = DummyStep2(name="dummy_step_3") - - dummy_generator_1.connect(dummy_step_1) - dummy_generator_2.connect(dummy_step_2) - dummy_step_1.connect(dummy_step_2) - dummy_step_1.connect(dummy_step_3) - - write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps) - - assert write_buffer._buffers == {"dummy_step_2": [], "dummy_step_3": []} - assert write_buffer._buffers_dump_batch_size == { - "dummy_step_2": 50, - "dummy_step_3": 50, - } - assert write_buffer._buffer_last_schema == {} - assert write_buffer._buffers_last_file == { - "dummy_step_2": 1, - "dummy_step_3": 1, - } - - def test_write_buffer_one_leaf_step_and_create_dataset(self) -> None: - with tempfile.TemporaryDirectory() as tmpdirname: - folder = Path(tmpdirname) / "data" - with Pipeline(name="unit-test-pipeline") as pipeline: - dummy_generator = DummyGeneratorStep(name="dummy_generator_step") - dummy_step_1 = DummyStep1(name="dummy_step_1") - dummy_step_2 = DummyStep2(name="dummy_step_2") - - dummy_generator.connect(dummy_step_1) - dummy_step_1.connect(dummy_step_2) - - write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps) - - # Add one batch with 5 rows, shouldn't write anything 5 < 50 - batch = batch_gen(dummy_step_2.name) - write_buffer.add_batch(batch) - - # Add 45 more rows, should write now - for _ in range(9): - batch = batch_gen(dummy_step_2.name) - write_buffer.add_batch(batch) - - assert Path(folder, "dummy_step_2", "00001.parquet").exists() - - # Add 50 more rows, we should have a new file - for _ in range(10): - batch = batch_gen(dummy_step_2.name) - write_buffer.add_batch(batch) - - assert Path(folder, "dummy_step_2", "00002.parquet").exists() - - # Add more rows and close the write buffer, we should have a new file - for _ in range(5): - batch = batch_gen(dummy_step_2.name) - write_buffer.add_batch(batch) - - write_buffer.close() - - assert Path(folder, "dummy_step_2", "00003.parquet").exists() - - ds = create_distiset(write_buffer._path) - assert isinstance(ds, Distiset) - assert len(ds.keys()) == 1 - assert len(ds["default"]["train"]) == 125 - - def test_write_buffer_multiple_leaf_steps_and_create_dataset(self): - with tempfile.TemporaryDirectory() as tmpdirname: - folder = Path(tmpdirname) / "data" - with Pipeline(name="unit-test-pipeline") as pipeline: - dummy_generator_1 = DummyGeneratorStep(name="dummy_generator_step_1") - dummy_generator_2 = DummyGeneratorStep(name="dummy_generator_step_2") - dummy_step_1 = DummyStep1(name="dummy_step_1") - dummy_step_2 = DummyStep2(name="dummy_step_2") - dummy_step_3 = DummyStep2(name="dummy_step_3") - - dummy_generator_1.connect(dummy_step_1) - dummy_generator_2.connect(dummy_step_2) - dummy_step_1.connect(dummy_step_2) - dummy_step_1.connect(dummy_step_3) - - write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps) - - for _ in range(10): - batch = batch_gen(dummy_step_2.name) - write_buffer.add_batch(batch) - - assert Path(folder, "dummy_step_2", "00001.parquet").exists() - - for _ in range(10): - batch = batch_gen(dummy_step_3.name) - write_buffer.add_batch(batch) - - assert Path(folder, "dummy_step_3", "00001.parquet").exists() - - for _ in range(5): - batch = batch_gen(dummy_step_2.name) - write_buffer.add_batch(batch) - - for _ in range(5): - batch = batch_gen(dummy_step_3.name) - write_buffer.add_batch(batch) - - write_buffer.close() - - assert Path(folder, "dummy_step_2", "00002.parquet").exists() - assert Path(folder, "dummy_step_3", "00002.parquet").exists() - - ds = create_distiset(write_buffer._path) - assert isinstance(ds, Distiset) - assert len(ds.keys()) == 2 - assert len(ds["dummy_step_2"]["train"]) == 75 - assert len(ds["dummy_step_3"]["train"]) == 75 diff --git a/tests/unit/pipeline/test_batch.py b/tests/unit/pipeline/test_batch.py new file mode 100644 index 0000000000..ed246e491f --- /dev/null +++ b/tests/unit/pipeline/test_batch.py @@ -0,0 +1,172 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distilabel.pipeline.batch import _Batch + + +class TestBatch: + def test_get_data(self) -> None: + batch = _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[ + [ + {"a": 0}, + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + {"a": 6}, + ] + ], + ) + + batch.set_data( + [ + [ + {"a": 0}, + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + {"a": 6}, + ] + ] + ) + + old_hash = batch.data_hash + + data = batch.get_data(5) + assert data == [{"a": 0}, {"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}] + assert batch.data == [[{"a": 5}, {"a": 6}]] + assert batch.data_hash != old_hash + + def test_set_data(self) -> None: + batch = _Batch(seq_no=0, step_name="step1", last_batch=False) + data = [[{"i": i} for i in range(5000)]] + batch.set_data(data) + + assert batch.data == data + assert batch.size == 5000 + + def test_next_batch(self) -> None: + batch = _Batch(seq_no=0, step_name="step1", last_batch=False) + next_batch = batch.next_batch() + + assert next_batch == _Batch(seq_no=1, step_name="step1", last_batch=False) + + def test_accumulate(self) -> None: + batches = [ + [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}]], + ), + _Batch( + seq_no=1, + step_name="step1", + last_batch=True, + data=[[{"a": 4}, {"a": 5}, {"a": 6}]], + ), + ], + [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}]], + ), + _Batch( + seq_no=1, + step_name="step2", + last_batch=True, + data=[[{"b": 4}, {"b": 5}, {"b": 6}]], + ), + ], + ] + + batch = _Batch.accumulate("step3", batches) + + assert batch.seq_no == 0 + assert batch.step_name == "step3" + assert batch.last_batch is True + assert batch.data == [ + [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}], + [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}], + ] + + def test_dump(self) -> None: + batch = _Batch(seq_no=0, step_name="step1", last_batch=False) + assert batch.dump() == { + "seq_no": 0, + "size": 0, + "step_name": "step1", + "last_batch": False, + "data": [], + "data_hash": None, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": {"module": "distilabel.pipeline.batch", "name": "_Batch"}, + } + + batch = _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}]], + data_hash="hash", + accumulated=False, + created_from={"step0": [(0, 5), (1, 5)]}, + batch_routed_to=["step2", "step3"], + ) + assert batch.dump() == { + "seq_no": 0, + "size": 0, + "step_name": "step1", + "last_batch": False, + "data": [[{"a": 1}, {"a": 2}, {"a": 3}]], + "data_hash": "hash", + "accumulated": False, + "created_from": {"step0": [(0, 5), (1, 5)]}, + "batch_routed_to": ["step2", "step3"], + "type_info": {"module": "distilabel.pipeline.batch", "name": "_Batch"}, + } + + def test_from_dict(self) -> None: + batch = _Batch.from_dict( + { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [[{"a": 1}, {"a": 2}, {"a": 3}]], + "accumulated": False, + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ) + + assert isinstance(batch, _Batch) + assert batch.seq_no == 0 + assert batch.step_name == "step1" + assert batch.last_batch is False + assert batch.data == [[{"a": 1}, {"a": 2}, {"a": 3}]] + assert batch.accumulated is False diff --git a/tests/unit/pipeline/test_batch_manager.py b/tests/unit/pipeline/test_batch_manager.py new file mode 100644 index 0000000000..7b1cb1a8a6 --- /dev/null +++ b/tests/unit/pipeline/test_batch_manager.py @@ -0,0 +1,2214 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from pathlib import Path +from typing import Dict, List + +import pytest +from distilabel.pipeline._dag import DAG +from distilabel.pipeline.batch import _Batch +from distilabel.pipeline.batch_manager import _BatchManager, _BatchManagerStep +from distilabel.steps.base import GeneratorStep, GlobalStep, Step + + +class TestBatchManagerStep: + def test_add_batch(self) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step2", accumulate=False, input_batch_size=10, data={"step1": []} + ) + + batch = _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}]], + ) + + batch_manager_step.add_batch(batch) + + assert batch_manager_step.data["step1"] == [batch] + assert batch_manager_step.last_batch_received == [] + + def test_add_batch_with_prepend(self) -> None: + batch_1 = _Batch( + seq_no=1, + step_name="step1", + last_batch=False, + data=[[{"a": 6}, {"a": 7}, {"a": 8}, {"a": 9}, {"a": 10}]], + ) + batch_manager_step = _BatchManagerStep( + step_name="step2", + accumulate=False, + input_batch_size=10, + data={"step1": [batch_1]}, + ) + + batch_0 = _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + batch_manager_step.add_batch(batch_0, prepend=True) + + assert batch_manager_step.built_batches == [batch_0] + assert batch_manager_step.data["step1"] == [batch_1] + assert batch_manager_step.last_batch_received == [] + + def test_add_batch_last_batch(self) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step2", accumulate=False, input_batch_size=10, data={"step1": []} + ) + + batch = _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}]], + ) + + batch_manager_step.add_batch(batch) + + assert batch_manager_step.data["step1"] == [batch] + assert batch_manager_step.last_batch_received == ["step1"] + + def test_get_batch(self) -> None: + previously_built_batch = _Batch( + seq_no=0, + step_name="step3", + last_batch=False, + data=[ + [ + {"a": -1}, + {"a": 0}, + ], + [ + {"b": -1}, + {"b": 0}, + ], + ], + ) + + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=False, + input_batch_size=2, + seq_no=1, + data={ + "step1": [ + _Batch( + seq_no=1, + step_name="step1", + last_batch=False, + data=[ + [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + ] + ], + size=5, + ) + ], + "step2": [ + _Batch( + seq_no=1, + step_name="step2", + last_batch=False, + data=[ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + ] + ], + size=5, + ) + ], + }, + built_batches=[previously_built_batch], + ) + + batch = batch_manager_step.get_batch() + + assert batch == previously_built_batch + + batch = batch_manager_step.get_batch() + + assert batch == _Batch( + step_name="step3", + seq_no=1, + last_batch=False, + data=[ + [ + {"a": 1}, + {"a": 2}, + ], + [ + {"b": 1}, + {"b": 2}, + ], + ], + created_from={"step1": [(1, 5)], "step2": [(1, 5)]}, + ) + + batch = batch_manager_step.get_batch() + + assert batch == _Batch( + step_name="step3", + seq_no=2, + last_batch=False, + data=[ + [ + {"a": 3}, + {"a": 4}, + ], + [ + {"b": 3}, + {"b": 4}, + ], + ], + created_from={"step1": [(1, 5)], "step2": [(1, 5)]}, + ) + + def test_get_batches_accumulate(self) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=True, + data={ + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[ + [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + ] + ], + size=5, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=True, + data=[ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + ] + ], + size=6, + ) + ], + }, + last_batch_received=["step1", "step2"], + ) + + batch = batch_manager_step.get_batch() + + assert batch == _Batch( + step_name="step3", + seq_no=0, + last_batch=True, + accumulated=True, + data=[ + [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + ], + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + ], + ], + created_from={"step1": [(0, 5)], "step2": [(0, 6)]}, + ) + + def test_get_batches_not_enough_data(self) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=False, + input_batch_size=2, + data={ + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[ + [ + {"a": 1}, + ] + ], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[ + [ + {"b": 1}, + {"b": 2}, + ] + ], + ) + ], + }, + ) + + assert batch_manager_step.get_batch() is None + + def test_from_step(self, dummy_step_1: "Step") -> None: + batch_manager_step = _BatchManagerStep.from_step( + step=dummy_step_1, predecessors=["step1", "step2"] + ) + + assert batch_manager_step.step_name == "dummy_step_1" + assert batch_manager_step.accumulate is False + assert batch_manager_step.input_batch_size == 50 + assert batch_manager_step.data == {"step1": [], "step2": []} + assert batch_manager_step.seq_no == 0 + assert batch_manager_step.last_batch_received == [] + + def test_from_step_with_global_step(self, dummy_global_step: "GlobalStep") -> None: + batch_manager_step = _BatchManagerStep.from_step( + step=dummy_global_step, predecessors=["step1", "step2"] + ) + + assert batch_manager_step.step_name == "dummy_global_step" + assert batch_manager_step.accumulate is True + assert batch_manager_step.input_batch_size == 50 + assert batch_manager_step.data == {"step1": [], "step2": []} + assert batch_manager_step.seq_no == 0 + assert batch_manager_step.last_batch_received == [] + + def test_get_seq_no(self) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step2", accumulate=False, input_batch_size=5, data={"step1": []} + ) + + seq_no = batch_manager_step._get_seq_no() + + assert seq_no == 0 + assert batch_manager_step.seq_no == 1 + + def test_get_data(self) -> None: + batch_step_1 = _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}]], + size=6, + batch_routed_to=["step1", "step2"], + ) + batch_step_2 = _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + {"b": 7}, + ] + ], + size=7, + batch_routed_to=["step1", "step2"], + ) + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=False, + input_batch_size=5, + data={ + "step1": [batch_step_1], + "step2": [batch_step_2], + }, + ) + + data, created_from, routed_to = batch_manager_step._get_data() + assert data == [ + [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}], + [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}], + ] + assert created_from == {"step1": [(0, 6)], "step2": [(0, 7)]} + assert routed_to == ["step1", "step2"] + + assert batch_manager_step.data == { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 6}]], + data_hash=batch_step_1.data_hash, + size=6, + batch_routed_to=["step1", "step2"], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"b": 6}, {"b": 7}]], + data_hash=batch_step_2.data_hash, + size=7, + batch_routed_to=["step1", "step2"], + ) + ], + } + + def test_get_data_accumulate(self) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=True, + data={ + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[ + [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}] + ], + size=6, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + {"b": 7}, + ] + ], + size=7, + ) + ], + }, + ) + + data, created_from, routed_to = batch_manager_step._get_data() + + assert data == [ + [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}], + [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}, {"b": 7}], + ] + assert created_from == {"step1": [(0, 6)], "step2": [(0, 7)]} + assert routed_to == [] + + assert batch_manager_step.data == {"step1": [], "step2": []} + + def test_get_data_convergence_step(self) -> None: + batch_a_0 = _Batch( + seq_no=0, + step_name="A", + last_batch=False, + data=[ + [ + {"generation": "Hello, I'm A 0"}, + {"generation": "Hello, I'm A 0"}, + {"generation": "Hello, I'm A 0"}, + ] + ], + size=3, + created_from={"Z": [(0, 3)]}, + ) + + batch_a_1 = _Batch( + seq_no=1, + step_name="A", + last_batch=False, + data=[ + [ + {"generation": "Hello, I'm A 1"}, + {"generation": "Hello, I'm A 1"}, + {"generation": "Hello, I'm A 1"}, + ] + ], + size=3, + created_from={"Z": [(1, 3)]}, + ) + + batch_b_0 = _Batch( + seq_no=0, + step_name="B", + last_batch=False, + data=[ + [ + {"generation": "Hello, I'm B 0"}, + {"generation": "Hello, I'm B 0"}, + {"generation": "Hello, I'm B 0"}, + ] + ], + size=3, + created_from={"Z": [(0, 3)]}, + ) + + batch_c_0 = _Batch( + seq_no=0, + step_name="C", + last_batch=False, + data=[ + [ + {"generation": "Hello, I'm C 0"}, + {"generation": "Hello, I'm C 0"}, + {"generation": "Hello, I'm C 0"}, + ] + ], + size=3, + created_from={"Z": [(1, 3)]}, + ) + + batch_manager_step = _BatchManagerStep( + step_name="D", + input_batch_size=3, + convergence_step=True, + accumulate=False, + data={"A": [batch_a_0, batch_a_1], "B": [batch_b_0], "C": [batch_c_0]}, + ) + + data, created_from, routed_to = batch_manager_step._get_data() + + assert data == [ + [ + {"generation": "Hello, I'm A 0"}, + {"generation": "Hello, I'm A 0"}, + {"generation": "Hello, I'm A 0"}, + ], + [ + {"generation": "Hello, I'm B 0"}, + {"generation": "Hello, I'm B 0"}, + {"generation": "Hello, I'm B 0"}, + ], + ] + assert created_from == {"A": [(0, 3)], "B": [(0, 3)]} + assert routed_to == [] + assert batch_manager_step.next_expected_created_from_batch_seq_no == 1 + + data, created_from, routed_to = batch_manager_step._get_data() + + assert data == [ + [ + {"generation": "Hello, I'm A 1"}, + {"generation": "Hello, I'm A 1"}, + {"generation": "Hello, I'm A 1"}, + ], + [ + {"generation": "Hello, I'm C 0"}, + {"generation": "Hello, I'm C 0"}, + {"generation": "Hello, I'm C 0"}, + ], + ] + assert created_from == {"A": [(1, 3)], "C": [(0, 3)]} + assert routed_to == [] + assert batch_manager_step.next_expected_created_from_batch_seq_no == 2 + + @pytest.mark.parametrize( + "data, last_batch_received, expected", + [ + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + ] + }, + [], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]], + ) + ], + }, + [], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[ + [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + {"a": 6}, + ] + ], + ) + ] + }, + ["step1"], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + ] + }, + ["step1"], + True, + ), + ], + ) + def test_last_batch( + self, + data: Dict[str, List[_Batch]], + last_batch_received: List[str], + expected: bool, + ) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step2", + accumulate=False, + input_batch_size=5, + data=data, + last_batch_received=last_batch_received, + ) + + assert batch_manager_step._last_batch() is expected + + @pytest.mark.parametrize( + "data, last_batch_received, expected", + [ + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + ) + ], + }, + [], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + ) + ], + }, + ["step1"], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + ) + ], + }, + ["step1", "step2"], + True, + ), + ], + ) + def test_last_batch_accumulate( + self, + data: Dict[str, List[_Batch]], + last_batch_received: List[str], + expected: bool, + ) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=True, + data=data, + last_batch_received=last_batch_received, + ) + + assert batch_manager_step._last_batch() is expected + + @pytest.mark.parametrize( + "data, last_batch_received, expected", + [ + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + created_from={"step0": [(0, 5)]}, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + created_from={"step0": [(0, 5)]}, + ) + ], + }, + [], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + created_from={"step0": [(0, 5)]}, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + created_from={"step0": [(0, 5)]}, + ) + ], + }, + [], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}]], + created_from={"step0": [(0, 3)]}, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"b": 1}, {"b": 2}, {"b": 3}]], + created_from={"step0": [(0, 3)]}, + ) + ], + }, + [], + True, + ), + ], + ) + def test_last_batch_convergence_step( + self, + data: Dict[str, List[_Batch]], + last_batch_received: List[str], + expected: bool, + ) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=False, + data=data, + last_batch_received=last_batch_received, + input_batch_size=3, + convergence_step=True, + ) + + assert batch_manager_step._last_batch() is expected + + @pytest.mark.parametrize( + "data, last_batch_received, expected", + [ + ( + { + "step1": [], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + ) + ], + }, + [], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + ) + ], + }, + [], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=True, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + ) + ], + }, + ["step1", "step2"], + True, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=True, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]], + ) + ], + }, + ["step1", "step2"], + True, + ), + ], + ) + def test_ready_to_create_batch( + self, + data: Dict[str, List[_Batch]], + last_batch_received: List[str], + expected: bool, + ) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step2", + accumulate=False, + input_batch_size=5, + data=data, + last_batch_received=last_batch_received, + ) + + assert batch_manager_step._ready_to_create_batch() is expected + + @pytest.mark.parametrize( + "data, last_batch_received, expected", + [ + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=True, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + ) + ], + }, + ["step1", "step2"], + True, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + ) + ], + }, + ["step1"], + False, + ), + ], + ) + def test_ready_to_create_batch_accumulate( + self, + data: Dict[str, List[_Batch]], + last_batch_received: List[str], + expected: bool, + ) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=True, + data=data, + last_batch_received=last_batch_received, + ) + + assert batch_manager_step._ready_to_create_batch() is expected + + def test_dump(self) -> None: + batch_step_1 = _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}]], + data_hash="hash0", + size=6, + ) + batch_step_2 = _Batch( + seq_no=0, + step_name="step2", + last_batch=True, + data=[ + [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}, {"b": 7}] + ], + data_hash="hash1", + size=7, + ) + batch_step_3 = _Batch( + seq_no=0, + step_name="step3", + last_batch=True, + data=[[{"c": 1}, {"c": 2}, {"c": 3}, {"c": 4}, {"c": 5}]], + data_hash="hash2", + size=5, + ) + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=True, + data={ + "step1": [batch_step_1], + "step2": [batch_step_2], + }, + built_batches=[batch_step_3], + ) + assert batch_manager_step.dump() == { + "step_name": "step3", + "accumulate": True, + "convergence_step": False, + "convergence_step_batches_consumed": {}, + "input_batch_size": None, + "data": { + "step1": [ + { + "seq_no": 0, + "step_name": "step1", + "last_batch": True, + "data": [ + [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + {"a": 6}, + ] + ], + "data_hash": "hash0", + "size": 6, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ], + "step2": [ + { + "seq_no": 0, + "step_name": "step2", + "last_batch": True, + "data": [ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + {"b": 7}, + ] + ], + "data_hash": "hash1", + "size": 7, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ], + }, + "built_batches": [ + { + "seq_no": 0, + "step_name": "step3", + "last_batch": True, + "data": [[{"c": 1}, {"c": 2}, {"c": 3}, {"c": 4}, {"c": 5}]], + "data_hash": "hash2", + "size": 5, + "accumulated": False, + "batch_routed_to": [], + "created_from": {}, + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ], + "seq_no": 0, + "last_batch_received": [], + "next_expected_created_from_batch_seq_no": 0, + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManagerStep", + }, + } + + @pytest.mark.parametrize( + "data, last_batch_received, expected", + [ + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 5)]}, + ) + ], + "step2": [], + }, + [], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 5)]}, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 5)]}, + ) + ], + }, + [], + True, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], + batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 4)]}, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 5)]}, + ) + ], + }, + [], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], + batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 4)]}, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=True, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]], + batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 4)]}, + ) + ], + }, + ["step1", "step2"], + True, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], + batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 4)]}, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 5)]}, + ) + ], + }, + [], + False, + ), + ], + ) + def test_ready_to_create_batch_convergence_step( + self, + data: Dict[str, List[_Batch]], + last_batch_received: List[str], + expected: bool, + ) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=False, + input_batch_size=5, + data=data, + last_batch_received=last_batch_received, + convergence_step=True, + ) + + assert batch_manager_step._ready_to_create_batch() is expected + + def test_from_dict(self) -> None: + batch_manager_step = _BatchManagerStep.from_dict( + { + "step_name": "step3", + "accumulate": True, + "convergence_step": False, + "convergence_step_batches_consumed": {0: {"Z": 1234}}, + "input_batch_size": None, + "data": { + "step1": [ + { + "seq_no": 0, + "step_name": "step1", + "last_batch": True, + "data": [ + [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + {"a": 6}, + ] + ], + "size": 6, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + } + ], + "step2": [ + { + "seq_no": 0, + "step_name": "step2", + "last_batch": True, + "data": [ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + {"b": 7}, + ] + ], + "size": 7, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + } + ], + }, + "seq_no": 0, + "last_batch_received": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManagerStep", + }, + } + ) + + assert isinstance(batch_manager_step, _BatchManagerStep) + assert batch_manager_step.step_name == "step3" + assert batch_manager_step.accumulate is True + assert batch_manager_step.convergence_step is False + assert batch_manager_step.convergence_step_batches_consumed == {0: {"Z": 1234}} + assert batch_manager_step.input_batch_size is None + assert batch_manager_step.seq_no == 0 + assert batch_manager_step.last_batch_received == [] + + +class TestBatchManager: + def test_add_batch(self) -> None: + batch_manager = _BatchManager( + steps={ + "step3": _BatchManagerStep( + step_name="step3", + accumulate=False, + input_batch_size=5, + data={"step1": [], "step2": []}, + ) + }, + last_batch_received={"step3": None}, + last_batch_sent={"step3": None}, + last_batch_flag_sent_to=[], + ) + + batch_from_step_1 = _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + batch_manager.add_batch(to_step="step3", batch=batch_from_step_1) + + assert batch_manager._steps["step3"].data == { + "step1": [batch_from_step_1], + "step2": [], + } + + def test_add_batch_with_prepend(self) -> None: + batch_1 = _Batch( + seq_no=1, + step_name="step1", + last_batch=False, + data=[[{"a": 6}, {"a": 7}, {"a": 8}, {"a": 9}, {"a": 10}]], + ) + batch_manager = _BatchManager( + steps={ + "step3": _BatchManagerStep( + step_name="step3", + accumulate=False, + input_batch_size=5, + data={ + "step1": [batch_1], + "step2": [], + }, + ) + }, + last_batch_received={"step3": None}, + last_batch_sent={"step3": None}, + last_batch_flag_sent_to=[], + ) + batch_0 = _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + batch_manager.add_batch(to_step="step3", batch=batch_0, prepend=True) + assert batch_manager._steps["step3"].built_batches == [batch_0] + assert batch_manager._steps["step3"].data == { + "step1": [batch_1], + "step2": [], + } + + def test_from_dag( + self, + dummy_generator_step: "GeneratorStep", + dummy_step_1: "Step", + dummy_step_2: "Step", + dummy_global_step: "GlobalStep", + ) -> None: + dag = DAG() + dag.add_step(dummy_generator_step) + dag.add_step(dummy_step_1) + dag.add_step(dummy_step_2) + dag.add_step(dummy_global_step) + dag.add_edge("dummy_generator_step", "dummy_step_1") + dag.add_edge("dummy_generator_step", "dummy_global_step") + dag.add_edge("dummy_step_1", "dummy_step_2") + + batch_manager = _BatchManager.from_dag(dag) + + assert batch_manager._steps == { + "dummy_step_1": _BatchManagerStep( + step_name="dummy_step_1", + accumulate=False, + input_batch_size=50, + data={"dummy_generator_step": []}, + ), + "dummy_global_step": _BatchManagerStep( + step_name="dummy_global_step", + accumulate=True, + input_batch_size=50, + data={"dummy_generator_step": []}, + ), + "dummy_step_2": _BatchManagerStep( + step_name="dummy_step_2", + accumulate=False, + input_batch_size=50, + data={"dummy_step_1": []}, + ), + } + + def test_can_generate(self) -> None: + batch_manager = _BatchManager( + steps={}, + last_batch_received={ + "step_1": _Batch(seq_no=0, step_name="step_1", last_batch=False), + "step_2": _Batch(seq_no=0, step_name="step_2", last_batch=False), + "step_3": _Batch(seq_no=0, step_name="step_3", last_batch=False), + }, + last_batch_sent={"step_1": None, "step_2": None, "step_3": None}, + last_batch_flag_sent_to=[], + ) + + assert batch_manager.can_generate() + + batch_1 = _Batch(seq_no=0, step_name="step_1", last_batch=True) + batch_2 = _Batch(seq_no=0, step_name="step_2", last_batch=True) + batch_3 = _Batch(seq_no=0, step_name="step_3", last_batch=True) + + batch_manager = _BatchManager( + steps={}, + last_batch_received={ + "step_1": batch_1, + "step_2": batch_2, + "step_3": batch_3, + }, + last_batch_sent={"step_1": batch_1, "step_2": batch_2, "step_3": batch_3}, + last_batch_flag_sent_to=[], + ) + + assert not batch_manager.can_generate() + + def test_dump(self) -> None: + built_batch = _Batch( + seq_no=0, + last_batch=False, + step_name="step3", + data=[[]], + data_hash="hash", + ) + + batch_manager = _BatchManager( + steps={ + "step3": _BatchManagerStep( + step_name="step3", + accumulate=False, + input_batch_size=5, + data={"step1": [], "step2": []}, + built_batches=[built_batch], + seq_no=1, + ) + }, + last_batch_received={ + "step3": _Batch( + seq_no=0, + step_name="step3", + last_batch=False, + ) + }, + last_batch_sent={ + "step3": _Batch( + seq_no=1, + step_name="step3", + last_batch=False, + ) + }, + last_batch_flag_sent_to=["step99"], + ) + assert batch_manager.dump() == { + "steps": { + "step3": { + "step_name": "step3", + "accumulate": False, + "convergence_step": False, + "convergence_step_batches_consumed": {}, + "input_batch_size": 5, + "data": {"step1": [], "step2": []}, + "built_batches": [ + { + "seq_no": 0, + "step_name": "step3", + "last_batch": False, + "data": [[]], + "data_hash": "hash", + "size": 0, + "accumulated": False, + "batch_routed_to": [], + "created_from": {}, + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ], + "seq_no": 1, + "last_batch_received": [], + "next_expected_created_from_batch_seq_no": 0, + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManagerStep", + }, + }, + }, + "last_batch_received": { + "step3": { + "seq_no": 0, + "step_name": "step3", + "batch_routed_to": [], + "created_from": {}, + "last_batch": False, + "data": [], + "data_hash": None, + "size": 0, + "accumulated": False, + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + }, + "last_batch_sent": { + "step3": { + "seq_no": 1, + "step_name": "step3", + "batch_routed_to": [], + "created_from": {}, + "last_batch": False, + "data": [], + "data_hash": None, + "size": 0, + "accumulated": False, + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + }, + "last_batch_flag_sent_to": ["step99"], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManager", + }, + } + + def test_from_dict(self) -> None: + batch_manager = _BatchManager.from_dict( + { + "steps": { + "step1": { + "step_name": "step1", + "accumulate": True, + "convergence_step": False, + "convergence_step_batches_consumed": {0: {"Z": 1234}}, + "input_batch_size": None, + "data": { + "step2": [ + { + "seq_no": 0, + "step_name": "step2", + "last_batch": True, + "data": [ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + {"b": 7}, + ] + ], + "size": 7, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + } + ], + }, + "seq_no": 0, + "last_batch_received": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManagerStep", + }, + }, + "step2": { + "step_name": "step2", + "accumulate": False, + "convergence_step": False, + "convergence_step_batches_consumed": {0: {"Z": 1234}}, + "input_batch_size": 50, + "data": { + "step2": [ + { + "seq_no": 0, + "step_name": "step2", + "last_batch": True, + "data": [ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + {"b": 7}, + ] + ], + "size": 7, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + } + ], + }, + "seq_no": 0, + "last_batch_received": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManagerStep", + }, + }, + }, + "last_batch_received": { + "step1": { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + }, + "step2": { + "seq_no": 0, + "step_name": "step2", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + }, + }, + "last_batch_sent": { + "step1": { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + }, + "step2": { + "seq_no": 0, + "step_name": "step2", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + }, + }, + "last_batch_flag_sent_to": ["step3"], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManager", + }, + } + ) + + assert isinstance(batch_manager, _BatchManager) + + assert len(batch_manager._steps) == 2 + for step in batch_manager._steps.values(): + assert isinstance(step, _BatchManagerStep) + + assert len(batch_manager._last_batch_received) == 2 + for step in batch_manager._last_batch_received.values(): + assert isinstance(step, _Batch) + + assert len(batch_manager._last_batch_sent) == 2 + for step in batch_manager._last_batch_sent.values(): + assert isinstance(step, _Batch) + + assert batch_manager._last_batch_flag_sent_to == ["step3"] + + def test_cache(self) -> None: + batch_manager = _BatchManager.from_dict( + { + "steps": { + "step1": { + "step_name": "step1", + "accumulate": True, + "convergence_step": False, + "convergence_step_batches_consumed": {"0": {"Z": 1234}}, + "input_batch_size": None, + "data": { + "step2": [ + { + "seq_no": 0, + "step_name": "step2", + "last_batch": True, + "data": [ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + {"b": 7}, + ] + ], + "data_hash": "1234", + "size": 7, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + } + ], + }, + "built_batches": [ + { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [ + [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + ] + ], + "data_hash": "1234", + "size": 5, + "accumulated": False, + "batch_routed_to": [], + "created_from": {}, + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ], + "seq_no": 0, + "last_batch_received": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManagerStep", + }, + }, + "step2": { + "step_name": "step2", + "accumulate": False, + "convergence_step": False, + "convergence_step_batches_consumed": {"0": {"Z": 1234}}, + "input_batch_size": 50, + "data": { + "step2": [ + { + "seq_no": 0, + "step_name": "step2", + "last_batch": True, + "data": [ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + {"b": 7}, + ] + ], + "data_hash": "1234", + "size": 7, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + } + ], + }, + "built_batches": [ + { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [ + [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + ] + ], + "data_hash": "1234", + "size": 5, + "accumulated": False, + "batch_routed_to": [], + "created_from": {}, + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ], + "seq_no": 0, + "last_batch_received": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManagerStep", + }, + }, + }, + "last_batch_received": { + "step1": { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + }, + "step2": { + "seq_no": 0, + "step_name": "step2", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + }, + }, + "last_batch_sent": { + "step1": { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + }, + "step2": { + "seq_no": 0, + "step_name": "step2", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_Batch", + }, + }, + }, + "last_batch_flag_sent_to": ["step3"], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManager", + }, + } + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + batch_manager_path = Path(tmp_dir) / "batch_manager.json" + batch_manager.cache(batch_manager_path) + + assert batch_manager_path.exists() and batch_manager_path.is_file() + + for step_name, step in batch_manager._steps.items(): + batch_manager_step_dir = ( + Path(tmp_dir) / "batch_manager_steps" / step_name + ) + assert ( + batch_manager_step_dir.exists() and batch_manager_step_dir.is_dir() + ) + + batch_manager_step_path = ( + batch_manager_step_dir / "batch_manager_step.json" + ) + assert ( + batch_manager_step_path.exists() + and batch_manager_step_path.is_file() + ) + + built_batches_dir = batch_manager_step_dir / "built_batches" + assert built_batches_dir.exists() + + for batch in step.built_batches: + batch_path = ( + built_batches_dir + / f"batch_{batch.seq_no}_{batch.data_hash}.json" + ) + assert batch_path.exists() and batch_path.is_file() + + for buffered_step_name in step.data: + buffered_step_dir = batch_manager_step_dir / buffered_step_name + assert buffered_step_dir.exists() and buffered_step_dir.is_dir() + + for batch in step.data[buffered_step_name]: + batch_path = ( + buffered_step_dir + / f"batch_{batch.seq_no}_{batch.data_hash}.json" + ) + assert batch_path.exists() and batch_path.is_file() + + def test_load_from_cache(self) -> None: + batch_manager = _BatchManager.from_dict( + { + "steps": { + "step1": { + "step_name": "step1", + "accumulate": True, + "convergence_step": False, + "convergence_step_batches_consumed": {"0": {"Z": 1234}}, + "input_batch_size": None, + "data": { + "step2": [ + { + "seq_no": 0, + "step_name": "step2", + "last_batch": True, + "data": [ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + {"b": 7}, + ] + ], + "data_hash": "1234", + "size": 7, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ], + }, + "built_batches": [ + { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [ + [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + ] + ], + "data_hash": "1234", + "size": 5, + "accumulated": False, + "batch_routed_to": [], + "created_from": {}, + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ], + "seq_no": 0, + "last_batch_received": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManagerStep", + }, + }, + "step2": { + "step_name": "step2", + "accumulate": False, + "convergence_step": False, + "convergence_step_batches_consumed": {"0": {"Z": 1234}}, + "input_batch_size": 50, + "data": { + "step2": [ + { + "seq_no": 0, + "step_name": "step2", + "last_batch": True, + "data": [ + [ + {"b": 1}, + {"b": 2}, + {"b": 3}, + {"b": 4}, + {"b": 5}, + {"b": 6}, + {"b": 7}, + ] + ], + "data_hash": "1234", + "size": 7, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ], + }, + "built_batches": [ + { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [ + [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + ] + ], + "data_hash": "1234", + "size": 5, + "accumulated": False, + "batch_routed_to": [], + "created_from": {}, + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + } + ], + "seq_no": 0, + "last_batch_received": [], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManagerStep", + }, + }, + }, + "last_batch_received": { + "step1": { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + }, + "step2": { + "seq_no": 0, + "step_name": "step2", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + }, + }, + "last_batch_sent": { + "step1": { + "seq_no": 0, + "step_name": "step1", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + }, + "step2": { + "seq_no": 0, + "step_name": "step2", + "last_batch": False, + "data": [], + "size": 0, + "accumulated": False, + "created_from": {}, + "batch_routed_to": [], + "type_info": { + "module": "distilabel.pipeline.batch", + "name": "_Batch", + }, + }, + }, + "last_batch_flag_sent_to": ["step3"], + "type_info": { + "module": "distilabel.pipeline.batch_manager", + "name": "_BatchManager", + }, + } + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + batch_manager_path = Path(tmp_dir) / "batch_manager.json" + batch_manager.cache(batch_manager_path) + loaded_batch_manager = _BatchManager.load_from_cache(batch_manager_path) + + assert batch_manager.dump() == loaded_batch_manager.dump() diff --git a/tests/unit/pipeline/test_local.py b/tests/unit/pipeline/test_local.py index 511f8f5040..4797f8e66d 100644 --- a/tests/unit/pipeline/test_local.py +++ b/tests/unit/pipeline/test_local.py @@ -15,7 +15,8 @@ from typing import TYPE_CHECKING from unittest import mock -from distilabel.pipeline.base import _Batch, _BatchManager +from distilabel.pipeline.batch import _Batch +from distilabel.pipeline.batch_manager import _BatchManager from distilabel.pipeline.local import Pipeline from .utils import DummyGeneratorStep, DummyStep1, DummyStep2 @@ -63,11 +64,6 @@ def test_send_batch_to_step(self, dummy_generator_step: "GeneratorStep") -> None @mock.patch("distilabel.pipeline.local._ProcessWrapper") def test_create_processes(self, process_wrapper_mock: mock.MagicMock) -> None: - pool = mock.MagicMock() - manager = mock.MagicMock() - queue = mock.MagicMock() - shared_info = mock.MagicMock() - with Pipeline(name="unit-test-pipeline") as pipeline: dummy_generator = DummyGeneratorStep(name="dummy_generator_step") dummy_step_1 = DummyStep1(name="dummy_step_1") @@ -76,51 +72,52 @@ def test_create_processes(self, process_wrapper_mock: mock.MagicMock) -> None: dummy_generator.connect(dummy_step_1) dummy_step_1.connect(dummy_step_2) - pipeline._run_steps_in_loop(pool, manager, queue, shared_info) + pipeline._pool = mock.MagicMock() + pipeline._manager = mock.MagicMock() + pipeline._output_queue = mock.MagicMock() + pipeline._load_queue = mock.MagicMock() + pipeline._run_steps() - assert manager.Queue.call_count == 3 + assert pipeline._manager.Queue.call_count == 3 process_wrapper_mock.assert_has_calls( [ mock.call( step=dummy_generator, input_queue=mock.ANY, - output_queue=queue, - shared_info=shared_info, + output_queue=pipeline._output_queue, + load_queue=pipeline._load_queue, dry_run=False, ), mock.call( step=dummy_step_1, input_queue=mock.ANY, - output_queue=queue, - shared_info=shared_info, + output_queue=pipeline._output_queue, + load_queue=pipeline._load_queue, dry_run=False, ), mock.call( step=dummy_step_2, input_queue=mock.ANY, - output_queue=queue, - shared_info=shared_info, + output_queue=pipeline._output_queue, + load_queue=pipeline._load_queue, dry_run=False, ), ], ) - pool.apply_async.assert_has_calls( + pipeline._pool.apply_async.assert_has_calls( [ mock.call( process_wrapper_mock.return_value.run, - callback=pipeline._finished_callback, error_callback=pipeline._error_callback, ), mock.call( process_wrapper_mock.return_value.run, - callback=pipeline._finished_callback, error_callback=pipeline._error_callback, ), mock.call( process_wrapper_mock.return_value.run, - callback=pipeline._finished_callback, error_callback=pipeline._error_callback, ), ] diff --git a/tests/unit/pipeline/test_routing_batch_function.py b/tests/unit/pipeline/test_routing_batch_function.py index 5e3f208c5b..6cc3090eb7 100644 --- a/tests/unit/pipeline/test_routing_batch_function.py +++ b/tests/unit/pipeline/test_routing_batch_function.py @@ -14,7 +14,7 @@ from typing import List -from distilabel.pipeline.base import _Batch +from distilabel.pipeline.batch import _Batch from distilabel.pipeline.local import Pipeline from distilabel.pipeline.routing_batch_function import ( RoutingBatchFunction, diff --git a/tests/unit/pipeline/test_write_buffer.py b/tests/unit/pipeline/test_write_buffer.py new file mode 100644 index 0000000000..a7ae64c91e --- /dev/null +++ b/tests/unit/pipeline/test_write_buffer.py @@ -0,0 +1,150 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +from pathlib import Path + +from distilabel.distiset import Distiset, create_distiset +from distilabel.pipeline.local import Pipeline +from distilabel.pipeline.write_buffer import _WriteBuffer + +from tests.unit.pipeline.utils import ( + DummyGeneratorStep, + DummyStep1, + DummyStep2, + batch_gen, +) + + +class TestWriteBuffer: + def test_create(self) -> None: + with tempfile.TemporaryDirectory() as tmpdirname: + folder = Path(tmpdirname) / "data" + with Pipeline(name="unit-test-pipeline") as pipeline: + dummy_generator_1 = DummyGeneratorStep(name="dummy_generator_step_1") + dummy_generator_2 = DummyGeneratorStep(name="dummy_generator_step_2") + dummy_step_1 = DummyStep1(name="dummy_step_1") + dummy_step_2 = DummyStep2(name="dummy_step_2") + dummy_step_3 = DummyStep2(name="dummy_step_3") + + dummy_generator_1.connect(dummy_step_1) + dummy_generator_2.connect(dummy_step_2) + dummy_step_1.connect(dummy_step_2) + dummy_step_1.connect(dummy_step_3) + + write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps) + + assert write_buffer._buffers == {"dummy_step_2": [], "dummy_step_3": []} + assert write_buffer._buffers_dump_batch_size == { + "dummy_step_2": 50, + "dummy_step_3": 50, + } + assert write_buffer._buffer_last_schema == {} + assert write_buffer._buffers_last_file == { + "dummy_step_2": 1, + "dummy_step_3": 1, + } + + def test_write_buffer_one_leaf_step_and_create_dataset(self) -> None: + with tempfile.TemporaryDirectory() as tmpdirname: + folder = Path(tmpdirname) / "data" + with Pipeline(name="unit-test-pipeline") as pipeline: + dummy_generator = DummyGeneratorStep(name="dummy_generator_step") + dummy_step_1 = DummyStep1(name="dummy_step_1") + dummy_step_2 = DummyStep2(name="dummy_step_2") + + dummy_generator.connect(dummy_step_1) + dummy_step_1.connect(dummy_step_2) + + write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps) + + # Add one batch with 5 rows, shouldn't write anything 5 < 50 + batch = batch_gen(dummy_step_2.name) # type: ignore + write_buffer.add_batch(batch) + + # Add 45 more rows, should write now + for _ in range(9): + batch = batch_gen(dummy_step_2.name) # type: ignore + write_buffer.add_batch(batch) + + assert Path(folder, "dummy_step_2", "00001.parquet").exists() + + # Add 50 more rows, we should have a new file + for _ in range(10): + batch = batch_gen(dummy_step_2.name) # type: ignore + write_buffer.add_batch(batch) + + assert Path(folder, "dummy_step_2", "00002.parquet").exists() + + # Add more rows and close the write buffer, we should have a new file + for _ in range(5): + batch = batch_gen(dummy_step_2.name) # type: ignore + write_buffer.add_batch(batch) + + write_buffer.close() + + assert Path(folder, "dummy_step_2", "00003.parquet").exists() + + ds = create_distiset(write_buffer._path) + assert isinstance(ds, Distiset) + assert len(ds.keys()) == 1 + assert len(ds["default"]["train"]) == 125 + + def test_write_buffer_multiple_leaf_steps_and_create_dataset(self) -> None: + with tempfile.TemporaryDirectory() as tmpdirname: + folder = Path(tmpdirname) / "data" + with Pipeline(name="unit-test-pipeline") as pipeline: + dummy_generator_1 = DummyGeneratorStep(name="dummy_generator_step_1") + dummy_generator_2 = DummyGeneratorStep(name="dummy_generator_step_2") + dummy_step_1 = DummyStep1(name="dummy_step_1") + dummy_step_2 = DummyStep2(name="dummy_step_2") + dummy_step_3 = DummyStep2(name="dummy_step_3") + + dummy_generator_1.connect(dummy_step_1) + dummy_generator_2.connect(dummy_step_2) + dummy_step_1.connect(dummy_step_2) + dummy_step_1.connect(dummy_step_3) + + write_buffer = _WriteBuffer(path=folder, leaf_steps=pipeline.dag.leaf_steps) + + for _ in range(10): + batch = batch_gen(dummy_step_2.name) # type: ignore + write_buffer.add_batch(batch) + + assert Path(folder, "dummy_step_2", "00001.parquet").exists() + + for _ in range(10): + batch = batch_gen(dummy_step_3.name) # type: ignore + write_buffer.add_batch(batch) + + assert Path(folder, "dummy_step_3", "00001.parquet").exists() + + for _ in range(5): + batch = batch_gen(dummy_step_2.name) # type: ignore + write_buffer.add_batch(batch) + + for _ in range(5): + batch = batch_gen(dummy_step_3.name) # type: ignore + write_buffer.add_batch(batch) + + write_buffer.close() + + assert Path(folder, "dummy_step_2", "00002.parquet").exists() + assert Path(folder, "dummy_step_3", "00002.parquet").exists() + + ds = create_distiset(write_buffer._path) + assert isinstance(ds, Distiset) + assert len(ds.keys()) == 2 + assert len(ds["dummy_step_2"]["train"]) == 75 + assert len(ds["dummy_step_3"]["train"]) == 75 diff --git a/tests/unit/pipeline/utils.py b/tests/unit/pipeline/utils.py index 8d02340114..7f771271d0 100644 --- a/tests/unit/pipeline/utils.py +++ b/tests/unit/pipeline/utils.py @@ -14,7 +14,7 @@ from typing import List -from distilabel.pipeline.base import _Batch +from distilabel.pipeline.batch import _Batch from distilabel.steps.base import GeneratorStep, GlobalStep, Step, StepInput from distilabel.steps.typing import GeneratorStepOutput, StepOutput