diff --git a/dlt/common/typing.py b/dlt/common/typing.py index a0322fe01e..e9d14141a2 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -438,6 +438,10 @@ def get_generic_type_argument_from_instance( """ orig_param_type = Any if cls_ := getattr(instance, "__orig_class__", None): + # unfurl Optional[Incremental[...]] to Incremental[...] + if is_optional_type(cls_): + cls_ = get_args(cls_)[0] + # instance of generic class pass elif bases_ := get_original_bases(instance.__class__): diff --git a/tests/common/test_typing.py b/tests/common/test_typing.py index e81c3e7fa2..b7865ddb7f 100644 --- a/tests/common/test_typing.py +++ b/tests/common/test_typing.py @@ -1,3 +1,5 @@ +from types import SimpleNamespace + import pytest from dataclasses import dataclass from typing import ( @@ -44,7 +46,9 @@ is_annotated, is_callable_type, add_value_to_literal, + get_generic_type_argument_from_instance, ) +from dlt.extract import Incremental class TTestTyDi(TypedDict): @@ -310,3 +314,17 @@ def test_add_value_to_literal() -> None: add_value_to_literal(TestSingleLiteral, "green") add_value_to_literal(TestSingleLiteral, "blue") assert get_args(TestSingleLiteral) == ("red", "green", "blue") + + +def test_get_generic_type_argument_from_instance() -> None: + # generic contains hint + instance = SimpleNamespace(__orig_class__=Incremental[str]) + assert get_generic_type_argument_from_instance(instance) is str + instance = SimpleNamespace(__orig_class__=Optional[Incremental[str]]) + assert get_generic_type_argument_from_instance(instance) is str + + # with sample values + instance = SimpleNamespace(__orig_class__=Incremental[Any]) + assert get_generic_type_argument_from_instance(instance, 1) is int + instance = SimpleNamespace(__orig_class__=Optional[Incremental[Any]]) + assert get_generic_type_argument_from_instance(instance, 1) is int