diff --git a/pyensign/connection.py b/pyensign/connection.py index d12300d..78be802 100644 --- a/pyensign/connection.py +++ b/pyensign/connection.py @@ -142,7 +142,7 @@ def _ensure_ready(self): is a bit easier. """ - if not self.channel: + if not self.channel or self.channel._loop.is_closed(): self.channel = self.connection.create_channel() self.stub = ensign_pb2_grpc.EnsignStub(self.channel) diff --git a/pyensign/ml/loader.py b/pyensign/ml/loader.py index 63687c1..ff9b9eb 100644 --- a/pyensign/ml/loader.py +++ b/pyensign/ml/loader.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, Iterator, List, Any +from typing import AsyncIterator, Iterator, List, Any, Coroutine try: from langchain_core.document_loaders import BaseLoader @@ -8,6 +8,7 @@ raise ImportError("please `pip install pyensign[ml]` to use the EnsignLoader") from pyensign.ensign import Ensign +from pyensign.sync import async_to_sync class EnsignLoader(BaseLoader): @@ -47,14 +48,42 @@ def _convert_to_document(self, event: Any) -> Document: metadata=event.meta, ) - def load(self) -> Iterator[Document]: + async def _get_event_cursor(self, limit: int = 0, offset: int = 0) -> Coroutine: + """ + Returns an event cursor that retrieves all events from the Ensign topic. + """ + + query = "SELECT * FROM {}".format(self.topic) + if limit > 0: + query += " LIMIT {}".format(limit) + query += " OFFSET {}".format(offset) + return await self.ensign.query(query) + + async def _load_documents(self) -> List[Document]: + """ + A helper coroutine that loads all documents from Ensign into memory. + """ + cursor = await self._get_event_cursor() + events = await cursor.fetchall() + return [self._convert_to_document(event) for event in events] + + async def _fetchone(self, limit: int = 0, offset: int = 0) -> Document: + """ + Fetch one document from Ensign. + """ + + cursor = await self._get_event_cursor(limit=limit, offset=offset) + return await cursor.fetchone() + + def load(self) -> List[Document]: """ Loads all documents from Ensign into memory. Note: In production, this should be used with caution since it loads all documents into memory at once. """ - raise NotImplementedError + # FIXME: This blocks forever in Jupyter notebooks. + return async_to_sync(self._load_documents)() async def aload(self) -> List[Document]: """ @@ -62,22 +91,33 @@ async def aload(self) -> List[Document]: TODO: Prevent SQL injection with the topic name. """ - cursor = await self.ensign.query("SELECT * FROM {}".format(self.topic)) - events = await cursor.fetchall() - return [self._convert_to_document(event) for event in events] + return await self._load_documents() def lazy_load(self) -> Iterator[Document]: """ - Load documents from Ensign one by one lazily. + Load documents from Ensign one by one lazily. This method will be slower than + alazy_load because a new connection must be established to retrieve each + document. This is because Ensign queries currently only exist in an async + context. """ - raise NotImplementedError + offset = 0 + while True: + # FIXME: This blocks forever in Jupyter notebooks. + event = async_to_sync(self._fetchone)(limit=1, offset=offset) + if event is None: + break + offset += 1 + yield self._convert_to_document(event) async def alazy_load( self, - ) -> AsyncIterator[Document]: # <-- Does not take any arguments + ) -> AsyncIterator[Document]: """ - Load documents from Ensign one by one lazily. This is an async generator. + An async generator that yields documents from Ensign one by one lazily. This is + the recommended method for loading documents in production. """ - raise NotImplementedError + cursor = await self._get_event_cursor() + async for event in cursor: + yield self._convert_to_document(event) diff --git a/tests/pyensign/ml/test_loader.py b/tests/pyensign/ml/test_loader.py index 4d3538a..7b0dee6 100644 --- a/tests/pyensign/ml/test_loader.py +++ b/tests/pyensign/ml/test_loader.py @@ -1,10 +1,41 @@ +import os import pytest +from unittest import mock from asyncmock import patch from pyensign.events import Event +from pyensign.connection import Client, Connection from pyensign.ml.loader import EnsignLoader +@pytest.fixture +def live(request): + return request.config.getoption("--live") + + +@pytest.fixture +def authserver(): + return os.environ.get("ENSIGN_AUTH_SERVER") + + +@pytest.fixture +def ensignserver(): + return os.environ.get("ENSIGN_SERVER") + + +@pytest.fixture +def creds(): + return { + "client_id": os.environ.get("ENSIGN_CLIENT_ID"), + "client_secret": os.environ.get("ENSIGN_CLIENT_SECRET"), + } + + +@pytest.fixture +def topic(): + return os.environ.get("ENSIGN_TOPIC") + + class TestEnsignLoader: """ Tests for the EnsignLoader class. @@ -23,9 +54,49 @@ def test_init_errors(self, topic, kwargs): arguments. """ - with pytest.raises(ValueError): + with pytest.raises(ValueError), mock.patch.dict(os.environ, {}, clear=True): EnsignLoader(topic, **kwargs) + @patch("pyensign.ensign.Ensign.query") + @pytest.mark.parametrize( + "events", + [ + ([Event(b"Hello, world!", mimetype="text/plain", meta={"id": "23"})]), + ( + [ + Event( + b'{"content": "hello"}', + mimetype="application/json", + meta={"page": "42"}, + ), + Event( + b"