Skip to content

Commit

Permalink
Feat: parameterize pipeline class in the primary factory method (#1176)
Browse files Browse the repository at this point in the history
* feat: parameterize pipeline class in the primary factory method

* chore: use generic typing

* chore: remove no args overload

* uses TypeVal with default

---------

Co-authored-by: Marcin Rudolf <[email protected]>
  • Loading branch information
z3z1ma and rudolfix authored Apr 8, 2024
1 parent 18665f1 commit 4c6bdbc
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions dlt/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Sequence, cast, overload
from typing import Sequence, Type, cast, overload
from typing_extensions import TypeVar

from dlt.common.schema import Schema
from dlt.common.schema.typing import TColumnSchema, TWriteDisposition, TSchemaContract
Expand All @@ -15,6 +16,8 @@
from dlt.pipeline.progress import _from_name as collector_from_name, TCollectorArg, _NULL_COLLECTOR
from dlt.pipeline.warnings import credentials_argument_deprecated

TPipeline = TypeVar("TPipeline", bound=Pipeline, default=Pipeline)


@overload
def pipeline(
Expand All @@ -29,7 +32,8 @@ def pipeline(
full_refresh: bool = False,
credentials: Any = None,
progress: TCollectorArg = _NULL_COLLECTOR,
) -> Pipeline:
_impl_cls: Type[TPipeline] = Pipeline, # type: ignore[assignment]
) -> TPipeline:
"""Creates a new instance of `dlt` pipeline, which moves the data from the source ie. a REST API to a destination ie. database or a data lake.
#### Note:
Expand Down Expand Up @@ -97,8 +101,9 @@ def pipeline(
full_refresh: bool = False,
credentials: Any = None,
progress: TCollectorArg = _NULL_COLLECTOR,
_impl_cls: Type[TPipeline] = Pipeline, # type: ignore[assignment]
**kwargs: Any,
) -> Pipeline:
) -> TPipeline:
ensure_correct_pipeline_kwargs(pipeline, **kwargs)
# call without arguments returns current pipeline
orig_args = get_orig_args(**kwargs) # original (*args, **kwargs)
Expand All @@ -111,7 +116,7 @@ def pipeline(
context = Container()[PipelineContext]
# if pipeline instance is already active then return it, otherwise create a new one
if context.is_active():
return cast(Pipeline, context.pipeline())
return cast(TPipeline, context.pipeline())
else:
pass

Expand All @@ -129,7 +134,7 @@ def pipeline(

progress = collector_from_name(progress)
# create new pipeline instance
p = Pipeline(
p = _impl_cls(
pipeline_name,
pipelines_dir,
pipeline_salt,
Expand Down

0 comments on commit 4c6bdbc

Please sign in to comment.