diff --git a/README.md b/README.md index 5fb8691..c5ce668 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ ## News and Updates 🔥 +- **2025-02-01** Callable for stdlib functions (sorted/min/max) - **2025-01-04** We're excited to announce support for model2vec static embeddings. See also: [Model2Vec](https://github.com/MinishLab/model2vec) - **2024-10-04** Added semantic splitting inference algorithm. See our [technical overview](tutorials/blog/semantic_split/wl_semantic_blog.md). @@ -20,6 +21,7 @@ - [How Fast?](#how-fast-zap) - [Usage](#usage) - [Embedding Text](#embedding-text) + - [Stdlib Sorted/Min/Max](#stdlib-sorted-min-max) - [Calculating Similarity](#calculating-similarity) - [Ranking Documents](#ranking-documents) - [Fuzzy Deduplication](#fuzzy-deduplication) @@ -52,22 +54,37 @@ from wordllama import WordLlama # Load the default WordLlama model wl = WordLlama.load() -# Calculate similarity between two sentences -similarity_score = wl.similarity("I went to the car", "I went to the pawn shop") -print(similarity_score) # Output: e.g., 0.0664 - -# Rank documents based on their similarity to a query -query = "I went to the car" -candidates = ["I went to the park", "I went to the shop", "I went to the truck", "I went to the vehicle"] -ranked_docs = wl.rank(query, candidates) -print(ranked_docs) -# Output: -# [ -# ('I went to the vehicle', 0.7441), -# ('I went to the truck', 0.2832), -# ('I went to the shop', 0.1973), -# ('I went to the park', 0.1510) -# ] +query = "Machine learning methods" +candidates = [ + "Foundations of neural science", + "Introduction to neural networks", + "Cooking delicious pasta at home", + "Introduction to philosophy: logic", +] + +# Returns a Callable[[str], float] function +sim_key = wl.key(query) + +# Sort candidates, most similar first +sorted_candidates = sorted(candidates, key=sim_key, reverse=True) + +# Most similar candidate +best_candidate = max(candidates, key=sim_key) + +# Print the results +print("Ranked Candidates:") +for i, candidate in enumerate(sorted_candidates, 1): + print(f"{i}. {candidate} (Score: {sim_key(candidate):.4f})") + +print(f"\nBest Match: {best_candidate} (Score: {sim_key(best_candidate):.4f})") + +# Ranked Candidates: +# 1. Introduction to neural networks (Score: 0.3414) +# 2. Foundations of neural science (Score: 0.2115) +# 3. Introduction to philosophy: logic (Score: 0.1067) +# 4. Cooking delicious pasta at home (Score: 0.0045) +# +# Best Match: Introduction to neural networks (Score: 0.3414) ``` ## Features @@ -149,6 +166,44 @@ embeddings = wl.embed(["The quick brown fox jumps over the lazy dog", "And all t print(embeddings.shape) # Output: (2, 64) ``` +### Stdlib Examples + +Return a Callable function from `.key(query)`. + +```python +query = "Machine learning methods" +candidates = [ + "Foundations of neural science", + "Introduction to neural networks", + "Cooking delicious pasta at home", + "Introduction to philosophy: logic", +] + +# Returns a Callable[[str], float] function +sim_key = wl.key(query) + +# Sort candidates, most similar first +sorted_candidates = sorted(candidates, key=sim_key, reverse=True) + +# Most similar candidate +best_candidate = max(candidates, key=sim_key) + +# Print the results +print("Ranked Candidates:") +for i, candidate in enumerate(sorted_candidates, 1): + print(f"{i}. {candidate} (Score: {sim_key(candidate):.4f})") + +print(f"\nBest Match: {best_candidate} (Score: {sim_key(best_candidate):.4f})") + +# Ranked Candidates: +# 1. Introduction to neural networks (Score: 0.3414) +# 2. Foundations of neural science (Score: 0.2115) +# 3. Introduction to philosophy: logic (Score: 0.1067) +# 4. Cooking delicious pasta at home (Score: 0.0045) +# +# Best Match: Introduction to neural networks (Score: 0.3414) +``` + ### Calculating Similarity Compute the similarity between two texts: @@ -324,7 +379,7 @@ If you use WordLlama in your research or project, please consider citing it as f title = {WordLlama: Recycled Token Embeddings from Large Language Models}, year = {2024}, url = {https://github.com/dleemiller/wordllama}, - version = {0.3.7} + version = {0.3.9} } ``` diff --git a/tests/test_functional.py b/tests/test_functional.py index 542d441..6876a46 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -15,3 +15,27 @@ def test_function_similarity_binary(self): wl = WordLlama.load() wl.binary = True wl.similarity("a", "b") + + def test_function_sorted(self): + wl = WordLlama.load() + query = "test query" + candidates = ["example A", "example B", "example C"] + + sim_key = wl.key(query) + sorted_candidates = sorted(candidates, key=sim_key, reverse=True) + + self.assertIsInstance(sorted_candidates, list) + self.assertEqual(len(sorted_candidates), len(candidates)) + + def test_function_max(self): + wl = WordLlama.load() + query = "test query" + candidates = ["example A", "example B", "example C"] + + sim_key = wl.key(query) + best_candidate = max(candidates, key=sim_key) + + self.assertIn(best_candidate, candidates) + self.assertEqual( + best_candidate, max(candidates, key=lambda x: wl.similarity(query, x)) + ) diff --git a/wordllama/inference.py b/wordllama/inference.py index ccd343a..f5159ac 100644 --- a/wordllama/inference.py +++ b/wordllama/inference.py @@ -1,6 +1,6 @@ import numpy as np from tokenizers import Tokenizer -from typing import Union, List, Tuple, Optional +from typing import Callable, List, Optional, Tuple, Union import logging from .algorithms import ( @@ -177,6 +177,32 @@ def similarity(self, text1: str, text2: str) -> float: embedding2 = self.embed(text2) return self.vector_similarity(embedding1[0], embedding2[0]).item() + def key(self, query: str, norm: bool = True) -> Callable[[str], float]: + """ + Returns a key function for comparing candidate strings based on their + similarity to the given query. This key function can be used with built-in + functions like sorted(), min(), and max(). + + Args: + query (str): The reference query text. + norm (bool, optional): Whether to normalize embeddings before computing + similarity. Defaults to True. + + Returns: + Callable[[str], float]: A function that computes the similarity between + the precomputed query embedding and a candidate string. + """ + # Precompute the embedding for the query + query_embedding = self.embed(query, norm=norm) + + def similarity_key(candidate: str) -> float: + candidate_embedding = self.embed(candidate, norm=norm) + return self.vector_similarity( + query_embedding[0], candidate_embedding[0] + ).item() + + return similarity_key + def rank( self, query: str, docs: List[str], sort: bool = True, batch_size: int = 64 ) -> List[Tuple[str, float]]: