Skip to content

Commit

Permalink
Add default name for a pipeline (#809)
Browse files Browse the repository at this point in the history
* Add default name for a pipeline

* Move to uuid instead

* Fix test and update final name based on uuid
  • Loading branch information
plaguss authored Jul 25, 2024
1 parent 90909ab commit a4d2f2e
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/distilabel/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import signal
import threading
import time
import uuid
from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -158,7 +159,7 @@ class BasePipeline(ABC, RequirementsMixin, _Serializable):

def __init__(
self,
name: str,
name: Optional[str] = None,
description: Optional[str] = None,
cache_dir: Optional[Union[str, "PathLike"]] = None,
enable_metadata: bool = False,
Expand All @@ -167,7 +168,7 @@ def __init__(
"""Initialize the `BasePipeline` instance.
Args:
name: The name of the pipeline.
name: The name of the pipeline. If not generated, a random one will be generated by default.
description: A description of the pipeline. Defaults to `None`.
cache_dir: A directory where the pipeline will be cached. Defaults to `None`.
enable_metadata: Whether to include the distilabel metadata column for the pipeline
Expand All @@ -178,7 +179,7 @@ def __init__(
Defaults to `None`, but can be helpful to inform in a pipeline to be shared
that this requirements must be installed.
"""
self.name = name
self.name = name or f"pipeline_{str(uuid.uuid4())[:8]}"
self.description = description
self._enable_metadata = enable_metadata
self.dag = DAG()
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/pipeline/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,15 @@ def process(self, inputs: StepInput) -> StepOutput: # type: ignore
gen_step >> step1_0 >> step2
pipeline.run()

def test_optional_name(self):
import random

random.seed(42)
with DummyPipeline() as pipeline:
name = pipeline.name
assert name.startswith("pipeline")
assert len(name.split("_")[-1]) == 8


class TestPipelineSerialization:
@pytest.mark.parametrize(
Expand Down

0 comments on commit a4d2f2e

Please sign in to comment.