diff --git a/dlt/common/configuration/__init__.py b/dlt/common/configuration/__init__.py index 8de57f7799..2abc31b17d 100644 --- a/dlt/common/configuration/__init__.py +++ b/dlt/common/configuration/__init__.py @@ -1,4 +1,10 @@ -from .specs.base_configuration import configspec, is_valid_hint, is_secret_hint, resolve_type +from .specs.base_configuration import ( + configspec, + is_valid_hint, + is_secret_hint, + resolve_type, + NotResolved, +) from .specs import known_sections from .resolve import resolve_configuration, inject_section from .inject import with_config, last_config, get_fun_spec, create_resolved_partial @@ -15,6 +21,7 @@ "configspec", "is_valid_hint", "is_secret_hint", + "NotResolved", "resolve_type", "known_sections", "resolve_configuration", diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index ebfa7b6b89..9101cfdd9c 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -8,7 +8,6 @@ StrAny, TSecretValue, get_all_types_of_class_in_union, - is_final_type, is_optional_type, is_union_type, ) @@ -21,6 +20,7 @@ is_context_inner_hint, is_base_configuration_inner_hint, is_valid_hint, + is_hint_not_resolved, ) from dlt.common.configuration.specs.config_section_context import ConfigSectionContext from dlt.common.configuration.specs.exceptions import NativeValueError @@ -194,7 +194,7 @@ def _resolve_config_fields( if explicit_values: explicit_value = explicit_values.get(key) else: - if is_final_type(hint): + if is_hint_not_resolved(hint): # for final fields default value is like explicit explicit_value = default_value else: @@ -258,7 +258,7 @@ def _resolve_config_fields( unresolved_fields[key] = traces # set resolved value in config if default_value != current_value: - if not is_final_type(hint): + if not is_hint_not_resolved(hint): # ignore final types setattr(config, key, current_value) diff --git a/dlt/common/configuration/specs/base_configuration.py b/dlt/common/configuration/specs/base_configuration.py index 1329feae6c..006cde8dce 100644 --- a/dlt/common/configuration/specs/base_configuration.py +++ b/dlt/common/configuration/specs/base_configuration.py @@ -20,7 +20,7 @@ ClassVar, TypeVar, ) -from typing_extensions import get_args, get_origin, dataclass_transform +from typing_extensions import get_args, get_origin, dataclass_transform, Annotated, TypeAlias from functools import wraps if TYPE_CHECKING: @@ -29,8 +29,11 @@ TDtcField = dataclasses.Field from dlt.common.typing import ( + AnyType, TAnyClass, extract_inner_type, + is_annotated, + is_final_type, is_optional_type, is_union_type, ) @@ -48,6 +51,34 @@ _C = TypeVar("_C", bound="CredentialsConfiguration") +class NotResolved: + """Used in type annotations to indicate types that should not be resolved.""" + + def __init__(self, not_resolved: bool = True): + self.not_resolved = not_resolved + + def __bool__(self) -> bool: + return self.not_resolved + + +def is_hint_not_resolved(hint: AnyType) -> bool: + """Checks if hint should NOT be resolved. Final and types annotated like + + >>> Annotated[str, NotResolved()] + + are not resolved. + """ + if is_final_type(hint): + return True + + if is_annotated(hint): + _, *a_m = get_args(hint) + for annotation in a_m: + if isinstance(annotation, NotResolved): + return bool(annotation) + return False + + def is_base_configuration_inner_hint(inner_hint: Type[Any]) -> bool: return inspect.isclass(inner_hint) and issubclass(inner_hint, BaseConfiguration) @@ -70,6 +101,11 @@ def is_valid_hint(hint: Type[Any]) -> bool: if get_origin(hint) is ClassVar: # class vars are skipped by dataclass return True + + if is_hint_not_resolved(hint): + # all hints that are not resolved are valid + return True + hint = extract_inner_type(hint) hint = get_config_if_union_hint(hint) or hint hint = get_origin(hint) or hint diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 2ad5131e63..d4cdfb729d 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -18,8 +18,8 @@ Any, TypeVar, Generic, - Final, ) +from typing_extensions import Annotated import datetime # noqa: 251 from copy import deepcopy import inspect @@ -35,7 +35,7 @@ has_column_with_prop, get_first_column_name_with_prop, ) -from dlt.common.configuration import configspec, resolve_configuration, known_sections +from dlt.common.configuration import configspec, resolve_configuration, known_sections, NotResolved from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration from dlt.common.configuration.accessors import config from dlt.common.destination.capabilities import DestinationCapabilitiesContext @@ -78,7 +78,7 @@ class StateInfo(NamedTuple): @configspec class DestinationClientConfiguration(BaseConfiguration): - destination_type: Final[str] = dataclasses.field( + destination_type: Annotated[str, NotResolved()] = dataclasses.field( default=None, init=False, repr=False, compare=False ) # which destination to load data to credentials: Optional[CredentialsConfiguration] = None @@ -103,11 +103,11 @@ def on_resolved(self) -> None: class DestinationClientDwhConfiguration(DestinationClientConfiguration): """Configuration of a destination that supports datasets/schemas""" - dataset_name: Final[str] = dataclasses.field( + dataset_name: Annotated[str, NotResolved()] = dataclasses.field( default=None, init=False, repr=False, compare=False - ) # dataset must be final so it is not configurable + ) # dataset cannot be resolved """dataset name in the destination to load data to, for schemas that are not default schema, it is used as dataset prefix""" - default_schema_name: Final[Optional[str]] = dataclasses.field( + default_schema_name: Annotated[Optional[str], NotResolved()] = dataclasses.field( default=None, init=False, repr=False, compare=False ) """name of default schema to be used to name effective dataset to load data to""" @@ -121,8 +121,8 @@ def _bind_dataset_name( This method is intended to be used internally. """ - self.dataset_name = dataset_name # type: ignore[misc] - self.default_schema_name = default_schema_name # type: ignore[misc] + self.dataset_name = dataset_name + self.default_schema_name = default_schema_name return self def normalize_dataset_name(self, schema: Schema) -> str: diff --git a/dlt/destinations/impl/qdrant/configuration.py b/dlt/destinations/impl/qdrant/configuration.py index d589537742..fd11cc7dcb 100644 --- a/dlt/destinations/impl/qdrant/configuration.py +++ b/dlt/destinations/impl/qdrant/configuration.py @@ -1,7 +1,8 @@ import dataclasses from typing import Optional, Final +from typing_extensions import Annotated -from dlt.common.configuration import configspec +from dlt.common.configuration import configspec, NotResolved from dlt.common.configuration.specs.base_configuration import ( BaseConfiguration, CredentialsConfiguration, @@ -55,7 +56,9 @@ class QdrantClientConfiguration(DestinationClientDwhConfiguration): dataset_separator: str = "_" # make it optional so empty dataset is allowed - dataset_name: Final[Optional[str]] = dataclasses.field(default=None, init=False, repr=False, compare=False) # type: ignore[misc] + dataset_name: Annotated[Optional[str], NotResolved()] = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) # Batch size for generating embeddings embedding_batch_size: int = 32 diff --git a/dlt/destinations/impl/weaviate/configuration.py b/dlt/destinations/impl/weaviate/configuration.py index 90fb7ce5ce..1a053e41f4 100644 --- a/dlt/destinations/impl/weaviate/configuration.py +++ b/dlt/destinations/impl/weaviate/configuration.py @@ -1,8 +1,9 @@ import dataclasses from typing import Dict, Literal, Optional, Final +from typing_extensions import Annotated from urllib.parse import urlparse -from dlt.common.configuration import configspec +from dlt.common.configuration import configspec, NotResolved from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration from dlt.common.destination.reference import DestinationClientDwhConfiguration from dlt.common.utils import digest128 @@ -26,7 +27,9 @@ def __str__(self) -> str: class WeaviateClientConfiguration(DestinationClientDwhConfiguration): destination_type: Final[str] = dataclasses.field(default="weaviate", init=False, repr=False, compare=False) # type: ignore # make it optional so empty dataset is allowed - dataset_name: Optional[str] = None # type: ignore[misc] + dataset_name: Annotated[Optional[str], NotResolved()] = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) batch_size: int = 100 batch_workers: int = 1 diff --git a/dlt/sources/helpers/requests/retry.py b/dlt/sources/helpers/requests/retry.py index c9a813598f..3f9d7d559e 100644 --- a/dlt/sources/helpers/requests/retry.py +++ b/dlt/sources/helpers/requests/retry.py @@ -239,7 +239,7 @@ def _make_session(self) -> Session: session.mount("http://", self._adapter) session.mount("https://", self._adapter) retry = _make_retry(**self._retry_kwargs) - session.request = retry.wraps(session.request) # type: ignore[method-assign] + session.send = retry.wraps(session.send) # type: ignore[method-assign] return session @property diff --git a/dlt/sources/helpers/rest_client/auth.py b/dlt/sources/helpers/rest_client/auth.py index 020c63a195..29e6d8c77a 100644 --- a/dlt/sources/helpers/rest_client/auth.py +++ b/dlt/sources/helpers/rest_client/auth.py @@ -1,4 +1,5 @@ from base64 import b64encode +import dataclasses import math from typing import ( List, @@ -12,12 +13,13 @@ Iterable, TYPE_CHECKING, ) +from typing_extensions import Annotated from requests.auth import AuthBase -from requests import PreparedRequest # noqa: I251 +from requests import PreparedRequest, Session as BaseSession # noqa: I251 from dlt.common import logger from dlt.common.exceptions import MissingDependencyException -from dlt.common.configuration.specs.base_configuration import configspec +from dlt.common.configuration.specs.base_configuration import configspec, NotResolved from dlt.common.configuration.specs import CredentialsConfiguration from dlt.common.configuration.specs.exceptions import NativeValueError from dlt.common.pendulum import pendulum @@ -146,7 +148,9 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest: class OAuthJWTAuth(BearerTokenAuth): """This is a form of Bearer auth, actually there's not standard way to declare it in openAPI""" - format: Final[Literal["JWT"]] = "JWT" # noqa: A003 + format: Final[Literal["JWT"]] = dataclasses.field( # noqa: A003 + default="JWT", init=False, repr=False, compare=False + ) client_id: str = None private_key: TSecretStrValue = None auth_endpoint: str = None @@ -154,11 +158,15 @@ class OAuthJWTAuth(BearerTokenAuth): headers: Optional[Dict[str, str]] = None private_key_passphrase: Optional[TSecretStrValue] = None default_token_expiration: int = 3600 + session: Annotated[BaseSession, NotResolved()] = None def __post_init__(self) -> None: self.scopes = self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes) self.token = None self.token_expiry: Optional[pendulum.DateTime] = None + # use default system session is not specified + if self.session is None: + self.session = requests.client.session def __call__(self, r: PreparedRequest) -> PreparedRequest: if self.token is None or self.is_token_expired(): @@ -183,7 +191,7 @@ def obtain_token(self) -> None: logger.debug(f"Obtaining token from {self.auth_endpoint}") - response = requests.post(self.auth_endpoint, headers=self.headers, data=data) + response = self.session.post(self.auth_endpoint, headers=self.headers, data=data) response.raise_for_status() token_response = response.json() diff --git a/dlt/sources/helpers/rest_client/client.py b/dlt/sources/helpers/rest_client/client.py index 7d1145a890..dc7304f159 100644 --- a/dlt/sources/helpers/rest_client/client.py +++ b/dlt/sources/helpers/rest_client/client.py @@ -82,8 +82,9 @@ def __init__( self.auth = auth if session: - self._validate_session_raise_for_status(session) - self.session = session + # dlt.sources.helpers.requests.session.Session + # has raise_for_status=True by default + self.session = _warn_if_raise_for_status_and_return(session) else: self.session = Client(raise_for_status=False).session @@ -92,15 +93,6 @@ def __init__( self.data_selector = data_selector - def _validate_session_raise_for_status(self, session: BaseSession) -> None: - # dlt.sources.helpers.requests.session.Session - # has raise_for_status=True by default - if getattr(self.session, "raise_for_status", False): - logger.warning( - "The session provided has raise_for_status enabled. " - "This may cause unexpected behavior." - ) - def _create_request( self, path: str, @@ -298,3 +290,12 @@ def detect_paginator(self, response: Response, data: Any) -> BasePaginator: " instance of the paginator as some settings may not be guessed correctly." ) return paginator + + +def _warn_if_raise_for_status_and_return(session: BaseSession) -> BaseSession: + """A generic function to warn if the session has raise_for_status enabled.""" + if getattr(session, "raise_for_status", False): + logger.warning( + "The session provided has raise_for_status enabled. This may cause unexpected behavior." + ) + return session diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index 84b2d1893d..43ccdf856c 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -12,11 +12,12 @@ Optional, Type, Union, - TYPE_CHECKING, ) +from typing_extensions import Annotated from dlt.common import json, pendulum, Decimal, Wei from dlt.common.configuration.providers.provider import ConfigProvider +from dlt.common.configuration.specs.base_configuration import NotResolved, is_hint_not_resolved from dlt.common.configuration.specs.gcp_credentials import ( GcpServiceAccountCredentialsWithoutDefaults, ) @@ -917,6 +918,58 @@ def test_is_valid_hint() -> None: assert is_valid_hint(Wei) is True # any class type, except deriving from BaseConfiguration is wrong type assert is_valid_hint(ConfigFieldMissingException) is False + # but final and annotated types are not ok because they are not resolved + assert is_valid_hint(Final[ConfigFieldMissingException]) is True # type: ignore[arg-type] + assert is_valid_hint(Annotated[ConfigFieldMissingException, NotResolved()]) is True # type: ignore[arg-type] + assert is_valid_hint(Annotated[ConfigFieldMissingException, "REQ"]) is False # type: ignore[arg-type] + + +def test_is_not_resolved_hint() -> None: + assert is_hint_not_resolved(Final[ConfigFieldMissingException]) is True + assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, NotResolved()]) is True + assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, NotResolved(True)]) is True + assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, NotResolved(False)]) is False + assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, "REQ"]) is False + assert is_hint_not_resolved(str) is False + + +def test_not_resolved_hint() -> None: + class SentinelClass: + pass + + @configspec + class OptionalNotResolveConfiguration(BaseConfiguration): + trace: Final[Optional[SentinelClass]] = None + traces: Annotated[Optional[List[SentinelClass]], NotResolved()] = None + + c = resolve.resolve_configuration(OptionalNotResolveConfiguration()) + assert c.trace is None + assert c.traces is None + + s1 = SentinelClass() + s2 = SentinelClass() + + c = resolve.resolve_configuration(OptionalNotResolveConfiguration(s1, [s2])) + assert c.trace is s1 + assert c.traces[0] is s2 + + @configspec + class NotResolveConfiguration(BaseConfiguration): + trace: Final[SentinelClass] = None + traces: Annotated[List[SentinelClass], NotResolved()] = None + + with pytest.raises(ConfigFieldMissingException): + resolve.resolve_configuration(NotResolveConfiguration()) + + with pytest.raises(ConfigFieldMissingException): + resolve.resolve_configuration(NotResolveConfiguration(trace=s1)) + + with pytest.raises(ConfigFieldMissingException): + resolve.resolve_configuration(NotResolveConfiguration(traces=[s2])) + + c2 = resolve.resolve_configuration(NotResolveConfiguration(s1, [s2])) + assert c2.trace is s1 + assert c2.traces[0] is s2 def test_configspec_auto_base_config_derivation() -> None: diff --git a/tests/load/utils.py b/tests/load/utils.py index 81107e83d9..c03470676f 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -574,8 +574,8 @@ def yield_client( destination = Destination.from_reference(destination_type) # create initial config dest_config: DestinationClientDwhConfiguration = None - dest_config = destination.spec() # type: ignore[assignment] - dest_config.dataset_name = dataset_name # type: ignore[misc] + dest_config = destination.spec() # type: ignore + dest_config.dataset_name = dataset_name if default_config_values is not None: # apply the values to credentials, if dict is provided it will be used as default @@ -597,7 +597,7 @@ def yield_client( staging_config = DestinationClientStagingConfiguration( bucket_url=AWS_BUCKET, )._bind_dataset_name(dataset_name=dest_config.dataset_name) - staging_config.destination_type = "filesystem" # type: ignore[misc] + staging_config.destination_type = "filesystem" staging_config.resolve() dest_config.staging_config = staging_config # type: ignore[attr-defined] diff --git a/tests/load/weaviate/test_weaviate_client.py b/tests/load/weaviate/test_weaviate_client.py index 11d3f13db9..8c3344f152 100644 --- a/tests/load/weaviate/test_weaviate_client.py +++ b/tests/load/weaviate/test_weaviate_client.py @@ -37,10 +37,10 @@ def drop_weaviate_schema() -> Iterator[None]: def get_client_instance(schema: Schema) -> WeaviateClient: - dest = weaviate(dataset_name="ClientTest" + uniq_id()) - return dest.client(schema, dest.spec()) - # with Container().injectable_context(ConfigSectionContext(sections=('destination', 'weaviate'))): - # return dest.client(schema, config) + dest = weaviate() + return dest.client( + schema, dest.spec()._bind_dataset_name(dataset_name="ClientTest" + uniq_id()) + ) @pytest.fixture(scope="function") diff --git a/tests/sources/helpers/rest_client/test_client.py b/tests/sources/helpers/rest_client/test_client.py index 7f03c6d167..79a57d0e82 100644 --- a/tests/sources/helpers/rest_client/test_client.py +++ b/tests/sources/helpers/rest_client/test_client.py @@ -1,10 +1,11 @@ import os import pytest from typing import Any, cast -from requests import PreparedRequest, Request +from dlt.common import logger +from requests import PreparedRequest, Request, Response from requests.auth import AuthBase from dlt.common.typing import TSecretStrValue -from dlt.sources.helpers.requests import Response +from dlt.sources.helpers.requests import Client from dlt.sources.helpers.rest_client import RESTClient from dlt.sources.helpers.rest_client.client import Hooks from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator @@ -35,6 +36,7 @@ def rest_client() -> RESTClient: return RESTClient( base_url="https://api.example.com", headers={"Accept": "application/json"}, + session=Client().session, ) @@ -168,6 +170,7 @@ def test_oauth_jwt_auth_success(self, rest_client: RESTClient): auth_endpoint="https://api.example.com/oauth/token", scopes=["read", "write"], headers={"Content-Type": "application/json"}, + session=Client().session, ) response = rest_client.get( @@ -185,6 +188,19 @@ def test_oauth_jwt_auth_success(self, rest_client: RESTClient): assert_pagination(list(pages_iter)) + def test_custom_session_client(self, mocker): + mocked_warning = mocker.patch.object(logger, "warning") + RESTClient( + base_url="https://api.example.com", + headers={"Accept": "application/json"}, + session=Client(raise_for_status=True).session, + ) + assert ( + mocked_warning.call_args[0][0] + == "The session provided has raise_for_status enabled. This may cause unexpected" + " behavior." + ) + def test_custom_auth_success(self, rest_client: RESTClient): class CustomAuthConfigBase(AuthConfigBase): def __init__(self, token: str): diff --git a/tests/sources/helpers/test_requests.py b/tests/sources/helpers/test_requests.py index aefdf23e77..70776a50ee 100644 --- a/tests/sources/helpers/test_requests.py +++ b/tests/sources/helpers/test_requests.py @@ -1,4 +1,4 @@ -from typing import Iterator, Type +from typing import Any, Dict, Iterator, List, Type from unittest import mock import os import random @@ -29,7 +29,7 @@ def mock_sleep() -> Iterator[mock.MagicMock]: def test_default_session_retry_settings() -> None: - retry: Retrying = Client().session.request.retry # type: ignore + retry: Retrying = Client().session.send.retry # type: ignore assert retry.stop.max_attempt_number == 5 # type: ignore assert isinstance(retry.retry, retry_any) retries = retry.retry.retries @@ -51,7 +51,7 @@ def custom_retry_cond(response, exception): respect_retry_after_header=False, ).session - retry: Retrying = session.request.retry # type: ignore + retry: Retrying = session.send.retry # type: ignore assert retry.stop.max_attempt_number == 14 # type: ignore assert isinstance(retry.retry, retry_any) retries = retry.retry.retries @@ -63,11 +63,12 @@ def custom_retry_cond(response, exception): def test_retry_on_status_all_fails(mock_sleep: mock.MagicMock) -> None: session = Client().session url = "https://example.com/data" + m = requests_mock.Adapter() + session.mount("https://", m) + m.register_uri("GET", url, status_code=503) - with requests_mock.mock(session=session) as m: - m.get(url, status_code=503) - with pytest.raises(requests.HTTPError): - session.get(url) + with pytest.raises(requests.HTTPError): + session.get(url) assert m.call_count == RunConfiguration.request_max_attempts @@ -76,6 +77,8 @@ def test_retry_on_status_success_after_2(mock_sleep: mock.MagicMock) -> None: """Test successful request after 2 retries""" session = Client().session url = "https://example.com/data" + m = requests_mock.Adapter() + session.mount("https://", m) responses = [ dict(text="error", status_code=503), @@ -83,9 +86,8 @@ def test_retry_on_status_success_after_2(mock_sleep: mock.MagicMock) -> None: dict(text="error", status_code=200), ] - with requests_mock.mock(session=session) as m: - m.get(url, responses) - resp = session.get(url) + m.register_uri("GET", url, responses) + resp = session.get(url) assert resp.status_code == 200 assert m.call_count == 3 @@ -94,11 +96,12 @@ def test_retry_on_status_success_after_2(mock_sleep: mock.MagicMock) -> None: def test_retry_on_status_without_raise_for_status(mock_sleep: mock.MagicMock) -> None: url = "https://example.com/data" session = Client(raise_for_status=False).session + m = requests_mock.Adapter() + session.mount("https://", m) - with requests_mock.mock(session=session) as m: - m.get(url, status_code=503) - response = session.get(url) - assert response.status_code == 503 + m.register_uri("GET", url, status_code=503) + response = session.get(url) + assert response.status_code == 503 assert m.call_count == RunConfiguration.request_max_attempts @@ -106,18 +109,19 @@ def test_retry_on_status_without_raise_for_status(mock_sleep: mock.MagicMock) -> def test_hooks_with_raise_for_statue() -> None: url = "https://example.com/data" session = Client(raise_for_status=True).session + m = requests_mock.Adapter() + session.mount("https://", m) def _no_content(resp: requests.Response, *args, **kwargs) -> requests.Response: resp.status_code = 204 resp._content = b"[]" return resp - with requests_mock.mock(session=session) as m: - m.get(url, status_code=503) - response = session.get(url, hooks={"response": _no_content}) - # we simulate empty response - assert response.status_code == 204 - assert response.json() == [] + m.register_uri("GET", url, status_code=503) + response = session.get(url, hooks={"response": _no_content}) + # we simulate empty response + assert response.status_code == 204 + assert response.json() == [] assert m.call_count == 1 @@ -130,12 +134,13 @@ def test_retry_on_exception_all_fails( exception_class: Type[Exception], mock_sleep: mock.MagicMock ) -> None: session = Client().session + m = requests_mock.Adapter() + session.mount("https://", m) url = "https://example.com/data" - with requests_mock.mock(session=session) as m: - m.get(url, exc=exception_class) - with pytest.raises(exception_class): - session.get(url) + m.register_uri("GET", url, exc=exception_class) + with pytest.raises(exception_class): + session.get(url) assert m.call_count == RunConfiguration.request_max_attempts @@ -145,12 +150,13 @@ def retry_on(response: requests.Response, exception: BaseException) -> bool: return response.text == "error" session = Client(retry_condition=retry_on).session + m = requests_mock.Adapter() + session.mount("https://", m) url = "https://example.com/data" - with requests_mock.mock(session=session) as m: - m.get(url, text="error") - response = session.get(url) - assert response.content == b"error" + m.register_uri("GET", url, text="error") + response = session.get(url) + assert response.content == b"error" assert m.call_count == RunConfiguration.request_max_attempts @@ -160,12 +166,12 @@ def retry_on(response: requests.Response, exception: BaseException) -> bool: return response.text == "error" session = Client(retry_condition=retry_on).session + m = requests_mock.Adapter() + session.mount("https://", m) url = "https://example.com/data" - responses = [dict(text="error"), dict(text="error"), dict(text="success")] - with requests_mock.mock(session=session) as m: - m.get(url, responses) - resp = session.get(url) + m.register_uri("GET", url, [dict(text="error"), dict(text="error"), dict(text="success")]) + resp = session.get(url) assert resp.text == "success" assert m.call_count == 3 @@ -174,14 +180,16 @@ def retry_on(response: requests.Response, exception: BaseException) -> bool: def test_wait_retry_after_int(mock_sleep: mock.MagicMock) -> None: session = Client(request_backoff_factor=0).session url = "https://example.com/data" - responses = [ + m = requests_mock.Adapter() + session.mount("https://", m) + m.register_uri("GET", url, text="error") + responses: List[Dict[str, Any]] = [ dict(text="error", headers={"retry-after": "4"}, status_code=429), dict(text="success"), ] - with requests_mock.mock(session=session) as m: - m.get(url, responses) - session.get(url) + m.register_uri("GET", url, responses) + session.get(url) mock_sleep.assert_called_once() assert 4 <= mock_sleep.call_args[0][0] <= 5 # Adds jitter up to 1s @@ -206,7 +214,7 @@ def test_init_default_client(existing_session: bool) -> None: session = default_client.session assert session.timeout == cfg["RUNTIME__REQUEST_TIMEOUT"] - retry = session.request.retry # type: ignore[attr-defined] + retry = session.send.retry # type: ignore[attr-defined] assert retry.wait.multiplier == cfg["RUNTIME__REQUEST_BACKOFF_FACTOR"] assert retry.stop.max_attempt_number == cfg["RUNTIME__REQUEST_MAX_ATTEMPTS"] assert retry.wait.max == cfg["RUNTIME__REQUEST_MAX_RETRY_DELAY"] @@ -226,7 +234,7 @@ def test_client_instance_with_config(existing_session: bool) -> None: session = client.session assert session.timeout == cfg["RUNTIME__REQUEST_TIMEOUT"] - retry = session.request.retry # type: ignore[attr-defined] + retry = session.send.retry # type: ignore[attr-defined] assert retry.wait.multiplier == cfg["RUNTIME__REQUEST_BACKOFF_FACTOR"] assert retry.stop.max_attempt_number == cfg["RUNTIME__REQUEST_MAX_ATTEMPTS"] assert retry.wait.max == cfg["RUNTIME__REQUEST_MAX_RETRY_DELAY"]