From c525acd205ba95c1c4a3f2bbdcf68c6fc170ed3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maik=20Fr=C3=B6be?= Date: Wed, 13 Nov 2024 18:47:13 +0100 Subject: [PATCH] draft to allow for multiple retrieval backends --- chatnoir_pyterrier/retrieve.py | 7 ++++++- tests/test_random_document_access.py | 6 ++++++ 2 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 tests/test_random_document_access.py diff --git a/chatnoir_pyterrier/retrieve.py b/chatnoir_pyterrier/retrieve.py index eb568a6..df33eab 100644 --- a/chatnoir_pyterrier/retrieve.py +++ b/chatnoir_pyterrier/retrieve.py @@ -8,7 +8,7 @@ search, search_phrases ) from chatnoir_api.defaults import ( - DEFAULT_INDEX, DEFAULT_SLOP, DEFAULT_RETRIES, DEFAULT_BACKOFF_SECONDS, DEFAULT_API_KEY + DEFAULT_INDEX, DEFAULT_SLOP, DEFAULT_RETRIES, DEFAULT_BACKOFF_SECONDS, DEFAULT_API_KEY, DEFAULT_RETRIEVAL_SYSTEM ) from pandas import DataFrame from pandas.core.groupby import DataFrameGroupBy @@ -35,6 +35,7 @@ class ChatNoirRetrieve(Transformer): backoff_seconds: float = DEFAULT_BACKOFF_SECONDS verbose: bool = False api_key: str = DEFAULT_API_KEY + retrieval_system: str = DEFAULT_RETRIEVAL_SYSTEM def _merge_result( self, @@ -135,6 +136,7 @@ def _transform_query(self, topic: DataFrame) -> DataFrame: retries=self.retries, backoff_seconds=self.backoff_seconds, api_key=self.api_key, + retrieval_system=self.retrieval_system ).results else: results = search( @@ -147,6 +149,7 @@ def _transform_query(self, topic: DataFrame) -> DataFrame: retries=self.retries, backoff_seconds=self.backoff_seconds, api_key=self.api_key, + retrieval_system=self.retrieval_system ).results else: if explain: @@ -161,6 +164,7 @@ def _transform_query(self, topic: DataFrame) -> DataFrame: retries=self.retries, backoff_seconds=self.backoff_seconds, api_key=self.api_key, + retrieval_system=self.retrieval_system ).results else: results = search_phrases( @@ -174,6 +178,7 @@ def _transform_query(self, topic: DataFrame) -> DataFrame: retries=self.retries, backoff_seconds=self.backoff_seconds, api_key=self.api_key, + retrieval_system=self.retrieval_system ).results if self.filter_unknown: diff --git a/tests/test_random_document_access.py b/tests/test_random_document_access.py new file mode 100644 index 0000000..c53f1ad --- /dev/null +++ b/tests/test_random_document_access.py @@ -0,0 +1,6 @@ +import unittest + +class TestRandomDocumentAccess(unittest.TestCase): + def test_document_on_ms_marco_v2_1(self): + +