Skip to content

Commit

Permalink
Allow setting base URL when constructing EnlyzeClient. (#31)
Browse files Browse the repository at this point in the history
With internal use-cases popping up, we need the ability to set a
non-standard base URL for the ENLYZE platform. For our users, nothing
should change as the same default will still be taken.
  • Loading branch information
daniel-k authored Dec 4, 2023
1 parent abb3023 commit bf1a82d
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 33 deletions.
6 changes: 3 additions & 3 deletions src/enlyze/api_clients/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pydantic import BaseModel, ValidationError

from enlyze.auth import TokenAuth
from enlyze.constants import ENLYZE_BASE_URL, HTTPX_TIMEOUT
from enlyze.constants import HTTPX_TIMEOUT
from enlyze.errors import EnlyzeError, InvalidTokenError


Expand Down Expand Up @@ -51,9 +51,9 @@ class ApiBaseClient(ABC, Generic[R]):

def __init__(
self,
token: str,
*,
base_url: str | httpx.URL = ENLYZE_BASE_URL,
token: str,
base_url: str | httpx.URL,
timeout: float = HTTPX_TIMEOUT,
):
self._client = httpx.Client(
Expand Down
10 changes: 7 additions & 3 deletions src/enlyze/api_clients/production_runs/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import BaseModel

from enlyze.api_clients.base import ApiBaseClient, PaginatedResponseBaseModel
from enlyze.constants import ENLYZE_BASE_URL, PRODUCTION_RUNS_API_SUB_PATH
from enlyze.constants import PRODUCTION_RUNS_API_SUB_PATH


class _Metadata(BaseModel):
Expand All @@ -29,10 +29,14 @@ class ProductionRunsApiClient(ApiBaseClient[_PaginatedResponse]):
PaginatedResponseModel = _PaginatedResponse

def __init__(
self, token: str, *, base_url: str | httpx.URL = ENLYZE_BASE_URL, **kwargs: Any
self,
*,
token: str,
base_url: str | httpx.URL,
**kwargs: Any,
):
super().__init__(
token,
token=token,
base_url=httpx.URL(base_url).join(PRODUCTION_RUNS_API_SUB_PATH),
**kwargs,
)
Expand Down
10 changes: 6 additions & 4 deletions src/enlyze/api_clients/timeseries/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import AnyUrl

from enlyze.api_clients.base import ApiBaseClient, PaginatedResponseBaseModel
from enlyze.constants import ENLYZE_BASE_URL, TIMESERIES_API_SUB_PATH
from enlyze.constants import TIMESERIES_API_SUB_PATH


class _PaginatedResponse(PaginatedResponseBaseModel):
Expand All @@ -25,13 +25,15 @@ class TimeseriesApiClient(ApiBaseClient[_PaginatedResponse]):

def __init__(
self,
token: str,
*,
base_url: str | httpx.URL = ENLYZE_BASE_URL,
token: str,
base_url: str | httpx.URL,
**kwargs: Any,
):
super().__init__(
token, base_url=httpx.URL(base_url).join(TIMESERIES_API_SUB_PATH), **kwargs
token=token,
base_url=httpx.URL(base_url).join(TIMESERIES_API_SUB_PATH),
**kwargs,
)

def _transform_paginated_response_data(
Expand Down
17 changes: 13 additions & 4 deletions src/enlyze/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
from enlyze.api_clients.production_runs.client import ProductionRunsApiClient
from enlyze.api_clients.production_runs.models import ProductionRun
from enlyze.api_clients.timeseries.client import TimeseriesApiClient
from enlyze.constants import VARIABLE_UUID_AND_RESAMPLING_METHOD_SEPARATOR
from enlyze.constants import (
ENLYZE_BASE_URL,
VARIABLE_UUID_AND_RESAMPLING_METHOD_SEPARATOR,
)
from enlyze.errors import EnlyzeError
from enlyze.validators import (
validate_datetime,
Expand Down Expand Up @@ -49,9 +52,15 @@ class EnlyzeClient:
"""

def __init__(self, token: str) -> None:
self._timeseries_api_client = TimeseriesApiClient(token=token)
self._production_runs_api_client = ProductionRunsApiClient(token=token)
def __init__(self, token: str, *, _base_url: str | None = None) -> None:
self._timeseries_api_client = TimeseriesApiClient(
token=token,
base_url=_base_url or ENLYZE_BASE_URL,
)
self._production_runs_api_client = ProductionRunsApiClient(
token=token,
base_url=_base_url or ENLYZE_BASE_URL,
)

def _get_sites(self) -> Iterator[timeseries_api_models.Site]:
"""Get all sites from the API"""
Expand Down
5 changes: 5 additions & 0 deletions tests/enlyze/api_clients/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,8 @@ def string_model():
@pytest.fixture
def endpoint():
return "https://my-endpoint.com"


@pytest.fixture
def base_url():
return "http://api-client-base"
9 changes: 4 additions & 5 deletions tests/enlyze/api_clients/production_runs/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
_Metadata,
_PaginatedResponse,
)
from enlyze.constants import ENLYZE_BASE_URL, PRODUCTION_RUNS_API_SUB_PATH
from enlyze.constants import PRODUCTION_RUNS_API_SUB_PATH


@pytest.fixture
Expand Down Expand Up @@ -36,12 +36,11 @@ def paginated_response_with_next_page(response_data, metadata_next_page):


@pytest.fixture
def production_runs_client(auth_token):
return ProductionRunsApiClient(token=auth_token)
def production_runs_client(auth_token, base_url):
return ProductionRunsApiClient(token=auth_token, base_url=base_url)


def test_timeseries_api_appends_sub_path(auth_token):
base_url = ENLYZE_BASE_URL
def test_timeseries_api_appends_sub_path(auth_token, base_url):
expected = str(httpx.URL(base_url).join(PRODUCTION_RUNS_API_SUB_PATH))
client = ProductionRunsApiClient(token=auth_token, base_url=base_url)
assert client._full_url("") == expected
Expand Down
25 changes: 15 additions & 10 deletions tests/enlyze/api_clients/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
import httpx
import pytest
import respx
from hypothesis import given
from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as st

from enlyze.api_clients.base import (
ApiBaseClient,
ApiBaseModel,
PaginatedResponseBaseModel,
)
from enlyze.constants import ENLYZE_BASE_URL
from enlyze.errors import EnlyzeError, InvalidTokenError


Expand Down Expand Up @@ -65,7 +64,7 @@ def paginated_response_no_next_page(response_data_integers, last_page_metadata):


@pytest.fixture
def base_client(auth_token, string_model):
def base_client(auth_token, string_model, base_url):
mock_has_more = MagicMock()
mock_transform_paginated_response_data = MagicMock(side_effect=lambda e: e)
mock_next_page_call_args = MagicMock()
Expand All @@ -76,18 +75,22 @@ def base_client(auth_token, string_model):
_next_page_call_args=mock_next_page_call_args,
_transform_paginated_response_data=mock_transform_paginated_response_data,
):
client = ApiBaseClient[PaginatedResponseModel](token=auth_token)
client = ApiBaseClient[PaginatedResponseModel](
token=auth_token,
base_url=base_url,
)
client.PaginatedResponseModel = PaginatedResponseModel
yield client


@settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
@given(
token=st.text(string.printable, min_size=1),
)
@respx.mock
def test_token_auth(token):
def test_token_auth(token, base_url):
with patch.multiple(ApiBaseClient, __abstractmethods__=set()):
client = ApiBaseClient(token=token)
client = ApiBaseClient(token=token, base_url=base_url)

route_is_authenticated = respx.get(
"",
Expand All @@ -99,11 +102,11 @@ def test_token_auth(token):


@respx.mock
def test_base_url(base_client):
def test_base_url(base_client, base_url):
endpoint = "some-endpoint"

route = respx.get(
httpx.URL(ENLYZE_BASE_URL).join(endpoint),
httpx.URL(base_url).join(endpoint),
).respond(json={})

base_client.get(endpoint)
Expand Down Expand Up @@ -276,9 +279,11 @@ def test_get_paginated_transform_paginated_data(
assert data == expected_data


def test_transform_paginated_data_returns_unmutated_element_by_default(auth_token):
def test_transform_paginated_data_returns_unmutated_element_by_default(
auth_token, base_url
):
with patch.multiple(ApiBaseClient, __abstractmethods__=set()):
client = ApiBaseClient(auth_token)
client = ApiBaseClient(token=auth_token, base_url=base_url)
data = [1, 2, 3]
value = client._transform_paginated_response_data(data)
assert data == value
7 changes: 3 additions & 4 deletions tests/enlyze/api_clients/timeseries/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,11 @@ def paginated_response_with_next_page(endpoint):


@pytest.fixture
def timeseries_client(auth_token):
return TimeseriesApiClient(token=auth_token)
def timeseries_client(auth_token, base_url):
return TimeseriesApiClient(token=auth_token, base_url=base_url)


def test_timeseries_api_appends_sub_path(auth_token):
base_url = "https://some-base-url.com"
def test_timeseries_api_appends_sub_path(auth_token, base_url):
expected = str(httpx.URL(base_url).join(TIMESERIES_API_SUB_PATH))
client = TimeseriesApiClient(token=auth_token, base_url=base_url)
assert client._full_url("") == expected
Expand Down

0 comments on commit bf1a82d

Please sign in to comment.