Skip to content

Commit 7601078

Browse files
committed
[ENH] Add nomic embedding function
1 parent a219bb3 commit 7601078

File tree

4 files changed

+148
-0
lines changed

4 files changed

+148
-0
lines changed

chromadb/test/ef/test_ef.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def test_get_builtins_holds() -> None:
3737
"InstructorEmbeddingFunction",
3838
"JinaEmbeddingFunction",
3939
"MistralEmbeddingFunction",
40+
"NomicEmbeddingFunction",
4041
"ONNXMiniLM_L6_V2",
4142
"OllamaEmbeddingFunction",
4243
"OpenAIEmbeddingFunction",

chromadb/utils/embedding_functions/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@
6565
from chromadb.utils.embedding_functions.mistral_embedding_function import (
6666
MistralEmbeddingFunction,
6767
)
68+
from chromadb.utils.embedding_functions.nomic_embedding_function import (
69+
NomicEmbeddingFunction,
70+
NomicQueryConfig,
71+
)
6872

6973
try:
7074
from chromadb.is_thin_client import is_thin_client
@@ -85,6 +89,7 @@
8589
"InstructorEmbeddingFunction",
8690
"JinaEmbeddingFunction",
8791
"MistralEmbeddingFunction",
92+
"NomicEmbeddingFunction",
8893
"VoyageAIEmbeddingFunction",
8994
"ONNXMiniLM_L6_V2",
9095
"OpenCLIPEmbeddingFunction",
@@ -146,6 +151,7 @@ def validate_config(config: Dict[str, Any]) -> None:
146151
"instructor": InstructorEmbeddingFunction,
147152
"jina": JinaEmbeddingFunction,
148153
"mistral": MistralEmbeddingFunction,
154+
"nomic": NomicEmbeddingFunction,
149155
"voyageai": VoyageAIEmbeddingFunction,
150156
"onnx_mini_lm_l6_v2": ONNXMiniLM_L6_V2,
151157
"open_clip": OpenCLIPEmbeddingFunction,
@@ -235,6 +241,8 @@ def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction:
235241
"JinaEmbeddingFunction",
236242
"JinaQueryConfig",
237243
"MistralEmbeddingFunction",
244+
"NomicEmbeddingFunction",
245+
"NomicQueryConfig",
238246
"VoyageAIEmbeddingFunction",
239247
"ONNXMiniLM_L6_V2",
240248
"OpenCLIPEmbeddingFunction",
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from chromadb.api.types import (
2+
Embeddings,
3+
Documents,
4+
EmbeddingFunction,
5+
Space,
6+
)
7+
from chromadb.utils.embedding_functions.schemas import validate_config_schema
8+
from typing import List, Dict, Any, TypedDict, Optional
9+
import os
10+
import numpy as np
11+
12+
13+
class NomicQueryConfig(TypedDict):
14+
task_type: str
15+
16+
17+
class NomicEmbeddingFunction(EmbeddingFunction[Documents]):
18+
"""
19+
This class is used to get embeddings for a list of texts using the Nomic API.
20+
"""
21+
22+
def __init__(
23+
self,
24+
model: str,
25+
api_key_env_var: str = "NOMIC_API_KEY",
26+
task_type: str = "search_document",
27+
query_config: Optional[NomicQueryConfig] = None,
28+
):
29+
try:
30+
from nomic import embed
31+
except ImportError:
32+
raise ValueError(
33+
"The nomic python package is not installed. Please install it with `pip install nomic`"
34+
)
35+
36+
self.model = model
37+
self.task_type = task_type
38+
self.api_key_env_var = api_key_env_var
39+
self.api_key = os.getenv(api_key_env_var)
40+
self.query_config = query_config
41+
if not self.api_key:
42+
raise ValueError(f"The {api_key_env_var} environment variable is not set.")
43+
self.embed = embed
44+
45+
def __call__(self, input: Documents) -> Embeddings:
46+
if not all(isinstance(item, str) for item in input):
47+
raise ValueError("Nomic only supports text documents, not images")
48+
output = self.embed.text(
49+
model=self.model,
50+
texts=input,
51+
task_type=self.task_type,
52+
)
53+
return [np.array(data.embedding) for data in output.data]
54+
55+
def embed_query(self, input: Documents) -> Embeddings:
56+
if not all(isinstance(item, str) for item in input):
57+
raise ValueError("Nomic only supports text queries, not images")
58+
59+
task_type = (
60+
self.query_config.get("task_type") if self.query_config else self.task_type
61+
)
62+
output = self.embed.text(
63+
model=self.model,
64+
texts=input,
65+
task_type=task_type,
66+
)
67+
return [np.array(data.embedding) for data in output.data]
68+
69+
@staticmethod
70+
def name() -> str:
71+
return "nomic"
72+
73+
def default_space(self) -> Space:
74+
return "cosine"
75+
76+
def supported_spaces(self) -> List[Space]:
77+
return ["cosine", "l2", "ip"]
78+
79+
@staticmethod
80+
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
81+
model = config.get("model")
82+
api_key_env_var = config.get("api_key_env_var")
83+
task_type = config.get("task_type")
84+
query_config = config.get("query_config")
85+
if model is None or api_key_env_var is None or task_type is None:
86+
assert False, "This code should not be reached" # this is for type checking
87+
return NomicEmbeddingFunction(
88+
model=model,
89+
api_key_env_var=api_key_env_var,
90+
task_type=task_type,
91+
query_config=query_config,
92+
)
93+
94+
def get_config(self) -> Dict[str, Any]:
95+
return {
96+
"model": self.model,
97+
"api_key_env_var": self.api_key_env_var,
98+
"task_type": self.task_type,
99+
"query_config": self.query_config,
100+
}
101+
102+
def validate_config_update(
103+
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
104+
) -> None:
105+
if "model" in new_config:
106+
raise ValueError(
107+
"The model cannot be changed after the embedding function has been initialized."
108+
)
109+
110+
@staticmethod
111+
def validate_config(config: Dict[str, Any]) -> None:
112+
"""
113+
Validate the configuration using the JSON schema.
114+
115+
Args:
116+
config: Configuration to validate
117+
"""
118+
validate_config_schema(config, "nomic")
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"$schema": "http://json-schema.org/draft-07/schema#",
3+
"title": "Nomic Embedding Function Schema",
4+
"description": "Schema for the Nomic embedding function configuration",
5+
"version": "1.0.0",
6+
"type": "object",
7+
"properties": {
8+
"model": {
9+
"type": "string",
10+
"description": "Parameter model for the Nomic embedding function"
11+
},
12+
"api_key_env_var": {
13+
"type": "string",
14+
"description": "Parameter api_key_env_var for the Nomic embedding function"
15+
},
16+
"task_type": {
17+
"type": "string",
18+
"description": "Parameter task_type for the Nomic embedding function"
19+
}
20+
}
21+
}

0 commit comments

Comments
 (0)