diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-supabase/llama_index/vector_stores/supabase/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-supabase/llama_index/vector_stores/supabase/base.py index 007fd108aa7fd..c4e657850274d 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-supabase/llama_index/vector_stores/supabase/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-supabase/llama_index/vector_stores/supabase/base.py @@ -1,14 +1,16 @@ import logging import math from collections import defaultdict -from typing import Any, List +from typing import Any, List, Optional import vecs +from vecs.collection import Collection from llama_index.core.constants import DEFAULT_EMBEDDING_DIM from llama_index.core.schema import BaseNode, TextNode +from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.vector_stores.types import ( MetadataFilters, - VectorStore, + BasePydanticVectorStore, VectorStoreQuery, VectorStoreQueryResult, ) @@ -22,7 +24,7 @@ logger = logging.getLogger(__name__) -class SupabaseVectorStore(VectorStore): +class SupabaseVectorStore(BasePydanticVectorStore): """Supbabase Vector. In this vector store, embeddings are stored in Postgres table using pgvector. @@ -41,6 +43,8 @@ class SupabaseVectorStore(VectorStore): stores_text = True flat_metadata = False + _client: Optional[Any] = PrivateAttr() + _collection: Optional[Collection] = PrivateAttr() def __init__( self, @@ -49,17 +53,17 @@ def __init__( dimension: int = DEFAULT_EMBEDDING_DIM, **kwargs: Any, ) -> None: - """Init params.""" - client = vecs.create_client(postgres_connection_string) + super().__init__() + self._client = vecs.create_client(postgres_connection_string) try: - self._collection = client.get_collection(name=collection_name) + self._collection = self._client.get_collection(name=collection_name) except CollectionNotFound: logger.info( f"Collection {collection_name} does not exist, " f"try creating one with dimension={dimension}" ) - self._collection = client.create_collection( + self._collection = self._client.create_collection( name=collection_name, dimension=dimension ) @@ -71,7 +75,7 @@ def client(self) -> None: def _to_vecs_filters(self, filters: MetadataFilters) -> Any: """Convert llama filters to vecs filters. $eq is the only supported operator.""" vecs_filter = defaultdict(list) - filter_cond = f"${filters.condition}" + filter_cond = f"${filters.condition.value}" for f in filters.legacy_filters(): sub_filter = {}