Skip to content

Commit

Permalink
add httpx client for HttpFetchProvider and make it default
Browse files Browse the repository at this point in the history
  • Loading branch information
omer9564 committed May 30, 2024
1 parent 5b80a48 commit e54829a
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Simple HTTP get data fetcher using requests supports."""

from enum import Enum
from typing import Any
from typing import Any, Union, cast

import httpx
from aiohttp import ClientResponse, ClientSession
from pydantic import validator
from typing_extensions import Literal

from ...http import is_http_error_response
from ...security.sslcontext import get_custom_ssl_context
Expand Down Expand Up @@ -48,6 +50,7 @@ class Config:
class HttpFetchEvent(FetchEvent):
fetcher: str = "HttpFetchProvider"
config: HttpFetcherConfig = None
client_type: Literal["httpx", "aiohttp"] = "httpx"


class HttpFetchProvider(BaseFetchProvider):
Expand All @@ -71,13 +74,15 @@ async def __aenter__(self):
headers = {}
if self._event.config.headers is not None:
headers = self._event.config.headers
self._session = await ClientSession(
headers=headers, raise_for_status=True
).__aenter__()
if self._event.client_type == "httpx":
self._session = httpx.AsyncClient(headers=headers)
else:
self._session = ClientSession(headers=headers, raise_for_status=True)
self._session = self._session.__aenter__()
return self

async def __aexit__(self, exc_type=None, exc_val=None, tb=None):
await self._session.__aexit__(exc_type=exc_type, exc_val=exc_val, exc_tb=tb)
await self._session.__aexit__(exc_type, exc_val, tb)

async def _fetch_(self):
logger.debug(f"{self.__class__.__name__} fetching from {self._url}")
Expand All @@ -93,21 +98,29 @@ async def _fetch_(self):
return result

@staticmethod
def match_http_method_from_type(session: ClientSession, method_type: HttpMethods):
def match_http_method_from_type(
session: Union[ClientSession, httpx.AsyncClient], method_type: HttpMethods
):
return getattr(session, method_type.value)

async def _process_(self, res: ClientResponse):
@staticmethod
async def _response_to_data(
res: Union[ClientResponse, httpx.Response], *, is_json: bool
) -> Any:
if isinstance(res, httpx.Response):
return res.json() if is_json else res.text
else:
res = cast(ClientResponse, res)
return await (res.json() if is_json else res.text())

async def _process_(self, res: Union[ClientResponse, httpx.Response]):
# do not process data when the http response is an error
if is_http_error_response(res):
return res

# if we are asked to process the data before we return it
if self._event.config.process_data:
# if data is JSON
if self._event.config.is_json:
data = await res.json()
else:
data = await res.text()
data = await self._response_to_data(res, is_json=self._event.config.is_json)
return data
# return raw result
else:
Expand Down
15 changes: 13 additions & 2 deletions packages/opal-common/opal_common/http.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
from typing import Union

import aiohttp
import httpx


def is_http_error_response(response: aiohttp.ClientResponse) -> bool:
def is_http_error_response(
response: Union[aiohttp.ClientResponse, httpx.Response]
) -> bool:
"""HTTP 400 and above are considered error responses."""
return response.status >= 400
status: int = (
response.status
if isinstance(response, aiohttp.ClientResponse)
else response.status_code
)

return status >= 400
1 change: 1 addition & 0 deletions packages/opal-common/requires.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ datadog>=0.44.0, <1
ddtrace>=2.8.1,<3
certifi>=2023.7.22 # not directly required, pinned by Snyk to avoid a vulnerability
requests>=2.31.0 # not directly required, pinned by Snyk to avoid a vulnerability
httpx==0.27.0

0 comments on commit e54829a

Please sign in to comment.