From a4d2f2edefe3e7702625333c62a33f35ce0ccb25 Mon Sep 17 00:00:00 2001 From: Agus Date: Thu, 25 Jul 2024 17:13:04 +0200 Subject: [PATCH] Add default name for a pipeline (#809) * Add default name for a pipeline * Move to uuid instead * Fix test and update final name based on uuid --- src/distilabel/pipeline/base.py | 7 ++++--- tests/unit/pipeline/test_base.py | 9 +++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py index 5c621c2ed5..e854dce49e 100644 --- a/src/distilabel/pipeline/base.py +++ b/src/distilabel/pipeline/base.py @@ -18,6 +18,7 @@ import signal import threading import time +import uuid from abc import ABC, abstractmethod from pathlib import Path from typing import ( @@ -158,7 +159,7 @@ class BasePipeline(ABC, RequirementsMixin, _Serializable): def __init__( self, - name: str, + name: Optional[str] = None, description: Optional[str] = None, cache_dir: Optional[Union[str, "PathLike"]] = None, enable_metadata: bool = False, @@ -167,7 +168,7 @@ def __init__( """Initialize the `BasePipeline` instance. Args: - name: The name of the pipeline. + name: The name of the pipeline. If not generated, a random one will be generated by default. description: A description of the pipeline. Defaults to `None`. cache_dir: A directory where the pipeline will be cached. Defaults to `None`. enable_metadata: Whether to include the distilabel metadata column for the pipeline @@ -178,7 +179,7 @@ def __init__( Defaults to `None`, but can be helpful to inform in a pipeline to be shared that this requirements must be installed. """ - self.name = name + self.name = name or f"pipeline_{str(uuid.uuid4())[:8]}" self.description = description self._enable_metadata = enable_metadata self.dag = DAG() diff --git a/tests/unit/pipeline/test_base.py b/tests/unit/pipeline/test_base.py index 26773f9a4c..d0e954b586 100644 --- a/tests/unit/pipeline/test_base.py +++ b/tests/unit/pipeline/test_base.py @@ -1150,6 +1150,15 @@ def process(self, inputs: StepInput) -> StepOutput: # type: ignore gen_step >> step1_0 >> step2 pipeline.run() + def test_optional_name(self): + import random + + random.seed(42) + with DummyPipeline() as pipeline: + name = pipeline.name + assert name.startswith("pipeline") + assert len(name.split("_")[-1]) == 8 + class TestPipelineSerialization: @pytest.mark.parametrize(