Skip to content

feat: allow injection of httpx client #591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ remove_pytest_asyncio_from_sync:
sed -i 's/@pytest.mark.asyncio//g' tests/_sync/test_client.py
sed -i 's/_async/_sync/g' tests/_sync/test_client.py
sed -i 's/Async/Sync/g' tests/_sync/test_client.py
sed -i 's/Async/Sync/g' postgrest/_sync/request_builder.py
sed -i 's/_client\.SyncClient/_client\.Client/g' tests/_sync/test_client.py
sed -i 's/SyncHTTPTransport/HTTPTransport/g' tests/_sync/test_client.py
sed -i 's/SyncHTTPTransport/HTTPTransport/g' tests/_sync/client.py

sleep:
sleep 2
37 changes: 37 additions & 0 deletions postgrest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,41 @@
from .deprecated_client import Client, PostgrestClient
from .deprecated_get_request_builder import GetRequestBuilder
from .exceptions import APIError
from .types import (
CountMethod,
Filters,
RequestMethod,
ReturnMethod,
)
from .version import __version__

__all__ = [
"AsyncPostgrestClient",
"AsyncFilterRequestBuilder",
"AsyncQueryRequestBuilder",
"AsyncRequestBuilder",
"AsyncRPCFilterRequestBuilder",
"AsyncSelectRequestBuilder",
"AsyncSingleRequestBuilder",
"AsyncMaybeSingleRequestBuilder",
"SyncPostgrestClient",
"SyncFilterRequestBuilder",
"SyncMaybeSingleRequestBuilder",
"SyncQueryRequestBuilder",
"SyncRequestBuilder",
"SyncRPCFilterRequestBuilder",
"SyncSelectRequestBuilder",
"SyncSingleRequestBuilder",
"APIResponse",
"DEFAULT_POSTGREST_CLIENT_HEADERS",
"Client",
"PostgrestClient",
"GetRequestBuilder",
"APIError",
"CountMethod",
"Filters",
"RequestMethod",
"ReturnMethod",
"Timeout",
"__version__",
]
50 changes: 46 additions & 4 deletions postgrest/_async/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import Any, Dict, Optional, Union, cast
from warnings import warn

from deprecation import deprecated
from httpx import Headers, QueryParams, Timeout
Expand All @@ -27,18 +28,50 @@ def __init__(
*,
schema: str = "public",
headers: Dict[str, str] = DEFAULT_POSTGREST_CLIENT_HEADERS,
timeout: Union[int, float, Timeout] = DEFAULT_POSTGREST_CLIENT_TIMEOUT,
verify: bool = True,
timeout: Union[int, float, Timeout, None] = None,
verify: Optional[bool] = None,
proxy: Optional[str] = None,
http_client: Optional[AsyncClient] = None,
) -> None:
if timeout is not None:
warn(
"The 'timeout' parameter is deprecated. Please configure it in the http client instead.",
DeprecationWarning,
stacklevel=2,
)
if verify is not None:
warn(
"The 'verify' parameter is deprecated. Please configure it in the http client instead.",
DeprecationWarning,
stacklevel=2,
)
if proxy is not None:
warn(
"The 'proxy' parameter is deprecated. Please configure it in the http client instead.",
DeprecationWarning,
stacklevel=2,
)

self.verify = bool(verify) if verify is not None else True
self.timeout = (
timeout
if isinstance(timeout, Timeout)
else (
int(abs(timeout))
if timeout is not None
else DEFAULT_POSTGREST_CLIENT_TIMEOUT
)
)

BasePostgrestClient.__init__(
self,
base_url,
schema=schema,
headers=headers,
timeout=timeout,
verify=verify,
timeout=self.timeout,
verify=self.verify,
proxy=proxy,
http_client=http_client,
)
self.session = cast(AsyncClient, self.session)

