diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py index 911273c2c9..324b3f23c9 100644 --- a/src/distilabel/pipeline/base.py +++ b/src/distilabel/pipeline/base.py @@ -18,7 +18,6 @@ import signal import threading import time -import uuid from abc import ABC, abstractmethod from pathlib import Path from typing import ( @@ -131,6 +130,7 @@ def get_pipeline(cls) -> Union["BasePipeline", None]: _STEP_NOT_LOADED_CODE = -999 _ATTRIBUTES_IGNORED_CACHE = ("disable_cuda_device_placement",) +_PIPELINE_DEFAULT_NAME = "__default_pipeline_name__" class BasePipeline(ABC, RequirementsMixin, _Serializable): @@ -189,7 +189,7 @@ def __init__( Defaults to `None`, but can be helpful to inform in a pipeline to be shared that this requirements must be installed. """ - self.name = name or f"pipeline_{str(uuid.uuid4())[:8]}" + self.name = name or _PIPELINE_DEFAULT_NAME self.description = description self._enable_metadata = enable_metadata self.dag = DAG() @@ -235,6 +235,12 @@ def __enter__(self) -> Self: def __exit__(self, exc_type, exc_value, traceback) -> None: """Unset the global pipeline instance when exiting a pipeline context.""" _GlobalPipelineManager.set_pipeline(None) + self._set_pipeline_name() + + def _set_pipeline_name(self) -> None: + """Creates a name for the pipeline if it's the default one (if hasn't been set).""" + if self.name == _PIPELINE_DEFAULT_NAME: + self.name = f"pipeline_{'_'.join(self.dag)}" def _create_signature(self) -> str: """Makes a signature (hash) of a pipeline, using the step ids and the adjacency between them. @@ -351,6 +357,13 @@ def run( log_queue=self._log_queue, filename=str(self._cache_location["log_file"]) ) + # Set the name of the pipeline if it's the default one. This should be called + # if the pipeline is defined within the context manager, and the run is called + # outside of it. Is here in the following case: + # with Pipeline() as pipeline: + # pipeline.run() + self._set_pipeline_name() + # Validate the pipeline DAG to check that all the steps are chainable, there are # no missing runtime parameters, batch sizes are correct, etc. self.dag.validate() diff --git a/tests/unit/pipeline/test_base.py b/tests/unit/pipeline/test_base.py index a2a043f737..77bfee600b 100644 --- a/tests/unit/pipeline/test_base.py +++ b/tests/unit/pipeline/test_base.py @@ -105,28 +105,29 @@ def test_context_manager(self) -> None: @pytest.mark.parametrize("use_cache", [False, True]) def test_load_batch_manager(self, use_cache: bool) -> None: - pipeline = DummyPipeline(name="unit-test-pipeline") - pipeline._load_batch_manager(use_cache=True) - pipeline._cache() - - with ( - mock.patch( - "distilabel.pipeline.base._BatchManager.load_from_cache" - ) as mock_load_from_cache, - mock.patch( - "distilabel.pipeline.base._BatchManager.from_dag" - ) as mock_from_dag, - ): - pipeline._load_batch_manager(use_cache=use_cache) - - if use_cache: - mock_load_from_cache.assert_called_once_with( - pipeline._cache_location["batch_manager"] - ) - mock_from_dag.assert_not_called() - else: - mock_load_from_cache.assert_not_called() - mock_from_dag.assert_called_once_with(pipeline.dag) + with tempfile.TemporaryDirectory() as temp_dir: + pipeline = DummyPipeline(name="unit-test-pipeline", cache_dir=temp_dir) + pipeline._load_batch_manager(use_cache=True) + pipeline._cache() + + with ( + mock.patch( + "distilabel.pipeline.base._BatchManager.load_from_cache" + ) as mock_load_from_cache, + mock.patch( + "distilabel.pipeline.base._BatchManager.from_dag" + ) as mock_from_dag, + ): + pipeline._load_batch_manager(use_cache=use_cache) + + if use_cache: + mock_load_from_cache.assert_called_once_with( + pipeline._cache_location["batch_manager"] + ) + mock_from_dag.assert_not_called() + else: + mock_load_from_cache.assert_not_called() + mock_from_dag.assert_called_once_with(pipeline.dag) def test_setup_write_buffer(self) -> None: pipeline = DummyPipeline(name="unit-test-pipeline") @@ -1187,13 +1188,16 @@ def test_pipeline_with_dataset_and_generator_step(self): ) def test_optional_name(self): - import random + from distilabel.pipeline.base import _PIPELINE_DEFAULT_NAME + + assert DummyPipeline().name == _PIPELINE_DEFAULT_NAME - random.seed(42) with DummyPipeline() as pipeline: - name = pipeline.name - assert name.startswith("pipeline") - assert len(name.split("_")[-1]) == 8 + gen_step = DummyGeneratorStep() + step1_0 = DummyStep1() + gen_step >> step1_0 + + assert pipeline.name == "pipeline_dummy_generator_step_0_dummy_step1_0" class TestPipelineSerialization: