Skip to content

Commit

Permalink
initialize_retriever(), partialclass() in a more general place
Browse files Browse the repository at this point in the history
  • Loading branch information
parkervg committed Oct 15, 2024
1 parent f68fb57 commit 13f94ef
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion blendsql/ingredients/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Union, List, Set, Dict
from typing import Union, List, Set, Dict, Callable
from functools import partialmethod, partial
from colorama import Fore

from ..utils import get_tablename_colname
from ..db import Database
from .few_shot import Example
from .._logger import logger


def unpack_options(
Expand All @@ -24,3 +28,32 @@ def unpack_options(
except ValueError:
unpacked_options = options.split(";")
return set(unpacked_options)


def initialize_retriever(
examples: Example, k: int = None, **to_string_args
) -> Callable[[str], List[Example]]:
"""Initializes a DPR retriever over the few-shot examples provided."""
if k is None or k == len(examples):
# Just return all the examples everytime this is called
return lambda *_: examples
assert k <= len(
examples
), f"The `k` argument to an ingredient must be less than `len(few_shot_examples)`!\n`k` is {k}, `len(few_shot_examples)` is {len(examples)}"
from .retriever import Retriever

logger.debug(Fore.YELLOW + "Processing documents with haystack..." + Fore.RESET)
retriever = Retriever(
documents=[example.to_string(**to_string_args) for example in examples],
return_objs=examples,
)
return partial(retriever.retrieve_top_k, k=k)


def partialclass(cls, *args, **kwds):
# https://stackoverflow.com/a/38911383
class NewCls(cls):
__init__ = partialmethod(cls.__init__, *args, **kwds)

NewCls.__name__ = cls.__name__
return NewCls

0 comments on commit 13f94ef

Please sign in to comment.