diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-supabase/llama_index/vector_stores/supabase/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-supabase/llama_index/vector_stores/supabase/base.py index 03dab9d83153bd..c4e657850274d0 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-supabase/llama_index/vector_stores/supabase/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-supabase/llama_index/vector_stores/supabase/base.py @@ -1,14 +1,16 @@ import logging import math from collections import defaultdict -from typing import Any, List +from typing import Any, List, Optional import vecs +from vecs.collection import Collection from llama_index.core.constants import DEFAULT_EMBEDDING_DIM from llama_index.core.schema import BaseNode, TextNode +from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.vector_stores.types import ( MetadataFilters, - VectorStore, + BasePydanticVectorStore, VectorStoreQuery, VectorStoreQueryResult, ) @@ -22,7 +24,7 @@ logger = logging.getLogger(__name__) -class SupabaseVectorStore(VectorStore): +class SupabaseVectorStore(BasePydanticVectorStore): """Supbabase Vector. In this vector store, embeddings are stored in Postgres table using pgvector. @@ -41,6 +43,8 @@ class SupabaseVectorStore(VectorStore): stores_text = True flat_metadata = False + _client: Optional[Any] = PrivateAttr() + _collection: Optional[Collection] = PrivateAttr() def __init__( self, @@ -49,17 +53,17 @@ def __init__( dimension: int = DEFAULT_EMBEDDING_DIM, **kwargs: Any, ) -> None: - """Init params.""" - client = vecs.create_client(postgres_connection_string) + super().__init__() + self._client = vecs.create_client(postgres_connection_string) try: - self._collection = client.get_collection(name=collection_name) + self._collection = self._client.get_collection(name=collection_name) except CollectionNotFound: logger.info( f"Collection {collection_name} does not exist, " f"try creating one with dimension={dimension}" ) - self._collection = client.create_collection( + self._collection = self._client.create_collection( name=collection_name, dimension=dimension ) diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-supabase/tests/test_vector_stores_supabase.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-supabase/tests/test_vector_stores_supabase.py index a83a334cb38158..d96c0489784a52 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-supabase/tests/test_vector_stores_supabase.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-supabase/tests/test_vector_stores_supabase.py @@ -1,7 +1,7 @@ -from llama_index.core.vector_stores.types import VectorStore +from llama_index.core.vector_stores.types import BasePydanticVectorStore from llama_index.vector_stores.supabase import SupabaseVectorStore def test_class(): names_of_base_classes = [b.__name__ for b in SupabaseVectorStore.__mro__] - assert VectorStore.__name__ in names_of_base_classes + assert BasePydanticVectorStore.__name__ in names_of_base_classes