Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make SupabaseVectorStore a subclass of BasePydanticVectorStore #11476

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import logging
import math
from collections import defaultdict
from typing import Any, List
from typing import Any, List, Optional

import vecs
from vecs.collection import Collection
from llama_index.core.constants import DEFAULT_EMBEDDING_DIM
from llama_index.core.schema import BaseNode, TextNode
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.vector_stores.types import (
MetadataFilters,
VectorStore,
BasePydanticVectorStore,
VectorStoreQuery,
VectorStoreQueryResult,
)
Expand All @@ -22,7 +24,7 @@
logger = logging.getLogger(__name__)


class SupabaseVectorStore(VectorStore):
class SupabaseVectorStore(BasePydanticVectorStore):
"""Supbabase Vector.

In this vector store, embeddings are stored in Postgres table using pgvector.
Expand All @@ -41,6 +43,8 @@ class SupabaseVectorStore(VectorStore):

stores_text = True
flat_metadata = False
_client: Optional[Any] = PrivateAttr()
_collection: Optional[Collection] = PrivateAttr()

def __init__(
self,
Expand All @@ -49,17 +53,17 @@ def __init__(
dimension: int = DEFAULT_EMBEDDING_DIM,
**kwargs: Any,
) -> None:
"""Init params."""
client = vecs.create_client(postgres_connection_string)
super().__init__()
self._client = vecs.create_client(postgres_connection_string)

try:
self._collection = client.get_collection(name=collection_name)
self._collection = self._client.get_collection(name=collection_name)
except CollectionNotFound:
logger.info(
f"Collection {collection_name} does not exist, "
f"try creating one with dimension={dimension}"
)
self._collection = client.create_collection(
self._collection = self._client.create_collection(
name=collection_name, dimension=dimension
)

Expand All @@ -71,7 +75,7 @@ def client(self) -> None:
def _to_vecs_filters(self, filters: MetadataFilters) -> Any:
"""Convert llama filters to vecs filters. $eq is the only supported operator."""
vecs_filter = defaultdict(list)
filter_cond = f"${filters.condition}"
filter_cond = f"${filters.condition.value}"

for f in filters.legacy_filters():
sub_filter = {}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from llama_index.core.vector_stores.types import VectorStore
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.supabase import SupabaseVectorStore


def test_class():
names_of_base_classes = [b.__name__ for b in SupabaseVectorStore.__mro__]
assert VectorStore.__name__ in names_of_base_classes
assert BasePydanticVectorStore.__name__ in names_of_base_classes
Loading