diff --git a/cmsdials/utils/api_client.py b/cmsdials/utils/api_client.py index 4d4e81d..dcc6cf3 100644 --- a/cmsdials/utils/api_client.py +++ b/cmsdials/utils/api_client.py @@ -3,6 +3,8 @@ from urllib.parse import parse_qs, urlparse import requests +from requests import Response, Session +from requests.adapters import DEFAULT_RETRIES, HTTPAdapter from requests.exceptions import HTTPError from ..auth._base import BaseCredentials @@ -35,6 +37,19 @@ def __endswithslash(value: str) -> str: return value + "/" return value + @classmethod + def _requests_get(cls, *args, retries=DEFAULT_RETRIES, **kwargs) -> Response: + """ + requests.get() with an additional `retries` parameter. + + Specify retries= for simple use cases. + For advanced usage, see https://docs.python-requests.org/en/latest/user/advanced/ + """ + with Session() as s: + s.mount(cls.PRODUCTION_BASE_URL, HTTPAdapter(max_retries=retries)) + ret = s.get(*args, **kwargs) + return ret + @property def api_url(self): return self.base_url + self.route + self.version @@ -79,11 +94,17 @@ def get(self, id: int): # noqa: A002 response = response.json() return self.data_model(**response) - def list(self, filters=None): + def list(self, filters=None, retries=DEFAULT_RETRIES): filters = filters or self.filter_class() endpoint_url = self.api_url + self.lookup_url headers = self._build_headers() - response = requests.get(endpoint_url, headers=headers, params=filters.cleandict(), timeout=self.default_timeout) + response = self._requests_get( + endpoint_url, + headers=headers, + params=filters.cleandict(), + timeout=self.default_timeout, + retries=retries, + ) try: response.raise_for_status() @@ -99,7 +120,13 @@ def list(self, filters=None): raise ValueError("pagination model is None and response is not a list.") - def __list_sync(self, filters, max_pages: Optional[int] = None, enable_progress: bool = False): + def __list_sync( + self, + filters, + max_pages: Optional[int] = None, + enable_progress: bool = False, + retries=DEFAULT_RETRIES, + ): next_token = None results = [] is_last_page = False @@ -112,7 +139,7 @@ def __list_sync(self, filters, max_pages: Optional[int] = None, enable_progress: while is_last_page is False: curr_filters = self.filter_class(**filters.dict()) curr_filters.next_token = next_token - response = self.list(curr_filters) + response = self.list(curr_filters, retries=retries) results.extend(response.results) is_last_page = response.next is None next_token = parse_qs(urlparse(response.next).query).get("next_token") if response.next else None