From 0e4758f422f0706a87ba4bf8b5c200e7540e71fd Mon Sep 17 00:00:00 2001 From: SJ Date: Mon, 9 Dec 2024 14:24:19 -0600 Subject: [PATCH] add a new opertor for array inclusion filtering --- langchain_postgres/vectorstores.py | 267 ++++++++--------------------- 1 file changed, 73 insertions(+), 194 deletions(-) diff --git a/langchain_postgres/vectorstores.py b/langchain_postgres/vectorstores.py index e1630a18..b7792394 100644 --- a/langchain_postgres/vectorstores.py +++ b/langchain_postgres/vectorstores.py @@ -91,10 +91,7 @@ class DistanceStrategy(str, enum.Enum): LOGICAL_OPERATORS = {"$and", "$or", "$not"} SUPPORTED_OPERATORS = ( - set(COMPARISONS_TO_NATIVE) - .union(TEXT_OPERATORS) - .union(LOGICAL_OPERATORS) - .union(SPECIAL_CASED_OPERATORS) + set(COMPARISONS_TO_NATIVE).union(TEXT_OPERATORS).union(LOGICAL_OPERATORS).union(SPECIAL_CASED_OPERATORS) ) @@ -110,9 +107,7 @@ class CollectionStore(Base): __tablename__ = "langchain_pg_collection" - uuid = sqlalchemy.Column( - UUID(as_uuid=True), primary_key=True, default=uuid.uuid4 - ) + uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) name = sqlalchemy.Column(sqlalchemy.String, nullable=False, unique=True) cmetadata = sqlalchemy.Column(JSON) @@ -123,27 +118,13 @@ class CollectionStore(Base): ) @classmethod - def get_by_name( - cls, session: Session, name: str - ) -> Optional["CollectionStore"]: - return ( - session.query(cls) - .filter(typing_cast(sqlalchemy.Column, cls.name) == name) - .first() - ) + def get_by_name(cls, session: Session, name: str) -> Optional["CollectionStore"]: + return session.query(cls).filter(typing_cast(sqlalchemy.Column, cls.name) == name).first() @classmethod - async def aget_by_name( - cls, session: AsyncSession, name: str - ) -> Optional["CollectionStore"]: + async def aget_by_name(cls, session: AsyncSession, name: str) -> Optional["CollectionStore"]: return ( - ( - await session.execute( - select(CollectionStore).where( - typing_cast(sqlalchemy.Column, cls.name) == name - ) - ) - ) + (await session.execute(select(CollectionStore).where(typing_cast(sqlalchemy.Column, cls.name) == name))) .scalars() .first() ) @@ -197,9 +178,7 @@ class EmbeddingStore(Base): __tablename__ = "langchain_pg_embedding" - id = sqlalchemy.Column( - sqlalchemy.String, nullable=True, primary_key=True, index=True, unique=True - ) + id = sqlalchemy.Column(sqlalchemy.String, nullable=True, primary_key=True, index=True, unique=True) collection_id = sqlalchemy.Column( UUID(as_uuid=True), @@ -235,8 +214,7 @@ def _results_to_docs(docs_and_scores: Any) -> List[Document]: def _create_vector_extension(conn: Connection) -> None: statement = sqlalchemy.text( - "SELECT pg_advisory_xact_lock(1573678846307946496);" - "CREATE EXTENSION IF NOT EXISTS vector;" + "SELECT pg_advisory_xact_lock(1573678846307946496);" "CREATE EXTENSION IF NOT EXISTS vector;" ) conn.execute(statement) conn.commit() @@ -254,7 +232,7 @@ class PGVector(VectorStore): .. code-block:: bash pip install -qU langchain-postgres - docker run --name pgvector-container -e POSTGRES_USER=langchain -e POSTGRES_PASSWORD=langchain -e POSTGRES_DB=langchain -p 6024:5432 -d pgvector/pgvector:pg16 + docker run --name pgvector-container -e POSTGRESQL_USER=langchain -e POSTGRESQL_PASSWORD=langchain -e POSTGRES_DB=langchain -p 6024:5432 -d pgvector/pgvector:pg16 Key init args — indexing params: collection_name: str @@ -431,9 +409,7 @@ def __init__( if isinstance(connection, str): if async_mode: - self._async_engine = create_async_engine( - connection, **(engine_args or {}) - ) + self._async_engine = create_async_engine(connection, **(engine_args or {})) else: self._engine = create_engine(url=connection, **(engine_args or {})) elif isinstance(connection, Engine): @@ -469,9 +445,7 @@ def __post_init__( if self.create_extension: self.create_vector_extension() - EmbeddingStore, CollectionStore = _get_embedding_collection_store( - self._embedding_length - ) + EmbeddingStore, CollectionStore = _get_embedding_collection_store(self._embedding_length) self.CollectionStore = CollectionStore self.EmbeddingStore = EmbeddingStore self.create_tables_if_not_exists() @@ -485,9 +459,7 @@ async def __apost_init__( return self._async_init = True - EmbeddingStore, CollectionStore = _get_embedding_collection_store( - self._embedding_length - ) + EmbeddingStore, CollectionStore = _get_embedding_collection_store(self._embedding_length) self.CollectionStore = CollectionStore self.EmbeddingStore = EmbeddingStore if self.create_extension: @@ -539,9 +511,7 @@ def create_collection(self) -> None: if self.pre_delete_collection: self.delete_collection() with self._make_sync_session() as session: - self.CollectionStore.get_or_create( - session, self.collection_name, cmetadata=self.collection_metadata - ) + self.CollectionStore.get_or_create(session, self.collection_name, cmetadata=self.collection_metadata) session.commit() async def acreate_collection(self) -> None: @@ -549,9 +519,7 @@ async def acreate_collection(self) -> None: async with self._make_async_session() as session: if self.pre_delete_collection: await self._adelete_collection(session) - await self.CollectionStore.aget_or_create( - session, self.collection_name, cmetadata=self.collection_metadata - ) + await self.CollectionStore.aget_or_create(session, self.collection_name, cmetadata=self.collection_metadata) await session.commit() def _delete_collection(self, session: Session) -> None: @@ -602,8 +570,7 @@ def delete( with self._make_sync_session() as session: if ids is not None: self.logger.debug( - "Trying to delete vectors by ids (represented by the model " - "using the custom ids field)" + "Trying to delete vectors by ids (represented by the model " "using the custom ids field)" ) stmt = delete(self.EmbeddingStore) @@ -614,9 +581,7 @@ def delete( self.logger.warning("Collection not found") return - stmt = stmt.where( - self.EmbeddingStore.collection_id == collection.uuid - ) + stmt = stmt.where(self.EmbeddingStore.collection_id == collection.uuid) stmt = stmt.where(self.EmbeddingStore.id.in_(ids)) session.execute(stmt) @@ -638,8 +603,7 @@ async def adelete( async with self._make_async_session() as session: if ids is not None: self.logger.debug( - "Trying to delete vectors by ids (represented by the model " - "using the custom ids field)" + "Trying to delete vectors by ids (represented by the model " "using the custom ids field)" ) stmt = delete(self.EmbeddingStore) @@ -650,9 +614,7 @@ async def adelete( self.logger.warning("Collection not found") return - stmt = stmt.where( - self.EmbeddingStore.collection_id == collection.uuid - ) + stmt = stmt.where(self.EmbeddingStore.collection_id == collection.uuid) stmt = stmt.where(self.EmbeddingStore.id.in_(ids)) await session.execute(stmt) @@ -699,9 +661,7 @@ def __from( **kwargs, ) - store.add_embeddings( - texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs - ) + store.add_embeddings(texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs) return store @@ -738,9 +698,7 @@ async def __afrom( **kwargs, ) - await store.aadd_embeddings( - texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs - ) + await store.aadd_embeddings(texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs) return store @@ -783,9 +741,7 @@ def add_embeddings( "document": text, "cmetadata": metadata or {}, } - for text, metadata, embedding, id in zip( - texts, metadatas, embeddings, ids_ - ) + for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids_) ] stmt = insert(self.EmbeddingStore).values(data) on_conflict_stmt = stmt.on_conflict_do_update( @@ -842,9 +798,7 @@ async def aadd_embeddings( "document": text, "cmetadata": metadata or {}, } - for text, metadata, embedding, id in zip( - texts, metadatas, embeddings, ids_ - ) + for text, metadata, embedding, id in zip(texts, metadatas, embeddings, ids_) ] stmt = insert(self.EmbeddingStore).values(data) on_conflict_stmt = stmt.on_conflict_do_update( @@ -989,9 +943,7 @@ def similarity_search_with_score( """ assert not self._async_engine, "This method must be called without async_mode" embedding = self.embeddings.embed_query(query) - docs = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter - ) + docs = self.similarity_search_with_score_by_vector(embedding=embedding, k=k, filter=filter) return docs async def asimilarity_search_with_score( @@ -1012,9 +964,7 @@ async def asimilarity_search_with_score( """ await self.__apost_init__() # Lazy async init embedding = await self.embeddings.aembed_query(query) - docs = await self.asimilarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter - ) + docs = await self.asimilarity_search_with_score_by_vector(embedding=embedding, k=k, filter=filter) return docs @property @@ -1050,9 +1000,7 @@ async def asimilarity_search_with_score_by_vector( ) -> List[Tuple[Document, float]]: await self.__apost_init__() # Lazy async init async with self._make_async_session() as session: # type: ignore[arg-type] - results = await self.__aquery_collection( - session=session, embedding=embedding, k=k, filter=filter - ) + results = await self.__aquery_collection(session=session, embedding=embedding, k=k, filter=filter) return self._results_to_docs_and_scores(results) @@ -1089,21 +1037,14 @@ def _handle_field_filter( sqlalchemy expression """ if not isinstance(field, str): - raise ValueError( - f"field should be a string but got: {type(field)} with value: {field}" - ) + raise ValueError(f"field should be a string but got: {type(field)} with value: {field}") if field.startswith("$"): - raise ValueError( - f"Invalid filter condition. Expected a field but got an operator: " - f"{field}" - ) + raise ValueError(f"Invalid filter condition. Expected a field but got an operator: " f"{field}") # Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters if not field.isidentifier(): - raise ValueError( - f"Invalid field name: {field}. Expected a valid identifier." - ) + raise ValueError(f"Invalid field name: {field}. Expected a valid identifier.") if isinstance(value, dict): # This is a filter specification @@ -1116,10 +1057,10 @@ def _handle_field_filter( ) operator, filter_value = list(value.items())[0] # Verify that that operator is an operator - if operator not in SUPPORTED_OPERATORS: + if operator not in SUPPORTED_OPERATORS.union({"$array_contains"}): raise ValueError( f"Invalid operator: {operator}. " - f"Expected one of {SUPPORTED_OPERATORS}" + f"Expected one of {SUPPORTED_OPERATORS.union({'$array_contains'})}" ) else: # Then we assume an equality operator operator = "$eq" @@ -1154,14 +1095,10 @@ def _handle_field_filter( if operator in {"$in", "$nin"}: for val in filter_value: if not isinstance(val, (str, int, float)): - raise NotImplementedError( - f"Unsupported type: {type(val)} for value: {val}" - ) + raise NotImplementedError(f"Unsupported type: {type(val)} for value: {val}") if isinstance(val, bool): # b/c bool is an instance of int - raise NotImplementedError( - f"Unsupported type: {type(val)} for value: {val}" - ) + raise NotImplementedError(f"Unsupported type: {type(val)} for value: {val}") queried_field = self.EmbeddingStore.cmetadata[field].astext @@ -1177,15 +1114,18 @@ def _handle_field_filter( raise NotImplementedError() elif operator == "$exists": if not isinstance(filter_value, bool): - raise ValueError( - "Expected a boolean value for $exists " - f"operator, but got: {filter_value}" - ) + raise ValueError("Expected a boolean value for $exists " f"operator, but got: {filter_value}") condition = func.jsonb_exists( self.EmbeddingStore.cmetadata, field, ) return condition if filter_value else ~condition + elif operator == "$array_contains": + return func.jsonb_path_exists( + self.EmbeddingStore.cmetadata, + cast(f"$.{field}[*] ? (@ == $value)", JSONPATH), + cast({"value": filter_value}, JSONB), + ) else: raise NotImplementedError() @@ -1203,53 +1143,31 @@ def _create_filter_clause_deprecated(self, key, value): # type: ignore[no-untyp value_case_insensitive = {k.lower(): v for k, v in value.items()} if IN in map(str.lower, value): - filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.in_( - value_case_insensitive[IN] - ) + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.in_(value_case_insensitive[IN]) elif NIN in map(str.lower, value): - filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.not_in( - value_case_insensitive[NIN] - ) + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.not_in(value_case_insensitive[NIN]) elif BETWEEN in map(str.lower, value): filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.between( str(value_case_insensitive[BETWEEN][0]), str(value_case_insensitive[BETWEEN][1]), ) elif GT in map(str.lower, value): - filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext > str( - value_case_insensitive[GT] - ) + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext > str(value_case_insensitive[GT]) elif LT in map(str.lower, value): - filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext < str( - value_case_insensitive[LT] - ) + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext < str(value_case_insensitive[LT]) elif NE in map(str.lower, value): - filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext != str( - value_case_insensitive[NE] - ) + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext != str(value_case_insensitive[NE]) elif EQ in map(str.lower, value): - filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext == str( - value_case_insensitive[EQ] - ) + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext == str(value_case_insensitive[EQ]) elif LIKE in map(str.lower, value): - filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.like( - value_case_insensitive[LIKE] - ) + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.like(value_case_insensitive[LIKE]) elif CONTAINS in map(str.lower, value): - filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.contains( - value_case_insensitive[CONTAINS] - ) + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext.contains(value_case_insensitive[CONTAINS]) elif OR in map(str.lower, value): - or_clauses = [ - self._create_filter_clause(key, sub_value) - for sub_value in value_case_insensitive[OR] - ] + or_clauses = [self._create_filter_clause(key, sub_value) for sub_value in value_case_insensitive[OR]] filter_by_metadata = sqlalchemy.or_(*or_clauses) elif AND in map(str.lower, value): - and_clauses = [ - self._create_filter_clause(key, sub_value) - for sub_value in value_case_insensitive[AND] - ] + and_clauses = [self._create_filter_clause(key, sub_value) for sub_value in value_case_insensitive[AND]] filter_by_metadata = sqlalchemy.and_(*and_clauses) else: @@ -1257,9 +1175,7 @@ def _create_filter_clause_deprecated(self, key, value): # type: ignore[no-untyp return filter_by_metadata - def _create_filter_clause_json_deprecated( - self, filter: Any - ) -> List[SQLColumnExpression]: + def _create_filter_clause_json_deprecated(self, filter: Any) -> List[SQLColumnExpression]: """Convert filters from IR to SQL clauses. **DEPRECATED** This functionality will be deprecated in the future. @@ -1276,9 +1192,7 @@ def _create_filter_clause_json_deprecated( if filter_by_metadata is not None: filter_clauses.append(filter_by_metadata) else: - filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext == str( - value - ) + filter_by_metadata = self.EmbeddingStore.cmetadata[key].astext == str(value) filter_clauses.append(filter_by_metadata) return filter_clauses @@ -1303,19 +1217,14 @@ def _create_filter_clause(self, filters: Any) -> Any: if key.startswith("$"): # Then it's an operator if key.lower() not in ["$and", "$or", "$not"]: - raise ValueError( - f"Invalid filter condition. Expected $and, $or or $not " - f"but got: {key}" - ) + raise ValueError(f"Invalid filter condition. Expected $and, $or or $not " f"but got: {key}") else: # Then it's a field return self._handle_field_filter(key, filters[key]) if key.lower() == "$and": if not isinstance(value, list): - raise ValueError( - f"Expected a list, but got {type(value)} for value: {value}" - ) + raise ValueError(f"Expected a list, but got {type(value)} for value: {value}") and_ = [self._create_filter_clause(el) for el in value] if len(and_) > 1: return sqlalchemy.and_(*and_) @@ -1323,14 +1232,11 @@ def _create_filter_clause(self, filters: Any) -> Any: return and_[0] else: raise ValueError( - "Invalid filter condition. Expected a dictionary " - "but got an empty dictionary" + "Invalid filter condition. Expected a dictionary " "but got an empty dictionary" ) elif key.lower() == "$or": if not isinstance(value, list): - raise ValueError( - f"Expected a list, but got {type(value)} for value: {value}" - ) + raise ValueError(f"Expected a list, but got {type(value)} for value: {value}") or_ = [self._create_filter_clause(el) for el in value] if len(or_) > 1: return sqlalchemy.or_(*or_) @@ -1338,41 +1244,27 @@ def _create_filter_clause(self, filters: Any) -> Any: return or_[0] else: raise ValueError( - "Invalid filter condition. Expected a dictionary " - "but got an empty dictionary" + "Invalid filter condition. Expected a dictionary " "but got an empty dictionary" ) elif key.lower() == "$not": if isinstance(value, list): - not_conditions = [ - self._create_filter_clause(item) for item in value - ] - not_ = sqlalchemy.and_( - *[ - sqlalchemy.not_(condition) - for condition in not_conditions - ] - ) + not_conditions = [self._create_filter_clause(item) for item in value] + not_ = sqlalchemy.and_(*[sqlalchemy.not_(condition) for condition in not_conditions]) return not_ elif isinstance(value, dict): not_ = self._create_filter_clause(value) return sqlalchemy.not_(not_) else: raise ValueError( - f"Invalid filter condition. Expected a dictionary " - f"or a list but got: {type(value)}" + f"Invalid filter condition. Expected a dictionary " f"or a list but got: {type(value)}" ) else: - raise ValueError( - f"Invalid filter condition. Expected $and, $or or $not " - f"but got: {key}" - ) + raise ValueError(f"Invalid filter condition. Expected $and, $or or $not " f"but got: {key}") elif len(filters) > 1: # Then all keys have to be fields (they cannot be operators) for key in filters.keys(): if key.startswith("$"): - raise ValueError( - f"Invalid filter condition. Expected a field but got: {key}" - ) + raise ValueError(f"Invalid filter condition. Expected a field but got: {key}") # These should all be fields and combined using an $and operator and_ = [self._handle_field_filter(k, v) for k, v in filters.items()] if len(and_) > 1: @@ -1380,16 +1272,11 @@ def _create_filter_clause(self, filters: Any) -> Any: elif len(and_) == 1: return and_[0] else: - raise ValueError( - "Invalid filter condition. Expected a dictionary " - "but got an empty dictionary" - ) + raise ValueError("Invalid filter condition. Expected a dictionary " "but got an empty dictionary") else: raise ValueError("Got an empty dictionary for filters.") else: - raise ValueError( - f"Invalid type: Expected a dictionary but got type: {type(filters)}" - ) + raise ValueError(f"Invalid type: Expected a dictionary but got type: {type(filters)}") def __query_collection( self, @@ -1495,9 +1382,7 @@ def similarity_search_by_vector( List of Documents most similar to the query vector. """ assert not self._async_engine, "This method must be called without async_mode" - docs_and_scores = self.similarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter - ) + docs_and_scores = self.similarity_search_with_score_by_vector(embedding=embedding, k=k, filter=filter) return _results_to_docs(docs_and_scores) async def asimilarity_search_by_vector( @@ -1519,9 +1404,7 @@ async def asimilarity_search_by_vector( """ assert self._async_engine, "This method must be called with async_mode" await self.__apost_init__() # Lazy async init - docs_and_scores = await self.asimilarity_search_with_score_by_vector( - embedding=embedding, k=k, filter=filter - ) + docs_and_scores = await self.asimilarity_search_with_score_by_vector(embedding=embedding, k=k, filter=filter) return _results_to_docs(docs_and_scores) @classmethod @@ -1947,9 +1830,7 @@ async def amax_marginal_relevance_search_with_score_by_vector( """ await self.__apost_init__() # Lazy async init async with self._make_async_session() as session: - results = await self.__aquery_collection( - session=session, embedding=embedding, k=fetch_k, filter=filter - ) + results = await self.__aquery_collection(session=session, embedding=embedding, k=fetch_k, filter=filter) embedding_list = [result.EmbeddingStore.embedding for result in results] @@ -2192,15 +2073,13 @@ async def amax_marginal_relevance_search_by_vector( List[Document]: List of Documents selected by maximal marginal relevance. """ await self.__apost_init__() # Lazy async init - docs_and_scores = ( - await self.amax_marginal_relevance_search_with_score_by_vector( - embedding, - k=k, - fetch_k=fetch_k, - lambda_mult=lambda_mult, - filter=filter, - **kwargs, - ) + docs_and_scores = await self.amax_marginal_relevance_search_with_score_by_vector( + embedding, + k=k, + fetch_k=fetch_k, + lambda_mult=lambda_mult, + filter=filter, + **kwargs, ) return _results_to_docs(docs_and_scores)