Skip to content

Commit

Permalink
Changes
Browse files Browse the repository at this point in the history
- add asgi integration
- add integration docs
  • Loading branch information
devkral committed Sep 16, 2024
1 parent 539d77f commit f995d6a
Show file tree
Hide file tree
Showing 11 changed files with 246 additions and 3 deletions.
67 changes: 67 additions & 0 deletions databasez/core/asgi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations

from contextlib import suppress
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict

if TYPE_CHECKING:
from edgy.core.Database import Database

ASGIApp = Callable[
[
Dict[str, Any],
Callable[[], Awaitable[Dict[str, Any]]],
Callable[[Dict[str, Any]], Awaitable[None]],
],
Awaitable[None],
]


class MuteInteruptException(BaseException):
pass


@dataclass
class ASGIHelper:
app: ASGIApp
database: Database
handle_lifespan: bool = False

async def __call__(
self,
scope: Dict[str, Any],
receive: Callable[[], Awaitable[Dict[str, Any]]],
send: Callable[[Dict[str, Any]], Awaitable[None]],
) -> None:
if scope["type"] == "lifespan":
original_receive = receive

async def receive() -> Dict[str, Any]:
message = await original_receive()
if message["type"] == "lifespan.startup":
try:
await self.database.connect()
except Exception as exc:
await send({"type": "lifespan.startup.failed", "msg": str(exc)})
raise MuteInteruptException from None
elif message["type"] == "lifespan.shutdown":
try:
await self.database.disconnect()
except Exception as exc:
await send({"type": "lifespan.shutdown.failed", "msg": str(exc)})
raise MuteInteruptException from None
return message

if self.handle_lifespan:
with suppress(MuteInteruptException):
while True:
message = await receive()
if message["type"] == "lifespan.startup":
await send({"type": "lifespan.startup.complete"})
elif message["type"] == "lifespan.shutdown":
await send({"type": "lifespan.shutdown.complete"})
return
return

with suppress(MuteInteruptException):
await self.app(scope, receive, send)
27 changes: 26 additions & 1 deletion databasez/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import typing
import weakref
from contextvars import ContextVar
from functools import lru_cache
from functools import lru_cache, partial
from types import TracebackType

from databasez import interfaces
Expand All @@ -18,6 +18,7 @@
multiloop_protector,
)

from .asgi import ASGIApp, ASGIHelper
from .connection import Connection
from .databaseurl import DatabaseURL
from .transaction import Transaction
Expand Down Expand Up @@ -571,6 +572,30 @@ def connection(self, timeout: typing.Optional[float] = None) -> Connection:
def engine(self) -> typing.Optional[AsyncEngine]:
return self.backend.engine

@typing.overload
def asgi(
self,
app: None,
handle_lifespan: bool = False,
) -> typing.Callable[[ASGIApp], ASGIHelper]: ...

@typing.overload
def asgi(
self,
app: ASGIApp,
handle_lifespan: bool = False,
) -> ASGIHelper: ...

def asgi(
self,
app: typing.Optional[ASGIApp] = None,
handle_lifespan: bool = False,
) -> typing.Union[ASGIHelper, typing.Callable[[ASGIApp], ASGIHelper]]:
"""Return wrapper for asgi integration."""
if app is not None:
return ASGIHelper(app=app, database=self, handle_lifespan=handle_lifespan)
return partial(ASGIHelper, database=self, handle_lifespan=handle_lifespan)

