Skip to content

Commit

Permalink
Merge pull request wagtail#80 from wagtail/add-async-support-to-vecto…
Browse files Browse the repository at this point in the history
…r-index

Add an async version of the query method to the VectorIndex class
  • Loading branch information
tomusher authored Jul 30, 2024
2 parents 412632f + d5e26a0 commit ea41b7d
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 27 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ filterwarnings = [
"ignore::DeprecationWarning",
]

[tool.ruff]
[tool.ruff.lint]
select = ["F", "E", "C90", "I", "B", "DJ", "RUF", "TRY", "C4", "TCH005", "TCH004"]
ignore = ["TRY003", "E501", "RUF012"]

Expand Down
3 changes: 2 additions & 1 deletion src/wagtail_vector_index/ai_utils/backends/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, NotRequired, Self

import litellm
import litellm.types.utils
from django.core.exceptions import ImproperlyConfigured

from ..types import (
Expand Down Expand Up @@ -174,7 +175,7 @@ class LiteLLMEmbeddingBackend(BaseEmbeddingBackend[LiteLLMEmbeddingBackendConfig
def embed(self, inputs: Iterable[str], **kwargs) -> Iterator[list[float]]:
response = litellm.embedding(model=self.config.model_id, input=inputs, **kwargs)
# LiteLLM *should* return an EmbeddingResponse
assert isinstance(response, litellm.EmbeddingResponse)
assert isinstance(response, litellm.types.utils.EmbeddingResponse)
yield from [data["embedding"] for data in response["data"]]

async def aembed(self, inputs: Iterable[str], **kwargs) -> Iterator[list[float]]:
Expand Down
5 changes: 5 additions & 0 deletions src/wagtail_vector_index/ai_utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,13 @@ class AIStreamingResponse:
def __iter__(self):
return self

def __aiter__(self):
return self

def __next__(self) -> AIResponseStreamingPart: ...

async def __anext__(self) -> AIResponseStreamingPart: ...


class AIResponse:
"""Representation of a non-streaming response from an AI backend.
Expand Down
63 changes: 62 additions & 1 deletion src/wagtail_vector_index/storage/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from collections.abc import Generator, Iterable, Mapping, Sequence
from collections.abc import AsyncGenerator, Generator, Iterable, Mapping, Sequence
from dataclasses import dataclass
from typing import Any, ClassVar, Generic, Protocol, TypeVar

Expand Down Expand Up @@ -94,6 +94,10 @@ def bulk_from_documents(
self, documents: Iterable[Document]
) -> Generator[object, None, None]: ...

async def abulk_from_documents(
self, documents: Iterable[Document]
) -> AsyncGenerator[object, None]: ...


@dataclass
class QueryResponse:
Expand All @@ -105,6 +109,14 @@ class QueryResponse:
sources: Iterable[object]


@dataclass
class AsyncQueryResponse:
"""Same as QueryResponse class, but with the response being an async generator."""

response: AsyncGenerator[str, None]
sources: Iterable[object]


class VectorIndex(Generic[ConfigClass]):
"""Base class for a VectorIndex, representing some set of documents that can be queried"""

Expand Down Expand Up @@ -152,6 +164,50 @@ def query(
response = chat_backend.chat(messages=messages)
return QueryResponse(response=response.choices[0], sources=sources)

async def aquery(
self, query: str, *, sources_limit: int = 5, chat_backend_alias: str = "default"
) -> AsyncQueryResponse:
"""
Replicates the features of `VectorIndex.query()`, but in an async way.
"""
try:
query_embedding = next(await self.get_embedding_backend().aembed([query]))
except IndexError as e:
raise ValueError("No embeddings were generated for the given query.") from e

similar_documents = [
doc async for doc in self.aget_similar_documents(query_embedding)
]

sources = [
source
async for source in self.get_converter().abulk_from_documents(
similar_documents
)
]

merged_context = "\n".join(doc.metadata["content"] for doc in similar_documents)
prompt = (
getattr(settings, "WAGTAIL_VECTOR_INDEX_QUERY_PROMPT", None)
or "You are a helpful assistant. Use the following context to answer the question. Don't mention the context in your answer."
)
messages = [
{"content": prompt, "role": "system"},
{"content": merged_context, "role": "system"},
{"content": query, "role": "user"},
]
chat_backend = get_chat_backend(chat_backend_alias)
response = await chat_backend.achat(messages=messages, stream=True)

async def async_stream_wrapper():
async for chunk in response:
yield chunk["content"]

return AsyncQueryResponse(
response=async_stream_wrapper(),
sources=sources,
)

def find_similar(
self, object, *, include_self: bool = False, limit: int = 5
) -> list:
Expand Down Expand Up @@ -209,3 +265,8 @@ def get_similar_documents(
self, query_vector: Sequence[float], *, limit: int = 5
) -> Generator[Document, None, None]:
raise NotImplementedError

def aget_similar_documents(
self, query_vector, *, limit: int = 5
) -> AsyncGenerator[Document, None]:
raise NotImplementedError
108 changes: 95 additions & 13 deletions src/wagtail_vector_index/storage/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from collections import defaultdict
from collections.abc import Generator, Iterable, MutableSequence, Sequence
from typing import TYPE_CHECKING, ClassVar, Optional, TypeVar, cast
from collections.abc import (
AsyncGenerator,
Generator,
Iterable,
MutableSequence,
Sequence,
)
from typing import TYPE_CHECKING, ClassVar, Optional, TypeAlias, TypeVar, cast

from django.apps import apps
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
Expand Down Expand Up @@ -39,6 +45,10 @@
- The EmbeddableFieldsDocumentConverter, which is a DocumentConverter that knows how to convert a model instance using the EmbeddableFieldsMixin protocol to and from a Document
"""

ContentTypeId: TypeAlias = str
ObjectId: TypeAlias = str
ModelKey: TypeAlias = tuple[ContentTypeId, ObjectId]


class Embedding(models.Model):
"""Stores an embedding for a model instance"""
Expand Down Expand Up @@ -183,6 +193,14 @@ def _model_class_from_ctid(id: str) -> type[models.Model]:
raise ValueError(f"Failed to find model class for {ct!r}")
return model_class

@classmethod
async def _amodel_class_from_ctid(cls, id: str) -> type[models.Model]:
ct = await cls._aget_content_type_for_id(int(id))
model_class = ct.model_class()
if model_class is None:
raise ValueError(f"Failed to find model class for {ct!r}")
return model_class

def from_document(self, document: Document) -> models.Model:
model_class = self._model_class_from_ctid(document.metadata["content_type_id"])
try:
Expand All @@ -193,25 +211,44 @@ def from_document(self, document: Document) -> models.Model:
def bulk_from_documents(
self, documents: Iterable[Document]
) -> Generator[models.Model, None, None]:
documents = tuple(documents)

ids_by_content_type = self._get_ids_by_content_type(documents)
objects_by_key = self._get_models_by_key(ids_by_content_type)

yield from self._get_deduplicated_objects_generator(documents, objects_by_key)

async def abulk_from_documents(
self, documents: Iterable[Document]
) -> AsyncGenerator[models.Model, None]:
"""A copy of `bulk_from_documents`, but async"""
# Force evaluate generators to allow value to be reused
documents = tuple(documents)

ids_by_content_type: dict[str, list[str]] = defaultdict(list)
ids_by_content_type = self._get_ids_by_content_type(documents)
objects_by_key = await self._aget_models_by_key(ids_by_content_type)

# N.B. `yield from` cannot be used in async functions, so we have to use a loop
for object_from_document in self._get_deduplicated_objects_generator(
documents, objects_by_key
):
yield object_from_document

@staticmethod
def _get_ids_by_content_type(
documents: Sequence[Document],
) -> dict[ContentTypeId, list[ObjectId]]:
ids_by_content_type = defaultdict(list)
for doc in documents:
ids_by_content_type[doc.metadata["content_type_id"]].append(
doc.metadata["object_id"]
)
return ids_by_content_type

# NOTE: (content_type_id, object_id) combo keys are required to
# reliably map data from multiple models
objects_by_key: dict[tuple[str, str], models.Model] = {}
for content_type_id, ids in ids_by_content_type.items():
model_class = self._model_class_from_ctid(content_type_id)
model_objects = model_class.objects.filter(pk__in=ids)
objects_by_key.update(
{(content_type_id, str(obj.pk)): obj for obj in model_objects}
)

@staticmethod
def _get_deduplicated_objects_generator(
documents: Sequence[Document], objects_by_key: dict[ModelKey, models.Model]
) -> Generator[models.Model, None, None]:
seen_keys = set() # de-dupe as we go
for doc in documents:
key = (doc.metadata["content_type_id"], doc.metadata["object_id"])
Expand All @@ -220,6 +257,51 @@ def bulk_from_documents(
seen_keys.add(key)
yield objects_by_key[key]

def _get_models_by_key(
self, ids_by_content_type: dict
) -> dict[ModelKey, models.Model]:
"""
(content_type_id, object_id) combo keys are required to reliably map data
from multiple models. This function loads the models from the database
and groups them by such a key.
"""
objects_by_key: dict[ModelKey, models.Model] = {}
for content_type_id, ids in ids_by_content_type.items():
model_class = self._model_class_from_ctid(content_type_id)
model_objects = model_class.objects.filter(pk__in=ids)
objects_by_key.update(
{(content_type_id, str(obj.pk)): obj for obj in model_objects}
)
return objects_by_key

async def _aget_models_by_key(
self, ids_by_content_type: dict
) -> dict[ModelKey, models.Model]:
"""
Same as `_get_models_by_key`, but async.
"""
objects_by_key: dict[ModelKey, models.Model] = {}
for content_type_id, ids in ids_by_content_type.items():
model_class = await self._amodel_class_from_ctid(content_type_id)
model_objects = model_class.objects.filter(pk__in=ids)
objects_by_key.update(
{(content_type_id, str(obj.pk)): obj async for obj in model_objects}
)
return objects_by_key

@staticmethod
async def _aget_content_type_for_id(id: int) -> ContentType:
"""
Same as `ContentTypeManager.get_for_id`, but async.
"""
manager = ContentType.objects
try:
ct = manager._cache[manager.db][id] # type: ignore[reportAttributeAccessIssue]
except KeyError:
ct = await manager.aget(pk=id)
manager._add_to_cache(manager.db, ct) # type: ignore[reportAttributeAccessIssue]
return ct


class EmbeddableFieldsDocumentConverter(DocumentToModelMixin):
"""Implementation of DocumentConverter that knows how to convert a model instance using the
Expand Down
43 changes: 32 additions & 11 deletions src/wagtail_vector_index/storage/pgvector/provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import logging
from collections.abc import Generator, Iterable, MutableSequence, Sequence
from collections.abc import (
AsyncGenerator,
Generator,
Iterable,
MutableSequence,
Sequence,
)
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, cast

Expand Down Expand Up @@ -76,16 +82,17 @@ def clear(self):
def get_similar_documents(
self, query_vector, *, limit: int = 5
) -> Generator[Document, None, None]:
for pgvector_embedding in (
self._get_queryset()
.select_related("embedding")
.filter(embedding_output_dimensions=len(query_vector))
.order_by_distance(
query_vector,
distance_method=self.distance_method,
fetch_distance=False,
)[:limit]
.iterator()
for pgvector_embedding in self._get_similar_documents_queryset(
query_vector, limit=limit
).iterator():
embedding = pgvector_embedding.embedding
yield embedding.to_document()

async def aget_similar_documents(
self, query_vector, *, limit: int = 5
) -> AsyncGenerator[Document, None]:
async for pgvector_embedding in self._get_similar_documents_queryset(
query_vector, limit=limit
):
embedding = pgvector_embedding.embedding
yield embedding.to_document()
Expand All @@ -97,6 +104,20 @@ def _get_queryset(self) -> "PgvectorEmbeddingQuerySet":
type(self).__name__
)

def _get_similar_documents_queryset(
self, query_vector: Sequence[float], *, limit: int
) -> "PgvectorEmbeddingQuerySet":
return (
self._get_queryset()
.select_related("embedding")
.filter(embedding_output_dimensions=len(query_vector))
.order_by_distance(
query_vector,
distance_method=self.distance_method,
fetch_distance=False,
)[:limit]
)

def _bulk_create(self, embeddings: Sequence["PgvectorEmbedding"]) -> None:
_embedding_model().objects.bulk_create(
embeddings,
Expand Down

0 comments on commit ea41b7d

Please sign in to comment.