Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: ensure custom session can be provided to rest client #1396

Merged
merged 10 commits into from
May 27, 2024
9 changes: 8 additions & 1 deletion dlt/common/configuration/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from .specs.base_configuration import configspec, is_valid_hint, is_secret_hint, resolve_type
from .specs.base_configuration import (
configspec,
is_valid_hint,
is_secret_hint,
resolve_type,
NotResolved,
)
from .specs import known_sections
from .resolve import resolve_configuration, inject_section
from .inject import with_config, last_config, get_fun_spec, create_resolved_partial
Expand All @@ -15,6 +21,7 @@
"configspec",
"is_valid_hint",
"is_secret_hint",
"NotResolved",
"resolve_type",
"known_sections",
"resolve_configuration",
Expand Down
6 changes: 3 additions & 3 deletions dlt/common/configuration/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
StrAny,
TSecretValue,
get_all_types_of_class_in_union,
is_final_type,
is_optional_type,
is_union_type,
)
Expand All @@ -21,6 +20,7 @@
is_context_inner_hint,
is_base_configuration_inner_hint,
is_valid_hint,
is_hint_not_resolved,
)
from dlt.common.configuration.specs.config_section_context import ConfigSectionContext
from dlt.common.configuration.specs.exceptions import NativeValueError
Expand Down Expand Up @@ -194,7 +194,7 @@ def _resolve_config_fields(
if explicit_values:
explicit_value = explicit_values.get(key)
else:
if is_final_type(hint):
if is_hint_not_resolved(hint):
# for final fields default value is like explicit
explicit_value = default_value
else:
Expand Down Expand Up @@ -258,7 +258,7 @@ def _resolve_config_fields(
unresolved_fields[key] = traces
# set resolved value in config
if default_value != current_value:
if not is_final_type(hint):
if not is_hint_not_resolved(hint):
# ignore final types
setattr(config, key, current_value)

Expand Down
38 changes: 37 additions & 1 deletion dlt/common/configuration/specs/base_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ClassVar,
TypeVar,
)
from typing_extensions import get_args, get_origin, dataclass_transform
from typing_extensions import get_args, get_origin, dataclass_transform, Annotated, TypeAlias
from functools import wraps

if TYPE_CHECKING:
Expand All @@ -29,8 +29,11 @@
TDtcField = dataclasses.Field

from dlt.common.typing import (
AnyType,
TAnyClass,
extract_inner_type,
is_annotated,
is_final_type,
is_optional_type,
is_union_type,
)
Expand All @@ -48,6 +51,34 @@
_C = TypeVar("_C", bound="CredentialsConfiguration")


class NotResolved:
"""Used in type annotations to indicate types that should not be resolved."""

def __init__(self, not_resolved: bool = True):
self.not_resolved = not_resolved

def __bool__(self) -> bool:
return self.not_resolved


def is_hint_not_resolved(hint: AnyType) -> bool:
"""Checks if hint should NOT be resolved. Final and types annotated like

>>> Annotated[str, NotResolved()]

are not resolved.
"""
if is_final_type(hint):
return True

if is_annotated(hint):
_, *a_m = get_args(hint)
for annotation in a_m:
if isinstance(annotation, NotResolved):
return bool(annotation)
return False


def is_base_configuration_inner_hint(inner_hint: Type[Any]) -> bool:
return inspect.isclass(inner_hint) and issubclass(inner_hint, BaseConfiguration)

Expand All @@ -70,6 +101,11 @@ def is_valid_hint(hint: Type[Any]) -> bool:
if get_origin(hint) is ClassVar:
# class vars are skipped by dataclass
return True

if is_hint_not_resolved(hint):
# all hints that are not resolved are valid
return True

hint = extract_inner_type(hint)
hint = get_config_if_union_hint(hint) or hint
hint = get_origin(hint) or hint
Expand Down
16 changes: 8 additions & 8 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
Any,
TypeVar,
Generic,
Final,
)
from typing_extensions import Annotated
import datetime # noqa: 251
from copy import deepcopy
import inspect
Expand All @@ -35,7 +35,7 @@
has_column_with_prop,
get_first_column_name_with_prop,
)
from dlt.common.configuration import configspec, resolve_configuration, known_sections
from dlt.common.configuration import configspec, resolve_configuration, known_sections, NotResolved
from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration
from dlt.common.configuration.accessors import config
from dlt.common.destination.capabilities import DestinationCapabilitiesContext
Expand Down Expand Up @@ -78,7 +78,7 @@ class StateInfo(NamedTuple):

