From 690013ada15950322eead35f68dbbbe2d149ee62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Mon, 20 May 2024 15:23:33 +0200 Subject: [PATCH] Fix routing batch function deadlocks and unordered batches (#649) * Add checking step `input_batch_size` multiple * Fix unordered batches when using `routing_batch_function` * Fix `can_generate` condition * Remove metadata and style * Fix getting data for batch when irregular batch sizes * Fix steps receiving routed batches getting stuck * Fix `_last_batch_convergence_step` method * Fix stop not checking for `None` * Fix issues related to the queues * Remove unused variable * Add integration tests timeout * Fix deadlock caused becase next expected batch in convergence step * Update unit tests * Add timeout to tests * Simplify condition * Fix unit test * Update timeouts --- .github/workflows/test.yml | 1 + README.md | 8 - pyproject.toml | 2 +- src/distilabel/distiset.py | 20 +- src/distilabel/pipeline/_dag.py | 12 ++ src/distilabel/pipeline/base.py | 143 +++++++++---- src/distilabel/pipeline/constants.py | 2 + src/distilabel/pipeline/local.py | 89 +++++--- src/distilabel/utils/serialization.py | 8 + tests/integration/test_pipe_simple.py | 4 - .../test_routing_batch_function.py | 97 ++++++++- tests/unit/pipeline/test_base.py | 200 ++++++++++++++++-- tests/unit/pipeline/test_dag.py | 24 +++ 13 files changed, 489 insertions(+), 121 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3b9de467f4..01f1ebcb9a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -59,3 +59,4 @@ jobs: - name: Integration Tests run: make integration-tests + timeout-minutes: 5 diff --git a/README.md b/README.md index 097d3b5c80..4e071df69d 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,3 @@ ---- -description: Distilabel is an AI Feedback (AIF) framework for building datasets with and for LLMs. -hide: - - toc ---- - - -
diff --git a/pyproject.toml b/pyproject.toml index 0e5176c089..c3317bf5b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ docs = [ "CairoSVG >= 2.7.1", "mknotebooks >= 0.8.0", ] -tests = ["pytest >= 7.4.0", "pytest-asyncio", "nest-asyncio"] +tests = ["pytest >= 7.4.0", "pytest-asyncio", "nest-asyncio", "pytest-timeout"] # Optional LLMs, integrations, etc anthropic = ["anthropic >= 0.20.0"] diff --git a/src/distilabel/distiset.py b/src/distilabel/distiset.py index 245ed29c17..b9bd1b0e03 100644 --- a/src/distilabel/distiset.py +++ b/src/distilabel/distiset.py @@ -234,22 +234,22 @@ def create_distiset( # noqa: C901 continue files = [str(file) for file in list_files_in_dir(file)] - try: - if files: + if files: + try: ds = load_dataset( "parquet", name=file.stem, data_files={"train": files} ) if not enable_metadata and DISTILABEL_METADATA_KEY in ds.column_names: ds = ds.remove_columns(DISTILABEL_METADATA_KEY) distiset[file.stem] = ds - else: - logger.warning( - f"No output files for step '{file.stem}', can't create a dataset." - " Did the step produce any data?" - ) - except ArrowInvalid: - logger.warning(f"❌ Failed to load the subset from '{file}' directory.") - continue + except ArrowInvalid: + logger.warning(f"❌ Failed to load the subset from '{file}' directory.") + continue + else: + logger.warning( + f"No output files for step '{file.stem}', can't create a dataset." + " Did the step produce any data?" + ) # If there's only one dataset i.e. one config, then set the config name to `default` if len(distiset.keys()) == 1: diff --git a/src/distilabel/pipeline/_dag.py b/src/distilabel/pipeline/_dag.py index 86c4e80bf1..873fe9e935 100644 --- a/src/distilabel/pipeline/_dag.py +++ b/src/distilabel/pipeline/_dag.py @@ -30,6 +30,7 @@ import networkx as nx from distilabel.pipeline.constants import ( + CONVERGENCE_STEP_ATTR_NAME, ROUTING_BATCH_FUNCTION_ATTR_NAME, STEP_ATTR_NAME, ) @@ -353,6 +354,9 @@ def _validate_convergence_step( ): return + # Mark the step as a convergence step + self.set_step_attr(step.name, CONVERGENCE_STEP_ATTR_NAME, True) # type: ignore + # Check if all the predecessors of the step are receiving routed batches from the # same step previous_steps_predecessors = [ @@ -431,6 +435,14 @@ def _validate_routing_batch_function( f" from step '{predecessor_step.name}' to step '{step.name}'." ) + if batch_size % step.input_batch_size != 0: # type: ignore + raise ValueError( + f"Step '{step.name}' should have an `input_batch_size` that is a multiple" + f" of the `input_batch_size` or `batch_size` of the previous step." + f" This is because the batches are being routed with a `routing_batch_function`" + f" from step '{predecessor_step.name}' to step '{step.name}'." + ) + return True def _validate_process_step_input_parameter( diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py index 8b20cd04c9..63eb8116d3 100644 --- a/src/distilabel/pipeline/base.py +++ b/src/distilabel/pipeline/base.py @@ -413,12 +413,12 @@ def _cache(self) -> None: """Saves the `BasePipeline` using the `_cache_filename`.""" self.save( path=self._cache_location["pipeline"], - format=self._cache_location["pipeline"].suffix.replace(".", ""), + format=self._cache_location["pipeline"].suffix.replace(".", ""), # type: ignore ) if self._batch_manager is not None: self._batch_manager.save( self._cache_location["batch_manager"], - format=self._cache_location["batch_manager"].suffix.replace(".", ""), + format=self._cache_location["batch_manager"].suffix.replace(".", ""), # type: ignore ) self._logger.debug("Pipeline and batch manager saved to cache.") @@ -428,12 +428,6 @@ def _load_from_cache(self) -> None: """ cache_loc = self._cache_location if cache_loc["pipeline"].exists(): - # Refresh the DAG to avoid errors when it's created within a context manager - # (it will check the steps aren't already defined for the DAG). - self.dag = DAG() - new_class = self.from_yaml(cache_loc["pipeline"]) - # Update the internal dag and batch_manager - self.dag.G = new_class.dag.G if cache_loc["batch_manager"].exists(): self._batch_manager = _BatchManager.from_json( cache_loc["batch_manager"] @@ -453,6 +447,7 @@ class _Batch(_Serializable): accumulated: A flag to indicate if the batch is accumulated. created_from: A dictionary containing the `seq_no` of the batches of the steps that were used to create this batch. + size: The size of the batch. """ seq_no: int @@ -460,8 +455,9 @@ class _Batch(_Serializable): last_batch: bool data: List[List[Dict[str, Any]]] = field(default_factory=list, repr=False) accumulated: bool = False - created_from: Dict[str, List[int]] = field(default_factory=dict) + created_from: Dict[str, List[Tuple[int, int]]] = field(default_factory=dict) batch_routed_to: List[str] = field(default_factory=list) + size: int = 0 def next_batch(self) -> "_Batch": """Create a new `_Batch` instance with the next batch of data. @@ -476,6 +472,15 @@ def next_batch(self) -> "_Batch": seq_no=self.seq_no + 1, step_name=self.step_name, last_batch=self.last_batch ) + def set_data(self, data: List[List[Dict[str, Any]]]) -> None: + """Sets the data of the batch and updates the size of the batch. + + Args: + data: The data of the batch. + """ + self.data = data + self.size = len(data[0]) + @classmethod def accumulate(cls, step_name: str, batches: List[List["_Batch"]]) -> "_Batch": """Creates a `_Batch` instance using the data from the list of batches that @@ -540,6 +545,16 @@ class _BatchManagerStep(_Serializable): convergence_step: A flag to indicate if the step is a convergence step. An `Step` is a convergence step if all its predecessors are receiving routed batches. Defaults to `False`. + convergence_step_batches_consumed: A dictionary in which the key is the `seq_no` + of the batch created by step A, that was used by step B and C and obtained from + the `created_from` of the batches created by them. It's used to know if all + the batches from B and C steps created from batches of A have been consumed + by D, in order to not mess up the order of the batches. Only used if `convergence_step=True`. + Defaults to an empty dictionary. + next_expected_created_from_batch_seq_no: The next expected sequence number of the + batch from step A used by steps B and C and obtained from the `created_from` + of the batches created by them. It's used to avoid messing up the order of the + batches. Only used if `convergence_step=True`. Defaults to `0`. """ step_name: str @@ -549,6 +564,10 @@ class _BatchManagerStep(_Serializable): seq_no: int = 0 last_batch_received: List[str] = field(default_factory=list) convergence_step: bool = False + convergence_step_batches_consumed: Dict[int, Dict[str, int]] = field( + default_factory=dict + ) + next_expected_created_from_batch_seq_no: int = 0 def add_batch(self, batch: _Batch, prepend: bool = False) -> None: """Add a batch of data from `batch.step_name` to the step. It will accumulate the @@ -582,6 +601,7 @@ def get_batch(self) -> Union[_Batch, None]: # `_last_batch` must be called before `_get_data`, as `_get_data` will update the # list of data which is used to determine if the batch to be created is the last one. + # TODO: remove `_last_batch` method and integrate logic in `_get_data` last_batch = self._last_batch() data, created_from, batch_routed_to = self._get_data() @@ -653,7 +673,7 @@ def _get_seq_no(self) -> int: def _get_data( self, - ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[int]], List[str]]: + ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]], List[str]]: """Gets the data needed to create a batch for the step to process. If the step is accumulating data, then it will return a list with all the data received from the predecessors. Otherwise, it will return a list of data with the `input_batch_size` @@ -679,7 +699,7 @@ def _get_data( def _get_data_for_accumulate( self, - ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[int]]]: + ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]]]: """Gets the data needed to create a batch for the step to process when the step is accumulating data. It will return a list with all the data received from the predecessors. In addition, it will remove the data used to create the batch from @@ -695,7 +715,7 @@ def _get_data_for_accumulate( for step_name, batches in self.data.items(): batches_used[step_name] = [] for batch in batches: - batches_used[step_name].append(batch.seq_no) + batches_used[step_name].append((batch.seq_no, batch.size)) data.append([row for batch in batches for row in batch.data[0]]) # Reset the data buffer self.data = {step_name: [] for step_name in self.data} @@ -703,7 +723,7 @@ def _get_data_for_accumulate( def _get_data_for_convergence_step( self, - ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[int]]]: + ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]]]: """Gets the data needed to create a batch for the step to process when the step is a convergence step. @@ -713,25 +733,35 @@ def _get_data_for_convergence_step( used to create the batch. """ grouped_batches = self._group_batches_by_created_from() - _, batches = grouped_batches[0] + seq_no, batches = grouped_batches[0] remaining_rows_per_step = { step_name: self.input_batch_size for step_name in self.data } batches_used = defaultdict(list) data = defaultdict(list) - for batch in batches: + for batch, batch_size in batches: batch_data = batch.data[0] remaining_rows = remaining_rows_per_step[batch.step_name] selected_data = batch_data[:remaining_rows] data[batch.step_name].extend(selected_data) + # If A -> [B, C] -> D, then in D (this step) we keep track of the remaining + # rows from the batches of A that B and C used to create the `batches`. + batch_size = self.convergence_step_batches_consumed.setdefault( + seq_no, {} + ).get(batch.step_name, batch_size) + remaining_rows_in_batch = batch_size - len(selected_data) + self.convergence_step_batches_consumed[seq_no].update( + {batch.step_name: remaining_rows_in_batch} + ) + # Update the remaining rows num_rows = len(selected_data) remaining_rows_per_step[batch.step_name] -= num_rows # type: ignore # Keep track of the batches used to create the batch - batches_used[batch.step_name].append(batch.seq_no) + batches_used[batch.step_name].append((batch.seq_no, batch.size)) # If the batch was entirely consumed, then remove it from the buffer if num_rows >= len(batch_data): @@ -744,11 +774,21 @@ def _get_data_for_convergence_step( batch_ref = self.data[batch.step_name][batch_idx] batch_ref.data[0] = batch_data[len(selected_data) :] + # If all the batches grouped by the `seq_no` in `created_from` were consumed, then + # we can update the `next_expected_created_from_batch_seq_no` to the next one + # to avoid skipping batches. + no_remaining_rows = all( + count == 0 + for count in self.convergence_step_batches_consumed[seq_no].values() + ) + if no_remaining_rows: + self.next_expected_created_from_batch_seq_no += 1 + return list(data.values()), dict(batches_used) def _get_data_normal( self, - ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[int]], List[str]]: + ) -> Tuple[List[List[Dict[str, Any]]], Dict[str, List[Tuple[int, int]]], List[str]]: """Gets the data needed to create a batch for the step to process when the step is not accumulating data. It will return a list of data with the `input_batch_size` for each predecessor. In addition, it will remove the data used to create the batch @@ -771,6 +811,9 @@ def _get_data_normal( idx_drop_batches = [] remaining_rows: int = self.input_batch_size # type: ignore for idx, batch in enumerate(self.data[step_name]): + if remaining_rows == 0: + break + # Get `remaining_rows` or the remaining rows in the batch and add it to # the step data that will be used to create the batch batch_data = batch.data[0] @@ -783,7 +826,7 @@ def _get_data_normal( remaining_rows -= num_rows # Keep track of the batches used to create the batch - batches_used[step_name].append(batch.seq_no) + batches_used[step_name].append((batch.seq_no, batch.size)) # If the batch was entirely consumed, then remove it from the buffer if num_rows >= len(batch_data): @@ -843,19 +886,24 @@ def _ready_to_create_batch_convergence_step(self) -> bool: grouped_batches = self._group_batches_by_created_from() if not grouped_batches: return False - _, batches = grouped_batches[0] + seq_no, batches = grouped_batches[0] + + # If the `seq_no` from the `created_from` field is not the expected one, then + # we cannot create a batch yet or the order will be messed up + if seq_no != self.next_expected_created_from_batch_seq_no: + return False # Not all output batches to which the input batch was routed to haven't been # received - batch_routed_to = batches[0].batch_routed_to - batches_received_from = {batch.step_name for batch in batches} + batch_routed_to = batches[0][0].batch_routed_to + batches_received_from = {batch.step_name for batch, _ in batches} if any(step_name not in batches_received_from for step_name in batch_routed_to): return False # There are output batches to which the input batch was routed to from all # the steps. Check if there is enough data for creating a batch with `input_batch_size` rows_per_step = defaultdict(lambda: 0) - for batch in batches: + for batch, _ in batches: num_rows = len(batch.data[0]) rows_per_step[batch.step_name] += num_rows @@ -931,10 +979,15 @@ def _last_batch_convergence_step(self) -> bool: if not grouped_batches: return False _, batches = grouped_batches[0] - steps_in_batches = {batch.step_name for batch in batches} - return all( - step_name in self.last_batch_received for step_name in steps_in_batches - ) + + for batch, _ in batches: + if not batch.last_batch: + return False + + if len(batch.data[0]) > self.input_batch_size: # type: ignore + return False + + return True def _last_batch_normal(self) -> bool: """Checks if the batch to be created is the last one for a normal step. `True` if @@ -958,7 +1011,9 @@ def _last_batch_normal(self) -> bool: return True - def _group_batches_by_created_from(self) -> List[Tuple[int, List["_Batch"]]]: + def _group_batches_by_created_from( + self, + ) -> List[Tuple[int, List[Tuple["_Batch", int]]]]: """Group the batches by the first key of `created_from` field. This method is meant to be used only with a `convergence_step`. @@ -966,23 +1021,23 @@ def _group_batches_by_created_from(self) -> List[Tuple[int, List["_Batch"]]]: A list of the batches grouped by the `seq_no` of the first step name in `created_from`. The list is sorted by the `seq_no`. """ - grouped_batches = defaultdict(list) + grouped_batches: Dict[int, List[Tuple["_Batch", int]]] = defaultdict(list) for batches in self.data.values(): for batch in batches: first_key = next(iter(batch.created_from)) - batch_seq_no = batch.created_from[first_key][0] - grouped_batches[batch_seq_no].append(batch) + batch_seq_no, batch_size = batch.created_from[first_key][0] + grouped_batches[batch_seq_no].append((batch, batch_size)) return sorted((seq_no, batches) for seq_no, batches in grouped_batches.items()) def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]: """Dumps the content of the `_BatchManagerStep` to a dictionary, using the `dataclass` helper function. Args: - obj (Any): Unused, just kept to match the signature of the parent method. - kwargs (Any): Additional arguments that are kept to match the signature of the parent method. + obj: Unused, just kept to match the signature of the parent method. + kwargs: Additional arguments that are kept to match the signature of the parent method. Returns: - Dict[str, Any]: Internal representation of the `_BatchManagerStep`. + Internal representation of the `_BatchManagerStep`. """ return asdict(self) @@ -1007,7 +1062,7 @@ def __init__( steps: Dict[str, _BatchManagerStep], last_batch_received: Dict[str, Union[_Batch, None]], last_batch_sent: Dict[str, Union[_Batch, None]], - last_batch_flag_sent_to: Optional[List[str]] = None, + last_batch_flag_sent_to: List[str], ) -> None: """Initialize the `_BatchManager` instance. @@ -1019,12 +1074,9 @@ def __init__( last_batch_sent: A dictionary with the step name as the key and a the last `_Batch` sent to the step. last_batch_flag_sent_to: A list with the names of the steps to which `LAST_BATCH_SENT_FLAG` - was sent. Defaults to `None`. + was sent. """ - if last_batch_flag_sent_to is None: - last_batch_flag_sent_to = [] - self._steps = steps self._last_batch_received = last_batch_received self._last_batch_sent = last_batch_sent @@ -1039,14 +1091,15 @@ def can_generate(self) -> bool: """ for step_name, batch in self._last_batch_received.items(): - if not batch: - return True + if step_name not in self._last_batch_flag_sent_to: + if not batch: + return True - if not batch.last_batch and step_name not in self._last_batch_flag_sent_to: - return True + if not batch.last_batch: + return True - if not self.get_last_batch_sent(step_name): - return True + if not self.get_last_batch_sent(step_name): + return True return False @@ -1172,7 +1225,7 @@ def from_dag(cls, dag: "DAG") -> "_BatchManager": convergence_step=convergence_step, ) steps[step_name] = batch_manager_step - return cls(steps, last_batch_received, last_batch_sent) + return cls(steps, last_batch_received, last_batch_sent, []) def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]: """Dumps the content of the `_BatchManager` to a dictionary. diff --git a/src/distilabel/pipeline/constants.py b/src/distilabel/pipeline/constants.py index 829aeb3993..450ef0ed6d 100644 --- a/src/distilabel/pipeline/constants.py +++ b/src/distilabel/pipeline/constants.py @@ -15,5 +15,7 @@ from typing import Final STEP_ATTR_NAME: Final[str] = "step" +INPUT_QUEUE_ATTR_NAME: Final[str] = "input_queue" RECEIVES_ROUTED_BATCHES_ATTR_NAME: Final[str] = "receives_routed_batches" ROUTING_BATCH_FUNCTION_ATTR_NAME: Final[str] = "routing_batch_function" +CONVERGENCE_STEP_ATTR_NAME: Final[str] = "convergence_step" diff --git a/src/distilabel/pipeline/local.py b/src/distilabel/pipeline/local.py index 94401eb38f..5bb7fd7991 100644 --- a/src/distilabel/pipeline/local.py +++ b/src/distilabel/pipeline/local.py @@ -32,6 +32,8 @@ _WriteBuffer, ) from distilabel.pipeline.constants import ( + CONVERGENCE_STEP_ATTR_NAME, + INPUT_QUEUE_ATTR_NAME, ROUTING_BATCH_FUNCTION_ATTR_NAME, STEP_ATTR_NAME, ) @@ -44,7 +46,7 @@ from queue import Queue from distilabel.distiset import Distiset - from distilabel.steps.base import GeneratorStep + from distilabel.steps.base import GeneratorStep, _Step _STEPS_LOADED_KEY = "steps_loaded" @@ -188,7 +190,7 @@ def _notify_steps_to_stop(self) -> None: """Notifies the steps to stop their infinite running loop by sending `None` to their input queues.""" for step_name in self.dag: - if input_queue := self.dag.get_step(step_name).get("input_queue"): + if input_queue := self.dag.get_step(step_name).get(INPUT_QUEUE_ATTR_NAME): input_queue.put(None) def _output_queue_loop(self, write_buffer: "_WriteBuffer") -> None: @@ -236,16 +238,27 @@ def _manage_batch_flow(self, batch: "_Batch") -> None: """ assert self._batch_manager, "Batch manager is not set" + # Make sure to send the `LAST_BATCH_SENT_FLAG` to the predecessors of the convergence + # step if the batch is the last one, so they stop their processing loop even if + # they haven't received the last batch because of the routing function. + if self._is_convergence_step(batch.step_name) and batch.last_batch: + for step_name in self.dag.get_step_predecessors(batch.step_name): + self._send_last_batch_flag_to_step(step_name) + + route_to, routed = self._get_successors(batch) + + # Keep track of the steps that the batch was routed to + if routed: + batch.batch_routed_to = route_to + self._register_batch(batch) - successors, route_to, routed = self._get_successors(batch) + step = self._get_step_from_batch(batch) # Add the batch to the successors input buffers for successor in route_to: # Copy batch to avoid modifying the same reference in the batch manager batch_to_add = batch.copy() if len(route_to) > 1 else batch - if routed: - batch_to_add.batch_routed_to = route_to self._batch_manager.add_batch(successor, batch_to_add) @@ -256,23 +269,14 @@ def _manage_batch_flow(self, batch: "_Batch") -> None: step.is_generator and step.name in self._batch_manager.step_empty_buffers(successor) ): - last_batch = self._batch_manager.get_last_batch_sent(step.name) - self._send_batch_to_step(last_batch.next_batch()) # type: ignore + last_batch_sent = self._batch_manager.get_last_batch_sent(step.name) + self._send_batch_to_step(last_batch_sent.next_batch()) # type: ignore # If successor step has enough data in its buffer to create a new batch, then # send the batch to the step. if new_batch := self._batch_manager.get_batch(successor): self._send_batch_to_step(new_batch) - # If `last_batch` was not routed to all the successors of the step, that - # means that the batch was routed to specific steps using a routing function. - # We have to send the `LAST_BATCH_SENT_FLAG` to the steps that the batch - # was not routed to, so they can stop processing batches. - not_routed_to = [s for s in successors if s not in route_to] - if batch.last_batch and len(not_routed_to): - for step_name in not_routed_to: - self._send_last_batch_flag_to_step(step_name) - if step.is_generator: return @@ -280,9 +284,8 @@ def _manage_batch_flow(self, batch: "_Batch") -> None: # buffers to create a new batch if new_batch := self._batch_manager.get_batch(step.name): # type: ignore self._send_batch_to_step(new_batch) - return - - self._request_more_batches_if_needed(step) + else: + self._request_more_batches_if_needed(step) self._cache() @@ -298,15 +301,15 @@ def _register_batch(self, batch: "_Batch") -> None: " manager" ) - def _get_successors(self, batch: "_Batch") -> Tuple[List[str], List[str], bool]: + def _get_successors(self, batch: "_Batch") -> Tuple[List[str], bool]: """Gets the successors and the successors to which the batch has to be routed. Args: batch: The batch to which the successors will be determined. Returns: - The successors, the successors to route the batch to, and whether the batch was - routed using a routing function. + The successors to route the batch to and whether the batch was routed using + a routing function. """ node = self.dag.get_step(batch.step_name) step: "Step" = node[STEP_ATTR_NAME] @@ -321,7 +324,7 @@ def _get_successors(self, batch: "_Batch") -> Tuple[List[str], List[str], bool]: f"🚏 Using '{step.name}' routing function to send batch {batch.seq_no} to steps: {successors_str}" ) - return successors, route_to, route_to != successors + return route_to, route_to != successors def _get_step_from_batch(self, batch: "_Batch") -> "Step": """Gets the `Step` instance from a batch. @@ -366,12 +369,17 @@ def _handle_stop(self, write_buffer: "_WriteBuffer") -> None: """ self._logger.debug("Handling stop of the pipeline execution...") - # Send `None` to the input queues of all the steps to notify them to stop - # processing batches. + # Add the remaining batches in the input queues back to the batch manager for step_name in self.dag: - if input_queue := self.dag.get_step(step_name).get("input_queue"): + node = self.dag.get_step(step_name) + step: "_Step" = node[STEP_ATTR_NAME] + if step.is_generator: + continue + if input_queue := node.get(INPUT_QUEUE_ATTR_NAME): while not input_queue.empty(): batch = input_queue.get() + if batch is None: + continue self._batch_manager.add_batch( # type: ignore to_step=step_name, batch=batch, prepend=True ) @@ -389,8 +397,12 @@ def _handle_stop(self, write_buffer: "_WriteBuffer") -> None: # processed by the steps before stop was called. while not self.output_queue.empty(): batch = self.output_queue.get() + if batch is None: + continue + if batch.step_name in self.dag.leaf_steps: write_buffer.add_batch(batch) + self._handle_batch_on_stop(batch) self._cache() @@ -419,7 +431,7 @@ def _wait_step_input_queue_empty(self, step_name: str) -> Union["Queue[Any]", No if self._check_step_not_loaded_or_finished(step_name): return None - if input_queue := self.dag.get_step(step_name).get("input_queue"): + if input_queue := self.dag.get_step(step_name).get(INPUT_QUEUE_ATTR_NAME): while input_queue.qsize() != 0: pass return input_queue @@ -520,10 +532,23 @@ def _send_batch_to_step(self, batch: "_Batch") -> None: self._logger.debug( f"Sending batch {batch.seq_no} to step '{batch.step_name}': {batch}" ) - input_queue = self.dag.get_step(batch.step_name)["input_queue"] + input_queue = self.dag.get_step(batch.step_name)[INPUT_QUEUE_ATTR_NAME] input_queue.put(batch) + def _is_convergence_step(self, step_name: str) -> None: + """Checks if a step is a convergence step. + + Args: + step_name: The name of the step. + """ + return self.dag.get_step(step_name).get(CONVERGENCE_STEP_ATTR_NAME) + def _send_last_batch_flag_to_step(self, step_name: str) -> None: + """Sends the `LAST_BATCH_SENT_FLAG` to a step to stop processing batches. + + Args: + step_name: The name of the step. + """ batch = self._batch_manager.get_last_batch_sent(step_name) # type: ignore if batch and batch.last_batch: return @@ -532,7 +557,7 @@ def _send_last_batch_flag_to_step(self, step_name: str) -> None: f"Sending `LAST_BATCH_SENT_FLAG` to '{step_name}' step to stop processing" " batches..." ) - input_queue = self.dag.get_step(step_name)["input_queue"] + input_queue = self.dag.get_step(step_name)[INPUT_QUEUE_ATTR_NAME] input_queue.put(LAST_BATCH_SENT_FLAG) self._batch_manager.set_last_batch_flag_sent_to(step_name) # type: ignore @@ -559,7 +584,7 @@ def _run_steps_in_loop( for step_name in self.dag: step: "Step" = self.dag.get_step(step_name)[STEP_ATTR_NAME] input_queue = manager.Queue() - self.dag.set_step_attr(step.name, "input_queue", input_queue) + self.dag.set_step_attr(step.name, INPUT_QUEUE_ATTR_NAME, input_queue) # Set `pipeline` to `None` as in some Python environments the pipeline is not # picklable and it will raise an error when trying to send the step to the process. @@ -889,7 +914,7 @@ def _generator_step_process_loop(self) -> None: ) for data, last_batch in step.process_applying_mappings(offset=offset): - batch.data = [data] + batch.set_data([data]) batch.last_batch = self._dry_run or last_batch self._send_batch(batch) @@ -962,7 +987,7 @@ def _non_generator_process_loop(self) -> None: f"Subprocess traceback:\n\n{traceback.format_exc()}" ) finally: - batch.data = [result] + batch.set_data([result]) self._send_batch(batch) if batch.last_batch: diff --git a/src/distilabel/utils/serialization.py b/src/distilabel/utils/serialization.py index 276cac1f2a..b97669b809 100644 --- a/src/distilabel/utils/serialization.py +++ b/src/distilabel/utils/serialization.py @@ -277,6 +277,14 @@ def from_yaml(cls, path: StrOrPath) -> Self: @classmethod def from_file(cls, path: StrOrPath) -> Self: + """Loads a class from a file. + + Args: + path: the path to the file containing the serialized class. + + Returns: + An instance of the class. + """ path = Path(path) if path.suffix == ".json": return cls.from_json(path) diff --git a/tests/integration/test_pipe_simple.py b/tests/integration/test_pipe_simple.py index 28388b8f77..8ab1fff29a 100644 --- a/tests/integration/test_pipe_simple.py +++ b/tests/integration/test_pipe_simple.py @@ -157,10 +157,6 @@ def run_pipeline(): return pipeline.run( parameters={ - "load_dataset": { - "repo_id": "plaguss/test", - "split": "train", - }, "rename_columns": { "rename_mappings": { "prompt": "instruction", diff --git a/tests/integration/test_routing_batch_function.py b/tests/integration/test_routing_batch_function.py index 2fbaa508e3..0ea2ee3cdc 100644 --- a/tests/integration/test_routing_batch_function.py +++ b/tests/integration/test_routing_batch_function.py @@ -16,6 +16,7 @@ import time from typing import TYPE_CHECKING, List +import pytest from distilabel.pipeline import Pipeline, routing_batch_function from distilabel.steps import LoadDataFromDicts, StepInput, step @@ -41,22 +42,31 @@ def Generate(inputs: StepInput) -> "StepOutput": yield inputs +@step(outputs=["generations"]) +def Generate2(inputs: StepInput) -> "StepOutput": + sleep_time = random.uniform(1.0, 2.0) + time.sleep(sleep_time) + for input in inputs: + input["2generation"] = "I slept for {} seconds".format(sleep_time) + yield inputs + + @step(outputs=["generations"]) def CombineGenerations(*inputs: StepInput) -> "StepOutput": + generation_key = ( + "2generation" if "2generation" in inputs[0][0].keys() else "generation" + ) + combined_list = [] for rows in zip(*inputs): combined_dict = { "index": rows[0]["index"], - "instruction": rows[0]["instruction"], - "generations": [row["generation"] for row in rows], + "instruction": [row["instruction"] for row in rows], + f"{generation_key}s": [row[generation_key] for row in rows], } # Check consistency in "index" and "instruction" - if any( - row["index"] != combined_dict["index"] - or row["instruction"] != combined_dict["instruction"] - for row in rows - ): + if any(row["index"] != combined_dict["index"] for row in rows): raise ValueError("Inconsistent 'index' or 'instruction'") combined_list.append(combined_dict) @@ -64,6 +74,7 @@ def CombineGenerations(*inputs: StepInput) -> "StepOutput": yield combined_list +@pytest.mark.timeout(120) def test_routing_batch_function() -> None: with Pipeline(name="test") as pipeline: load_dataset = LoadDataFromDicts( @@ -80,3 +91,75 @@ def test_routing_batch_function() -> None: for i, row in enumerate(distiset["default"]["train"]): assert row["index"] == i + assert row["instruction"] == [f"Instruction {i}", f"Instruction {i}"] + assert len(row["generations"]) == 2 + + +@pytest.mark.timeout(120) +def test_routing_batch_function_irregular_batch_sizes() -> None: + with Pipeline(name="test") as pipeline: + load_dataset = LoadDataFromDicts( + data=[{"index": i, "instruction": f"Instruction {i}"} for i in range(1000)], + batch_size=200, + ) + + generates = [ + Generate(input_batch_size=input_batch_size) + for input_batch_size in [25, 50, 100, 200] + ] + + combine_generations = CombineGenerations(input_batch_size=25) + + load_dataset >> random_routing_batch >> generates >> combine_generations + + distiset = pipeline.run(use_cache=False) + + for i, row in enumerate(distiset["default"]["train"]): + assert row["index"] == i + assert row["instruction"] == [f"Instruction {i}", f"Instruction {i}"] + assert len(row["generations"]) == 2 + + +@pytest.mark.timeout(120) +def test_multiple_routing_batch_function() -> None: + batch_size = 200 + + with Pipeline(name="test") as pipeline: + load_dataset = LoadDataFromDicts( + data=[ + { + "index": i, + "instruction": f"Instruction {i}", + "batch": i // batch_size, + } + for i in range(1000) + ], + batch_size=batch_size, + ) + + generates = [ + Generate(input_batch_size=input_batch_size) + for input_batch_size in [25, 50, 100, 200] + ] + + combine_generations = CombineGenerations(input_batch_size=25) + + generates2 = [Generate2(input_batch_size=25) for _ in range(4)] + + combine_generations_2 = CombineGenerations(input_batch_size=25) + + ( + load_dataset + >> random_routing_batch + >> generates + >> combine_generations + >> random_routing_batch + >> generates2 + >> combine_generations_2 + ) + + distiset = pipeline.run(use_cache=False) + + for i, row in enumerate(distiset["default"]["train"]): + assert row["index"] == i + assert len(row["2generations"]) == 2 diff --git a/tests/unit/pipeline/test_base.py b/tests/unit/pipeline/test_base.py index 6c7dd33b6b..76a48ec4be 100644 --- a/tests/unit/pipeline/test_base.py +++ b/tests/unit/pipeline/test_base.py @@ -261,6 +261,14 @@ def test_infer_step_names_big_pipeline(self) -> None: class TestBatch: + def test_set_data(self) -> None: + batch = _Batch(seq_no=0, step_name="step1", last_batch=False) + data = [[{"i": i} for i in range(5000)]] + batch.set_data(data) + + assert batch.data == data + assert batch.size == 5000 + def test_next_batch(self) -> None: batch = _Batch(seq_no=0, step_name="step1", last_batch=False) next_batch = batch.next_batch() @@ -313,6 +321,7 @@ def test_dump(self) -> None: batch = _Batch(seq_no=0, step_name="step1", last_batch=False) assert batch.dump() == { "seq_no": 0, + "size": 0, "step_name": "step1", "last_batch": False, "data": [], @@ -333,6 +342,7 @@ def test_dump(self) -> None: ) assert batch.dump() == { "seq_no": 0, + "size": 0, "step_name": "step1", "last_batch": False, "data": [[{"a": 1}, {"a": 2}, {"a": 3}]], @@ -441,6 +451,7 @@ def test_get_batch(self) -> None: {"a": 5}, ] ], + size=5, ) ], "step2": [ @@ -458,6 +469,7 @@ def test_get_batch(self) -> None: {"b": 6}, ] ], + size=5, ) ], }, @@ -479,7 +491,7 @@ def test_get_batch(self) -> None: {"b": 2}, ], ], - created_from={"step1": [0], "step2": [0]}, + created_from={"step1": [(0, 5)], "step2": [(0, 5)]}, ) batch = batch_manager_step.get_batch() @@ -498,7 +510,7 @@ def test_get_batch(self) -> None: {"b": 4}, ], ], - created_from={"step1": [0], "step2": [0]}, + created_from={"step1": [(0, 5)], "step2": [(0, 5)]}, ) def test_get_batches_accumulate(self) -> None: @@ -520,6 +532,7 @@ def test_get_batches_accumulate(self) -> None: {"a": 5}, ] ], + size=5, ) ], "step2": [ @@ -537,6 +550,7 @@ def test_get_batches_accumulate(self) -> None: {"b": 6}, ] ], + size=6, ) ], }, @@ -567,7 +581,7 @@ def test_get_batches_accumulate(self) -> None: {"b": 6}, ], ], - created_from={"step1": [0], "step2": [0]}, + created_from={"step1": [(0, 5)], "step2": [(0, 6)]}, ) def test_get_batches_not_enough_data(self) -> None: @@ -654,6 +668,7 @@ def test_get_data(self) -> None: data=[ [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}] ], + size=6, batch_routed_to=["step1", "step2"], ) ], @@ -673,6 +688,7 @@ def test_get_data(self) -> None: {"b": 7}, ] ], + size=7, batch_routed_to=["step1", "step2"], ) ], @@ -684,7 +700,7 @@ def test_get_data(self) -> None: [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}], [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}], ] - assert created_from == {"step1": [0], "step2": [0]} + assert created_from == {"step1": [(0, 6)], "step2": [(0, 7)]} assert routed_to == ["step1", "step2"] assert batch_manager_step.data == { @@ -694,6 +710,7 @@ def test_get_data(self) -> None: step_name="step1", last_batch=False, data=[[{"a": 6}]], + size=6, batch_routed_to=["step1", "step2"], ) ], @@ -703,6 +720,7 @@ def test_get_data(self) -> None: step_name="step2", last_batch=False, data=[[{"b": 6}, {"b": 7}]], + size=7, batch_routed_to=["step1", "step2"], ) ], @@ -721,6 +739,7 @@ def test_get_data_accumulate(self) -> None: data=[ [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}] ], + size=6, ) ], "step2": [ @@ -739,6 +758,7 @@ def test_get_data_accumulate(self) -> None: {"b": 7}, ] ], + size=7, ) ], }, @@ -750,7 +770,7 @@ def test_get_data_accumulate(self) -> None: [{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}], [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}, {"b": 7}], ] - assert created_from == {"step1": [0], "step2": [0]} + assert created_from == {"step1": [(0, 6)], "step2": [(0, 7)]} assert routed_to == [] assert batch_manager_step.data == {"step1": [], "step2": []} @@ -767,7 +787,8 @@ def test_get_data_convergence_step(self) -> None: {"generation": "Hello, I'm A 0"}, ] ], - created_from={"Z": [0]}, + size=3, + created_from={"Z": [(0, 3)]}, ) batch_a_1 = _Batch( @@ -781,7 +802,8 @@ def test_get_data_convergence_step(self) -> None: {"generation": "Hello, I'm A 1"}, ] ], - created_from={"Z": [1]}, + size=3, + created_from={"Z": [(1, 3)]}, ) batch_b_0 = _Batch( @@ -795,7 +817,8 @@ def test_get_data_convergence_step(self) -> None: {"generation": "Hello, I'm B 0"}, ] ], - created_from={"Z": [0]}, + size=3, + created_from={"Z": [(0, 3)]}, ) batch_c_0 = _Batch( @@ -809,10 +832,11 @@ def test_get_data_convergence_step(self) -> None: {"generation": "Hello, I'm C 0"}, ] ], - created_from={"Z": [1]}, + size=3, + created_from={"Z": [(1, 3)]}, ) - bath_manager_step = _BatchManagerStep( + batch_manager_step = _BatchManagerStep( step_name="D", input_batch_size=3, convergence_step=True, @@ -820,7 +844,7 @@ def test_get_data_convergence_step(self) -> None: data={"A": [batch_a_0, batch_a_1], "B": [batch_b_0], "C": [batch_c_0]}, ) - data, created_from, routed_to = bath_manager_step._get_data() + data, created_from, routed_to = batch_manager_step._get_data() assert data == [ [ @@ -834,10 +858,11 @@ def test_get_data_convergence_step(self) -> None: {"generation": "Hello, I'm B 0"}, ], ] - assert created_from == {"A": [0], "B": [0]} + assert created_from == {"A": [(0, 3)], "B": [(0, 3)]} assert routed_to == [] + assert batch_manager_step.next_expected_created_from_batch_seq_no == 1 - data, created_from, routed_to = bath_manager_step._get_data() + data, created_from, routed_to = batch_manager_step._get_data() assert data == [ [ @@ -851,8 +876,9 @@ def test_get_data_convergence_step(self) -> None: {"generation": "Hello, I'm C 0"}, ], ] - assert created_from == {"A": [1], "C": [0]} + assert created_from == {"A": [(1, 3)], "C": [(0, 3)]} assert routed_to == [] + assert batch_manager_step.next_expected_created_from_batch_seq_no == 2 @pytest.mark.parametrize( "data, last_batch_received, expected", @@ -1034,6 +1060,100 @@ def test_last_batch_accumulate( assert batch_manager_step._last_batch() is expected + @pytest.mark.parametrize( + "data, last_batch_received, expected", + [ + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + created_from={"step0": [(0, 5)]}, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + created_from={"step0": [(0, 5)]}, + ) + ], + }, + [], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], + created_from={"step0": [(0, 5)]}, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + created_from={"step0": [(0, 5)]}, + ) + ], + }, + [], + False, + ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"a": 1}, {"a": 2}, {"a": 3}]], + created_from={"step0": [(0, 3)]}, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=True, + data=[[{"b": 1}, {"b": 2}, {"b": 3}]], + created_from={"step0": [(0, 3)]}, + ) + ], + }, + [], + True, + ), + ], + ) + def test_last_batch_convergence_step( + self, + data: Dict[str, List[_Batch]], + last_batch_received: List[str], + expected: bool, + ) -> None: + batch_manager_step = _BatchManagerStep( + step_name="step3", + accumulate=False, + data=data, + last_batch_received=last_batch_received, + input_batch_size=3, + convergence_step=True, + ) + + assert batch_manager_step._last_batch() is expected + @pytest.mark.parametrize( "data, last_batch_received, expected", [ @@ -1206,6 +1326,7 @@ def test_dump(self) -> None: step_name="step1", last_batch=True, data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}, {"a": 6}]], + size=6, ) batch_step_2 = _Batch( seq_no=0, @@ -1214,6 +1335,7 @@ def test_dump(self) -> None: data=[ [{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}, {"b": 6}, {"b": 7}] ], + size=7, ) batch_manager_step = _BatchManagerStep( step_name="step3", @@ -1227,6 +1349,7 @@ def test_dump(self) -> None: "step_name": "step3", "accumulate": True, "convergence_step": False, + "convergence_step_batches_consumed": {}, "input_batch_size": None, "data": { "step1": [ @@ -1244,6 +1367,7 @@ def test_dump(self) -> None: {"a": 6}, ] ], + "size": 6, "accumulated": False, "created_from": {}, "batch_routed_to": [], @@ -1265,6 +1389,7 @@ def test_dump(self) -> None: {"b": 7}, ] ], + "size": 7, "accumulated": False, "created_from": {}, "batch_routed_to": [], @@ -1273,6 +1398,7 @@ def test_dump(self) -> None: }, "seq_no": 0, "last_batch_received": [], + "next_expected_created_from_batch_seq_no": 0, "type_info": { "module": "distilabel.pipeline.base", "name": "_BatchManagerStep", @@ -1291,6 +1417,7 @@ def test_dump(self) -> None: last_batch=False, data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 5)]}, ) ], "step2": [], @@ -1307,6 +1434,7 @@ def test_dump(self) -> None: last_batch=False, data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}, {"a": 5}]], batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 5)]}, ) ], "step2": [ @@ -1316,6 +1444,7 @@ def test_dump(self) -> None: last_batch=False, data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 5)]}, ) ], }, @@ -1331,6 +1460,7 @@ def test_dump(self) -> None: last_batch=False, data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 4)]}, ) ], "step2": [ @@ -1340,6 +1470,7 @@ def test_dump(self) -> None: last_batch=False, data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 5)]}, ) ], }, @@ -1355,6 +1486,7 @@ def test_dump(self) -> None: last_batch=True, data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 4)]}, ) ], "step2": [ @@ -1364,12 +1496,39 @@ def test_dump(self) -> None: last_batch=True, data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}]], batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 4)]}, ) ], }, ["step1", "step2"], True, ), + ( + { + "step1": [ + _Batch( + seq_no=0, + step_name="step1", + last_batch=False, + data=[[{"a": 1}, {"a": 2}, {"a": 3}, {"a": 4}]], + batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 4)]}, + ) + ], + "step2": [ + _Batch( + seq_no=0, + step_name="step2", + last_batch=False, + data=[[{"b": 1}, {"b": 2}, {"b": 3}, {"b": 4}, {"b": 5}]], + batch_routed_to=["step1", "step2"], + created_from={"step0": [(0, 5)]}, + ) + ], + }, + [], + False, + ), ], ) def test_ready_to_create_batch_convergence_step( @@ -1384,6 +1543,7 @@ def test_ready_to_create_batch_convergence_step( input_batch_size=5, data=data, last_batch_received=last_batch_received, + convergence_step=True, ) assert batch_manager_step._ready_to_create_batch() is expected @@ -1394,6 +1554,7 @@ def test_from_dict(self) -> None: "step_name": "step3", "accumulate": True, "convergence_step": False, + "convergence_step_batches_consumed": {0: {"Z": 1234}}, "input_batch_size": None, "data": { "step1": [ @@ -1411,6 +1572,7 @@ def test_from_dict(self) -> None: {"a": 6}, ] ], + "size": 6, "accumulated": False, "created_from": {}, "batch_routed_to": [], @@ -1432,6 +1594,7 @@ def test_from_dict(self) -> None: {"b": 7}, ] ], + "size": 7, "accumulated": False, "created_from": {}, "batch_routed_to": [], @@ -1451,6 +1614,7 @@ def test_from_dict(self) -> None: assert batch_manager_step.step_name == "step3" assert batch_manager_step.accumulate is True assert batch_manager_step.convergence_step is False + assert batch_manager_step.convergence_step_batches_consumed == {0: {"Z": 1234}} assert batch_manager_step.input_batch_size is None assert batch_manager_step.seq_no == 0 assert batch_manager_step.last_batch_received == [] @@ -1469,6 +1633,7 @@ def test_add_batch(self) -> None: }, last_batch_received={"step3": None}, last_batch_sent={"step3": None}, + last_batch_flag_sent_to=[], ) batch_from_step_1 = _Batch( @@ -1505,6 +1670,7 @@ def test_add_batch_with_prepend(self) -> None: }, last_batch_received={"step3": None}, last_batch_sent={"step3": None}, + last_batch_flag_sent_to=[], ) batch_0 = _Batch( seq_no=0, @@ -1566,6 +1732,7 @@ def test_can_generate(self) -> None: "step_3": _Batch(seq_no=0, step_name="step_3", last_batch=False), }, last_batch_sent={"step_1": None, "step_2": None, "step_3": None}, + last_batch_flag_sent_to=[], ) assert batch_manager.can_generate() @@ -1582,6 +1749,7 @@ def test_can_generate(self) -> None: "step_3": batch_3, }, last_batch_sent={"step_1": batch_1, "step_2": batch_2, "step_3": batch_3}, + last_batch_flag_sent_to=[], ) assert not batch_manager.can_generate() @@ -1619,10 +1787,12 @@ def test_dump(self) -> None: "step_name": "step3", "accumulate": False, "convergence_step": False, + "convergence_step_batches_consumed": {}, "input_batch_size": 5, "data": {"step1": [], "step2": []}, "seq_no": 1, "last_batch_received": [], + "next_expected_created_from_batch_seq_no": 0, "type_info": { "module": "distilabel.pipeline.base", "name": "_BatchManagerStep", @@ -1637,6 +1807,7 @@ def test_dump(self) -> None: "created_from": {}, "last_batch": False, "data": [], + "size": 0, "accumulated": False, "type_info": { "module": "distilabel.pipeline.base", @@ -1652,6 +1823,7 @@ def test_dump(self) -> None: "created_from": {}, "last_batch": False, "data": [], + "size": 0, "accumulated": False, "type_info": { "module": "distilabel.pipeline.base", diff --git a/tests/unit/pipeline/test_dag.py b/tests/unit/pipeline/test_dag.py index dbdddb4dae..2b38f55abd 100644 --- a/tests/unit/pipeline/test_dag.py +++ b/tests/unit/pipeline/test_dag.py @@ -595,6 +595,30 @@ def routing_batch_function_1(steps: List[str]) -> List[str]: ): pipeline.dag.validate() + def test_validate_step_receiving_routed_batches_input_batch_size_multiple( + self, pipeline: "Pipeline" + ) -> None: + generator_step_1 = DummyGeneratorStep(pipeline=pipeline) + dummy_step_1 = DummyStep1(pipeline=pipeline) + dummy_step_2 = DummyStep1(name="demon", pipeline=pipeline, input_batch_size=7) + + @routing_batch_function() + def routing_batch_function_1(steps: List[str]) -> List[str]: + return steps + + convergence_step = DummyStep2(name="convergence_step", pipeline=pipeline) + ( + generator_step_1 + >> routing_batch_function_1 + >> [dummy_step_1, dummy_step_2] + >> convergence_step + ) + with pytest.raises( + ValueError, + match="Step 'demon' should have an `input_batch_size` that is a multiple of the `input_batch_size` or `batch_size`", + ): + pipeline.dag.validate() + class TestDagSerialization: def test_dag_dump(self, dummy_step_1: "Step", dummy_step_2: "Step") -> None: