Skip to content

Commit

Permalink
fix(channels): use SQL function for psycopg calls
Browse files Browse the repository at this point in the history
  • Loading branch information
cofin committed Dec 30, 2024
1 parent ebdb6e5 commit 4ab7a11
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions litestar/channels/backends/psycopg.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,48 @@
from __future__ import annotations

from contextlib import AsyncExitStack
from typing import AsyncGenerator, Iterable
from typing import Any, AsyncGenerator, Iterable

import psycopg
from psycopg import AsyncConnection
from psycopg.sql import SQL, Identifier

from .base import ChannelsBackend


def _safe_quote(ident: str) -> str:
return '"{}"'.format(ident.replace('"', '""')) # sourcery skip


class PsycoPgChannelsBackend(ChannelsBackend):
_listener_conn: psycopg.AsyncConnection
_listener_conn: AsyncConnection[Any]

def __init__(self, pg_dsn: str) -> None:
self._pg_dsn = pg_dsn
self._subscribed_channels: set[str] = set()
self._exit_stack = AsyncExitStack()

async def on_startup(self) -> None:
self._listener_conn = await psycopg.AsyncConnection.connect(self._pg_dsn, autocommit=True)
self._listener_conn = await AsyncConnection[Any].connect(self._pg_dsn, autocommit=True)
await self._exit_stack.enter_async_context(self._listener_conn)

async def on_shutdown(self) -> None:
await self._exit_stack.aclose()

async def publish(self, data: bytes, channels: Iterable[str]) -> None:
dec_data = data.decode("utf-8")
async with await psycopg.AsyncConnection.connect(self._pg_dsn) as conn:
async with await AsyncConnection[Any].connect(self._pg_dsn) as conn:
for channel in channels:
await conn.execute("SELECT pg_notify(%s, %s);", (channel, dec_data))
await conn.execute(SQL("SELECT pg_notify(%s, %s);").format(Identifier(channel), dec_data))

async def subscribe(self, channels: Iterable[str]) -> None:
for channel in set(channels) - self._subscribed_channels:
# can't use placeholders in LISTEN
await self._listener_conn.execute(f"LISTEN {_safe_quote(channel)};") # pyright: ignore
channels_to_subscribe = set(channels) - self._subscribed_channels
if not channels_to_subscribe:
return

for channel in channels_to_subscribe:
await self._listener_conn.execute(SQL("LISTEN {}").format(Identifier(channel)))

self._subscribed_channels.add(channel)

async def unsubscribe(self, channels: Iterable[str]) -> None:
for channel in channels:
# can't use placeholders in UNLISTEN
await self._listener_conn.execute(f"UNLISTEN {_safe_quote(channel)};") # pyright: ignore
await self._listener_conn.execute(SQL("UNLISTEN {}").format(Identifier(channel)))
self._subscribed_channels = self._subscribed_channels - set(channels)

async def stream_events(self) -> AsyncGenerator[tuple[str, bytes], None]:
Expand Down

0 comments on commit 4ab7a11

Please sign in to comment.