@configspec
class DestinationClientConfiguration(BaseConfiguration):
destination_type: Final[str] = dataclasses.field(
destination_type: Annotated[str, NotResolved()] = dataclasses.field(
default=None, init=False, repr=False, compare=False
) # which destination to load data to
credentials: Optional[CredentialsConfiguration] = None
Expand All @@ -103,11 +103,11 @@ def on_resolved(self) -> None:
class DestinationClientDwhConfiguration(DestinationClientConfiguration):
"""Configuration of a destination that supports datasets/schemas"""

dataset_name: Final[str] = dataclasses.field(
dataset_name: Annotated[str, NotResolved()] = dataclasses.field(
default=None, init=False, repr=False, compare=False
) # dataset must be final so it is not configurable
) # dataset cannot be resolved
"""dataset name in the destination to load data to, for schemas that are not default schema, it is used as dataset prefix"""
default_schema_name: Final[Optional[str]] = dataclasses.field(
default_schema_name: Annotated[Optional[str], NotResolved()] = dataclasses.field(
default=None, init=False, repr=False, compare=False
)
"""name of default schema to be used to name effective dataset to load data to"""
Expand All @@ -121,8 +121,8 @@ def _bind_dataset_name(

This method is intended to be used internally.
"""
self.dataset_name = dataset_name # type: ignore[misc]
self.default_schema_name = default_schema_name # type: ignore[misc]
self.dataset_name = dataset_name
self.default_schema_name = default_schema_name
return self

def normalize_dataset_name(self, schema: Schema) -> str:
Expand Down
7 changes: 5 additions & 2 deletions dlt/destinations/impl/qdrant/configuration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import dataclasses
from typing import Optional, Final
from typing_extensions import Annotated

from dlt.common.configuration import configspec
from dlt.common.configuration import configspec, NotResolved
from dlt.common.configuration.specs.base_configuration import (
BaseConfiguration,
CredentialsConfiguration,
Expand Down Expand Up @@ -55,7 +56,9 @@ class QdrantClientConfiguration(DestinationClientDwhConfiguration):
dataset_separator: str = "_"

# make it optional so empty dataset is allowed
dataset_name: Final[Optional[str]] = dataclasses.field(default=None, init=False, repr=False, compare=False) # type: ignore[misc]
dataset_name: Annotated[Optional[str], NotResolved()] = dataclasses.field(
default=None, init=False, repr=False, compare=False
)

# Batch size for generating embeddings
embedding_batch_size: int = 32
Expand Down
7 changes: 5 additions & 2 deletions dlt/destinations/impl/weaviate/configuration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import dataclasses
from typing import Dict, Literal, Optional, Final
from typing_extensions import Annotated
from urllib.parse import urlparse

from dlt.common.configuration import configspec
from dlt.common.configuration import configspec, NotResolved
from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration
from dlt.common.destination.reference import DestinationClientDwhConfiguration
from dlt.common.utils import digest128
Expand All @@ -26,7 +27,9 @@ def __str__(self) -> str:
class WeaviateClientConfiguration(DestinationClientDwhConfiguration):
destination_type: Final[str] = dataclasses.field(default="weaviate", init=False, repr=False, compare=False) # type: ignore
# make it optional so empty dataset is allowed
dataset_name: Optional[str] = None # type: ignore[misc]
dataset_name: Annotated[Optional[str], NotResolved()] = dataclasses.field(
default=None, init=False, repr=False, compare=False
)

batch_size: int = 100
batch_workers: int = 1
Expand Down
2 changes: 1 addition & 1 deletion dlt/sources/helpers/requests/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def _make_session(self) -> Session:
session.mount("http://", self._adapter)
session.mount("https://", self._adapter)
retry = _make_retry(**self._retry_kwargs)
session.request = retry.wraps(session.request) # type: ignore[method-assign]
session.send = retry.wraps(session.send) # type: ignore[method-assign]
return session

@property
Expand Down
16 changes: 12 additions & 4 deletions dlt/sources/helpers/rest_client/auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from base64 import b64encode
import dataclasses
import math
from typing import (
List,
Expand All @@ -12,12 +13,13 @@
Iterable,
TYPE_CHECKING,
)
from typing_extensions import Annotated
from requests.auth import AuthBase
from requests import PreparedRequest # noqa: I251
from requests import PreparedRequest, Session as BaseSession # noqa: I251

from dlt.common import logger
from dlt.common.exceptions import MissingDependencyException
from dlt.common.configuration.specs.base_configuration import configspec
from dlt.common.configuration.specs.base_configuration import configspec, NotResolved
from dlt.common.configuration.specs import CredentialsConfiguration
from dlt.common.configuration.specs.exceptions import NativeValueError
from dlt.common.pendulum import pendulum
Expand Down Expand Up @@ -146,19 +148,25 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest:
class OAuthJWTAuth(BearerTokenAuth):
"""This is a form of Bearer auth, actually there's not standard way to declare it in openAPI"""

format: Final[Literal["JWT"]] = "JWT" # noqa: A003
format: Final[Literal["JWT"]] = dataclasses.field( # noqa: A003
default="JWT", init=False, repr=False, compare=False
)
client_id: str = None
private_key: TSecretStrValue = None
auth_endpoint: str = None
scopes: Optional[Union[str, List[str]]] = None
headers: Optional[Dict[str, str]] = None
private_key_passphrase: Optional[TSecretStrValue] = None
default_token_expiration: int = 3600
session: Annotated[BaseSession, NotResolved()] = None

def __post_init__(self) -> None:
self.scopes = self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes)
self.token = None
self.token_expiry: Optional[pendulum.DateTime] = None
# use default system session is not specified
if self.session is None:
self.session = requests.client.session

def __call__(self, r: PreparedRequest) -> PreparedRequest:
if self.token is None or self.is_token_expired():
Expand All @@ -183,7 +191,7 @@ def obtain_token(self) -> None:

logger.debug(f"Obtaining token from {self.auth_endpoint}")

response = requests.post(self.auth_endpoint, headers=self.headers, data=data)
response = self.session.post(self.auth_endpoint, headers=self.headers, data=data)
response.raise_for_status()

token_response = response.json()
Expand Down
23 changes: 12 additions & 11 deletions dlt/sources/helpers/rest_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ def __init__(
self.auth = auth

if session:
self._validate_session_raise_for_status(session)
self.session = session
# dlt.sources.helpers.requests.session.Session
# has raise_for_status=True by default
self.session = _warn_if_raise_for_status_and_return(session)
else:
self.session = Client(raise_for_status=False).session

Expand All @@ -92,15 +93,6 @@ def __init__(

self.data_selector = data_selector

def _validate_session_raise_for_status(self, session: BaseSession) -> None:
# dlt.sources.helpers.requests.session.Session
# has raise_for_status=True by default
z3z1ma marked this conversation as resolved.
Show resolved Hide resolved
if getattr(self.session, "raise_for_status", False):
logger.warning(
"The session provided has raise_for_status enabled. "
"This may cause unexpected behavior."
)

def _create_request(
self,
path: str,
Expand Down Expand Up @@ -298,3 +290,12 @@ def detect_paginator(self, response: Response, data: Any) -> BasePaginator:
" instance of the paginator as some settings may not be guessed correctly."
)
return paginator


def _warn_if_raise_for_status_and_return(session: BaseSession) -> BaseSession:
"""A generic function to warn if the session has raise_for_status enabled."""
if getattr(session, "raise_for_status", False):
logger.warning(
"The session provided has raise_for_status enabled. This may cause unexpected behavior."
)
return session
55 changes: 54 additions & 1 deletion tests/common/configuration/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
Optional,
Type,
Union,
TYPE_CHECKING,
)
from typing_extensions import Annotated

from dlt.common import json, pendulum, Decimal, Wei
from dlt.common.configuration.providers.provider import ConfigProvider
from dlt.common.configuration.specs.base_configuration import NotResolved, is_hint_not_resolved
from dlt.common.configuration.specs.gcp_credentials import (
GcpServiceAccountCredentialsWithoutDefaults,
)
Expand Down Expand Up @@ -917,6 +918,58 @@ def test_is_valid_hint() -> None:
assert is_valid_hint(Wei) is True
# any class type, except deriving from BaseConfiguration is wrong type
assert is_valid_hint(ConfigFieldMissingException) is False
# but final and annotated types are not ok because they are not resolved
assert is_valid_hint(Final[ConfigFieldMissingException]) is True # type: ignore[arg-type]
assert is_valid_hint(Annotated[ConfigFieldMissingException, NotResolved()]) is True # type: ignore[arg-type]
assert is_valid_hint(Annotated[ConfigFieldMissingException, "REQ"]) is False # type: ignore[arg-type]


def test_is_not_resolved_hint() -> None:
assert is_hint_not_resolved(Final[ConfigFieldMissingException]) is True
assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, NotResolved()]) is True
assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, NotResolved(True)]) is True
assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, NotResolved(False)]) is False
assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, "REQ"]) is False
assert is_hint_not_resolved(str) is False


