Skip to content

Commit

Permalink
Add metadata column to SQLiteYStore, fix update broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Nov 16, 2022
1 parent 0d9a54b commit 95ea62f
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 63 deletions.
1 change: 1 addition & 0 deletions ypy_websocket/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .websocket_provider import WebsocketProvider # noqa
from .websocket_server import WebsocketServer, YRoom # noqa
from .ydoc import YDoc # noqa
from .yutils import YMessageType # noqa

__version__ = "0.4.0"
3 changes: 2 additions & 1 deletion ypy_websocket/awareness.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def get_changes(self, message: bytes) -> Dict[str, Any]:
if client_id == self.client_id and self.states.get(client_id) is not None:
clock += 1
else:
del self.states[client_id]
if client_id in self.states:
del self.states[client_id]
else:
self.states[client_id] = state
self.meta[client_id] = {
Expand Down
44 changes: 22 additions & 22 deletions ypy_websocket/websocket_server.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
import asyncio
import logging
from functools import partial
from typing import Callable, Dict, List, Optional

import y_py as Y

from .awareness import Awareness
from .ydoc import YDoc
from .ystore import BaseYStore
from .yutils import put_updates, sync, update
from .yutils import sync, update


class YRoom:

clients: List
ydoc: Y.YDoc
ydoc: YDoc
ystore: Optional[BaseYStore]
_on_message: Optional[Callable]
_update_queue: asyncio.Queue
_ready: bool

def __init__(self, ready: bool = True, ystore: Optional[BaseYStore] = None):
self.ydoc = Y.YDoc()
self.awareness = Awareness(self.ydoc)
def __init__(self, ready: bool = True, ystore: Optional[BaseYStore] = None, log=None):
self._update_queue = asyncio.Queue()
self.ydoc = YDoc()
self.ydoc.init(self._update_queue) # FIXME: overriding Y.YDoc.__init__ doesn't seem to work
self.awareness = Awareness(self.ydoc)
self._ready = False
self.ready = ready
self.ystore = ystore
self.log = log or logging.getLogger(__name__)
self.clients = []
self._on_message = None
self._broadcast_task = asyncio.create_task(self._broadcast_updates())
Expand All @@ -36,7 +38,7 @@ def ready(self) -> bool:
def ready(self, value: bool) -> None:
self._ready = value
if value:
self.ydoc.observe_after_transaction(partial(put_updates, self._update_queue, self.ydoc))
self.ydoc._ready = True

@property
def on_message(self) -> Optional[Callable]:
Expand All @@ -47,17 +49,14 @@ def on_message(self, value: Optional[Callable]):
self._on_message = value

async def _broadcast_updates(self):
try:
while True:
update = await self._update_queue.get()
# broadcast internal ydoc's update to all clients
for client in self.clients:
try:
await client.send(update)
except Exception:
pass
except Exception:
pass
while True:
update = await self._update_queue.get()
# broadcast internal ydoc's update to all clients
for client in self.clients:
self.log.debug(
"Sending Y update from backend to client with endpoint: %s", client.path
)
asyncio.create_task(client.send(update))

def _clean(self):
self._broadcast_task.cancel()
Expand All @@ -76,7 +75,7 @@ def __init__(self, rooms_ready: bool = True, auto_clean_rooms: bool = True, log=

def get_room(self, path: str) -> YRoom:
if path not in self.rooms.keys():
self.rooms[path] = YRoom(ready=self.rooms_ready)
self.rooms[path] = YRoom(ready=self.rooms_ready, log=self.log)
return self.rooms[path]

def get_room_name(self, room):
Expand Down Expand Up @@ -115,7 +114,8 @@ async def serve(self, websocket):
asyncio.create_task(update(message, room, websocket, self.log))
# forward messages to every other client in the background
for client in [c for c in room.clients if c != websocket]:
self.log.debug("Sending Y update to client with endpoint: %s", client.path)
self.log.debug("Sending Y update from client with endpoint: %s", websocket.path)
self.log.debug("... to client with endpoint: %s", client.path)
asyncio.create_task(client.send(message))
# remove this client
room.clients = [c for c in room.clients if c != websocket]
Expand Down
54 changes: 54 additions & 0 deletions ypy_websocket/ydoc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import asyncio
from types import TracebackType
from typing import Optional, Type

import y_py as Y

from .yutils import create_update_message


class YDoc(Y.YDoc):

_begin_transaction = Y.YDoc.begin_transaction
_update_queue: asyncio.Queue
_ready: bool

def init(self, update_queue: asyncio.Queue):
self._ready = False
self._update_queue = update_queue

def begin_transaction(self):
return Transaction(self, self._update_queue, self._ready)


class Transaction:

ydoc: YDoc
update_queue: asyncio.Queue
state: bytes
transaction: Y.YTransaction
ready: bool

def __init__(self, ydoc: YDoc, update_queue: asyncio.Queue, ready: bool):
self.ydoc = ydoc
self.update_queue = update_queue
self.ready = ready

def __enter__(self):
if self.ready:
self.state = Y.encode_state_vector(self.ydoc)
self.transaction = self.ydoc._begin_transaction()
return self.transaction.__enter__()

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
res = self.transaction.__exit__(exc_type, exc_value, exc_tb) # type: ignore
if self.ready:
update = Y.encode_state_as_update(self.ydoc, self.state)
message = create_update_message(update)
self.update_queue.put_nowait(message)
return res
87 changes: 47 additions & 40 deletions ypy_websocket/ystore.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from typing import AsyncIterator, Callable, Optional, Tuple

Expand All @@ -14,19 +15,21 @@ class YDocNotFound(Exception):
pass


class BaseYStore:
class BaseYStore(ABC):

metadata_callback: Optional[Callable] = None

@abstractmethod
def __init__(self, path: str, metadata_callback=None):
raise RuntimeError("Not implemented")
...

@abstractmethod
async def write(self, data: bytes) -> None:
raise RuntimeError("Not implemented")
...

@abstractmethod
async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]:
raise RuntimeError("Not implemented")
yield b"", b""
...

async def get_metadata(self) -> bytes:
metadata = b"" if not self.metadata_callback else await self.metadata_callback()
Expand All @@ -37,30 +40,35 @@ async def encode_state_as_update(self, ydoc: Y.YDoc):
await self.write(update)

async def apply_updates(self, ydoc: Y.YDoc):
async for update, metadata in self.read():
async for update, metadata in await self.read():
Y.apply_update(ydoc, update) # type: ignore


class FileYStore(BaseYStore):
"""A YStore which uses the local file system."""
"""A YStore which uses one file per document."""

path: str
metadata_callback: Optional[Callable]
lock: asyncio.Lock

def __init__(self, path: str, metadata_callback=None):
def __init__(self, path: str, metadata_callback: Optional[Callable] = None):
self.path = path
self.metadata_callback = metadata_callback

async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]:
try:
async with aiofiles.open(self.path, "rb") as f:
data = await f.read()
except Exception:
raise YDocNotFound
self.lock = asyncio.Lock()

