Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement document loader methods #94

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyensign/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _ensure_ready(self):
is a bit easier.
"""

if not self.channel:
if not self.channel or self.channel._loop.is_closed():
self.channel = self.connection.create_channel()
self.stub = ensign_pb2_grpc.EnsignStub(self.channel)

Expand Down
62 changes: 51 additions & 11 deletions pyensign/ml/loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import AsyncIterator, Iterator, List, Any
from typing import AsyncIterator, Iterator, List, Any, Coroutine

try:
from langchain_core.document_loaders import BaseLoader
Expand All @@ -8,6 +8,7 @@
raise ImportError("please `pip install pyensign[ml]` to use the EnsignLoader")

from pyensign.ensign import Ensign
from pyensign.sync import async_to_sync


class EnsignLoader(BaseLoader):
Expand Down Expand Up @@ -47,37 +48,76 @@ def _convert_to_document(self, event: Any) -> Document:
metadata=event.meta,
)

def load(self) -> Iterator[Document]:
async def _get_event_cursor(self, limit: int = 0, offset: int = 0) -> Coroutine:
"""
Returns an event cursor that retrieves all events from the Ensign topic.
"""

query = "SELECT * FROM {}".format(self.topic)
if limit > 0:
query += " LIMIT {}".format(limit)
query += " OFFSET {}".format(offset)
return await self.ensign.query(query)

async def _load_documents(self) -> List[Document]:
"""
A helper coroutine that loads all documents from Ensign into memory.
"""
cursor = await self._get_event_cursor()
events = await cursor.fetchall()
return [self._convert_to_document(event) for event in events]

async def _fetchone(self, limit: int = 0, offset: int = 0) -> Document:
"""
Fetch one document from Ensign.
"""

cursor = await self._get_event_cursor(limit=limit, offset=offset)
return await cursor.fetchone()

def load(self) -> List[Document]:
"""
Loads all documents from Ensign into memory.
Note: In production, this should be used with caution since it loads all
documents into memory at once.
"""

raise NotImplementedError
# FIXME: This blocks forever in Jupyter notebooks.
return async_to_sync(self._load_documents)()

async def aload(self) -> List[Document]:
"""
Load all documents from Ensign into memory asynchronously.
TODO: Prevent SQL injection with the topic name.
"""

cursor = await self.ensign.query("SELECT * FROM {}".format(self.topic))
events = await cursor.fetchall()
return [self._convert_to_document(event) for event in events]
return await self._load_documents()

def lazy_load(self) -> Iterator[Document]:
"""
Load documents from Ensign one by one lazily.
Load documents from Ensign one by one lazily. This method will be slower than
alazy_load because a new connection must be established to retrieve each
document. This is because Ensign queries currently only exist in an async
context.
"""

raise NotImplementedError
offset = 0
while True:
# FIXME: This blocks forever in Jupyter notebooks.
event = async_to_sync(self._fetchone)(limit=1, offset=offset)
if event is None:
break
offset += 1
yield self._convert_to_document(event)

async def alazy_load(
self,
) -> AsyncIterator[Document]: # <-- Does not take any arguments
) -> AsyncIterator[Document]:
"""
Load documents from Ensign one by one lazily. This is an async generator.
An async generator that yields documents from Ensign one by one lazily. This is
the recommended method for loading documents in production.
"""

raise NotImplementedError
cursor = await self._get_event_cursor()
async for event in cursor:
yield self._convert_to_document(event)
219 changes: 217 additions & 2 deletions tests/pyensign/ml/test_loader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,41 @@
import os
import pytest
from unittest import mock
from asyncmock import patch

from pyensign.events import Event
from pyensign.connection import Client, Connection
from pyensign.ml.loader import EnsignLoader


@pytest.fixture
def live(request):
return request.config.getoption("--live")


@pytest.fixture
def authserver():
return os.environ.get("ENSIGN_AUTH_SERVER")


@pytest.fixture
def ensignserver():
return os.environ.get("ENSIGN_SERVER")


@pytest.fixture
def creds():
return {
"client_id": os.environ.get("ENSIGN_CLIENT_ID"),
"client_secret": os.environ.get("ENSIGN_CLIENT_SECRET"),
}


@pytest.fixture
def topic():
return os.environ.get("ENSIGN_TOPIC")


class TestEnsignLoader:
"""
Tests for the EnsignLoader class.
Expand All @@ -23,9 +54,49 @@ def test_init_errors(self, topic, kwargs):
arguments.
"""

with pytest.raises(ValueError):
with pytest.raises(ValueError), mock.patch.dict(os.environ, {}, clear=True):
EnsignLoader(topic, **kwargs)

@patch("pyensign.ensign.Ensign.query")
@pytest.mark.parametrize(
"events",
[
([Event(b"Hello, world!", mimetype="text/plain", meta={"id": "23"})]),
(
[
Event(
b'{"content": "hello"}',
mimetype="application/json",
meta={"page": "42"},
),
Event(
b"<h1>Hello, world!</h1>",
mimetype="text/html",
meta={"page": "23"},
),
]
),
],
)
def test_load(self, mock_query, events):
"""
Test loading all documents from an Ensign topic
"""

loader = EnsignLoader(
"otters",
client_id="my_client_id",
client_secret="my_client_secret",
)
mock_query.return_value.fetchall.return_value = events
documents = loader.load()
args, _ = mock_query.call_args
assert args[0] == "SELECT * FROM otters OFFSET 0"
assert len(documents) == len(events)
for i, document in enumerate(documents):
assert document.page_content == events[i].data.decode()
assert document.metadata == events[i].meta

@patch("pyensign.ensign.Ensign.query")
@pytest.mark.parametrize(
"events",
Expand Down Expand Up @@ -60,8 +131,152 @@ async def test_aload(self, mock_query, events):
mock_query.return_value.fetchall.return_value = events
documents = await loader.aload()
args, _ = mock_query.call_args
assert args[0] == "SELECT * FROM otters"
assert args[0] == "SELECT * FROM otters OFFSET 0"
assert len(documents) == len(events)
for i, document in enumerate(documents):
assert document.page_content == events[i].data.decode()
assert document.metadata == events[i].meta

@patch("pyensign.ensign.Ensign.query")
@pytest.mark.parametrize(
"events",
[
([Event(b"Hello, world!", mimetype="text/plain", meta={"id": "23"})]),
(
[
Event(
b'{"content": "hello"}',
mimetype="application/json",
meta={"page": "42"},
),
Event(
b"<h1>Hello, world!</h1>",
mimetype="text/html",
meta={"page": "23"},
),
]
),
],
)
def test_lazy_load(self, mock_query, events):
"""
Test lazily loading documents from an Ensign topic
"""

loader = EnsignLoader(
"otters",
client_id="my_client_id",
client_secret="my_client_secret",
)
mock_query.return_value.fetchone.side_effect = events + [None]
i = 0
for document in loader.lazy_load():
print(document.page_content)
assert document.page_content == events[i].data.decode()
assert document.metadata == events[i].meta
args, _ = mock_query.call_args
assert args[0] == "SELECT * FROM otters LIMIT 1 OFFSET {}".format(i)
i += 1
assert i == len(events)

@patch("pyensign.ensign.Ensign.query")
@pytest.mark.parametrize(
"events",
[
([Event(b"Hello, world!", mimetype="text/plain", meta={"id": "23"})]),
(
[
Event(
b'{"content": "hello"}',
mimetype="application/json",
meta={"page": "42"},
),
Event(
b"<h1>Hello, world!</h1>",
mimetype="text/html",
meta={"page": "23"},
),
]
),
],
)
async def test_alazy_load(self, mock_query, events):
"""
Test lazily loading documents asynchronously from an Ensign topic
"""

loader = EnsignLoader(
"otters",
client_id="my_client_id",
client_secret="my_client_secret",
)
mock_query.return_value.__aiter__.return_value = events
i = 0
async for document in loader.alazy_load():
assert document.page_content == events[i].data.decode()
assert document.metadata == events[i].meta
i += 1
args, _ = mock_query.call_args
assert args[0] == "SELECT * FROM otters OFFSET 0"
assert i == len(events)

def test_live_lazy_load(self, live, ensignserver, authserver, creds, topic):
"""
Test loading documents from a live Ensign connection. This test exists to
assert that the live connection recovers automatically during a lazy load,
which is necessary because the lazy_load method closes the connection after
each iteration.
"""

if not live:
pytest.skip("Skipping live test")
if not ensignserver:
pytest.fail("ENSIGN_SERVER environment variable not set")
if not authserver:
pytest.fail("ENSIGN_AUTH_SERVER environment variable not set")
if not creds:
pytest.fail(
"ENSIGN_CLIENT_ID and ENSIGN_CLIENT_SECRET environment variables not set"
)

loader = EnsignLoader(
topic,
content_field="wiki_text",
client_id=creds["client_id"],
client_secret=creds["client_secret"],
endpoint=ensignserver,
auth_url=authserver,
)

# Test lazy load
for document in loader.load():
assert len(document.page_content) > 0

async def test_async_live_load(self, live, ensignserver, authserver, creds, topic):
"""
Test loading documents from a live Ensign connection.
"""

if not live:
pytest.skip("Skipping live test")
if not ensignserver:
pytest.fail("ENSIGN_SERVER environment variable not set")
if not authserver:
pytest.fail("ENSIGN_AUTH_SERVER environment variable not set")
if not creds:
pytest.fail(
"ENSIGN_CLIENT_ID and ENSIGN_CLIENT_SECRET environment variables not set"
)

loader = EnsignLoader(
topic,
content_field="wiki_text",
client_id=creds["client_id"],
client_secret=creds["client_secret"],
endpoint=ensignserver,
auth_url=authserver,
)

# Test batch load
documents = await loader.aload()
assert len(documents) > 0
Loading