diff --git a/dlt/common/__init__.py b/dlt/common/__init__.py index 466fd7c546..0c8a09ec3e 100644 --- a/dlt/common/__init__.py +++ b/dlt/common/__init__.py @@ -1,8 +1,8 @@ +from dlt.common import logger from dlt.common.arithmetics import Decimal from dlt.common.wei import Wei from dlt.common.pendulum import pendulum from dlt.common.json import json from dlt.common.runtime.signals import sleep -from dlt.common.runtime import logger __all__ = ["Decimal", "Wei", "pendulum", "json", "sleep", "logger"] diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index b398f0463a..ebfa7b6b89 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -76,7 +76,7 @@ def initialize_credentials(hint: Any, initial_value: Any) -> CredentialsConfigur first_credentials: CredentialsConfiguration = None for idx, spec in enumerate(specs_in_union): try: - credentials = spec(initial_value) + credentials = spec.from_init_value(initial_value) if credentials.is_resolved(): return credentials # keep first credentials in the union to return in case all of the match but not resolve @@ -88,7 +88,7 @@ def initialize_credentials(hint: Any, initial_value: Any) -> CredentialsConfigur return first_credentials else: assert issubclass(hint, CredentialsConfiguration) - return hint(initial_value) # type: ignore + return hint.from_init_value(initial_value) # type: ignore def inject_section( diff --git a/dlt/common/configuration/specs/api_credentials.py b/dlt/common/configuration/specs/api_credentials.py index fd7ae8cb09..918cd4ee45 100644 --- a/dlt/common/configuration/specs/api_credentials.py +++ b/dlt/common/configuration/specs/api_credentials.py @@ -6,9 +6,9 @@ @configspec class OAuth2Credentials(CredentialsConfiguration): - client_id: str - client_secret: TSecretValue - refresh_token: Optional[TSecretValue] + client_id: str = None + client_secret: TSecretValue = None + refresh_token: Optional[TSecretValue] = None scopes: Optional[List[str]] = None token: Optional[TSecretValue] = None diff --git a/dlt/common/configuration/specs/aws_credentials.py b/dlt/common/configuration/specs/aws_credentials.py index ee7360e2cb..ee49e79e40 100644 --- a/dlt/common/configuration/specs/aws_credentials.py +++ b/dlt/common/configuration/specs/aws_credentials.py @@ -121,3 +121,9 @@ def parse_native_representation(self, native_value: Any) -> None: self.__is_resolved__ = True except Exception: raise InvalidBoto3Session(self.__class__, native_value) + + @classmethod + def from_session(cls, botocore_session: Any) -> "AwsCredentials": + self = cls() + self.parse_native_representation(botocore_session) + return self diff --git a/dlt/common/configuration/specs/base_configuration.py b/dlt/common/configuration/specs/base_configuration.py index 62abf42f27..06fb97fcdd 100644 --- a/dlt/common/configuration/specs/base_configuration.py +++ b/dlt/common/configuration/specs/base_configuration.py @@ -2,6 +2,7 @@ import inspect import contextlib import dataclasses +import warnings from collections.abc import Mapping as C_Mapping from typing import ( @@ -19,7 +20,7 @@ ClassVar, TypeVar, ) -from typing_extensions import get_args, get_origin +from typing_extensions import get_args, get_origin, dataclass_transform from functools import wraps if TYPE_CHECKING: @@ -44,6 +45,7 @@ _F_BaseConfiguration: Any = type(object) _F_ContainerInjectableContext: Any = type(object) _T = TypeVar("_T", bound="BaseConfiguration") +_C = TypeVar("_C", bound="CredentialsConfiguration") def is_base_configuration_inner_hint(inner_hint: Type[Any]) -> bool: @@ -106,18 +108,26 @@ def is_secret_hint(hint: Type[Any]) -> bool: @overload -def configspec(cls: Type[TAnyClass]) -> Type[TAnyClass]: ... +def configspec(cls: Type[TAnyClass], init: bool = True) -> Type[TAnyClass]: ... @overload -def configspec(cls: None = ...) -> Callable[[Type[TAnyClass]], Type[TAnyClass]]: ... +def configspec( + cls: None = ..., init: bool = True +) -> Callable[[Type[TAnyClass]], Type[TAnyClass]]: ... +@dataclass_transform(eq_default=False, field_specifiers=(dataclasses.Field, dataclasses.field)) def configspec( - cls: Optional[Type[Any]] = None, + cls: Optional[Type[Any]] = None, init: bool = True ) -> Union[Type[TAnyClass], Callable[[Type[TAnyClass]], Type[TAnyClass]]]: """Converts (via derivation) any decorated class to a Python dataclass that may be used as a spec to resolve configurations + __init__ method is synthesized by default. `init` flag is ignored if the decorated class implements custom __init__ as well as + when any of base classes has no synthesized __init__ + + All fields must have default values. This decorator will add `None` default values that miss one. + In comparison the Python dataclass, a spec implements full dictionary interface for its attributes, allows instance creation from ie. strings or other types (parsing, deserialization) and control over configuration resolution process. See `BaseConfiguration` and CredentialsConfiguration` for more information. @@ -142,6 +152,10 @@ def wrap(cls: Type[TAnyClass]) -> Type[TAnyClass]: # get all annotations without corresponding attributes and set them to None for ann in cls.__annotations__: if not hasattr(cls, ann) and not ann.startswith(("__", "_abc_")): + warnings.warn( + f"Missing default value for field {ann} on {cls.__name__}. None assumed. All" + " fields in configspec must have default." + ) setattr(cls, ann, None) # get all attributes without corresponding annotations for att_name, att_value in list(cls.__dict__.items()): @@ -177,17 +191,18 @@ def default_factory(att_value=att_value): # type: ignore[no-untyped-def] # We don't want to overwrite user's __init__ method # Create dataclass init only when not defined in the class - # (never put init on BaseConfiguration itself) - try: - is_base = cls is BaseConfiguration - except NameError: - is_base = True - init = False - base_params = getattr(cls, "__dataclass_params__", None) - if not is_base and (base_params and base_params.init or cls.__init__ is object.__init__): - init = True + # NOTE: any class without synthesized __init__ breaks the creation chain + has_default_init = super(cls, cls).__init__ == cls.__init__ # type: ignore[misc] + base_params = getattr(cls, "__dataclass_params__", None) # cls.__init__ is object.__init__ + synth_init = init and ((not base_params or base_params.init) and has_default_init) + if synth_init != init and has_default_init: + warnings.warn( + f"__init__ method will not be generated on {cls.__name__} because bas class didn't" + " synthesize __init__. Please correct `init` flag in confispec decorator. You are" + " probably receiving incorrect __init__ signature for type checking" + ) # do not generate repr as it may contain secret values - return dataclasses.dataclass(cls, init=init, eq=False, repr=False) # type: ignore + return dataclasses.dataclass(cls, init=synth_init, eq=False, repr=False) # type: ignore # called with parenthesis if cls is None: @@ -198,12 +213,14 @@ def default_factory(att_value=att_value): # type: ignore[no-untyped-def] @configspec class BaseConfiguration(MutableMapping[str, Any]): - __is_resolved__: bool = dataclasses.field(default=False, init=False, repr=False) + __is_resolved__: bool = dataclasses.field(default=False, init=False, repr=False, compare=False) """True when all config fields were resolved and have a specified value type""" - __section__: str = dataclasses.field(default=None, init=False, repr=False) - """Obligatory section used by config providers when searching for keys, always present in the search path""" - __exception__: Exception = dataclasses.field(default=None, init=False, repr=False) + __exception__: Exception = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) """Holds the exception that prevented the full resolution""" + __section__: ClassVar[str] = None + """Obligatory section used by config providers when searching for keys, always present in the search path""" __config_gen_annotations__: ClassVar[List[str]] = [] """Additional annotations for config generator, currently holds a list of fields of interest that have defaults""" __dataclass_fields__: ClassVar[Dict[str, TDtcField]] @@ -342,9 +359,10 @@ def call_method_in_mro(config, method_name: str) -> None: class CredentialsConfiguration(BaseConfiguration): """Base class for all credentials. Credentials are configurations that may be stored only by providers supporting secrets.""" - __section__: str = "credentials" + __section__: ClassVar[str] = "credentials" - def __init__(self, init_value: Any = None) -> None: + @classmethod + def from_init_value(cls: Type[_C], init_value: Any = None) -> _C: """Initializes credentials from `init_value` Init value may be a native representation of the credentials or a dict. In case of native representation (for example a connection string or JSON with service account credentials) @@ -353,14 +371,10 @@ def __init__(self, init_value: Any = None) -> None: Credentials will be marked as resolved if all required fields are set. """ - if init_value is None: - return - elif isinstance(init_value, C_Mapping): - self.update(init_value) - else: - self.parse_native_representation(init_value) - if not self.is_partial(): - self.resolve() + # create an instance + self = cls() + self._apply_init_value(init_value) + return self def to_native_credentials(self) -> Any: """Returns native credentials object. @@ -369,6 +383,16 @@ def to_native_credentials(self) -> Any: """ return self.to_native_representation() + def _apply_init_value(self, init_value: Any = None) -> None: + if isinstance(init_value, C_Mapping): + self.update(init_value) + elif init_value is not None: + self.parse_native_representation(init_value) + else: + return + if not self.is_partial(): + self.resolve() + def __str__(self) -> str: """Get string representation of credentials to be displayed, with all secret parts removed""" return super().__str__() diff --git a/dlt/common/configuration/specs/config_providers_context.py b/dlt/common/configuration/specs/config_providers_context.py index 860e7414de..642634fb0a 100644 --- a/dlt/common/configuration/specs/config_providers_context.py +++ b/dlt/common/configuration/specs/config_providers_context.py @@ -1,4 +1,5 @@ import contextlib +import dataclasses import io from typing import ClassVar, List @@ -28,7 +29,7 @@ class ConfigProvidersConfiguration(BaseConfiguration): only_toml_fragments: bool = True # always look in providers - __section__ = known_sections.PROVIDERS + __section__: ClassVar[str] = known_sections.PROVIDERS @configspec @@ -37,8 +38,12 @@ class ConfigProvidersContext(ContainerInjectableContext): global_affinity: ClassVar[bool] = True - providers: List[ConfigProvider] - context_provider: ConfigProvider + providers: List[ConfigProvider] = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) + context_provider: ConfigProvider = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) def __init__(self) -> None: super().__init__() diff --git a/dlt/common/configuration/specs/config_section_context.py b/dlt/common/configuration/specs/config_section_context.py index a656a2b0fe..1e6cd56155 100644 --- a/dlt/common/configuration/specs/config_section_context.py +++ b/dlt/common/configuration/specs/config_section_context.py @@ -8,7 +8,7 @@ class ConfigSectionContext(ContainerInjectableContext): TMergeFunc = Callable[["ConfigSectionContext", "ConfigSectionContext"], None] - pipeline_name: Optional[str] + pipeline_name: Optional[str] = None sections: Tuple[str, ...] = () merge_style: TMergeFunc = None source_state_key: str = None @@ -70,13 +70,3 @@ def __str__(self) -> str: super().__str__() + f": {self.pipeline_name} {self.sections}@{self.merge_style} state['{self.source_state_key}']" ) - - if TYPE_CHECKING: - # provide __init__ signature when type checking - def __init__( - self, - pipeline_name: str = None, - sections: Tuple[str, ...] = (), - merge_style: TMergeFunc = None, - source_state_key: str = None, - ) -> None: ... diff --git a/dlt/common/configuration/specs/connection_string_credentials.py b/dlt/common/configuration/specs/connection_string_credentials.py index 9dd6f00942..2691c5d886 100644 --- a/dlt/common/configuration/specs/connection_string_credentials.py +++ b/dlt/common/configuration/specs/connection_string_credentials.py @@ -1,14 +1,15 @@ -from typing import Any, ClassVar, Dict, List, Optional +import dataclasses +from typing import Any, ClassVar, Dict, List, Optional, Union + from dlt.common.libs.sql_alchemy import URL, make_url from dlt.common.configuration.specs.exceptions import InvalidConnectionString - from dlt.common.typing import TSecretValue from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec @configspec class ConnectionStringCredentials(CredentialsConfiguration): - drivername: str = None + drivername: str = dataclasses.field(default=None, init=False, repr=False, compare=False) database: str = None password: Optional[TSecretValue] = None username: str = None @@ -18,6 +19,11 @@ class ConnectionStringCredentials(CredentialsConfiguration): __config_gen_annotations__: ClassVar[List[str]] = ["port", "password", "host"] + def __init__(self, connection_string: Union[str, Dict[str, Any]] = None) -> None: + """Initializes the credentials from SQLAlchemy like connection string or from dict holding connection string elements""" + super().__init__() + self._apply_init_value(connection_string) + def parse_native_representation(self, native_value: Any) -> None: if not isinstance(native_value, str): raise InvalidConnectionString(self.__class__, native_value, self.drivername) diff --git a/dlt/common/configuration/specs/gcp_credentials.py b/dlt/common/configuration/specs/gcp_credentials.py index 431f35c8d0..4d81a493a3 100644 --- a/dlt/common/configuration/specs/gcp_credentials.py +++ b/dlt/common/configuration/specs/gcp_credentials.py @@ -1,5 +1,6 @@ +import dataclasses import sys -from typing import Any, Final, List, Tuple, Union, Dict +from typing import Any, ClassVar, Final, List, Tuple, Union, Dict from dlt.common import json, pendulum from dlt.common.configuration.specs.api_credentials import OAuth2Credentials @@ -22,8 +23,12 @@ @configspec class GcpCredentials(CredentialsConfiguration): - token_uri: Final[str] = "https://oauth2.googleapis.com/token" - auth_uri: Final[str] = "https://accounts.google.com/o/oauth2/auth" + token_uri: Final[str] = dataclasses.field( + default="https://oauth2.googleapis.com/token", init=False, repr=False, compare=False + ) + auth_uri: Final[str] = dataclasses.field( + default="https://accounts.google.com/o/oauth2/auth", init=False, repr=False, compare=False + ) project_id: str = None @@ -69,7 +74,9 @@ def to_gcs_credentials(self) -> Dict[str, Any]: class GcpServiceAccountCredentialsWithoutDefaults(GcpCredentials): private_key: TSecretValue = None client_email: str = None - type: Final[str] = "service_account" # noqa: A003 + type: Final[str] = dataclasses.field( # noqa: A003 + default="service_account", init=False, repr=False, compare=False + ) def parse_native_representation(self, native_value: Any) -> None: """Accepts ServiceAccountCredentials as native value. In other case reverts to serialized services.json""" @@ -121,8 +128,10 @@ def __str__(self) -> str: @configspec class GcpOAuthCredentialsWithoutDefaults(GcpCredentials, OAuth2Credentials): # only desktop app supported - refresh_token: TSecretValue - client_type: Final[str] = "installed" + refresh_token: TSecretValue = None + client_type: Final[str] = dataclasses.field( + default="installed", init=False, repr=False, compare=False + ) def parse_native_representation(self, native_value: Any) -> None: """Accepts Google OAuth2 credentials as native value. In other case reverts to serialized oauth client secret json""" @@ -237,7 +246,7 @@ def __str__(self) -> str: @configspec class GcpDefaultCredentials(CredentialsWithDefault, GcpCredentials): - _LAST_FAILED_DEFAULT: float = 0.0 + _LAST_FAILED_DEFAULT: ClassVar[float] = 0.0 def parse_native_representation(self, native_value: Any) -> None: """Accepts google credentials as native value""" diff --git a/dlt/common/configuration/specs/known_sections.py b/dlt/common/configuration/specs/known_sections.py index 97ba85ffd6..8bd754ddd5 100644 --- a/dlt/common/configuration/specs/known_sections.py +++ b/dlt/common/configuration/specs/known_sections.py @@ -13,6 +13,9 @@ EXTRACT = "extract" """extract stage of the pipeline""" +SCHEMA = "schema" +"""schema configuration, ie. normalizers""" + PROVIDERS = "providers" """secrets and config providers""" diff --git a/dlt/common/configuration/specs/run_configuration.py b/dlt/common/configuration/specs/run_configuration.py index 54ce46ceba..b57b4abbdd 100644 --- a/dlt/common/configuration/specs/run_configuration.py +++ b/dlt/common/configuration/specs/run_configuration.py @@ -1,7 +1,7 @@ import binascii from os.path import isfile, join from pathlib import Path -from typing import Any, Optional, Tuple, IO +from typing import Any, ClassVar, Optional, IO from dlt.common.typing import TSecretStrValue from dlt.common.utils import encoding_for_mode, main_module_file_path, reveal_pseudo_secret @@ -30,7 +30,7 @@ class RunConfiguration(BaseConfiguration): """Platform connection""" dlthub_dsn: Optional[TSecretStrValue] = None - __section__ = "runtime" + __section__: ClassVar[str] = "runtime" def on_resolved(self) -> None: # generate pipeline name from the entry point script name diff --git a/dlt/common/data_writers/buffered.py b/dlt/common/data_writers/buffered.py index 24935d73ac..b10b1d14b9 100644 --- a/dlt/common/data_writers/buffered.py +++ b/dlt/common/data_writers/buffered.py @@ -1,6 +1,6 @@ import gzip import time -from typing import List, IO, Any, Optional, Type, TypeVar, Generic +from typing import ClassVar, List, IO, Any, Optional, Type, TypeVar, Generic from dlt.common.typing import TDataItem, TDataItems from dlt.common.data_writers import TLoaderFileFormat @@ -33,7 +33,7 @@ class BufferedDataWriterConfiguration(BaseConfiguration): disable_compression: bool = False _caps: Optional[DestinationCapabilitiesContext] = None - __section__ = known_sections.DATA_WRITER + __section__: ClassVar[str] = known_sections.DATA_WRITER @with_config(spec=BufferedDataWriterConfiguration) def __init__( diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 0f3640da1e..2aadb010e0 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -4,6 +4,7 @@ IO, TYPE_CHECKING, Any, + ClassVar, Dict, List, Optional, @@ -236,7 +237,7 @@ class ParquetDataWriterConfiguration(BaseConfiguration): timestamp_timezone: str = "UTC" row_group_size: Optional[int] = None - __section__: str = known_sections.DATA_WRITER + __section__: ClassVar[str] = known_sections.DATA_WRITER class ParquetDataWriter(DataWriter): diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index 36a9cc3b6e..7a64f32ea3 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -30,22 +30,22 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): """Injectable destination capabilities required for many Pipeline stages ie. normalize""" - preferred_loader_file_format: TLoaderFileFormat - supported_loader_file_formats: List[TLoaderFileFormat] - preferred_staging_file_format: Optional[TLoaderFileFormat] - supported_staging_file_formats: List[TLoaderFileFormat] - escape_identifier: Callable[[str], str] - escape_literal: Callable[[Any], Any] - decimal_precision: Tuple[int, int] - wei_precision: Tuple[int, int] - max_identifier_length: int - max_column_identifier_length: int - max_query_length: int - is_max_query_length_in_bytes: bool - max_text_data_type_length: int - is_max_text_data_type_length_in_bytes: bool - supports_transactions: bool - supports_ddl_transactions: bool + preferred_loader_file_format: TLoaderFileFormat = None + supported_loader_file_formats: List[TLoaderFileFormat] = None + preferred_staging_file_format: Optional[TLoaderFileFormat] = None + supported_staging_file_formats: List[TLoaderFileFormat] = None + escape_identifier: Callable[[str], str] = None + escape_literal: Callable[[Any], Any] = None + decimal_precision: Tuple[int, int] = None + wei_precision: Tuple[int, int] = None + max_identifier_length: int = None + max_column_identifier_length: int = None + max_query_length: int = None + is_max_query_length_in_bytes: bool = None + max_text_data_type_length: int = None + is_max_text_data_type_length_in_bytes: bool = None + supports_transactions: bool = None + supports_ddl_transactions: bool = None naming_convention: str = "snake_case" alter_add_multi_column: bool = True supports_truncate_command: bool = True diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 738c07bdc7..ddcc5d1146 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +import dataclasses from importlib import import_module from types import TracebackType from typing import ( @@ -11,7 +12,6 @@ Iterable, Type, Union, - TYPE_CHECKING, List, ContextManager, Dict, @@ -48,11 +48,11 @@ from dlt.common.schema.exceptions import UnknownTableException from dlt.common.storages import FileStorage from dlt.common.storages.load_storage import ParsedLoadJobFileName -from dlt.common.configuration.specs import GcpCredentials, AwsCredentialsWithoutDefaults TLoaderReplaceStrategy = Literal["truncate-and-insert", "insert-from-staging", "staging-optimized"] TDestinationConfig = TypeVar("TDestinationConfig", bound="DestinationClientConfiguration") TDestinationClient = TypeVar("TDestinationClient", bound="JobClientBase") +TDestinationDwhClient = TypeVar("TDestinationDwhClient", bound="DestinationClientDwhConfiguration") class StorageSchemaInfo(NamedTuple): @@ -75,8 +75,10 @@ class StateInfo(NamedTuple): @configspec class DestinationClientConfiguration(BaseConfiguration): - destination_type: Final[str] = None # which destination to load data to - credentials: Optional[CredentialsConfiguration] + destination_type: Final[str] = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) # which destination to load data to + credentials: Optional[CredentialsConfiguration] = None destination_name: Optional[str] = ( None # name of the destination, if not set, destination_type is used ) @@ -93,28 +95,33 @@ def __str__(self) -> str: def on_resolved(self) -> None: self.destination_name = self.destination_name or self.destination_type - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Optional[CredentialsConfiguration] = None, - destination_name: str = None, - environment: str = None, - ) -> None: ... - @configspec class DestinationClientDwhConfiguration(DestinationClientConfiguration): """Configuration of a destination that supports datasets/schemas""" - dataset_name: Final[str] = None # dataset must be final so it is not configurable + dataset_name: Final[str] = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) # dataset must be final so it is not configurable """dataset name in the destination to load data to, for schemas that are not default schema, it is used as dataset prefix""" - default_schema_name: Optional[str] = None + default_schema_name: Final[Optional[str]] = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) """name of default schema to be used to name effective dataset to load data to""" replace_strategy: TLoaderReplaceStrategy = "truncate-and-insert" """How to handle replace disposition for this destination, can be classic or staging""" + def _bind_dataset_name( + self: TDestinationDwhClient, dataset_name: str, default_schema_name: str = None + ) -> TDestinationDwhClient: + """Binds the dataset and default schema name to the configuration + + This method is intended to be used internally. + """ + self.dataset_name = dataset_name # type: ignore[misc] + self.default_schema_name = default_schema_name # type: ignore[misc] + return self + def normalize_dataset_name(self, schema: Schema) -> str: """Builds full db dataset (schema) name out of configured dataset name and schema name: {dataset_name}_{schema.name}. The resulting name is normalized. @@ -136,18 +143,6 @@ def normalize_dataset_name(self, schema: Schema) -> str: else schema.naming.normalize_table_identifier(self.dataset_name) ) - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Optional[CredentialsConfiguration] = None, - dataset_name: str = None, - default_schema_name: Optional[str] = None, - destination_name: str = None, - environment: str = None, - ) -> None: ... - @configspec class DestinationClientStagingConfiguration(DestinationClientDwhConfiguration): @@ -161,21 +156,6 @@ class DestinationClientStagingConfiguration(DestinationClientDwhConfiguration): # layout of the destination files layout: str = "{table_name}/{load_id}.{file_id}.{ext}" - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Union[AwsCredentialsWithoutDefaults, GcpCredentials] = None, - dataset_name: str = None, - default_schema_name: Optional[str] = None, - as_staging: bool = False, - bucket_url: str = None, - layout: str = None, - destination_name: str = None, - environment: str = None, - ) -> None: ... - @configspec class DestinationClientDwhWithStagingConfiguration(DestinationClientDwhConfiguration): @@ -183,18 +163,6 @@ class DestinationClientDwhWithStagingConfiguration(DestinationClientDwhConfigura staging_config: Optional[DestinationClientStagingConfiguration] = None """configuration of the staging, if present, injected at runtime""" - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Optional[CredentialsConfiguration] = None, - dataset_name: str = None, - default_schema_name: Optional[str] = None, - staging_config: Optional[DestinationClientStagingConfiguration] = None, - destination_name: str = None, - environment: str = None, - ) -> None: ... TLoadJobState = Literal["running", "failed", "retry", "completed"] diff --git a/dlt/common/runtime/logger.py b/dlt/common/logger.py similarity index 84% rename from dlt/common/runtime/logger.py rename to dlt/common/logger.py index 9dd8ce4e3a..02412248c3 100644 --- a/dlt/common/runtime/logger.py +++ b/dlt/common/logger.py @@ -2,12 +2,7 @@ import logging import traceback from logging import LogRecord, Logger -from typing import Any, Iterator, Protocol - -from dlt.common.json import json -from dlt.common.runtime.exec_info import dlt_version_info -from dlt.common.typing import StrAny, StrStr -from dlt.common.configuration.specs import RunConfiguration +from typing import Any, Mapping, Iterator, Protocol DLT_LOGGER_NAME = "dlt" LOGGER: Logger = None @@ -32,7 +27,7 @@ def wrapper(msg: str, *args: Any, **kwargs: Any) -> None: return wrapper -def metrics(name: str, extra: StrAny, stacklevel: int = 1) -> None: +def metrics(name: str, extra: Mapping[str, Any], stacklevel: int = 1) -> None: """Forwards metrics call to LOGGER""" if LOGGER: LOGGER.info(name, extra=extra, stacklevel=stacklevel) @@ -46,15 +41,6 @@ def suppress_and_warn() -> Iterator[None]: LOGGER.warning("Suppressed exception", exc_info=True) -def init_logging(config: RunConfiguration) -> None: - global LOGGER - - version = dlt_version_info(config.pipeline_name) - LOGGER = _init_logging( - DLT_LOGGER_NAME, config.log_level, config.log_format, config.pipeline_name, version - ) - - def is_logging() -> bool: return LOGGER is not None @@ -75,6 +61,8 @@ def pretty_format_exception() -> str: class _MetricsFormatter(logging.Formatter): def format(self, record: LogRecord) -> str: # noqa: A003 + from dlt.common.json import json + s = super(_MetricsFormatter, self).format(record) # dump metrics dictionary nicely if "metrics" in record.__dict__: @@ -83,7 +71,7 @@ def format(self, record: LogRecord) -> str: # noqa: A003 def _init_logging( - logger_name: str, level: str, fmt: str, component: str, version: StrStr + logger_name: str, level: str, fmt: str, component: str, version: Mapping[str, str] ) -> Logger: if logger_name == "root": logging.basicConfig(level=level) @@ -102,7 +90,7 @@ def _init_logging( from dlt.common.runtime import json_logging class _CustomJsonFormatter(json_logging.JSONLogFormatter): - version: StrStr = None + version: Mapping[str, str] = None def _format_log_object(self, record: LogRecord) -> Any: json_log_object = super(_CustomJsonFormatter, self)._format_log_object(record) diff --git a/dlt/common/normalizers/configuration.py b/dlt/common/normalizers/configuration.py index adeefe2237..54b725db1f 100644 --- a/dlt/common/normalizers/configuration.py +++ b/dlt/common/normalizers/configuration.py @@ -1,8 +1,7 @@ -import dataclasses -from typing import Optional, TYPE_CHECKING +from typing import ClassVar, Optional, TYPE_CHECKING from dlt.common.configuration import configspec -from dlt.common.configuration.specs import BaseConfiguration +from dlt.common.configuration.specs import BaseConfiguration, known_sections from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.normalizers.typing import TJSONNormalizer from dlt.common.typing import DictStrAny @@ -11,7 +10,7 @@ @configspec class NormalizersConfiguration(BaseConfiguration): # always in section - __section__: str = "schema" + __section__: ClassVar[str] = known_sections.SCHEMA naming: Optional[str] = None json_normalizer: Optional[DictStrAny] = None @@ -32,7 +31,3 @@ def on_resolved(self) -> None: self.json_normalizer["config"][ "max_nesting" ] = self.destination_capabilities.max_table_nesting - - if TYPE_CHECKING: - - def __init__(self, naming: str = None, json_normalizer: TJSONNormalizer = None) -> None: ... diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index 57dda11c39..7c117d4612 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +import dataclasses import os import datetime # noqa: 251 import humanize @@ -553,8 +554,12 @@ def __call__( @configspec class PipelineContext(ContainerInjectableContext): - _deferred_pipeline: Callable[[], SupportsPipeline] - _pipeline: SupportsPipeline + _deferred_pipeline: Callable[[], SupportsPipeline] = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) + _pipeline: SupportsPipeline = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) can_create_default: ClassVar[bool] = True @@ -592,14 +597,10 @@ def __init__(self, deferred_pipeline: Callable[..., SupportsPipeline] = None) -> @configspec class StateInjectableContext(ContainerInjectableContext): - state: TPipelineState + state: TPipelineState = None can_create_default: ClassVar[bool] = False - if TYPE_CHECKING: - - def __init__(self, state: TPipelineState = None) -> None: ... - def pipeline_state( container: Container, initial_default: TPipelineState = None diff --git a/dlt/common/runners/configuration.py b/dlt/common/runners/configuration.py index c5de2353f4..5857e1799f 100644 --- a/dlt/common/runners/configuration.py +++ b/dlt/common/runners/configuration.py @@ -16,13 +16,3 @@ class PoolRunnerConfiguration(BaseConfiguration): """# how many threads/processes in the pool""" run_sleep: float = 0.1 """how long to sleep between runs with workload, seconds""" - - if TYPE_CHECKING: - - def __init__( - self, - pool_type: TPoolType = None, - start_method: str = None, - workers: int = None, - run_sleep: float = 0.1, - ) -> None: ... diff --git a/dlt/common/runtime/init.py b/dlt/common/runtime/init.py index dc1430a527..5354dee4ff 100644 --- a/dlt/common/runtime/init.py +++ b/dlt/common/runtime/init.py @@ -5,8 +5,17 @@ _RUN_CONFIGURATION: RunConfiguration = None +def init_logging(config: RunConfiguration) -> None: + from dlt.common import logger + from dlt.common.runtime.exec_info import dlt_version_info + + version = dlt_version_info(config.pipeline_name) + logger.LOGGER = logger._init_logging( + logger.DLT_LOGGER_NAME, config.log_level, config.log_format, config.pipeline_name, version + ) + + def initialize_runtime(config: RunConfiguration) -> None: - from dlt.common.runtime.logger import init_logging from dlt.common.runtime.telemetry import start_telemetry from dlt.sources.helpers import requests diff --git a/dlt/common/runtime/prometheus.py b/dlt/common/runtime/prometheus.py index 1b233ffa9b..07c960efe7 100644 --- a/dlt/common/runtime/prometheus.py +++ b/dlt/common/runtime/prometheus.py @@ -3,7 +3,7 @@ from prometheus_client.metrics import MetricWrapperBase from dlt.common.configuration.specs import RunConfiguration -from dlt.common.runtime import logger +from dlt.common import logger from dlt.common.runtime.exec_info import dlt_version_info from dlt.common.typing import DictStrAny, StrAny diff --git a/dlt/common/runtime/segment.py b/dlt/common/runtime/segment.py index e302767fcc..70b81fb4f4 100644 --- a/dlt/common/runtime/segment.py +++ b/dlt/common/runtime/segment.py @@ -10,7 +10,7 @@ from typing import Literal, Optional from dlt.common.configuration.paths import get_dlt_data_dir -from dlt.common.runtime import logger +from dlt.common import logger from dlt.common.managed_thread_pool import ManagedThreadPool from dlt.common.configuration.specs import RunConfiguration diff --git a/dlt/common/runtime/signals.py b/dlt/common/runtime/signals.py index 8e64c8ba64..8d1cb3803e 100644 --- a/dlt/common/runtime/signals.py +++ b/dlt/common/runtime/signals.py @@ -2,8 +2,9 @@ import signal from contextlib import contextmanager from threading import Event -from typing import Any, TYPE_CHECKING, Iterator +from typing import Any, Iterator +from dlt.common import logger from dlt.common.exceptions import SignalReceivedException _received_signal: int = 0 @@ -11,11 +12,6 @@ def signal_receiver(sig: int, frame: Any) -> None: - if not TYPE_CHECKING: - from dlt.common.runtime import logger - else: - logger: Any = None - global _received_signal logger.info(f"Signal {sig} received") @@ -64,9 +60,5 @@ def delayed_signals() -> Iterator[None]: signal.signal(signal.SIGINT, original_sigint_handler) signal.signal(signal.SIGTERM, original_sigterm_handler) else: - if not TYPE_CHECKING: - from dlt.common.runtime import logger - else: - logger: Any = None logger.info("Running in daemon thread, signals not enabled") yield diff --git a/dlt/common/runtime/slack.py b/dlt/common/runtime/slack.py index 15da89f333..b1e090098d 100644 --- a/dlt/common/runtime/slack.py +++ b/dlt/common/runtime/slack.py @@ -1,8 +1,9 @@ import requests -from dlt.common import json, logger def send_slack_message(incoming_hook: str, message: str, is_markdown: bool = True) -> None: + from dlt.common import json, logger + """Sends a `message` to Slack `incoming_hook`, by default formatted as markdown.""" r = requests.post( incoming_hook, diff --git a/dlt/common/storages/configuration.py b/dlt/common/storages/configuration.py index 2cbe7c78d5..d0100c335d 100644 --- a/dlt/common/storages/configuration.py +++ b/dlt/common/storages/configuration.py @@ -31,24 +31,11 @@ class SchemaStorageConfiguration(BaseConfiguration): True # remove default values when exporting schema ) - if TYPE_CHECKING: - - def __init__( - self, - schema_volume_path: str = None, - import_schema_path: str = None, - export_schema_path: str = None, - ) -> None: ... - @configspec class NormalizeStorageConfiguration(BaseConfiguration): normalize_volume_path: str = None # path to volume where normalized loader files will be stored - if TYPE_CHECKING: - - def __init__(self, normalize_volume_path: str = None) -> None: ... - @configspec class LoadStorageConfiguration(BaseConfiguration): @@ -59,12 +46,6 @@ class LoadStorageConfiguration(BaseConfiguration): False # if set to true the folder with completed jobs will be deleted ) - if TYPE_CHECKING: - - def __init__( - self, load_volume_path: str = None, delete_completed_jobs: bool = None - ) -> None: ... - FileSystemCredentials = Union[ AwsCredentials, GcpServiceAccountCredentials, AzureCredentials, GcpOAuthCredentials @@ -96,7 +77,7 @@ class FilesystemConfiguration(BaseConfiguration): bucket_url: str = None # should be a union of all possible credentials as found in PROTOCOL_CREDENTIALS - credentials: FileSystemCredentials + credentials: FileSystemCredentials = None read_only: bool = False """Indicates read only filesystem access. Will enable caching""" @@ -144,14 +125,3 @@ def __str__(self) -> str: new_netloc += f":{url.port}" return url._replace(netloc=new_netloc).geturl() return self.bucket_url - - if TYPE_CHECKING: - - def __init__( - self, - bucket_url: str, - credentials: FileSystemCredentials = None, - read_only: bool = False, - kwargs: Optional[DictStrAny] = None, - client_kwargs: Optional[DictStrAny] = None, - ) -> None: ... diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index b9c36143ac..3b8af424ee 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -627,8 +627,8 @@ def filter_jobs_for_table( @configspec class LoadPackageStateInjectableContext(ContainerInjectableContext): - storage: PackageStorage - load_id: str + storage: PackageStorage = None + load_id: str = None can_create_default: ClassVar[bool] = False global_affinity: ClassVar[bool] = False @@ -640,10 +640,6 @@ def on_resolved(self) -> None: self.state_save_lock = threading.Lock() self.state = self.storage.get_load_package_state(self.load_id) - if TYPE_CHECKING: - - def __init__(self, load_id: str, storage: PackageStorage) -> None: ... - def load_package() -> TLoadPackage: """Get full load package state present in current context. Across all threads this will be the same in memory dict.""" diff --git a/dlt/destinations/decorators.py b/dlt/destinations/decorators.py index 62d059c4a6..a920d336a2 100644 --- a/dlt/destinations/decorators.py +++ b/dlt/destinations/decorators.py @@ -1,6 +1,6 @@ import functools -from typing import Any, Type, Optional, Callable, Union, cast +from typing import Any, Type, Optional, Callable, Union from typing_extensions import Concatenate from dlt.common.typing import AnyFun @@ -13,7 +13,6 @@ CustomDestinationClientConfiguration, ) from dlt.common.destination import TLoaderFileFormat -from dlt.common.destination.reference import Destination from dlt.common.typing import TDataItems from dlt.common.schema import TTableSchema diff --git a/dlt/destinations/impl/athena/configuration.py b/dlt/destinations/impl/athena/configuration.py index 6b985f284a..59dfeee4ec 100644 --- a/dlt/destinations/impl/athena/configuration.py +++ b/dlt/destinations/impl/athena/configuration.py @@ -1,4 +1,5 @@ -from typing import ClassVar, Final, List, Optional, TYPE_CHECKING +import dataclasses +from typing import ClassVar, Final, List, Optional from dlt.common.configuration import configspec from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration @@ -7,7 +8,7 @@ @configspec class AthenaClientConfiguration(DestinationClientDwhWithStagingConfiguration): - destination_type: Final[str] = "athena" # type: ignore[misc] + destination_type: Final[str] = dataclasses.field(default="athena", init=False, repr=False, compare=False) # type: ignore[misc] query_result_bucket: str = None credentials: AwsCredentials = None athena_work_group: Optional[str] = None @@ -23,19 +24,3 @@ def __str__(self) -> str: return str(self.staging_config.credentials) else: return "[no staging set]" - - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Optional[AwsCredentials] = None, - dataset_name: str = None, - default_schema_name: Optional[str] = None, - athena_work_group: Optional[str] = None, - aws_data_catalog: Optional[str] = None, - supports_truncate_command: bool = False, - force_iceberg: Optional[bool] = False, - destination_name: str = None, - environment: str = None, - ) -> None: ... diff --git a/dlt/destinations/impl/bigquery/configuration.py b/dlt/destinations/impl/bigquery/configuration.py index 3c4a71c0df..a6686c3f2d 100644 --- a/dlt/destinations/impl/bigquery/configuration.py +++ b/dlt/destinations/impl/bigquery/configuration.py @@ -1,5 +1,6 @@ +import dataclasses import warnings -from typing import TYPE_CHECKING, ClassVar, List, Optional, Final +from typing import ClassVar, List, Final from dlt.common.configuration import configspec from dlt.common.configuration.specs import GcpServiceAccountCredentials @@ -10,7 +11,7 @@ @configspec class BigQueryClientConfiguration(DestinationClientDwhWithStagingConfiguration): - destination_type: Final[str] = "bigquery" # type: ignore + destination_type: Final[str] = dataclasses.field(default="bigquery", init=False, repr=False, compare=False) # type: ignore credentials: GcpServiceAccountCredentials = None location: str = "US" @@ -38,31 +39,3 @@ def fingerprint(self) -> str: if self.credentials and self.credentials.project_id: return digest128(self.credentials.project_id) return "" - - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Optional[GcpServiceAccountCredentials] = None, - dataset_name: str = None, - default_schema_name: Optional[str] = None, - location: str = "US", - http_timeout: float = 15.0, - file_upload_timeout: float = 30 * 60.0, - retry_deadline: float = 60.0, - destination_name: str = None, - environment: str = None - ) -> None: - super().__init__( - credentials=credentials, - dataset_name=dataset_name, - default_schema_name=default_schema_name, - destination_name=destination_name, - environment=environment, - ) - self.retry_deadline = retry_deadline - self.file_upload_timeout = file_upload_timeout - self.http_timeout = http_timeout - self.location = location - ... diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index 924047e30f..3bd2d12a5a 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -1,3 +1,4 @@ +import dataclasses from typing import ClassVar, Final, Optional, Any, Dict, List from dlt.common.typing import TSecretStrValue @@ -40,8 +41,8 @@ def to_connector_params(self) -> Dict[str, Any]: @configspec class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration): - destination_type: Final[str] = "databricks" # type: ignore[misc] - credentials: DatabricksCredentials + destination_type: Final[str] = dataclasses.field(default="databricks", init=False, repr=False, compare=False) # type: ignore[misc] + credentials: DatabricksCredentials = None def __str__(self) -> str: """Return displayable destination location""" diff --git a/dlt/destinations/impl/destination/configuration.py b/dlt/destinations/impl/destination/configuration.py index f123ba69b3..30e54a8313 100644 --- a/dlt/destinations/impl/destination/configuration.py +++ b/dlt/destinations/impl/destination/configuration.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Optional, Final, Callable, Union, Any +import dataclasses +from typing import Optional, Final, Callable, Union from typing_extensions import ParamSpec from dlt.common.configuration import configspec @@ -16,19 +17,9 @@ @configspec class CustomDestinationClientConfiguration(DestinationClientConfiguration): - destination_type: Final[str] = "destination" # type: ignore + destination_type: Final[str] = dataclasses.field(default="destination", init=False, repr=False, compare=False) # type: ignore destination_callable: Optional[Union[str, TDestinationCallable]] = None # noqa: A003 loader_file_format: TLoaderFileFormat = "puae-jsonl" batch_size: int = 10 skip_dlt_columns_and_tables: bool = True max_table_nesting: int = 0 - - if TYPE_CHECKING: - - def __init__( - self, - *, - loader_file_format: TLoaderFileFormat = "puae-jsonl", - batch_size: int = 10, - destination_callable: Union[TDestinationCallable, str] = None, - ) -> None: ... diff --git a/dlt/destinations/impl/duckdb/configuration.py b/dlt/destinations/impl/duckdb/configuration.py index 8cb88c43b5..70d91dcb56 100644 --- a/dlt/destinations/impl/duckdb/configuration.py +++ b/dlt/destinations/impl/duckdb/configuration.py @@ -1,7 +1,8 @@ import os +import dataclasses import threading from pathvalidate import is_valid_filepath -from typing import Any, ClassVar, Final, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, ClassVar, Final, List, Optional, Tuple, TYPE_CHECKING, Type, Union from dlt.common import logger from dlt.common.configuration import configspec @@ -13,12 +14,17 @@ ) from dlt.common.typing import TSecretValue +try: + from duckdb import DuckDBPyConnection +except ModuleNotFoundError: + DuckDBPyConnection = Type[Any] # type: ignore[assignment,misc] + DUCK_DB_NAME = "%s.duckdb" DEFAULT_DUCK_DB_NAME = DUCK_DB_NAME % "quack" LOCAL_STATE_KEY = "duckdb_database" -@configspec +@configspec(init=False) class DuckDbBaseCredentials(ConnectionStringCredentials): password: Optional[TSecretValue] = None host: Optional[str] = None @@ -95,7 +101,7 @@ def __del__(self) -> None: @configspec class DuckDbCredentials(DuckDbBaseCredentials): - drivername: Final[str] = "duckdb" # type: ignore + drivername: Final[str] = dataclasses.field(default="duckdb", init=False, repr=False, compare=False) # type: ignore username: Optional[str] = None __config_gen_annotations__: ClassVar[List[str]] = [] @@ -193,30 +199,31 @@ def _path_from_pipeline(self, default_path: str) -> Tuple[str, bool]: def _conn_str(self) -> str: return self.database + def __init__(self, conn_or_path: Union[str, DuckDBPyConnection] = None) -> None: + """Access to duckdb database at a given path or from duckdb connection""" + self._apply_init_value(conn_or_path) + @configspec class DuckDbClientConfiguration(DestinationClientDwhWithStagingConfiguration): - destination_type: Final[str] = "duckdb" # type: ignore - credentials: DuckDbCredentials + destination_type: Final[str] = dataclasses.field(default="duckdb", init=False, repr=False, compare=False) # type: ignore + credentials: DuckDbCredentials = None create_indexes: bool = ( False # should unique indexes be created, this slows loading down massively ) - if TYPE_CHECKING: - try: - from duckdb import DuckDBPyConnection - except ModuleNotFoundError: - DuckDBPyConnection = Any # type: ignore[assignment,misc] - - def __init__( - self, - *, - credentials: Union[DuckDbCredentials, str, DuckDBPyConnection] = None, - dataset_name: str = None, - default_schema_name: Optional[str] = None, - create_indexes: bool = False, - staging_config: Optional[DestinationClientStagingConfiguration] = None, - destination_name: str = None, - environment: str = None, - ) -> None: ... + def __init__( + self, + *, + credentials: Union[DuckDbCredentials, str, DuckDBPyConnection] = None, + create_indexes: bool = False, + destination_name: str = None, + environment: str = None, + ) -> None: + super().__init__( + credentials=credentials, # type: ignore[arg-type] + destination_name=destination_name, + environment=environment, + ) + self.create_indexes = create_indexes diff --git a/dlt/destinations/impl/dummy/configuration.py b/dlt/destinations/impl/dummy/configuration.py index cce0dfa8ed..a9fdb1f47d 100644 --- a/dlt/destinations/impl/dummy/configuration.py +++ b/dlt/destinations/impl/dummy/configuration.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Optional, Final +import dataclasses +from typing import Final from dlt.common.configuration import configspec from dlt.common.destination import TLoaderFileFormat @@ -16,7 +17,7 @@ def __str__(self) -> str: @configspec class DummyClientConfiguration(DestinationClientConfiguration): - destination_type: Final[str] = "dummy" # type: ignore + destination_type: Final[str] = dataclasses.field(default="dummy", init=False, repr=False, compare=False) # type: ignore loader_file_format: TLoaderFileFormat = "jsonl" fail_schema_update: bool = False fail_prob: float = 0.0 @@ -30,22 +31,3 @@ class DummyClientConfiguration(DestinationClientConfiguration): create_followup_jobs: bool = False credentials: DummyClientCredentials = None - - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Optional[CredentialsConfiguration] = None, - loader_file_format: TLoaderFileFormat = None, - fail_schema_update: bool = None, - fail_prob: float = None, - retry_prob: float = None, - completed_prob: float = None, - exception_prob: float = None, - timeout: float = None, - fail_in_init: bool = None, - create_followup_jobs: bool = None, - destination_name: str = None, - environment: str = None, - ) -> None: ... diff --git a/dlt/destinations/impl/filesystem/configuration.py b/dlt/destinations/impl/filesystem/configuration.py index 93e5537aab..1521222180 100644 --- a/dlt/destinations/impl/filesystem/configuration.py +++ b/dlt/destinations/impl/filesystem/configuration.py @@ -1,6 +1,5 @@ -from urllib.parse import urlparse - -from typing import Final, Type, Optional, Any, TYPE_CHECKING +import dataclasses +from typing import Final, Type, Optional from dlt.common.configuration import configspec, resolve_type from dlt.common.destination.reference import ( @@ -12,22 +11,9 @@ @configspec class FilesystemDestinationClientConfiguration(FilesystemConfiguration, DestinationClientStagingConfiguration): # type: ignore[misc] - destination_type: Final[str] = "filesystem" # type: ignore + destination_type: Final[str] = dataclasses.field(default="filesystem", init=False, repr=False, compare=False) # type: ignore @resolve_type("credentials") def resolve_credentials_type(self) -> Type[CredentialsConfiguration]: # use known credentials or empty credentials for unknown protocol return self.PROTOCOL_CREDENTIALS.get(self.protocol) or Optional[CredentialsConfiguration] # type: ignore[return-value] - - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Optional[Any] = None, - dataset_name: str = None, - default_schema_name: Optional[str] = None, - bucket_url: str = None, - destination_name: str = None, - environment: str = None, - ) -> None: ... diff --git a/dlt/destinations/impl/motherduck/configuration.py b/dlt/destinations/impl/motherduck/configuration.py index 35f02f709a..3179295c54 100644 --- a/dlt/destinations/impl/motherduck/configuration.py +++ b/dlt/destinations/impl/motherduck/configuration.py @@ -1,4 +1,5 @@ -from typing import Any, ClassVar, Final, List, TYPE_CHECKING, Optional +import dataclasses +from typing import Any, ClassVar, Final, List from dlt.common.configuration import configspec from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration @@ -12,9 +13,9 @@ MOTHERDUCK_DRIVERNAME = "md" -@configspec +@configspec(init=False) class MotherDuckCredentials(DuckDbBaseCredentials): - drivername: Final[str] = "md" # type: ignore + drivername: Final[str] = dataclasses.field(default="md", init=False, repr=False, compare=False) # type: ignore username: str = "motherduck" read_only: bool = False # open database read/write @@ -57,8 +58,8 @@ def on_resolved(self) -> None: @configspec class MotherDuckClientConfiguration(DestinationClientDwhWithStagingConfiguration): - destination_type: Final[str] = "motherduck" # type: ignore - credentials: MotherDuckCredentials + destination_type: Final[str] = dataclasses.field(default="motherduck", init=False, repr=False, compare=False) # type: ignore + credentials: MotherDuckCredentials = None create_indexes: bool = ( False # should unique indexes be created, this slows loading down massively @@ -70,19 +71,6 @@ def fingerprint(self) -> str: return digest128(self.credentials.password) return "" - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Optional[MotherDuckCredentials] = None, - dataset_name: str = None, - default_schema_name: Optional[str] = None, - create_indexes: Optional[bool] = None, - destination_name: str = None, - environment: str = None, - ) -> None: ... - class MotherduckLocalVersionNotSupported(DestinationTerminalException): def __init__(self, duckdb_version: str) -> None: diff --git a/dlt/destinations/impl/mssql/configuration.py b/dlt/destinations/impl/mssql/configuration.py index 45c448fab7..1d085f40c1 100644 --- a/dlt/destinations/impl/mssql/configuration.py +++ b/dlt/destinations/impl/mssql/configuration.py @@ -1,4 +1,5 @@ -from typing import Final, ClassVar, Any, List, Dict, Optional, TYPE_CHECKING +import dataclasses +from typing import Final, ClassVar, Any, List, Dict from dlt.common.libs.sql_alchemy import URL from dlt.common.configuration import configspec @@ -10,11 +11,11 @@ from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration -@configspec +@configspec(init=False) class MsSqlCredentials(ConnectionStringCredentials): - drivername: Final[str] = "mssql" # type: ignore - password: TSecretValue - host: str + drivername: Final[str] = dataclasses.field(default="mssql", init=False, repr=False, compare=False) # type: ignore + password: TSecretValue = None + host: str = None port: int = 1433 connect_timeout: int = 15 driver: str = None @@ -90,8 +91,8 @@ def to_odbc_dsn(self) -> str: @configspec class MsSqlClientConfiguration(DestinationClientDwhWithStagingConfiguration): - destination_type: Final[str] = "mssql" # type: ignore - credentials: MsSqlCredentials + destination_type: Final[str] = dataclasses.field(default="mssql", init=False, repr=False, compare=False) # type: ignore + credentials: MsSqlCredentials = None create_indexes: bool = False @@ -100,16 +101,3 @@ def fingerprint(self) -> str: if self.credentials and self.credentials.host: return digest128(self.credentials.host) return "" - - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Optional[MsSqlCredentials] = None, - dataset_name: str = None, - default_schema_name: Optional[str] = None, - create_indexes: Optional[bool] = None, - destination_name: str = None, - environment: str = None, - ) -> None: ... diff --git a/dlt/destinations/impl/postgres/configuration.py b/dlt/destinations/impl/postgres/configuration.py index 109d422650..0d12abbac7 100644 --- a/dlt/destinations/impl/postgres/configuration.py +++ b/dlt/destinations/impl/postgres/configuration.py @@ -1,6 +1,7 @@ -from typing import Final, ClassVar, Any, List, TYPE_CHECKING -from dlt.common.libs.sql_alchemy import URL +import dataclasses +from typing import Final, ClassVar, Any, List, TYPE_CHECKING, Union +from dlt.common.libs.sql_alchemy import URL from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials from dlt.common.utils import digest128 @@ -9,11 +10,11 @@ from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration -@configspec +@configspec(init=False) class PostgresCredentials(ConnectionStringCredentials): - drivername: Final[str] = "postgresql" # type: ignore - password: TSecretValue - host: str + drivername: Final[str] = dataclasses.field(default="postgresql", init=False, repr=False, compare=False) # type: ignore + password: TSecretValue = None + host: str = None port: int = 5432 connect_timeout: int = 15 @@ -33,8 +34,8 @@ def to_url(self) -> URL: @configspec class PostgresClientConfiguration(DestinationClientDwhWithStagingConfiguration): - destination_type: Final[str] = "postgres" # type: ignore - credentials: PostgresCredentials + destination_type: Final[str] = dataclasses.field(default="postgres", init=False, repr=False, compare=False) # type: ignore + credentials: PostgresCredentials = None create_indexes: bool = True @@ -43,16 +44,3 @@ def fingerprint(self) -> str: if self.credentials and self.credentials.host: return digest128(self.credentials.host) return "" - - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: PostgresCredentials = None, - dataset_name: str = None, - default_schema_name: str = None, - create_indexes: bool = True, - destination_name: str = None, - environment: str = None, - ) -> None: ... diff --git a/dlt/destinations/impl/qdrant/configuration.py b/dlt/destinations/impl/qdrant/configuration.py index 23637dee33..d589537742 100644 --- a/dlt/destinations/impl/qdrant/configuration.py +++ b/dlt/destinations/impl/qdrant/configuration.py @@ -1,3 +1,4 @@ +import dataclasses from typing import Optional, Final from dlt.common.configuration import configspec @@ -15,7 +16,7 @@ class QdrantCredentials(CredentialsConfiguration): # If `None` - use default values for `host` and `port` location: Optional[str] = None # API key for authentication in Qdrant Cloud. Default: `None` - api_key: Optional[str] + api_key: Optional[str] = None def __str__(self) -> str: return self.location or "localhost" @@ -47,12 +48,14 @@ class QdrantClientOptions(BaseConfiguration): @configspec class QdrantClientConfiguration(DestinationClientDwhConfiguration): - destination_type: Final[str] = "qdrant" # type: ignore + destination_type: Final[str] = dataclasses.field(default="qdrant", init=False, repr=False, compare=False) # type: ignore + # Qdrant connection credentials + credentials: QdrantCredentials = None # character for the dataset separator dataset_separator: str = "_" # make it optional so empty dataset is allowed - dataset_name: Final[Optional[str]] = None # type: ignore[misc] + dataset_name: Final[Optional[str]] = dataclasses.field(default=None, init=False, repr=False, compare=False) # type: ignore[misc] # Batch size for generating embeddings embedding_batch_size: int = 32 @@ -67,10 +70,7 @@ class QdrantClientConfiguration(DestinationClientDwhConfiguration): upload_max_retries: int = 3 # Qdrant client options - options: QdrantClientOptions - - # Qdrant connection credentials - credentials: QdrantCredentials + options: QdrantClientOptions = None # FlagEmbedding model to use # Find the list here. https://qdrant.github.io/fastembed/examples/Supported_Models/. diff --git a/dlt/destinations/impl/redshift/configuration.py b/dlt/destinations/impl/redshift/configuration.py index 2a6ade4a4f..72d7f70a9f 100644 --- a/dlt/destinations/impl/redshift/configuration.py +++ b/dlt/destinations/impl/redshift/configuration.py @@ -1,4 +1,5 @@ -from typing import Final, Optional, TYPE_CHECKING +import dataclasses +from typing import Final, Optional from dlt.common.typing import TSecretValue from dlt.common.configuration import configspec @@ -10,7 +11,7 @@ ) -@configspec +@configspec(init=False) class RedshiftCredentials(PostgresCredentials): port: int = 5439 password: TSecretValue = None @@ -20,8 +21,8 @@ class RedshiftCredentials(PostgresCredentials): @configspec class RedshiftClientConfiguration(PostgresClientConfiguration): - destination_type: Final[str] = "redshift" # type: ignore - credentials: RedshiftCredentials + destination_type: Final[str] = dataclasses.field(default="redshift", init=False, repr=False, compare=False) # type: ignore + credentials: RedshiftCredentials = None staging_iam_role: Optional[str] = None def fingerprint(self) -> str: @@ -29,17 +30,3 @@ def fingerprint(self) -> str: if self.credentials and self.credentials.host: return digest128(self.credentials.host) return "" - - if TYPE_CHECKING: - - def __init__( - self, - *, - destination_type: str = None, - credentials: PostgresCredentials = None, - dataset_name: str = None, - default_schema_name: str = None, - staging_iam_role: str = None, - destination_name: str = None, - environment: str = None, - ) -> None: ... diff --git a/dlt/destinations/impl/snowflake/configuration.py b/dlt/destinations/impl/snowflake/configuration.py index 4f97f08700..5a1f7a65a9 100644 --- a/dlt/destinations/impl/snowflake/configuration.py +++ b/dlt/destinations/impl/snowflake/configuration.py @@ -1,11 +1,9 @@ +import dataclasses import base64 -import binascii - -from typing import Final, Optional, Any, Dict, ClassVar, List, TYPE_CHECKING - -from dlt.common.libs.sql_alchemy import URL +from typing import Final, Optional, Any, Dict, ClassVar, List, TYPE_CHECKING, Union from dlt import version +from dlt.common.libs.sql_alchemy import URL from dlt.common.exceptions import MissingDependencyException from dlt.common.typing import TSecretStrValue from dlt.common.configuration.specs import ConnectionStringCredentials @@ -51,9 +49,9 @@ def _read_private_key(private_key: str, password: Optional[str] = None) -> bytes ) -@configspec +@configspec(init=False) class SnowflakeCredentials(ConnectionStringCredentials): - drivername: Final[str] = "snowflake" # type: ignore[misc] + drivername: Final[str] = dataclasses.field(default="snowflake", init=False, repr=False, compare=False) # type: ignore[misc] password: Optional[TSecretStrValue] = None host: str = None database: str = None @@ -118,8 +116,8 @@ def to_connector_params(self) -> Dict[str, Any]: @configspec class SnowflakeClientConfiguration(DestinationClientDwhWithStagingConfiguration): - destination_type: Final[str] = "snowflake" # type: ignore[misc] - credentials: SnowflakeCredentials + destination_type: Final[str] = dataclasses.field(default="snowflake", init=False, repr=False, compare=False) # type: ignore[misc] + credentials: SnowflakeCredentials = None stage_name: Optional[str] = None """Use an existing named stage instead of the default. Default uses the implicit table stage per table""" @@ -131,18 +129,3 @@ def fingerprint(self) -> str: if self.credentials and self.credentials.host: return digest128(self.credentials.host) return "" - - if TYPE_CHECKING: - - def __init__( - self, - *, - destination_type: str = None, - credentials: SnowflakeCredentials = None, - dataset_name: str = None, - default_schema_name: str = None, - stage_name: str = None, - keep_staged_files: bool = True, - destination_name: str = None, - environment: str = None, - ) -> None: ... diff --git a/dlt/destinations/impl/synapse/configuration.py b/dlt/destinations/impl/synapse/configuration.py index bb1ba632dc..37b932cd67 100644 --- a/dlt/destinations/impl/synapse/configuration.py +++ b/dlt/destinations/impl/synapse/configuration.py @@ -1,9 +1,8 @@ +import dataclasses +from dlt import version from typing import Final, Any, List, Dict, Optional, ClassVar -from dlt.common import logger from dlt.common.configuration import configspec -from dlt.common.schema.typing import TSchemaTables -from dlt.common.schema.utils import get_write_disposition from dlt.destinations.impl.mssql.configuration import ( MsSqlCredentials, @@ -14,9 +13,9 @@ from dlt.destinations.impl.synapse.synapse_adapter import TTableIndexType -@configspec +@configspec(init=False) class SynapseCredentials(MsSqlCredentials): - drivername: Final[str] = "synapse" # type: ignore + drivername: Final[str] = dataclasses.field(default="synapse", init=False, repr=False, compare=False) # type: ignore # LongAsMax keyword got introduced in ODBC Driver 18 for SQL Server. SUPPORTED_DRIVERS: ClassVar[List[str]] = ["ODBC Driver 18 for SQL Server"] @@ -32,8 +31,8 @@ def _get_odbc_dsn_dict(self) -> Dict[str, Any]: @configspec class SynapseClientConfiguration(MsSqlClientConfiguration): - destination_type: Final[str] = "synapse" # type: ignore - credentials: SynapseCredentials + destination_type: Final[str] = dataclasses.field(default="synapse", init=False, repr=False, compare=False) # type: ignore + credentials: SynapseCredentials = None # While Synapse uses CLUSTERED COLUMNSTORE INDEX tables by default, we use # HEAP tables (no indexing) by default. HEAP is a more robust choice, because diff --git a/dlt/destinations/impl/synapse/factory.py b/dlt/destinations/impl/synapse/factory.py index b7eddd6ef7..100878ae05 100644 --- a/dlt/destinations/impl/synapse/factory.py +++ b/dlt/destinations/impl/synapse/factory.py @@ -16,6 +16,11 @@ class synapse(Destination[SynapseClientConfiguration, "SynapseClient"]): spec = SynapseClientConfiguration + # TODO: implement as property everywhere and makes sure not accessed as class property + # @property + # def spec(self) -> t.Type[SynapseClientConfiguration]: + # return SynapseClientConfiguration + def capabilities(self) -> DestinationCapabilitiesContext: return capabilities() diff --git a/dlt/destinations/impl/weaviate/configuration.py b/dlt/destinations/impl/weaviate/configuration.py index 5014e69163..90fb7ce5ce 100644 --- a/dlt/destinations/impl/weaviate/configuration.py +++ b/dlt/destinations/impl/weaviate/configuration.py @@ -1,5 +1,5 @@ -from typing import Dict, Literal, Optional, Final, TYPE_CHECKING -from dataclasses import field +import dataclasses +from typing import Dict, Literal, Optional, Final from urllib.parse import urlparse from dlt.common.configuration import configspec @@ -13,7 +13,7 @@ @configspec class WeaviateCredentials(CredentialsConfiguration): url: str = "http://localhost:8080" - api_key: Optional[str] + api_key: Optional[str] = None additional_headers: Optional[Dict[str, str]] = None def __str__(self) -> str: @@ -24,7 +24,7 @@ def __str__(self) -> str: @configspec class WeaviateClientConfiguration(DestinationClientDwhConfiguration): - destination_type: Final[str] = "weaviate" # type: ignore + destination_type: Final[str] = dataclasses.field(default="weaviate", init=False, repr=False, compare=False) # type: ignore # make it optional so empty dataset is allowed dataset_name: Optional[str] = None # type: ignore[misc] @@ -39,9 +39,9 @@ class WeaviateClientConfiguration(DestinationClientDwhConfiguration): dataset_separator: str = "_" - credentials: WeaviateCredentials + credentials: WeaviateCredentials = None vectorizer: str = "text2vec-openai" - module_config: Dict[str, Dict[str, str]] = field( + module_config: Dict[str, Dict[str, str]] = dataclasses.field( default_factory=lambda: { "text2vec-openai": { "model": "ada", @@ -58,26 +58,3 @@ def fingerprint(self) -> str: hostname = urlparse(self.credentials.url).hostname return digest128(hostname) return "" - - if TYPE_CHECKING: - - def __init__( - self, - *, - destination_type: str = None, - credentials: WeaviateCredentials = None, - name: str = None, - environment: str = None, - dataset_name: str = None, - default_schema_name: str = None, - batch_size: int = None, - batch_workers: int = None, - batch_consistency: TWeaviateBatchConsistency = None, - batch_retries: int = None, - conn_timeout: float = None, - read_timeout: float = None, - startup_period: int = None, - dataset_separator: str = None, - vectorizer: str = None, - module_config: Dict[str, Dict[str, str]] = None, - ) -> None: ... diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index 215bcf9fe5..91be3a60c9 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -1,7 +1,7 @@ from typing import Any, Callable, List, Sequence, Tuple, cast, TypedDict, Optional import yaml -from dlt.common.runtime.logger import pretty_format_exception +from dlt.common.logger import pretty_format_exception from dlt.common.schema.typing import TTableSchema, TSortOrder from dlt.common.schema.utils import ( diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index e5525519ec..28a2aca633 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -72,27 +72,19 @@ class SourceSchemaInjectableContext(ContainerInjectableContext): """A context containing the source schema, present when dlt.source/resource decorated function is executed""" - schema: Schema + schema: Schema = None can_create_default: ClassVar[bool] = False - if TYPE_CHECKING: - - def __init__(self, schema: Schema = None) -> None: ... - @configspec class SourceInjectableContext(ContainerInjectableContext): """A context containing the source schema, present when dlt.resource decorated function is executed""" - source: DltSource + source: DltSource = None can_create_default: ClassVar[bool] = False - if TYPE_CHECKING: - - def __init__(self, source: DltSource = None) -> None: ... - TSourceFunParams = ParamSpec("TSourceFunParams") TResourceFunParams = ParamSpec("TResourceFunParams") diff --git a/dlt/extract/pipe_iterator.py b/dlt/extract/pipe_iterator.py index 145b517802..1edd9bd039 100644 --- a/dlt/extract/pipe_iterator.py +++ b/dlt/extract/pipe_iterator.py @@ -2,6 +2,7 @@ import types from typing import ( AsyncIterator, + ClassVar, Dict, Sequence, Union, @@ -16,7 +17,11 @@ from dlt.common.configuration import configspec from dlt.common.configuration.inject import with_config -from dlt.common.configuration.specs import BaseConfiguration, ContainerInjectableContext +from dlt.common.configuration.specs import ( + BaseConfiguration, + ContainerInjectableContext, + known_sections, +) from dlt.common.configuration.container import Container from dlt.common.exceptions import PipelineException from dlt.common.source import unset_current_pipe_name, set_current_pipe_name @@ -48,7 +53,7 @@ class PipeIteratorConfiguration(BaseConfiguration): copy_on_fork: bool = False next_item_mode: str = "fifo" - __section__ = "extract" + __section__: ClassVar[str] = known_sections.EXTRACT def __init__( self, diff --git a/dlt/helpers/dbt/configuration.py b/dlt/helpers/dbt/configuration.py index 4cd3f3a0f4..70fa4d1ac5 100644 --- a/dlt/helpers/dbt/configuration.py +++ b/dlt/helpers/dbt/configuration.py @@ -19,7 +19,7 @@ class DBTRunnerConfiguration(BaseConfiguration): package_additional_vars: Optional[StrAny] = None - runtime: RunConfiguration + runtime: RunConfiguration = None def on_resolved(self) -> None: if not self.package_profiles_dir: diff --git a/dlt/helpers/dbt/runner.py b/dlt/helpers/dbt/runner.py index 388b81b2ee..7b1f79dc77 100644 --- a/dlt/helpers/dbt/runner.py +++ b/dlt/helpers/dbt/runner.py @@ -11,7 +11,7 @@ from dlt.common.runners import Venv from dlt.common.runners.stdout import iter_stdout_with_result from dlt.common.typing import StrAny, TSecretValue -from dlt.common.runtime.logger import is_json_logging +from dlt.common.logger import is_json_logging from dlt.common.storages import FileStorage from dlt.common.git import git_custom_key_command, ensure_remote_head, force_clone_repo from dlt.common.utils import with_custom_environ diff --git a/dlt/helpers/dbt_cloud/configuration.py b/dlt/helpers/dbt_cloud/configuration.py index aac94b2f4a..3c95d53431 100644 --- a/dlt/helpers/dbt_cloud/configuration.py +++ b/dlt/helpers/dbt_cloud/configuration.py @@ -9,13 +9,13 @@ class DBTCloudConfiguration(BaseConfiguration): api_token: TSecretValue = TSecretValue("") - account_id: Optional[str] - job_id: Optional[str] - project_id: Optional[str] - environment_id: Optional[str] - run_id: Optional[str] + account_id: Optional[str] = None + job_id: Optional[str] = None + project_id: Optional[str] = None + environment_id: Optional[str] = None + run_id: Optional[str] = None cause: str = "Triggered via API" - git_sha: Optional[str] - git_branch: Optional[str] - schema_override: Optional[str] + git_sha: Optional[str] = None + git_branch: Optional[str] = None + schema_override: Optional[str] = None diff --git a/dlt/load/configuration.py b/dlt/load/configuration.py index 0a84e3c331..97cf23fdfc 100644 --- a/dlt/load/configuration.py +++ b/dlt/load/configuration.py @@ -18,13 +18,3 @@ class LoaderConfiguration(PoolRunnerConfiguration): def on_resolved(self) -> None: self.pool_type = "none" if self.workers == 1 else "thread" - - if TYPE_CHECKING: - - def __init__( - self, - pool_type: TPoolType = "thread", - workers: int = None, - raise_on_failed_jobs: bool = False, - _load_storage_config: LoadStorageConfiguration = None, - ) -> None: ... diff --git a/dlt/load/load.py b/dlt/load/load.py index a0909fa2d0..f02a21f98e 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -15,16 +15,15 @@ SupportsPipeline, WithStepInfo, ) -from dlt.common.schema.utils import get_child_tables, get_top_level_table +from dlt.common.schema.utils import get_top_level_table from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState from dlt.common.storages.load_package import LoadPackageStateInjectableContext from dlt.common.runners import TRunMetrics, Runnable, workermethod, NullExecutor from dlt.common.runtime.collector import Collector, NULL_COLLECTOR -from dlt.common.runtime.logger import pretty_format_exception +from dlt.common.logger import pretty_format_exception from dlt.common.exceptions import TerminalValueError from dlt.common.configuration.container import Container -from dlt.common.configuration.specs.config_section_context import ConfigSectionContext -from dlt.common.schema import Schema, TSchemaTables +from dlt.common.schema import Schema from dlt.common.storages import LoadStorage from dlt.common.destination.reference import ( DestinationClientDwhConfiguration, diff --git a/dlt/normalize/configuration.py b/dlt/normalize/configuration.py index 3949a07fa8..5676d23569 100644 --- a/dlt/normalize/configuration.py +++ b/dlt/normalize/configuration.py @@ -18,18 +18,14 @@ class ItemsNormalizerConfiguration(BaseConfiguration): add_dlt_load_id: bool = False """When true, items to be normalized will have `_dlt_load_id` column added with the current load ID.""" - if TYPE_CHECKING: - - def __init__(self, add_dlt_id: bool = None, add_dlt_load_id: bool = None) -> None: ... - @configspec class NormalizeConfiguration(PoolRunnerConfiguration): pool_type: TPoolType = "process" destination_capabilities: DestinationCapabilitiesContext = None # injectable - _schema_storage_config: SchemaStorageConfiguration - _normalize_storage_config: NormalizeStorageConfiguration - _load_storage_config: LoadStorageConfiguration + _schema_storage_config: SchemaStorageConfiguration = None + _normalize_storage_config: NormalizeStorageConfiguration = None + _load_storage_config: LoadStorageConfiguration = None json_normalizer: ItemsNormalizerConfiguration = ItemsNormalizerConfiguration( add_dlt_id=True, add_dlt_load_id=True @@ -41,14 +37,3 @@ class NormalizeConfiguration(PoolRunnerConfiguration): def on_resolved(self) -> None: self.pool_type = "none" if self.workers == 1 else "process" - - if TYPE_CHECKING: - - def __init__( - self, - pool_type: TPoolType = "process", - workers: int = None, - _schema_storage_config: SchemaStorageConfiguration = None, - _normalize_storage_config: NormalizeStorageConfiguration = None, - _load_storage_config: LoadStorageConfiguration = None, - ) -> None: ... diff --git a/dlt/pipeline/configuration.py b/dlt/pipeline/configuration.py index 7aa54541c0..d7ffca6e89 100644 --- a/dlt/pipeline/configuration.py +++ b/dlt/pipeline/configuration.py @@ -27,7 +27,7 @@ class PipelineConfiguration(BaseConfiguration): full_refresh: bool = False """When set to True, each instance of the pipeline with the `pipeline_name` starts from scratch when run and loads the data to a separate dataset.""" progress: Optional[str] = None - runtime: RunConfiguration + runtime: RunConfiguration = None def on_resolved(self) -> None: if not self.pipeline_name: diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index efb6ae078b..de1f7afced 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -12,6 +12,7 @@ Optional, Sequence, Tuple, + Type, cast, get_type_hints, ContextManager, @@ -1122,17 +1123,16 @@ def _get_destination_client_initial_config( ) if issubclass(client_spec, DestinationClientStagingConfiguration): - return client_spec( - dataset_name=self.dataset_name, - default_schema_name=default_schema_name, + spec: DestinationClientDwhConfiguration = client_spec( credentials=credentials, as_staging=as_staging, ) - return client_spec( - dataset_name=self.dataset_name, - default_schema_name=default_schema_name, - credentials=credentials, - ) + else: + spec = client_spec( + credentials=credentials, + ) + spec._bind_dataset_name(self.dataset_name, default_schema_name) + return spec return client_spec(credentials=credentials) diff --git a/dlt/pipeline/platform.py b/dlt/pipeline/platform.py index c8014d5ae7..0955e91b51 100644 --- a/dlt/pipeline/platform.py +++ b/dlt/pipeline/platform.py @@ -6,7 +6,7 @@ from dlt.pipeline.trace import PipelineTrace, PipelineStepTrace, TPipelineStep, SupportsPipeline from dlt.common import json -from dlt.common.runtime import logger +from dlt.common import logger from dlt.common.pipeline import LoadInfo from dlt.common.schema.typing import TStoredSchema diff --git a/dlt/pipeline/trace.py b/dlt/pipeline/trace.py index 5679884b0b..b610d1751f 100644 --- a/dlt/pipeline/trace.py +++ b/dlt/pipeline/trace.py @@ -14,7 +14,7 @@ from dlt.common.configuration.utils import _RESOLVED_TRACES from dlt.common.configuration.container import Container from dlt.common.exceptions import ExceptionTrace, ResourceNameNotAvailable -from dlt.common.runtime.logger import suppress_and_warn +from dlt.common.logger import suppress_and_warn from dlt.common.runtime.exec_info import TExecutionContext, get_execution_context from dlt.common.pipeline import ( ExtractInfo, diff --git a/dlt/sources/helpers/requests/__init__.py b/dlt/sources/helpers/requests/__init__.py index 3e29a2cf52..d76e24ec42 100644 --- a/dlt/sources/helpers/requests/__init__.py +++ b/dlt/sources/helpers/requests/__init__.py @@ -15,11 +15,12 @@ from requests.exceptions import ChunkedEncodingError from dlt.sources.helpers.requests.retry import Client from dlt.sources.helpers.requests.session import Session +from dlt.sources.helpers.rest_client import paginate from dlt.common.configuration.specs import RunConfiguration client = Client() -get, post, put, patch, delete, options, head, request = ( +get, post, put, patch, delete, options, head, request, paginate = ( client.get, client.post, client.put, @@ -28,6 +29,7 @@ client.options, client.head, client.request, + paginate, ) diff --git a/dlt/sources/helpers/rest_client/__init__.py b/dlt/sources/helpers/rest_client/__init__.py new file mode 100644 index 0000000000..b2fb0a2351 --- /dev/null +++ b/dlt/sources/helpers/rest_client/__init__.py @@ -0,0 +1,46 @@ +from typing import Optional, Dict, Iterator, Union, Any + +from dlt.common import jsonpath + +from .client import RESTClient # noqa: F401 +from .client import PageData +from .auth import AuthConfigBase +from .paginators import BasePaginator +from .typing import HTTPMethodBasic, Hooks + + +def paginate( + url: str, + method: HTTPMethodBasic = "GET", + headers: Optional[Dict[str, str]] = None, + params: Optional[Dict[str, Any]] = None, + json: Optional[Dict[str, Any]] = None, + auth: AuthConfigBase = None, + paginator: Optional[BasePaginator] = None, + data_selector: Optional[jsonpath.TJsonPath] = None, + hooks: Optional[Hooks] = None, +) -> Iterator[PageData[Any]]: + """ + Paginate over a REST API endpoint. + + Args: + url: URL to paginate over. + **kwargs: Keyword arguments to pass to `RESTClient.paginate`. + + Returns: + Iterator[Page]: Iterator over pages. + """ + client = RESTClient( + base_url=url, + headers=headers, + ) + return client.paginate( + path="", + method=method, + params=params, + json=json, + auth=auth, + paginator=paginator, + data_selector=data_selector, + hooks=hooks, + ) diff --git a/dlt/sources/helpers/rest_client/auth.py b/dlt/sources/helpers/rest_client/auth.py new file mode 100644 index 0000000000..5d7a2f7eb2 --- /dev/null +++ b/dlt/sources/helpers/rest_client/auth.py @@ -0,0 +1,215 @@ +from base64 import b64encode +import math +from typing import ( + List, + Dict, + Final, + Literal, + Optional, + Union, + Any, + cast, + Iterable, + TYPE_CHECKING, +) +from dlt.sources.helpers import requests +from requests.auth import AuthBase +from requests import PreparedRequest # noqa: I251 +import pendulum + +from dlt.common.exceptions import MissingDependencyException + +from dlt.common import logger +from dlt.common.configuration.specs.base_configuration import configspec +from dlt.common.configuration.specs import CredentialsConfiguration +from dlt.common.configuration.specs.exceptions import NativeValueError +from dlt.common.typing import TSecretStrValue + +if TYPE_CHECKING: + from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes +else: + PrivateKeyTypes = Any + +TApiKeyLocation = Literal[ + "header", "cookie", "query", "param" +] # Alias for scheme "in" field + + +class AuthConfigBase(AuthBase, CredentialsConfiguration): + """Authenticator base which is both `requests` friendly AuthBase and dlt SPEC + configurable via env variables or toml files + """ + + pass + + +@configspec +class BearerTokenAuth(AuthConfigBase): + token: TSecretStrValue = None + + def parse_native_representation(self, value: Any) -> None: + if isinstance(value, str): + self.token = cast(TSecretStrValue, value) + else: + raise NativeValueError( + type(self), + value, + f"BearerTokenAuth token must be a string, got {type(value)}", + ) + + def __call__(self, request: PreparedRequest) -> PreparedRequest: + request.headers["Authorization"] = f"Bearer {self.token}" + return request + + +@configspec +class APIKeyAuth(AuthConfigBase): + name: str = "Authorization" + api_key: TSecretStrValue = None + location: TApiKeyLocation = "header" + + def parse_native_representation(self, value: Any) -> None: + if isinstance(value, str): + self.api_key = cast(TSecretStrValue, value) + else: + raise NativeValueError( + type(self), + value, + f"APIKeyAuth api_key must be a string, got {type(value)}", + ) + + def __call__(self, request: PreparedRequest) -> PreparedRequest: + if self.location == "header": + request.headers[self.name] = self.api_key + elif self.location in ["query", "param"]: + request.prepare_url(request.url, {self.name: self.api_key}) + elif self.location == "cookie": + raise NotImplementedError() + return request + + +@configspec +class HttpBasicAuth(AuthConfigBase): + username: str = None + password: TSecretStrValue = None + + def parse_native_representation(self, value: Any) -> None: + if isinstance(value, Iterable) and not isinstance(value, str): + value = list(value) + if len(value) == 2: + self.username, self.password = value + return + raise NativeValueError( + type(self), + value, + f"HttpBasicAuth username and password must be a tuple of two strings, got {type(value)}", + ) + + def __call__(self, request: PreparedRequest) -> PreparedRequest: + encoded = b64encode(f"{self.username}:{self.password}".encode()).decode() + request.headers["Authorization"] = f"Basic {encoded}" + return request + + +@configspec +class OAuth2AuthBase(AuthConfigBase): + """Base class for oauth2 authenticators. requires access_token""" + + # TODO: Separate class for flows (implicit, authorization_code, client_credentials, etc) + access_token: TSecretStrValue = None + + def parse_native_representation(self, value: Any) -> None: + if isinstance(value, str): + self.access_token = cast(TSecretStrValue, value) + else: + raise NativeValueError( + type(self), + value, + f"OAuth2AuthBase access_token must be a string, got {type(value)}", + ) + + def __call__(self, request: PreparedRequest) -> PreparedRequest: + request.headers["Authorization"] = f"Bearer {self.access_token}" + return request + + +@configspec +class OAuthJWTAuth(BearerTokenAuth): + """This is a form of Bearer auth, actually there's not standard way to declare it in openAPI""" + + format: Final[Literal["JWT"]] = "JWT" # noqa: A003 + client_id: str = None + private_key: TSecretStrValue = None + auth_endpoint: str = None + scopes: Optional[Union[str, List[str]]] = None + headers: Optional[Dict[str, str]] = None + private_key_passphrase: Optional[TSecretStrValue] = None + default_token_expiration: int = 3600 + + def __post_init__(self) -> None: + self.scopes = ( + self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes) + ) + self.token = None + self.token_expiry: Optional[pendulum.DateTime] = None + + def __call__(self, r: PreparedRequest) -> PreparedRequest: + if self.token is None or self.is_token_expired(): + self.obtain_token() + r.headers["Authorization"] = f"Bearer {self.token}" + return r + + def is_token_expired(self) -> bool: + return not self.token_expiry or pendulum.now() >= self.token_expiry + + def obtain_token(self) -> None: + try: + import jwt + except ModuleNotFoundError: + raise MissingDependencyException("dlt OAuth helpers", ["PyJWT"]) + + payload = self.create_jwt_payload() + data = { + "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer", + "assertion": jwt.encode( + payload, self.load_private_key(), algorithm="RS256" + ), + } + + logger.debug(f"Obtaining token from {self.auth_endpoint}") + + response = requests.post(self.auth_endpoint, headers=self.headers, data=data) + response.raise_for_status() + + token_response = response.json() + self.token = token_response["access_token"] + self.token_expiry = pendulum.now().add( + seconds=token_response.get("expires_in", self.default_token_expiration) + ) + + def create_jwt_payload(self) -> Dict[str, Union[str, int]]: + now = pendulum.now() + return { + "iss": self.client_id, + "sub": self.client_id, + "aud": self.auth_endpoint, + "exp": math.floor((now.add(hours=1)).timestamp()), + "iat": math.floor(now.timestamp()), + "scope": cast(str, self.scopes), + } + + def load_private_key(self) -> "PrivateKeyTypes": + try: + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + except ModuleNotFoundError: + raise MissingDependencyException("dlt OAuth helpers", ["cryptography"]) + + private_key_bytes = self.private_key.encode("utf-8") + return serialization.load_pem_private_key( + private_key_bytes, + password=self.private_key_passphrase.encode("utf-8") + if self.private_key_passphrase + else None, + backend=default_backend(), + ) diff --git a/dlt/sources/helpers/rest_client/client.py b/dlt/sources/helpers/rest_client/client.py new file mode 100644 index 0000000000..4b5625eebe --- /dev/null +++ b/dlt/sources/helpers/rest_client/client.py @@ -0,0 +1,264 @@ +from typing import ( + Iterator, + Optional, + List, + Dict, + Any, + TypeVar, + Iterable, + cast, +) +import copy +from urllib.parse import urlparse + +from requests import Session as BaseSession # noqa: I251 + +from dlt.common import logger +from dlt.common import jsonpath +from dlt.sources.helpers.requests.retry import Client +from dlt.sources.helpers.requests import Response, Request + +from .typing import HTTPMethodBasic, HTTPMethod, Hooks +from .paginators import BasePaginator +from .auth import AuthConfigBase +from .detector import PaginatorFactory, find_records +from .exceptions import IgnoreResponseException + +from .utils import join_url + + +_T = TypeVar("_T") + + +class PageData(List[_T]): + """A list of elements in a single page of results with attached request context. + + The context allows to inspect the response, paginator and authenticator, modify the request + """ + + def __init__( + self, + __iterable: Iterable[_T], + request: Request, + response: Response, + paginator: BasePaginator, + auth: AuthConfigBase, + ): + super().__init__(__iterable) + self.request = request + self.response = response + self.paginator = paginator + self.auth = auth + + +class RESTClient: + """A generic REST client for making requests to an API with support for + pagination and authentication. + + Args: + base_url (str): The base URL of the API to make requests to. + headers (Optional[Dict[str, str]]): Default headers to include in all requests. + auth (Optional[AuthConfigBase]): Authentication configuration for all requests. + paginator (Optional[BasePaginator]): Default paginator for handling paginated responses. + data_selector (Optional[jsonpath.TJsonPath]): JSONPath selector for extracting data from responses. + session (BaseSession): HTTP session for making requests. + paginator_factory (Optional[PaginatorFactory]): Factory for creating paginator instances, + used for detecting paginators. + """ + + def __init__( + self, + base_url: str, + headers: Optional[Dict[str, str]] = None, + auth: Optional[AuthConfigBase] = None, + paginator: Optional[BasePaginator] = None, + data_selector: Optional[jsonpath.TJsonPath] = None, + session: BaseSession = None, + paginator_factory: Optional[PaginatorFactory] = None, + ) -> None: + self.base_url = base_url + self.headers = headers + self.auth = auth + + if session: + self._validate_session_raise_for_status(session) + self.session = session + else: + self.session = Client(raise_for_status=False).session + + self.paginator = paginator + self.pagination_factory = paginator_factory or PaginatorFactory() + + self.data_selector = data_selector + + def _validate_session_raise_for_status(self, session: BaseSession) -> None: + # dlt.sources.helpers.requests.session.Session + # has raise_for_status=True by default + if getattr(self.session, "raise_for_status", False): + logger.warning( + "The session provided has raise_for_status enabled. " + "This may cause unexpected behavior." + ) + + def _create_request( + self, + path: str, + method: HTTPMethod, + params: Dict[str, Any], + json: Optional[Dict[str, Any]] = None, + auth: Optional[AuthConfigBase] = None, + hooks: Optional[Hooks] = None, + ) -> Request: + parsed_url = urlparse(path) + if parsed_url.scheme in ("http", "https"): + url = path + else: + url = join_url(self.base_url, path) + + return Request( + method=method, + url=url, + headers=self.headers, + params=params, + json=json, + auth=auth or self.auth, + hooks=hooks, + ) + + def _send_request(self, request: Request) -> Response: + logger.info( + f"Making {request.method.upper()} request to {request.url}" + f" with params={request.params}, json={request.json}" + ) + + prepared_request = self.session.prepare_request(request) + + return self.session.send(prepared_request) + + def request( + self, path: str = "", method: HTTPMethod = "GET", **kwargs: Any + ) -> Response: + prepared_request = self._create_request( + path=path, + method=method, + **kwargs, + ) + return self._send_request(prepared_request) + + def get( + self, path: str, params: Optional[Dict[str, Any]] = None, **kwargs: Any + ) -> Response: + return self.request(path, method="GET", params=params, **kwargs) + + def post( + self, path: str, json: Optional[Dict[str, Any]] = None, **kwargs: Any + ) -> Response: + return self.request(path, method="POST", json=json, **kwargs) + + def paginate( + self, + path: str = "", + method: HTTPMethodBasic = "GET", + params: Optional[Dict[str, Any]] = None, + json: Optional[Dict[str, Any]] = None, + auth: Optional[AuthConfigBase] = None, + paginator: Optional[BasePaginator] = None, + data_selector: Optional[jsonpath.TJsonPath] = None, + hooks: Optional[Hooks] = None, + ) -> Iterator[PageData[Any]]: + """Iterates over paginated API responses, yielding pages of data. + + Args: + path (str): Endpoint path for the request, relative to `base_url`. + method (HTTPMethodBasic): HTTP method for the request, defaults to 'get'. + params (Optional[Dict[str, Any]]): URL parameters for the request. + json (Optional[Dict[str, Any]]): JSON payload for the request. + auth (Optional[AuthConfigBase]): Authentication configuration for the request. + paginator (Optional[BasePaginator]): Paginator instance for handling + pagination logic. + data_selector (Optional[jsonpath.TJsonPath]): JSONPath selector for + extracting data from the response. + hooks (Optional[Hooks]): Hooks to modify request/response objects. Note that + when hooks are not provided, the default behavior is to raise an exception + on error status codes. + + Yields: + PageData[Any]: A page of data from the paginated API response, along with request and response context. + + Raises: + HTTPError: If the response status code is not a success code. This is raised + by default when hooks are not provided. + + Example: + >>> client = RESTClient(base_url="https://api.example.com") + >>> for page in client.paginate("/search", method="post", json={"query": "foo"}): + >>> print(page) + """ + + paginator = paginator if paginator else copy.deepcopy(self.paginator) + auth = auth or self.auth + data_selector = data_selector or self.data_selector + hooks = hooks or {} + + def raise_for_status(response: Response, *args: Any, **kwargs: Any) -> None: + response.raise_for_status() + + if "response" not in hooks: + hooks["response"] = [raise_for_status] + + request = self._create_request( + path=path, method=method, params=params, json=json, auth=auth, hooks=hooks + ) + + while True: + try: + response = self._send_request(request) + except IgnoreResponseException: + break + + if paginator is None: + paginator = self.detect_paginator(response) + + data = self.extract_response(response, data_selector) + paginator.update_state(response) + paginator.update_request(request) + + # yield data with context + yield PageData( + data, request=request, response=response, paginator=paginator, auth=auth + ) + + if not paginator.has_next_page: + break + + def extract_response( + self, response: Response, data_selector: jsonpath.TJsonPath + ) -> List[Any]: + if data_selector: + # we should compile data_selector + data: Any = jsonpath.find_values(data_selector, response.json()) + # extract if single item selected + data = data[0] if isinstance(data, list) and len(data) == 1 else data + else: + data = find_records(response.json()) + # wrap single pages into lists + if not isinstance(data, list): + data = [data] + return cast(List[Any], data) + + def detect_paginator(self, response: Response) -> BasePaginator: + """Detects a paginator for the response and returns it. + + Args: + response (Response): The response to detect the paginator for. + + Returns: + BasePaginator: The paginator instance that was detected. + """ + paginator = self.pagination_factory.create_paginator(response) + if paginator is None: + raise ValueError( + f"No suitable paginator found for the response at {response.url}" + ) + logger.info(f"Detected paginator: {paginator.__class__.__name__}") + return paginator diff --git a/dlt/sources/helpers/rest_client/detector.py b/dlt/sources/helpers/rest_client/detector.py new file mode 100644 index 0000000000..f3af31bb4d --- /dev/null +++ b/dlt/sources/helpers/rest_client/detector.py @@ -0,0 +1,161 @@ +import re +from typing import List, Dict, Any, Tuple, Union, Optional, Callable, Iterable + +from dlt.sources.helpers.requests import Response + +from .paginators import ( + BasePaginator, + HeaderLinkPaginator, + JSONResponsePaginator, + SinglePagePaginator, +) + +RECORD_KEY_PATTERNS = frozenset( + [ + "data", + "items", + "results", + "entries", + "records", + "rows", + "entities", + "payload", + "content", + "objects", + ] +) + +NON_RECORD_KEY_PATTERNS = frozenset( + [ + "meta", + "metadata", + "pagination", + "links", + "extras", + "headers", + ] +) + +NEXT_PAGE_KEY_PATTERNS = frozenset(["next", "nextpage", "nexturl"]) +NEXT_PAGE_DICT_KEY_PATTERNS = frozenset(["href", "url"]) + + +def single_entity_path(path: str) -> bool: + """Checks if path ends with path param indicating that single object is returned""" + return re.search(r"\{([a-zA-Z_][a-zA-Z0-9_]*)\}$", path) is not None + + +def find_all_lists( + dict_: Dict[str, Any], + result: List[Tuple[int, str, List[Any]]] = None, + level: int = 0, +) -> List[Tuple[int, str, List[Any]]]: + """Recursively looks for lists in dict_ and returns tuples + in format (nesting level, dictionary key, list) + """ + if level > 2: + return [] + + for key, value in dict_.items(): + if isinstance(value, list): + result.append((level, key, value)) + elif isinstance(value, dict): + find_all_lists(value, result, level + 1) + + return result + + +def find_records( + response: Union[Dict[str, Any], List[Any], Any], +) -> Union[Dict[str, Any], List[Any], Any]: + # when a list was returned (or in rare case a simple type or null) + if not isinstance(response, dict): + return response + lists = find_all_lists(response, result=[]) + if len(lists) == 0: + # could not detect anything + return response + # we are ordered by nesting level, find the most suitable list + try: + return next( + list_info[2] + for list_info in lists + if list_info[1] in RECORD_KEY_PATTERNS + and list_info[1] not in NON_RECORD_KEY_PATTERNS + ) + except StopIteration: + # return the least nested element + return lists[0][2] + + +def matches_any_pattern(key: str, patterns: Iterable[str]) -> bool: + normalized_key = key.lower() + return any(pattern in normalized_key for pattern in patterns) + + +def find_next_page_path( + dictionary: Dict[str, Any], path: Optional[List[str]] = None +) -> Optional[List[str]]: + if not isinstance(dictionary, dict): + return None + + if path is None: + path = [] + + for key, value in dictionary.items(): + if matches_any_pattern(key, NEXT_PAGE_KEY_PATTERNS): + if isinstance(value, dict): + for dict_key in value: + if matches_any_pattern(dict_key, NEXT_PAGE_DICT_KEY_PATTERNS): + return [*path, key, dict_key] + return [*path, key] + + if isinstance(value, dict): + result = find_next_page_path(value, [*path, key]) + if result: + return result + + return None + + +def header_links_detector(response: Response) -> Optional[HeaderLinkPaginator]: + links_next_key = "next" + + if response.links.get(links_next_key): + return HeaderLinkPaginator() + return None + + +def json_links_detector(response: Response) -> Optional[JSONResponsePaginator]: + dictionary = response.json() + next_path_parts = find_next_page_path(dictionary) + + if not next_path_parts: + return None + + return JSONResponsePaginator(next_url_path=".".join(next_path_parts)) + + +def single_page_detector(response: Response) -> Optional[SinglePagePaginator]: + """This is our fallback paginator, also for results that are single entities""" + return SinglePagePaginator() + + +class PaginatorFactory: + def __init__( + self, detectors: List[Callable[[Response], Optional[BasePaginator]]] = None + ): + if detectors is None: + detectors = [ + header_links_detector, + json_links_detector, + single_page_detector, + ] + self.detectors = detectors + + def create_paginator(self, response: Response) -> Optional[BasePaginator]: + for detector in self.detectors: + paginator = detector(response) + if paginator: + return paginator + return None diff --git a/dlt/sources/helpers/rest_client/exceptions.py b/dlt/sources/helpers/rest_client/exceptions.py new file mode 100644 index 0000000000..4b4d555ca7 --- /dev/null +++ b/dlt/sources/helpers/rest_client/exceptions.py @@ -0,0 +1,5 @@ +from dlt.common.exceptions import DltException + + +class IgnoreResponseException(DltException): + pass diff --git a/dlt/sources/helpers/rest_client/paginators.py b/dlt/sources/helpers/rest_client/paginators.py new file mode 100644 index 0000000000..c098ea667f --- /dev/null +++ b/dlt/sources/helpers/rest_client/paginators.py @@ -0,0 +1,178 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from dlt.sources.helpers.requests import Response, Request +from dlt.common import jsonpath + + +class BasePaginator(ABC): + def __init__(self) -> None: + self._has_next_page = True + self._next_reference: Optional[str] = None + + @property + def has_next_page(self) -> bool: + """ + Check if there is a next page available. + + Returns: + bool: True if there is a next page available, False otherwise. + """ + return self._has_next_page + + @property + def next_reference(self) -> Optional[str]: + return self._next_reference + + @next_reference.setter + def next_reference(self, value: Optional[str]) -> None: + self._next_reference = value + self._has_next_page = value is not None + + @abstractmethod + def update_state(self, response: Response) -> None: + """Update the paginator state based on the response. + + Args: + response (Response): The response object from the API. + """ + ... + + @abstractmethod + def update_request(self, request: Request) -> None: + """ + Update the request object with the next arguments for the API request. + + Args: + request (Request): The request object to be updated. + """ + ... + + +class SinglePagePaginator(BasePaginator): + """A paginator for single-page API responses.""" + + def update_state(self, response: Response) -> None: + self._has_next_page = False + + def update_request(self, request: Request) -> None: + return + + +class OffsetPaginator(BasePaginator): + """A paginator that uses the 'offset' parameter for pagination.""" + + def __init__( + self, + initial_limit: int, + initial_offset: int = 0, + offset_param: str = "offset", + limit_param: str = "limit", + total_path: jsonpath.TJsonPath = "total", + ) -> None: + super().__init__() + self.offset_param = offset_param + self.limit_param = limit_param + self.total_path = jsonpath.compile_path(total_path) + + self.offset = initial_offset + self.limit = initial_limit + + def update_state(self, response: Response) -> None: + values = jsonpath.find_values(self.total_path, response.json()) + total = values[0] if values else None + + if total is None: + raise ValueError( + f"Total count not found in response for {self.__class__.__name__}" + ) + + self.offset += self.limit + + if self.offset >= total: + self._has_next_page = False + + def update_request(self, request: Request) -> None: + if request.params is None: + request.params = {} + + request.params[self.offset_param] = self.offset + request.params[self.limit_param] = self.limit + + +class BaseNextUrlPaginator(BasePaginator): + def update_request(self, request: Request) -> None: + request.url = self.next_reference + + +class HeaderLinkPaginator(BaseNextUrlPaginator): + """A paginator that uses the 'Link' header in HTTP responses + for pagination. + + A good example of this is the GitHub API: + https://docs.github.com/en/rest/guides/traversing-with-pagination + """ + + def __init__(self, links_next_key: str = "next") -> None: + """ + Args: + links_next_key (str, optional): The key (rel ) in the 'Link' header + that contains the next page URL. Defaults to 'next'. + """ + super().__init__() + self.links_next_key = links_next_key + + def update_state(self, response: Response) -> None: + self.next_reference = response.links.get(self.links_next_key, {}).get("url") + + +class JSONResponsePaginator(BaseNextUrlPaginator): + """A paginator that uses a specific key in the JSON response to find + the next page URL. + """ + + def __init__( + self, + next_url_path: jsonpath.TJsonPath = "next", + ): + """ + Args: + next_url_path: The JSON path to the key that contains the next page URL in the response. + Defaults to 'next'. + """ + super().__init__() + self.next_url_path = jsonpath.compile_path(next_url_path) + + def update_state(self, response: Response) -> None: + values = jsonpath.find_values(self.next_url_path, response.json()) + self.next_reference = values[0] if values else None + + +class JSONResponseCursorPaginator(BasePaginator): + """A paginator that uses a cursor query param to paginate. The cursor for the + next page is found in the JSON response. + """ + + def __init__( + self, + cursor_path: jsonpath.TJsonPath = "cursors.next", + cursor_param: str = "after", + ): + """ + Args: + cursor_path: The JSON path to the key that contains the cursor in the response. + cursor_param: The name of the query parameter to be used in the request to get the next page. + """ + super().__init__() + self.cursor_path = jsonpath.compile_path(cursor_path) + self.cursor_param = cursor_param + + def update_state(self, response: Response) -> None: + values = jsonpath.find_values(self.cursor_path, response.json()) + self.next_reference = values[0] if values else None + + def update_request(self, request: Request) -> None: + if request.params is None: + request.params = {} + + request.params[self.cursor_param] = self._next_reference diff --git a/dlt/sources/helpers/rest_client/typing.py b/dlt/sources/helpers/rest_client/typing.py new file mode 100644 index 0000000000..626aee4877 --- /dev/null +++ b/dlt/sources/helpers/rest_client/typing.py @@ -0,0 +1,17 @@ +from typing import ( + List, + Dict, + Union, + Literal, + Callable, + Any, +) +from dlt.sources.helpers.requests import Response + + +HTTPMethodBasic = Literal["GET", "POST"] +HTTPMethodExtended = Literal["PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"] +HTTPMethod = Union[HTTPMethodBasic, HTTPMethodExtended] +HookFunction = Callable[[Response, Any, Any], None] +HookEvent = Union[HookFunction, List[HookFunction]] +Hooks = Dict[str, HookEvent] diff --git a/dlt/sources/helpers/rest_client/utils.py b/dlt/sources/helpers/rest_client/utils.py new file mode 100644 index 0000000000..7fe91655c5 --- /dev/null +++ b/dlt/sources/helpers/rest_client/utils.py @@ -0,0 +1,16 @@ +def join_url(base_url: str, path: str) -> str: + if base_url is None: + raise ValueError("Base URL must be provided or set to an empty string.") + + if base_url == "": + return path + + if path == "": + return base_url + + # Normalize the base URL + base_url = base_url.rstrip("/") + if not base_url.endswith("/"): + base_url += "/" + + return base_url + path.lstrip("/") diff --git a/docs/examples/chess_production/chess.py b/docs/examples/chess_production/chess.py index d9e138187f..e2d0b9c10d 100644 --- a/docs/examples/chess_production/chess.py +++ b/docs/examples/chess_production/chess.py @@ -6,6 +6,7 @@ from dlt.common.typing import StrAny, TDataItems from dlt.sources.helpers.requests import client + @dlt.source def chess( chess_url: str = dlt.config.value, @@ -56,6 +57,7 @@ def players_games(username: Any) -> Iterator[TDataItems]: MAX_PLAYERS = 5 + def load_data_with_retry(pipeline, data): try: for attempt in Retrying( @@ -65,9 +67,7 @@ def load_data_with_retry(pipeline, data): reraise=True, ): with attempt: - logger.info( - f"Running the pipeline, attempt={attempt.retry_state.attempt_number}" - ) + logger.info(f"Running the pipeline, attempt={attempt.retry_state.attempt_number}") load_info = pipeline.run(data) logger.info(str(load_info)) @@ -89,9 +89,7 @@ def load_data_with_retry(pipeline, data): # print the information on the first load package and all jobs inside logger.info(f"First load package info: {load_info.load_packages[0]}") # print the information on the first completed job in first load package - logger.info( - f"First completed job info: {load_info.load_packages[0].jobs['completed_jobs'][0]}" - ) + logger.info(f"First completed job info: {load_info.load_packages[0].jobs['completed_jobs'][0]}") # check for schema updates: schema_updates = [p.schema_update for p in load_info.load_packages] @@ -149,4 +147,4 @@ def load_data_with_retry(pipeline, data): ) # get data for a few famous players data = chess(chess_url="https://api.chess.com/pub/", max_players=MAX_PLAYERS) - load_data_with_retry(pipeline, data) \ No newline at end of file + load_data_with_retry(pipeline, data) diff --git a/docs/examples/connector_x_arrow/load_arrow.py b/docs/examples/connector_x_arrow/load_arrow.py index 06ca4e17b3..b3c654cef9 100644 --- a/docs/examples/connector_x_arrow/load_arrow.py +++ b/docs/examples/connector_x_arrow/load_arrow.py @@ -3,6 +3,7 @@ import dlt from dlt.sources.credentials import ConnectionStringCredentials + def read_sql_x( conn_str: ConnectionStringCredentials = dlt.secrets.value, query: str = dlt.config.value, @@ -14,6 +15,7 @@ def read_sql_x( protocol="binary", ) + def genome_resource(): # create genome resource with merge on `upid` primary key genome = dlt.resource( diff --git a/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py b/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py index fcd964a980..624888f70a 100644 --- a/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py +++ b/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py @@ -15,6 +15,7 @@ # format: "your-project.your_dataset.your_table" BIGQUERY_TABLE_ID = "chat-analytics-rasa-ci.ci_streaming_insert.natural-disasters" + # dlt sources @dlt.resource(name="natural_disasters") def resource(url: str): @@ -38,6 +39,7 @@ def resource(url: str): ) yield table + # dlt biquery custom destination # we can use the dlt provided credentials class # to retrieve the gcp credentials from the secrets @@ -58,6 +60,7 @@ def bigquery_insert( load_job = client.load_table_from_file(f, BIGQUERY_TABLE_ID, job_config=job_config) load_job.result() # Waits for the job to complete. + if __name__ == "__main__": # run the pipeline and print load results pipeline = dlt.pipeline( @@ -68,4 +71,4 @@ def bigquery_insert( ) load_info = pipeline.run(resource(url=OWID_DISASTERS_URL)) - print(load_info) \ No newline at end of file + print(load_info) diff --git a/docs/examples/google_sheets/google_sheets.py b/docs/examples/google_sheets/google_sheets.py index 8a93df9970..1ba330e4ca 100644 --- a/docs/examples/google_sheets/google_sheets.py +++ b/docs/examples/google_sheets/google_sheets.py @@ -9,6 +9,7 @@ ) from dlt.common.typing import DictStrAny, StrAny + def _initialize_sheets( credentials: Union[GcpOAuthCredentials, GcpServiceAccountCredentials] ) -> Any: @@ -16,6 +17,7 @@ def _initialize_sheets( service = build("sheets", "v4", credentials=credentials.to_native_credentials()) return service + @dlt.source def google_spreadsheet( spreadsheet_id: str, @@ -55,6 +57,7 @@ def get_sheet(sheet_name: str) -> Iterator[DictStrAny]: for name in sheet_names ] + if __name__ == "__main__": pipeline = dlt.pipeline(destination="duckdb") # see example.secrets.toml to where to put credentials @@ -67,4 +70,4 @@ def get_sheet(sheet_name: str) -> Iterator[DictStrAny]: sheet_names=range_names, ) ) - print(info) \ No newline at end of file + print(info) diff --git a/docs/examples/incremental_loading/zendesk.py b/docs/examples/incremental_loading/zendesk.py index 4b8597886a..6113f98793 100644 --- a/docs/examples/incremental_loading/zendesk.py +++ b/docs/examples/incremental_loading/zendesk.py @@ -6,12 +6,11 @@ from dlt.common.typing import TAnyDateTime from dlt.sources.helpers.requests import client + @dlt.source(max_table_nesting=2) def zendesk_support( credentials: Dict[str, str] = dlt.secrets.value, - start_date: Optional[TAnyDateTime] = pendulum.datetime( # noqa: B008 - year=2000, month=1, day=1 - ), + start_date: Optional[TAnyDateTime] = pendulum.datetime(year=2000, month=1, day=1), # noqa: B008 end_date: Optional[TAnyDateTime] = None, ): """ @@ -113,6 +112,7 @@ def get_pages( if not response_json["end_of_stream"]: get_url = response_json["next_page"] + if __name__ == "__main__": # create dlt pipeline pipeline = dlt.pipeline( @@ -120,4 +120,4 @@ def get_pages( ) load_info = pipeline.run(zendesk_support()) - print(load_info) \ No newline at end of file + print(load_info) diff --git a/docs/examples/nested_data/nested_data.py b/docs/examples/nested_data/nested_data.py index 3464448de6..7f85f0522e 100644 --- a/docs/examples/nested_data/nested_data.py +++ b/docs/examples/nested_data/nested_data.py @@ -13,6 +13,7 @@ CHUNK_SIZE = 10000 + # You can limit how deep dlt goes when generating child tables. # By default, the library will descend and generate child tables # for all nested lists, without a limit. @@ -81,6 +82,7 @@ def load_documents(self) -> Iterator[TDataItem]: while docs_slice := list(islice(cursor, CHUNK_SIZE)): yield map_nested_in_place(convert_mongo_objs, docs_slice) + def convert_mongo_objs(value: Any) -> Any: if isinstance(value, (ObjectId, Decimal128)): return str(value) diff --git a/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py b/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py index 8f7833e7d7..e7f57853ed 100644 --- a/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py +++ b/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py @@ -4,6 +4,7 @@ from dlt.destinations.impl.weaviate import weaviate_adapter from PyPDF2 import PdfReader + @dlt.resource(selected=False) def list_files(folder_path: str): folder_path = os.path.abspath(folder_path) @@ -15,6 +16,7 @@ def list_files(folder_path: str): "mtime": os.path.getmtime(file_path), } + @dlt.transformer(primary_key="page_id", write_disposition="merge") def pdf_to_text(file_item, separate_pages: bool = False): if not separate_pages: @@ -28,6 +30,7 @@ def pdf_to_text(file_item, separate_pages: bool = False): page_item["page_id"] = file_item["file_name"] + "_" + str(page_no) yield page_item + pipeline = dlt.pipeline(pipeline_name="pdf_to_text", destination="weaviate") # this constructs a simple pipeline that: (1) reads files from "invoices" folder (2) filters only those ending with ".pdf" @@ -51,4 +54,4 @@ def pdf_to_text(file_item, separate_pages: bool = False): client = weaviate.Client("http://localhost:8080") # get text of all the invoices in InvoiceText class we just created above -print(client.query.get("InvoiceText", ["text", "file_name", "mtime", "page_id"]).do()) \ No newline at end of file +print(client.query.get("InvoiceText", ["text", "file_name", "mtime", "page_id"]).do()) diff --git a/docs/examples/qdrant_zendesk/qdrant.py b/docs/examples/qdrant_zendesk/qdrant.py index 300d8dc6ad..bd0cbafc99 100644 --- a/docs/examples/qdrant_zendesk/qdrant.py +++ b/docs/examples/qdrant_zendesk/qdrant.py @@ -10,13 +10,12 @@ from dlt.common.configuration.inject import with_config + # function from: https://github.com/dlt-hub/verified-sources/tree/master/sources/zendesk @dlt.source(max_table_nesting=2) def zendesk_support( credentials: Dict[str, str] = dlt.secrets.value, - start_date: Optional[TAnyDateTime] = pendulum.datetime( # noqa: B008 - year=2000, month=1, day=1 - ), + start_date: Optional[TAnyDateTime] = pendulum.datetime(year=2000, month=1, day=1), # noqa: B008 end_date: Optional[TAnyDateTime] = None, ): """ @@ -80,6 +79,7 @@ def _parse_date_or_none(value: Optional[str]) -> Optional[pendulum.DateTime]: return None return ensure_pendulum_datetime(value) + # modify dates to return datetime objects instead def _fix_date(ticket): ticket["updated_at"] = _parse_date_or_none(ticket["updated_at"]) @@ -87,6 +87,7 @@ def _fix_date(ticket): ticket["due_at"] = _parse_date_or_none(ticket["due_at"]) return ticket + # function from: https://github.com/dlt-hub/verified-sources/tree/master/sources/zendesk def get_pages( url: str, @@ -127,6 +128,7 @@ def get_pages( if not response_json["end_of_stream"]: get_url = response_json["next_page"] + if __name__ == "__main__": # create a pipeline with an appropriate name pipeline = dlt.pipeline( @@ -146,7 +148,6 @@ def get_pages( print(load_info) - # running the Qdrant client to connect to your Qdrant database @with_config(sections=("destination", "qdrant", "credentials")) diff --git a/docs/examples/transformers/pokemon.py b/docs/examples/transformers/pokemon.py index 2181c33259..ca32c570ef 100644 --- a/docs/examples/transformers/pokemon.py +++ b/docs/examples/transformers/pokemon.py @@ -1,6 +1,7 @@ import dlt from dlt.sources.helpers import requests + @dlt.source(max_table_nesting=2) def source(pokemon_api_url: str): """""" @@ -44,6 +45,7 @@ def species(pokemon_details): return (pokemon_list | pokemon, pokemon_list | pokemon | species) + if __name__ == "__main__": # build duck db pipeline pipeline = dlt.pipeline( @@ -52,4 +54,4 @@ def species(pokemon_details): # the pokemon_list resource does not need to be loaded load_info = pipeline.run(source("https://pokeapi.co/api/v2/pokemon")) - print(load_info) \ No newline at end of file + print(load_info) diff --git a/docs/website/docs/general-usage/incremental-loading.md b/docs/website/docs/general-usage/incremental-loading.md index fe3bb8b61d..23b2218b46 100644 --- a/docs/website/docs/general-usage/incremental-loading.md +++ b/docs/website/docs/general-usage/incremental-loading.md @@ -298,13 +298,13 @@ We just yield all the events and `dlt` does the filtering (using `id` column dec Github returns events ordered from newest to oldest. So we declare the `rows_order` as **descending** to [stop requesting more pages once the incremental value is out of range](#declare-row-order-to-not-request-unnecessary-data). We stop requesting more data from the API after finding the first event with `created_at` earlier than `initial_value`. -:::note +:::note **Note on Incremental Cursor Behavior:** -When using incremental cursors for loading data, it's essential to understand how `dlt` handles records in relation to the cursor's +When using incremental cursors for loading data, it's essential to understand how `dlt` handles records in relation to the cursor's last value. By default, `dlt` will load only those records for which the incremental cursor value is higher than the last known value of the cursor. This means that any records with a cursor value lower than or equal to the last recorded value will be ignored during the loading process. -This behavior ensures efficiency by avoiding the reprocessing of records that have already been loaded, but it can lead to confusion if -there are expectations of loading older records that fall below the current cursor threshold. If your use case requires the inclusion of +This behavior ensures efficiency by avoiding the reprocessing of records that have already been loaded, but it can lead to confusion if +there are expectations of loading older records that fall below the current cursor threshold. If your use case requires the inclusion of such records, you can consider adjusting your data extraction logic, using a full refresh strategy where appropriate or using `last_value_func` as discussed in the subsquent section. ::: @@ -625,6 +625,35 @@ Before `dlt` starts executing incremental resources, it looks for `data_interval You can run DAGs manually but you must remember to specify the Airflow logical date of the run in the past (use Run with config option). For such run `dlt` will load all data from that past date until now. If you do not specify the past date, a run with a range (now, now) will happen yielding no data. +### Reading incremental loading parameters from configuration + +Consider the example below for reading incremental loading parameters from "config.toml". We create a `generate_incremental_records` resource that yields "id", "idAfter", and "name". This resource retrieves `cursor_path` and `initial_value` from "config.toml". + +1. In "config.toml", define the `cursor_path` and `initial_value` as: + ```toml + # Configuration snippet for an incremental resource + [pipeline_with_incremental.sources.id_after] + cursor_path = "idAfter" + initial_value = 10 + ``` + + `cursor_path` is assigned the value "idAfter" with an initial value of 10. + +1. Here's how the `generate_incremental_records` resource uses `cursor_path` defined in "config.toml": + ```py + @dlt.resource(table_name="incremental_records") + def generate_incremental_records(id_after: dlt.sources.incremental = dlt.config.value): + for i in range(150): + yield {"id": i, "idAfter": i, "name": "name-" + str(i)} + + pipeline = dlt.pipeline( + pipeline_name="pipeline_with_incremental", + destination="duckdb", + ) + + pipeline.run(generate_incremental_records) + ``` + `id_after` incrementally stores the latest `cursor_path` value for future pipeline runs. ## Doing a full refresh diff --git a/docs/website/docs/general-usage/schema.md b/docs/website/docs/general-usage/schema.md index 9b0d8ec622..cb1c73c340 100644 --- a/docs/website/docs/general-usage/schema.md +++ b/docs/website/docs/general-usage/schema.md @@ -268,9 +268,44 @@ settings: re:^updated_at$: timestamp re:^_dlt_list_idx$: bigint ``` +### Applying data types directly with `@dlt.resource` and `apply_hints` +`dlt` offers the flexibility to directly apply data types and hints in your code, bypassing the need for importing and adjusting schemas. This approach is ideal for rapid prototyping and handling data sources with dynamic schema requirements. + +### Direct specification in `@dlt.resource` +Directly define data types and their properties, such as nullability, within the `@dlt.resource` decorator. This eliminates the dependency on external schema files. For example: + +```py +@dlt.resource(name='my_table', columns={"my_column": {"data_type": "bool", "nullable": True}}) +def my_resource(): + for i in range(10): + yield {'my_column': i % 2 == 0} +``` +This code snippet sets up a nullable boolean column named `my_column` directly in the decorator. + +#### Using `apply_hints` +When dealing with dynamically generated resources or needing to programmatically set hints, `apply_hints` is your tool. It's especially useful for applying hints across various collections or tables at once. + +For example, to apply a complex data type across all collections from a MongoDB source: + +```py +all_collections = ["collection1", "collection2", "collection3"] # replace with your actual collection names +source_data = mongodb().with_resources(*all_collections) + +for col in all_collections: + source_data.resources[col].apply_hints(columns={"column_name": {"data_type": "complex"}}) + +pipeline = dlt.pipeline( + pipeline_name="mongodb_pipeline", + destination="duckdb", + dataset_name="mongodb_data" +) +load_info = pipeline.run(source_data) +``` +This example iterates through MongoDB collections, applying the complex [data type](schema#data-types) to a specified column, and then processes the data with `pipeline.run`. ## Export and import schema files + Please follow the guide on [how to adjust a schema](../walkthroughs/adjust-a-schema.md) to export and import `yaml` schema files in your pipeline. diff --git a/docs/website/docs/walkthroughs/deploy-a-pipeline/deploy-with-kestra.md b/docs/website/docs/walkthroughs/deploy-a-pipeline/deploy-with-kestra.md new file mode 100644 index 0000000000..cfb63ce808 --- /dev/null +++ b/docs/website/docs/walkthroughs/deploy-a-pipeline/deploy-with-kestra.md @@ -0,0 +1,116 @@ +--- +title: Deploy with Kestra +description: How to deploy a pipeline with Kestra +keywords: [how to, deploy a pipeline, Kestra] +--- + +# Deploy with Kestra + +## Introduction to Kestra + +[Kestra](https://kestra.io/docs) is an open-source, scalable orchestration platform that enables +engineers to manage business-critical workflows declaratively in code. By applying  +infrastructure as code best practices to data, process, and microservice orchestration, you +can build and manage reliable workflows. + +Kestra facilitates reliable workflow management, offering advanced settings for resiliency, +triggers, real-time monitoring, and integration capabilities, making it a valuable tool for data +engineers and developers. + +### Kestra features + +Kestra provides a robust orchestration engine with features including: + +- Workflows accessible through a user interface, event-driven + automation, and an embedded visual studio code editor. +- It also offers embedded documentation, a live-updating topology view, and access to over 400 + plugins, enhancing its versatility. +- Kestra supports Git & CI/CD integrations, basic authentication, and benefits from community + support. + +To know more, please refer to [Kestra's documentation.](https://kestra.io/docs) + +## Building Data Pipelines with `dlt` + +**`dlt`** is an open-source Python library that allows you to declaratively load data sources +into well-structured tables or datasets. It does this through automatic schema inference and evolution. +The library simplifies building data pipeline by providing functionality to support the entire extract +and load process. + +### How does `dlt` integrate with Kestra for pipeline orchestration? + +To illustrate setting up a pipeline in Kestra, we’ll be using the following example: +[From Inbox to Insights AI-Enhanced Email Analysis with dlt and Kestra.](https://kestra.io/blogs/2023-12-04-dlt-kestra-usage) + +The example demonstrates automating a workflow to load data from Gmail to BigQuery using the `dlt`, +complemented by AI-driven summarization and sentiment analysis. You can refer to the project's +github repo by clicking [here.](https://github.com/dlt-hub/dlt-kestra-demo) + +:::info +For the detailed guide, please take a look at the project's [README](https://github.com/dlt-hub/dlt-kestra-demo/blob/main/README.md) section. +::: + +Here is the summary of the steps: + +1. Start by creating a virtual environment. + +1. Generate an `.env` File: Inside your project repository, create an `.env` file to store + credentials in "base64" format, prefixed with 'SECRET\_' for compatibility with Kestra's `secret()` + function. + +1. As per Kestra’s recommendation, install the docker desktop on your machine. + +1. Ensure Docker is running, then download the Docker compose file with: + + ```sh + curl -o docker-compose.yml \ + https://raw.githubusercontent.com/kestra-io/kestra/develop/docker-compose.yml + ``` + +1. Configure Docker compose file: + Edit the downloaded Docker compose file to link the `.env` file for environment + variables. + + ```yaml + kestra: + image: kestra/kestra:develop-full + env_file: + - .env + ``` + +1. Enable Auto-Restart: In your `docker-compose.yml`, set `restart: always` for both postgres and + kestra services to ensure they reboot automatically after a system restart. + +1. Launch Kestra Server: Execute `docker compose up -d` to start the server. + +1. Access Kestra UI: Navigate to `http://localhost:8080/` to use the Kestra user interface. + +1. Create and Configure Flows: + + - Go to 'Flows', then 'Create'. + - Configure the flow files in the editor. + - Save your flows. + +1. **Understand Flow Components**: + + - Each flow must have an `id`, `namespace`, and a list of `tasks` with their respective `id` and + `type`. + - The main flow orchestrates tasks like loading data from a source to a destination. + +By following these steps, you establish a structured workflow within Kestra, leveraging its powerful +features for efficient data pipeline orchestration. + +:::info +For detailed information on these steps, please consult the `README.md` in the +[dlt-kestra-demo](https://github.com/dlt-hub/dlt-kestra-demo/blob/main/README.md) repo. +::: + +### Additional Resources + +- Ingest Zendesk data into Weaviate using `dlt` with Kestra: + [here](https://kestra.io/blueprints/148-ingest-zendesk-data-into-weaviate-using-dlt). +- Ingest Zendesk data into DuckDb using dlt with Kestra: + [here.](https://kestra.io/blueprints/147-ingest-zendesk-data-into-duckdb-using-dlt) +- Ingest Pipedrive CRM data to BigQuery using `dlt` and schedule it to run every hour: + [here.](https://kestra.io/blueprints/146-ingest-pipedrive-crm-data-to-bigquery-using-dlt-and-schedule-it-to-run-every-hour) + diff --git a/docs/website/sidebars.js b/docs/website/sidebars.js index 275c1f438a..4fd6bfca6b 100644 --- a/docs/website/sidebars.js +++ b/docs/website/sidebars.js @@ -218,6 +218,7 @@ const sidebars = { 'reference/explainers/airflow-gcp-cloud-composer', 'walkthroughs/deploy-a-pipeline/deploy-with-google-cloud-functions', 'walkthroughs/deploy-a-pipeline/deploy-gcp-cloud-function-as-webhook', + 'walkthroughs/deploy-a-pipeline/deploy-with-kestra', ] }, { diff --git a/poetry.lock b/poetry.lock index 96e730bf3a..a7c3979625 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4681,16 +4681,6 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -7081,7 +7071,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -7089,16 +7078,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -7115,7 +7096,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -7123,7 +7103,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -8011,7 +7990,6 @@ files = [ {file = "SQLAlchemy-1.4.49-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:03db81b89fe7ef3857b4a00b63dedd632d6183d4ea5a31c5d8a92e000a41fc71"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-macosx_11_0_x86_64.whl", hash = "sha256:95b9df9afd680b7a3b13b38adf6e3a38995da5e162cc7524ef08e3be4e5ed3e1"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a63e43bf3f668c11bb0444ce6e809c1227b8f067ca1068898f3008a273f52b09"}, - {file = "SQLAlchemy-1.4.49-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca46de16650d143a928d10842939dab208e8d8c3a9a8757600cae9b7c579c5cd"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f835c050ebaa4e48b18403bed2c0fda986525896efd76c245bdd4db995e51a4c"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c21b172dfb22e0db303ff6419451f0cac891d2e911bb9fbf8003d717f1bcf91"}, {file = "SQLAlchemy-1.4.49-cp310-cp310-win32.whl", hash = "sha256:5fb1ebdfc8373b5a291485757bd6431de8d7ed42c27439f543c81f6c8febd729"}, @@ -8021,35 +7999,26 @@ files = [ {file = "SQLAlchemy-1.4.49-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5debe7d49b8acf1f3035317e63d9ec8d5e4d904c6e75a2a9246a119f5f2fdf3d"}, {file = "SQLAlchemy-1.4.49-cp311-cp311-win32.whl", hash = "sha256:82b08e82da3756765c2e75f327b9bf6b0f043c9c3925fb95fb51e1567fa4ee87"}, {file = "SQLAlchemy-1.4.49-cp311-cp311-win_amd64.whl", hash = "sha256:171e04eeb5d1c0d96a544caf982621a1711d078dbc5c96f11d6469169bd003f1"}, - {file = "SQLAlchemy-1.4.49-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f23755c384c2969ca2f7667a83f7c5648fcf8b62a3f2bbd883d805454964a800"}, - {file = "SQLAlchemy-1.4.49-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8396e896e08e37032e87e7fbf4a15f431aa878c286dc7f79e616c2feacdb366c"}, - {file = "SQLAlchemy-1.4.49-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:66da9627cfcc43bbdebd47bfe0145bb662041472393c03b7802253993b6b7c90"}, - {file = "SQLAlchemy-1.4.49-cp312-cp312-win32.whl", hash = "sha256:9a06e046ffeb8a484279e54bda0a5abfd9675f594a2e38ef3133d7e4d75b6214"}, - {file = "SQLAlchemy-1.4.49-cp312-cp312-win_amd64.whl", hash = "sha256:7cf8b90ad84ad3a45098b1c9f56f2b161601e4670827d6b892ea0e884569bd1d"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:36e58f8c4fe43984384e3fbe6341ac99b6b4e083de2fe838f0fdb91cebe9e9cb"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b31e67ff419013f99ad6f8fc73ee19ea31585e1e9fe773744c0f3ce58c039c30"}, - {file = "SQLAlchemy-1.4.49-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ebc22807a7e161c0d8f3da34018ab7c97ef6223578fcdd99b1d3e7ed1100a5db"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c14b29d9e1529f99efd550cd04dbb6db6ba5d690abb96d52de2bff4ed518bc95"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c40f3470e084d31247aea228aa1c39bbc0904c2b9ccbf5d3cfa2ea2dac06f26d"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-win32.whl", hash = "sha256:706bfa02157b97c136547c406f263e4c6274a7b061b3eb9742915dd774bbc264"}, {file = "SQLAlchemy-1.4.49-cp36-cp36m-win_amd64.whl", hash = "sha256:a7f7b5c07ae5c0cfd24c2db86071fb2a3d947da7bd487e359cc91e67ac1c6d2e"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-macosx_11_0_x86_64.whl", hash = "sha256:4afbbf5ef41ac18e02c8dc1f86c04b22b7a2125f2a030e25bbb4aff31abb224b"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:24e300c0c2147484a002b175f4e1361f102e82c345bf263242f0449672a4bccf"}, - {file = "SQLAlchemy-1.4.49-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:393cd06c3b00b57f5421e2133e088df9cabcececcea180327e43b937b5a7caa5"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:201de072b818f8ad55c80d18d1a788729cccf9be6d9dc3b9d8613b053cd4836d"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7653ed6817c710d0c95558232aba799307d14ae084cc9b1f4c389157ec50df5c"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-win32.whl", hash = "sha256:647e0b309cb4512b1f1b78471fdaf72921b6fa6e750b9f891e09c6e2f0e5326f"}, {file = "SQLAlchemy-1.4.49-cp37-cp37m-win_amd64.whl", hash = "sha256:ab73ed1a05ff539afc4a7f8cf371764cdf79768ecb7d2ec691e3ff89abbc541e"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:37ce517c011560d68f1ffb28af65d7e06f873f191eb3a73af5671e9c3fada08a"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1878ce508edea4a879015ab5215546c444233881301e97ca16fe251e89f1c55"}, - {file = "SQLAlchemy-1.4.49-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95ab792ca493891d7a45a077e35b418f68435efb3e1706cb8155e20e86a9013c"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:0e8e608983e6f85d0852ca61f97e521b62e67969e6e640fe6c6b575d4db68557"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ccf956da45290df6e809ea12c54c02ace7f8ff4d765d6d3dfb3655ee876ce58d"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-win32.whl", hash = "sha256:f167c8175ab908ce48bd6550679cc6ea20ae169379e73c7720a28f89e53aa532"}, {file = "SQLAlchemy-1.4.49-cp38-cp38-win_amd64.whl", hash = "sha256:45806315aae81a0c202752558f0df52b42d11dd7ba0097bf71e253b4215f34f4"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:b6d0c4b15d65087738a6e22e0ff461b407533ff65a73b818089efc8eb2b3e1de"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a843e34abfd4c797018fd8d00ffffa99fd5184c421f190b6ca99def4087689bd"}, - {file = "SQLAlchemy-1.4.49-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:738d7321212941ab19ba2acf02a68b8ee64987b248ffa2101630e8fccb549e0d"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:1c890421651b45a681181301b3497e4d57c0d01dc001e10438a40e9a9c25ee77"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d26f280b8f0a8f497bc10573849ad6dc62e671d2468826e5c748d04ed9e670d5"}, {file = "SQLAlchemy-1.4.49-cp39-cp39-win32.whl", hash = "sha256:ec2268de67f73b43320383947e74700e95c6770d0c68c4e615e9897e46296294"}, @@ -9066,4 +9035,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "99658baf1bfda2ac065bda897637cae0eb122c76777688a7d606df0ef06c7fcc" +content-hash = "e6e43e82afedfa274c91f3fd13dbbddd9cac64f386d2f5f1c4564ff6f5784cd2" diff --git a/pyproject.toml b/pyproject.toml index de5f8055c5..62a45c86f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,6 +145,7 @@ google-api-python-client = ">=1.7.11" pytest-asyncio = "^0.23.5" types-sqlalchemy = "^1.4.53.38" ruff = "^0.3.2" +pyjwt = "^2.8.0" [tool.poetry.group.pipeline] optional=true diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index a883f76ddb..5fbcd86d92 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -126,18 +126,14 @@ class MockProdConfiguration(RunConfiguration): @configspec class FieldWithNoDefaultConfiguration(RunConfiguration): - no_default: str - - if TYPE_CHECKING: - - def __init__(self, no_default: str = None, sentry_dsn: str = None) -> None: ... + no_default: str = None @configspec class InstrumentedConfiguration(BaseConfiguration): - head: str - tube: List[str] - heels: str + head: str = None + tube: List[str] = None + heels: str = None def to_native_representation(self) -> Any: return self.head + ">" + ">".join(self.tube) + ">" + self.heels @@ -156,63 +152,50 @@ def on_resolved(self) -> None: if self.head > self.heels: raise RuntimeError("Head over heels") - if TYPE_CHECKING: - - def __init__(self, head: str = None, tube: List[str] = None, heels: str = None) -> None: ... - @configspec class EmbeddedConfiguration(BaseConfiguration): - default: str - instrumented: InstrumentedConfiguration - sectioned: SectionedConfiguration - - if TYPE_CHECKING: - - def __init__( - self, - default: str = None, - instrumented: InstrumentedConfiguration = None, - sectioned: SectionedConfiguration = None, - ) -> None: ... + default: str = None + instrumented: InstrumentedConfiguration = None + sectioned: SectionedConfiguration = None @configspec class EmbeddedOptionalConfiguration(BaseConfiguration): - instrumented: Optional[InstrumentedConfiguration] + instrumented: Optional[InstrumentedConfiguration] = None @configspec class EmbeddedSecretConfiguration(BaseConfiguration): - secret: SecretConfiguration + secret: SecretConfiguration = None @configspec class NonTemplatedComplexTypesConfiguration(BaseConfiguration): - list_val: list # type: ignore[type-arg] - tuple_val: tuple # type: ignore[type-arg] - dict_val: dict # type: ignore[type-arg] + list_val: list = None # type: ignore[type-arg] + tuple_val: tuple = None # type: ignore[type-arg] + dict_val: dict = None # type: ignore[type-arg] @configspec class DynamicConfigA(BaseConfiguration): - field_for_a: str + field_for_a: str = None @configspec class DynamicConfigB(BaseConfiguration): - field_for_b: str + field_for_b: str = None @configspec class DynamicConfigC(BaseConfiguration): - field_for_c: str + field_for_c: str = None @configspec class ConfigWithDynamicType(BaseConfiguration): - discriminator: str - embedded_config: BaseConfiguration + discriminator: str = None + embedded_config: BaseConfiguration = None @resolve_type("embedded_config") def resolve_embedded_type(self) -> Type[BaseConfiguration]: @@ -240,8 +223,8 @@ def resolve_c_type(self) -> Type[BaseConfiguration]: @configspec class SubclassConfigWithDynamicType(ConfigWithDynamicType): - is_number: bool - dynamic_type_field: Any + is_number: bool = None + dynamic_type_field: Any = None @resolve_type("embedded_config") def resolve_embedded_type(self) -> Type[BaseConfiguration]: @@ -937,11 +920,7 @@ def test_is_valid_hint() -> None: def test_configspec_auto_base_config_derivation() -> None: @configspec class AutoBaseDerivationConfiguration: - auto: str - - if TYPE_CHECKING: - - def __init__(self, auto: str = None) -> None: ... + auto: str = None assert issubclass(AutoBaseDerivationConfiguration, BaseConfiguration) assert hasattr(AutoBaseDerivationConfiguration, "auto") diff --git a/tests/common/configuration/test_container.py b/tests/common/configuration/test_container.py index 9521f5960d..eddd0b21dc 100644 --- a/tests/common/configuration/test_container.py +++ b/tests/common/configuration/test_container.py @@ -20,19 +20,15 @@ @configspec class InjectableTestContext(ContainerInjectableContext): - current_value: str + current_value: str = None def parse_native_representation(self, native_value: Any) -> None: raise ValueError(native_value) - if TYPE_CHECKING: - - def __init__(self, current_value: str = None) -> None: ... - @configspec class EmbeddedWithInjectableContext(BaseConfiguration): - injected: InjectableTestContext + injected: InjectableTestContext = None @configspec @@ -47,12 +43,12 @@ class GlobalTestContext(InjectableTestContext): @configspec class EmbeddedWithNoDefaultInjectableContext(BaseConfiguration): - injected: NoDefaultInjectableContext + injected: NoDefaultInjectableContext = None @configspec class EmbeddedWithNoDefaultInjectableOptionalContext(BaseConfiguration): - injected: Optional[NoDefaultInjectableContext] + injected: Optional[NoDefaultInjectableContext] = None @pytest.fixture() diff --git a/tests/common/configuration/test_credentials.py b/tests/common/configuration/test_credentials.py index ae9b96e903..7c184c16e5 100644 --- a/tests/common/configuration/test_credentials.py +++ b/tests/common/configuration/test_credentials.py @@ -158,6 +158,34 @@ def test_connection_string_resolved_from_native_representation_env(environment: assert c.host == "aws.12.1" +def test_connection_string_from_init() -> None: + c = ConnectionStringCredentials("postgres://loader:pass@localhost:5432/dlt_data?a=b&c=d") + assert c.drivername == "postgres" + assert c.is_resolved() + assert not c.is_partial() + + c = ConnectionStringCredentials( + { + "drivername": "postgres", + "username": "loader", + "password": "pass", + "host": "localhost", + "port": 5432, + "database": "dlt_data", + "query": {"a": "b", "c": "d"}, + } + ) + assert c.drivername == "postgres" + assert c.username == "loader" + assert c.password == "pass" + assert c.host == "localhost" + assert c.port == 5432 + assert c.database == "dlt_data" + assert c.query == {"a": "b", "c": "d"} + assert c.is_resolved() + assert not c.is_partial() + + def test_gcp_service_credentials_native_representation(environment) -> None: with pytest.raises(InvalidGoogleNativeCredentialsType): GcpServiceAccountCredentials().parse_native_representation(1) diff --git a/tests/common/configuration/test_inject.py b/tests/common/configuration/test_inject.py index c6ab8aa756..1aa52c1919 100644 --- a/tests/common/configuration/test_inject.py +++ b/tests/common/configuration/test_inject.py @@ -167,7 +167,7 @@ def test_inject_with_sections() -> None: def test_inject_spec_in_func_params() -> None: @configspec class TestConfig(BaseConfiguration): - base_value: str + base_value: str = None # if any of args (ie. `init` below) is an instance of SPEC, we use it as initial value @@ -179,7 +179,7 @@ def test_spec_arg(base_value=dlt.config.value, init: TestConfig = None): spec = get_fun_spec(test_spec_arg) assert spec == TestConfig # call function with init, should resolve even if we do not provide the base_value in config - assert test_spec_arg(init=TestConfig(base_value="A")) == "A" # type: ignore[call-arg] + assert test_spec_arg(init=TestConfig(base_value="A")) == "A" def test_inject_with_sections_and_sections_context() -> None: @@ -272,7 +272,7 @@ def test_sections(value=dlt.config.value): def test_base_spec() -> None: @configspec class BaseParams(BaseConfiguration): - str_str: str + str_str: str = None @with_config(base=BaseParams) def f_explicit_base(str_str=dlt.config.value, opt: bool = True): diff --git a/tests/common/configuration/test_sections.py b/tests/common/configuration/test_sections.py index 9e0bc7e26d..bf6780e087 100644 --- a/tests/common/configuration/test_sections.py +++ b/tests/common/configuration/test_sections.py @@ -25,33 +25,33 @@ @configspec class SingleValConfiguration(BaseConfiguration): - sv: str + sv: str = None @configspec class EmbeddedConfiguration(BaseConfiguration): - sv_config: Optional[SingleValConfiguration] + sv_config: Optional[SingleValConfiguration] = None @configspec class EmbeddedWithSectionedConfiguration(BaseConfiguration): - embedded: SectionedConfiguration + embedded: SectionedConfiguration = None @configspec class EmbeddedIgnoredConfiguration(BaseConfiguration): # underscore prevents the field name to be added to embedded sections - _sv_config: Optional[SingleValConfiguration] + _sv_config: Optional[SingleValConfiguration] = None @configspec class EmbeddedIgnoredWithSectionedConfiguration(BaseConfiguration): - _embedded: SectionedConfiguration + _embedded: SectionedConfiguration = None @configspec class EmbeddedWithIgnoredEmbeddedConfiguration(BaseConfiguration): - ignored_embedded: EmbeddedIgnoredWithSectionedConfiguration + ignored_embedded: EmbeddedIgnoredWithSectionedConfiguration = None def test_sectioned_configuration(environment: Any, env_provider: ConfigProvider) -> None: diff --git a/tests/common/configuration/test_spec_union.py b/tests/common/configuration/test_spec_union.py index 4892967ab7..b1e316734d 100644 --- a/tests/common/configuration/test_spec_union.py +++ b/tests/common/configuration/test_spec_union.py @@ -26,8 +26,8 @@ def auth(self): @configspec class ZenEmailCredentials(ZenCredentials): - email: str - password: TSecretValue + email: str = None + password: TSecretValue = None def parse_native_representation(self, native_value: Any) -> None: assert isinstance(native_value, str) @@ -44,8 +44,8 @@ def auth(self): @configspec class ZenApiKeyCredentials(ZenCredentials): - api_key: str - api_secret: TSecretValue + api_key: str = None + api_secret: TSecretValue = None def parse_native_representation(self, native_value: Any) -> None: assert isinstance(native_value, str) @@ -62,14 +62,14 @@ def auth(self): @configspec class ZenConfig(BaseConfiguration): - credentials: Union[ZenApiKeyCredentials, ZenEmailCredentials] + credentials: Union[ZenApiKeyCredentials, ZenEmailCredentials] = None some_option: bool = False @configspec class ZenConfigOptCredentials: # add none to union to make it optional - credentials: Union[ZenApiKeyCredentials, ZenEmailCredentials, None] + credentials: Union[ZenApiKeyCredentials, ZenEmailCredentials, None] = None some_option: bool = False @@ -200,10 +200,10 @@ class GoogleAnalyticsCredentialsOAuth(GoogleAnalyticsCredentialsBase): This class is used to store credentials Google Analytics """ - client_id: str - client_secret: TSecretValue - project_id: TSecretValue - refresh_token: TSecretValue + client_id: str = None + client_secret: TSecretValue = None + project_id: TSecretValue = None + refresh_token: TSecretValue = None access_token: Optional[TSecretValue] = None diff --git a/tests/common/configuration/test_toml_provider.py b/tests/common/configuration/test_toml_provider.py index fcec881521..4f2219716a 100644 --- a/tests/common/configuration/test_toml_provider.py +++ b/tests/common/configuration/test_toml_provider.py @@ -42,12 +42,12 @@ @configspec class EmbeddedWithGcpStorage(BaseConfiguration): - gcp_storage: GcpServiceAccountCredentialsWithoutDefaults + gcp_storage: GcpServiceAccountCredentialsWithoutDefaults = None @configspec class EmbeddedWithGcpCredentials(BaseConfiguration): - credentials: GcpServiceAccountCredentialsWithoutDefaults + credentials: GcpServiceAccountCredentialsWithoutDefaults = None def test_secrets_from_toml_secrets(toml_providers: ConfigProvidersContext) -> None: @@ -378,7 +378,7 @@ def test_write_value(toml_providers: ConfigProvidersContext) -> None: # dict creates only shallow dict so embedded credentials will fail creds = WithCredentialsConfiguration() - creds.credentials = SecretCredentials({"secret_value": "***** ***"}) + creds.credentials = SecretCredentials(secret_value=TSecretValue("***** ***")) with pytest.raises(ValueError): provider.set_value("written_creds", dict(creds), None) diff --git a/tests/common/configuration/utils.py b/tests/common/configuration/utils.py index 73643561dc..670dcac87a 100644 --- a/tests/common/configuration/utils.py +++ b/tests/common/configuration/utils.py @@ -3,6 +3,7 @@ import datetime # noqa: I251 from typing import ( Any, + ClassVar, Iterator, List, Optional, @@ -71,19 +72,15 @@ class SecretCredentials(CredentialsConfiguration): @configspec class WithCredentialsConfiguration(BaseConfiguration): - credentials: SecretCredentials + credentials: SecretCredentials = None @configspec class SectionedConfiguration(BaseConfiguration): - __section__ = "DLT_TEST" + __section__: ClassVar[str] = "DLT_TEST" password: str = None - if TYPE_CHECKING: - - def __init__(self, password: str = None) -> None: ... - @pytest.fixture(scope="function") def environment() -> Any: diff --git a/tests/common/reflection/test_reflect_spec.py b/tests/common/reflection/test_reflect_spec.py index 092d25b717..952d0fc596 100644 --- a/tests/common/reflection/test_reflect_spec.py +++ b/tests/common/reflection/test_reflect_spec.py @@ -314,7 +314,7 @@ def f_kw_defaults_args( def test_reflect_custom_base() -> None: @configspec class BaseParams(BaseConfiguration): - str_str: str + str_str: str = None def _f_1(str_str=dlt.config.value, p_def: bool = True): pass diff --git a/tests/common/runtime/test_logging.py b/tests/common/runtime/test_logging.py index 19f67fe899..5ff92f7d94 100644 --- a/tests/common/runtime/test_logging.py +++ b/tests/common/runtime/test_logging.py @@ -3,7 +3,7 @@ from dlt.common import logger from dlt.common.runtime import exec_info -from dlt.common.runtime.logger import is_logging +from dlt.common.logger import is_logging from dlt.common.typing import StrStr, DictStrStr from dlt.common.configuration import configspec from dlt.common.configuration.specs import RunConfiguration diff --git a/tests/common/runtime/test_telemetry.py b/tests/common/runtime/test_telemetry.py index eece36aae7..e67f7e8360 100644 --- a/tests/common/runtime/test_telemetry.py +++ b/tests/common/runtime/test_telemetry.py @@ -35,16 +35,6 @@ class SentryLoggerConfiguration(RunConfiguration): class SentryLoggerCriticalConfiguration(SentryLoggerConfiguration): log_level: str = "CRITICAL" - if TYPE_CHECKING: - - def __init__( - self, - pipeline_name: str = "logger", - sentry_dsn: str = "https://sentry.io", - dlthub_telemetry_segment_write_key: str = "TLJiyRkGVZGCi2TtjClamXpFcxAA1rSB", - log_level: str = "CRITICAL", - ) -> None: ... - def test_sentry_log_level() -> None: from dlt.common.runtime.sentry import _get_sentry_log_level diff --git a/tests/common/test_destination.py b/tests/common/test_destination.py index 5240b889f3..24b0928463 100644 --- a/tests/common/test_destination.py +++ b/tests/common/test_destination.py @@ -78,7 +78,7 @@ def test_import_destination_config() -> None: dest = Destination.from_reference(ref="dlt.destinations.duckdb", environment="stage") assert dest.destination_type == "dlt.destinations.duckdb" assert dest.config_params["environment"] == "stage" - config = dest.configuration(dest.spec(dataset_name="dataset")) # type: ignore + config = dest.configuration(dest.spec()._bind_dataset_name(dataset_name="dataset")) # type: ignore assert config.destination_type == "duckdb" assert config.destination_name == "duckdb" assert config.environment == "stage" @@ -87,7 +87,7 @@ def test_import_destination_config() -> None: dest = Destination.from_reference(ref=None, destination_name="duckdb", environment="production") assert dest.destination_type == "dlt.destinations.duckdb" assert dest.config_params["environment"] == "production" - config = dest.configuration(dest.spec(dataset_name="dataset")) # type: ignore + config = dest.configuration(dest.spec()._bind_dataset_name(dataset_name="dataset")) # type: ignore assert config.destination_type == "duckdb" assert config.destination_name == "duckdb" assert config.environment == "production" @@ -98,7 +98,7 @@ def test_import_destination_config() -> None: ) assert dest.destination_type == "dlt.destinations.duckdb" assert dest.config_params["environment"] == "devel" - config = dest.configuration(dest.spec(dataset_name="dataset")) # type: ignore + config = dest.configuration(dest.spec()._bind_dataset_name(dataset_name="dataset")) # type: ignore assert config.destination_type == "duckdb" assert config.destination_name == "my_destination" assert config.environment == "devel" @@ -112,63 +112,63 @@ def test_normalize_dataset_name() -> None: # with schema name appended assert ( - DestinationClientDwhConfiguration( - dataset_name="ban_ana_dataset", default_schema_name="default" - ).normalize_dataset_name(Schema("banana")) + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name="ban_ana_dataset", default_schema_name="default") + .normalize_dataset_name(Schema("banana")) == "ban_ana_dataset_banana" ) # without schema name appended assert ( - DestinationClientDwhConfiguration( - dataset_name="ban_ana_dataset", default_schema_name="default" - ).normalize_dataset_name(Schema("default")) + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name="ban_ana_dataset", default_schema_name="default") + .normalize_dataset_name(Schema("default")) == "ban_ana_dataset" ) # dataset name will be normalized (now it is up to destination to normalize this) assert ( - DestinationClientDwhConfiguration( - dataset_name="BaNaNa", default_schema_name="default" - ).normalize_dataset_name(Schema("banana")) + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name="BaNaNa", default_schema_name="default") + .normalize_dataset_name(Schema("banana")) == "ba_na_na_banana" ) # empty schemas are invalid with pytest.raises(ValueError): - DestinationClientDwhConfiguration( - dataset_name="banana_dataset", default_schema_name=None + DestinationClientDwhConfiguration()._bind_dataset_name( + dataset_name="banana_dataset" ).normalize_dataset_name(Schema(None)) with pytest.raises(ValueError): - DestinationClientDwhConfiguration( + DestinationClientDwhConfiguration()._bind_dataset_name( dataset_name="banana_dataset", default_schema_name="" ).normalize_dataset_name(Schema("")) # empty dataset name is valid! assert ( - DestinationClientDwhConfiguration( - dataset_name="", default_schema_name="ban_schema" - ).normalize_dataset_name(Schema("schema_ana")) + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name="", default_schema_name="ban_schema") + .normalize_dataset_name(Schema("schema_ana")) == "_schema_ana" ) # empty dataset name is valid! assert ( - DestinationClientDwhConfiguration( - dataset_name="", default_schema_name="schema_ana" - ).normalize_dataset_name(Schema("schema_ana")) + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name="", default_schema_name="schema_ana") + .normalize_dataset_name(Schema("schema_ana")) == "" ) # None dataset name is valid! assert ( - DestinationClientDwhConfiguration( - dataset_name=None, default_schema_name="ban_schema" - ).normalize_dataset_name(Schema("schema_ana")) + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name=None, default_schema_name="ban_schema") + .normalize_dataset_name(Schema("schema_ana")) == "_schema_ana" ) # None dataset name is valid! assert ( - DestinationClientDwhConfiguration( - dataset_name=None, default_schema_name="schema_ana" - ).normalize_dataset_name(Schema("schema_ana")) + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name=None, default_schema_name="schema_ana") + .normalize_dataset_name(Schema("schema_ana")) is None ) @@ -176,9 +176,9 @@ def test_normalize_dataset_name() -> None: schema = Schema("barbapapa") schema._schema_name = "BarbaPapa" assert ( - DestinationClientDwhConfiguration( - dataset_name="set", default_schema_name="default" - ).normalize_dataset_name(schema) + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name="set", default_schema_name="default") + .normalize_dataset_name(schema) == "set_barba_papa" ) @@ -186,8 +186,8 @@ def test_normalize_dataset_name() -> None: def test_normalize_dataset_name_none_default_schema() -> None: # if default schema is None, suffix is not added assert ( - DestinationClientDwhConfiguration( - dataset_name="ban_ana_dataset", default_schema_name=None - ).normalize_dataset_name(Schema("default")) + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name="ban_ana_dataset", default_schema_name=None) + .normalize_dataset_name(Schema("default")) == "ban_ana_dataset" ) diff --git a/tests/destinations/test_custom_destination.py b/tests/destinations/test_custom_destination.py index 7280ec419b..cfefceac88 100644 --- a/tests/destinations/test_custom_destination.py +++ b/tests/destinations/test_custom_destination.py @@ -455,7 +455,7 @@ def my_gcp_sink( def test_destination_with_spec() -> None: @configspec class MyDestinationSpec(CustomDestinationClientConfiguration): - my_predefined_val: str + my_predefined_val: str = None # check destination without additional config params @dlt.destination(spec=MyDestinationSpec) diff --git a/tests/helpers/dbt_tests/local/utils.py b/tests/helpers/dbt_tests/local/utils.py index 7097140a83..8fd3dba44f 100644 --- a/tests/helpers/dbt_tests/local/utils.py +++ b/tests/helpers/dbt_tests/local/utils.py @@ -40,7 +40,9 @@ def setup_rasa_runner( runner = create_runner( Venv.restore_current(), # credentials are exported to env in setup_rasa_runner_client - DestinationClientDwhConfiguration(dataset_name=dataset_name or FIXTURES_DATASET_NAME), + DestinationClientDwhConfiguration()._bind_dataset_name( + dataset_name=dataset_name or FIXTURES_DATASET_NAME + ), TEST_STORAGE_ROOT, package_profile_name=profile_name, config=C, diff --git a/tests/helpers/providers/test_google_secrets_provider.py b/tests/helpers/providers/test_google_secrets_provider.py index d6d94774b9..00c54b5705 100644 --- a/tests/helpers/providers/test_google_secrets_provider.py +++ b/tests/helpers/providers/test_google_secrets_provider.py @@ -1,6 +1,5 @@ -import dlt from dlt import TSecretValue -from dlt.common import logger +from dlt.common.runtime.init import init_logging from dlt.common.configuration.specs import GcpServiceAccountCredentials from dlt.common.configuration.providers import GoogleSecretsProvider from dlt.common.configuration.accessors import secrets @@ -24,7 +23,7 @@ def test_regular_keys() -> None: - logger.init_logging(RunConfiguration()) + init_logging(RunConfiguration()) # copy bigquery credentials into providers credentials c = resolve_configuration( GcpServiceAccountCredentials(), sections=(known_sections.DESTINATION, "bigquery") diff --git a/tests/load/bigquery/test_bigquery_client.py b/tests/load/bigquery/test_bigquery_client.py index ac17bb8316..a97b612ad0 100644 --- a/tests/load/bigquery/test_bigquery_client.py +++ b/tests/load/bigquery/test_bigquery_client.py @@ -203,7 +203,8 @@ def test_get_oauth_access_token() -> None: def test_bigquery_configuration() -> None: config = resolve_configuration( - BigQueryClientConfiguration(dataset_name="dataset"), sections=("destination", "bigquery") + BigQueryClientConfiguration()._bind_dataset_name(dataset_name="dataset"), + sections=("destination", "bigquery"), ) assert config.location == "US" assert config.get_location() == "US" @@ -215,7 +216,8 @@ def test_bigquery_configuration() -> None: # credential location is deprecated os.environ["CREDENTIALS__LOCATION"] = "EU" config = resolve_configuration( - BigQueryClientConfiguration(dataset_name="dataset"), sections=("destination", "bigquery") + BigQueryClientConfiguration()._bind_dataset_name(dataset_name="dataset"), + sections=("destination", "bigquery"), ) assert config.location == "US" assert config.credentials.location == "EU" @@ -223,17 +225,21 @@ def test_bigquery_configuration() -> None: assert config.get_location() == "EU" os.environ["LOCATION"] = "ATLANTIS" config = resolve_configuration( - BigQueryClientConfiguration(dataset_name="dataset"), sections=("destination", "bigquery") + BigQueryClientConfiguration()._bind_dataset_name(dataset_name="dataset"), + sections=("destination", "bigquery"), ) assert config.get_location() == "ATLANTIS" os.environ["DESTINATION__FILE_UPLOAD_TIMEOUT"] = "20000" config = resolve_configuration( - BigQueryClientConfiguration(dataset_name="dataset"), sections=("destination", "bigquery") + BigQueryClientConfiguration()._bind_dataset_name(dataset_name="dataset"), + sections=("destination", "bigquery"), ) assert config.file_upload_timeout == 20000.0 # default fingerprint is empty - assert BigQueryClientConfiguration(dataset_name="dataset").fingerprint() == "" + assert ( + BigQueryClientConfiguration()._bind_dataset_name(dataset_name="dataset").fingerprint() == "" + ) def test_bigquery_job_errors(client: BigQueryClient, file_storage: FileStorage) -> None: diff --git a/tests/load/bigquery/test_bigquery_table_builder.py b/tests/load/bigquery/test_bigquery_table_builder.py index a223de9b26..fd58a6e033 100644 --- a/tests/load/bigquery/test_bigquery_table_builder.py +++ b/tests/load/bigquery/test_bigquery_table_builder.py @@ -4,10 +4,6 @@ from dlt.destinations.impl.bigquery.bigquery_adapter import ( PARTITION_HINT, CLUSTER_HINT, - TABLE_DESCRIPTION_HINT, - ROUND_HALF_EVEN_HINT, - ROUND_HALF_AWAY_FROM_ZERO_HINT, - TABLE_EXPIRATION_HINT, ) import google @@ -17,9 +13,12 @@ import dlt from dlt.common.configuration import resolve_configuration -from dlt.common.configuration.specs import GcpServiceAccountCredentialsWithoutDefaults +from dlt.common.configuration.specs import ( + GcpServiceAccountCredentialsWithoutDefaults, + GcpServiceAccountCredentials, +) from dlt.common.pendulum import pendulum -from dlt.common.schema import Schema, TColumnHint +from dlt.common.schema import Schema from dlt.common.utils import custom_environ from dlt.common.utils import uniq_id from dlt.destinations.exceptions import DestinationSchemaWillNotUpdate @@ -53,13 +52,13 @@ def test_configuration() -> None: @pytest.fixture def gcp_client(empty_schema: Schema) -> BigQueryClient: # return a client without opening connection - creds = GcpServiceAccountCredentialsWithoutDefaults() + creds = GcpServiceAccountCredentials() creds.project_id = "test_project_id" # noinspection PydanticTypeChecker return BigQueryClient( empty_schema, - BigQueryClientConfiguration( - dataset_name=f"test_{uniq_id()}", credentials=creds # type: ignore[arg-type] + BigQueryClientConfiguration(credentials=creds)._bind_dataset_name( + dataset_name=f"test_{uniq_id()}" ), ) diff --git a/tests/load/databricks/test_databricks_configuration.py b/tests/load/databricks/test_databricks_configuration.py index 9127e39be4..8d30d05e42 100644 --- a/tests/load/databricks/test_databricks_configuration.py +++ b/tests/load/databricks/test_databricks_configuration.py @@ -17,7 +17,9 @@ def test_databricks_credentials_to_connector_params(): # JSON encoded dict of extra args os.environ["CREDENTIALS__CONNECTION_PARAMETERS"] = '{"extra_a": "a", "extra_b": "b"}' - config = resolve_configuration(DatabricksClientConfiguration(dataset_name="my-dataset")) + config = resolve_configuration( + DatabricksClientConfiguration()._bind_dataset_name(dataset_name="my-dataset") + ) credentials = config.credentials diff --git a/tests/load/duckdb/test_duckdb_client.py b/tests/load/duckdb/test_duckdb_client.py index ef151833e4..3deed7a77d 100644 --- a/tests/load/duckdb/test_duckdb_client.py +++ b/tests/load/duckdb/test_duckdb_client.py @@ -31,7 +31,9 @@ def test_duckdb_open_conn_default() -> None: delete_quack_db() try: get_resolved_traces().clear() - c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset")) + c = resolve_configuration( + DuckDbClientConfiguration()._bind_dataset_name(dataset_name="test_dataset") + ) # print(str(c.credentials)) # print(str(os.getcwd())) # print(get_resolved_traces()) @@ -52,11 +54,15 @@ def test_duckdb_open_conn_default() -> None: def test_duckdb_database_path() -> None: # resolve without any path provided - c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset")) + c = resolve_configuration( + DuckDbClientConfiguration()._bind_dataset_name(dataset_name="test_dataset") + ) assert c.credentials._conn_str().lower() == os.path.abspath("quack.duckdb").lower() # resolve without any path but with pipeline context p = dlt.pipeline(pipeline_name="quack_pipeline") - c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset")) + c = resolve_configuration( + DuckDbClientConfiguration()._bind_dataset_name(dataset_name="test_dataset") + ) # still cwd db_path = os.path.abspath(os.path.join(".", "quack_pipeline.duckdb")) assert c.credentials._conn_str().lower() == db_path.lower() @@ -75,7 +81,9 @@ def test_duckdb_database_path() -> None: # test special :pipeline: path to create in pipeline folder c = resolve_configuration( - DuckDbClientConfiguration(dataset_name="test_dataset", credentials=":pipeline:") + DuckDbClientConfiguration(credentials=":pipeline:")._bind_dataset_name( + dataset_name="test_dataset" + ) ) db_path = os.path.abspath(os.path.join(p.working_dir, DEFAULT_DUCK_DB_NAME)) assert c.credentials._conn_str().lower() == db_path.lower() @@ -90,8 +98,8 @@ def test_duckdb_database_path() -> None: db_path = "_storage/test_quack.duckdb" c = resolve_configuration( DuckDbClientConfiguration( - dataset_name="test_dataset", credentials="duckdb:///_storage/test_quack.duckdb" - ) + credentials="duckdb:///_storage/test_quack.duckdb" + )._bind_dataset_name(dataset_name="test_dataset") ) assert c.credentials._conn_str().lower() == os.path.abspath(db_path).lower() conn = c.credentials.borrow_conn(read_only=False) @@ -102,7 +110,9 @@ def test_duckdb_database_path() -> None: # provide absolute path db_path = os.path.abspath("_storage/abs_test_quack.duckdb") c = resolve_configuration( - DuckDbClientConfiguration(dataset_name="test_dataset", credentials=f"duckdb:///{db_path}") + DuckDbClientConfiguration(credentials=f"duckdb:///{db_path}")._bind_dataset_name( + dataset_name="test_dataset", + ) ) assert os.path.isabs(c.credentials.database) assert c.credentials._conn_str().lower() == db_path.lower() @@ -114,7 +124,9 @@ def test_duckdb_database_path() -> None: # set just path as credentials db_path = "_storage/path_test_quack.duckdb" c = resolve_configuration( - DuckDbClientConfiguration(dataset_name="test_dataset", credentials=db_path) + DuckDbClientConfiguration(credentials=db_path)._bind_dataset_name( + dataset_name="test_dataset" + ) ) assert c.credentials._conn_str().lower() == os.path.abspath(db_path).lower() conn = c.credentials.borrow_conn(read_only=False) @@ -124,7 +136,9 @@ def test_duckdb_database_path() -> None: db_path = os.path.abspath("_storage/abs_path_test_quack.duckdb") c = resolve_configuration( - DuckDbClientConfiguration(dataset_name="test_dataset", credentials=db_path) + DuckDbClientConfiguration(credentials=db_path)._bind_dataset_name( + dataset_name="test_dataset" + ) ) assert os.path.isabs(c.credentials.database) assert c.credentials._conn_str().lower() == db_path.lower() @@ -138,7 +152,9 @@ def test_duckdb_database_path() -> None: with pytest.raises(duckdb.IOException): c = resolve_configuration( - DuckDbClientConfiguration(dataset_name="test_dataset", credentials=TEST_STORAGE_ROOT) + DuckDbClientConfiguration(credentials=TEST_STORAGE_ROOT)._bind_dataset_name( + dataset_name="test_dataset" + ) ) conn = c.credentials.borrow_conn(read_only=False) @@ -225,7 +241,7 @@ def test_external_duckdb_database() -> None: # pass explicit in memory database conn = duckdb.connect(":memory:") c = resolve_configuration( - DuckDbClientConfiguration(dataset_name="test_dataset", credentials=conn) + DuckDbClientConfiguration(credentials=conn)._bind_dataset_name(dataset_name="test_dataset") ) assert c.credentials._conn_borrows == 0 assert c.credentials._conn is conn diff --git a/tests/load/duckdb/test_duckdb_table_builder.py b/tests/load/duckdb/test_duckdb_table_builder.py index 9b12e04f77..542b18993c 100644 --- a/tests/load/duckdb/test_duckdb_table_builder.py +++ b/tests/load/duckdb/test_duckdb_table_builder.py @@ -14,7 +14,10 @@ @pytest.fixture def client(empty_schema: Schema) -> DuckDbClient: # return client without opening connection - return DuckDbClient(empty_schema, DuckDbClientConfiguration(dataset_name="test_" + uniq_id())) + return DuckDbClient( + empty_schema, + DuckDbClientConfiguration()._bind_dataset_name(dataset_name="test_" + uniq_id()), + ) def test_create_table(client: DuckDbClient) -> None: @@ -89,7 +92,9 @@ def test_create_table_with_hints(client: DuckDbClient) -> None: # same thing with indexes client = DuckDbClient( client.schema, - DuckDbClientConfiguration(dataset_name="test_" + uniq_id(), create_indexes=True), + DuckDbClientConfiguration(create_indexes=True)._bind_dataset_name( + dataset_name="test_" + uniq_id() + ), ) sql = client._get_table_update_sql("event_test_table", mod_update, False)[0] sqlfluff.parse(sql) diff --git a/tests/load/duckdb/test_motherduck_client.py b/tests/load/duckdb/test_motherduck_client.py index d57cf58f53..ba60e0de6d 100644 --- a/tests/load/duckdb/test_motherduck_client.py +++ b/tests/load/duckdb/test_motherduck_client.py @@ -19,13 +19,15 @@ def test_motherduck_database() -> None: # os.environ.pop("HOME", None) cred = MotherDuckCredentials("md:///?token=TOKEN") + print(dict(cred)) assert cred.password == "TOKEN" cred = MotherDuckCredentials() cred.parse_native_representation("md:///?token=TOKEN") assert cred.password == "TOKEN" config = resolve_configuration( - MotherDuckClientConfiguration(dataset_name="test"), sections=("destination", "motherduck") + MotherDuckClientConfiguration()._bind_dataset_name(dataset_name="test"), + sections=("destination", "motherduck"), ) # connect con = config.credentials.borrow_conn(read_only=False) diff --git a/tests/load/filesystem/test_aws_credentials.py b/tests/load/filesystem/test_aws_credentials.py index 7a0d42eb6d..62c2e3cd85 100644 --- a/tests/load/filesystem/test_aws_credentials.py +++ b/tests/load/filesystem/test_aws_credentials.py @@ -45,7 +45,7 @@ def test_aws_credentials_from_botocore(environment: Dict[str, str]) -> None: session = botocore.session.get_session() region_name = "eu-central-1" # session.get_config_variable('region') - c = AwsCredentials(session) + c = AwsCredentials.from_session(session) assert c.profile_name is None assert c.aws_access_key_id == "fake_access_key" assert c.region_name == region_name @@ -83,7 +83,7 @@ def test_aws_credentials_from_boto3(environment: Dict[str, str]) -> None: session = boto3.Session() - c = AwsCredentials(session) + c = AwsCredentials.from_session(session) assert c.profile_name is None assert c.aws_access_key_id == "fake_access_key" assert c.region_name == session.region_name diff --git a/tests/load/filesystem/utils.py b/tests/load/filesystem/utils.py index 1232be5c43..6e697fdef9 100644 --- a/tests/load/filesystem/utils.py +++ b/tests/load/filesystem/utils.py @@ -14,7 +14,7 @@ from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_section_context import ConfigSectionContext -from dlt.common.destination.reference import LoadJob, TDestination +from dlt.common.destination.reference import LoadJob from dlt.common.pendulum import timedelta, __utcnow from dlt.destinations import filesystem from dlt.destinations.impl.filesystem.filesystem import FilesystemClient @@ -24,11 +24,11 @@ def setup_loader(dataset_name: str) -> Load: - destination: TDestination = filesystem() # type: ignore[assignment] - config = filesystem.spec(dataset_name=dataset_name) + destination = filesystem() + config = destination.spec()._bind_dataset_name(dataset_name=dataset_name) # setup loader with Container().injectable_context(ConfigSectionContext(sections=("filesystem",))): - return Load(destination, initial_client_config=config) + return Load(destination, initial_client_config=config) # type: ignore[arg-type] @contextmanager diff --git a/tests/load/mssql/test_mssql_table_builder.py b/tests/load/mssql/test_mssql_table_builder.py index 75f46e8905..1b4a77a2ab 100644 --- a/tests/load/mssql/test_mssql_table_builder.py +++ b/tests/load/mssql/test_mssql_table_builder.py @@ -17,7 +17,9 @@ def client(empty_schema: Schema) -> MsSqlClient: # return client without opening connection return MsSqlClient( empty_schema, - MsSqlClientConfiguration(dataset_name="test_" + uniq_id(), credentials=MsSqlCredentials()), + MsSqlClientConfiguration(credentials=MsSqlCredentials())._bind_dataset_name( + dataset_name="test_" + uniq_id() + ), ) diff --git a/tests/load/postgres/test_postgres_client.py b/tests/load/postgres/test_postgres_client.py index daabf6fc51..896e449b28 100644 --- a/tests/load/postgres/test_postgres_client.py +++ b/tests/load/postgres/test_postgres_client.py @@ -62,6 +62,10 @@ def test_postgres_credentials_native_value(environment) -> None: assert c.is_resolved() assert c.password == "loader" + c = PostgresCredentials("postgres://loader:loader@localhost/dlt_data") + assert c.password == "loader" + assert c.database == "dlt_data" + def test_postgres_credentials_timeout() -> None: # test postgres timeout diff --git a/tests/load/postgres/test_postgres_table_builder.py b/tests/load/postgres/test_postgres_table_builder.py index fde9d82cf7..0ab1343a3b 100644 --- a/tests/load/postgres/test_postgres_table_builder.py +++ b/tests/load/postgres/test_postgres_table_builder.py @@ -2,7 +2,6 @@ from copy import deepcopy import sqlfluff -from dlt.common.schema.utils import new_table from dlt.common.utils import uniq_id from dlt.common.schema import Schema @@ -20,8 +19,8 @@ def client(empty_schema: Schema) -> PostgresClient: # return client without opening connection return PostgresClient( empty_schema, - PostgresClientConfiguration( - dataset_name="test_" + uniq_id(), credentials=PostgresCredentials() + PostgresClientConfiguration(credentials=PostgresCredentials())._bind_dataset_name( + dataset_name="test_" + uniq_id() ), ) @@ -97,10 +96,9 @@ def test_create_table_with_hints(client: PostgresClient) -> None: client = PostgresClient( client.schema, PostgresClientConfiguration( - dataset_name="test_" + uniq_id(), create_indexes=False, credentials=PostgresCredentials(), - ), + )._bind_dataset_name(dataset_name="test_" + uniq_id()), ) sql = client._get_table_update_sql("event_test_table", mod_update, False)[0] sqlfluff.parse(sql, dialect="postgres") diff --git a/tests/load/redshift/test_redshift_table_builder.py b/tests/load/redshift/test_redshift_table_builder.py index c6981e5553..bc132c7818 100644 --- a/tests/load/redshift/test_redshift_table_builder.py +++ b/tests/load/redshift/test_redshift_table_builder.py @@ -20,8 +20,8 @@ def client(empty_schema: Schema) -> RedshiftClient: # return client without opening connection return RedshiftClient( empty_schema, - RedshiftClientConfiguration( - dataset_name="test_" + uniq_id(), credentials=RedshiftCredentials() + RedshiftClientConfiguration(credentials=RedshiftCredentials())._bind_dataset_name( + dataset_name="test_" + uniq_id() ), ) diff --git a/tests/load/snowflake/test_snowflake_table_builder.py b/tests/load/snowflake/test_snowflake_table_builder.py index 1e80a61f1c..5d7108803e 100644 --- a/tests/load/snowflake/test_snowflake_table_builder.py +++ b/tests/load/snowflake/test_snowflake_table_builder.py @@ -21,7 +21,9 @@ def snowflake_client(empty_schema: Schema) -> SnowflakeClient: creds = SnowflakeCredentials() return SnowflakeClient( empty_schema, - SnowflakeClientConfiguration(dataset_name="test_" + uniq_id(), credentials=creds), + SnowflakeClientConfiguration(credentials=creds)._bind_dataset_name( + dataset_name="test_" + uniq_id() + ), ) diff --git a/tests/load/synapse/test_synapse_table_builder.py b/tests/load/synapse/test_synapse_table_builder.py index 871ceecf96..8575835820 100644 --- a/tests/load/synapse/test_synapse_table_builder.py +++ b/tests/load/synapse/test_synapse_table_builder.py @@ -25,8 +25,8 @@ def client(empty_schema: Schema) -> SynapseClient: # return client without opening connection client = SynapseClient( empty_schema, - SynapseClientConfiguration( - dataset_name="test_" + uniq_id(), credentials=SynapseCredentials() + SynapseClientConfiguration(credentials=SynapseCredentials())._bind_dataset_name( + dataset_name="test_" + uniq_id() ), ) assert client.config.create_indexes is False @@ -39,8 +39,8 @@ def client_with_indexes_enabled(empty_schema: Schema) -> SynapseClient: client = SynapseClient( empty_schema, SynapseClientConfiguration( - dataset_name="test_" + uniq_id(), credentials=SynapseCredentials(), create_indexes=True - ), + credentials=SynapseCredentials(), create_indexes=True + )._bind_dataset_name(dataset_name="test_" + uniq_id()), ) assert client.config.create_indexes is True return client diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index d7884abcf0..c5e4f874fc 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -6,8 +6,7 @@ from typing import List from dlt.common.exceptions import TerminalException, TerminalValueError -from dlt.common.schema.typing import TWriteDisposition -from dlt.common.storages import FileStorage, LoadStorage, PackageStorage, ParsedLoadJobFileName +from dlt.common.storages import FileStorage, PackageStorage, ParsedLoadJobFileName from dlt.common.storages.load_package import LoadJobInfo from dlt.common.storages.load_storage import JobWithUnsupportedWriterException from dlt.common.destination.reference import LoadJob, TDestination @@ -814,7 +813,9 @@ def setup_loader( if filesystem_staging: # do not accept jsonl to not conflict with filesystem destination client_config = client_config or DummyClientConfiguration(loader_file_format="reference") - staging_system_config = FilesystemDestinationClientConfiguration(dataset_name="dummy") + staging_system_config = FilesystemDestinationClientConfiguration()._bind_dataset_name( + dataset_name="dummy" + ) staging_system_config.as_staging = True os.makedirs(REMOTE_FILESYSTEM) staging = filesystem(bucket_url=REMOTE_FILESYSTEM) diff --git a/tests/pipeline/test_dlt_versions.py b/tests/pipeline/test_dlt_versions.py index 8906958e0c..ccf926cc62 100644 --- a/tests/pipeline/test_dlt_versions.py +++ b/tests/pipeline/test_dlt_versions.py @@ -70,7 +70,7 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: } # check loads table without attaching to pipeline duckdb_cfg = resolve_configuration( - DuckDbClientConfiguration(dataset_name=GITHUB_DATASET), + DuckDbClientConfiguration()._bind_dataset_name(dataset_name=GITHUB_DATASET), sections=("destination", "duckdb"), ) with DuckDbSqlClient(GITHUB_DATASET, duckdb_cfg.credentials) as client: @@ -189,7 +189,7 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: venv = Venv.restore_current() print(venv.run_script("../tests/pipeline/cases/github_pipeline/github_load.py")) duckdb_cfg = resolve_configuration( - DuckDbClientConfiguration(dataset_name=GITHUB_DATASET), + DuckDbClientConfiguration()._bind_dataset_name(dataset_name=GITHUB_DATASET), sections=("destination", "duckdb"), ) with DuckDbSqlClient(GITHUB_DATASET, duckdb_cfg.credentials) as client: diff --git a/tests/sources/helpers/rest_client/__init__.py b/tests/sources/helpers/rest_client/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/sources/helpers/rest_client/conftest.py b/tests/sources/helpers/rest_client/conftest.py new file mode 100644 index 0000000000..7eec090db6 --- /dev/null +++ b/tests/sources/helpers/rest_client/conftest.py @@ -0,0 +1,204 @@ +import re +from typing import NamedTuple, Callable, Pattern, List, TYPE_CHECKING +import base64 + +from urllib.parse import urlsplit, urlunsplit + +import pytest +import requests_mock + +from dlt.common import json + +if TYPE_CHECKING: + RequestCallback = Callable[[requests_mock.Request, requests_mock.Context], str] +else: + RequestCallback = Callable + +MOCK_BASE_URL = "https://api.example.com" + + +class Route(NamedTuple): + method: str + pattern: Pattern[str] + callback: RequestCallback + + +class APIRouter: + def __init__(self, base_url: str): + self.routes: List[Route] = [] + self.base_url = base_url + + def _add_route( + self, method: str, pattern: str, func: RequestCallback + ) -> RequestCallback: + compiled_pattern = re.compile(f"{self.base_url}{pattern}") + self.routes.append(Route(method, compiled_pattern, func)) + return func + + def get(self, pattern: str) -> Callable[[RequestCallback], RequestCallback]: + def decorator(func: RequestCallback) -> RequestCallback: + return self._add_route("GET", pattern, func) + + return decorator + + def post(self, pattern: str) -> Callable[[RequestCallback], RequestCallback]: + def decorator(func: RequestCallback) -> RequestCallback: + return self._add_route("POST", pattern, func) + + return decorator + + def register_routes(self, mocker: requests_mock.Mocker) -> None: + for route in self.routes: + mocker.register_uri( + route.method, + route.pattern, + text=route.callback, + ) + + +router = APIRouter(MOCK_BASE_URL) + + +def serialize_page(records, page_number, total_pages, base_url, records_key="data"): + if records_key is None: + return json.dumps(records) + + response = { + records_key: records, + "page": page_number, + "total_pages": total_pages, + } + + if page_number < total_pages: + next_page = page_number + 1 + + scheme, netloc, path, _, _ = urlsplit(base_url) + next_page = urlunsplit([scheme, netloc, path, f"page={next_page}", ""]) + response["next_page"] = next_page + + return json.dumps(response) + + +def generate_posts(count=100): + return [{"id": i, "title": f"Post {i}"} for i in range(count)] + + +def generate_comments(post_id, count=50): + return [{"id": i, "body": f"Comment {i} for post {post_id}"} for i in range(count)] + + +def get_page_number(qs, key="page", default=1): + return int(qs.get(key, [default])[0]) + + +def paginate_response(request, records, page_size=10, records_key="data"): + page_number = get_page_number(request.qs) + total_records = len(records) + total_pages = (total_records + page_size - 1) // page_size + start_index = (page_number - 1) * 10 + end_index = start_index + 10 + records_slice = records[start_index:end_index] + return serialize_page( + records_slice, page_number, total_pages, request.url, records_key + ) + + +@pytest.fixture(scope="module") +def mock_api_server(): + with requests_mock.Mocker() as m: + + @router.get(r"/posts_no_key(\?page=\d+)?$") + def posts_no_key(request, context): + return paginate_response(request, generate_posts(), records_key=None) + + @router.get(r"/posts(\?page=\d+)?$") + def posts(request, context): + return paginate_response(request, generate_posts()) + + @router.get(r"/posts/(\d+)/comments") + def post_comments(request, context): + post_id = int(request.url.split("/")[-2]) + return paginate_response(request, generate_comments(post_id)) + + @router.get(r"/posts/\d+$") + def post_detail(request, context): + post_id = request.url.split("/")[-1] + return json.dumps({"id": post_id, "body": f"Post body {post_id}"}) + + @router.get(r"/posts/\d+/some_details_404") + def post_detail_404(request, context): + """Return 404 for post with id > 0. Used to test ignoring 404 errors.""" + post_id = int(request.url.split("/")[-2]) + if post_id < 1: + return json.dumps({"id": post_id, "body": f"Post body {post_id}"}) + else: + context.status_code = 404 + return json.dumps({"error": "Post not found"}) + + @router.get(r"/posts_under_a_different_key$") + def posts_with_results_key(request, context): + return paginate_response( + request, generate_posts(), records_key="many-results" + ) + + @router.get("/protected/posts/basic-auth") + def protected_basic_auth(request, context): + auth = request.headers.get("Authorization") + creds = "user:password" + creds_base64 = base64.b64encode(creds.encode()).decode() + if auth == f"Basic {creds_base64}": + return paginate_response(request, generate_posts()) + context.status_code = 401 + return json.dumps({"error": "Unauthorized"}) + + @router.get("/protected/posts/bearer-token") + def protected_bearer_token(request, context): + auth = request.headers.get("Authorization") + if auth == "Bearer test-token": + return paginate_response(request, generate_posts()) + context.status_code = 401 + return json.dumps({"error": "Unauthorized"}) + + @router.get("/protected/posts/bearer-token-plain-text-error") + def protected_bearer_token_plain_text_erorr(request, context): + auth = request.headers.get("Authorization") + if auth == "Bearer test-token": + return paginate_response(request, generate_posts()) + context.status_code = 401 + return "Unauthorized" + + @router.get("/protected/posts/api-key") + def protected_api_key(request, context): + api_key = request.headers.get("x-api-key") + if api_key == "test-api-key": + return paginate_response(request, generate_posts()) + context.status_code = 401 + return json.dumps({"error": "Unauthorized"}) + + @router.post("/oauth/token") + def oauth_token(request, context): + return json.dumps( + { + "access_token": "test-token", + "expires_in": 3600, + } + ) + + @router.post("/auth/refresh") + def refresh_token(request, context): + body = request.json() + if body.get("refresh_token") == "valid-refresh-token": + return json.dumps({"access_token": "new-valid-token"}) + context.status_code = 401 + return json.dumps({"error": "Invalid refresh token"}) + + router.register_routes(m) + + yield m + + +def assert_pagination(pages, expected_start=0, page_size=10): + for i, page in enumerate(pages): + assert page == [ + {"id": i, "title": f"Post {i}"} for i in range(i * 10, (i + 1) * 10) + ] diff --git a/tests/sources/helpers/rest_client/private_key.pem b/tests/sources/helpers/rest_client/private_key.pem new file mode 100644 index 0000000000..ce4592157b --- /dev/null +++ b/tests/sources/helpers/rest_client/private_key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDQQxVECHvO2Gs9 +MaRlD0HG5IpoJ3jhuG+nTgDEY7AU75nO74juOZuQR6AxO5nS/QeZS6bbjrzgz9P4 +vtDTksuSwXrgFJF1M5qiYwLZBr3ZNQA/e/D39+L2735craFsy8x6Xz5OCSCWaAyu +ufOMl1Yt2vRsDZ+x0OPPvKgUCBkgRMDxPbf4kuWnG/f4Z6czt3oReE6SiriT7EXS +ucNccSzgVs9HRopJ0M7jcbWPwGUfSlA3IO1G5sAEfVCihpzFlC7OoB+qAKj0wnAZ +Kr6gOuEFneoNUlErpLaeQwdRE+h61s5JybxZhFgr69n6kYIPG8ra6spVyB13WYt1 +FMEtL4P1AgMBAAECggEALv0vx2OdoaApZAt3Etk0J17JzrG3P8CIKqi6GhV+9V5R +JwRbMhrb21wZy/ntXVI7XG5aBbhJK/UgV8Of5Ni+Z0yRv4zMe/PqfCCYVCTGAYPI +nEpH5n7u3fXP3jPL0/sQlfy2108OY/kygVrR1YMQzfRUyStywGFIAUdI6gogtyt7 +cjh07mmMc8HUMhAVyluE5hpQCLDv5Xige2PY7zv1TqhI3OoJFi27VeBCSyI7x/94 +GM1XpzdFcvYPNPo6aE9vGnDq8TfYwjy+hkY+D9DRpnEmVEXmeBdsxsSD+ybyprO1 +C2sytiV9d3wJ96fhsYupLK88EGxU2uhmFntHuasMQQKBgQD9cWVo7B18FCV/NAdS +nV3KzNtlIrGRFZ7FMZuVZ/ZjOpvzbTVbla3YbRjTkXYpK9Meo8KczwzxQ2TQ1qxY +67SrhfFRRWzktMWqwBSKHPIig+DnqUCUo7OSA0pN+u6yUvFWdINZucB+yMWtgRrj +8GuAMXD/vaoCiNrHVf2V191fwQKBgQDSXP3cqBjBtDLP3qFwDzOG8cR9qiiDvesQ +DXf5seV/rBCXZvkw81t+PGz0O/UrUonv/FqxQR0GqpAdX1ZM3Jko0WxbfoCgsT0u +1aSzcMq1JQt0CI77T8tIPYvym9FO+Jz89kX0WliL/I7GLsmG5EYBK/+dcJBh1QCE +VaMCgrbxNQKBgB10zYWJU8/1A3qqUGOQuLL2ZlV11892BNMEdgHCaIeV60Q6oCX5 +2o+59lW4pVQZrNr1y4uwIN/1pkUDflqDYqdA1RBOEl7uh77Vvk1jGd1bGIu0RzY/ +ZIKG8V7o2E9Pho820YFfLnlN2nPU+owdiFEI7go7QAQ1ZcAfRW7h/O/BAoGBAJg+ +IKO/LBuUFGoIT4HQHpR9CJ2BtkyR+Drn5HpbWyKpHmDUb2gT15VmmduwQOEXnSiH +1AMQgrc+XYpEYyrBRD8cQXV9+g1R+Fua1tXevXWX19AkGYab2xzvHgd46WRj3Qne +GgacFBVLtPCND+CF+HwEobwJqRSEmRks+QpqG4g5AoGAXpw9CZb+gYfwl2hphFGO +kT/NOfk8PN7WeZAe7ktStZByiGhHWaxqYE0q5favhNG6tMxSdmSOzYF8liHWuvJm +cDHqNVJeTGT8rjW7Iz08wj5F+ZAJYCMkM9aDpDUKJIHnOwYZCGfZxRJCiHTReyR7 +u03hoszfCn13l85qBnYlwaw= +-----END PRIVATE KEY----- diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py new file mode 100644 index 0000000000..7a4c55f9a6 --- /dev/null +++ b/tests/sources/helpers/rest_client/test_client.py @@ -0,0 +1,169 @@ +import os +import pytest +from typing import Any, cast +from dlt.common.typing import TSecretStrValue +from dlt.sources.helpers.requests import Response, Request +from dlt.sources.helpers.rest_client import RESTClient +from dlt.sources.helpers.rest_client.client import Hooks +from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator + +from dlt.sources.helpers.rest_client.auth import AuthConfigBase +from dlt.sources.helpers.rest_client.auth import ( + BearerTokenAuth, + APIKeyAuth, + HttpBasicAuth, + OAuthJWTAuth, +) +from dlt.sources.helpers.rest_client.exceptions import IgnoreResponseException + +from .conftest import assert_pagination + + +def load_private_key(name="private_key.pem"): + key_path = os.path.join(os.path.dirname(__file__), name) + with open(key_path, "r", encoding="utf-8") as key_file: + return key_file.read() + + +TEST_PRIVATE_KEY = load_private_key() + + +@pytest.fixture +def rest_client() -> RESTClient: + return RESTClient( + base_url="https://api.example.com", + headers={"Accept": "application/json"}, + ) + + +@pytest.mark.usefixtures("mock_api_server") +class TestRESTClient: + def test_get_single_resource(self, rest_client): + response = rest_client.get("/posts/1") + assert response.status_code == 200 + assert response.json() == {"id": "1", "body": "Post body 1"} + + def test_pagination(self, rest_client: RESTClient): + pages_iter = rest_client.paginate( + "/posts", + paginator=JSONResponsePaginator(next_url_path="next_page"), + ) + + pages = list(pages_iter) + + assert_pagination(pages) + + def test_page_context(self, rest_client: RESTClient) -> None: + for page in rest_client.paginate( + "/posts", + paginator=JSONResponsePaginator(next_url_path="next_page"), + auth=AuthConfigBase(), + ): + # response that produced data + assert isinstance(page.response, Response) + # updated request + assert isinstance(page.request, Request) + # make request url should be same as next link in paginator + if page.paginator.has_next_page: + assert page.paginator.next_reference == page.request.url + + def test_default_paginator(self, rest_client: RESTClient): + pages_iter = rest_client.paginate("/posts") + + pages = list(pages_iter) + + assert_pagination(pages) + + def test_paginate_with_hooks(self, rest_client: RESTClient): + def response_hook(response: Response, *args: Any, **kwargs: Any) -> None: + if response.status_code == 404: + raise IgnoreResponseException + + hooks: Hooks = { + "response": response_hook, + } + + pages_iter = rest_client.paginate( + "/posts", + paginator=JSONResponsePaginator(next_url_path="next_page"), + hooks=hooks, + ) + + pages = list(pages_iter) + + assert_pagination(pages) + + pages_iter = rest_client.paginate( + "/posts/1/some_details_404", + paginator=JSONResponsePaginator(), + hooks=hooks, + ) + + pages = list(pages_iter) + assert pages == [] + + def test_basic_auth_success(self, rest_client: RESTClient): + response = rest_client.get( + "/protected/posts/basic-auth", + auth=HttpBasicAuth("user", cast(TSecretStrValue, "password")), + ) + assert response.status_code == 200 + assert response.json()["data"][0] == {"id": 0, "title": "Post 0"} + + pages_iter = rest_client.paginate( + "/protected/posts/basic-auth", + auth=HttpBasicAuth("user", cast(TSecretStrValue, "password")), + ) + + pages = list(pages_iter) + assert_pagination(pages) + + def test_bearer_token_auth_success(self, rest_client: RESTClient): + response = rest_client.get( + "/protected/posts/bearer-token", + auth=BearerTokenAuth(cast(TSecretStrValue, "test-token")), + ) + assert response.status_code == 200 + assert response.json()["data"][0] == {"id": 0, "title": "Post 0"} + + pages_iter = rest_client.paginate( + "/protected/posts/bearer-token", + auth=BearerTokenAuth(cast(TSecretStrValue, "test-token")), + ) + + pages = list(pages_iter) + assert_pagination(pages) + + def test_api_key_auth_success(self, rest_client: RESTClient): + response = rest_client.get( + "/protected/posts/api-key", + auth=APIKeyAuth( + name="x-api-key", api_key=cast(TSecretStrValue, "test-api-key") + ), + ) + assert response.status_code == 200 + assert response.json()["data"][0] == {"id": 0, "title": "Post 0"} + + def test_oauth_jwt_auth_success(self, rest_client: RESTClient): + auth = OAuthJWTAuth( + client_id="test-client-id", + private_key=TEST_PRIVATE_KEY, + auth_endpoint="https://api.example.com/oauth/token", + scopes=["read", "write"], + headers={"Content-Type": "application/json"}, + ) + + response = rest_client.get( + "/protected/posts/bearer-token", + auth=auth, + ) + + assert response.status_code == 200 + assert "test-token" in response.request.headers["Authorization"] + + pages_iter = rest_client.paginate( + "/protected/posts/bearer-token", + auth=auth, + ) + + assert_pagination(list(pages_iter)) diff --git a/tests/sources/helpers/rest_client/test_detector.py b/tests/sources/helpers/rest_client/test_detector.py new file mode 100644 index 0000000000..a9af1d36a4 --- /dev/null +++ b/tests/sources/helpers/rest_client/test_detector.py @@ -0,0 +1,360 @@ +import pytest +from dlt.common import jsonpath + +from dlt.sources.helpers.rest_client.detector import ( + find_records, + find_next_page_path, + single_entity_path, +) + + +TEST_RESPONSES = [ + { + "response": { + "data": [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}], + "pagination": {"offset": 0, "limit": 2, "total": 100}, + }, + "expected": { + "type": "offset_limit", + "records_path": "data", + }, + }, + { + "response": { + "items": [ + {"id": 11, "title": "Page Item 1"}, + {"id": 12, "title": "Page Item 2"}, + ], + "page_info": {"current_page": 1, "items_per_page": 2, "total_pages": 50}, + }, + "expected": { + "type": "page_number", + "records_path": "items", + }, + }, + { + "response": { + "products": [ + {"id": 101, "name": "Product 1"}, + {"id": 102, "name": "Product 2"}, + ], + "next_cursor": "eyJpZCI6MTAyfQ==", + }, + "expected": { + "type": "cursor", + "records_path": "products", + "next_path": ["next_cursor"], + }, + }, + { + "response": { + "results": [ + {"id": 201, "description": "Result 1"}, + {"id": 202, "description": "Result 2"}, + ], + "cursors": {"next": "NjM=", "previous": "MTk="}, + }, + "expected": { + "type": "cursor", + "records_path": "results", + "next_path": ["cursors", "next"], + }, + }, + { + "response": { + "entries": [{"id": 31, "value": "Entry 1"}, {"id": 32, "value": "Entry 2"}], + "next_id": 33, + "limit": 2, + }, + "expected": { + "type": "cursor", + "records_path": "entries", + "next_path": ["next_id"], + }, + }, + { + "response": { + "comments": [ + {"id": 51, "text": "Comment 1"}, + {"id": 52, "text": "Comment 2"}, + ], + "page_number": 3, + "total_pages": 15, + }, + "expected": { + "type": "page_number", + "records_path": "comments", + }, + }, + { + "response": { + "count": 1023, + "next": "https://api.example.org/accounts/?page=5", + "previous": "https://api.example.org/accounts/?page=3", + "results": [{"id": 1, "name": "Account 1"}, {"id": 2, "name": "Account 2"}], + }, + "expected": { + "type": "json_link", + "records_path": "results", + "next_path": ["next"], + }, + }, + { + "response": { + "_embedded": { + "items": [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}] + }, + "_links": { + "first": {"href": "http://api.example.com/items?page=0&size=2"}, + "self": {"href": "http://api.example.com/items?page=1&size=2"}, + "next": {"href": "http://api.example.com/items?page=2&size=2"}, + "last": {"href": "http://api.example.com/items?page=50&size=2"}, + }, + "page": {"size": 2, "totalElements": 100, "totalPages": 50, "number": 1}, + }, + "expected": { + "type": "json_link", + "records_path": "_embedded.items", + "next_path": ["_links", "next", "href"], + }, + }, + { + "response": { + "items": [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}], + "meta": { + "currentPage": 1, + "pageSize": 2, + "totalPages": 50, + "totalItems": 100, + }, + "links": { + "firstPage": "/items?page=1&limit=2", + "previousPage": "/items?page=0&limit=2", + "nextPage": "/items?page=2&limit=2", + "lastPage": "/items?page=50&limit=2", + }, + }, + "expected": { + "type": "json_link", + "records_path": "items", + "next_path": ["links", "nextPage"], + }, + }, + { + "response": { + "data": [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}], + "pagination": { + "currentPage": 1, + "pageSize": 2, + "totalPages": 5, + "totalItems": 10, + }, + }, + "expected": { + "type": "page_number", + "records_path": "data", + }, + }, + { + "response": { + "items": [{"id": 1, "title": "Item 1"}, {"id": 2, "title": "Item 2"}], + "pagination": {"page": 1, "perPage": 2, "total": 10, "totalPages": 5}, + }, + "expected": { + "type": "page_number", + "records_path": "items", + }, + }, + { + "response": { + "data": [ + {"id": 1, "description": "Item 1"}, + {"id": 2, "description": "Item 2"}, + ], + "meta": { + "currentPage": 1, + "itemsPerPage": 2, + "totalItems": 10, + "totalPages": 5, + }, + "links": { + "first": "/api/items?page=1", + "previous": None, + "next": "/api/items?page=2", + "last": "/api/items?page=5", + }, + }, + "expected": { + "type": "json_link", + "records_path": "data", + "next_path": ["links", "next"], + }, + }, + { + "response": { + "page": 2, + "per_page": 10, + "total": 100, + "pages": 10, + "data": [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}], + }, + "expected": { + "type": "page_number", + "records_path": "data", + }, + }, + { + "response": { + "currentPage": 1, + "pageSize": 10, + "totalPages": 5, + "totalRecords": 50, + "items": [{"id": 1, "name": "Item 1"}, {"id": 2, "name": "Item 2"}], + }, + "expected": { + "type": "page_number", + "records_path": "items", + }, + }, + { + "response": { + "articles": [ + {"id": 21, "headline": "Article 1"}, + {"id": 22, "headline": "Article 2"}, + ], + "paging": {"current": 3, "size": 2, "total": 60}, + }, + "expected": { + "type": "page_number", + "records_path": "articles", + }, + }, + { + "response": { + "feed": [ + {"id": 41, "content": "Feed Content 1"}, + {"id": 42, "content": "Feed Content 2"}, + ], + "offset": 40, + "limit": 2, + "total_count": 200, + }, + "expected": { + "type": "offset_limit", + "records_path": "feed", + }, + }, + { + "response": { + "query_results": [ + {"id": 81, "snippet": "Result Snippet 1"}, + {"id": 82, "snippet": "Result Snippet 2"}, + ], + "page_details": { + "number": 1, + "size": 2, + "total_elements": 50, + "total_pages": 25, + }, + }, + "expected": { + "type": "page_number", + "records_path": "query_results", + }, + }, + { + "response": { + "posts": [ + {"id": 91, "title": "Blog Post 1"}, + {"id": 92, "title": "Blog Post 2"}, + ], + "pagination_details": { + "current_page": 4, + "posts_per_page": 2, + "total_posts": 100, + "total_pages": 50, + }, + }, + "expected": { + "type": "page_number", + "records_path": "posts", + }, + }, + { + "response": { + "catalog": [ + {"id": 101, "product_name": "Product A"}, + {"id": 102, "product_name": "Product B"}, + ], + "page_metadata": { + "index": 1, + "size": 2, + "total_items": 20, + "total_pages": 10, + }, + }, + "expected": { + "type": "page_number", + "records_path": "catalog", + }, + }, +] + + +@pytest.mark.parametrize("test_case", TEST_RESPONSES) +def test_find_records(test_case): + response = test_case["response"] + expected = test_case["expected"]["records_path"] + r = find_records(response) + # all of them look fine mostly because those are simple cases... + # case 7 fails because it is nested but in fact we select a right response + # assert r is create_nested_accessor(expected)(response) + assert r == jsonpath.find_values(expected, response)[0] + + +@pytest.mark.parametrize("test_case", TEST_RESPONSES) +def test_find_next_page_key(test_case): + response = test_case["response"] + expected = test_case.get("expected").get( + "next_path", None + ) # Some cases may not have next_path + assert find_next_page_path(response) == expected + + +@pytest.mark.skip +@pytest.mark.parametrize( + "path", + [ + "/users/{user_id}", + "/api/v1/products/{product_id}/", + "/api/v1/products/{product_id}//", + "/api/v1/products/{product_id}?param1=value1", + "/api/v1/products/{product_id}#section", + "/api/v1/products/{product_id}/#section", + "/users/{user_id}/posts/{post_id}", + "/users/{user_id}/posts/{post_id}/comments/{comment_id}", + "{entity}", + "/{entity}", + "/{user_123}", + ], +) +def test_single_entity_path_valid(path): + assert single_entity_path(path) is True + + +@pytest.mark.parametrize( + "path", + [ + "/users/user_id", + "/api/v1/products/product_id/", + "/users/{user_id}/details", + "/", + "/{}", + "/users/{123}", + "/users/{user-id}", + "/users/{user id}", + "/users/{user_id}/{", # Invalid ending + ], +) +def test_single_entity_path_invalid(path): + assert single_entity_path(path) is False diff --git a/tests/sources/helpers/rest_client/test_paginators.py b/tests/sources/helpers/rest_client/test_paginators.py new file mode 100644 index 0000000000..cc4dea65dc --- /dev/null +++ b/tests/sources/helpers/rest_client/test_paginators.py @@ -0,0 +1,82 @@ +import pytest +from unittest.mock import Mock + +from requests.models import Response + +from dlt.sources.helpers.rest_client.paginators import ( + SinglePagePaginator, + OffsetPaginator, + HeaderLinkPaginator, + JSONResponsePaginator, +) + + +class TestHeaderLinkPaginator: + def test_update_state_with_next(self): + paginator = HeaderLinkPaginator() + response = Mock(Response) + response.links = {"next": {"url": "http://example.com/next"}} + paginator.update_state(response) + assert paginator.next_reference == "http://example.com/next" + assert paginator.has_next_page is True + + def test_update_state_without_next(self): + paginator = HeaderLinkPaginator() + response = Mock(Response) + response.links = {} + paginator.update_state(response) + assert paginator.has_next_page is False + + +class TestJSONResponsePaginator: + def test_update_state_with_next(self): + paginator = JSONResponsePaginator() + response = Mock( + Response, json=lambda: {"next": "http://example.com/next", "results": []} + ) + paginator.update_state(response) + assert paginator.next_reference == "http://example.com/next" + assert paginator.has_next_page is True + + def test_update_state_without_next(self): + paginator = JSONResponsePaginator() + response = Mock(Response, json=lambda: {"results": []}) + paginator.update_state(response) + assert paginator.next_reference is None + assert paginator.has_next_page is False + + +class TestSinglePagePaginator: + def test_update_state(self): + paginator = SinglePagePaginator() + response = Mock(Response) + paginator.update_state(response) + assert paginator.has_next_page is False + + def test_update_state_with_next(self): + paginator = SinglePagePaginator() + response = Mock( + Response, json=lambda: {"next": "http://example.com/next", "results": []} + ) + response.links = {"next": {"url": "http://example.com/next"}} + paginator.update_state(response) + assert paginator.has_next_page is False + + +class TestOffsetPaginator: + def test_update_state(self): + paginator = OffsetPaginator(initial_offset=0, initial_limit=10) + response = Mock(Response, json=lambda: {"total": 20}) + paginator.update_state(response) + assert paginator.offset == 10 + assert paginator.has_next_page is True + + # Test for reaching the end + paginator.update_state(response) + assert paginator.has_next_page is False + + def test_update_state_without_total(self): + paginator = OffsetPaginator(0, 10) + response = Mock(Response, json=lambda: {}) + with pytest.raises(ValueError): + paginator.update_state(response) diff --git a/tests/sources/helpers/rest_client/test_requests_paginate.py b/tests/sources/helpers/rest_client/test_requests_paginate.py new file mode 100644 index 0000000000..5ea137c735 --- /dev/null +++ b/tests/sources/helpers/rest_client/test_requests_paginate.py @@ -0,0 +1,17 @@ +import pytest + +from dlt.sources.helpers.requests import paginate +from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator +from .conftest import assert_pagination + + +@pytest.mark.usefixtures("mock_api_server") +def test_requests_paginate(): + pages_iter = paginate( + "https://api.example.com/posts", + paginator=JSONResponsePaginator(next_url_path="next_page"), + ) + + pages = list(pages_iter) + + assert_pagination(pages) diff --git a/tests/sources/helpers/rest_client/test_utils.py b/tests/sources/helpers/rest_client/test_utils.py new file mode 100644 index 0000000000..0de9729a42 --- /dev/null +++ b/tests/sources/helpers/rest_client/test_utils.py @@ -0,0 +1,90 @@ +import pytest +from dlt.sources.helpers.rest_client.utils import join_url + + +@pytest.mark.parametrize( + "base_url, path, expected", + [ + # Normal cases + ( + "http://example.com", + "path/to/resource", + "http://example.com/path/to/resource", + ), + ( + "http://example.com/", + "/path/to/resource", + "http://example.com/path/to/resource", + ), + ( + "http://example.com/", + "path/to/resource", + "http://example.com/path/to/resource", + ), + ( + "http://example.com", + "//path/to/resource", + "http://example.com/path/to/resource", + ), + ( + "http://example.com///", + "//path/to/resource", + "http://example.com/path/to/resource", + ), + # Trailing and leading slashes + ("http://example.com/", "/", "http://example.com/"), + ("http://example.com", "/", "http://example.com/"), + ("http://example.com/", "///", "http://example.com/"), + ("http://example.com", "///", "http://example.com/"), + ("/", "path/to/resource", "/path/to/resource"), + ("/", "/path/to/resource", "/path/to/resource"), + # Empty strings + ("", "", ""), + ( + "", + "http://example.com/path/to/resource", + "http://example.com/path/to/resource", + ), + ("", "path/to/resource", "path/to/resource"), + ("http://example.com", "", "http://example.com"), + # Query parameters and fragments + ( + "http://example.com", + "path/to/resource?query=123", + "http://example.com/path/to/resource?query=123", + ), + ( + "http://example.com/", + "path/to/resource#fragment", + "http://example.com/path/to/resource#fragment", + ), + # Special characters in the path + ( + "http://example.com", + "/path/to/resource with spaces", + "http://example.com/path/to/resource with spaces", + ), + ("http://example.com", "/path/with/中文", "http://example.com/path/with/中文"), + # Protocols and subdomains + ("https://sub.example.com", "path", "https://sub.example.com/path"), + ("ftp://example.com", "/path", "ftp://example.com/path"), + # Missing protocol in base_url + ("example.com", "path", "example.com/path"), + ], +) +def test_join_url(base_url, path, expected): + assert join_url(base_url, path) == expected + + +@pytest.mark.parametrize( + "base_url, path, exception", + [ + (None, "path", ValueError), + ("http://example.com", None, AttributeError), + (123, "path", AttributeError), + ("http://example.com", 123, AttributeError), + ], +) +def test_join_url_invalid_input_types(base_url, path, exception): + with pytest.raises(exception): + join_url(base_url, path) diff --git a/tests/utils.py b/tests/utils.py index 924f44de73..00523486ea 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -19,7 +19,7 @@ ConfigProvidersContext, ) from dlt.common.pipeline import PipelineContext -from dlt.common.runtime.logger import init_logging +from dlt.common.runtime.init import init_logging from dlt.common.runtime.telemetry import start_telemetry, stop_telemetry from dlt.common.schema import Schema from dlt.common.storages import FileStorage