From 645930ad4be9858a5940fe3ffdccc61ecbbebc6a Mon Sep 17 00:00:00 2001 From: Olivier Philippon Date: Fri, 12 Jul 2024 14:50:44 +0100 Subject: [PATCH 1/6] Add an async version of the `query` method to the VectorIndex class ...with all the dependencies it entails here and there in the code used by this method --- pyproject.toml | 2 +- src/wagtail_vector_index/storage/base.py | 66 ++++++++++++++++++- src/wagtail_vector_index/storage/models.py | 60 ++++++++++++++++- .../storage/pgvector/provider.py | 25 ++++++- 4 files changed, 149 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 503b308..5d98200 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,7 @@ filterwarnings = [ "ignore::DeprecationWarning", ] -[tool.ruff] +[tool.ruff.lint] select = ["F", "E", "C90", "I", "B", "DJ", "RUF", "TRY", "C4", "TCH005", "TCH004"] ignore = ["TRY003", "E501", "RUF012"] diff --git a/src/wagtail_vector_index/storage/base.py b/src/wagtail_vector_index/storage/base.py index 779ab28..5bf54b3 100644 --- a/src/wagtail_vector_index/storage/base.py +++ b/src/wagtail_vector_index/storage/base.py @@ -1,5 +1,5 @@ import copy -from collections.abc import Generator, Iterable, Mapping, Sequence +from collections.abc import AsyncGenerator, Generator, Iterable, Mapping, Sequence from dataclasses import dataclass from typing import Any, ClassVar, Generic, Protocol, TypeVar @@ -94,6 +94,10 @@ def bulk_from_documents( self, documents: Iterable[Document] ) -> Generator[object, None, None]: ... + async def abulk_from_documents( + self, documents: Iterable[Document] + ) -> AsyncGenerator[object, None]: ... + @dataclass class QueryResponse: @@ -105,6 +109,14 @@ class QueryResponse: sources: Iterable[object] +@dataclass +class AsyncQueryResponse: + """Same as QueryResponse class, but with the response being an async generator.""" + + response: AsyncGenerator[str, None] + sources: Iterable[object] + + class VectorIndex(Generic[ConfigClass]): """Base class for a VectorIndex, representing some set of documents that can be queried""" @@ -152,6 +164,53 @@ def query( response = chat_backend.chat(messages=messages) return QueryResponse(response=response.choices[0], sources=sources) + async def aquery( + self, query: str, *, sources_limit: int = 5, chat_backend_alias: str = "default" + ) -> AsyncQueryResponse: + """ + Replicates the features of `VectorIndex.query()`, but in an async way. + """ + try: + # TODO: use `next()` instead of `[0]` once the `aembed()` method of the + # LiteLLMBackends class is updated to return the same data structure than + # its BaseEmbeddingBackend parent class. + query_embedding = (await self.get_embedding_backend().aembed([query]))[0] # type: ignore[reportIndexIssue] + except IndexError as e: + raise ValueError("No embeddings were generated for the given query.") from e + + similar_documents = [ + doc async for doc in self.aget_similar_documents(query_embedding) + ] + + sources = [ + source + async for source in self.get_converter().abulk_from_documents( + similar_documents + ) + ] + + merged_context = "\n".join(doc.metadata["content"] for doc in similar_documents) + prompt = ( + getattr(settings, "WAGTAIL_VECTOR_INDEX_QUERY_PROMPT", None) + or "You are a helpful assistant. Use the following context to answer the question. Don't mention the context in your answer." + ) + messages = [ + {"content": prompt, "role": "system"}, + {"content": merged_context, "role": "system"}, + {"content": query, "role": "user"}, + ] + chat_backend = get_chat_backend(chat_backend_alias) + response = await chat_backend.achat(messages=messages, stream=True) + + async def async_stream_wrapper(): + async for chunk in response: + yield chunk["content"] + + return AsyncQueryResponse( + response=async_stream_wrapper(), + sources=sources, + ) + def find_similar( self, object, *, include_self: bool = False, limit: int = 5 ) -> list: @@ -209,3 +268,8 @@ def get_similar_documents( self, query_vector: Sequence[float], *, limit: int = 5 ) -> Generator[Document, None, None]: raise NotImplementedError + + def aget_similar_documents( + self, query_vector, *, limit: int = 5 + ) -> AsyncGenerator[Document, None]: + raise NotImplementedError diff --git a/src/wagtail_vector_index/storage/models.py b/src/wagtail_vector_index/storage/models.py index a56e3e8..2c32f51 100644 --- a/src/wagtail_vector_index/storage/models.py +++ b/src/wagtail_vector_index/storage/models.py @@ -1,5 +1,11 @@ from collections import defaultdict -from collections.abc import Generator, Iterable, MutableSequence, Sequence +from collections.abc import ( + AsyncGenerator, + Generator, + Iterable, + MutableSequence, + Sequence, +) from typing import TYPE_CHECKING, ClassVar, Optional, TypeVar, cast from django.apps import apps @@ -183,6 +189,14 @@ def _model_class_from_ctid(id: str) -> type[models.Model]: raise ValueError(f"Failed to find model class for {ct!r}") return model_class + @classmethod + async def _amodel_class_from_ctid(cls, id: str) -> type[models.Model]: + ct = await cls._aget_content_type_for_id(int(id)) + model_class = ct.model_class() + if model_class is None: + raise ValueError(f"Failed to find model class for {ct!r}") + return model_class + def from_document(self, document: Document) -> models.Model: model_class = self._model_class_from_ctid(document.metadata["content_type_id"]) try: @@ -220,6 +234,50 @@ def bulk_from_documents( seen_keys.add(key) yield objects_by_key[key] + async def abulk_from_documents( + self, documents: Iterable[Document] + ) -> AsyncGenerator[models.Model, None, None]: + """A copy of `bulk_from_documents`, but async""" + # Force evaluate generators to allow value to be reused + documents = tuple(documents) + + ids_by_content_type: dict[str, list[str]] = defaultdict(list) + for doc in documents: + ids_by_content_type[doc.metadata["content_type_id"]].append( + doc.metadata["object_id"] + ) + + # NOTE: (content_type_id, object_id) combo keys are required to + # reliably map data from multiple models + objects_by_key: dict[tuple[str, str], models.Model] = {} + for content_type_id, ids in ids_by_content_type.items(): + model_class = await self._amodel_class_from_ctid(content_type_id) + model_objects = model_class.objects.filter(pk__in=ids) + objects_by_key.update( + {(content_type_id, str(obj.pk)): obj async for obj in model_objects} + ) + + seen_keys = set() # de-dupe as we go + for doc in documents: + key = (doc.metadata["content_type_id"], doc.metadata["object_id"]) + if key in seen_keys: + continue + seen_keys.add(key) + yield objects_by_key[key] + + @staticmethod + async def _aget_content_type_for_id(id: int) -> ContentType: + """ + Same as `ContentTypeManager.get_for_id`, but async. + """ + manager = ContentType.objects + try: + ct = manager._cache[manager.db][id] + except KeyError: + ct = await manager.aget(pk=id) + manager._add_to_cache(manager.db, ct) + return ct + class EmbeddableFieldsDocumentConverter(DocumentToModelMixin): """Implementation of DocumentConverter that knows how to convert a model instance using the diff --git a/src/wagtail_vector_index/storage/pgvector/provider.py b/src/wagtail_vector_index/storage/pgvector/provider.py index bb82058..89e2d45 100644 --- a/src/wagtail_vector_index/storage/pgvector/provider.py +++ b/src/wagtail_vector_index/storage/pgvector/provider.py @@ -1,5 +1,11 @@ import logging -from collections.abc import Generator, Iterable, MutableSequence, Sequence +from collections.abc import ( + AsyncGenerator, + Generator, + Iterable, + MutableSequence, + Sequence, +) from dataclasses import dataclass from typing import TYPE_CHECKING, Any, ClassVar, cast @@ -90,6 +96,23 @@ def get_similar_documents( embedding = pgvector_embedding.embedding yield embedding.to_document() + async def aget_similar_documents( + self, query_vector, *, limit: int = 5 + ) -> AsyncGenerator[Document, None]: + async for pgvector_embedding in ( + self._get_queryset() + .select_related("embedding") + .filter(embedding_output_dimensions=len(query_vector)) + .order_by_distance( + query_vector, + distance_method=self.distance_method, + fetch_distance=False, + )[:limit] + .aiterator() + ): + embedding = pgvector_embedding.embedding + yield embedding.to_document() + def _get_queryset(self) -> "PgvectorEmbeddingQuerySet": # objects is technically a Manager instance but we want to use the custom # queryset method From 188760fcedf091665b5c65fd3abef6707c4e8026 Mon Sep 17 00:00:00 2001 From: Olivier Philippon Date: Wed, 17 Jul 2024 11:17:30 +0100 Subject: [PATCH 2/6] Factorise code used by the sync & async versions of the `query` and `bulk_from_documents` methods --- src/wagtail_vector_index/storage/models.py | 96 ++++++++++++------- .../storage/pgvector/provider.py | 42 ++++---- 2 files changed, 80 insertions(+), 58 deletions(-) diff --git a/src/wagtail_vector_index/storage/models.py b/src/wagtail_vector_index/storage/models.py index 2c32f51..e467541 100644 --- a/src/wagtail_vector_index/storage/models.py +++ b/src/wagtail_vector_index/storage/models.py @@ -6,7 +6,7 @@ MutableSequence, Sequence, ) -from typing import TYPE_CHECKING, ClassVar, Optional, TypeVar, cast +from typing import TYPE_CHECKING, ClassVar, Optional, TypeAlias, TypeVar, cast from django.apps import apps from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation @@ -45,6 +45,10 @@ - The EmbeddableFieldsDocumentConverter, which is a DocumentConverter that knows how to convert a model instance using the EmbeddableFieldsMixin protocol to and from a Document """ +ContentTypeId: TypeAlias = str +ObjectId: TypeAlias = str +ModelKey: TypeAlias = tuple[ContentTypeId, ObjectId] + class Embedding(models.Model): """Stores an embedding for a model instance""" @@ -207,25 +211,44 @@ def from_document(self, document: Document) -> models.Model: def bulk_from_documents( self, documents: Iterable[Document] ) -> Generator[models.Model, None, None]: + documents = tuple(documents) + + ids_by_content_type = self._get_ids_by_content_type(documents) + objects_by_key = self._get_models_by_key(ids_by_content_type) + + yield from self._get_deduplicated_objects_generator(documents, objects_by_key) + + async def abulk_from_documents( + self, documents: Iterable[Document] + ) -> AsyncGenerator[models.Model, None]: + """A copy of `bulk_from_documents`, but async""" # Force evaluate generators to allow value to be reused documents = tuple(documents) - ids_by_content_type: dict[str, list[str]] = defaultdict(list) + ids_by_content_type = self._get_ids_by_content_type(documents) + objects_by_key = await self._aget_models_by_key(ids_by_content_type) + + # N.B. `yield from` cannot be used in async functions, so we have to use a loop + for object_from_document in self._get_deduplicated_objects_generator( + documents, objects_by_key + ): + yield object_from_document + + @staticmethod + def _get_ids_by_content_type( + documents: Sequence[Document], + ) -> dict[ContentTypeId, list[ObjectId]]: + ids_by_content_type = defaultdict(list) for doc in documents: ids_by_content_type[doc.metadata["content_type_id"]].append( doc.metadata["object_id"] ) + return ids_by_content_type - # NOTE: (content_type_id, object_id) combo keys are required to - # reliably map data from multiple models - objects_by_key: dict[tuple[str, str], models.Model] = {} - for content_type_id, ids in ids_by_content_type.items(): - model_class = self._model_class_from_ctid(content_type_id) - model_objects = model_class.objects.filter(pk__in=ids) - objects_by_key.update( - {(content_type_id, str(obj.pk)): obj for obj in model_objects} - ) - + @staticmethod + def _get_deduplicated_objects_generator( + documents: Sequence[Document], objects_by_key: dict[ModelKey, models.Model] + ) -> Generator[models.Model, None, None]: seen_keys = set() # de-dupe as we go for doc in documents: key = (doc.metadata["content_type_id"], doc.metadata["object_id"]) @@ -234,36 +257,37 @@ def bulk_from_documents( seen_keys.add(key) yield objects_by_key[key] - async def abulk_from_documents( - self, documents: Iterable[Document] - ) -> AsyncGenerator[models.Model, None, None]: - """A copy of `bulk_from_documents`, but async""" - # Force evaluate generators to allow value to be reused - documents = tuple(documents) - - ids_by_content_type: dict[str, list[str]] = defaultdict(list) - for doc in documents: - ids_by_content_type[doc.metadata["content_type_id"]].append( - doc.metadata["object_id"] + def _get_models_by_key( + self, ids_by_content_type: dict + ) -> dict[ModelKey, models.Model]: + """ + (content_type_id, object_id) combo keys are required to reliably map data + from multiple models. This function loads the models from the database + and groups them by such a key. + """ + objects_by_key: dict[ModelKey, models.Model] = {} + for content_type_id, ids in ids_by_content_type.items(): + model_class = self._model_class_from_ctid(content_type_id) + model_objects = model_class.objects.filter(pk__in=ids) + objects_by_key.update( + {(content_type_id, str(obj.pk)): obj for obj in model_objects} ) + return objects_by_key - # NOTE: (content_type_id, object_id) combo keys are required to - # reliably map data from multiple models - objects_by_key: dict[tuple[str, str], models.Model] = {} + async def _aget_models_by_key( + self, ids_by_content_type: dict + ) -> dict[ModelKey, models.Model]: + """ + Same as `_get_models_by_key`, but async. + """ + objects_by_key: dict[ModelKey, models.Model] = {} for content_type_id, ids in ids_by_content_type.items(): model_class = await self._amodel_class_from_ctid(content_type_id) model_objects = model_class.objects.filter(pk__in=ids) objects_by_key.update( {(content_type_id, str(obj.pk)): obj async for obj in model_objects} ) - - seen_keys = set() # de-dupe as we go - for doc in documents: - key = (doc.metadata["content_type_id"], doc.metadata["object_id"]) - if key in seen_keys: - continue - seen_keys.add(key) - yield objects_by_key[key] + return objects_by_key @staticmethod async def _aget_content_type_for_id(id: int) -> ContentType: @@ -272,10 +296,10 @@ async def _aget_content_type_for_id(id: int) -> ContentType: """ manager = ContentType.objects try: - ct = manager._cache[manager.db][id] + ct = manager._cache[manager.db][id] # type: ignore[reportAttributeAccessIssue] except KeyError: ct = await manager.aget(pk=id) - manager._add_to_cache(manager.db, ct) + manager._add_to_cache(manager.db, ct) # type: ignore[reportAttributeAccessIssue] return ct diff --git a/src/wagtail_vector_index/storage/pgvector/provider.py b/src/wagtail_vector_index/storage/pgvector/provider.py index 89e2d45..7c50fdf 100644 --- a/src/wagtail_vector_index/storage/pgvector/provider.py +++ b/src/wagtail_vector_index/storage/pgvector/provider.py @@ -82,34 +82,18 @@ def clear(self): def get_similar_documents( self, query_vector, *, limit: int = 5 ) -> Generator[Document, None, None]: - for pgvector_embedding in ( - self._get_queryset() - .select_related("embedding") - .filter(embedding_output_dimensions=len(query_vector)) - .order_by_distance( - query_vector, - distance_method=self.distance_method, - fetch_distance=False, - )[:limit] - .iterator() - ): + for pgvector_embedding in self._get_similar_documents_queryset( + query_vector, limit=limit + ).iterator(): embedding = pgvector_embedding.embedding yield embedding.to_document() async def aget_similar_documents( self, query_vector, *, limit: int = 5 ) -> AsyncGenerator[Document, None]: - async for pgvector_embedding in ( - self._get_queryset() - .select_related("embedding") - .filter(embedding_output_dimensions=len(query_vector)) - .order_by_distance( - query_vector, - distance_method=self.distance_method, - fetch_distance=False, - )[:limit] - .aiterator() - ): + async for pgvector_embedding in self._get_similar_documents_queryset( + query_vector, limit=limit + ).aiterator(): embedding = pgvector_embedding.embedding yield embedding.to_document() @@ -120,6 +104,20 @@ def _get_queryset(self) -> "PgvectorEmbeddingQuerySet": type(self).__name__ ) + def _get_similar_documents_queryset( + self, query_vector: Sequence[float], *, limit: int + ) -> "PgvectorEmbeddingQuerySet": + return ( + self._get_queryset() + .select_related("embedding") + .filter(embedding_output_dimensions=len(query_vector)) + .order_by_distance( + query_vector, + distance_method=self.distance_method, + fetch_distance=False, + )[:limit] + ) + def _bulk_create(self, embeddings: Sequence["PgvectorEmbedding"]) -> None: _embedding_model().objects.bulk_create( embeddings, From 4000f98d2e78a7a7aa64cf2e46535e6b2364161d Mon Sep 17 00:00:00 2001 From: Tom Usher Date: Thu, 25 Jul 2024 13:07:17 +0000 Subject: [PATCH 3/6] Use next() when retreiving query embedding now that LiteLLM returns an iterator --- src/wagtail_vector_index/storage/base.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/wagtail_vector_index/storage/base.py b/src/wagtail_vector_index/storage/base.py index 5bf54b3..8ee06db 100644 --- a/src/wagtail_vector_index/storage/base.py +++ b/src/wagtail_vector_index/storage/base.py @@ -171,10 +171,7 @@ async def aquery( Replicates the features of `VectorIndex.query()`, but in an async way. """ try: - # TODO: use `next()` instead of `[0]` once the `aembed()` method of the - # LiteLLMBackends class is updated to return the same data structure than - # its BaseEmbeddingBackend parent class. - query_embedding = (await self.get_embedding_backend().aembed([query]))[0] # type: ignore[reportIndexIssue] + query_embedding = next(await self.get_embedding_backend().aembed([query])) except IndexError as e: raise ValueError("No embeddings were generated for the given query.") from e From 16a588d6f905595b26e17df9809f2ac3adcb2652 Mon Sep 17 00:00:00 2001 From: Tom Usher Date: Thu, 25 Jul 2024 14:40:12 +0000 Subject: [PATCH 4/6] Add async iterator methods to base AIStreamingResponse --- src/wagtail_vector_index/ai_utils/types.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/wagtail_vector_index/ai_utils/types.py b/src/wagtail_vector_index/ai_utils/types.py index 7736b45..3caaeae 100644 --- a/src/wagtail_vector_index/ai_utils/types.py +++ b/src/wagtail_vector_index/ai_utils/types.py @@ -40,8 +40,13 @@ class AIStreamingResponse: def __iter__(self): return self + def __aiter__(self): + return self + def __next__(self) -> AIResponseStreamingPart: ... + async def __anext__(self) -> AIResponseStreamingPart: ... + class AIResponse: """Representation of a non-streaming response from an AI backend. From 09d2b624b80a5a5adbca7f9ec8303c8a53c610e9 Mon Sep 17 00:00:00 2001 From: Tom Usher Date: Thu, 25 Jul 2024 14:48:30 +0000 Subject: [PATCH 5/6] Iterate directly over queryset in aget_similar_documents --- src/wagtail_vector_index/storage/pgvector/provider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wagtail_vector_index/storage/pgvector/provider.py b/src/wagtail_vector_index/storage/pgvector/provider.py index 7c50fdf..bdaae6c 100644 --- a/src/wagtail_vector_index/storage/pgvector/provider.py +++ b/src/wagtail_vector_index/storage/pgvector/provider.py @@ -93,7 +93,7 @@ async def aget_similar_documents( ) -> AsyncGenerator[Document, None]: async for pgvector_embedding in self._get_similar_documents_queryset( query_vector, limit=limit - ).aiterator(): + ): embedding = pgvector_embedding.embedding yield embedding.to_document() From d5e26a0bb786bfbfd9424c430d1d35b2f3184750 Mon Sep 17 00:00:00 2001 From: Tom Usher Date: Tue, 30 Jul 2024 16:02:28 +0000 Subject: [PATCH 6/6] Refer to location of EmbeddingResponse type in LiteLLM --- src/wagtail_vector_index/ai_utils/backends/litellm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/wagtail_vector_index/ai_utils/backends/litellm.py b/src/wagtail_vector_index/ai_utils/backends/litellm.py index c8822b4..6031fed 100644 --- a/src/wagtail_vector_index/ai_utils/backends/litellm.py +++ b/src/wagtail_vector_index/ai_utils/backends/litellm.py @@ -3,6 +3,7 @@ from typing import Any, NotRequired, Self import litellm +import litellm.types.utils from django.core.exceptions import ImproperlyConfigured from ..types import ( @@ -174,7 +175,7 @@ class LiteLLMEmbeddingBackend(BaseEmbeddingBackend[LiteLLMEmbeddingBackendConfig def embed(self, inputs: Iterable[str], **kwargs) -> Iterator[list[float]]: response = litellm.embedding(model=self.config.model_id, input=inputs, **kwargs) # LiteLLM *should* return an EmbeddingResponse - assert isinstance(response, litellm.EmbeddingResponse) + assert isinstance(response, litellm.types.utils.EmbeddingResponse) yield from [data["embedding"] for data in response["data"]] async def aembed(self, inputs: Iterable[str], **kwargs) -> Iterator[list[float]]: