diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py index f858af48ea..b2acd7e37e 100644 --- a/src/distilabel/pipeline/base.py +++ b/src/distilabel/pipeline/base.py @@ -452,15 +452,20 @@ class _BatchManagerStep(_Serializable): seq_no: int = 0 last_batch_received: List[str] = field(default_factory=list) - def add_batch(self, batch: _Batch) -> None: + def add_batch(self, batch: _Batch, prepend: bool = False) -> None: """Add a batch of data from `batch.step_name` to the step. It will accumulate the data and keep track of the last batch received from the predecessors. Args: batch: The output batch of an step to be processed by the step. + prepend: If `True`, the content of the batch will be added at the start of + the buffer. """ from_step = batch.step_name - self.data[from_step].extend(batch.data[0]) + if prepend: + self.data[from_step] = batch.data[0] + self.data[from_step] + else: + self.data[from_step].extend(batch.data[0]) if batch.last_batch: self.last_batch_received.append(from_step) @@ -676,12 +681,14 @@ def register_batch(self, batch: _Batch) -> None: def get_last_batch(self, step_name: str) -> Union[_Batch, None]: return self._last_batch_received.get(step_name) - def add_batch(self, to_step: str, batch: _Batch) -> None: + def add_batch(self, to_step: str, batch: _Batch, prepend: bool = False) -> None: """Add an output batch from `batch.step_name` to `to_step`. Args: to_step: The name of the step that will process the batch. batch: The output batch of an step to be processed by `to_step`. + prepend: If `True`, the content of the batch will be added at the start of + the buffer. Raises: ValueError: If `to_step` is not found in the batch manager. @@ -690,7 +697,7 @@ def add_batch(self, to_step: str, batch: _Batch) -> None: raise ValueError(f"Step '{to_step}' not found in the batch manager.") step = self._steps[to_step] - step.add_batch(batch) + step.add_batch(batch, prepend) def get_batch(self, step_name: str) -> Union[_Batch, None]: """Get the next batch to be processed by the step. diff --git a/src/distilabel/pipeline/local.py b/src/distilabel/pipeline/local.py index c41dc26a45..ec6cd08177 100644 --- a/src/distilabel/pipeline/local.py +++ b/src/distilabel/pipeline/local.py @@ -49,8 +49,6 @@ _STEPS_FINISHED = set() _STEPS_FINISHED_LOCK = threading.Lock() -_STOP_LOOP = False - def _init_worker(queue: "Queue[Any]") -> None: signal.signal(signal.SIGINT, signal.SIG_IGN) @@ -167,7 +165,7 @@ def _output_queue_loop(self, write_buffer: "_WriteBuffer") -> None: Args: write_buffer: The write buffer to write the data from the leaf steps to disk. """ - while self._batch_manager.can_generate() and not _STOP_LOOP: # type: ignore + while self._batch_manager.can_generate() and not _STOP_CALLED: # type: ignore self._logger.debug("Waiting for output batch from step...") if (batch := self.output_queue.get()) is None: self._logger.debug("Received `None` from output queue. Breaking loop.") @@ -176,13 +174,12 @@ def _output_queue_loop(self, write_buffer: "_WriteBuffer") -> None: if batch.step_name in self.dag.leaf_steps: write_buffer.add_batch(batch) - # If `_STOP_LOOP` was set to `True` while waiting for the output queue, then + # If `_STOP_CALLED` was set to `True` while waiting for the output queue, then # we need to handle the stop of the pipeline and break the loop to avoid # propagating the batches through the pipeline and making the stop process # slower. - if _STOP_LOOP: + if _STOP_CALLED: self._handle_batch_on_stop(batch) - self._handle_stop(write_buffer) break self._logger.debug( @@ -192,7 +189,7 @@ def _output_queue_loop(self, write_buffer: "_WriteBuffer") -> None: self._manage_batch_flow(batch) - if _STOP_LOOP: + if _STOP_CALLED: self._handle_stop(write_buffer) def _manage_batch_flow(self, batch: "_Batch") -> None: @@ -266,15 +263,16 @@ def _handle_stop(self, write_buffer: "_WriteBuffer") -> None: # Send `None` to the input queues of all the steps to notify them to stop # processing batches. for step_name in self.dag: - if input_queue := self._wait_step_input_queue_empty(step_name): - if self._check_step_not_loaded_or_finished(step_name): + if input_queue := self.dag.get_step(step_name).get("input_queue"): + while not input_queue.empty(): + batch = input_queue.get() + self._batch_manager.add_batch( # type: ignore + to_step=step_name, batch=batch, prepend=True + ) self._logger.debug( - f"Step '{step_name}' not loaded or already finished. Skipping sending" - " sentinel `None`" + f"Adding batch back to the batch manager: {batch}" ) - continue input_queue.put(None) - self._logger.debug(f"Send `None` to step '{step_name}' input queue.") # Wait for the input queue to be empty, which means that all the steps finished # processing the batches that were sent before the stop flag. @@ -352,7 +350,7 @@ def _update_all_steps_loaded(steps_loaded: List[str]) -> None: self._logger.info("⏳ Waiting for all the steps to load...") previous_message = None - while True: + while not _STOP_CALLED: with self.shared_info[_STEPS_LOADED_LOCK_KEY]: steps_loaded = self.shared_info[_STEPS_LOADED_KEY] num_steps_loaded = ( @@ -379,12 +377,17 @@ def _update_all_steps_loaded(steps_loaded: List[str]) -> None: time.sleep(2.5) + return not _STOP_CALLED + def _request_initial_batches(self) -> None: """Requests the initial batches to the generator steps.""" assert self._batch_manager, "Batch manager is not set" for step in self._batch_manager._steps.values(): if batch := step.get_batch(): + self._logger.debug( + f"Sending initial batch to '{step.step_name}' step: {batch}" + ) self._send_batch_to_step(batch) for step_name in self.dag.root_steps: @@ -392,6 +395,9 @@ def _request_initial_batches(self) -> None: if last_batch := self._batch_manager.get_last_batch(step_name): seq_no = last_batch.seq_no + 1 batch = _Batch(seq_no=seq_no, step_name=step_name, last_batch=False) + self._logger.debug( + f"Requesting initial batch to '{step_name}' generator step: {batch}" + ) self._send_batch_to_step(batch) def _send_batch_to_step(self, batch: "_Batch") -> None: @@ -520,9 +526,7 @@ def _stop(self) -> None: finished processing the batches that were sent before the stop flag. Then it will send `None` to the output queue to notify the pipeline to stop.""" - global _STOP_LOOP, _STOP_CALLED - - _STOP_LOOP = True + global _STOP_CALLED with _STOP_CALLED_LOCK: if _STOP_CALLED: diff --git a/src/distilabel/steps/argilla/base.py b/src/distilabel/steps/argilla/base.py index ea7627b914..a626a2cae5 100644 --- a/src/distilabel/steps/argilla/base.py +++ b/src/distilabel/steps/argilla/base.py @@ -45,15 +45,19 @@ class Argilla(Step, ABC): This class is not intended to be instanced directly, but via subclass. Attributes: - dataset_name: The name of the dataset in Argilla. + dataset_name: The name of the dataset in Argilla where the records will be added. dataset_workspace: The workspace where the dataset will be created in Argilla. Defaults to - None, which means it will be created in the default workspace. + `None`, which means it will be created in the default workspace. api_url: The URL of the Argilla API. Defaults to `None`, which means it will be read from the `ARGILLA_API_URL` environment variable. api_key: The API key to authenticate with Argilla. Defaults to `None`, which means it will be read from the `ARGILLA_API_KEY` environment variable. Runtime parameters: + - `dataset_name`: The name of the dataset in Argilla where the records will be + added. + - `dataset_workspace`: The workspace where the dataset will be created in Argilla. + Defaults to `None`, which means it will be created in the default workspace. - `api_url`: The base URL to use for the Argilla API requests. - `api_key`: The API key to authenticate the requests to the Argilla API. @@ -61,11 +65,17 @@ class Argilla(Step, ABC): - dynamic, based on the `inputs` value provided """ - dataset_name: str - dataset_workspace: Optional[str] = None + dataset_name: RuntimeParameter[str] = Field( + default=None, description="The name of the dataset in Argilla." + ) + dataset_workspace: Optional[RuntimeParameter[str]] = Field( + default=None, + description="The workspace where the dataset will be created in Argilla. Defaults" + "to `None` which means it will be created in the default workspace.", + ) api_url: Optional[RuntimeParameter[str]] = Field( - default_factory=lambda: os.getenv("ARGILLA_BASE_URL"), + default_factory=lambda: os.getenv("ARGILLA_API_URL"), description="The base URL to use for the Argilla API requests.", ) api_key: Optional[RuntimeParameter[SecretStr]] = Field( @@ -122,6 +132,8 @@ def load(self) -> None: """ super().load() + self._rg_init() + @property @abstractmethod def inputs(self) -> List[str]: diff --git a/src/distilabel/steps/argilla/preference.py b/src/distilabel/steps/argilla/preference.py index 576c9bf502..08c0722d83 100644 --- a/src/distilabel/steps/argilla/preference.py +++ b/src/distilabel/steps/argilla/preference.py @@ -88,8 +88,6 @@ def load(self) -> None: """ super().load() - self._rg_init() - # Both `instruction` and `generations` will be used as the fields of the dataset self._instruction = self.input_mappings.get("instruction", "instruction") self._generations = self.input_mappings.get("generations", "generations") diff --git a/src/distilabel/steps/argilla/text_generation.py b/src/distilabel/steps/argilla/text_generation.py index 6e7ca21ec8..8ccea11790 100644 --- a/src/distilabel/steps/argilla/text_generation.py +++ b/src/distilabel/steps/argilla/text_generation.py @@ -69,8 +69,6 @@ def load(self) -> None: """ super().load() - self._rg_init() - self._instruction = self.input_mappings.get("instruction", "instruction") self._generation = self.input_mappings.get("generation", "generation") diff --git a/tests/unit/pipeline/test_base.py b/tests/unit/pipeline/test_base.py index 1ebf685183..1939199bfd 100644 --- a/tests/unit/pipeline/test_base.py +++ b/tests/unit/pipeline/test_base.py @@ -264,6 +264,46 @@ def test_add_batch(self) -> None: assert batch_manager_step.data["step1"] == [{"a": 1}, {"a": 2}, {"a": 3}] assert batch_manager_step.last_batch_received == [] + def test_add_batch_with_prepend(self) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step2", + accumulate=False, + input_batch_size=10, + data={ + "step1": [ + {"a": 6}, + {"a": 7}, + {"a": 8}, + {"a": 9}, + {"a": 10}, + ] + }, + ) + + batch_manager_step.add_batch( + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ), + prepend=True, + ) + + assert batch_manager_step.data["step1"] == [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + {"a": 6}, + {"a": 7}, + {"a": 8}, + {"a": 9}, + {"a": 10}, + ] + assert batch_manager_step.last_batch_received == [] + def test_add_batch_last_batch(self) -> None: batch_manager_step = _BatchManagerStep( step_name="step2", accumulate=False, input_batch_size=10, data={"step1": []} @@ -784,9 +824,56 @@ def test_add_batch(self) -> None: last_batch=False, data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], ) - batch = batch_manager.add_batch(to_step="step3", batch=batch_from_step_1) + batch_manager.add_batch(to_step="step3", batch=batch_from_step_1) - assert batch is None + assert batch_manager._steps["step3"].data == { + "step1": [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + ], + "step2": [], + } + + def test_add_batch_with_prepend(self) -> None: + batch_manager = _BatchManager( + steps={ + "step3": _BatchManagerStep( + step_name="step3", + accumulate=False, + input_batch_size=5, + data={ + "step1": [{"a": 6}, {"a": 7}, {"a": 8}, {"a": 9}, {"a": 10}], + "step2": [], + }, + ) + }, + last_batch_received={"step3": None}, + ) + batch_from_step_1 = _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + ) + batch_manager.add_batch(to_step="step3", batch=batch_from_step_1, prepend=True) + assert batch_manager._steps["step3"].data == { + "step1": [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + {"a": 6}, + {"a": 7}, + {"a": 8}, + {"a": 9}, + {"a": 10}, + ], + "step2": [], + } def test_add_batch_enough_data(self) -> None: batch_manager = _BatchManager( diff --git a/tests/unit/steps/argilla/test_base.py b/tests/unit/steps/argilla/test_base.py index fcc9dd6ebc..14895e6caf 100644 --- a/tests/unit/steps/argilla/test_base.py +++ b/tests/unit/steps/argilla/test_base.py @@ -19,7 +19,6 @@ from distilabel.pipeline.local import Pipeline from distilabel.steps.argilla.base import Argilla from distilabel.steps.base import StepInput -from pydantic import ValidationError if TYPE_CHECKING: from distilabel.steps.typing import StepOutput @@ -82,17 +81,6 @@ def test_with_errors(self) -> None: dataset_workspace="argilla", ) - with pytest.raises( - ValidationError, match="dataset_name\n Field required \\[type=missing" - ): - CustomArgilla( - name="step", - api_url="https://example.com", - api_key="api.key", # type: ignore - dataset_workspace="argilla", - pipeline=Pipeline(name="unit-test-pipeline"), - ) - with pytest.raises( TypeError, match="Can't instantiate abstract class Argilla with abstract methods inputs, process", @@ -136,6 +124,18 @@ def test_serialization(self) -> None: "name": "input_batch_size", "optional": True, }, + { + "description": "The name of the dataset in Argilla.", + "name": "dataset_name", + "optional": False, + }, + { + "description": "The workspace where the dataset will be created in Argilla. " + "Defaultsto `None` which means it will be created in the default " + "workspace.", + "name": "dataset_workspace", + "optional": True, + }, { "name": "api_url", "optional": True, diff --git a/tests/unit/steps/argilla/test_preference.py b/tests/unit/steps/argilla/test_preference.py index d460c82280..447550cc59 100644 --- a/tests/unit/steps/argilla/test_preference.py +++ b/tests/unit/steps/argilla/test_preference.py @@ -109,6 +109,18 @@ def test_serialization(self) -> None: "name": "input_batch_size", "optional": True, }, + { + "description": "The name of the dataset in Argilla.", + "name": "dataset_name", + "optional": False, + }, + { + "description": "The workspace where the dataset will be created in Argilla. " + "Defaultsto `None` which means it will be created in the default " + "workspace.", + "name": "dataset_workspace", + "optional": True, + }, { "name": "api_url", "optional": True, diff --git a/tests/unit/steps/argilla/test_text_generation.py b/tests/unit/steps/argilla/test_text_generation.py index f18f9dcb0c..1a1126a3a7 100644 --- a/tests/unit/steps/argilla/test_text_generation.py +++ b/tests/unit/steps/argilla/test_text_generation.py @@ -82,6 +82,18 @@ def test_serialization(self) -> None: "name": "input_batch_size", "optional": True, }, + { + "description": "The name of the dataset in Argilla.", + "name": "dataset_name", + "optional": False, + }, + { + "description": "The workspace where the dataset will be created in Argilla. " + "Defaultsto `None` which means it will be created in the default " + "workspace.", + "name": "dataset_workspace", + "optional": True, + }, { "name": "api_url", "optional": True,