From 25d237b947c0b311b1f0fd248a9cf2c30e2750f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Mon, 15 Apr 2024 14:57:43 +0200 Subject: [PATCH] Fix pipeline stuck when empty batches (#531) * Fix missing sending `None` to output queue * Fix pipeline stuck when step sent empty batch --- src/distilabel/pipeline/base.py | 25 +++++++++++++++++++++++++ src/distilabel/pipeline/local.py | 12 ++++++++++++ tests/unit/pipeline/test_base.py | 4 ++++ 3 files changed, 41 insertions(+) diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py index b2acd7e37e..884d05a75c 100644 --- a/src/distilabel/pipeline/base.py +++ b/src/distilabel/pipeline/base.py @@ -368,6 +368,15 @@ def next_batch(self) -> "_Batch": seq_no=self.seq_no + 1, step_name=self.step_name, last_batch=self.last_batch ) + @property + def empty(self) -> bool: + """Checks if the batch is empty. + + Returns: + `True` if the batch is empty. Otherwise, `False`. + """ + return all(len(rows) == 0 for rows in self.data) + @classmethod def from_batches(cls, step_name: str, batches: List["_Batch"]) -> "_Batch": """Create a `_Batch` instance with the outputs from the list of batches that @@ -665,6 +674,22 @@ def can_generate(self) -> bool: `True` if there are still batches to be processed by the steps. Otherwise, `False`. """ + + # Check if any step that hasn't finished producing data (we haven't received its + # last batch) still needs data from its predecessors, and those predecessors have + # already sent their last batch and it's empty. In this case, we cannot continue + # the pipeline. + for batch_manager_step in self._steps.values(): + for predecessor in batch_manager_step.data.keys(): + batch = self._last_batch_received.get(predecessor) + if ( + batch + and batch.last_batch + and batch.empty + and batch_manager_step.data[predecessor] == [] + ): + return False + return not all( batch and batch.last_batch for batch in self._last_batch_received.values() ) diff --git a/src/distilabel/pipeline/local.py b/src/distilabel/pipeline/local.py index ec6cd08177..4770125fb4 100644 --- a/src/distilabel/pipeline/local.py +++ b/src/distilabel/pipeline/local.py @@ -136,6 +136,9 @@ def run( # Start a loop to receive the output batches from the steps self._run_output_queue_loop_in_thread(write_buffer) + # Send `None` to steps `input_queue`s just in case some step is still waiting + self._notify_steps_to_stop() + pool.close() pool.join() @@ -158,6 +161,13 @@ def _run_output_queue_loop_in_thread(self, write_buffer: "_WriteBuffer") -> None thread.start() thread.join() + def _notify_steps_to_stop(self) -> None: + """Notifies the steps to stop their infinite running loop by sending `None` to + their input queues.""" + for step_name in self.dag: + if input_queue := self.dag.get_step(step_name).get("input_queue"): + input_queue.put(None) + def _output_queue_loop(self, write_buffer: "_WriteBuffer") -> None: """Loop to receive the output batches from the steps and manage the flow of the batches through the pipeline. @@ -541,6 +551,8 @@ def _stop(self) -> None: self._logger.info( "🛑 Stopping pipeline. Waiting for steps to finish processing batches..." ) + self._logger.debug("Sending `None` to the output queue to notify stop...") + self.output_queue.put(None) def _handle_keyboard_interrupt(self) -> None: """Handles KeyboardInterrupt signal sent during the Pipeline.run method. diff --git a/tests/unit/pipeline/test_base.py b/tests/unit/pipeline/test_base.py index 1939199bfd..ded8dc816b 100644 --- a/tests/unit/pipeline/test_base.py +++ b/tests/unit/pipeline/test_base.py @@ -202,6 +202,10 @@ def test_accumulate(self) -> None: [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}], ] + def test_empty(self) -> None: + batch = _Batch(seq_no=0, step_name="step1", last_batch=False, data=[[]]) + assert batch.empty + def test_dump(self) -> None: batch = _Batch(seq_no=0, step_name="step1", last_batch=False) assert batch.dump() == {