diff --git a/chatnoir_pyterrier/retrieve.py b/chatnoir_pyterrier/retrieve.py index eb568a6..df33eab 100644 --- a/chatnoir_pyterrier/retrieve.py +++ b/chatnoir_pyterrier/retrieve.py @@ -8,7 +8,7 @@ search, search_phrases ) from chatnoir_api.defaults import ( - DEFAULT_INDEX, DEFAULT_SLOP, DEFAULT_RETRIES, DEFAULT_BACKOFF_SECONDS, DEFAULT_API_KEY + DEFAULT_INDEX, DEFAULT_SLOP, DEFAULT_RETRIES, DEFAULT_BACKOFF_SECONDS, DEFAULT_API_KEY, DEFAULT_RETRIEVAL_SYSTEM ) from pandas import DataFrame from pandas.core.groupby import DataFrameGroupBy @@ -35,6 +35,7 @@ class ChatNoirRetrieve(Transformer): backoff_seconds: float = DEFAULT_BACKOFF_SECONDS verbose: bool = False api_key: str = DEFAULT_API_KEY + retrieval_system: str = DEFAULT_RETRIEVAL_SYSTEM def _merge_result( self, @@ -135,6 +136,7 @@ def _transform_query(self, topic: DataFrame) -> DataFrame: retries=self.retries, backoff_seconds=self.backoff_seconds, api_key=self.api_key, + retrieval_system=self.retrieval_system ).results else: results = search( @@ -147,6 +149,7 @@ def _transform_query(self, topic: DataFrame) -> DataFrame: retries=self.retries, backoff_seconds=self.backoff_seconds, api_key=self.api_key, + retrieval_system=self.retrieval_system ).results else: if explain: @@ -161,6 +164,7 @@ def _transform_query(self, topic: DataFrame) -> DataFrame: retries=self.retries, backoff_seconds=self.backoff_seconds, api_key=self.api_key, + retrieval_system=self.retrieval_system ).results else: results = search_phrases( @@ -174,6 +178,7 @@ def _transform_query(self, topic: DataFrame) -> DataFrame: retries=self.retries, backoff_seconds=self.backoff_seconds, api_key=self.api_key, + retrieval_system=self.retrieval_system ).results if self.filter_unknown: diff --git a/tests/test_random_document_access.py b/tests/test_random_document_access.py new file mode 100644 index 0000000..c53f1ad --- /dev/null +++ b/tests/test_random_document_access.py @@ -0,0 +1,6 @@ +import unittest + +class TestRandomDocumentAccess(unittest.TestCase): + def test_document_on_ms_marco_v2_1(self): + +