Skip to content

Commit

Permalink
Merge pull request #48 from dleemiller/feature/stdlib-key-fn
Browse files Browse the repository at this point in the history
Feature/stdlib key fn
  • Loading branch information
dleemiller authored Feb 2, 2025
2 parents 0f49ac1 + bf839d7 commit 64ffc66
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 18 deletions.
89 changes: 72 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

## News and Updates 🔥

- **2025-02-01** Callable for stdlib functions (sorted/min/max)
- **2025-01-04** We're excited to announce support for model2vec static embeddings. See also: [Model2Vec](https://github.com/MinishLab/model2vec)
- **2024-10-04** Added semantic splitting inference algorithm. See our [technical overview](tutorials/blog/semantic_split/wl_semantic_blog.md).

Expand All @@ -20,6 +21,7 @@
- [How Fast?](#how-fast-zap)
- [Usage](#usage)
- [Embedding Text](#embedding-text)
- [Stdlib Sorted/Min/Max](#stdlib-sorted-min-max)
- [Calculating Similarity](#calculating-similarity)
- [Ranking Documents](#ranking-documents)
- [Fuzzy Deduplication](#fuzzy-deduplication)
Expand Down Expand Up @@ -52,22 +54,37 @@ from wordllama import WordLlama
# Load the default WordLlama model
wl = WordLlama.load()

# Calculate similarity between two sentences
similarity_score = wl.similarity("I went to the car", "I went to the pawn shop")
print(similarity_score) # Output: e.g., 0.0664

# Rank documents based on their similarity to a query
query = "I went to the car"
candidates = ["I went to the park", "I went to the shop", "I went to the truck", "I went to the vehicle"]
ranked_docs = wl.rank(query, candidates)
print(ranked_docs)
# Output:
# [
# ('I went to the vehicle', 0.7441),
# ('I went to the truck', 0.2832),
# ('I went to the shop', 0.1973),
# ('I went to the park', 0.1510)
# ]
query = "Machine learning methods"
candidates = [
"Foundations of neural science",
"Introduction to neural networks",
"Cooking delicious pasta at home",
"Introduction to philosophy: logic",
]

# Returns a Callable[[str], float] function
sim_key = wl.key(query)

# Sort candidates, most similar first
sorted_candidates = sorted(candidates, key=sim_key, reverse=True)

# Most similar candidate
best_candidate = max(candidates, key=sim_key)

# Print the results
print("Ranked Candidates:")
for i, candidate in enumerate(sorted_candidates, 1):
print(f"{i}. {candidate} (Score: {sim_key(candidate):.4f})")

print(f"\nBest Match: {best_candidate} (Score: {sim_key(best_candidate):.4f})")

# Ranked Candidates:
# 1. Introduction to neural networks (Score: 0.3414)
# 2. Foundations of neural science (Score: 0.2115)
# 3. Introduction to philosophy: logic (Score: 0.1067)
# 4. Cooking delicious pasta at home (Score: 0.0045)
#
# Best Match: Introduction to neural networks (Score: 0.3414)
```

## Features
Expand Down Expand Up @@ -149,6 +166,44 @@ embeddings = wl.embed(["The quick brown fox jumps over the lazy dog", "And all t
print(embeddings.shape) # Output: (2, 64)
```

### Stdlib Examples

Return a Callable function from `.key(query)`.

```python
query = "Machine learning methods"
candidates = [
"Foundations of neural science",
"Introduction to neural networks",
"Cooking delicious pasta at home",
"Introduction to philosophy: logic",
]

# Returns a Callable[[str], float] function
sim_key = wl.key(query)

# Sort candidates, most similar first
sorted_candidates = sorted(candidates, key=sim_key, reverse=True)

# Most similar candidate
best_candidate = max(candidates, key=sim_key)

# Print the results
print("Ranked Candidates:")
for i, candidate in enumerate(sorted_candidates, 1):
print(f"{i}. {candidate} (Score: {sim_key(candidate):.4f})")

print(f"\nBest Match: {best_candidate} (Score: {sim_key(best_candidate):.4f})")

# Ranked Candidates:
# 1. Introduction to neural networks (Score: 0.3414)
# 2. Foundations of neural science (Score: 0.2115)
# 3. Introduction to philosophy: logic (Score: 0.1067)
# 4. Cooking delicious pasta at home (Score: 0.0045)
#
# Best Match: Introduction to neural networks (Score: 0.3414)
```

### Calculating Similarity

Compute the similarity between two texts:
Expand Down Expand Up @@ -324,7 +379,7 @@ If you use WordLlama in your research or project, please consider citing it as f
title = {WordLlama: Recycled Token Embeddings from Large Language Models},
year = {2024},
url = {https://github.com/dleemiller/wordllama},
version = {0.3.7}
version = {0.3.9}
}
```

Expand Down
24 changes: 24 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,27 @@ def test_function_similarity_binary(self):
wl = WordLlama.load()
wl.binary = True
wl.similarity("a", "b")

def test_function_sorted(self):
wl = WordLlama.load()
query = "test query"
candidates = ["example A", "example B", "example C"]

sim_key = wl.key(query)
sorted_candidates = sorted(candidates, key=sim_key, reverse=True)

self.assertIsInstance(sorted_candidates, list)
self.assertEqual(len(sorted_candidates), len(candidates))

def test_function_max(self):
wl = WordLlama.load()
query = "test query"
candidates = ["example A", "example B", "example C"]

sim_key = wl.key(query)
best_candidate = max(candidates, key=sim_key)

self.assertIn(best_candidate, candidates)
self.assertEqual(
best_candidate, max(candidates, key=lambda x: wl.similarity(query, x))
)
28 changes: 27 additions & 1 deletion wordllama/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from tokenizers import Tokenizer
from typing import Union, List, Tuple, Optional
from typing import Callable, List, Optional, Tuple, Union
import logging

from .algorithms import (
Expand Down Expand Up @@ -177,6 +177,32 @@ def similarity(self, text1: str, text2: str) -> float:
embedding2 = self.embed(text2)
return self.vector_similarity(embedding1[0], embedding2[0]).item()

def key(self, query: str, norm: bool = True) -> Callable[[str], float]:
"""
Returns a key function for comparing candidate strings based on their
similarity to the given query. This key function can be used with built-in
functions like sorted(), min(), and max().
Args:
query (str): The reference query text.
norm (bool, optional): Whether to normalize embeddings before computing
similarity. Defaults to True.
Returns:
Callable[[str], float]: A function that computes the similarity between
the precomputed query embedding and a candidate string.
"""
# Precompute the embedding for the query
query_embedding = self.embed(query, norm=norm)

def similarity_key(candidate: str) -> float:
candidate_embedding = self.embed(candidate, norm=norm)
return self.vector_similarity(
query_embedding[0], candidate_embedding[0]
).item()

return similarity_key

def rank(
self, query: str, docs: List[str], sort: bool = True, batch_size: int = 64
) -> List[Tuple[str, float]]:
Expand Down

0 comments on commit 64ffc66

Please sign in to comment.