Skip to content

Commit 1542b3b

Browse files
committed
[ENH] Add nomic embedding function
1 parent 36af0d1 commit 1542b3b

File tree

4 files changed

+159
-0
lines changed

4 files changed

+159
-0
lines changed

chromadb/test/ef/test_ef.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def test_get_builtins_holds() -> None:
3939
"JinaEmbeddingFunction",
4040
"MistralEmbeddingFunction",
4141
"MorphEmbeddingFunction",
42+
"NomicEmbeddingFunction",
4243
"ONNXMiniLM_L6_V2",
4344
"OllamaEmbeddingFunction",
4445
"OpenAIEmbeddingFunction",

chromadb/utils/embedding_functions/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@
6868
from chromadb.utils.embedding_functions.morph_embedding_function import (
6969
MorphEmbeddingFunction,
7070
)
71+
from chromadb.utils.embedding_functions.nomic_embedding_function import (
72+
NomicEmbeddingFunction,
73+
NomicQueryConfig,
74+
)
7175
from chromadb.utils.embedding_functions.huggingface_sparse_embedding_function import (
7276
HuggingFaceSparseEmbeddingFunction,
7377
)
@@ -103,6 +107,7 @@
103107
"JinaEmbeddingFunction",
104108
"MistralEmbeddingFunction",
105109
"MorphEmbeddingFunction",
110+
"NomicEmbeddingFunction",
106111
"VoyageAIEmbeddingFunction",
107112
"ONNXMiniLM_L6_V2",
108113
"OpenCLIPEmbeddingFunction",
@@ -142,6 +147,7 @@ def get_builtins() -> Set[str]:
142147
"jina": JinaEmbeddingFunction,
143148
"mistral": MistralEmbeddingFunction,
144149
"morph": MorphEmbeddingFunction,
150+
"nomic": NomicEmbeddingFunction,
145151
"voyageai": VoyageAIEmbeddingFunction,
146152
"onnx_mini_lm_l6_v2": ONNXMiniLM_L6_V2,
147153
"open_clip": OpenCLIPEmbeddingFunction,
@@ -265,6 +271,8 @@ def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction:
265271
"JinaQueryConfig",
266272
"MistralEmbeddingFunction",
267273
"MorphEmbeddingFunction",
274+
"NomicEmbeddingFunction",
275+
"NomicQueryConfig",
268276
"VoyageAIEmbeddingFunction",
269277
"ONNXMiniLM_L6_V2",
270278
"OpenCLIPEmbeddingFunction",
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from chromadb.api.types import (
2+
Embeddings,
3+
Documents,
4+
EmbeddingFunction,
5+
Space,
6+
)
7+
from chromadb.utils.embedding_functions.schemas import validate_config_schema
8+
from typing import List, Dict, Any, TypedDict, Optional
9+
import os
10+
import numpy as np
11+
12+
13+
class NomicQueryConfig(TypedDict):
14+
task_type: str
15+
16+
17+
class NomicEmbeddingFunction(EmbeddingFunction[Documents]):
18+
"""
19+
This class is used to get embeddings for a list of texts using the Nomic API.
20+
"""
21+
22+
def __init__(
23+
self,
24+
model: str,
25+
task_type: str,
26+
query_config: Optional[NomicQueryConfig],
27+
api_key_env_var: str = "NOMIC_API_KEY",
28+
):
29+
"""
30+
Initialize the NomicEmbeddingFunction.
31+
32+
Args:
33+
model (str): The name of the model to use for text embeddings.
34+
task_type (str): The type of task to embed with. See reference https://docs.nomic.ai/platform/embeddings-and-retrieval/text-embedding#embedding-task-types
35+
query_config (Optional[NomicQueryConfig]): The configuration for setting task type for queries
36+
api_key_env_var (str): The environment variable name for the Nomic API key. Defaults to "NOMIC_API_KEY".
37+
38+
Supported task types: search_document, search_query, classification, clustering
39+
"""
40+
try:
41+
from nomic import embed
42+
except ImportError:
43+
raise ValueError(
44+
"The nomic python package is not installed. Please install it with `pip install nomic`"
45+
)
46+
47+
self.model = model
48+
self.task_type = task_type
49+
self.api_key_env_var = api_key_env_var
50+
self.api_key = os.getenv(api_key_env_var)
51+
self.query_config = query_config
52+
if not self.api_key:
53+
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
54+
self.embed = embed
55+
56+
def __call__(self, input: Documents) -> Embeddings:
57+
if not all(isinstance(item, str) for item in input):
58+
raise ValueError("Nomic only supports text documents, not images")
59+
output = self.embed.text(
60+
model=self.model,
61+
texts=input,
62+
task_type=self.task_type,
63+
)
64+
return [np.array(data.embedding) for data in output.data]
65+
66+
def embed_query(self, input: Documents) -> Embeddings:
67+
if not all(isinstance(item, str) for item in input):
68+
raise ValueError("Nomic only supports text queries, not images")
69+
70+
task_type = (
71+
self.query_config.get("task_type") if self.query_config else self.task_type
72+
)
73+
output = self.embed.text(
74+
model=self.model,
75+
texts=input,
76+
task_type=task_type,
77+
)
78+
return [np.array(data.embedding) for data in output.data]
79+
80+
@staticmethod
81+
def name() -> str:
82+
return "nomic"
83+
84+
def default_space(self) -> Space:
85+
return "cosine"
86+
87+
def supported_spaces(self) -> List[Space]:
88+
return ["cosine", "l2", "ip"]
89+
90+
@staticmethod
91+
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
92+
model = config.get("model")
93+
api_key_env_var = config.get("api_key_env_var")
94+
task_type = config.get("task_type")
95+
query_config = config.get("query_config")
96+
if model is None or api_key_env_var is None or task_type is None:
97+
assert False, "This code should not be reached" # this is for type checking
98+
return NomicEmbeddingFunction(
99+
model=model,
100+
api_key_env_var=api_key_env_var,
101+
task_type=task_type,
102+
query_config=query_config,
103+
)
104+
105+
def get_config(self) -> Dict[str, Any]:
106+
return {
107+
"model": self.model,
108+
"api_key_env_var": self.api_key_env_var,
109+
"task_type": self.task_type,
110+
"query_config": self.query_config,
111+
}
112+
113+
def validate_config_update(
114+
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
115+
) -> None:
116+
if "model" in new_config:
117+
raise ValueError(
118+
"The model cannot be changed after the embedding function has been initialized."
119+
)
120+
121+
@staticmethod
122+
def validate_config(config: Dict[str, Any]) -> None:
123+
"""
124+
Validate the configuration using the JSON schema.
125+
126+
Args:
127+
config: Configuration to validate
128+
"""
129+
validate_config_schema(config, "nomic")
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"$schema": "http://json-schema.org/draft-07/schema#",
3+
"title": "Nomic Embedding Function Schema",
4+
"description": "Schema for the Nomic embedding function configuration",
5+
"version": "1.0.0",
6+
"type": "object",
7+
"properties": {
8+
"model": {
9+
"type": "string",
10+
"description": "Parameter model for the Nomic embedding function"
11+
},
12+
"api_key_env_var": {
13+
"type": "string",
14+
"description": "Parameter api_key_env_var for the Nomic embedding function"
15+
},
16+
"task_type": {
17+
"type": "string",
18+
"description": "Parameter task_type for the Nomic embedding function"
19+
}
20+
}
21+
}

0 commit comments

Comments
 (0)