diff --git a/packages/opal-common/opal_common/config.py b/packages/opal-common/opal_common/config.py index 0397196b7..7666d47e4 100644 --- a/packages/opal-common/opal_common/config.py +++ b/packages/opal-common/opal_common/config.py @@ -169,6 +169,12 @@ class OpalCommonConfig(Confi): False, description="Set if OPAL server should enable tracing with datadog APM", ) + HTTP_FETCHER_PROVIDER_CLIENT = confi.str( + "HTTP_FETCHER_PROVIDER_CLIENT", + "aiohttp", + description="The client to use for fetching data, can be either aiohttp or httpx." + "if provided different value, aiohttp will be used.", + ) opal_common_config = OpalCommonConfig(prefix="OPAL_") diff --git a/packages/opal-common/opal_common/fetcher/providers/http_fetch_provider.py b/packages/opal-common/opal_common/fetcher/providers/http_fetch_provider.py index 05189876f..7261b538b 100644 --- a/packages/opal-common/opal_common/fetcher/providers/http_fetch_provider.py +++ b/packages/opal-common/opal_common/fetcher/providers/http_fetch_provider.py @@ -1,9 +1,11 @@ """Simple HTTP get data fetcher using requests supports.""" from enum import Enum -from typing import Any +from typing import Any, Union, cast +import httpx from aiohttp import ClientResponse, ClientSession +from opal_common.config import opal_common_config from pydantic import validator from ...http import is_http_error_response @@ -71,13 +73,15 @@ async def __aenter__(self): headers = {} if self._event.config.headers is not None: headers = self._event.config.headers - self._session = await ClientSession( - headers=headers, raise_for_status=True - ).__aenter__() + if opal_common_config.HTTP_FETCHER_PROVIDER_CLIENT == "httpx": + self._session = httpx.AsyncClient(headers=headers) + else: + self._session = ClientSession(headers=headers, raise_for_status=True) + self._session = await self._session.__aenter__() return self async def __aexit__(self, exc_type=None, exc_val=None, tb=None): - await self._session.__aexit__(exc_type=exc_type, exc_val=exc_val, exc_tb=tb) + await self._session.__aexit__(exc_type, exc_val, tb) async def _fetch_(self): logger.debug(f"{self.__class__.__name__} fetching from {self._url}") @@ -85,29 +89,38 @@ async def _fetch_(self): self._session, self._event.config.method ) if self._event.config.data is not None: - result = await http_method( + result: Union[ClientResponse, httpx.Response] = await http_method( self._url, data=self._event.config.data, **self._ssl_context_kwargs ) else: result = await http_method(self._url, **self._ssl_context_kwargs) + result.raise_for_status() return result @staticmethod - def match_http_method_from_type(session: ClientSession, method_type: HttpMethods): + def match_http_method_from_type( + session: Union[ClientSession, httpx.AsyncClient], method_type: HttpMethods + ): return getattr(session, method_type.value) - async def _process_(self, res: ClientResponse): + @staticmethod + async def _response_to_data( + res: Union[ClientResponse, httpx.Response], *, is_json: bool + ) -> Any: + if isinstance(res, httpx.Response): + return res.json() if is_json else res.text + else: + res = cast(ClientResponse, res) + return await (res.json() if is_json else res.text()) + + async def _process_(self, res: Union[ClientResponse, httpx.Response]): # do not process data when the http response is an error if is_http_error_response(res): return res # if we are asked to process the data before we return it if self._event.config.process_data: - # if data is JSON - if self._event.config.is_json: - data = await res.json() - else: - data = await res.text() + data = await self._response_to_data(res, is_json=self._event.config.is_json) return data # return raw result else: diff --git a/packages/opal-common/opal_common/http.py b/packages/opal-common/opal_common/http.py index 8ff942320..9c2d35a76 100644 --- a/packages/opal-common/opal_common/http.py +++ b/packages/opal-common/opal_common/http.py @@ -1,6 +1,17 @@ +from typing import Union + import aiohttp +import httpx -def is_http_error_response(response: aiohttp.ClientResponse) -> bool: +def is_http_error_response( + response: Union[aiohttp.ClientResponse, httpx.Response] +) -> bool: """HTTP 400 and above are considered error responses.""" - return response.status >= 400 + status: int = ( + response.status + if isinstance(response, aiohttp.ClientResponse) + else response.status_code + ) + + return status >= 400 diff --git a/packages/opal-common/requires.txt b/packages/opal-common/requires.txt index 30494e4f0..90d21f1df 100644 --- a/packages/opal-common/requires.txt +++ b/packages/opal-common/requires.txt @@ -10,3 +10,4 @@ datadog>=0.44.0, <1 ddtrace>=2.8.1,<3 certifi>=2023.7.22 # not directly required, pinned by Snyk to avoid a vulnerability requests>=2.31.0 # not directly required, pinned by Snyk to avoid a vulnerability +httpx==0.27.0