def test_not_resolved_hint() -> None:
class SentinelClass:
pass

@configspec
class OptionalNotResolveConfiguration(BaseConfiguration):
trace: Final[Optional[SentinelClass]] = None
traces: Annotated[Optional[List[SentinelClass]], NotResolved()] = None

c = resolve.resolve_configuration(OptionalNotResolveConfiguration())
assert c.trace is None
assert c.traces is None

s1 = SentinelClass()
s2 = SentinelClass()

c = resolve.resolve_configuration(OptionalNotResolveConfiguration(s1, [s2]))
assert c.trace is s1
assert c.traces[0] is s2

@configspec
class NotResolveConfiguration(BaseConfiguration):
trace: Final[SentinelClass] = None
traces: Annotated[List[SentinelClass], NotResolved()] = None

with pytest.raises(ConfigFieldMissingException):
resolve.resolve_configuration(NotResolveConfiguration())

with pytest.raises(ConfigFieldMissingException):
resolve.resolve_configuration(NotResolveConfiguration(trace=s1))

with pytest.raises(ConfigFieldMissingException):
resolve.resolve_configuration(NotResolveConfiguration(traces=[s2]))

c2 = resolve.resolve_configuration(NotResolveConfiguration(s1, [s2]))
assert c2.trace is s1
assert c2.traces[0] is s2


def test_configspec_auto_base_config_derivation() -> None:
Expand Down
Loading
Loading