Skip to content

Commit

Permalink
Fix error on Pipeline.dry_run without parameters (#655)
Browse files Browse the repository at this point in the history
* Prepare branch for v1.2.0

* Fix error on dry_run method without parameters

* Fix version

* Refactor test to rerun on a new pipeline

---------

Co-authored-by: Alvaro Bartolome <[email protected]>
  • Loading branch information
plaguss and alvarobartt authored May 22, 2024
1 parent b8fe3dc commit a0c9a41
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 32 deletions.
2 changes: 1 addition & 1 deletion src/distilabel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@

from rich import traceback as rich_traceback

__version__ = "1.2.0"
__version__ = "1.1.1"

rich_traceback.install(show_locals=True)
10 changes: 5 additions & 5 deletions src/distilabel/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,13 @@ def dry_run(

for step_name in self.dag:
step = self.dag.get_step(step_name)[STEP_ATTR_NAME]

if step.is_generator:
if parameters.get(step_name) and parameters[step_name].get(
"batch_size"
):
parameters[step_name]["batch_size"] = batch_size
if not parameters:
parameters = {}
parameters[step_name] = {"batch_size": batch_size}

distiset = self.run(parameters, use_cache=False)
distiset = self.run(parameters=parameters, use_cache=False)

self._dry_run = False
return distiset
Expand Down
53 changes: 27 additions & 26 deletions tests/integration/test_dry_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,36 @@ def SucceedAlways(inputs: StepInput) -> "StepOutput":


def test_dry_run():
with Pipeline(name="other-pipe") as pipeline:
load_dataset = LoadDataFromDicts(
data=[
{"instruction": "Tell me a joke."},
]
* 50,
batch_size=20,
)
text_generation = SucceedAlways()

load_dataset >> text_generation

distiset = pipeline.dry_run(parameters={load_dataset.name: {"batch_size": 8}})
assert len(distiset["default"]["train"]) == 1
load_dataset_name = "load_dataset"

def get_pipeline():
with Pipeline(name="other-pipe") as pipeline:
load_dataset = LoadDataFromDicts(
name=load_dataset_name,
data=[
{"instruction": "Tell me a joke."},
]
* 50,
batch_size=20,
)
text_generation = SucceedAlways()

load_dataset >> text_generation
return pipeline

# Test with and without parameters
pipeline = get_pipeline()
distiset = pipeline.dry_run(batch_size=2)
assert len(distiset["default"]["train"]) == 2
assert pipeline._dry_run is False

with Pipeline(name="other-pipe") as pipeline:
load_dataset = LoadDataFromDicts(
data=[
{"instruction": "Tell me a joke."},
]
* 50,
batch_size=20,
)
text_generation = SucceedAlways()

load_dataset >> text_generation
pipeline = get_pipeline()
distiset = pipeline.dry_run(parameters={load_dataset_name: {"batch_size": 8}})
assert len(distiset["default"]["train"]) == 1
assert pipeline._dry_run is False

pipeline = get_pipeline()
distiset = pipeline.run(
parameters={load_dataset.name: {"batch_size": 10}}, use_cache=False
parameters={load_dataset_name: {"batch_size": 10}}, use_cache=False
)
assert len(distiset["default"]["train"]) == 50

0 comments on commit a0c9a41

Please sign in to comment.