From bf1a82dd4082f8dbc9027e4c68b55330a5e05d7f Mon Sep 17 00:00:00 2001 From: Daniel Krebs Date: Mon, 4 Dec 2023 14:09:14 +0100 Subject: [PATCH] Allow setting base URL when constructing EnlyzeClient. (#31) With internal use-cases popping up, we need the ability to set a non-standard base URL for the ENLYZE platform. For our users, nothing should change as the same default will still be taken. --- src/enlyze/api_clients/base.py | 6 ++--- .../api_clients/production_runs/client.py | 10 +++++--- src/enlyze/api_clients/timeseries/client.py | 10 +++++--- src/enlyze/client.py | 17 ++++++++++--- tests/enlyze/api_clients/conftest.py | 5 ++++ .../production_runs/test_client.py | 9 +++---- tests/enlyze/api_clients/test_base.py | 25 +++++++++++-------- .../api_clients/timeseries/test_client.py | 7 +++--- 8 files changed, 56 insertions(+), 33 deletions(-) diff --git a/src/enlyze/api_clients/base.py b/src/enlyze/api_clients/base.py index 1d30d06..557deb3 100644 --- a/src/enlyze/api_clients/base.py +++ b/src/enlyze/api_clients/base.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, ValidationError from enlyze.auth import TokenAuth -from enlyze.constants import ENLYZE_BASE_URL, HTTPX_TIMEOUT +from enlyze.constants import HTTPX_TIMEOUT from enlyze.errors import EnlyzeError, InvalidTokenError @@ -51,9 +51,9 @@ class ApiBaseClient(ABC, Generic[R]): def __init__( self, - token: str, *, - base_url: str | httpx.URL = ENLYZE_BASE_URL, + token: str, + base_url: str | httpx.URL, timeout: float = HTTPX_TIMEOUT, ): self._client = httpx.Client( diff --git a/src/enlyze/api_clients/production_runs/client.py b/src/enlyze/api_clients/production_runs/client.py index 73ab06f..6610913 100644 --- a/src/enlyze/api_clients/production_runs/client.py +++ b/src/enlyze/api_clients/production_runs/client.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from enlyze.api_clients.base import ApiBaseClient, PaginatedResponseBaseModel -from enlyze.constants import ENLYZE_BASE_URL, PRODUCTION_RUNS_API_SUB_PATH +from enlyze.constants import PRODUCTION_RUNS_API_SUB_PATH class _Metadata(BaseModel): @@ -29,10 +29,14 @@ class ProductionRunsApiClient(ApiBaseClient[_PaginatedResponse]): PaginatedResponseModel = _PaginatedResponse def __init__( - self, token: str, *, base_url: str | httpx.URL = ENLYZE_BASE_URL, **kwargs: Any + self, + *, + token: str, + base_url: str | httpx.URL, + **kwargs: Any, ): super().__init__( - token, + token=token, base_url=httpx.URL(base_url).join(PRODUCTION_RUNS_API_SUB_PATH), **kwargs, ) diff --git a/src/enlyze/api_clients/timeseries/client.py b/src/enlyze/api_clients/timeseries/client.py index a4824a4..b3385fd 100644 --- a/src/enlyze/api_clients/timeseries/client.py +++ b/src/enlyze/api_clients/timeseries/client.py @@ -4,7 +4,7 @@ from pydantic import AnyUrl from enlyze.api_clients.base import ApiBaseClient, PaginatedResponseBaseModel -from enlyze.constants import ENLYZE_BASE_URL, TIMESERIES_API_SUB_PATH +from enlyze.constants import TIMESERIES_API_SUB_PATH class _PaginatedResponse(PaginatedResponseBaseModel): @@ -25,13 +25,15 @@ class TimeseriesApiClient(ApiBaseClient[_PaginatedResponse]): def __init__( self, - token: str, *, - base_url: str | httpx.URL = ENLYZE_BASE_URL, + token: str, + base_url: str | httpx.URL, **kwargs: Any, ): super().__init__( - token, base_url=httpx.URL(base_url).join(TIMESERIES_API_SUB_PATH), **kwargs + token=token, + base_url=httpx.URL(base_url).join(TIMESERIES_API_SUB_PATH), + **kwargs, ) def _transform_paginated_response_data( diff --git a/src/enlyze/client.py b/src/enlyze/client.py index e04dd9d..ec4aa11 100644 --- a/src/enlyze/client.py +++ b/src/enlyze/client.py @@ -8,7 +8,10 @@ from enlyze.api_clients.production_runs.client import ProductionRunsApiClient from enlyze.api_clients.production_runs.models import ProductionRun from enlyze.api_clients.timeseries.client import TimeseriesApiClient -from enlyze.constants import VARIABLE_UUID_AND_RESAMPLING_METHOD_SEPARATOR +from enlyze.constants import ( + ENLYZE_BASE_URL, + VARIABLE_UUID_AND_RESAMPLING_METHOD_SEPARATOR, +) from enlyze.errors import EnlyzeError from enlyze.validators import ( validate_datetime, @@ -49,9 +52,15 @@ class EnlyzeClient: """ - def __init__(self, token: str) -> None: - self._timeseries_api_client = TimeseriesApiClient(token=token) - self._production_runs_api_client = ProductionRunsApiClient(token=token) + def __init__(self, token: str, *, _base_url: str | None = None) -> None: + self._timeseries_api_client = TimeseriesApiClient( + token=token, + base_url=_base_url or ENLYZE_BASE_URL, + ) + self._production_runs_api_client = ProductionRunsApiClient( + token=token, + base_url=_base_url or ENLYZE_BASE_URL, + ) def _get_sites(self) -> Iterator[timeseries_api_models.Site]: """Get all sites from the API""" diff --git a/tests/enlyze/api_clients/conftest.py b/tests/enlyze/api_clients/conftest.py index cc69a41..1773257 100644 --- a/tests/enlyze/api_clients/conftest.py +++ b/tests/enlyze/api_clients/conftest.py @@ -17,3 +17,8 @@ def string_model(): @pytest.fixture def endpoint(): return "https://my-endpoint.com" + + +@pytest.fixture +def base_url(): + return "http://api-client-base" diff --git a/tests/enlyze/api_clients/production_runs/test_client.py b/tests/enlyze/api_clients/production_runs/test_client.py index c357eb3..f64167d 100644 --- a/tests/enlyze/api_clients/production_runs/test_client.py +++ b/tests/enlyze/api_clients/production_runs/test_client.py @@ -7,7 +7,7 @@ _Metadata, _PaginatedResponse, ) -from enlyze.constants import ENLYZE_BASE_URL, PRODUCTION_RUNS_API_SUB_PATH +from enlyze.constants import PRODUCTION_RUNS_API_SUB_PATH @pytest.fixture @@ -36,12 +36,11 @@ def paginated_response_with_next_page(response_data, metadata_next_page): @pytest.fixture -def production_runs_client(auth_token): - return ProductionRunsApiClient(token=auth_token) +def production_runs_client(auth_token, base_url): + return ProductionRunsApiClient(token=auth_token, base_url=base_url) -def test_timeseries_api_appends_sub_path(auth_token): - base_url = ENLYZE_BASE_URL +def test_timeseries_api_appends_sub_path(auth_token, base_url): expected = str(httpx.URL(base_url).join(PRODUCTION_RUNS_API_SUB_PATH)) client = ProductionRunsApiClient(token=auth_token, base_url=base_url) assert client._full_url("") == expected diff --git a/tests/enlyze/api_clients/test_base.py b/tests/enlyze/api_clients/test_base.py index d37ca0b..922db32 100644 --- a/tests/enlyze/api_clients/test_base.py +++ b/tests/enlyze/api_clients/test_base.py @@ -4,7 +4,7 @@ import httpx import pytest import respx -from hypothesis import given +from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st from enlyze.api_clients.base import ( @@ -12,7 +12,6 @@ ApiBaseModel, PaginatedResponseBaseModel, ) -from enlyze.constants import ENLYZE_BASE_URL from enlyze.errors import EnlyzeError, InvalidTokenError @@ -65,7 +64,7 @@ def paginated_response_no_next_page(response_data_integers, last_page_metadata): @pytest.fixture -def base_client(auth_token, string_model): +def base_client(auth_token, string_model, base_url): mock_has_more = MagicMock() mock_transform_paginated_response_data = MagicMock(side_effect=lambda e: e) mock_next_page_call_args = MagicMock() @@ -76,18 +75,22 @@ def base_client(auth_token, string_model): _next_page_call_args=mock_next_page_call_args, _transform_paginated_response_data=mock_transform_paginated_response_data, ): - client = ApiBaseClient[PaginatedResponseModel](token=auth_token) + client = ApiBaseClient[PaginatedResponseModel]( + token=auth_token, + base_url=base_url, + ) client.PaginatedResponseModel = PaginatedResponseModel yield client +@settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) @given( token=st.text(string.printable, min_size=1), ) @respx.mock -def test_token_auth(token): +def test_token_auth(token, base_url): with patch.multiple(ApiBaseClient, __abstractmethods__=set()): - client = ApiBaseClient(token=token) + client = ApiBaseClient(token=token, base_url=base_url) route_is_authenticated = respx.get( "", @@ -99,11 +102,11 @@ def test_token_auth(token): @respx.mock -def test_base_url(base_client): +def test_base_url(base_client, base_url): endpoint = "some-endpoint" route = respx.get( - httpx.URL(ENLYZE_BASE_URL).join(endpoint), + httpx.URL(base_url).join(endpoint), ).respond(json={}) base_client.get(endpoint) @@ -276,9 +279,11 @@ def test_get_paginated_transform_paginated_data( assert data == expected_data -def test_transform_paginated_data_returns_unmutated_element_by_default(auth_token): +def test_transform_paginated_data_returns_unmutated_element_by_default( + auth_token, base_url +): with patch.multiple(ApiBaseClient, __abstractmethods__=set()): - client = ApiBaseClient(auth_token) + client = ApiBaseClient(token=auth_token, base_url=base_url) data = [1, 2, 3] value = client._transform_paginated_response_data(data) assert data == value diff --git a/tests/enlyze/api_clients/timeseries/test_client.py b/tests/enlyze/api_clients/timeseries/test_client.py index 8cfe9ac..035b5dc 100644 --- a/tests/enlyze/api_clients/timeseries/test_client.py +++ b/tests/enlyze/api_clients/timeseries/test_client.py @@ -35,12 +35,11 @@ def paginated_response_with_next_page(endpoint): @pytest.fixture -def timeseries_client(auth_token): - return TimeseriesApiClient(token=auth_token) +def timeseries_client(auth_token, base_url): + return TimeseriesApiClient(token=auth_token, base_url=base_url) -def test_timeseries_api_appends_sub_path(auth_token): - base_url = "https://some-base-url.com" +def test_timeseries_api_appends_sub_path(auth_token, base_url): expected = str(httpx.URL(base_url).join(TIMESERIES_API_SUB_PATH)) client = TimeseriesApiClient(token=auth_token, base_url=base_url) assert client._full_url("") == expected