From c004a1f78ae10cb7f695b9ee0f1c739bb52ca8cb Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Sat, 25 May 2024 17:18:01 +0200 Subject: [PATCH] adds NotResolved type annotations that excludes type from resolving in configspec --- dlt/common/configuration/__init__.py | 9 ++- dlt/common/configuration/resolve.py | 6 +- .../configuration/specs/base_configuration.py | 38 ++++++++++++- dlt/common/destination/reference.py | 16 +++--- dlt/destinations/impl/qdrant/configuration.py | 7 ++- .../impl/weaviate/configuration.py | 5 +- .../configuration/test_configuration.py | 55 ++++++++++++++++++- tests/load/utils.py | 6 +- 8 files changed, 121 insertions(+), 21 deletions(-) diff --git a/dlt/common/configuration/__init__.py b/dlt/common/configuration/__init__.py index 8de57f7799..2abc31b17d 100644 --- a/dlt/common/configuration/__init__.py +++ b/dlt/common/configuration/__init__.py @@ -1,4 +1,10 @@ -from .specs.base_configuration import configspec, is_valid_hint, is_secret_hint, resolve_type +from .specs.base_configuration import ( + configspec, + is_valid_hint, + is_secret_hint, + resolve_type, + NotResolved, +) from .specs import known_sections from .resolve import resolve_configuration, inject_section from .inject import with_config, last_config, get_fun_spec, create_resolved_partial @@ -15,6 +21,7 @@ "configspec", "is_valid_hint", "is_secret_hint", + "NotResolved", "resolve_type", "known_sections", "resolve_configuration", diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index ebfa7b6b89..9101cfdd9c 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -8,7 +8,6 @@ StrAny, TSecretValue, get_all_types_of_class_in_union, - is_final_type, is_optional_type, is_union_type, ) @@ -21,6 +20,7 @@ is_context_inner_hint, is_base_configuration_inner_hint, is_valid_hint, + is_hint_not_resolved, ) from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.configuration.specs.exceptions import NativeValueError @@ -194,7 +194,7 @@ def _resolve_config_fields( if explicit_values: explicit_value = explicit_values.get(key) else: - if is_final_type(hint): + if is_hint_not_resolved(hint): # for final fields default value is like explicit explicit_value = default_value else: @@ -258,7 +258,7 @@ def _resolve_config_fields( unresolved_fields[key] = traces # set resolved value in config if default_value != current_value: - if not is_final_type(hint): + if not is_hint_not_resolved(hint): # ignore final types setattr(config, key, current_value) diff --git a/dlt/common/configuration/specs/base_configuration.py b/dlt/common/configuration/specs/base_configuration.py index 1329feae6c..006cde8dce 100644 --- a/dlt/common/configuration/specs/base_configuration.py +++ b/dlt/common/configuration/specs/base_configuration.py @@ -20,7 +20,7 @@ ClassVar, TypeVar, ) -from typing_extensions import get_args, get_origin, dataclass_transform +from typing_extensions import get_args, get_origin, dataclass_transform, Annotated, TypeAlias from functools import wraps if TYPE_CHECKING: @@ -29,8 +29,11 @@ TDtcField = dataclasses.Field from dlt.common.typing import ( + AnyType, TAnyClass, extract_inner_type, + is_annotated, + is_final_type, is_optional_type, is_union_type, ) @@ -48,6 +51,34 @@ _C = TypeVar("_C", bound="CredentialsConfiguration") +class NotResolved: + """Used in type annotations to indicate types that should not be resolved.""" + + def __init__(self, not_resolved: bool = True): + self.not_resolved = not_resolved + + def __bool__(self) -> bool: + return self.not_resolved + + +def is_hint_not_resolved(hint: AnyType) -> bool: + """Checks if hint should NOT be resolved. Final and types annotated like + + >>> Annotated[str, NotResolved()] + + are not resolved. + """ + if is_final_type(hint): + return True + + if is_annotated(hint): + _, *a_m = get_args(hint) + for annotation in a_m: + if isinstance(annotation, NotResolved): + return bool(annotation) + return False + + def is_base_configuration_inner_hint(inner_hint: Type[Any]) -> bool: return inspect.isclass(inner_hint) and issubclass(inner_hint, BaseConfiguration) @@ -70,6 +101,11 @@ def is_valid_hint(hint: Type[Any]) -> bool: if get_origin(hint) is ClassVar: # class vars are skipped by dataclass return True + + if is_hint_not_resolved(hint): + # all hints that are not resolved are valid + return True + hint = extract_inner_type(hint) hint = get_config_if_union_hint(hint) or hint hint = get_origin(hint) or hint diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 2ad5131e63..d4cdfb729d 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -18,8 +18,8 @@ Any, TypeVar, Generic, - Final, ) +from typing_extensions import Annotated import datetime # noqa: 251 from copy import deepcopy import inspect @@ -35,7 +35,7 @@ has_column_with_prop, get_first_column_name_with_prop, ) -from dlt.common.configuration import configspec, resolve_configuration, known_sections +from dlt.common.configuration import configspec, resolve_configuration, known_sections, NotResolved from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration from dlt.common.configuration.accessors import config from dlt.common.destination.capabilities import DestinationCapabilitiesContext @@ -78,7 +78,7 @@ class StateInfo(NamedTuple): @configspec class DestinationClientConfiguration(BaseConfiguration): - destination_type: Final[str] = dataclasses.field( + destination_type: Annotated[str, NotResolved()] = dataclasses.field( default=None, init=False, repr=False, compare=False ) # which destination to load data to credentials: Optional[CredentialsConfiguration] = None @@ -103,11 +103,11 @@ def on_resolved(self) -> None: class DestinationClientDwhConfiguration(DestinationClientConfiguration): """Configuration of a destination that supports datasets/schemas""" - dataset_name: Final[str] = dataclasses.field( + dataset_name: Annotated[str, NotResolved()] = dataclasses.field( default=None, init=False, repr=False, compare=False - ) # dataset must be final so it is not configurable + ) # dataset cannot be resolved """dataset name in the destination to load data to, for schemas that are not default schema, it is used as dataset prefix""" - default_schema_name: Final[Optional[str]] = dataclasses.field( + default_schema_name: Annotated[Optional[str], NotResolved()] = dataclasses.field( default=None, init=False, repr=False, compare=False ) """name of default schema to be used to name effective dataset to load data to""" @@ -121,8 +121,8 @@ def _bind_dataset_name( This method is intended to be used internally. """ - self.dataset_name = dataset_name # type: ignore[misc] - self.default_schema_name = default_schema_name # type: ignore[misc] + self.dataset_name = dataset_name + self.default_schema_name = default_schema_name return self def normalize_dataset_name(self, schema: Schema) -> str: diff --git a/dlt/destinations/impl/qdrant/configuration.py b/dlt/destinations/impl/qdrant/configuration.py index d589537742..fd11cc7dcb 100644 --- a/dlt/destinations/impl/qdrant/configuration.py +++ b/dlt/destinations/impl/qdrant/configuration.py @@ -1,7 +1,8 @@ import dataclasses from typing import Optional, Final +from typing_extensions import Annotated -from dlt.common.configuration import configspec +from dlt.common.configuration import configspec, NotResolved from dlt.common.configuration.specs.base_configuration import ( BaseConfiguration, CredentialsConfiguration, @@ -55,7 +56,9 @@ class QdrantClientConfiguration(DestinationClientDwhConfiguration): dataset_separator: str = "_" # make it optional so empty dataset is allowed - dataset_name: Final[Optional[str]] = dataclasses.field(default=None, init=False, repr=False, compare=False) # type: ignore[misc] + dataset_name: Annotated[Optional[str], NotResolved()] = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) # Batch size for generating embeddings embedding_batch_size: int = 32 diff --git a/dlt/destinations/impl/weaviate/configuration.py b/dlt/destinations/impl/weaviate/configuration.py index 90fb7ce5ce..211a6ec029 100644 --- a/dlt/destinations/impl/weaviate/configuration.py +++ b/dlt/destinations/impl/weaviate/configuration.py @@ -1,8 +1,9 @@ import dataclasses from typing import Dict, Literal, Optional, Final +from typing_extensions import Annotated from urllib.parse import urlparse -from dlt.common.configuration import configspec +from dlt.common.configuration import configspec, NotResolved from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration from dlt.common.destination.reference import DestinationClientDwhConfiguration from dlt.common.utils import digest128 @@ -26,7 +27,7 @@ def __str__(self) -> str: class WeaviateClientConfiguration(DestinationClientDwhConfiguration): destination_type: Final[str] = dataclasses.field(default="weaviate", init=False, repr=False, compare=False) # type: ignore # make it optional so empty dataset is allowed - dataset_name: Optional[str] = None # type: ignore[misc] + dataset_name: Annotated[Optional[str], NotResolved()] = None batch_size: int = 100 batch_workers: int = 1 diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index 84b2d1893d..43ccdf856c 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -12,11 +12,12 @@ Optional, Type, Union, - TYPE_CHECKING, ) +from typing_extensions import Annotated from dlt.common import json, pendulum, Decimal, Wei from dlt.common.configuration.providers.provider import ConfigProvider +from dlt.common.configuration.specs.base_configuration import NotResolved, is_hint_not_resolved from dlt.common.configuration.specs.gcp_credentials import ( GcpServiceAccountCredentialsWithoutDefaults, ) @@ -917,6 +918,58 @@ def test_is_valid_hint() -> None: assert is_valid_hint(Wei) is True # any class type, except deriving from BaseConfiguration is wrong type assert is_valid_hint(ConfigFieldMissingException) is False + # but final and annotated types are not ok because they are not resolved + assert is_valid_hint(Final[ConfigFieldMissingException]) is True # type: ignore[arg-type] + assert is_valid_hint(Annotated[ConfigFieldMissingException, NotResolved()]) is True # type: ignore[arg-type] + assert is_valid_hint(Annotated[ConfigFieldMissingException, "REQ"]) is False # type: ignore[arg-type] + + +def test_is_not_resolved_hint() -> None: + assert is_hint_not_resolved(Final[ConfigFieldMissingException]) is True + assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, NotResolved()]) is True + assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, NotResolved(True)]) is True + assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, NotResolved(False)]) is False + assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, "REQ"]) is False + assert is_hint_not_resolved(str) is False + + +def test_not_resolved_hint() -> None: + class SentinelClass: + pass + + @configspec + class OptionalNotResolveConfiguration(BaseConfiguration): + trace: Final[Optional[SentinelClass]] = None + traces: Annotated[Optional[List[SentinelClass]], NotResolved()] = None + + c = resolve.resolve_configuration(OptionalNotResolveConfiguration()) + assert c.trace is None + assert c.traces is None + + s1 = SentinelClass() + s2 = SentinelClass() + + c = resolve.resolve_configuration(OptionalNotResolveConfiguration(s1, [s2])) + assert c.trace is s1 + assert c.traces[0] is s2 + + @configspec + class NotResolveConfiguration(BaseConfiguration): + trace: Final[SentinelClass] = None + traces: Annotated[List[SentinelClass], NotResolved()] = None + + with pytest.raises(ConfigFieldMissingException): + resolve.resolve_configuration(NotResolveConfiguration()) + + with pytest.raises(ConfigFieldMissingException): + resolve.resolve_configuration(NotResolveConfiguration(trace=s1)) + + with pytest.raises(ConfigFieldMissingException): + resolve.resolve_configuration(NotResolveConfiguration(traces=[s2])) + + c2 = resolve.resolve_configuration(NotResolveConfiguration(s1, [s2])) + assert c2.trace is s1 + assert c2.traces[0] is s2 def test_configspec_auto_base_config_derivation() -> None: diff --git a/tests/load/utils.py b/tests/load/utils.py index 81107e83d9..c03470676f 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -574,8 +574,8 @@ def yield_client( destination = Destination.from_reference(destination_type) # create initial config dest_config: DestinationClientDwhConfiguration = None - dest_config = destination.spec() # type: ignore[assignment] - dest_config.dataset_name = dataset_name # type: ignore[misc] + dest_config = destination.spec() # type: ignore + dest_config.dataset_name = dataset_name if default_config_values is not None: # apply the values to credentials, if dict is provided it will be used as default @@ -597,7 +597,7 @@ def yield_client( staging_config = DestinationClientStagingConfiguration( bucket_url=AWS_BUCKET, )._bind_dataset_name(dataset_name=dest_config.dataset_name) - staging_config.destination_type = "filesystem" # type: ignore[misc] + staging_config.destination_type = "filesystem" staging_config.resolve() dest_config.staging_config = staging_config # type: ignore[attr-defined]