Skip to content

Commit 8e63a72

Browse files
committed
[ENH] Add muvera support
1 parent 998da94 commit 8e63a72

File tree

6 files changed

+873
-1
lines changed

6 files changed

+873
-1
lines changed

chromadb/test/ef/test_ef.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def test_get_builtins_holds() -> None:
5454
"Bm25EmbeddingFunction",
5555
"ChromaCloudQwenEmbeddingFunction",
5656
"ChromaCloudSpladeEmbeddingFunction",
57+
"PylateColBERTEmbeddingFunction",
5758
}
5859

5960
assert expected_builtins == embedding_functions.get_builtins()

chromadb/utils/embedding_functions/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@
8383
from chromadb.utils.embedding_functions.chroma_cloud_splade_embedding_function import (
8484
ChromaCloudSpladeEmbeddingFunction,
8585
)
86-
86+
from chromadb.utils.embedding_functions.pylate_colbert_embedding_function import (
87+
PylateColBERTEmbeddingFunction,
88+
)
8789

8890
# Get all the class names for backward compatibility
8991
_all_classes: Set[str] = {
@@ -116,6 +118,7 @@
116118
"Bm25EmbeddingFunction",
117119
"ChromaCloudQwenEmbeddingFunction",
118120
"ChromaCloudSpladeEmbeddingFunction",
121+
"PylateColBERTEmbeddingFunction",
119122
}
120123

121124

@@ -150,6 +153,7 @@ def get_builtins() -> Set[str]:
150153
"cloudflare_workers_ai": CloudflareWorkersAIEmbeddingFunction,
151154
"together_ai": TogetherAIEmbeddingFunction,
152155
"chroma-cloud-qwen": ChromaCloudQwenEmbeddingFunction,
156+
"pylate_colbert": PylateColBERTEmbeddingFunction,
153157
}
154158

