Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/stdlib key fn #48

Merged
merged 5 commits into from
Feb 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 72 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

## News and Updates 🔥

- **2025-02-01** Callable for stdlib functions (sorted/min/max)
- **2025-01-04** We're excited to announce support for model2vec static embeddings. See also: [Model2Vec](https://github.com/MinishLab/model2vec)
- **2024-10-04** Added semantic splitting inference algorithm. See our [technical overview](tutorials/blog/semantic_split/wl_semantic_blog.md).

Expand All @@ -20,6 +21,7 @@
- [How Fast?](#how-fast-zap)
- [Usage](#usage)
- [Embedding Text](#embedding-text)
- [Stdlib Sorted/Min/Max](#stdlib-sorted-min-max)
- [Calculating Similarity](#calculating-similarity)
- [Ranking Documents](#ranking-documents)
- [Fuzzy Deduplication](#fuzzy-deduplication)
Expand Down Expand Up @@ -52,22 +54,37 @@ from wordllama import WordLlama
# Load the default WordLlama model
wl = WordLlama.load()

# Calculate similarity between two sentences
similarity_score = wl.similarity("I went to the car", "I went to the pawn shop")
print(similarity_score) # Output: e.g., 0.0664

# Rank documents based on their similarity to a query
query = "I went to the car"
candidates = ["I went to the park", "I went to the shop", "I went to the truck", "I went to the vehicle"]
ranked_docs = wl.rank(query, candidates)
print(ranked_docs)
# Output:
# [
# ('I went to the vehicle', 0.7441),
# ('I went to the truck', 0.2832),
# ('I went to the shop', 0.1973),
# ('I went to the park', 0.1510)
# ]
query = "Machine learning methods"
candidates = [
"Foundations of neural science",
"Introduction to neural networks",
"Cooking delicious pasta at home",
"Introduction to philosophy: logic",
]

# Returns a Callable[[str], float] function
sim_key = wl.key(query)

# Sort candidates, most similar first
sorted_candidates = sorted(candidates, key=sim_key, reverse=True)

# Most similar candidate
best_candidate = max(candidates, key=sim_key)

# Print the results
print("Ranked Candidates:")
for i, candidate in enumerate(sorted_candidates, 1):
print(f"{i}. {candidate} (Score: {sim_key(candidate):.4f})")

print(f"\nBest Match: {best_candidate} (Score: {sim_key(best_candidate):.4f})")

# Ranked Candidates:
# 1. Introduction to neural networks (Score: 0.3414)
# 2. Foundations of neural science (Score: 0.2115)
# 3. Introduction to philosophy: logic (Score: 0.1067)
# 4. Cooking delicious pasta at home (Score: 0.0045)
#
# Best Match: Introduction to neural networks (Score: 0.3414)
```

## Features
Expand Down Expand Up @@ -149,6 +166,44 @@ embeddings = wl.embed(["The quick brown fox jumps over the lazy dog", "And all t
print(embeddings.shape) # Output: (2, 64)
```

### Stdlib Examples

Return a Callable function from `.key(query)`.

```python
query = "Machine learning methods"
candidates = [
"Foundations of neural science",
"Introduction to neural networks",
"Cooking delicious pasta at home",
"Introduction to philosophy: logic",
]

# Returns a Callable[[str], float] function
sim_key = wl.key(query)

# Sort candidates, most similar first
sorted_candidates = sorted(candidates, key=sim_key, reverse=True)

# Most similar candidate
best_candidate = max(candidates, key=sim_key)

# Print the results
print("Ranked Candidates:")
for i, candidate in enumerate(sorted_candidates, 1):
print(f"{i}. {candidate} (Score: {sim_key(candidate):.4f})")

print(f"\nBest Match: {best_candidate} (Score: {sim_key(best_candidate):.4f})")

# Ranked Candidates:
# 1. Introduction to neural networks (Score: 0.3414)
# 2. Foundations of neural science (Score: 0.2115)
# 3. Introduction to philosophy: logic (Score: 0.1067)
# 4. Cooking delicious pasta at home (Score: 0.0045)
#
# Best Match: Introduction to neural networks (Score: 0.3414)
```

### Calculating Similarity

Compute the similarity between two texts:
Expand Down Expand Up @@ -324,7 +379,7 @@ If you use WordLlama in your research or project, please consider citing it as f
title = {WordLlama: Recycled Token Embeddings from Large Language Models},
year = {2024},
url = {https://github.com/dleemiller/wordllama},
version = {0.3.7}
version = {0.3.9}
}
```

Expand Down
24 changes: 24 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,27 @@ def test_function_similarity_binary(self):
wl = WordLlama.load()
wl.binary = True
wl.similarity("a", "b")

def test_function_sorted(self):
wl = WordLlama.load()
query = "test query"
candidates = ["example A", "example B", "example C"]

sim_key = wl.key(query)
sorted_candidates = sorted(candidates, key=sim_key, reverse=True)

self.assertIsInstance(sorted_candidates, list)
self.assertEqual(len(sorted_candidates), len(candidates))

def test_function_max(self):
wl = WordLlama.load()
query = "test query"
candidates = ["example A", "example B", "example C"]

sim_key = wl.key(query)
best_candidate = max(candidates, key=sim_key)

self.assertIn(best_candidate, candidates)
self.assertEqual(
best_candidate, max(candidates, key=lambda x: wl.similarity(query, x))
)
28 changes: 27 additions & 1 deletion wordllama/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from tokenizers import Tokenizer
from typing import Union, List, Tuple, Optional
from typing import Callable, List, Optional, Tuple, Union
import logging

from .algorithms import (
Expand Down Expand Up @@ -177,6 +177,32 @@ def similarity(self, text1: str, text2: str) -> float:
embedding2 = self.embed(text2)
return self.vector_similarity(embedding1[0], embedding2[0]).item()

def key(self, query: str, norm: bool = True) -> Callable[[str], float]:
"""
Returns a key function for comparing candidate strings based on their
similarity to the given query. This key function can be used with built-in
functions like sorted(), min(), and max().

Args:
query (str): The reference query text.
norm (bool, optional): Whether to normalize embeddings before computing
similarity. Defaults to True.

Returns:
Callable[[str], float]: A function that computes the similarity between
the precomputed query embedding and a candidate string.
"""
# Precompute the embedding for the query
query_embedding = self.embed(query, norm=norm)

def similarity_key(candidate: str) -> float:
candidate_embedding = self.embed(candidate, norm=norm)
return self.vector_similarity(
query_embedding[0], candidate_embedding[0]
).item()

return similarity_key

def rank(
self, query: str, docs: List[str], sort: bool = True, batch_size: int = 64
) -> List[Tuple[str, float]]:
Expand Down