Skip to content

Commit

Permalink
Fix pipeline stuck when empty batches (#531)
Browse files Browse the repository at this point in the history
* Fix missing sending `None` to output queue

* Fix pipeline stuck when step sent empty batch
  • Loading branch information
gabrielmbmb authored Apr 15, 2024
1 parent 34bab6a commit 25d237b
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 0 deletions.
25 changes: 25 additions & 0 deletions src/distilabel/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,15 @@ def next_batch(self) -> "_Batch":
seq_no=self.seq_no + 1, step_name=self.step_name, last_batch=self.last_batch
)

@property
def empty(self) -> bool:
"""Checks if the batch is empty.
Returns:
`True` if the batch is empty. Otherwise, `False`.
"""
return all(len(rows) == 0 for rows in self.data)

@classmethod
def from_batches(cls, step_name: str, batches: List["_Batch"]) -> "_Batch":
"""Create a `_Batch` instance with the outputs from the list of batches that
Expand Down Expand Up @@ -665,6 +674,22 @@ def can_generate(self) -> bool:
`True` if there are still batches to be processed by the steps. Otherwise,
`False`.
"""

# Check if any step that hasn't finished producing data (we haven't received its
# last batch) still needs data from its predecessors, and those predecessors have
# already sent their last batch and it's empty. In this case, we cannot continue
# the pipeline.
for batch_manager_step in self._steps.values():
for predecessor in batch_manager_step.data.keys():
batch = self._last_batch_received.get(predecessor)
if (
batch
and batch.last_batch
and batch.empty
and batch_manager_step.data[predecessor] == []
):
return False

return not all(
batch and batch.last_batch for batch in self._last_batch_received.values()
)
Expand Down
12 changes: 12 additions & 0 deletions src/distilabel/pipeline/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ def run(
# Start a loop to receive the output batches from the steps
self._run_output_queue_loop_in_thread(write_buffer)

# Send `None` to steps `input_queue`s just in case some step is still waiting
self._notify_steps_to_stop()

pool.close()
pool.join()

Expand All @@ -158,6 +161,13 @@ def _run_output_queue_loop_in_thread(self, write_buffer: "_WriteBuffer") -> None
thread.start()
thread.join()

def _notify_steps_to_stop(self) -> None:
"""Notifies the steps to stop their infinite running loop by sending `None` to
their input queues."""
for step_name in self.dag:
if input_queue := self.dag.get_step(step_name).get("input_queue"):
input_queue.put(None)

def _output_queue_loop(self, write_buffer: "_WriteBuffer") -> None:
"""Loop to receive the output batches from the steps and manage the flow of the
batches through the pipeline.
Expand Down Expand Up @@ -541,6 +551,8 @@ def _stop(self) -> None:
self._logger.info(
"🛑 Stopping pipeline. Waiting for steps to finish processing batches..."
)
self._logger.debug("Sending `None` to the output queue to notify stop...")
self.output_queue.put(None)

def _handle_keyboard_interrupt(self) -> None:
"""Handles KeyboardInterrupt signal sent during the Pipeline.run method.
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/pipeline/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ def test_accumulate(self) -> None:
[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}],
]

def test_empty(self) -> None:
batch = _Batch(seq_no=0, step_name="step1", last_batch=False, data=[[]])
assert batch.empty

def test_dump(self) -> None:
batch = _Batch(seq_no=0, step_name="step1", last_batch=False)
assert batch.dump() == {
Expand Down

0 comments on commit 25d237b

Please sign in to comment.