Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support retrying #16

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions cmsdials/utils/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from urllib.parse import parse_qs, urlparse

import requests
from requests import Response, Session
from requests.adapters import DEFAULT_RETRIES, HTTPAdapter
from requests.exceptions import HTTPError

from ..auth._base import BaseCredentials
Expand Down Expand Up @@ -35,6 +37,19 @@ def __endswithslash(value: str) -> str:
return value + "/"
return value

@classmethod
def _requests_get(cls, *args, retries=DEFAULT_RETRIES, **kwargs) -> Response:
"""
requests.get() with an additional `retries` parameter.

Specify retries=<number of attempts - 1> for simple use cases.
For advanced usage, see https://docs.python-requests.org/en/latest/user/advanced/
"""
with Session() as s:
s.mount(cls.PRODUCTION_BASE_URL, HTTPAdapter(max_retries=retries))
ret = s.get(*args, **kwargs)
return ret

@property
def api_url(self):
return self.base_url + self.route + self.version
Expand Down Expand Up @@ -79,11 +94,17 @@ def get(self, id: int): # noqa: A002
response = response.json()
return self.data_model(**response)

def list(self, filters=None):
def list(self, filters=None, retries=DEFAULT_RETRIES):
filters = filters or self.filter_class()
endpoint_url = self.api_url + self.lookup_url
headers = self._build_headers()
response = requests.get(endpoint_url, headers=headers, params=filters.cleandict(), timeout=self.default_timeout)
response = self._requests_get(
endpoint_url,
headers=headers,
params=filters.cleandict(),
timeout=self.default_timeout,
retries=retries,
)

try:
response.raise_for_status()
Expand All @@ -99,7 +120,13 @@ def list(self, filters=None):

raise ValueError("pagination model is None and response is not a list.")

def __list_sync(self, filters, max_pages: Optional[int] = None, enable_progress: bool = False):
def __list_sync(
self,
filters,
max_pages: Optional[int] = None,
enable_progress: bool = False,
retries=DEFAULT_RETRIES,
):
next_token = None
results = []
is_last_page = False
Expand All @@ -112,7 +139,7 @@ def __list_sync(self, filters, max_pages: Optional[int] = None, enable_progress:
while is_last_page is False:
curr_filters = self.filter_class(**filters.dict())
curr_filters.next_token = next_token
response = self.list(curr_filters)
response = self.list(curr_filters, retries=retries)
results.extend(response.results)
is_last_page = response.next is None
next_token = parse_qs(urlparse(response.next).query).get("next_token") if response.next else None
Expand Down