Skip to content

Commit

Permalink
Merge pull request #33 from jgbarah/patch-1
Browse files Browse the repository at this point in the history
Allow WordLlama.rank to not sort the results
  • Loading branch information
dleemiller authored Oct 13, 2024
2 parents deddd6f + 2baa51f commit 3f834bf
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions wordllama/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,15 +239,18 @@ def similarity(self, text1: str, text2: str) -> float:
embedding2 = self.embed(text2)
return self.vector_similarity(embedding1[0], embedding2[0]).item()

def rank(self, query: str, docs: List[str]) -> List[Tuple[str, float]]:
def rank(self, query: str, docs: List[str], sort: bool = True) -> List[Tuple[str, float]]:
"""Rank documents based on their similarity to a query.
Result may be sorted by similarity score in descending order, or not (see `sort` parameter)
Args:
query (str): The query text.
docs (List[str]): The list of document texts to rank.
sort (bool): Sort documents by similarity, or not (respect the order in `docs`)
Returns:
List[Tuple[str, float]]: A list of tuples `(doc, score)`, sorted by similarity score in descending order.
List[Tuple[str, float]]: A list of tuples `(doc, score)`.
"""
assert isinstance(query, str), "Query must be a string"
assert (
Expand All @@ -259,7 +262,8 @@ def rank(self, query: str, docs: List[str]) -> List[Tuple[str, float]]:

scores = np.atleast_1d(scores.squeeze())
similarities = list(zip(docs, scores.tolist()))
similarities.sort(key=lambda x: x[1], reverse=True)
if sort:
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities

def deduplicate(
Expand Down

0 comments on commit 3f834bf

Please sign in to comment.