Skip to content

Commit

Permalink
Send as many Nones as replicas in the step (#982)
Browse files Browse the repository at this point in the history
gabrielmbmb authored Sep 16, 2024
1 parent e1253a6 commit 75e34e1
Showing 2 changed files with 11 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/distilabel/pipeline/base.py
Original file line number Diff line number Diff line change
@@ -1496,7 +1496,8 @@ def _notify_steps_to_stop(self) -> None:
with self._steps_load_status_lock:
for step_name, replicas in self._steps_load_status.items():
if replicas > 0:
self._send_to_step(step_name, None)
for _ in range(replicas):
self._send_to_step(step_name, None)

def _get_successors(self, batch: "_Batch") -> Tuple[List[str], List[str], bool]:
"""Gets the successors and the successors to which the batch has to be routed.
9 changes: 9 additions & 0 deletions src/distilabel/pipeline/step_wrapper.py
Original file line number Diff line number Diff line change
@@ -134,14 +134,23 @@ def run(self) -> str:

def _notify_load(self) -> None:
"""Notifies that the step has finished executing its `load` function successfully."""
self.step._logger.debug(
f"Notifying load of step '{self.step.name}' (replica ID {self.replica})..."
)
self.load_queue.put({"name": self.step.name, "status": "loaded"}) # type: ignore

def _notify_unload(self) -> None:
"""Notifies that the step has been unloaded."""
self.step._logger.debug(
f"Notifying unload of step '{self.step.name}' (replica ID {self.replica})..."
)
self.load_queue.put({"name": self.step.name, "status": "unloaded"}) # type: ignore

def _notify_load_failed(self) -> None:
"""Notifies that the step failed to load."""
self.step._logger.debug(
f"Notifying load failed of step '{self.step.name}' (replica ID {self.replica})..."
)
self.load_queue.put({"name": self.step.name, "status": "load_failed"}) # type: ignore

def _generator_step_process_loop(self) -> None:

0 comments on commit 75e34e1

Please sign in to comment.