diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py
index c6debf5a74..9a4d8f47be 100644
--- a/dlt/pipeline/__init__.py
+++ b/dlt/pipeline/__init__.py
@@ -1,4 +1,4 @@
-from typing import Sequence, Type, cast, overload
+from typing import Sequence, Type, TypeVar, cast, overload
 
 from dlt.common.schema import Schema
 from dlt.common.schema.typing import TColumnSchema, TWriteDisposition, TSchemaContract
@@ -15,6 +15,8 @@
 from dlt.pipeline.progress import _from_name as collector_from_name, TCollectorArg, _NULL_COLLECTOR
 from dlt.pipeline.warnings import credentials_argument_deprecated
 
+TPipeline = TypeVar("TPipeline", bound=Pipeline)
+
 
 @overload
 def pipeline(
@@ -29,7 +31,8 @@ def pipeline(
     full_refresh: bool = False,
     credentials: Any = None,
     progress: TCollectorArg = _NULL_COLLECTOR,
-) -> Pipeline:
+    _impl_cls: Type[TPipeline] = Pipeline,
+) -> TPipeline:
     """Creates a new instance of `dlt` pipeline, which moves the data from the source ie. a REST API to a destination ie. database or a data lake.
 
     #### Note:
@@ -97,9 +100,9 @@ def pipeline(
     full_refresh: bool = False,
     credentials: Any = None,
     progress: TCollectorArg = _NULL_COLLECTOR,
-    _impl_cls: Type[Pipeline] = Pipeline,
+    _impl_cls: Type[TPipeline] = Pipeline,
     **kwargs: Any,
-) -> Pipeline:
+) -> TPipeline:
     ensure_correct_pipeline_kwargs(pipeline, **kwargs)
     # call without arguments returns current pipeline
     orig_args = get_orig_args(**kwargs)  # original (*args, **kwargs)
@@ -112,7 +115,7 @@ def pipeline(
         context = Container()[PipelineContext]
         # if pipeline instance is already active then return it, otherwise create a new one
         if context.is_active():
-            return cast(Pipeline, context.pipeline())
+            return cast(TPipeline, context.pipeline())
         else:
             pass