Skip to content

Commit

Permalink
Add CombineOutputs step (#939)
Browse files Browse the repository at this point in the history
* Add `CombineOutputs` step

* Add `merge_distilabel_metadata` function

* Add unit tests

* Add docstrings

* Update docs

* Update src/distilabel/steps/columns/combine.py

Co-authored-by: Agus <[email protected]>

* Update docstrings

* Update mkdocs

---------

Co-authored-by: Agus <[email protected]>
  • Loading branch information
gabrielmbmb and plaguss authored Sep 2, 2024
1 parent 4556135 commit d5f2ae3
Show file tree
Hide file tree
Showing 15 changed files with 265 additions and 25 deletions.
3 changes: 0 additions & 3 deletions docs/api/pipeline/utils.md

This file was deleted.

1 change: 1 addition & 0 deletions docs/api/step_gallery/columns.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ This section contains the existing steps intended to be used for common column o
::: distilabel.steps.columns.keep
::: distilabel.steps.columns.merge
::: distilabel.steps.columns.group
::: distilabel.steps.columns.utils
5 changes: 2 additions & 3 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ plugins:
# Members
inherited_members: false # allow looking up inherited methods
members_order: source # order methods according to their order of definition in the source code, not alphabetical order
show_labels : true
show_labels: true
# Docstring
docstring_style: google # more info: https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html
docstring_style: google # more info: https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html
show_if_no_docstring: false
# Signature
separate_signature: false
Expand Down Expand Up @@ -240,7 +240,6 @@ nav:
- Routing Batch Function: "api/pipeline/routing_batch_function.md"
- Typing: "api/pipeline/typing.md"
- Step Wrapper: "api/pipeline/step_wrapper.md"
- Utils: "api/pipeline/utils.md"
- Mixins:
- RuntimeParametersMixin: "api/mixins/runtime_parameters.md"
- RequirementsMixin: "api/mixins/requirements.md"
Expand Down
22 changes: 12 additions & 10 deletions src/distilabel/steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
StepInput,
StepResources,
)
from distilabel.steps.columns.combine import CombineOutputs
from distilabel.steps.columns.expand import ExpandColumns
from distilabel.steps.columns.group import CombineColumns, GroupColumns
from distilabel.steps.columns.keep import KeepColumns
Expand Down Expand Up @@ -54,34 +55,35 @@
__all__ = [
"PreferenceToArgilla",
"TextGenerationToArgilla",
"GeneratorStep",
"GlobalStep",
"Step",
"StepInput",
"StepResources",
"CombineOutputs",
"ExpandColumns",
"CombineColumns",
"GroupColumns",
"KeepColumns",
"MergeColumns",
"CombineColumns",
"ConversationTemplate",
"step",
"DeitaFiltering",
"EmbeddingGeneration",
"FaissNearestNeighbour",
"ExpandColumns",
"ConversationTemplate",
"FormatChatGenerationDPO",
"FormatChatGenerationSFT",
"FormatTextGenerationDPO",
"FormatChatGenerationSFT",
"FormatTextGenerationSFT",
"GeneratorStep",
"GlobalStep",
"KeepColumns",
"LoadDataFromDicts",
"LoadDataFromDisk",
"LoadDataFromFileSystem",
"LoadDataFromHub",
"MinHashDedup",
"make_generator_step",
"PushToHub",
"Step",
"StepInput",
"RewardModelScore",
"TruncateTextColumn",
"GeneratorStepOutput",
"StepOutput",
"step",
]
99 changes: 99 additions & 0 deletions src/distilabel/steps/columns/combine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from distilabel.constants import DISTILABEL_METADATA_KEY
from distilabel.steps.base import Step, StepInput
from distilabel.steps.columns.utils import merge_distilabel_metadata

if TYPE_CHECKING:
from distilabel.steps.typing import StepOutput


class CombineOutputs(Step):
"""Combine the outputs of several upstream steps.
`CombineOutputs` is a `Step` that takes the outputs of several upstream steps and combines
them to generate a new dictionary with all keys/columns of the upstream steps outputs.
Input columns:
- dynamic (based on the upstream `Step`s): All the columns of the upstream steps outputs.
Output columns:
- dynamic (based on the upstream `Step`s): All the columns of the upstream steps outputs.
Categories:
- columns
Examples:
Combine dictionaries of a dataset:
```python
from distilabel.steps import CombineOutputs
combine_outputs = CombineOutputs()
combine_outputs.load()
result = next(
combine_outputs.process(
[{"a": 1, "b": 2}, {"a": 3, "b": 4}],
[{"c": 5, "d": 6}, {"c": 7, "d": 8}],
)
)
# [
# {"a": 1, "b": 2, "c": 5, "d": 6},
# {"a": 3, "b": 4, "c": 7, "d": 8},
# ]
```
Combine upstream steps outputs in a pipeline:
```python
from distilabel.pipeline import Pipeline
from distilabel.steps import CombineOutputs
with Pipeline() as pipeline:
step_1 = ...
step_2 = ...
step_3 = ...
combine = CombineOutputs()
[step_1, step_2, step_3] >> combine
```
"""

def process(self, *inputs: StepInput) -> "StepOutput":
combined_outputs = []
for output_dicts in zip(*inputs):
combined_dict = {}
for output_dict in output_dicts:
combined_dict.update(
{
k: v
for k, v in output_dict.items()
if k != DISTILABEL_METADATA_KEY
}
)

