diff --git a/pyproject.toml b/pyproject.toml index 503b308..5d98200 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ filterwarnings = [ "ignore::DeprecationWarning", ] -[tool.ruff] +[tool.ruff.lint] select = ["F", "E", "C90", "I", "B", "DJ", "RUF", "TRY", "C4", "TCH005", "TCH004"] ignore = ["TRY003", "E501", "RUF012"] diff --git a/src/wagtail_vector_index/ai_utils/backends/litellm.py b/src/wagtail_vector_index/ai_utils/backends/litellm.py index c8822b4..6031fed 100644 --- a/src/wagtail_vector_index/ai_utils/backends/litellm.py +++ b/src/wagtail_vector_index/ai_utils/backends/litellm.py @@ -3,6 +3,7 @@ from typing import Any, NotRequired, Self import litellm +import litellm.types.utils from django.core.exceptions import ImproperlyConfigured from ..types import ( @@ -174,7 +175,7 @@ class LiteLLMEmbeddingBackend(BaseEmbeddingBackend[LiteLLMEmbeddingBackendConfig def embed(self, inputs: Iterable[str], **kwargs) -> Iterator[list[float]]: response = litellm.embedding(model=self.config.model_id, input=inputs, **kwargs) # LiteLLM *should* return an EmbeddingResponse - assert isinstance(response, litellm.EmbeddingResponse) + assert isinstance(response, litellm.types.utils.EmbeddingResponse) yield from [data["embedding"] for data in response["data"]] async def aembed(self, inputs: Iterable[str], **kwargs) -> Iterator[list[float]]: diff --git a/src/wagtail_vector_index/ai_utils/types.py b/src/wagtail_vector_index/ai_utils/types.py index 7736b45..3caaeae 100644 --- a/src/wagtail_vector_index/ai_utils/types.py +++ b/src/wagtail_vector_index/ai_utils/types.py @@ -40,8 +40,13 @@ class AIStreamingResponse: def __iter__(self): return self + def __aiter__(self): + return self + def __next__(self) -> AIResponseStreamingPart: ... + async def __anext__(self) -> AIResponseStreamingPart: ... + class AIResponse: """Representation of a non-streaming response from an AI backend. diff --git a/src/wagtail_vector_index/storage/base.py b/src/wagtail_vector_index/storage/base.py index 779ab28..8ee06db 100644 --- a/src/wagtail_vector_index/storage/base.py +++ b/src/wagtail_vector_index/storage/base.py @@ -1,5 +1,5 @@ import copy -from collections.abc import Generator, Iterable, Mapping, Sequence +from collections.abc import AsyncGenerator, Generator, Iterable, Mapping, Sequence from dataclasses import dataclass from typing import Any, ClassVar, Generic, Protocol, TypeVar @@ -94,6 +94,10 @@ def bulk_from_documents( self, documents: Iterable[Document] ) -> Generator[object, None, None]: ... + async def abulk_from_documents( + self, documents: Iterable[Document] + ) -> AsyncGenerator[object, None]: ... + @dataclass class QueryResponse: @@ -105,6 +109,14 @@ class QueryResponse: sources: Iterable[object] +@dataclass +class AsyncQueryResponse: + """Same as QueryResponse class, but with the response being an async generator.""" + + response: AsyncGenerator[str, None] + sources: Iterable[object] + + class VectorIndex(Generic[ConfigClass]): """Base class for a VectorIndex, representing some set of documents that can be queried""" @@ -152,6 +164,50 @@ def query( response = chat_backend.chat(messages=messages) return QueryResponse(response=response.choices[0], sources=sources) + async def aquery( + self, query: str, *, sources_limit: int = 5, chat_backend_alias: str = "default" + ) -> AsyncQueryResponse: + """ + Replicates the features of `VectorIndex.query()`, but in an async way. + """ + try: + query_embedding = next(await self.get_embedding_backend().aembed([query])) + except IndexError as e: + raise ValueError("No embeddings were generated for the given query.") from e + + similar_documents = [ + doc async for doc in self.aget_similar_documents(query_embedding) + ] + + sources = [ + source + async for source in self.get_converter().abulk_from_documents( + similar_documents + ) + ] + + merged_context = "\n".join(doc.metadata["content"] for doc in similar_documents) + prompt = ( + getattr(settings, "WAGTAIL_VECTOR_INDEX_QUERY_PROMPT", None) + or "You are a helpful assistant. Use the following context to answer the question. Don't mention the context in your answer." + ) + messages = [ + {"content": prompt, "role": "system"}, + {"content": merged_context, "role": "system"}, + {"content": query, "role": "user"}, + ] + chat_backend = get_chat_backend(chat_backend_alias) + response = await chat_backend.achat(messages=messages, stream=True) + + async def async_stream_wrapper(): + async for chunk in response: + yield chunk["content"] + + return AsyncQueryResponse( + response=async_stream_wrapper(), + sources=sources, + ) + def find_similar( self, object, *, include_self: bool = False, limit: int = 5 ) -> list: @@ -209,3 +265,8 @@ def get_similar_documents( self, query_vector: Sequence[float], *, limit: int = 5 ) -> Generator[Document, None, None]: raise NotImplementedError + + def aget_similar_documents( + self, query_vector, *, limit: int = 5 + ) -> AsyncGenerator[Document, None]: + raise NotImplementedError diff --git a/src/wagtail_vector_index/storage/models.py b/src/wagtail_vector_index/storage/models.py index a56e3e8..e467541 100644 --- a/src/wagtail_vector_index/storage/models.py +++ b/src/wagtail_vector_index/storage/models.py @@ -1,6 +1,12 @@ from collections import defaultdict -from collections.abc import Generator, Iterable, MutableSequence, Sequence -from typing import TYPE_CHECKING, ClassVar, Optional, TypeVar, cast +from collections.abc import ( + AsyncGenerator, + Generator, + Iterable, + MutableSequence, + Sequence, +) +from typing import TYPE_CHECKING, ClassVar, Optional, TypeAlias, TypeVar, cast from django.apps import apps from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation @@ -39,6 +45,10 @@ - The EmbeddableFieldsDocumentConverter, which is a DocumentConverter that knows how to convert a model instance using the EmbeddableFieldsMixin protocol to and from a Document """ +ContentTypeId: TypeAlias = str +ObjectId: TypeAlias = str +ModelKey: TypeAlias = tuple[ContentTypeId, ObjectId] + class Embedding(models.Model): """Stores an embedding for a model instance""" @@ -183,6 +193,14 @@ def _model_class_from_ctid(id: str) -> type[models.Model]: raise ValueError(f"Failed to find model class for {ct!r}") return model_class + @classmethod + async def _amodel_class_from_ctid(cls, id: str) -> type[models.Model]: + ct = await cls._aget_content_type_for_id(int(id)) + model_class = ct.model_class() + if model_class is None: + raise ValueError(f"Failed to find model class for {ct!r}") + return model_class + def from_document(self, document: Document) -> models.Model: model_class = self._model_class_from_ctid(document.metadata["content_type_id"]) try: @@ -193,25 +211,44 @@ def from_document(self, document: Document) -> models.Model: def bulk_from_documents( self, documents: Iterable[Document] ) -> Generator[models.Model, None, None]: + documents = tuple(documents) + + ids_by_content_type = self._get_ids_by_content_type(documents) + objects_by_key = self._get_models_by_key(ids_by_content_type) + + yield from self._get_deduplicated_objects_generator(documents, objects_by_key) + + async def abulk_from_documents( + self, documents: Iterable[Document] + ) -> AsyncGenerator[models.Model, None]: + """A copy of `bulk_from_documents`, but async""" # Force evaluate generators to allow value to be reused documents = tuple(documents) - ids_by_content_type: dict[str, list[str]] = defaultdict(list) + ids_by_content_type = self._get_ids_by_content_type(documents) + objects_by_key = await self._aget_models_by_key(ids_by_content_type) + + # N.B. `yield from` cannot be used in async functions, so we have to use a loop + for object_from_document in self._get_deduplicated_objects_generator( + documents, objects_by_key + ): + yield object_from_document + + @staticmethod + def _get_ids_by_content_type( + documents: Sequence[Document], + ) -> dict[ContentTypeId, list[ObjectId]]: + ids_by_content_type = defaultdict(list) for doc in documents: ids_by_content_type[doc.metadata["content_type_id"]].append( doc.metadata["object_id"] ) + return ids_by_content_type - # NOTE: (content_type_id, object_id) combo keys are required to - # reliably map data from multiple models - objects_by_key: dict[tuple[str, str], models.Model] = {} - for content_type_id, ids in ids_by_content_type.items(): - model_class = self._model_class_from_ctid(content_type_id) - model_objects = model_class.objects.filter(pk__in=ids) - objects_by_key.update( - {(content_type_id, str(obj.pk)): obj for obj in model_objects} - ) - + @staticmethod + def _get_deduplicated_objects_generator( + documents: Sequence[Document], objects_by_key: dict[ModelKey, models.Model] + ) -> Generator[models.Model, None, None]: seen_keys = set() # de-dupe as we go for doc in documents: key = (doc.metadata["content_type_id"], doc.metadata["object_id"]) @@ -220,6 +257,51 @@ def bulk_from_documents( seen_keys.add(key) yield objects_by_key[key] + def _get_models_by_key( + self, ids_by_content_type: dict + ) -> dict[ModelKey, models.Model]: + """ + (content_type_id, object_id) combo keys are required to reliably map data + from multiple models. This function loads the models from the database + and groups them by such a key. + """ + objects_by_key: dict[ModelKey, models.Model] = {} + for content_type_id, ids in ids_by_content_type.items(): + model_class = self._model_class_from_ctid(content_type_id) + model_objects = model_class.objects.filter(pk__in=ids) + objects_by_key.update( + {(content_type_id, str(obj.pk)): obj for obj in model_objects} + ) + return objects_by_key + + async def _aget_models_by_key( + self, ids_by_content_type: dict + ) -> dict[ModelKey, models.Model]: + """ + Same as `_get_models_by_key`, but async. + """ + objects_by_key: dict[ModelKey, models.Model] = {} + for content_type_id, ids in ids_by_content_type.items(): + model_class = await self._amodel_class_from_ctid(content_type_id) + model_objects = model_class.objects.filter(pk__in=ids) + objects_by_key.update( + {(content_type_id, str(obj.pk)): obj async for obj in model_objects} + ) + return objects_by_key + + @staticmethod + async def _aget_content_type_for_id(id: int) -> ContentType: + """ + Same as `ContentTypeManager.get_for_id`, but async. + """ + manager = ContentType.objects + try: + ct = manager._cache[manager.db][id] # type: ignore[reportAttributeAccessIssue] + except KeyError: + ct = await manager.aget(pk=id) + manager._add_to_cache(manager.db, ct) # type: ignore[reportAttributeAccessIssue] + return ct + class EmbeddableFieldsDocumentConverter(DocumentToModelMixin): """Implementation of DocumentConverter that knows how to convert a model instance using the diff --git a/src/wagtail_vector_index/storage/pgvector/provider.py b/src/wagtail_vector_index/storage/pgvector/provider.py index bb82058..bdaae6c 100644 --- a/src/wagtail_vector_index/storage/pgvector/provider.py +++ b/src/wagtail_vector_index/storage/pgvector/provider.py @@ -1,5 +1,11 @@ import logging -from collections.abc import Generator, Iterable, MutableSequence, Sequence +from collections.abc import ( + AsyncGenerator, + Generator, + Iterable, + MutableSequence, + Sequence, +) from dataclasses import dataclass from typing import TYPE_CHECKING, Any, ClassVar, cast @@ -76,16 +82,17 @@ def clear(self): def get_similar_documents( self, query_vector, *, limit: int = 5 ) -> Generator[Document, None, None]: - for pgvector_embedding in ( - self._get_queryset() - .select_related("embedding") - .filter(embedding_output_dimensions=len(query_vector)) - .order_by_distance( - query_vector, - distance_method=self.distance_method, - fetch_distance=False, - )[:limit] - .iterator() + for pgvector_embedding in self._get_similar_documents_queryset( + query_vector, limit=limit + ).iterator(): + embedding = pgvector_embedding.embedding + yield embedding.to_document() + + async def aget_similar_documents( + self, query_vector, *, limit: int = 5 + ) -> AsyncGenerator[Document, None]: + async for pgvector_embedding in self._get_similar_documents_queryset( + query_vector, limit=limit ): embedding = pgvector_embedding.embedding yield embedding.to_document() @@ -97,6 +104,20 @@ def _get_queryset(self) -> "PgvectorEmbeddingQuerySet": type(self).__name__ ) + def _get_similar_documents_queryset( + self, query_vector: Sequence[float], *, limit: int + ) -> "PgvectorEmbeddingQuerySet": + return ( + self._get_queryset() + .select_related("embedding") + .filter(embedding_output_dimensions=len(query_vector)) + .order_by_distance( + query_vector, + distance_method=self.distance_method, + fetch_distance=False, + )[:limit] + ) + def _bulk_create(self, embeddings: Sequence["PgvectorEmbedding"]) -> None: _embedding_model().objects.bulk_create( embeddings,