Skip to content

Commit

Permalink
Make SupabaseVectorStore a subclass of BasePydanticVectorStore (run-l…
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-tee authored and Izuki Matsuba committed Mar 29, 2024
1 parent 2fd2688 commit d2db21b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import logging
import math
from collections import defaultdict
from typing import Any, List
from typing import Any, List, Optional

import vecs
from vecs.collection import Collection
from llama_index.core.constants import DEFAULT_EMBEDDING_DIM
from llama_index.core.schema import BaseNode, TextNode
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.vector_stores.types import (
MetadataFilters,
VectorStore,
BasePydanticVectorStore,
VectorStoreQuery,
VectorStoreQueryResult,
)
Expand All @@ -22,7 +24,7 @@
logger = logging.getLogger(__name__)


class SupabaseVectorStore(VectorStore):
class SupabaseVectorStore(BasePydanticVectorStore):
"""Supbabase Vector.
In this vector store, embeddings are stored in Postgres table using pgvector.
Expand All @@ -41,6 +43,8 @@ class SupabaseVectorStore(VectorStore):

stores_text = True
flat_metadata = False
_client: Optional[Any] = PrivateAttr()
_collection: Optional[Collection] = PrivateAttr()

def __init__(
self,
Expand All @@ -49,17 +53,17 @@ def __init__(
dimension: int = DEFAULT_EMBEDDING_DIM,
**kwargs: Any,
) -> None:
"""Init params."""
client = vecs.create_client(postgres_connection_string)
super().__init__()
self._client = vecs.create_client(postgres_connection_string)

try:
self._collection = client.get_collection(name=collection_name)
self._collection = self._client.get_collection(name=collection_name)
except CollectionNotFound:
logger.info(
f"Collection {collection_name} does not exist, "
f"try creating one with dimension={dimension}"
)
self._collection = client.create_collection(
self._collection = self._client.create_collection(
name=collection_name, dimension=dimension
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from llama_index.core.vector_stores.types import VectorStore
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.supabase import SupabaseVectorStore


def test_class():
names_of_base_classes = [b.__name__ for b in SupabaseVectorStore.__mro__]
assert VectorStore.__name__ in names_of_base_classes
assert BasePydanticVectorStore.__name__ in names_of_base_classes

0 comments on commit d2db21b

Please sign in to comment.