Skip to content

Commit

Permalink
Make SupabaseVectorStore a subclass of BasePydanticVectorStore
Browse files Browse the repository at this point in the history
This bug was also present for SupabaseVectorStore, so it could not be used in a IngestionPipeline: #10688

This PR fixes the bug above by making SupabaseVectorStore a subclass of BasePydanticVectorStore.

This change is analogous to
https://github.com/run-llama/llama_index/pull/11435/files#diff-a903926a12a2a95032e32938ee7a8d5ab960dda9872690435e6f53a527f6368cR98
  • Loading branch information
dan-tee authored Feb 28, 2024
1 parent b2f0a59 commit 0094b5e
Showing 1 changed file with 12 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import logging
import math
from collections import defaultdict
from typing import Any, List
from typing import Any, List, Optional

import vecs
from vecs.collection import Collection
from llama_index.core.constants import DEFAULT_EMBEDDING_DIM
from llama_index.core.schema import BaseNode, TextNode
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.vector_stores.types import (
MetadataFilters,
VectorStore,
BasePydanticVectorStore,
VectorStoreQuery,
VectorStoreQueryResult,
)
Expand All @@ -22,7 +24,7 @@
logger = logging.getLogger(__name__)


class SupabaseVectorStore(VectorStore):
class SupabaseVectorStore(BasePydanticVectorStore):
"""Supbabase Vector.
In this vector store, embeddings are stored in Postgres table using pgvector.
Expand All @@ -41,6 +43,8 @@ class SupabaseVectorStore(VectorStore):

stores_text = True
flat_metadata = False
_client: Optional[Any] = PrivateAttr()
_collection: Optional[Collection] = PrivateAttr()

def __init__(
self,
Expand All @@ -49,17 +53,17 @@ def __init__(
dimension: int = DEFAULT_EMBEDDING_DIM,
**kwargs: Any,
) -> None:
"""Init params."""
client = vecs.create_client(postgres_connection_string)
super().__init__()
self._client = vecs.create_client(postgres_connection_string)

try:
self._collection = client.get_collection(name=collection_name)
self._collection = self._client.get_collection(name=collection_name)
except CollectionNotFound:
logger.info(
f"Collection {collection_name} does not exist, "
f"try creating one with dimension={dimension}"
)
self._collection = client.create_collection(
self._collection = self._client.create_collection(
name=collection_name, dimension=dimension
)

Expand All @@ -71,7 +75,7 @@ def client(self) -> None:
def _to_vecs_filters(self, filters: MetadataFilters) -> Any:
"""Convert llama filters to vecs filters. $eq is the only supported operator."""
vecs_filter = defaultdict(list)
filter_cond = f"${filters.condition}"
filter_cond = f"${filters.condition.value}"

for f in filters.legacy_filters():
sub_filter = {}
Expand Down

0 comments on commit 0094b5e

Please sign in to comment.