Expand All @@ -50,6 +83,15 @@ def create_session(
verify: bool = True,
proxy: Optional[str] = None,
) -> AsyncClient:
http_client = None
if isinstance(self.http_client, AsyncClient):
http_client = self.http_client

if http_client is not None:
http_client.base_url = base_url
http_client.headers.update({**headers})
return http_client

return AsyncClient(
base_url=base_url,
headers=headers,
Expand Down
50 changes: 46 additions & 4 deletions postgrest/_sync/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import Any, Dict, Optional, Union, cast
from warnings import warn

from deprecation import deprecated
from httpx import Headers, QueryParams, Timeout
Expand All @@ -27,18 +28,50 @@ def __init__(
*,
schema: str = "public",
headers: Dict[str, str] = DEFAULT_POSTGREST_CLIENT_HEADERS,
timeout: Union[int, float, Timeout] = DEFAULT_POSTGREST_CLIENT_TIMEOUT,
verify: bool = True,
timeout: Union[int, float, Timeout, None] = None,
verify: Optional[bool] = None,
proxy: Optional[str] = None,
http_client: Optional[SyncClient] = None,
) -> None:
if timeout is not None:
warn(
"The 'timeout' parameter is deprecated. Please configure it in the http client instead.",
DeprecationWarning,
stacklevel=2,
)
if verify is not None:
warn(
"The 'verify' parameter is deprecated. Please configure it in the http client instead.",
DeprecationWarning,
stacklevel=2,
)
if proxy is not None:
warn(
"The 'proxy' parameter is deprecated. Please configure it in the http client instead.",
DeprecationWarning,
stacklevel=2,
)

self.verify = bool(verify) if verify is not None else True
self.timeout = (
timeout
if isinstance(timeout, Timeout)
else (
int(abs(timeout))
if timeout is not None
else DEFAULT_POSTGREST_CLIENT_TIMEOUT
)
)

BasePostgrestClient.__init__(
self,
base_url,
schema=schema,
headers=headers,
timeout=timeout,
verify=verify,
timeout=self.timeout,
verify=self.verify,
proxy=proxy,
http_client=http_client,
)
self.session = cast(SyncClient, self.session)

Expand All @@ -50,6 +83,15 @@ def create_session(
verify: bool = True,
proxy: Optional[str] = None,
) -> SyncClient:
http_client = None
if isinstance(self.http_client, SyncClient):
http_client = self.http_client

if http_client is not None:
http_client.base_url = base_url
http_client.headers.update({**headers})
return http_client

return SyncClient(
base_url=base_url,
headers=headers,
Expand Down
10 changes: 5 additions & 5 deletions postgrest/_sync/request_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def select(
*columns: The names of the columns to fetch.
count: The method to use to get the count of rows returned.
Returns:
:class:`AsyncSelectRequestBuilder`
:class:`SyncSelectRequestBuilder`
"""
method, params, headers, json = pre_select(*columns, count=count, head=head)
return SyncSelectRequestBuilder[_ReturnT](
Expand All @@ -314,7 +314,7 @@ def insert(
Otherwise, use the default value for the column.
Only applies for bulk inserts.
Returns:
:class:`AsyncQueryRequestBuilder`
:class:`SyncQueryRequestBuilder`
"""
method, params, headers, json = pre_insert(
json,
Expand Down Expand Up @@ -350,7 +350,7 @@ def upsert(
not when merging with existing rows under `ignoreDuplicates: false`.
This also only applies when doing bulk upserts.
Returns:
:class:`AsyncQueryRequestBuilder`
:class:`SyncQueryRequestBuilder`
"""
method, params, headers, json = pre_upsert(
json,
Expand Down Expand Up @@ -378,7 +378,7 @@ def update(
count: The method to use to get the count of rows returned.
returning: Either 'minimal' or 'representation'
Returns:
:class:`AsyncFilterRequestBuilder`
:class:`SyncFilterRequestBuilder`
"""
method, params, headers, json = pre_update(
json,
Expand All @@ -401,7 +401,7 @@ def delete(
count: The method to use to get the count of rows returned.
returning: Either 'minimal' or 'representation'
Returns:
:class:`AsyncFilterRequestBuilder`
:class:`SyncFilterRequestBuilder`
"""
method, params, headers, json = pre_delete(
count=count,
Expand Down
8 changes: 7 additions & 1 deletion postgrest/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
timeout: Union[int, float, Timeout],
verify: bool = True,
proxy: Optional[str] = None,
http_client: Union[SyncClient, AsyncClient, None] = None,
) -> None:
if not is_http_url(base_url):
ValueError("base_url must be a valid HTTP URL string")
Expand All @@ -33,8 +34,13 @@ def __init__(
self.timeout = timeout
self.verify = verify
self.proxy = proxy
self.http_client = http_client
self.session = self.create_session(
self.base_url, self.headers, self.timeout, self.verify, self.proxy
self.base_url,
self.headers,
self.timeout,
self.verify,
self.proxy,
)

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions postgrest/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Any, Dict, Optional

from pydantic import BaseModel

Expand Down Expand Up @@ -34,7 +34,7 @@ class APIError(Exception):
details: Optional[str]
"""The error details."""

def __init__(self, error: Dict[str, str]) -> None:
def __init__(self, error: Dict[str, Any]) -> None:
self._raw_error = error
self.message = error.get("message")
self.code = error.get("code")
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ furo = ">=2023.9.10,<2025.0.0"

[tool.pytest.ini_options]
asyncio_mode = "auto"
filterwarnings = [
"ignore::DeprecationWarning", # ignore deprecation warnings globally
]

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
20 changes: 20 additions & 0 deletions tests/_async/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from httpx import AsyncHTTPTransport, Limits

from postgrest import AsyncPostgrestClient
from postgrest.utils import AsyncClient

REST_URL = "http://127.0.0.1:3000"

Expand All @@ -7,3 +10,20 @@ def rest_client():
return AsyncPostgrestClient(
base_url=REST_URL,
)


def rest_client_httpx():
transport = AsyncHTTPTransport(
retries=4,
limits=Limits(
max_connections=1,
max_keepalive_connections=1,
keepalive_expiry=None,
),
)
headers = {"x-user-agent": "my-app/0.0.1"}
http_client = AsyncClient(transport=transport, headers=headers)
return AsyncPostgrestClient(
base_url=REST_URL,
http_client=http_client,
)
37 changes: 36 additions & 1 deletion tests/_async/test_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from unittest.mock import patch

import pytest
from httpx import BasicAuth, Headers, Request, Response
from httpx import (
AsyncHTTPTransport,
BasicAuth,
Headers,
Limits,
Request,
Response,
Timeout,
)

from postgrest import AsyncPostgrestClient
from postgrest.exceptions import APIError
from postgrest.utils import AsyncClient


@pytest.fixture
Expand Down Expand Up @@ -46,6 +55,32 @@ async def test_custom_headers(self):
assert session.headers.items() >= headers.items()


class TestHttpxClientConstructor:
@pytest.mark.asyncio
async def test_custom_httpx_client(self):
transport = AsyncHTTPTransport(
retries=10,
limits=Limits(
max_connections=1,
max_keepalive_connections=1,
keepalive_expiry=None,
),
)
headers = {"x-user-agent": "my-app/0.0.1"}
http_client = AsyncClient(transport=transport, headers=headers)
async with AsyncPostgrestClient(
"https://example.com", http_client=http_client, timeout=20.0
) as client:
session = client.session

assert session.base_url == "https://example.com"
assert session.timeout == Timeout(
timeout=5.0
) # Should be the default 5 since we use custom httpx client
assert session.headers.get("x-user-agent") == "my-app/0.0.1"
assert isinstance(session, AsyncClient)


class TestAuth:
def test_auth_token(self, postgrest_client: AsyncPostgrestClient):
postgrest_client.auth("s3cr3t")
Expand Down
Loading