155159
sparse_known_embedding_functions: Dict[str, Type[SparseEmbeddingFunction]] = { # type: ignore
@@ -273,6 +277,7 @@ def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction:
273277
"Bm25EmbeddingFunction",
274278
"ChromaCloudQwenEmbeddingFunction",
275279
"ChromaCloudSpladeEmbeddingFunction",
280+
"PylateColBERTEmbeddingFunction",
276281
"register_embedding_function",
277282
"config_to_embedding_function",
278283
"known_embedding_functions",

chromadb/utils/embedding_functions/jina_embedding_function.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import os
1212
import numpy as np
1313
import warnings
14+
from chromadb.utils.muvera import create_fdes
1415
import importlib
1516
import base64
1617
import io
@@ -37,6 +38,7 @@ def __init__(
3738
dimensions: Optional[int] = None,
3839
embedding_type: Optional[str] = None,
3940
normalized: Optional[bool] = None,
41+
return_multivector: Optional[bool] = None,
4042
query_config: Optional[JinaQueryConfig] = None,
4143
):
4244
"""
@@ -95,6 +97,7 @@ def __init__(
9597
self.dimensions = dimensions
9698
self.embedding_type = embedding_type
9799
self.normalized = normalized
100+
self.return_multivector = return_multivector
98101
self.query_config = query_config
99102

100103
self._api_url = "https://api.jina.ai/v1/embeddings"
@@ -143,6 +146,8 @@ def _build_payload(self, input: Embeddable, is_query: bool) -> Dict[str, Any]:
143146
payload["embedding_type"] = self.embedding_type
144147
if self.normalized is not None:
145148
payload["normalized"] = self.normalized
149+
if self.return_multivector is not None:
150+
payload["return_multivector"] = self.return_multivector
146151

147152
# overwrite parameteres when query payload is used
148153
if is_query and self.query_config is not None:
@@ -164,6 +169,30 @@ def _convert_resp(self, resp: Any, is_query: bool = False) -> Embeddings:
164169
if "data" not in resp:
165170
raise RuntimeError(resp.get("detail", "Unknown error"))
166171

172+
if self.return_multivector:
173+
# if it gives back multivector embeddings
174+
multi_embeddings_data: List[Dict[str, Any]] = resp["data"]
175+
sorted_multi_embeddings = sorted(
176+
multi_embeddings_data, key=lambda e: e["index"]
177+
)
178+
multi_embeddings: List[Embeddings] = [
179+
[
180+
np.array(vec, dtype=np.float32)
181+
for vec in multi_embedding_obj["embeddings"]
182+
]
183+
for multi_embedding_obj in sorted_multi_embeddings
184+
]
185+
186+
dims = len(multi_embeddings[0][0])
187+
fdes = create_fdes(
188+
multi_embeddings,
189+
dims=dims,
190+
is_query=is_query,
191+
fill_empty_partitions=not is_query,
192+
)
193+
194+
return fdes
195+
167196
embeddings_data: List[Dict[str, Union[int, List[float]]]] = resp["data"]
168197

169198
# Sort resulting embeddings by index
@@ -225,6 +254,7 @@ def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]"
225254
dimensions = config.get("dimensions")
226255
embedding_type = config.get("embedding_type")
227256
normalized = config.get("normalized")
257+
return_multivector = config.get("return_multivector")
228258
query_config = config.get("query_config")
229259

230260
if api_key_env_var is None or model_name is None:
@@ -239,6 +269,7 @@ def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Embeddable]"
239269
dimensions=dimensions,
240270
embedding_type=embedding_type,
241271
normalized=normalized,
272+
return_multivector=return_multivector,
242273
query_config=query_config,
243274
)
244275

@@ -252,6 +283,7 @@ def get_config(self) -> Dict[str, Any]:
252283
"dimensions": self.dimensions,
253284
"embedding_type": self.embedding_type,
254285
"normalized": self.normalized,
286+
"return_multivector": self.return_multivector,
255287
"query_config": self.query_config,
256288
}
257289

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from chromadb.api.types import Embeddings, Documents, EmbeddingFunction, Space
2+
from typing import List, Dict, Any
3+
from chromadb.utils.embedding_functions.schemas import validate_config_schema
4+
from chromadb.utils.muvera import create_fdes
5+
6+
7+
class PylateColBERTEmbeddingFunction(EmbeddingFunction[Documents]):
8+
"""
9+
This class is used to get embeddings for a list of texts using the ColBERT API.
10+
"""
11+
12+
def __init__(
13+
self,
14+
model_name: str,
15+
):
16+
"""
17+
Initialize the PylateColBERTEmbeddingFunction.
18+
19+
Args:
20+
model_name (str): The name of the model to use for text embeddings.
21+
Examples: "mixedbread-ai/mxbai-edge-colbert-v0-17m", "mixedbread-ai/mxbai-edge-colbert-v0-32m", "lightonai/colbertv2.0", "answerdotai/answerai-colbert-small-v1", "jinaai/jina-colbert-v2", "GTE-ModernColBERT-v1"
22+
"""
23+
try:
24+
from pylate import models
25+
except ImportError:
26+
raise ValueError(
27+
"The pylate colbert python package is not installed. Please install it with `pip install pylate-colbert`"
28+
)
29+
30+
self.model_name = model_name
31+
self.model = models.ColBERT(model_name_or_path=model_name)
32+
33+
def __call__(self, input: Documents) -> Embeddings:
34+
"""
35+
Get the embeddings for a list of texts.
36+
37+
Args:
38+
input (Documents): A list of texts to get embeddings for.
39+
40+
Returns:
41+
Embeddings: The embeddings for the texts.
42+
"""
43+
multivec = self.model.encode(input, batch_size=32, is_query=False)
44+
return create_fdes(
45+
multivec,
46+
dims=len(multivec[0][0]),
47+
is_query=False,
48+
fill_empty_partitions=True,
49+
)
50+
51+
def embed_query(self, input: Documents) -> Embeddings:
52+
"""
53+
Get the embeddings for a list of texts.
54+
55+
Args:
56+
input (Documents): A list of texts to get embeddings for.
57+
58+
Returns:
59+
Embeddings: The embeddings for the texts.
60+
"""
61+
multivec = self.model.encode(input, batch_size=32, is_query=True)
62+
return create_fdes(
63+
multivec,
64+
dims=len(multivec[0][0]),
65+
is_query=True,
66+
fill_empty_partitions=False,
67+
)
68+
69+
@staticmethod
70+
def name() -> str:
71+
return "pylate_colbert"
72+
73+
def default_space(self) -> Space:
74+
return "cosine"
75+
76+
def supported_spaces(self) -> List[Space]:
77+
return ["cosine", "l2", "ip"]
78+
79+
@staticmethod
80+
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
81+
model_name = config.get("model_name")
82+
83+
if model_name is None:
84+
assert False, "This code should not be reached"
85+
86+
return PylateColBERTEmbeddingFunction(model_name=model_name)
87+
88+
def get_config(self) -> Dict[str, Any]:
89+
return {"model_name": self.model_name}
90+
91+
def validate_config_update(
92+
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
93+
) -> None:
94+
if "model_name" in new_config:
95+
raise ValueError(
96+
"The model name cannot be changed after the embedding function has been initialized."
97+
)
98+
99+
@staticmethod
100+
def validate_config(config: Dict[str, Any]) -> None:
101+
"""
102+
Validate the configuration using the JSON schema.
103+
104+
Args:
105+
config: Configuration to validate
106+
107+
Raises:
108+
ValidationError: If the configuration does not match the schema
109+
"""
110+
validate_config_schema(config, "pylate_colbert")

0 commit comments

Comments
 (0)