From c661f491af75c2a5e1f776cd108b1c987eac4e09 Mon Sep 17 00:00:00 2001 From: Lee Miller Date: Mon, 14 Oct 2024 09:02:51 -0600 Subject: [PATCH 1/4] add option in deduplicate() to return indices --- tests/test_inference.py | 14 ++++++++++++++ wordllama/inference.py | 13 +++++++++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/tests/test_inference.py b/tests/test_inference.py index f9e2e86..e02c6ba 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -123,6 +123,20 @@ def test_deduplicate_all_duplicates(self, mock_embed): self.assertEqual(len(deduplicated_docs), 1) self.assertIn("doc1", deduplicated_docs) + @patch.object( + WordLlamaInference, + "embed", + return_value=np.array([[0.1] * 64, [0.1] * 64, [0.1] * 64], dtype=np.float32), + ) + def test_deduplicate_return_indices(self, mock_embed): + docs = ["doc1", "doc1_dup", "doc1_dup2"] + duplicated_idx = self.model.deduplicate( + docs, return_indices=True, threshold=0.9 + ) + self.assertEqual(len(duplicated_idx), 2) + self.assertIn(1, duplicated_idx) + self.assertIn(2, duplicated_idx) + def test_tokenize(self): tokens = self.model.tokenize("test string") self.mock_tokenizer.encode_batch.assert_called_with( diff --git a/wordllama/inference.py b/wordllama/inference.py index da1abcf..4efae36 100644 --- a/wordllama/inference.py +++ b/wordllama/inference.py @@ -207,13 +207,18 @@ def rank( return similarities def deduplicate( - self, docs: List[str], threshold: float = 0.9, batch_size: Optional[int] = None - ) -> List[str]: + self, + docs: List[str], + threshold: float = 0.9, + return_indices: bool = False, + batch_size: Optional[int] = None, + ) -> List[Union[str, int]]: """Deduplicate documents based on a similarity threshold. Args: docs (List[str]): List of documents to deduplicate. threshold (float, optional): Similarity threshold above which documents are considered duplicates. Defaults to 0.9. + return_indices (bool, optional): Return indices of duplicated documents, rather than deduplicated list of documents. batch_size (Optional[int], optional): Batch size for processing embeddings. Defaults to None. Returns: @@ -226,6 +231,10 @@ def deduplicate( duplicate_indices = deduplicate_embeddings( doc_embeddings, threshold, batch_size ) + if return_indices: + # turn set of numpy int into sorted list of python int + duplicate_indices = list(map(lambda x: x.item(), duplicate_indices)) + return sorted(duplicate_indices) unique_docs = [ doc for idx, doc in enumerate(docs) if idx not in duplicate_indices From c59a80c1a1aef73e272c73798df56b2b1aaecf6f Mon Sep 17 00:00:00 2001 From: Lee Miller Date: Mon, 14 Oct 2024 09:05:14 -0600 Subject: [PATCH 2/4] adding kwarg to readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 80d36b0..1c84201 100644 --- a/README.md +++ b/README.md @@ -178,7 +178,7 @@ print(ranked_docs) Remove duplicate texts based on a similarity threshold: ```python -deduplicated_docs = wl.deduplicate(candidates, threshold=0.5) +deduplicated_docs = wl.deduplicate(candidates, return_indices=False, threshold=0.5) print(deduplicated_docs) # Output: # ['I went to the park', @@ -294,7 +294,7 @@ If you use WordLlama in your research or project, please consider citing it as f title = {WordLlama: Recycled Token Embeddings from Large Language Models}, year = {2024}, url = {https://github.com/dleemiller/wordllama}, - version = {0.3.1} + version = {0.3.2} } ``` From b7a1211019b7e25686d2f3226e2b79bb8525357c Mon Sep 17 00:00:00 2001 From: Lee Miller Date: Mon, 14 Oct 2024 09:11:29 -0600 Subject: [PATCH 3/4] adding seed to setup --- tests/test_inference.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_inference.py b/tests/test_inference.py index e02c6ba..110f960 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -15,6 +15,8 @@ class TestWordLlamaInference(unittest.TestCase): @patch("wordllama.inference.Tokenizer.from_pretrained") def setUp(self, mock_tokenizer): + np.random.seed(42) + # Mock the tokenizer self.mock_tokenizer = MagicMock() From c1cc6040464e47f52eee2b7c46a93d1e2f24aba2 Mon Sep 17 00:00:00 2001 From: Lee Miller Date: Mon, 14 Oct 2024 09:16:51 -0600 Subject: [PATCH 4/4] add seed at top of file as well --- tests/test_inference.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_inference.py b/tests/test_inference.py index 110f960..15a654f 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -11,6 +11,8 @@ TokenizerInferenceConfig, ) +np.random.seed(42) + class TestWordLlamaInference(unittest.TestCase): @patch("wordllama.inference.Tokenizer.from_pretrained")