if any(
DISTILABEL_METADATA_KEY in output_dict for output_dict in output_dicts
):
combined_dict[DISTILABEL_METADATA_KEY] = merge_distilabel_metadata(
*output_dicts
)
combined_outputs.append(combined_dict)

yield combined_outputs
3 changes: 3 additions & 0 deletions src/distilabel/steps/columns/expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class ExpandColumns(Step):
Output columns:
- dynamic (determined by `columns` attribute): The expanded columns.
Categories:
- columns
Examples:
Expand the selected columns into multiple rows:
Expand Down
8 changes: 6 additions & 2 deletions src/distilabel/steps/columns/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

from typing_extensions import override

from distilabel.pipeline.utils import group_columns
from distilabel.steps.base import Step, StepInput
from distilabel.steps.columns.utils import group_columns

if TYPE_CHECKING:
from distilabel.steps.typing import StepColumns, StepOutput
Expand All @@ -43,8 +43,12 @@ class GroupColumns(Step):
- dynamic (determined by `columns` and `output_columns` attributes): The columns
that were grouped.
Categories:
- columns
Examples:
Combine columns of a dataset:
Group columns of a dataset:
```python
from distilabel.steps import GroupColumns
Expand Down
3 changes: 3 additions & 0 deletions src/distilabel/steps/columns/keep.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class KeepColumns(Step):
Output columns:
- dynamic (determined by `columns` attribute): The columns that were kept.
Categories:
- columns
Examples:
Select the columns to keep:
Expand Down
5 changes: 4 additions & 1 deletion src/distilabel/steps/columns/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from typing_extensions import override

from distilabel.pipeline.utils import merge_columns
from distilabel.steps.base import Step, StepInput
from distilabel.steps.columns.utils import merge_columns

if TYPE_CHECKING:
from distilabel.steps.typing import StepColumns, StepOutput
Expand Down Expand Up @@ -47,6 +47,9 @@ class MergeColumns(Step):
- dynamic (determined by `columns` and `output_column` attributes): The columns
that were merged.
Categories:
- columns
Examples:
Combine columns in rows of a dataset:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,47 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from distilabel.steps.base import StepInput
from distilabel.constants import DISTILABEL_METADATA_KEY

if TYPE_CHECKING:
from distilabel.steps.base import StepInput


def merge_distilabel_metadata(*output_dicts: Dict[str, Any]) -> Dict[str, Any]:
"""
Merge the `DISTILABEL_METADATA_KEY` from multiple output dictionaries.
Args:
*output_dicts: Variable number of dictionaries containing distilabel metadata.
Returns:
A merged dictionary containing all the distilabel metadata from the input dictionaries.
"""
merged_metadata = defaultdict(list)

for output_dict in output_dicts:
metadata = output_dict.get(DISTILABEL_METADATA_KEY, {})
for key, value in metadata.items():
merged_metadata[key].append(value)

final_metadata = {}
for key, value_list in merged_metadata.items():
if len(value_list) == 1:
final_metadata[key] = value_list[0]
else:
final_metadata[key] = value_list

return final_metadata


def group_columns(
*inputs: StepInput,
*inputs: "StepInput",
group_columns: List[str],
output_group_columns: Optional[List[str]] = None,
) -> StepInput:
) -> "StepInput":
"""Groups multiple list of dictionaries into a single list of dictionaries on the
specified `group_columns`. If `group_columns` are provided, then it will also rename
`group_columns`.
Expand Down Expand Up @@ -49,16 +80,30 @@ def group_columns(
# Use zip to iterate over lists based on their index
for dicts_at_index in zip(*inputs):
combined_dict = {}
metadata_dicts = []
# Iterate over dicts at the same index
for d in dicts_at_index:
# Extract metadata for merging
if DISTILABEL_METADATA_KEY in d:
metadata_dicts.append(
{DISTILABEL_METADATA_KEY: d[DISTILABEL_METADATA_KEY]}
)
# Iterate over key-value pairs in each dict
for key, value in d.items():
if key == DISTILABEL_METADATA_KEY:
continue
# If the key is in the merge_keys, append the value to the existing list
if key in group_columns_dict.keys():
combined_dict.setdefault(group_columns_dict[key], []).append(value)
# If the key is not in the merge_keys, create a new key-value pair
else:
combined_dict[key] = value

if metadata_dicts:
combined_dict[DISTILABEL_METADATA_KEY] = merge_distilabel_metadata(
*metadata_dicts
)

result.append(combined_dict)
return result

Expand Down
3 changes: 3 additions & 0 deletions src/distilabel/steps/embeddings/embedding_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class EmbeddingGeneration(Step):
Output columns:
- embedding (`List[Union[float, int]]`): the generated sentence embedding.
Categories:
- embedding
Examples:
Generate sentence embeddings with Sentence Transformers:
Expand Down
1 change: 1 addition & 0 deletions src/distilabel/utils/mkdocs/components_gallery.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
"scorer": ":octicons-number-16:",
"text-generation": ":material-text-box-edit:",
"text-manipulation": ":material-receipt-text-edit:",
"columns": ":material-table-column:",
}


Expand Down
14 changes: 14 additions & 0 deletions tests/unit/steps/columns/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Loading

0 comments on commit d5f2ae3

Please sign in to comment.