diff --git a/pyensign/connection.py b/pyensign/connection.py index d12300d..78be802 100644 --- a/pyensign/connection.py +++ b/pyensign/connection.py @@ -142,7 +142,7 @@ def _ensure_ready(self): is a bit easier. """ - if not self.channel: + if not self.channel or self.channel._loop.is_closed(): self.channel = self.connection.create_channel() self.stub = ensign_pb2_grpc.EnsignStub(self.channel) diff --git a/pyensign/ml/loader.py b/pyensign/ml/loader.py index 63687c1..ff9b9eb 100644 --- a/pyensign/ml/loader.py +++ b/pyensign/ml/loader.py @@ -1,4 +1,4 @@ -from typing import AsyncIterator, Iterator, List, Any +from typing import AsyncIterator, Iterator, List, Any, Coroutine try: from langchain_core.document_loaders import BaseLoader @@ -8,6 +8,7 @@ raise ImportError("please `pip install pyensign[ml]` to use the EnsignLoader") from pyensign.ensign import Ensign +from pyensign.sync import async_to_sync class EnsignLoader(BaseLoader): @@ -47,14 +48,42 @@ def _convert_to_document(self, event: Any) -> Document: metadata=event.meta, ) - def load(self) -> Iterator[Document]: + async def _get_event_cursor(self, limit: int = 0, offset: int = 0) -> Coroutine: + """ + Returns an event cursor that retrieves all events from the Ensign topic. + """ + + query = "SELECT * FROM {}".format(self.topic) + if limit > 0: + query += " LIMIT {}".format(limit) + query += " OFFSET {}".format(offset) + return await self.ensign.query(query) + + async def _load_documents(self) -> List[Document]: + """ + A helper coroutine that loads all documents from Ensign into memory. + """ + cursor = await self._get_event_cursor() + events = await cursor.fetchall() + return [self._convert_to_document(event) for event in events] + + async def _fetchone(self, limit: int = 0, offset: int = 0) -> Document: + """ + Fetch one document from Ensign. + """ + + cursor = await self._get_event_cursor(limit=limit, offset=offset) + return await cursor.fetchone() + + def load(self) -> List[Document]: """ Loads all documents from Ensign into memory. Note: In production, this should be used with caution since it loads all documents into memory at once. """ - raise NotImplementedError + # FIXME: This blocks forever in Jupyter notebooks. + return async_to_sync(self._load_documents)() async def aload(self) -> List[Document]: """ @@ -62,22 +91,33 @@ async def aload(self) -> List[Document]: TODO: Prevent SQL injection with the topic name. """ - cursor = await self.ensign.query("SELECT * FROM {}".format(self.topic)) - events = await cursor.fetchall() - return [self._convert_to_document(event) for event in events] + return await self._load_documents() def lazy_load(self) -> Iterator[Document]: """ - Load documents from Ensign one by one lazily. + Load documents from Ensign one by one lazily. This method will be slower than + alazy_load because a new connection must be established to retrieve each + document. This is because Ensign queries currently only exist in an async + context. """ - raise NotImplementedError + offset = 0 + while True: + # FIXME: This blocks forever in Jupyter notebooks. + event = async_to_sync(self._fetchone)(limit=1, offset=offset) + if event is None: + break + offset += 1 + yield self._convert_to_document(event) async def alazy_load( self, - ) -> AsyncIterator[Document]: # <-- Does not take any arguments + ) -> AsyncIterator[Document]: """ - Load documents from Ensign one by one lazily. This is an async generator. + An async generator that yields documents from Ensign one by one lazily. This is + the recommended method for loading documents in production. """ - raise NotImplementedError + cursor = await self._get_event_cursor() + async for event in cursor: + yield self._convert_to_document(event) diff --git a/tests/pyensign/ml/test_loader.py b/tests/pyensign/ml/test_loader.py index 4d3538a..7b0dee6 100644 --- a/tests/pyensign/ml/test_loader.py +++ b/tests/pyensign/ml/test_loader.py @@ -1,10 +1,41 @@ +import os import pytest +from unittest import mock from asyncmock import patch from pyensign.events import Event +from pyensign.connection import Client, Connection from pyensign.ml.loader import EnsignLoader +@pytest.fixture +def live(request): + return request.config.getoption("--live") + + +@pytest.fixture +def authserver(): + return os.environ.get("ENSIGN_AUTH_SERVER") + + +@pytest.fixture +def ensignserver(): + return os.environ.get("ENSIGN_SERVER") + + +@pytest.fixture +def creds(): + return { + "client_id": os.environ.get("ENSIGN_CLIENT_ID"), + "client_secret": os.environ.get("ENSIGN_CLIENT_SECRET"), + } + + +@pytest.fixture +def topic(): + return os.environ.get("ENSIGN_TOPIC") + + class TestEnsignLoader: """ Tests for the EnsignLoader class. @@ -23,9 +54,49 @@ def test_init_errors(self, topic, kwargs): arguments. """ - with pytest.raises(ValueError): + with pytest.raises(ValueError), mock.patch.dict(os.environ, {}, clear=True): EnsignLoader(topic, **kwargs) + @patch("pyensign.ensign.Ensign.query") + @pytest.mark.parametrize( + "events", + [ + ([Event(b"Hello, world!", mimetype="text/plain", meta={"id": "23"})]), + ( + [ + Event( + b'{"content": "hello"}', + mimetype="application/json", + meta={"page": "42"}, + ), + Event( + b"

Hello, world!

", + mimetype="text/html", + meta={"page": "23"}, + ), + ] + ), + ], + ) + def test_load(self, mock_query, events): + """ + Test loading all documents from an Ensign topic + """ + + loader = EnsignLoader( + "otters", + client_id="my_client_id", + client_secret="my_client_secret", + ) + mock_query.return_value.fetchall.return_value = events + documents = loader.load() + args, _ = mock_query.call_args + assert args[0] == "SELECT * FROM otters OFFSET 0" + assert len(documents) == len(events) + for i, document in enumerate(documents): + assert document.page_content == events[i].data.decode() + assert document.metadata == events[i].meta + @patch("pyensign.ensign.Ensign.query") @pytest.mark.parametrize( "events", @@ -60,8 +131,152 @@ async def test_aload(self, mock_query, events): mock_query.return_value.fetchall.return_value = events documents = await loader.aload() args, _ = mock_query.call_args - assert args[0] == "SELECT * FROM otters" + assert args[0] == "SELECT * FROM otters OFFSET 0" assert len(documents) == len(events) for i, document in enumerate(documents): assert document.page_content == events[i].data.decode() assert document.metadata == events[i].meta + + @patch("pyensign.ensign.Ensign.query") + @pytest.mark.parametrize( + "events", + [ + ([Event(b"Hello, world!", mimetype="text/plain", meta={"id": "23"})]), + ( + [ + Event( + b'{"content": "hello"}', + mimetype="application/json", + meta={"page": "42"}, + ), + Event( + b"

Hello, world!

", + mimetype="text/html", + meta={"page": "23"}, + ), + ] + ), + ], + ) + def test_lazy_load(self, mock_query, events): + """ + Test lazily loading documents from an Ensign topic + """ + + loader = EnsignLoader( + "otters", + client_id="my_client_id", + client_secret="my_client_secret", + ) + mock_query.return_value.fetchone.side_effect = events + [None] + i = 0 + for document in loader.lazy_load(): + print(document.page_content) + assert document.page_content == events[i].data.decode() + assert document.metadata == events[i].meta + args, _ = mock_query.call_args + assert args[0] == "SELECT * FROM otters LIMIT 1 OFFSET {}".format(i) + i += 1 + assert i == len(events) + + @patch("pyensign.ensign.Ensign.query") + @pytest.mark.parametrize( + "events", + [ + ([Event(b"Hello, world!", mimetype="text/plain", meta={"id": "23"})]), + ( + [ + Event( + b'{"content": "hello"}', + mimetype="application/json", + meta={"page": "42"}, + ), + Event( + b"

Hello, world!

", + mimetype="text/html", + meta={"page": "23"}, + ), + ] + ), + ], + ) + async def test_alazy_load(self, mock_query, events): + """ + Test lazily loading documents asynchronously from an Ensign topic + """ + + loader = EnsignLoader( + "otters", + client_id="my_client_id", + client_secret="my_client_secret", + ) + mock_query.return_value.__aiter__.return_value = events + i = 0 + async for document in loader.alazy_load(): + assert document.page_content == events[i].data.decode() + assert document.metadata == events[i].meta + i += 1 + args, _ = mock_query.call_args + assert args[0] == "SELECT * FROM otters OFFSET 0" + assert i == len(events) + + def test_live_lazy_load(self, live, ensignserver, authserver, creds, topic): + """ + Test loading documents from a live Ensign connection. This test exists to + assert that the live connection recovers automatically during a lazy load, + which is necessary because the lazy_load method closes the connection after + each iteration. + """ + + if not live: + pytest.skip("Skipping live test") + if not ensignserver: + pytest.fail("ENSIGN_SERVER environment variable not set") + if not authserver: + pytest.fail("ENSIGN_AUTH_SERVER environment variable not set") + if not creds: + pytest.fail( + "ENSIGN_CLIENT_ID and ENSIGN_CLIENT_SECRET environment variables not set" + ) + + loader = EnsignLoader( + topic, + content_field="wiki_text", + client_id=creds["client_id"], + client_secret=creds["client_secret"], + endpoint=ensignserver, + auth_url=authserver, + ) + + # Test lazy load + for document in loader.load(): + assert len(document.page_content) > 0 + + async def test_async_live_load(self, live, ensignserver, authserver, creds, topic): + """ + Test loading documents from a live Ensign connection. + """ + + if not live: + pytest.skip("Skipping live test") + if not ensignserver: + pytest.fail("ENSIGN_SERVER environment variable not set") + if not authserver: + pytest.fail("ENSIGN_AUTH_SERVER environment variable not set") + if not creds: + pytest.fail( + "ENSIGN_CLIENT_ID and ENSIGN_CLIENT_SECRET environment variables not set" + ) + + loader = EnsignLoader( + topic, + content_field="wiki_text", + client_id=creds["client_id"], + client_secret=creds["client_secret"], + endpoint=ensignserver, + auth_url=authserver, + ) + + # Test batch load + documents = await loader.aload() + assert len(documents) > 0