diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 3b9de467f..01f1ebcb9 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -59,3 +59,4 @@ jobs:
- name: Integration Tests
run: make integration-tests
+ timeout-minutes: 5
diff --git a/README.md b/README.md
index 097d3b5c8..4e071df69 100644
--- a/README.md
+++ b/README.md
@@ -1,11 +1,3 @@
----
-description: Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs.
-hide:
- - toc
----
-
-
-
diff --git a/pyproject.toml b/pyproject.toml
index 0e5176c08..c3317bf5b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -55,7 +55,7 @@ docs = [
"CairoSVG >= 2.7.1",
"mknotebooks >= 0.8.0",
]
-tests = ["pytest >= 7.4.0", "pytest-asyncio", "nest-asyncio"]
+tests = ["pytest >= 7.4.0", "pytest-asyncio", "nest-asyncio", "pytest-timeout"]
# Optional LLMs, integrations, etc
anthropic = ["anthropic >= 0.20.0"]
diff --git a/src/distilabel/distiset.py b/src/distilabel/distiset.py
index 245ed29c1..b9bd1b0e0 100644
--- a/src/distilabel/distiset.py
+++ b/src/distilabel/distiset.py
@@ -234,22 +234,22 @@ def create_distiset( # noqa: C901
continue
files = [str(file) for file in list_files_in_dir(file)]
- try:
- if files:
+ if files:
+ try:
ds = load_dataset(
"parquet", name=file.stem, data_files={"train": files}
)
if not enable_metadata and DISTILABEL_METADATA_KEY in ds.column_names:
ds = ds.remove_columns(DISTILABEL_METADATA_KEY)
distiset[file.stem] = ds
- else:
- logger.warning(
- f"No output files for step '{file.stem}', can't create a dataset."
- " Did the step produce any data?"
- )
- except ArrowInvalid:
- logger.warning(f"❌ Failed to load the subset from '{file}' directory.")
- continue
+ except ArrowInvalid:
+ logger.warning(f"❌ Failed to load the subset from '{file}' directory.")
+ continue
+ else:
+ logger.warning(
+ f"No output files for step '{file.stem}', can't create a dataset."
+ " Did the step produce any data?"
+ )
# If there's only one dataset i.e. one config, then set the config name to `default`
if len(distiset.keys()) == 1:
diff --git a/src/distilabel/pipeline/_dag.py b/src/distilabel/pipeline/_dag.py
index 86c4e80bf..873fe9e93 100644
--- a/src/distilabel/pipeline/_dag.py
+++ b/src/distilabel/pipeline/_dag.py
@@ -30,6 +30,7 @@
import networkx as nx
from distilabel.pipeline.constants import (
+ CONVERGENCE_STEP_ATTR_NAME,
ROUTING_BATCH_FUNCTION_ATTR_NAME,
STEP_ATTR_NAME,
)
@@ -353,6 +354,9 @@ def _validate_convergence_step(
):
return
+ # Mark the step as a convergence step
+ self.set_step_attr(step.name, CONVERGENCE_STEP_ATTR_NAME, True) # type: ignore
+
# Check if all the predecessors of the step are receiving routed batches from the
# same step
previous_steps_predecessors = [
@@ -431,6 +435,14 @@ def _validate_routing_batch_function(
f" from step '{predecessor_step.name}' to step '{step.name}'."
)
+ if batch_size % step.input_batch_size != 0: # type: ignore
+ raise ValueError(
+ f"Step '{step.name}' should have an `input_batch_size` that is a multiple"
+ f" of the `input_batch_size` or `batch_size` of the previous step."
+ f" This is because the batches are being routed with a `routing_batch_function`"
+ f" from step '{predecessor_step.name}' to step '{step.name}'."
+ )
+
return True
def _validate_process_step_input_parameter(
diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py
index 8b20cd04c..63eb8116d 100644
--- a/src/distilabel/pipeline/base.py
+++ b/src/distilabel/pipeline/base.py
@@ -413,12 +413,12 @@ def _cache(self) -> None:
"""Saves the `BasePipeline` using the `_cache_filename`."""
self.save(
path=self._cache_location["pipeline"],
- format=self._cache_location["pipeline"].suffix.replace(".", ""),
+ format=self._cache_location["pipeline"].suffix.replace(".", ""), # type: ignore
)
if self._batch_manager is not None:
self._batch_manager.save(
self._cache_location["batch_manager"],
- format=self._cache_location["batch_manager"].suffix.replace(".", ""),
+ format=self._cache_location["batch_manager"].suffix.replace(".", ""), # type: ignore
)
self._logger.debug("Pipeline and batch manager saved to cache.")
@@ -428,12 +428,6 @@ def _load_from_cache(self) -> None:
"""
cache_loc = self._cache_location
if cache_loc["pipeline"].exists():
- # Refresh the DAG to avoid errors when it's created within a context manager
- # (it will check the steps aren't already defined for the DAG).
- self.dag = DAG()
- new_class = self.from_yaml(cache_loc["pipeline"])
- # Update the internal dag and batch_manager
- self.dag.G = new_class.dag.G
if cache_loc["batch_manager"].exists():
self._batch_manager = _BatchManager.from_json(
cache_loc["batch_manager"]
@@ -453,6 +447,7 @@ class _Batch(_Serializable):
accumulated: A flag to indicate if the batch is accumulated.
created_from: A dictionary containing the `seq_no` of the batches of the steps that
were used to create this batch.
+ size: The size of the batch.
"""
seq_no: int
@@ -460,8 +455,9 @@ class _Batch(_Serializable):
last_batch: bool
data: List[List[Dict[str, Any]]] = field(default_factory=list, repr=False)
accumulated: bool = False
- created_from: Dict[str, List[int]] = field(default_factory=dict)
+ created_from: Dict[str, List[Tuple[int, int]]] = field(default_factory=dict)
batch_routed_to: List[str] = field(default_factory=list)
+ size: int = 0
def next_batch(self) -> "_Batch":
"""Create a new `_Batch` instance with the next batch of data.
@@ -476,6 +472,15 @@ def next_batch(self) -> "_Batch":
seq_no=self.seq_no + 1, step_name=self.step_name, last_batch=self.last_batch
)
+ def set_data(self, data: List[List[Dict[str, Any]]]) -> None:
+ """Sets the data of the batch and updates the size of the batch.
+
+ Args:
+ data: The data of the batch.
+ """
+ self.data = data
+ self.size = len(data[0])
+
@classmethod
def accumulate(cls, step_name: str, batches: List[List["_Batch"]]) -> "_Batch":
"""Creates a `_Batch` instance using the data from the list of batches that
@@ -540,6 +545,16 @@ class _BatchManagerStep(_Serializable):
convergence_step: A flag to indicate if the step is a convergence step. An
`Step` is a convergence step if all its predecessors are receiving routed
batches. Defaults to `False`.
+ convergence_step_batches_consumed: A dictionary in which the key is the `seq_no`
+ of the batch created by step A, that was used by step B and C and obtained from
+ the `created_from` of the batches created by them. It's used to know if all
+ the batches from B and C steps created from batches of A have been consumed
+ by D, in order to not mess up the order of the batches. Only used if `convergence_step=True`.
+ Defaults to an empty dictionary.
+ next_expected_created_from_batch_seq_no: The next expected sequence number of the
+ batch from step A used by steps B and C and obtained from the `created_from`
+ of the batches created by them. It's used to avoid messing up the order of the
+ batches. Only used if `convergence_step=True`. Defaults to `0`.
"""
step_name: str
@@ -549,6 +564,10 @@ class _BatchManagerStep(_Serializable):
seq_no: int = 0
last_batch_received: List[str] = field(default_factory=list)
convergence_step: bool = False
+ convergence_step_batches_consumed: Dict[int, Dict[str, int]] = field(
+ default_factory=dict
+ )
+ next_expected_created_from_batch_seq_no: int = 0
def add_batch(self, batch: _Batch, prepend: bool = False) -> None:
"""Add a batch of data from `batch.step_name` to the step. It will accumulate the
@@ -582,6 +601,7 @@ def get_batch(self) -> Union[_Batch, None]:
# `_last_batch` must be called before `_get_data`, as `_get_data` will update the
# list of data which is used to determine if the batch to be created is the last one.
+ # TODO: remove `_last_batch` method and integrate logic in `_get_data`
last_batch = self._last_batch()
data, created_from, batch_routed_to = self._get_data()
@@ -653,7 +673,7 @@ def _get_seq_no(self) -> int:
def _get_data(
self,
- ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[int]], List[str]]:
+ ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]], List[str]]:
"""Gets the data needed to create a batch for the step to process. If the step is
accumulating data, then it will return a list with all the data received from the
predecessors. Otherwise, it will return a list of data with the `input_batch_size`
@@ -679,7 +699,7 @@ def _get_data(
def _get_data_for_accumulate(
self,
- ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[int]]]:
+ ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]]]:
"""Gets the data needed to create a batch for the step to process when the step
is accumulating data. It will return a list with all the data received from the
predecessors. In addition, it will remove the data used to create the batch from
@@ -695,7 +715,7 @@ def _get_data_for_accumulate(
for step_name, batches in self.data.items():
batches_used[step_name] = []
for batch in batches:
- batches_used[step_name].append(batch.seq_no)
+ batches_used[step_name].append((batch.seq_no, batch.size))
data.append([row for batch in batches for row in batch.data[0]])
# Reset the data buffer
self.data = {step_name: [] for step_name in self.data}
@@ -703,7 +723,7 @@ def _get_data_for_accumulate(
def _get_data_for_convergence_step(
self,
- ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[int]]]:
+ ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]]]:
"""Gets the data needed to create a batch for the step to process when the step is
a convergence step.
@@ -713,25 +733,35 @@ def _get_data_for_convergence_step(
used to create the batch.
"""
grouped_batches = self._group_batches_by_created_from()
- _, batches = grouped_batches[0]
+ seq_no, batches = grouped_batches[0]
remaining_rows_per_step = {
step_name: self.input_batch_size for step_name in self.data
}
batches_used = defaultdict(list)
data = defaultdict(list)
- for batch in batches:
+ for batch, batch_size in batches:
batch_data = batch.data[0]
remaining_rows = remaining_rows_per_step[batch.step_name]
selected_data = batch_data[:remaining_rows]
data[batch.step_name].extend(selected_data)
+ # If A -> [B, C] -> D, then in D (this step) we keep track of the remaining
+ # rows from the batches of A that B and C used to create the `batches`.
+ batch_size = self.convergence_step_batches_consumed.setdefault(
+ seq_no, {}
+ ).get(batch.step_name, batch_size)
+ remaining_rows_in_batch = batch_size - len(selected_data)
+ self.convergence_step_batches_consumed[seq_no].update(
+ {batch.step_name: remaining_rows_in_batch}
+ )
+
# Update the remaining rows
num_rows = len(selected_data)
remaining_rows_per_step[batch.step_name] -= num_rows # type: ignore
# Keep track of the batches used to create the batch
- batches_used[batch.step_name].append(batch.seq_no)
+ batches_used[batch.step_name].append((batch.seq_no, batch.size))
# If the batch was entirely consumed, then remove it from the buffer
if num_rows >= len(batch_data):
@@ -744,11 +774,21 @@ def _get_data_for_convergence_step(
batch_ref = self.data[batch.step_name][batch_idx]
batch_ref.data[0] = batch_data[len(selected_data) :]
+ # If all the batches grouped by the `seq_no` in `created_from` were consumed, then
+ # we can update the `next_expected_created_from_batch_seq_no` to the next one
+ # to avoid skipping batches.
+ no_remaining_rows = all(
+ count == 0
+ for count in self.convergence_step_batches_consumed[seq_no].values()
+ )
+ if no_remaining_rows:
+ self.next_expected_created_from_batch_seq_no += 1
+
return list(data.values()), dict(batches_used)
def _get_data_normal(
self,
- ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[int]], List[str]]:
+ ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]], List[str]]:
"""Gets the data needed to create a batch for the step to process when the step is
not accumulating data. It will return a list of data with the `input_batch_size`
for each predecessor. In addition, it will remove the data used to create the batch
@@ -771,6 +811,9 @@ def _get_data_normal(
idx_drop_batches = []
remaining_rows: int = self.input_batch_size # type: ignore
for idx, batch in enumerate(self.data[step_name]):
+ if remaining_rows == 0:
+ break
+
# Get `remaining_rows` or the remaining rows in the batch and add it to
# the step data that will be used to create the batch
batch_data = batch.data[0]
@@ -783,7 +826,7 @@ def _get_data_normal(
remaining_rows -= num_rows
# Keep track of the batches used to create the batch
- batches_used[step_name].append(batch.seq_no)
+ batches_used[step_name].append((batch.seq_no, batch.size))
# If the batch was entirely consumed, then remove it from the buffer
if num_rows >= len(batch_data):
@@ -843,19 +886,24 @@ def _ready_to_create_batch_convergence_step(self) -> bool:
grouped_batches = self._group_batches_by_created_from()
if not grouped_batches:
return False
- _, batches = grouped_batches[0]
+ seq_no, batches = grouped_batches[0]
+
+ # If the `seq_no` from the `created_from` field is not the expected one, then
+ # we cannot create a batch yet or the order will be messed up
+ if seq_no != self.next_expected_created_from_batch_seq_no:
+ return False
# Not all output batches to which the input batch was routed to haven't been
# received
- batch_routed_to = batches[0].batch_routed_to
- batches_received_from = {batch.step_name for batch in batches}
+ batch_routed_to = batches[0][0].batch_routed_to
+ batches_received_from = {batch.step_name for batch, _ in batches}
if any(step_name not in batches_received_from for step_name in batch_routed_to):
return False
# There are output batches to which the input batch was routed to from all
# the steps. Check if there is enough data for creating a batch with `input_batch_size`
rows_per_step = defaultdict(lambda: 0)
- for batch in batches:
+ for batch, _ in batches:
num_rows = len(batch.data[0])
rows_per_step[batch.step_name] += num_rows
@@ -931,10 +979,15 @@ def _last_batch_convergence_step(self) -> bool:
if not grouped_batches:
return False
_, batches = grouped_batches[0]
- steps_in_batches = {batch.step_name for batch in batches}
- return all(
- step_name in self.last_batch_received for step_name in steps_in_batches
- )
+
+ for batch, _ in batches:
+ if not batch.last_batch:
+ return False
+
+ if len(batch.data[0]) > self.input_batch_size: # type: ignore
+ return False
+
+ return True
def _last_batch_normal(self) -> bool:
"""Checks if the batch to be created is the last one for a normal step. `True` if
@@ -958,7 +1011,9 @@ def _last_batch_normal(self) -> bool:
return True
- def _group_batches_by_created_from(self) -> List[Tuple[int, List["_Batch"]]]:
+ def _group_batches_by_created_from(
+ self,
+ ) -> List[Tuple[int, List[Tuple["_Batch", int]]]]:
"""Group the batches by the first key of `created_from` field. This method is
meant to be used only with a `convergence_step`.
@@ -966,23 +1021,23 @@ def _group_batches_by_created_from(self) -> List[Tuple[int, List["_Batch"]]]:
A list of the batches grouped by the `seq_no` of the first step name in `created_from`.
The list is sorted by the `seq_no`.
"""
- grouped_batches = defaultdict(list)
+ grouped_batches: Dict[int, List[Tuple["_Batch", int]]] = defaultdict(list)
for batches in self.data.values():
for batch in batches:
first_key = next(iter(batch.created_from))
- batch_seq_no = batch.created_from[first_key][0]
- grouped_batches[batch_seq_no].append(batch)
+ batch_seq_no, batch_size = batch.created_from[first_key][0]
+ grouped_batches[batch_seq_no].append((batch, batch_size))
return sorted((seq_no, batches) for seq_no, batches in grouped_batches.items())
def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
"""Dumps the content of the `_BatchManagerStep` to a dictionary, using the `dataclass` helper function.
Args:
- obj (Any): Unused, just kept to match the signature of the parent method.
- kwargs (Any): Additional arguments that are kept to match the signature of the parent method.
+ obj: Unused, just kept to match the signature of the parent method.
+ kwargs: Additional arguments that are kept to match the signature of the parent method.
Returns:
- Dict[str, Any]: Internal representation of the `_BatchManagerStep`.
+ Internal representation of the `_BatchManagerStep`.
"""
return asdict(self)
@@ -1007,7 +1062,7 @@ def __init__(
steps: Dict[str, _BatchManagerStep],
last_batch_received: Dict[str, Union[_Batch, None]],
last_batch_sent: Dict[str, Union[_Batch, None]],
- last_batch_flag_sent_to: Optional[List[str]] = None,
+ last_batch_flag_sent_to: List[str],
) -> None:
"""Initialize the `_BatchManager` instance.
@@ -1019,12 +1074,9 @@ def __init__(
last_batch_sent: A dictionary with the step name as the key and a the last
`_Batch` sent to the step.
last_batch_flag_sent_to: A list with the names of the steps to which `LAST_BATCH_SENT_FLAG`
- was sent. Defaults to `None`.
+ was sent.
"""
- if last_batch_flag_sent_to is None:
- last_batch_flag_sent_to = []
-
self._steps = steps
self._last_batch_received = last_batch_received
self._last_batch_sent = last_batch_sent
@@ -1039,14 +1091,15 @@ def can_generate(self) -> bool:
"""
for step_name, batch in self._last_batch_received.items():
- if not batch:
- return True
+ if step_name not in self._last_batch_flag_sent_to:
+ if not batch:
+ return True
- if not batch.last_batch and step_name not in self._last_batch_flag_sent_to:
- return True
+ if not batch.last_batch:
+ return True
- if not self.get_last_batch_sent(step_name):
- return True
+ if not self.get_last_batch_sent(step_name):
+ return True
return False
@@ -1172,7 +1225,7 @@ def from_dag(cls, dag: "DAG") -> "_BatchManager":
convergence_step=convergence_step,
)
steps[step_name] = batch_manager_step
- return cls(steps, last_batch_received, last_batch_sent)
+ return cls(steps, last_batch_received, last_batch_sent, [])
def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
"""Dumps the content of the `_BatchManager` to a dictionary.
diff --git a/src/distilabel/pipeline/constants.py b/src/distilabel/pipeline/constants.py
index 829aeb399..450ef0ed6 100644
--- a/src/distilabel/pipeline/constants.py
+++ b/src/distilabel/pipeline/constants.py
@@ -15,5 +15,7 @@
from typing import Final
STEP_ATTR_NAME: Final[str] = "step"
+INPUT_QUEUE_ATTR_NAME: Final[str] = "input_queue"
RECEIVES_ROUTED_BATCHES_ATTR_NAME: Final[str] = "receives_routed_batches"
ROUTING_BATCH_FUNCTION_ATTR_NAME: Final[str] = "routing_batch_function"
+CONVERGENCE_STEP_ATTR_NAME: Final[str] = "convergence_step"
diff --git a/src/distilabel/pipeline/local.py b/src/distilabel/pipeline/local.py
index 94401eb38..5bb7fd799 100644
--- a/src/distilabel/pipeline/local.py
+++ b/src/distilabel/pipeline/local.py
@@ -32,6 +32,8 @@
_WriteBuffer,
)
from distilabel.pipeline.constants import (
+ CONVERGENCE_STEP_ATTR_NAME,
+ INPUT_QUEUE_ATTR_NAME,
ROUTING_BATCH_FUNCTION_ATTR_NAME,
STEP_ATTR_NAME,
)
@@ -44,7 +46,7 @@
from queue import Queue
from distilabel.distiset import Distiset
- from distilabel.steps.base import GeneratorStep
+ from distilabel.steps.base import GeneratorStep, _Step
_STEPS_LOADED_KEY = "steps_loaded"
@@ -188,7 +190,7 @@ def _notify_steps_to_stop(self) -> None:
"""Notifies the steps to stop their infinite running loop by sending `None` to
their input queues."""
for step_name in self.dag:
- if input_queue := self.dag.get_step(step_name).get("input_queue"):
+ if input_queue := self.dag.get_step(step_name).get(INPUT_QUEUE_ATTR_NAME):
input_queue.put(None)
def _output_queue_loop(self, write_buffer: "_WriteBuffer") -> None:
@@ -236,16 +238,27 @@ def _manage_batch_flow(self, batch: "_Batch") -> None:
"""
assert self._batch_manager, "Batch manager is not set"
+ # Make sure to send the `LAST_BATCH_SENT_FLAG` to the predecessors of the convergence
+ # step if the batch is the last one, so they stop their processing loop even if
+ # they haven't received the last batch because of the routing function.
+ if self._is_convergence_step(batch.step_name) and batch.last_batch:
+ for step_name in self.dag.get_step_predecessors(batch.step_name):
+ self._send_last_batch_flag_to_step(step_name)
+
+ route_to, routed = self._get_successors(batch)
+
+ # Keep track of the steps that the batch was routed to
+ if routed:
+ batch.batch_routed_to = route_to
+
self._register_batch(batch)
- successors, route_to, routed = self._get_successors(batch)
+
step = self._get_step_from_batch(batch)
# Add the batch to the successors input buffers
for successor in route_to:
# Copy batch to avoid modifying the same reference in the batch manager
batch_to_add = batch.copy() if len(route_to) > 1 else batch
- if routed:
- batch_to_add.batch_routed_to = route_to
self._batch_manager.add_batch(successor, batch_to_add)
@@ -256,23 +269,14 @@ def _manage_batch_flow(self, batch: "_Batch") -> None:
step.is_generator
and step.name in self._batch_manager.step_empty_buffers(successor)
):
- last_batch = self._batch_manager.get_last_batch_sent(step.name)
- self._send_batch_to_step(last_batch.next_batch()) # type: ignore
+ last_batch_sent = self._batch_manager.get_last_batch_sent(step.name)
+ self._send_batch_to_step(last_batch_sent.next_batch()) # type: ignore
# If successor step has enough data in its buffer to create a new batch, then
# send the batch to the step.
if new_batch := self._batch_manager.get_batch(successor):
self._send_batch_to_step(new_batch)
- # If `last_batch` was not routed to all the successors of the step, that
- # means that the batch was routed to specific steps using a routing function.
- # We have to send the `LAST_BATCH_SENT_FLAG` to the steps that the batch
- # was not routed to, so they can stop processing batches.
- not_routed_to = [s for s in successors if s not in route_to]
- if batch.last_batch and len(not_routed_to):
- for step_name in not_routed_to:
- self._send_last_batch_flag_to_step(step_name)
-
if step.is_generator:
return
@@ -280,9 +284,8 @@ def _manage_batch_flow(self, batch: "_Batch") -> None:
# buffers to create a new batch
if new_batch := self._batch_manager.get_batch(step.name): # type: ignore
self._send_batch_to_step(new_batch)
- return
-
- self._request_more_batches_if_needed(step)
+ else:
+ self._request_more_batches_if_needed(step)
self._cache()
@@ -298,15 +301,15 @@ def _register_batch(self, batch: "_Batch") -> None:
" manager"
)
- def _get_successors(self, batch: "_Batch") -> Tuple[List[str], List[str], bool]:
+ def _get_successors(self, batch: "_Batch") -> Tuple[List[str], bool]:
"""Gets the successors and the successors to which the batch has to be routed.
Args:
batch: The batch to which the successors will be determined.
Returns:
- The successors, the successors to route the batch to, and whether the batch was
- routed using a routing function.
+ The successors to route the batch to and whether the batch was routed using
+ a routing function.
"""
node = self.dag.get_step(batch.step_name)
step: "Step" = node[STEP_ATTR_NAME]
@@ -321,7 +324,7 @@ def _get_successors(self, batch: "_Batch") -> Tuple[List[str], List[str], bool]:
f"🚏 Using '{step.name}' routing function to send batch {batch.seq_no} to steps: {successors_str}"
)
- return successors, route_to, route_to != successors
+ return route_to, route_to != successors
def _get_step_from_batch(self, batch: "_Batch") -> "Step":
"""Gets the `Step` instance from a batch.
@@ -366,12 +369,17 @@ def _handle_stop(self, write_buffer: "_WriteBuffer") -> None:
"""
self._logger.debug("Handling stop of the pipeline execution...")
- # Send `None` to the input queues of all the steps to notify them to stop
- # processing batches.
+ # Add the remaining batches in the input queues back to the batch manager
for step_name in self.dag:
- if input_queue := self.dag.get_step(step_name).get("input_queue"):
+ node = self.dag.get_step(step_name)
+ step: "_Step" = node[STEP_ATTR_NAME]
+ if step.is_generator:
+ continue
+ if input_queue := node.get(INPUT_QUEUE_ATTR_NAME):
while not input_queue.empty():
batch = input_queue.get()
+ if batch is None:
+ continue
self._batch_manager.add_batch( # type: ignore
to_step=step_name, batch=batch, prepend=True
)
@@ -389,8 +397,12 @@ def _handle_stop(self, write_buffer: "_WriteBuffer") -> None:
# processed by the steps before stop was called.
while not self.output_queue.empty():
batch = self.output_queue.get()
+ if batch is None:
+ continue
+
if batch.step_name in self.dag.leaf_steps:
write_buffer.add_batch(batch)
+
self._handle_batch_on_stop(batch)
self._cache()
@@ -419,7 +431,7 @@ def _wait_step_input_queue_empty(self, step_name: str) -> Union["Queue[Any]", No
if self._check_step_not_loaded_or_finished(step_name):
return None
- if input_queue := self.dag.get_step(step_name).get("input_queue"):
+ if input_queue := self.dag.get_step(step_name).get(INPUT_QUEUE_ATTR_NAME):
while input_queue.qsize() != 0:
pass
return input_queue
@@ -520,10 +532,23 @@ def _send_batch_to_step(self, batch: "_Batch") -> None:
self._logger.debug(
f"Sending batch {batch.seq_no} to step '{batch.step_name}': {batch}"
)
- input_queue = self.dag.get_step(batch.step_name)["input_queue"]
+ input_queue = self.dag.get_step(batch.step_name)[INPUT_QUEUE_ATTR_NAME]
input_queue.put(batch)
+ def _is_convergence_step(self, step_name: str) -> None:
+ """Checks if a step is a convergence step.
+
+ Args:
+ step_name: The name of the step.
+ """
+ return self.dag.get_step(step_name).get(CONVERGENCE_STEP_ATTR_NAME)
+
def _send_last_batch_flag_to_step(self, step_name: str) -> None:
+ """Sends the `LAST_BATCH_SENT_FLAG` to a step to stop processing batches.
+
+ Args:
+ step_name: The name of the step.
+ """
batch = self._batch_manager.get_last_batch_sent(step_name) # type: ignore
if batch and batch.last_batch:
return
@@ -532,7 +557,7 @@ def _send_last_batch_flag_to_step(self, step_name: str) -> None:
f"Sending `LAST_BATCH_SENT_FLAG` to '{step_name}' step to stop processing"
" batches..."
)
- input_queue = self.dag.get_step(step_name)["input_queue"]
+ input_queue = self.dag.get_step(step_name)[INPUT_QUEUE_ATTR_NAME]
input_queue.put(LAST_BATCH_SENT_FLAG)
self._batch_manager.set_last_batch_flag_sent_to(step_name) # type: ignore
@@ -559,7 +584,7 @@ def _run_steps_in_loop(
for step_name in self.dag:
step: "Step" = self.dag.get_step(step_name)[STEP_ATTR_NAME]
input_queue = manager.Queue()
- self.dag.set_step_attr(step.name, "input_queue", input_queue)
+ self.dag.set_step_attr(step.name, INPUT_QUEUE_ATTR_NAME, input_queue)
# Set `pipeline` to `None` as in some Python environments the pipeline is not
# picklable and it will raise an error when trying to send the step to the process.
@@ -889,7 +914,7 @@ def _generator_step_process_loop(self) -> None:
)
for data, last_batch in step.process_applying_mappings(offset=offset):
- batch.data = [data]
+ batch.set_data([data])
batch.last_batch = self._dry_run or last_batch
self._send_batch(batch)
@@ -962,7 +987,7 @@ def _non_generator_process_loop(self) -> None:
f"Subprocess traceback:\n\n{traceback.format_exc()}"
)
finally:
- batch.data = [result]
+ batch.set_data([result])
self._send_batch(batch)
if batch.last_batch:
diff --git a/src/distilabel/utils/serialization.py b/src/distilabel/utils/serialization.py
index 276cac1f2..b97669b80 100644
--- a/src/distilabel/utils/serialization.py
+++ b/src/distilabel/utils/serialization.py
@@ -277,6 +277,14 @@ def from_yaml(cls, path: StrOrPath) -> Self:
@classmethod
def from_file(cls, path: StrOrPath) -> Self:
+ """Loads a class from a file.
+
+ Args:
+ path: the path to the file containing the serialized class.
+
+ Returns:
+ An instance of the class.
+ """
path = Path(path)
if path.suffix == ".json":
return cls.from_json(path)
diff --git a/tests/integration/test_pipe_simple.py b/tests/integration/test_pipe_simple.py
index 28388b8f7..8ab1fff29 100644
--- a/tests/integration/test_pipe_simple.py
+++ b/tests/integration/test_pipe_simple.py
@@ -157,10 +157,6 @@ def run_pipeline():
return pipeline.run(
parameters={
- "load_dataset": {
- "repo_id": "plaguss/test",
- "split": "train",
- },
"rename_columns": {
"rename_mappings": {
"prompt": "instruction",
diff --git a/tests/integration/test_routing_batch_function.py b/tests/integration/test_routing_batch_function.py
index 2fbaa508e..0ea2ee3cd 100644
--- a/tests/integration/test_routing_batch_function.py
+++ b/tests/integration/test_routing_batch_function.py
@@ -16,6 +16,7 @@
import time
from typing import TYPE_CHECKING, List
+import pytest
from distilabel.pipeline import Pipeline, routing_batch_function
from distilabel.steps import LoadDataFromDicts, StepInput, step
@@ -41,22 +42,31 @@ def Generate(inputs: StepInput) -> "StepOutput":
yield inputs
+@step(outputs=["generations"])
+def Generate2(inputs: StepInput) -> "StepOutput":
+ sleep_time = random.uniform(1.0, 2.0)
+ time.sleep(sleep_time)
+ for input in inputs:
+ input["2generation"] = "I slept for {} seconds".format(sleep_time)
+ yield inputs
+
+
@step(outputs=["generations"])
def CombineGenerations(*inputs: StepInput) -> "StepOutput":
+ generation_key = (
+ "2generation" if "2generation" in inputs[0][0].keys() else "generation"
+ )
+
combined_list = []
for rows in zip(*inputs):
combined_dict = {
"index": rows[0]["index"],
- "instruction": rows[0]["instruction"],
- "generations": [row["generation"] for row in rows],
+ "instruction": [row["instruction"] for row in rows],
+ f"{generation_key}s": [row[generation_key] for row in rows],
}
# Check consistency in "index" and "instruction"
- if any(
- row["index"] != combined_dict["index"]
- or row["instruction"] != combined_dict["instruction"]
- for row in rows
- ):
+ if any(row["index"] != combined_dict["index"] for row in rows):
raise ValueError("Inconsistent 'index' or 'instruction'")
combined_list.append(combined_dict)
@@ -64,6 +74,7 @@ def CombineGenerations(*inputs: StepInput) -> "StepOutput":
yield combined_list
+@pytest.mark.timeout(120)
def test_routing_batch_function() -> None:
with Pipeline(name="test") as pipeline:
load_dataset = LoadDataFromDicts(
@@ -80,3 +91,75 @@ def test_routing_batch_function() -> None:
for i, row in enumerate(distiset["default"]["train"]):
assert row["index"] == i
+ assert row["instruction"] == [f"Instruction {i}", f"Instruction {i}"]
+ assert len(row["generations"]) == 2
+
+
+@pytest.mark.timeout(120)
+def test_routing_batch_function_irregular_batch_sizes() -> None:
+ with Pipeline(name="test") as pipeline:
+ load_dataset = LoadDataFromDicts(
+ data=[{"index": i, "instruction": f"Instruction {i}"} for i in range(1000)],
+ batch_size=200,
+ )
+
+ generates = [
+ Generate(input_batch_size=input_batch_size)
+ for input_batch_size in [25, 50, 100, 200]
+ ]
+
+ combine_generations = CombineGenerations(input_batch_size=25)
+
+ load_dataset >> random_routing_batch >> generates >> combine_generations
+
+ distiset = pipeline.run(use_cache=False)
+
+ for i, row in enumerate(distiset["default"]["train"]):
+ assert row["index"] == i
+ assert row["instruction"] == [f"Instruction {i}", f"Instruction {i}"]
+ assert len(row["generations"]) == 2
+
+
+@pytest.mark.timeout(120)
+def test_multiple_routing_batch_function() -> None:
+ batch_size = 200
+
+ with Pipeline(name="test") as pipeline:
+ load_dataset = LoadDataFromDicts(
+ data=[
+ {
+ "index": i,
+ "instruction": f"Instruction {i}",
+ "batch": i // batch_size,
+ }
+ for i in range(1000)
+ ],
+ batch_size=batch_size,
+ )
+
+ generates = [
+ Generate(input_batch_size=input_batch_size)
+ for input_batch_size in [25, 50, 100, 200]
+ ]
+
+ combine_generations = CombineGenerations(input_batch_size=25)
+
+ generates2 = [Generate2(input_batch_size=25) for _ in range(4)]
+
+ combine_generations_2 = CombineGenerations(input_batch_size=25)
+
+ (
+ load_dataset
+ >> random_routing_batch
+ >> generates
+ >> combine_generations
+ >> random_routing_batch
+ >> generates2
+ >> combine_generations_2
+ )
+
+ distiset = pipeline.run(use_cache=False)
+
+ for i, row in enumerate(distiset["default"]["train"]):
+ assert row["index"] == i
+ assert len(row["2generations"]) == 2
diff --git a/tests/unit/pipeline/test_base.py b/tests/unit/pipeline/test_base.py
index 6c7dd33b6..76a48ec4b 100644
--- a/tests/unit/pipeline/test_base.py
+++ b/tests/unit/pipeline/test_base.py
@@ -261,6 +261,14 @@ def test_infer_step_names_big_pipeline(self) -> None:
class TestBatch:
+ def test_set_data(self) -> None:
+ batch = _Batch(seq_no=0, step_name="step1", last_batch=False)
+ data = [[{"i": i} for i in range(5000)]]
+ batch.set_data(data)
+
+ assert batch.data == data
+ assert batch.size == 5000
+
def test_next_batch(self) -> None:
batch = _Batch(seq_no=0, step_name="step1", last_batch=False)
next_batch = batch.next_batch()
@@ -313,6 +321,7 @@ def test_dump(self) -> None:
batch = _Batch(seq_no=0, step_name="step1", last_batch=False)
assert batch.dump() == {
"seq_no": 0,
+ "size": 0,
"step_name": "step1",
"last_batch": False,
"data": [],
@@ -333,6 +342,7 @@ def test_dump(self) -> None:
)
assert batch.dump() == {
"seq_no": 0,
+ "size": 0,
"step_name": "step1",
"last_batch": False,
"data": [[{"a": 1}, {"a": 2}, {"a": 3}]],
@@ -441,6 +451,7 @@ def test_get_batch(self) -> None:
{"a": 5},
]
],
+ size=5,
)
],
"step2": [
@@ -458,6 +469,7 @@ def test_get_batch(self) -> None:
{"b": 6},
]
],
+ size=5,
)
],
},
@@ -479,7 +491,7 @@ def test_get_batch(self) -> None:
{"b": 2},
],
],
- created_from={"step1": [0], "step2": [0]},
+ created_from={"step1": [(0, 5)], "step2": [(0, 5)]},
)
batch = batch_manager_step.get_batch()
@@ -498,7 +510,7 @@ def test_get_batch(self) -> None:
{"b": 4},
],
],
- created_from={"step1": [0], "step2": [0]},
+ created_from={"step1": [(0, 5)], "step2": [(0, 5)]},
)
def test_get_batches_accumulate(self) -> None:
@@ -520,6 +532,7 @@ def test_get_batches_accumulate(self) -> None:
{"a": 5},
]
],
+ size=5,
)
],
"step2": [
@@ -537,6 +550,7 @@ def test_get_batches_accumulate(self) -> None:
{"b": 6},
]
],
+ size=6,
)
],
},
@@ -567,7 +581,7 @@ def test_get_batches_accumulate(self) -> None:
{"b": 6},
],
],
- created_from={"step1": [0], "step2": [0]},
+ created_from={"step1": [(0, 5)], "step2": [(0, 6)]},
)
def test_get_batches_not_enough_data(self) -> None:
@@ -654,6 +668,7 @@ def test_get_data(self) -> None:
data=[
[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}]
],
+ size=6,
batch_routed_to=["step1", "step2"],
)
],
@@ -673,6 +688,7 @@ def test_get_data(self) -> None:
{"b": 7},
]
],
+ size=7,
batch_routed_to=["step1", "step2"],
)
],
@@ -684,7 +700,7 @@ def test_get_data(self) -> None:
[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}],
[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}],
]
- assert created_from == {"step1": [0], "step2": [0]}
+ assert created_from == {"step1": [(0, 6)], "step2": [(0, 7)]}
assert routed_to == ["step1", "step2"]
assert batch_manager_step.data == {
@@ -694,6 +710,7 @@ def test_get_data(self) -> None:
step_name="step1",
last_batch=False,
data=[[{"a": 6}]],
+ size=6,
batch_routed_to=["step1", "step2"],
)
],
@@ -703,6 +720,7 @@ def test_get_data(self) -> None:
step_name="step2",
last_batch=False,
data=[[{"b": 6}, {"b": 7}]],
+ size=7,
batch_routed_to=["step1", "step2"],
)
],
@@ -721,6 +739,7 @@ def test_get_data_accumulate(self) -> None:
data=[
[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}]
],
+ size=6,
)
],
"step2": [
@@ -739,6 +758,7 @@ def test_get_data_accumulate(self) -> None:
{"b": 7},
]
],
+ size=7,
)
],
},
@@ -750,7 +770,7 @@ def test_get_data_accumulate(self) -> None:
[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}],
[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}, {"b": 7}],
]
- assert created_from == {"step1": [0], "step2": [0]}
+ assert created_from == {"step1": [(0, 6)], "step2": [(0, 7)]}
assert routed_to == []
assert batch_manager_step.data == {"step1": [], "step2": []}
@@ -767,7 +787,8 @@ def test_get_data_convergence_step(self) -> None:
{"generation": "Hello, I'm A 0"},
]
],
- created_from={"Z": [0]},
+ size=3,
+ created_from={"Z": [(0, 3)]},
)
batch_a_1 = _Batch(
@@ -781,7 +802,8 @@ def test_get_data_convergence_step(self) -> None:
{"generation": "Hello, I'm A 1"},
]
],
- created_from={"Z": [1]},
+ size=3,
+ created_from={"Z": [(1, 3)]},
)
batch_b_0 = _Batch(
@@ -795,7 +817,8 @@ def test_get_data_convergence_step(self) -> None:
{"generation": "Hello, I'm B 0"},
]
],
- created_from={"Z": [0]},
+ size=3,
+ created_from={"Z": [(0, 3)]},
)
batch_c_0 = _Batch(
@@ -809,10 +832,11 @@ def test_get_data_convergence_step(self) -> None:
{"generation": "Hello, I'm C 0"},
]
],
- created_from={"Z": [1]},
+ size=3,
+ created_from={"Z": [(1, 3)]},
)
- bath_manager_step = _BatchManagerStep(
+ batch_manager_step = _BatchManagerStep(
step_name="D",
input_batch_size=3,
convergence_step=True,
@@ -820,7 +844,7 @@ def test_get_data_convergence_step(self) -> None:
data={"A": [batch_a_0, batch_a_1], "B": [batch_b_0], "C": [batch_c_0]},
)
- data, created_from, routed_to = bath_manager_step._get_data()
+ data, created_from, routed_to = batch_manager_step._get_data()
assert data == [
[
@@ -834,10 +858,11 @@ def test_get_data_convergence_step(self) -> None:
{"generation": "Hello, I'm B 0"},
],
]
- assert created_from == {"A": [0], "B": [0]}
+ assert created_from == {"A": [(0, 3)], "B": [(0, 3)]}
assert routed_to == []
+ assert batch_manager_step.next_expected_created_from_batch_seq_no == 1
- data, created_from, routed_to = bath_manager_step._get_data()
+ data, created_from, routed_to = batch_manager_step._get_data()
assert data == [
[
@@ -851,8 +876,9 @@ def test_get_data_convergence_step(self) -> None:
{"generation": "Hello, I'm C 0"},
],
]
- assert created_from == {"A": [1], "C": [0]}
+ assert created_from == {"A": [(1, 3)], "C": [(0, 3)]}
assert routed_to == []
+ assert batch_manager_step.next_expected_created_from_batch_seq_no == 2
@pytest.mark.parametrize(
"data, last_batch_received, expected",
@@ -1034,6 +1060,100 @@ def test_last_batch_accumulate(
assert batch_manager_step._last_batch() is expected
+ @pytest.mark.parametrize(
+ "data, last_batch_received, expected",
+ [
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ created_from={"step0": [(0, 5)]},
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
+ created_from={"step0": [(0, 5)]},
+ )
+ ],
+ },
+ [],
+ False,
+ ),
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
+ created_from={"step0": [(0, 5)]},
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
+ created_from={"step0": [(0, 5)]},
+ )
+ ],
+ },
+ [],
+ False,
+ ),
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}]],
+ created_from={"step0": [(0, 3)]},
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=True,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}]],
+ created_from={"step0": [(0, 3)]},
+ )
+ ],
+ },
+ [],
+ True,
+ ),
+ ],
+ )
+ def test_last_batch_convergence_step(
+ self,
+ data: Dict[str, List[_Batch]],
+ last_batch_received: List[str],
+ expected: bool,
+ ) -> None:
+ batch_manager_step = _BatchManagerStep(
+ step_name="step3",
+ accumulate=False,
+ data=data,
+ last_batch_received=last_batch_received,
+ input_batch_size=3,
+ convergence_step=True,
+ )
+
+ assert batch_manager_step._last_batch() is expected
+
@pytest.mark.parametrize(
"data, last_batch_received, expected",
[
@@ -1206,6 +1326,7 @@ def test_dump(self) -> None:
step_name="step1",
last_batch=True,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}]],
+ size=6,
)
batch_step_2 = _Batch(
seq_no=0,
@@ -1214,6 +1335,7 @@ def test_dump(self) -> None:
data=[
[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}, {"b": 7}]
],
+ size=7,
)
batch_manager_step = _BatchManagerStep(
step_name="step3",
@@ -1227,6 +1349,7 @@ def test_dump(self) -> None:
"step_name": "step3",
"accumulate": True,
"convergence_step": False,
+ "convergence_step_batches_consumed": {},
"input_batch_size": None,
"data": {
"step1": [
@@ -1244,6 +1367,7 @@ def test_dump(self) -> None:
{"a": 6},
]
],
+ "size": 6,
"accumulated": False,
"created_from": {},
"batch_routed_to": [],
@@ -1265,6 +1389,7 @@ def test_dump(self) -> None:
{"b": 7},
]
],
+ "size": 7,
"accumulated": False,
"created_from": {},
"batch_routed_to": [],
@@ -1273,6 +1398,7 @@ def test_dump(self) -> None:
},
"seq_no": 0,
"last_batch_received": [],
+ "next_expected_created_from_batch_seq_no": 0,
"type_info": {
"module": "distilabel.pipeline.base",
"name": "_BatchManagerStep",
@@ -1291,6 +1417,7 @@ def test_dump(self) -> None:
last_batch=False,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
batch_routed_to=["step1", "step2"],
+ created_from={"step0": [(0, 5)]},
)
],
"step2": [],
@@ -1307,6 +1434,7 @@ def test_dump(self) -> None:
last_batch=False,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]],
batch_routed_to=["step1", "step2"],
+ created_from={"step0": [(0, 5)]},
)
],
"step2": [
@@ -1316,6 +1444,7 @@ def test_dump(self) -> None:
last_batch=False,
data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
batch_routed_to=["step1", "step2"],
+ created_from={"step0": [(0, 5)]},
)
],
},
@@ -1331,6 +1460,7 @@ def test_dump(self) -> None:
last_batch=False,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]],
batch_routed_to=["step1", "step2"],
+ created_from={"step0": [(0, 4)]},
)
],
"step2": [
@@ -1340,6 +1470,7 @@ def test_dump(self) -> None:
last_batch=False,
data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
batch_routed_to=["step1", "step2"],
+ created_from={"step0": [(0, 5)]},
)
],
},
@@ -1355,6 +1486,7 @@ def test_dump(self) -> None:
last_batch=True,
data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]],
batch_routed_to=["step1", "step2"],
+ created_from={"step0": [(0, 4)]},
)
],
"step2": [
@@ -1364,12 +1496,39 @@ def test_dump(self) -> None:
last_batch=True,
data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]],
batch_routed_to=["step1", "step2"],
+ created_from={"step0": [(0, 4)]},
)
],
},
["step1", "step2"],
True,
),
+ (
+ {
+ "step1": [
+ _Batch(
+ seq_no=0,
+ step_name="step1",
+ last_batch=False,
+ data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]],
+ batch_routed_to=["step1", "step2"],
+ created_from={"step0": [(0, 4)]},
+ )
+ ],
+ "step2": [
+ _Batch(
+ seq_no=0,
+ step_name="step2",
+ last_batch=False,
+ data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]],
+ batch_routed_to=["step1", "step2"],
+ created_from={"step0": [(0, 5)]},
+ )
+ ],
+ },
+ [],
+ False,
+ ),
],
)
def test_ready_to_create_batch_convergence_step(
@@ -1384,6 +1543,7 @@ def test_ready_to_create_batch_convergence_step(
input_batch_size=5,
data=data,
last_batch_received=last_batch_received,
+ convergence_step=True,
)
assert batch_manager_step._ready_to_create_batch() is expected
@@ -1394,6 +1554,7 @@ def test_from_dict(self) -> None:
"step_name": "step3",
"accumulate": True,
"convergence_step": False,
+ "convergence_step_batches_consumed": {0: {"Z": 1234}},
"input_batch_size": None,
"data": {
"step1": [
@@ -1411,6 +1572,7 @@ def test_from_dict(self) -> None:
{"a": 6},
]
],
+ "size": 6,
"accumulated": False,
"created_from": {},
"batch_routed_to": [],
@@ -1432,6 +1594,7 @@ def test_from_dict(self) -> None:
{"b": 7},
]
],
+ "size": 7,
"accumulated": False,
"created_from": {},
"batch_routed_to": [],
@@ -1451,6 +1614,7 @@ def test_from_dict(self) -> None:
assert batch_manager_step.step_name == "step3"
assert batch_manager_step.accumulate is True
assert batch_manager_step.convergence_step is False
+ assert batch_manager_step.convergence_step_batches_consumed == {0: {"Z": 1234}}
assert batch_manager_step.input_batch_size is None
assert batch_manager_step.seq_no == 0
assert batch_manager_step.last_batch_received == []
@@ -1469,6 +1633,7 @@ def test_add_batch(self) -> None:
},
last_batch_received={"step3": None},
last_batch_sent={"step3": None},
+ last_batch_flag_sent_to=[],
)
batch_from_step_1 = _Batch(
@@ -1505,6 +1670,7 @@ def test_add_batch_with_prepend(self) -> None:
},
last_batch_received={"step3": None},
last_batch_sent={"step3": None},
+ last_batch_flag_sent_to=[],
)
batch_0 = _Batch(
seq_no=0,
@@ -1566,6 +1732,7 @@ def test_can_generate(self) -> None:
"step_3": _Batch(seq_no=0, step_name="step_3", last_batch=False),
},
last_batch_sent={"step_1": None, "step_2": None, "step_3": None},
+ last_batch_flag_sent_to=[],
)
assert batch_manager.can_generate()
@@ -1582,6 +1749,7 @@ def test_can_generate(self) -> None:
"step_3": batch_3,
},
last_batch_sent={"step_1": batch_1, "step_2": batch_2, "step_3": batch_3},
+ last_batch_flag_sent_to=[],
)
assert not batch_manager.can_generate()
@@ -1619,10 +1787,12 @@ def test_dump(self) -> None:
"step_name": "step3",
"accumulate": False,
"convergence_step": False,
+ "convergence_step_batches_consumed": {},
"input_batch_size": 5,
"data": {"step1": [], "step2": []},
"seq_no": 1,
"last_batch_received": [],
+ "next_expected_created_from_batch_seq_no": 0,
"type_info": {
"module": "distilabel.pipeline.base",
"name": "_BatchManagerStep",
@@ -1637,6 +1807,7 @@ def test_dump(self) -> None:
"created_from": {},
"last_batch": False,
"data": [],
+ "size": 0,
"accumulated": False,
"type_info": {
"module": "distilabel.pipeline.base",
@@ -1652,6 +1823,7 @@ def test_dump(self) -> None:
"created_from": {},
"last_batch": False,
"data": [],
+ "size": 0,
"accumulated": False,
"type_info": {
"module": "distilabel.pipeline.base",
diff --git a/tests/unit/pipeline/test_dag.py b/tests/unit/pipeline/test_dag.py
index dbdddb4da..2b38f55ab 100644
--- a/tests/unit/pipeline/test_dag.py
+++ b/tests/unit/pipeline/test_dag.py
@@ -595,6 +595,30 @@ def routing_batch_function_1(steps: List[str]) -> List[str]:
):
pipeline.dag.validate()
+ def test_validate_step_receiving_routed_batches_input_batch_size_multiple(
+ self, pipeline: "Pipeline"
+ ) -> None:
+ generator_step_1 = DummyGeneratorStep(pipeline=pipeline)
+ dummy_step_1 = DummyStep1(pipeline=pipeline)
+ dummy_step_2 = DummyStep1(name="demon", pipeline=pipeline, input_batch_size=7)
+
+ @routing_batch_function()
+ def routing_batch_function_1(steps: List[str]) -> List[str]:
+ return steps
+
+ convergence_step = DummyStep2(name="convergence_step", pipeline=pipeline)
+ (
+ generator_step_1
+ >> routing_batch_function_1
+ >> [dummy_step_1, dummy_step_2]
+ >> convergence_step
+ )
+ with pytest.raises(
+ ValueError,
+ match="Step 'demon' should have an `input_batch_size` that is a multiple of the `input_batch_size` or `batch_size`",
+ ):
+ pipeline.dag.validate()
+
class TestDagSerialization:
def test_dag_dump(self, dummy_step_1: "Step", dummy_step_2: "Step") -> None: