Skip to content

Commit

Permalink
feat: extend OpenSearch params support (#70)
Browse files Browse the repository at this point in the history
* feat: extend OpenSearch params support

* add defaults to docstrings
  • Loading branch information
tstadel authored Nov 30, 2023
1 parent 6dbec0e commit 2f453ce
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 11 deletions.
58 changes: 53 additions & 5 deletions integrations/opensearch/src/opensearch_haystack/bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,22 @@ def __init__(
fuzziness: str = "AUTO",
top_k: int = 10,
scale_score: bool = False,
all_terms_must_match: bool = False,
):
"""
Create the OpenSearchBM25Retriever component.
:param document_store: An instance of OpenSearchDocumentStore.
:param filters: Filters applied to the retrieved Documents. Defaults to None.
:param fuzziness: Fuzziness parameter for full-text queries. Defaults to "AUTO".
:param top_k: Maximum number of Documents to return, defaults to 10
:param scale_score: Whether to scale the score of retrieved documents between 0 and 1.
This is useful when comparing documents across different indexes. Defaults to False.
:param all_terms_must_match: If True, all terms in the query string must be present in the retrieved documents.
This is useful when searching for short text where even one term can make a difference. Defaults to False.
:raises ValueError: If `document_store` is not an instance of OpenSearchDocumentStore.
"""
if not isinstance(document_store, OpenSearchDocumentStore):
msg = "document_store must be an instance of OpenSearchDocumentStore"
raise ValueError(msg)
Expand All @@ -26,6 +41,7 @@ def __init__(
self._fuzziness = fuzziness
self._top_k = top_k
self._scale_score = scale_score
self._all_terms_must_match = all_terms_must_match

def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
Expand All @@ -45,12 +61,44 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchBM25Retriever":
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, query: str):
def run(
self,
query: str,
filters: Optional[Dict[str, Any]] = None,
all_terms_must_match: Optional[bool] = None,
top_k: Optional[int] = None,
fuzziness: Optional[str] = None,
scale_score: Optional[bool] = None,
):
"""
Retrieve documents using BM25 retrieval.
:param query: The query string
:param filters: Optional filters to narrow down the search space.
:param all_terms_must_match: If True, all terms in the query string must be present in the retrieved documents.
:param top_k: Maximum number of Documents to return.
:param fuzziness: Fuzziness parameter for full-text queries.
:param scale_score: Whether to scale the score of retrieved documents between 0 and 1.
This is useful when comparing documents across different indexes.
:return: A dictionary containing the retrieved documents.
"""
if filters is None:
filters = self._filters
if all_terms_must_match is None:
all_terms_must_match = self._all_terms_must_match
if top_k is None:
top_k = self._top_k
if fuzziness is None:
fuzziness = self._fuzziness
if scale_score is None:
scale_score = self._scale_score

docs = self._document_store._bm25_retrieval(
query=query,
filters=self._filters,
fuzziness=self._fuzziness,
top_k=self._top_k,
scale_score=self._scale_score,
filters=filters,
fuzziness=fuzziness,
top_k=top_k,
scale_score=scale_score,
all_terms_must_match=all_terms_must_match,
)
return {"documents": docs}
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def _bm25_retrieval(
fuzziness: str = "AUTO",
top_k: int = 10,
scale_score: bool = False,
all_terms_must_match: bool = False,
) -> List[Document]:
"""
OpenSearch by defaults uses BM25 search algorithm.
Expand All @@ -234,13 +235,13 @@ def _bm25_retrieval(
`query` must be a non empty string, otherwise a `ValueError` will be raised.
:param query: String to search in saved Documents' text.
:param filters: Filters applied to the retrieved Documents, for more info
see `OpenSearchDocumentStore.filter_documents`, defaults to None
:param filters: Optional filters to narrow down the search space.
:param fuzziness: Fuzziness parameter passed to OpenSearch, defaults to "AUTO".
see the official documentation for valid values:
https://www.elastic.co/guide/en/OpenSearch/reference/current/common-options.html#fuzziness
:param top_k: Maximum number of Documents to return, defaults to 10
:param scale_score: If `True` scales the Document`s scores between 0 and 1, defaults to False
:param all_terms_must_match: If `True` all terms in `query` must be present in the Document, defaults to False
:raises ValueError: If `query` is an empty string
:return: List of Document that match `query`
"""
Expand All @@ -249,6 +250,7 @@ def _bm25_retrieval(
msg = "query must be a non empty string"
raise ValueError(msg)

operator = "AND" if all_terms_must_match else "OR"
body: Dict[str, Any] = {
"size": top_k,
"query": {
Expand All @@ -259,7 +261,7 @@ def _bm25_retrieval(
"query": query,
"fuzziness": fuzziness,
"type": "most_fields",
"operator": "AND",
"operator": operator,
}
}
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,23 @@ def from_dict(cls, data: Dict[str, Any]) -> "OpenSearchEmbeddingRetriever":
return default_from_dict(cls, data)

@component.output_types(documents=List[Document])
def run(self, query_embedding: List[float]):
def run(self, query_embedding: List[float], filters: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None):
"""
Retrieve documents using a vector similarity metric.
:param query_embedding: Embedding of the query.
:param filters: Optional filters to narrow down the search space.
:param top_k: Maximum number of Documents to return.
:return: List of Document similar to `query_embedding`.
"""
if filters is None:
filters = self._filters
if top_k is None:
top_k = self._top_k

docs = self._document_store._embedding_retrieval(
query_embedding=query_embedding,
filters=self._filters,
top_k=self._top_k,
filters=filters,
top_k=top_k,
)
return {"documents": docs}
58 changes: 58 additions & 0 deletions integrations/opensearch/tests/test_bm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,64 @@ def test_run():
fuzziness="AUTO",
top_k=10,
scale_score=False,
all_terms_must_match=False,
)
assert len(res) == 1
assert len(res["documents"]) == 1
assert res["documents"][0].content == "Test doc"


def test_run_init_params():
mock_store = Mock(spec=OpenSearchDocumentStore)
mock_store._bm25_retrieval.return_value = [Document(content="Test doc")]
retriever = OpenSearchBM25Retriever(
document_store=mock_store,
filters={"from": "init"},
all_terms_must_match=True,
scale_score=True,
top_k=11,
fuzziness="1",
)
res = retriever.run(query="some query")
mock_store._bm25_retrieval.assert_called_once_with(
query="some query",
filters={"from": "init"},
fuzziness="1",
top_k=11,
scale_score=True,
all_terms_must_match=True,
)
assert len(res) == 1
assert len(res["documents"]) == 1
assert res["documents"][0].content == "Test doc"


def test_run_time_params():
mock_store = Mock(spec=OpenSearchDocumentStore)
mock_store._bm25_retrieval.return_value = [Document(content="Test doc")]
retriever = OpenSearchBM25Retriever(
document_store=mock_store,
filters={"from": "init"},
all_terms_must_match=True,
scale_score=True,
top_k=11,
fuzziness="1",
)
res = retriever.run(
query="some query",
filters={"from": "run"},
all_terms_must_match=False,
scale_score=False,
top_k=9,
fuzziness="2",
)
mock_store._bm25_retrieval.assert_called_once_with(
query="some query",
filters={"from": "run"},
fuzziness="2",
top_k=9,
scale_score=False,
all_terms_must_match=False,
)
assert len(res) == 1
assert len(res["documents"]) == 1
Expand Down
46 changes: 46 additions & 0 deletions integrations/opensearch/tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,52 @@ def test_bm25_retrieval_pagination(self, document_store: OpenSearchDocumentStore
assert len(res) == 11
assert all("programming" in doc.content for doc in res)

def test_bm25_retrieval_all_terms_must_match(self, document_store: OpenSearchDocumentStore):
document_store.write_documents(
[
Document(content="Haskell is a functional programming language"),
Document(content="Lisp is a functional programming language"),
Document(content="Exilir is a functional programming language"),
Document(content="F# is a functional programming language"),
Document(content="C# is a functional programming language"),
Document(content="C++ is an object oriented programming language"),
Document(content="Dart is an object oriented programming language"),
Document(content="Go is an object oriented programming language"),
Document(content="Python is a object oriented programming language"),
Document(content="Ruby is a object oriented programming language"),
Document(content="PHP is a object oriented programming language"),
]
)

res = document_store._bm25_retrieval("functional Haskell", top_k=3, all_terms_must_match=True)
assert len(res) == 1
assert "Haskell is a functional programming language" in res[0].content

def test_bm25_retrieval_all_terms_must_match_false(self, document_store: OpenSearchDocumentStore):
document_store.write_documents(
[
Document(content="Haskell is a functional programming language"),
Document(content="Lisp is a functional programming language"),
Document(content="Exilir is a functional programming language"),
Document(content="F# is a functional programming language"),
Document(content="C# is a functional programming language"),
Document(content="C++ is an object oriented programming language"),
Document(content="Dart is an object oriented programming language"),
Document(content="Go is an object oriented programming language"),
Document(content="Python is a object oriented programming language"),
Document(content="Ruby is a object oriented programming language"),
Document(content="PHP is a object oriented programming language"),
]
)

res = document_store._bm25_retrieval("functional Haskell", top_k=10, all_terms_must_match=False)
assert len(res) == 5
assert "functional" in res[0].content
assert "functional" in res[1].content
assert "functional" in res[2].content
assert "functional" in res[3].content
assert "functional" in res[4].content

def test_bm25_retrieval_with_fuzziness(self, document_store: OpenSearchDocumentStore):
document_store.write_documents(
[
Expand Down
32 changes: 32 additions & 0 deletions integrations/opensearch/tests/test_embedding_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,35 @@ def test_run():
assert len(res["documents"]) == 1
assert res["documents"][0].content == "Test doc"
assert res["documents"][0].embedding == [0.1, 0.2]


def test_run_init_params():
mock_store = Mock(spec=OpenSearchDocumentStore)
mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])]
retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filters={"from": "init"}, top_k=11)
res = retriever.run(query_embedding=[0.5, 0.7])
mock_store._embedding_retrieval.assert_called_once_with(
query_embedding=[0.5, 0.7],
filters={"from": "init"},
top_k=11,
)
assert len(res) == 1
assert len(res["documents"]) == 1
assert res["documents"][0].content == "Test doc"
assert res["documents"][0].embedding == [0.1, 0.2]


def test_run_time_params():
mock_store = Mock(spec=OpenSearchDocumentStore)
mock_store._embedding_retrieval.return_value = [Document(content="Test doc", embedding=[0.1, 0.2])]
retriever = OpenSearchEmbeddingRetriever(document_store=mock_store, filters={"from": "init"}, top_k=11)
res = retriever.run(query_embedding=[0.5, 0.7], filters={"from": "run"}, top_k=9)
mock_store._embedding_retrieval.assert_called_once_with(
query_embedding=[0.5, 0.7],
filters={"from": "run"},
top_k=9,
)
assert len(res) == 1
assert len(res["documents"]) == 1
assert res["documents"][0].content == "Test doc"
assert res["documents"][0].embedding == [0.1, 0.2]

0 comments on commit 2f453ce

Please sign in to comment.