Skip to content

Commit

Permalink
draft to allow for multiple retrieval backends
Browse files Browse the repository at this point in the history
  • Loading branch information
mam10eks committed Nov 13, 2024
1 parent a668f04 commit c525acd
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
7 changes: 6 additions & 1 deletion chatnoir_pyterrier/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
search, search_phrases
)
from chatnoir_api.defaults import (
DEFAULT_INDEX, DEFAULT_SLOP, DEFAULT_RETRIES, DEFAULT_BACKOFF_SECONDS, DEFAULT_API_KEY
DEFAULT_INDEX, DEFAULT_SLOP, DEFAULT_RETRIES, DEFAULT_BACKOFF_SECONDS, DEFAULT_API_KEY, DEFAULT_RETRIEVAL_SYSTEM
)
from pandas import DataFrame
from pandas.core.groupby import DataFrameGroupBy
Expand All @@ -35,6 +35,7 @@ class ChatNoirRetrieve(Transformer):
backoff_seconds: float = DEFAULT_BACKOFF_SECONDS
verbose: bool = False
api_key: str = DEFAULT_API_KEY
retrieval_system: str = DEFAULT_RETRIEVAL_SYSTEM

def _merge_result(
self,
Expand Down Expand Up @@ -135,6 +136,7 @@ def _transform_query(self, topic: DataFrame) -> DataFrame:
retries=self.retries,
backoff_seconds=self.backoff_seconds,
api_key=self.api_key,
retrieval_system=self.retrieval_system
).results
else:
results = search(
Expand All @@ -147,6 +149,7 @@ def _transform_query(self, topic: DataFrame) -> DataFrame:
retries=self.retries,
backoff_seconds=self.backoff_seconds,
api_key=self.api_key,
retrieval_system=self.retrieval_system
).results
else:
if explain:
Expand All @@ -161,6 +164,7 @@ def _transform_query(self, topic: DataFrame) -> DataFrame:
retries=self.retries,
backoff_seconds=self.backoff_seconds,
api_key=self.api_key,
retrieval_system=self.retrieval_system
).results
else:
results = search_phrases(
Expand All @@ -174,6 +178,7 @@ def _transform_query(self, topic: DataFrame) -> DataFrame:
retries=self.retries,
backoff_seconds=self.backoff_seconds,
api_key=self.api_key,
retrieval_system=self.retrieval_system
).results

if self.filter_unknown:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_random_document_access.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import unittest

class TestRandomDocumentAccess(unittest.TestCase):
def test_document_on_ms_marco_v2_1(self):


0 comments on commit c525acd

Please sign in to comment.