diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py index c6debf5a74..9a4d8f47be 100644 --- a/dlt/pipeline/__init__.py +++ b/dlt/pipeline/__init__.py @@ -1,4 +1,4 @@ -from typing import Sequence, Type, cast, overload +from typing import Sequence, Type, TypeVar, cast, overload from dlt.common.schema import Schema from dlt.common.schema.typing import TColumnSchema, TWriteDisposition, TSchemaContract @@ -15,6 +15,8 @@ from dlt.pipeline.progress import _from_name as collector_from_name, TCollectorArg, _NULL_COLLECTOR from dlt.pipeline.warnings import credentials_argument_deprecated +TPipeline = TypeVar("TPipeline", bound=Pipeline) + @overload def pipeline( @@ -29,7 +31,8 @@ def pipeline( full_refresh: bool = False, credentials: Any = None, progress: TCollectorArg = _NULL_COLLECTOR, -) -> Pipeline: + _impl_cls: Type[TPipeline] = Pipeline, +) -> TPipeline: """Creates a new instance of `dlt` pipeline, which moves the data from the source ie. a REST API to a destination ie. database or a data lake. #### Note: @@ -97,9 +100,9 @@ def pipeline( full_refresh: bool = False, credentials: Any = None, progress: TCollectorArg = _NULL_COLLECTOR, - _impl_cls: Type[Pipeline] = Pipeline, + _impl_cls: Type[TPipeline] = Pipeline, **kwargs: Any, -) -> Pipeline: +) -> TPipeline: ensure_correct_pipeline_kwargs(pipeline, **kwargs) # call without arguments returns current pipeline orig_args = get_orig_args(**kwargs) # original (*args, **kwargs) @@ -112,7 +115,7 @@ def pipeline( context = Container()[PipelineContext] # if pipeline instance is already active then return it, otherwise create a new one if context.is_active(): - return cast(Pipeline, context.pipeline()) + return cast(TPipeline, context.pipeline()) else: pass