Skip to content

Commit

Permalink
Update default names in GroupColumns (#808)
Browse files Browse the repository at this point in the history
* Update default names in GroupColumns

* Fix integration test
  • Loading branch information
plaguss authored Jul 23, 2024
1 parent b7f124f commit a3f6cdd
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 18 deletions.
20 changes: 10 additions & 10 deletions src/distilabel/steps/columns/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ class GroupColumns(Step):
```python
from distilabel.steps import GroupColumns
combine_columns = GroupColumns(
name="combine_columns",
group_columns = GroupColumns(
name="group_columns",
columns=["generation", "model_name"],
)
combine_columns.load()
group_columns.load()
result = next(
combine_columns.process(
group_columns.process(
[{"generation": "AI generated text"}, {"model_name": "my_model"}],
[{"generation": "Other generated text", "model_name": "my_model"}]
)
Expand All @@ -71,15 +71,15 @@ class GroupColumns(Step):
```python
from distilabel.steps import GroupColumns
combine_columns = GroupColumns(
name="combine_columns",
group_columns = GroupColumns(
name="group_columns",
columns=["generation", "model_name"],
output_columns=["generations", "generation_models"]
)
combine_columns.load()
group_columns.load()
result = next(
combine_columns.process(
group_columns.process(
[{"generation": "AI generated text"}, {"model_name": "my_model"}],
[{"generation": "Other generated text", "model_name": "my_model"}]
)
Expand All @@ -100,11 +100,11 @@ def inputs(self) -> List[str]:
@property
def outputs(self) -> List[str]:
"""The outputs for the task are the column names in `output_columns` or
`merged_{column}` for each column in `columns`."""
`grouped_{column}` for each column in `columns`."""
return (
self.output_columns
if self.output_columns is not None
else [f"merged_{column}" for column in self.columns]
else [f"grouped_{column}" for column in self.columns]
)

@override
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/test_branching_missaligmnent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ def test_branching_missalignment_because_step_fails_processing_batch() -> None:
distiset = pipeline.run(use_cache=False)

assert (
distiset["default"]["train"]["merged_response"]
distiset["default"]["train"]["grouped_response"]
== [[None, "This step always succeeds"]] * 20
)
14 changes: 7 additions & 7 deletions tests/unit/steps/columns/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
class TestGroupColumns:
def test_init(self) -> None:
task = GroupColumns(
name="combine-columns",
name="group-columns",
columns=["a", "b"],
pipeline=Pipeline(name="unit-test-pipeline"),
)
assert task.inputs == ["a", "b"]
assert task.outputs == ["merged_a", "merged_b"]
assert task.outputs == ["grouped_a", "grouped_b"]

task = GroupColumns(
name="combine-columns",
name="group-columns",
columns=["a", "b"],
output_columns=["c", "d"],
pipeline=Pipeline(name="unit-test-pipeline"),
Expand All @@ -38,13 +38,13 @@ def test_init(self) -> None:
assert task.outputs == ["c", "d"]

def test_process(self) -> None:
combine = GroupColumns(
name="combine-columns",
group = GroupColumns(
name="group-columns",
columns=["a", "b"],
pipeline=Pipeline(name="unit-test-pipeline"),
)
output = next(combine.process([{"a": 1, "b": 2}], [{"a": 3, "b": 4}]))
assert output == [{"merged_a": [1, 3], "merged_b": [2, 4]}]
output = next(group.process([{"a": 1, "b": 2}], [{"a": 3, "b": 4}]))
assert output == [{"grouped_a": [1, 3], "grouped_b": [2, 4]}]


def test_CombineColumns_deprecation_warning():
Expand Down

0 comments on commit a3f6cdd

Please sign in to comment.