Skip to content

Commit

Permalink
Merge changes in main
Browse files Browse the repository at this point in the history
  • Loading branch information
jgbarah committed Oct 14, 2024
1 parent 67b1db9 commit 2d96f4d
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions wordllama/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,18 @@ def rank(
return similarities

def deduplicate(
self, docs: List[str], threshold: float = 0.9, batch_size: Optional[int] = None
) -> List[str]:
self,
docs: List[str],
threshold: float = 0.9,
return_indices: bool = False,º
batch_size: Optional[int] = None
) -> List[Union[str, int]]:
"""Deduplicate documents based on a similarity threshold.
Args:
docs (List[str]): List of documents to deduplicate.
threshold (float, optional): Similarity threshold above which documents are considered duplicates. Defaults to 0.9.
return_indices (bool, optional): Return indices of duplicated documents, rather than deduplicated list of documents.
batch_size (Optional[int], optional): Batch size for processing embeddings. Defaults to None.
Returns:
Expand All @@ -227,6 +232,10 @@ def deduplicate(
duplicate_indices = deduplicate_embeddings(
doc_embeddings, threshold, batch_size
)
if return_indices:
# turn set of numpy int into sorted list of python int
duplicate_indices = list(map(lambda x: x.item(), duplicate_indices))
return sorted(duplicate_indices)

unique_docs = [
doc for idx, doc in enumerate(docs) if idx not in duplicate_indices
Expand Down

0 comments on commit 2d96f4d

Please sign in to comment.