async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore
async with self.lock:
try:
async with aiofiles.open(self.path, "rb") as f:
data = await f.read()
except BaseException:
raise YDocNotFound
is_data = True
for d in Decoder(data).read_messages():
if is_data:
update = d
else:
# yield data and metadata
yield update, d
is_data = not is_data

Expand All @@ -71,16 +79,18 @@ async def write(self, data: bytes) -> None:
mode = "wb"
else:
mode = "ab"
async with aiofiles.open(self.path, mode) as f:
data_len = write_var_uint(len(data))
await f.write(data_len + data)
metadata = await self.get_metadata()
metadata_len = write_var_uint(len(metadata))
await f.write(metadata_len + metadata)
async with self.lock:
async with aiofiles.open(self.path, mode) as f:
data_len = write_var_uint(len(data))
await f.write(data_len + data)
metadata = await self.get_metadata()
metadata_len = write_var_uint(len(metadata))
await f.write(metadata_len + metadata)


class TempFileYStore(FileYStore):
"""A YStore which uses the system's temporary directory.
Files are writen under a common directory.
To prefix the directory name (e.g. /tmp/my_prefix_b4whmm7y/):
class PrefixTempFileYStore(TempFileYStore):
Expand All @@ -90,7 +100,7 @@ class PrefixTempFileYStore(TempFileYStore):
prefix_dir: Optional[str] = None
base_dir: Optional[str] = None

def __init__(self, path: str, metadata_callback=None):
def __init__(self, path: str, metadata_callback: Optional[Callable] = None):
full_path = str(Path(self.get_base_dir()) / path)
super().__init__(full_path, metadata_callback=metadata_callback)

Expand All @@ -106,6 +116,8 @@ def make_directory(self):

class SQLiteYStore(BaseYStore):
"""A YStore which uses an SQLite database.
Unlike file-based YStores, the Y updates of all documents are stored in the same database.
Subclass to point to your database file:
class MySQLiteYStore(SQLiteYStore):
Expand All @@ -116,42 +128,37 @@ class MySQLiteYStore(SQLiteYStore):
path: str
db_created: asyncio.Event

def __init__(self, path: str, metadata_callback=None):
def __init__(self, path: str, metadata_callback: Optional[Callable] = None):
self.path = path
self.metadata_callback = metadata_callback
self.db_created = asyncio.Event()
asyncio.create_task(self.create_db())

async def create_db(self):
async with aiosqlite.connect(self.db_path) as db:
await db.execute("CREATE TABLE IF NOT EXISTS yupdates (path TEXT, yupdate BLOB)")
await db.execute(
"CREATE TABLE IF NOT EXISTS yupdates (path TEXT, yupdate BLOB, metadata BLOB)"
)
await db.commit()
self.db_created.set()

async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]:
async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore
await self.db_created.wait()
try:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
"SELECT * FROM yupdates WHERE path = ?", (self.path,)
"SELECT yupdate, metadata FROM yupdates WHERE path = ?", (self.path,)
) as cursor:
found = False
is_data = True
async for _, d in cursor:
found = True
if is_data:
update = d
else:
yield update, d
is_data = not is_data
if not found:
async for update, metadata in cursor:
yield update, metadata
else:
raise YDocNotFound
except Exception:
except BaseException:
raise YDocNotFound

async def write(self, data: bytes) -> None:
await self.db_created.wait()
metadata = await self.get_metadata()
async with aiosqlite.connect(self.db_path) as db:
await db.execute("INSERT INTO yupdates VALUES (?, ?)", (self.path, data))
metadata = await self.get_metadata()
await db.execute("INSERT INTO yupdates VALUES (?, ?)", (self.path, metadata))
await db.execute("INSERT INTO yupdates VALUES (?, ?, ?)", (self.path, data, metadata))
await db.commit()

0 comments on commit 95ea62f

Please sign in to comment.