forked from chroma-core/chroma
-
Notifications
You must be signed in to change notification settings - Fork 0
/
new_api.py
317 lines (274 loc) · 9.64 KB
/
new_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
import chromadb
from chromadb.config import Settings
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
USE_LOCAL = False
client = None
if USE_LOCAL:
client = chromadb.Client()
else:
client = chromadb.Client(
Settings(
chroma_api_impl="rest", chroma_server_host="localhost", chroma_server_http_port="8000"
)
)
# print(client)
print(client.heartbeat())
client.reset()
collection = client.create_collection(name="test")
# Check type of Collection
assert type(collection) == chromadb.api.models.Collection.Collection
print(collection)
print(collection.name)
assert collection.count() == 0
getcollection = client.get_collection(name="test")
# Check type of get Collection
assert type(getcollection) == chromadb.api.models.Collection.Collection
print(getcollection)
# Test list, delete collections #
collections_list = client.list_collections()
assert len(collections_list) == 1
assert type(collections_list[0]) == chromadb.api.models.Collection.Collection
collection2 = client.create_collection(name="test2")
assert len(client.list_collections()) == 2
client.delete_collection(name="test2")
assert len(client.list_collections()) == 1
client.create_collection(name="test2")
client.delete_collection(name="test2")
assert len(client.list_collections()) == 1
print(client.list_collections())
# Check type of list_collections
# collection.create_index # wipes out the index you have (if you have one) and creates a fresh one
# collection = client.update_collection(oldName="test", newName="test2") # this feels a little odd to me (Jeff) -> collection.update(name="test2")
# add many
collection.add(
embeddings=[
[1.1, 2.3, 3.2],
[4.5, 6.9, 4.4],
[1.1, 2.3, 3.2],
[4.5, 6.9, 4.4],
[1.1, 2.3, 3.2],
[4.5, 6.9, 4.4],
[1.1, 2.3, 3.2],
[4.5, 6.9, 4.4],
],
metadatas=[
{"uri": "img1.png", "style": "style1"},
{"uri": "img2.png", "style": "style2"},
{"uri": "img3.png", "style": "style1"},
{"uri": "img4.png", "style": "style1"},
{"uri": "img5.png", "style": "style1"},
{"uri": "img6.png", "style": "style1"},
{"uri": "img7.png", "style": "style1"},
{"uri": "img8.png", "style": "style1"},
],
documents=["doc1", "doc2", "doc3", "doc4", "doc5", "doc6", "doc7", "doc8"],
ids=["id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"],
)
# add one
collection.add(
embeddings=[1.5, 2.9, 3.4],
metadatas={"uri": "img9.png", "style": "style1"},
documents="doc1000101",
ids="uri9",
)
print(collection.peek(5))
print(collection.count()) # NIT: count count take a where filter too
assert collection.count() == 9
### Test get by ids ###
get_ids_result = collection.get(
ids=["id1", "id2"],
)
print("\nGET ids\n", get_ids_result)
assert len(get_ids_result["embeddings"]) == 2
### Test get where clause ###
get_where_result = collection.get(
where={"style": "style1", "uri": "img1.png"},
)
print("\nGet where\n", get_where_result)
assert len(get_where_result["ids"]) == 1
### Test get both ###
get_both_result = collection.get(
ids=["id1", "id3"],
where={"style": "style1"},
)
print("\nGet both\n", get_both_result)
assert len(get_both_result["documents"]) == 2
# NIT: verify supports multiple at once is actually working
print(
"\nquery",
collection.query(
query_embeddings=[[1.1, 2.3, 3.2], [5.1, 4.3, 2.2]],
# OR // COULD BE an AND and return a tuple
# query_texts="doc10",
n_results=2,
# where={"style": "style2"},
),
)
### Test delete Partial ##
collection.delete( # propagates to the index
ids=["id1"],
)
assert collection.count() == 8
### Test delete Partial ##
collection.delete( # propagates to the index
where={"style": "style2"},
)
assert collection.count() == 7
### Test delete All ##
collection.delete()
assert collection.count() == 0
client.delete_collection(name="test")
assert len(client.list_collections()) == 0
# Test embedding function
collection = client.create_collection(
name="test", embedding_function=SentenceTransformerEmbeddingFunction()
)
# Add docs without embeddings (call emb function)
collection.add(
metadatas=[
{"uri": "img1.png", "style": "style1"},
{"uri": "img2.png", "style": "style2"},
{"uri": "img3.png", "style": "style1"},
{"uri": "img4.png", "style": "style1"},
{"uri": "img5.png", "style": "style1"},
{"uri": "img6.png", "style": "style1"},
{"uri": "img7.png", "style": "style1"},
{"uri": "img8.png", "style": "style1"},
],
documents=["doc1", "doc2", "doc3", "doc4", "doc5", "doc6", "doc7", "doc8"],
ids=["id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"],
)
# Add single doc without embeddings (call emb function)
collection.add(metadatas={"uri": "img9.png", "style": "style1"}, documents="doc9", ids="id9")
print(collection.peek(5))
assert collection.count() == 9
# Query with only text docs
# print(
# "query",
# collection.query(
# query_texts=["doc1", "doc2"],
# n_results=2,
# ),
# )
### TEST UPDATE ###
collection = client.create_collection(
"test_update", embedding_function=(lambda documents: [[0.1, 1.1, 1.2]] * len(documents))
)
assert collection.count() == 0
collection.add(
embeddings=[
[1.1, 2.3, 3.2],
[4.5, 6.9, 4.4],
[1.1, 2.3, 3.2],
[4.5, 6.9, 4.4],
[1.1, 2.3, 3.2],
[4.5, 6.9, 4.4],
[1.1, 2.3, 3.2],
[4.5, 6.9, 4.4],
],
metadatas=[
{"uri": "img1.png", "style": "style1"},
{"uri": "img2.png", "style": "style2"},
{"uri": "img3.png", "style": "style1"},
{"uri": "img4.png", "style": "style1"},
{"uri": "img5.png", "style": "style1"},
{"uri": "img6.png", "style": "style1"},
{"uri": "img7.png", "style": "style1"},
{"uri": "img8.png", "style": "style1"},
],
documents=["doc1", "doc2", "doc3", "doc4", "doc5", "doc6", "doc7", "doc8"],
ids=["id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"],
)
# Test update all fields again
collection.update(
ids=["id1", "id2"],
embeddings=[[0.0, 0.0, 0.5], [2.0, 0.0, 2.0]],
metadatas=[
{"uri": "img1.1.png", "style": "style1"},
{"uri": "img2.1.png", "style": "style1"},
],
documents=["cod1", "cod2"],
)
results = collection.get(ids=["id1", "id2"])
assert results["documents"][0] == "cod1"
assert results["metadatas"][0]["uri"] == "img1.1.png"
assert results["documents"][1] == "cod2"
assert results["metadatas"][1]["uri"] == "img2.1.png"
# Test update just document, embedding should get computed via function
collection.update(
ids=["id1"],
documents=["cod1"],
)
item1 = collection.get(ids="id1")
assert item1["metadatas"][0]["uri"] == "img1.1.png"
assert item1["embeddings"][0][0] == 0.1
# Test update just metadata
collection.update(
ids="id1",
metadatas={"uri": "img1.2.png", "style": "style1"},
)
item1 = collection.get(ids="id1")
assert item1["metadatas"][0]["uri"] == "img1.2.png"
assert item1["embeddings"][0][0] == 0.1
assert item1["documents"][0] == "cod1"
collection.delete()
### Test default embedding function ###
# Create collection with no embedding function
client.delete_collection(name="test")
collection = client.create_collection(name="test")
# Add docs without embeddings (call emb function)
collection.add(
metadatas=[
{"uri": "img1.png", "style": "style1"},
{"uri": "img2.png", "style": "style2"},
{"uri": "img3.png", "style": "style1"},
{"uri": "img4.png", "style": "style1"},
{"uri": "img5.png", "style": "style1"},
{"uri": "img6.png", "style": "style1"},
{"uri": "img7.png", "style": "style1"},
{"uri": "img8.png", "style": "style1"},
],
documents=["doc1", "doc2", "doc3", "doc4", "doc5", "doc6", "doc7", "doc8"],
ids=["id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"],
)
# Query with only text docs
print(
"query",
collection.query(
query_texts=["doc1", "doc2"],
n_results=2,
),
)
# Try to add embeddings of the wrong dimension.
# This should fail and not add any embeddings
# TODO: Currently only works locally since the exception is raised on the database side
try:
collection.add(
embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2], [4.5, 6.9, 4.4]],
metadatas=[
{"uri": "img1.png", "style": "style1"},
{"uri": "img2.png", "style": "style1"},
{"uri": "img3.png", "style": "style1"},
{"uri": "img4.png", "style": "style1"},
],
documents=["doc1", "doc2", "doc3", "doc4"],
ids=["id1", "id2", "id3", "id4"],
)
except Exception as e:
print("Exception", e)
# Try to query for more neighbors than exist
# This should fail and not add any embeddings
# TODO: Currently only works localy since the exception is raised on the database side
try:
collection.query(
query_texts=["doc1", "doc2"],
n_results=100,
)
except Exception as e:
print("Exception", e)
# collection.upsert( # always succeeds
# embeddings=[[1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2], [4.5, 6.9, 4.4], [1.1, 2.3, 3.2], [4.5, 6.9, 4.4]],
# metadatas=[{"uri": "img1.png", "style": "style1"}, {"uri": "img2.png", "style": "style1"}, {"uri": "img3.png", "style": "style1"}, {"uri": "img4.png", "style": "style1"}, {"uri": "img5.png", "style": "style1"}, {"uri": "img6.png", "style": "style1"}, {"uri": "img7.png", "style": "style1"}, {"uri": "img8.png", "style": "style1"}],
# documents=["doc1", "doc2", "doc3", "doc4", "doc5", "doc6", "doc7", "doc8"],
# ids=["id1", "id2", "id3", "id4", "id5", "id6", "id7", "id8"],
# )