@classmethod
def get_backends(
cls,
Expand Down
1 change: 1 addition & 0 deletions docs/connections-and-transactions.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ to connect to the database.

* **force_rollback(force_rollback=True)**: - The magic attribute is also function returning a context-manager for temporary overwrites of force_rollback.

* **asgi** - ASGI lifespan interception shim.

## Connecting and disconnecting

Expand Down
4 changes: 4 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ Now from the console, we can run a simple example.
Check out the documentation on [making database queries](./queries.md)
for examples of how to start using databases together with SQLAlchemy core expressions.

For the integration in frameworks see:

[Integrations](./integrations.md)


[sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/
[sqlalchemy-core-tutorial]: https://docs.sqlalchemy.org/en/latest/core/tutorial.html
Expand Down
39 changes: 39 additions & 0 deletions docs/integrations.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Integrations

Databasez has several ways to integrate in applications. Mainly recommended
are the async contextmanager one and the asgi based.

## AsyncContextManager

The recommended way of manually using databasez is via the async contextmanager protocol.
This way it is ensured that the database is tore down on errors.

Luckily starlette based apps support the lifespan protocol (startup, teardown of an ASGI server) via async contextmanagers.

```python
{!> ../docs_src/integrations/starlette.py !}
```

Note: This works also in different domains which are not web related.


## ASGI

This is a lifespan protocol interception shim for ASGI lifespan. Instead of using the lifespan parameter of starlette, it is possible
to wrap the ASGI application via the shim. This way databasez intercepts lifespans requests and injects its code.
By default it passes the lifespan request further down, but it has a compatibility option named `handle_lifespan`.
It is required for ASGI apps without lifespan support like django.


```python
{!> ../docs_src/integrations/django.py !}
```

## Manually

Some Server doesn't support the lifespan protocol. E.g. WSGI based servers. Here is an example how to integrate it.
As well as with the AsyncContextManager we are not limitted to web applications.

```python
{!> ../docs_src/integrations/esmerald.py !}
```
2 changes: 2 additions & 0 deletions docs/vendors.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ as this fork was possible because of their great work.

Commit number: [615c4d602beb5b067ad925215e3fe2944cf5150c](https://github.com/encode/databases/commit/615c4d602beb5b067ad925215e3fe2944cf5150c)

This package heavily depend on the awesome sqlalchemy library.
However we use sometimes more user-friendly defaults.
8 changes: 8 additions & 0 deletions docs_src/integrations/django.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from django.core.asgi import get_asgi_application

from databasez import Database

applications = Database("sqlite:///foo.sqlite").asgi(
# except you have a lifespan handler in django
handle_lifespan=True
)(get_asgi_application())
18 changes: 18 additions & 0 deletions docs_src/integrations/esmerald.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from esmerald import Esmerald

from databasez import Database

database = Database("sqlite:///foo.sqlite")


app = Esmerald(routes=[])


@app.on_event("startup")
async def startup():
await database.connect()


@app.on_event("shutdown")
async def shutdown():
await database.disconnect()
18 changes: 18 additions & 0 deletions docs_src/integrations/starlette.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import contextlib

from starlette.applications import Starlette

from databasez import Database

database = Database("sqlite:///foo.sqlite")


@contextlib.asynccontextmanager
async def lifespan(app):
async with database:
yield


application = Starlette(
lifespan=lifespan,
)
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ plugins:

nav:
- Databasez: "index.md"
- Integrations: "integrations.md"
- Queries: "queries.md"
- Connections & Transactions: "connections-and-transactions.md"
- Test Client: "test-client.md"
Expand Down
64 changes: 62 additions & 2 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def database_url(request):
loop.run_until_complete(stop_database_client(database, metadata))


def get_app(database_url):
def get_starlette_app(database_url):
database = Database(database_url, force_rollback=True)

@contextlib.asynccontextmanager
Expand Down Expand Up @@ -67,6 +67,65 @@ async def add_note(request):
return app


def get_asgi_app(database_url):
database = Database(database_url, force_rollback=True)

async def list_notes(request):
query = notes.select()
results = await database.fetch_all(query)
content = [{"text": result.text, "completed": result.completed} for result in results]
return JSONResponse(content)

async def add_note(request):
data = await request.json()
query = notes.insert().values(text=data["text"], completed=data["completed"])
await database.execute(query)
return JSONResponse({"text": data["text"], "completed": data["completed"]})

app = database.asgi(
Starlette(
routes=[
Route("/notes", endpoint=list_notes, methods=["GET"]),
Route("/notes", endpoint=add_note, methods=["POST"]),
],
)
)

return app


def get_asgi_no_lifespan(database_url):
database = Database(database_url, force_rollback=True)

@contextlib.asynccontextmanager
async def lifespan(app):
raise

async def list_notes(request):
query = notes.select()
results = await database.fetch_all(query)
content = [{"text": result.text, "completed": result.completed} for result in results]
return JSONResponse(content)

async def add_note(request):
data = await request.json()
query = notes.insert().values(text=data["text"], completed=data["completed"])
await database.execute(query)
return JSONResponse({"text": data["text"], "completed": data["completed"]})

app = database.asgi(handle_lifespan=True)(
Starlette(
lifespan=lifespan,
routes=[
Route("/notes", endpoint=list_notes, methods=["GET"]),
Route("/notes", endpoint=add_note, methods=["POST"]),
],
)
)

return app


def get_esmerald_app(database_url):
database = Database(database_url, force_rollback=True)

Expand Down Expand Up @@ -97,7 +156,8 @@ async def shutdown():
return app


def test_integration(database_url):
@pytest.mark.parametrize("get_app", [get_starlette_app, get_asgi_app, get_asgi_no_lifespan])
def test_integration(database_url, get_app):
app = get_app(database_url)

with TestClient(app) as client:
Expand Down

0 comments on commit f995d6a

Please sign in to comment.