Skip to content

Commit

Permalink
Empty input queues when CTRL + C (#528)
Browse files Browse the repository at this point in the history
* Add `prepend` argument to `add_batch`

* Empty step input queue and add back to batch manager in stop
  • Loading branch information
gabrielmbmb authored Apr 15, 2024
1 parent d4abf13 commit 8063572
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 13 deletions.
15 changes: 11 additions & 4 deletions src/distilabel/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,15 +452,20 @@ class _BatchManagerStep(_Serializable):
seq_no: int = 0
last_batch_received: List[str] = field(default_factory=list)

def add_batch(self, batch: _Batch) -> None:
def add_batch(self, batch: _Batch, prepend: bool = False) -> None:
"""Add a batch of data from `batch.step_name` to the step. It will accumulate the
data and keep track of the last batch received from the predecessors.
Args:
batch: The output batch of an step to be processed by the step.
prepend: If `True`, the content of the batch will be added at the start of
the buffer.
"""
from_step = batch.step_name
self.data[from_step].extend(batch.data[0])
if prepend:
self.data[from_step] = batch.data[0] + self.data[from_step]
else:
self.data[from_step].extend(batch.data[0])
if batch.last_batch:
self.last_batch_received.append(from_step)

Expand Down Expand Up @@ -676,12 +681,14 @@ def register_batch(self, batch: _Batch) -> None:
def get_last_batch(self, step_name: str) -> Union[_Batch, None]:
return self._last_batch_received.get(step_name)

def add_batch(self, to_step: str, batch: _Batch) -> None:
def add_batch(self, to_step: str, batch: _Batch, prepend: bool = False) -> None:
"""Add an output batch from `batch.step_name` to `to_step`.
Args:
to_step: The name of the step that will process the batch.
batch: The output batch of an step to be processed by `to_step`.
prepend: If `True`, the content of the batch will be added at the start of
the buffer.
Raises:
ValueError: If `to_step` is not found in the batch manager.
Expand All @@ -690,7 +697,7 @@ def add_batch(self, to_step: str, batch: _Batch) -> None:
raise ValueError(f"Step '{to_step}' not found in the batch manager.")

step = self._steps[to_step]
step.add_batch(batch)
step.add_batch(batch, prepend)

def get_batch(self, step_name: str) -> Union[_Batch, None]:
"""Get the next batch to be processed by the step.
Expand Down
20 changes: 13 additions & 7 deletions src/distilabel/pipeline/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ def _output_queue_loop(self, write_buffer: "_WriteBuffer") -> None:
# slower.
if _STOP_CALLED:
self._handle_batch_on_stop(batch)
self._handle_stop(write_buffer)
break

self._logger.debug(
Expand Down Expand Up @@ -264,15 +263,16 @@ def _handle_stop(self, write_buffer: "_WriteBuffer") -> None:
# Send `None` to the input queues of all the steps to notify them to stop
# processing batches.
for step_name in self.dag:
if input_queue := self._wait_step_input_queue_empty(step_name):
if self._check_step_not_loaded_or_finished(step_name):
if input_queue := self.dag.get_step(step_name).get("input_queue"):
while not input_queue.empty():
batch = input_queue.get()
self._batch_manager.add_batch( # type: ignore
to_step=step_name, batch=batch, prepend=True
)
self._logger.debug(
f"Step '{step_name}' not loaded or already finished. Skipping sending"
" sentinel `None`"
f"Adding batch back to the batch manager: {batch}"
)
continue
input_queue.put(None)
self._logger.debug(f"Send `None` to step '{step_name}' input queue.")

# Wait for the input queue to be empty, which means that all the steps finished
# processing the batches that were sent before the stop flag.
Expand Down Expand Up @@ -385,13 +385,19 @@ def _request_initial_batches(self) -> None:

for step in self._batch_manager._steps.values():
if batch := step.get_batch():
self._logger.debug(
f"Sending initial batch to '{step.step_name}' step: {batch}"
)
self._send_batch_to_step(batch)

for step_name in self.dag.root_steps:
seq_no = 0
if last_batch := self._batch_manager.get_last_batch(step_name):
seq_no = last_batch.seq_no + 1
batch = _Batch(seq_no=seq_no, step_name=step_name, last_batch=False)
self._logger.debug(
f"Requesting initial batch to '{step_name}' generator step: {batch}"
)
self._send_batch_to_step(batch)

def _send_batch_to_step(self, batch: "_Batch") -> None:
Expand Down
91 changes: 89 additions & 2 deletions tests/unit/pipeline/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,46 @@ def test_add_batch(self) -> None:
assert batch_manager_step.data["step1"] == [{"a": 1}, {"a": 2}, {"a": 3}]
assert batch_manager_step.last_batch_received == []

def test_add_batch_with_prepend(self) -> None:
batch_manager_step = _BatchManagerStep(
step_name="step2",
accumulate=False,
input_batch_size=10,
data={
"step1": [
{"a": 6},
{"a": 7},
{"a": 8},
{"a": 9},
{"a": 10},
]
},
)

batch_manager_step.add_batch(
_Batch(
seq_no=0,
step_name="step1",
last_batch=False,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
),
prepend=True,
)

assert batch_manager_step.data["step1"] == [
{"a": 1},
{"a": 2},
{"a": 3},
{"a": 4},
{"a": 5},
{"a": 6},
{"a": 7},
{"a": 8},
{"a": 9},
{"a": 10},
]
assert batch_manager_step.last_batch_received == []

def test_add_batch_last_batch(self) -> None:
batch_manager_step = _BatchManagerStep(
step_name="step2", accumulate=False, input_batch_size=10, data={"step1": []}
Expand Down Expand Up @@ -784,9 +824,56 @@ def test_add_batch(self) -> None:
last_batch=False,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
)
batch = batch_manager.add_batch(to_step="step3", batch=batch_from_step_1)
batch_manager.add_batch(to_step="step3", batch=batch_from_step_1)

assert batch is None
assert batch_manager._steps["step3"].data == {
"step1": [
{"a": 1},
{"a": 2},
{"a": 3},
{"a": 4},
{"a": 5},
],
"step2": [],
}

def test_add_batch_with_prepend(self) -> None:
batch_manager = _BatchManager(
steps={
"step3": _BatchManagerStep(
step_name="step3",
accumulate=False,
input_batch_size=5,
data={
"step1": [{"a": 6}, {"a": 7}, {"a": 8}, {"a": 9}, {"a": 10}],
"step2": [],
},
)
},
last_batch_received={"step3": None},
)
batch_from_step_1 = _Batch(
seq_no=0,
step_name="step1",
last_batch=False,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
)
batch_manager.add_batch(to_step="step3", batch=batch_from_step_1, prepend=True)
assert batch_manager._steps["step3"].data == {
"step1": [
{"a": 1},
{"a": 2},
{"a": 3},
{"a": 4},
{"a": 5},
{"a": 6},
{"a": 7},
{"a": 8},
{"a": 9},
{"a": 10},
],
"step2": [],
}

def test_add_batch_enough_data(self) -> None:
batch_manager = _BatchManager(
Expand Down

0 comments on commit 8063572

Please sign in to comment.