Skip to content

Commit a816d5a

Browse files
authored
Adding convenience method for NNS (#18)
* adding convenience nns method * bumping version * fixing copy-paste errors * feedback
1 parent 4bd479f commit a816d5a

File tree

3 files changed

+112
-58
lines changed

3 files changed

+112
-58
lines changed

cottontaildb_client/cottontaildb_client.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
EntityDefinition, CreateEntityMessage, InsertMessage, ColumnName, Scan, From, Type, ListSchemaMessage, \
99
ListEntityMessage, EntityDetailsMessage, DropEntityMessage, TruncateEntityMessage, OptimizeEntityMessage, \
1010
IndexName, IndexType, CreateIndexMessage, DropIndexMessage, RebuildIndexMessage, UpdateMessage, \
11-
DeleteMessage, Literal, Vector, FloatVector, BatchInsertMessage, Metadata, QueryMessage, Query
11+
DeleteMessage, Literal, Vector, FloatVector, BatchInsertMessage, Metadata, QueryMessage, Query, Expression, \
12+
FunctionName, Function, Projection, Order
1213
from .cottontail_pb2_grpc import DDLStub, DMLStub, TXNStub, DQLStub
1314

1415

@@ -343,6 +344,42 @@ def ping(self):
343344
"""Sends a ping message to the endpoint. If method returns without exception endpoint is connected."""
344345
self._dql.Ping(Empty())
345346

347+
def nns(self, schema, entity, query_vector, distance='manhattan', limit=None, vector_col='feature', id_col='id'):
348+
"""
349+
Queries the specified entity with the given vector.
350+
351+
@param schema: the schema containing the queried entity
352+
@param entity: the entity being queried
353+
@param query_vector: the query vector. Simple float array.
354+
@param distance: the distance to be used
355+
@param limit: maximum number of rows to return
356+
@param vector_col: column name where the vector is stored
357+
@param id_col: column name where the id is stored
358+
"""
359+
schema_name = SchemaName(name=schema)
360+
entity_name = EntityName(schema=schema_name, name=entity)
361+
nns_col = Expression(column=ColumnName(entity=entity_name, name=vector_col))
362+
363+
distance_col = ColumnName(name='distance')
364+
id_expression = Expression(column=ColumnName(name=id_col))
365+
366+
nns_expression = Expression(literal=float_vector(*query_vector))
367+
fn = FunctionName(name=distance)
368+
fun = Function(name=fn, arguments=[nns_col, nns_expression])
369+
370+
expression = Expression(function=fun)
371+
372+
projection_element = Projection.ProjectionElement(alias=distance_col,
373+
expression=expression)
374+
projection = Projection(op=Projection.ProjectionOperation.SELECT,
375+
elements=[projection_element, Projection.ProjectionElement(expression=id_expression)])
376+
377+
order_component = Order.Component(column=distance_col, direction=Order.Direction.ASCENDING)
378+
379+
order = Order(components=[order_component])
380+
381+
return self.query(schema, entity, projection, None, limit=limit, order=order)
382+
346383
def query(self, schema, entity, projection, where, order=None, limit=None, skip=None, from_=None):
347384
"""
348385
Queries the specified entity where the provided conditions are met and applies the given projection.

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[metadata]
22
name = cottontaildb-client
3-
version = 0.14.0
3+
version = 0.14.1
44
author = Florian Spiess
55
author_email = [email protected]
66
description = A Cottontail DB gRPC client.

tests/test_cottontaildb_client.py

Lines changed: 73 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@
99
DB_HOST = 'localhost'
1010
DB_PORT = 1865
1111

12-
TEST_SCHEMA = 'schema_test'
13-
TEST_ENTITY = 'entity_test'
14-
TEST_VECTOR_ENTITY = 'entity_test_vector'
12+
TEST_SCHEMA_STR = 'schema_test'
13+
TEST_ENTITY_STR = 'entity_test'
14+
TEST_VECTOR_ENTITY_STR = 'entity_test_vector'
15+
TEST_SCHEMA_NAME = SchemaName(name=TEST_SCHEMA_STR)
16+
TEST_ENTITY_NAME = EntityName(schema=TEST_SCHEMA_NAME, name=TEST_ENTITY_STR)
17+
TEST_VECTOR_ENTITY_NAME = EntityName(schema=TEST_SCHEMA_NAME, name=TEST_VECTOR_ENTITY_STR)
1518
TEST_INDEX = 'index_test'
1619
TEST_COLUMN_ID = 'id'
1720
TEST_COLUMN_VALUE = 'value'
@@ -32,74 +35,76 @@ def setUp(self):
3235
self.skipTest(f'error connecting to Cottontail DB at {DB_HOST}:{DB_PORT}')
3336

3437
def tearDown(self):
35-
if TEST_SCHEMA in [s.split('.')[-1] for s in self.client.list_schemas()]:
36-
self.client.drop_schema(TEST_SCHEMA)
38+
if TEST_SCHEMA_STR in [s.split('.')[-1] for s in self.client.list_schemas()]:
39+
self.client.drop_schema(TEST_SCHEMA_STR)
3740
self.client.close()
3841

3942
def test_create_drop_schema(self):
4043
self._create_schema()
41-
self.assert_in(TEST_SCHEMA, self.client.list_schemas(), 'schema was not created')
42-
self.client.create_schema(TEST_SCHEMA, exist_ok=True)
43-
self.client.drop_schema(TEST_SCHEMA)
44-
self.assert_not_in(TEST_SCHEMA, self.client.list_schemas(), 'schema was not dropped')
44+
self.assert_in(TEST_SCHEMA_STR, self.client.list_schemas(), 'schema was not created')
45+
self.client.create_schema(TEST_SCHEMA_STR, exist_ok=True)
46+
self.client.drop_schema(TEST_SCHEMA_STR)
47+
self.assert_not_in(TEST_SCHEMA_STR, self.client.list_schemas(), 'schema was not dropped')
4548

4649
def test_create_drop_entity(self):
4750
self._create_schema()
4851
self._create_entity()
4952
self._create_vector_entity()
50-
self.assert_in(TEST_ENTITY, self.client.list_entities(TEST_SCHEMA), 'entity was not created')
51-
self.assert_in(TEST_VECTOR_ENTITY, self.client.list_entities(TEST_SCHEMA), 'vector entity was not created')
52-
self.client.create_entity(TEST_SCHEMA, TEST_ENTITY, [], exist_ok=True)
53-
self.client.create_entity(TEST_SCHEMA, TEST_VECTOR_ENTITY, [], exist_ok=True)
54-
self.client.drop_entity(TEST_SCHEMA, TEST_ENTITY)
55-
self.client.drop_entity(TEST_SCHEMA, TEST_VECTOR_ENTITY)
56-
self.assert_not_in(TEST_ENTITY, self.client.list_entities(TEST_SCHEMA), 'entity was not dropped')
57-
self.assert_not_in(TEST_VECTOR_ENTITY, self.client.list_entities(TEST_SCHEMA), 'vector entity was not dropped')
53+
self.assert_in(TEST_ENTITY_STR, self.client.list_entities(TEST_SCHEMA_STR), 'entity was not created')
54+
self.assert_in(TEST_VECTOR_ENTITY_STR, self.client.list_entities(TEST_SCHEMA_STR),
55+
'vector entity was not created')
56+
self.client.create_entity(TEST_SCHEMA_STR, TEST_ENTITY_STR, [], exist_ok=True)
57+
self.client.create_entity(TEST_SCHEMA_STR, TEST_VECTOR_ENTITY_STR, [], exist_ok=True)
58+
self.client.drop_entity(TEST_SCHEMA_STR, TEST_ENTITY_STR)
59+
self.client.drop_entity(TEST_SCHEMA_STR, TEST_VECTOR_ENTITY_STR)
60+
self.assert_not_in(TEST_ENTITY_STR, self.client.list_entities(TEST_SCHEMA_STR), 'entity was not dropped')
61+
self.assert_not_in(TEST_VECTOR_ENTITY_STR, self.client.list_entities(TEST_SCHEMA_STR),
62+
'vector entity was not dropped')
5863

5964
def test_drop_not_exists_entity(self):
6065
self._create_schema()
61-
result = self.client.drop_entity(TEST_SCHEMA, TEST_ENTITY, not_exist_ok=True)
66+
result = self.client.drop_entity(TEST_SCHEMA_STR, TEST_ENTITY_STR, not_exist_ok=True)
6267
self.assertIsNone(result, 'response received after dropping nonexistent entity')
6368

6469
def test_truncate_not_exists_entity(self):
6570
self._create_schema()
66-
result = self.client.truncate_entity(TEST_SCHEMA, TEST_ENTITY, not_exist_ok=True)
71+
result = self.client.truncate_entity(TEST_SCHEMA_STR, TEST_ENTITY_STR, not_exist_ok=True)
6772
self.assertIsNone(result, 'response received after truncating nonexistent entity')
6873

6974
def test_create_truncate_entity(self):
7075
self._create_schema()
7176
self._create_entity()
72-
self.assert_in(TEST_ENTITY, self.client.list_entities(TEST_SCHEMA), 'entity was not created')
73-
self.client.truncate_entity(TEST_SCHEMA, TEST_ENTITY)
74-
self.assert_in(TEST_ENTITY, self.client.list_entities(TEST_SCHEMA), 'entity was dropped')
77+
self.assert_in(TEST_ENTITY_STR, self.client.list_entities(TEST_SCHEMA_STR), 'entity was not created')
78+
self.client.truncate_entity(TEST_SCHEMA_STR, TEST_ENTITY_STR)
79+
self.assert_in(TEST_ENTITY_STR, self.client.list_entities(TEST_SCHEMA_STR), 'entity was dropped')
7580

7681
def test_insert(self):
7782
self._create_schema()
7883
self._create_entity()
7984
self._insert()
80-
details = self.client.get_entity_details(TEST_SCHEMA, TEST_ENTITY)
85+
details = self.client.get_entity_details(TEST_SCHEMA_STR, TEST_ENTITY_STR)
8186
self.assertEqual(details['rows'], 1, 'unexpected number of rows in entity after insert')
8287

8388
def test_vector_insert(self):
8489
self._create_schema()
8590
self._create_vector_entity()
8691
self._insert_vector()
87-
details = self.client.get_entity_details(TEST_SCHEMA, TEST_VECTOR_ENTITY)
92+
details = self.client.get_entity_details(TEST_SCHEMA_STR, TEST_VECTOR_ENTITY_STR)
8893
self.assertEqual(details['rows'], 1, 'unexpected number of rows in vector entity after insert')
8994
print('success')
9095

9196
def test_batch_insert(self):
9297
self._create_schema()
9398
self._create_entity()
9499
self._batch_insert()
95-
details = self.client.get_entity_details(TEST_SCHEMA, TEST_ENTITY)
100+
details = self.client.get_entity_details(TEST_SCHEMA_STR, TEST_ENTITY_STR)
96101
self.assertEqual(details['rows'], 3, 'unexpected number of rows in entity after batch insert')
97102

98103
def test_batch_insert_vectors(self):
99104
self._create_schema()
100105
self._create_vector_entity()
101106
self._batch_insert_vectors()
102-
details = self.client.get_entity_details(TEST_SCHEMA, TEST_VECTOR_ENTITY)
107+
details = self.client.get_entity_details(TEST_SCHEMA_STR, TEST_VECTOR_ENTITY_STR)
103108
self.assertEqual(details['rows'], 3, 'unexpected number of rows in entity after batch insert')
104109

105110
def test_query(self):
@@ -110,6 +115,19 @@ def test_query(self):
110115
query_result = self._query_value_with_key(query_key)
111116
self.assertEqual(len(query_result), 1, 'unexpected number of rows returned from query')
112117

118+
def test_query_vectors(self):
119+
self._create_schema()
120+
self._create_vector_entity()
121+
self._batch_insert_vectors()
122+
query = [0.1, 0.2, 0.4]
123+
124+
results = self.client.nns(TEST_SCHEMA_STR, TEST_VECTOR_ENTITY_STR, query, vector_col='value')
125+
self.assertEqual(len(results), 3, 'unexpected number of rows returned from query')
126+
# test with limit
127+
results = self.client.nns(TEST_SCHEMA_STR, TEST_VECTOR_ENTITY_STR, query, vector_col='value', limit=1)
128+
print(results)
129+
self.assertEqual(len(results), 1, 'unexpected number of rows returned from query')
130+
113131
def test_update(self):
114132
self._create_schema()
115133
self._create_entity()
@@ -124,20 +142,20 @@ def test_optimize(self):
124142
self._create_schema()
125143
self._create_entity()
126144
self._batch_insert()
127-
self.client.optimize_entity(TEST_SCHEMA, TEST_ENTITY)
145+
self.client.optimize_entity(TEST_SCHEMA_STR, TEST_ENTITY_STR)
128146

129147
def test_create_rebuild_drop_index(self):
130148
self._create_schema()
131149
self._create_entity()
132150
self._batch_insert()
133-
details = self.client.get_entity_details(TEST_SCHEMA, TEST_ENTITY)
151+
details = self.client.get_entity_details(TEST_SCHEMA_STR, TEST_ENTITY_STR)
134152
self.assertEqual(len(details['indexes']), 0, 'unexpected number of indexes in entity before index creation')
135-
self.client.create_index(TEST_SCHEMA, TEST_ENTITY, TEST_INDEX, IndexType.BTREE, [TEST_COLUMN_VALUE])
136-
details = self.client.get_entity_details(TEST_SCHEMA, TEST_ENTITY)
153+
self.client.create_index(TEST_SCHEMA_STR, TEST_ENTITY_STR, TEST_INDEX, IndexType.BTREE, [TEST_COLUMN_VALUE])
154+
details = self.client.get_entity_details(TEST_SCHEMA_STR, TEST_ENTITY_STR)
137155
self.assertEqual(len(details['indexes']), 1, 'index was not created')
138-
self.client.rebuild_index(TEST_SCHEMA, TEST_ENTITY, TEST_INDEX)
139-
self.client.drop_index(TEST_SCHEMA, TEST_ENTITY, TEST_INDEX)
140-
details = self.client.get_entity_details(TEST_SCHEMA, TEST_ENTITY)
156+
self.client.rebuild_index(TEST_SCHEMA_STR, TEST_ENTITY_STR, TEST_INDEX)
157+
self.client.drop_index(TEST_SCHEMA_STR, TEST_ENTITY_STR, TEST_INDEX)
158+
details = self.client.get_entity_details(TEST_SCHEMA_STR, TEST_ENTITY_STR)
141159
self.assertEqual(len(details['indexes']), 0, 'index was not dropped')
142160

143161
def test_transaction_commit(self):
@@ -146,18 +164,18 @@ def test_transaction_commit(self):
146164
self.client.start_transaction()
147165
self._insert()
148166
self.client.commit_transaction()
149-
details = self.client.get_entity_details(TEST_SCHEMA, TEST_ENTITY)
167+
details = self.client.get_entity_details(TEST_SCHEMA_STR, TEST_ENTITY_STR)
150168
self.assertEqual(details['rows'], 1, 'unexpected number of rows in entity after committed insert')
151169

152170
def test_transaction_abort(self):
153171
self._create_schema()
154172
self._create_entity()
155173
self.client.start_transaction()
156174
self._insert()
157-
details = self.client.get_entity_details(TEST_SCHEMA, TEST_ENTITY)
175+
details = self.client.get_entity_details(TEST_SCHEMA_STR, TEST_ENTITY_STR)
158176
self.assertEqual(details['rows'], 1, 'unexpected number of rows in entity after transaction insert')
159177
self.client.abort_transaction()
160-
details = self.client.get_entity_details(TEST_SCHEMA, TEST_ENTITY)
178+
details = self.client.get_entity_details(TEST_SCHEMA_STR, TEST_ENTITY_STR)
161179
self.assertEqual(details['rows'], 0, 'unexpected number of rows in entity after aborted insert')
162180

163181
def test_delete(self):
@@ -166,8 +184,8 @@ def test_delete(self):
166184
self._insert()
167185
where = Where(atomic=AtomicBooleanPredicate(left=ColumnName(name=TEST_COLUMN_VALUE), right=AtomicBooleanOperand(
168186
expressions=Expressions(expression=[Expression(literal=Literal(intData=0))])), op=ComparisonOperator.EQUAL))
169-
self.client.delete(TEST_SCHEMA, TEST_ENTITY, where)
170-
details = self.client.get_entity_details(TEST_SCHEMA, TEST_ENTITY)
187+
self.client.delete(TEST_SCHEMA_STR, TEST_ENTITY_STR, where)
188+
details = self.client.get_entity_details(TEST_SCHEMA_STR, TEST_ENTITY_STR)
171189
self.assertEqual(details['rows'], 0, 'unexpected number of rows in entity after delete')
172190

173191
def assert_in(self, name, names, message):
@@ -179,51 +197,52 @@ def assert_not_in(self, name, names, message):
179197
self.assertNotIn(name, [n.split('.')[-1] for n in names], message)
180198

181199
def _create_schema(self):
182-
self.client.create_schema(TEST_SCHEMA)
200+
self.client.create_schema(TEST_SCHEMA_STR)
183201

184202
def _create_entity(self):
185203
columns = [
186204
column_def(TEST_COLUMN_ID, Type.STRING, nullable=False),
187205
column_def(TEST_COLUMN_VALUE, Type.INTEGER, nullable=False)
188206
]
189-
self.client.create_entity(TEST_SCHEMA, TEST_ENTITY, columns)
207+
self.client.create_entity(TEST_SCHEMA_STR, TEST_ENTITY_STR, columns)
190208

191209
def _create_vector_entity(self):
192210
columns = [
193211
column_def(TEST_COLUMN_ID, Type.STRING, nullable=False),
194212
column_def(TEST_COLUMN_VALUE, Type.FLOAT_VEC, length=3, nullable=False)
195213
]
196-
self.client.create_entity(TEST_SCHEMA, TEST_VECTOR_ENTITY, columns)
214+
self.client.create_entity(TEST_SCHEMA_STR, TEST_VECTOR_ENTITY_STR, columns)
197215

198216
def _insert(self):
199-
values = {'id': Literal(stringData='test_0'), 'value': Literal(intData=0)}
200-
self.client.insert(TEST_SCHEMA, TEST_ENTITY, values)
217+
values = {'id': Literal(stringData='test_0'), TEST_COLUMN_VALUE: Literal(intData=0)}
218+
self.client.insert(TEST_SCHEMA_STR, TEST_ENTITY_STR, values)
201219

202220
def _insert_vector(self):
203221
value_list = [0.2, 0.3, 0.5]
204-
values = {'id': Literal(stringData='test_0'), 'value': float_vector(*value_list)}
205-
self.client.insert(TEST_SCHEMA, TEST_VECTOR_ENTITY, values)
222+
values = {'id': Literal(stringData='test_0'), TEST_COLUMN_VALUE: float_vector(*value_list)}
223+
self.client.insert(TEST_SCHEMA_STR, TEST_VECTOR_ENTITY_STR, values)
206224

207225
def _batch_insert(self):
208-
columns = ['id', 'value']
226+
columns = ['id', TEST_COLUMN_VALUE]
209227
values = [
210228
[Literal(stringData='test_1'), Literal(intData=1)],
211229
[Literal(stringData='test_2'), Literal(intData=2)],
212230
[Literal(stringData='test_3'), Literal(intData=3)]
213231
]
214-
self.client.insert_batch(TEST_SCHEMA, TEST_ENTITY, columns, values)
232+
233+
self.client.insert_batch(TEST_SCHEMA_STR, TEST_ENTITY_STR, columns, values)
215234

216235
def _batch_insert_vectors(self):
217-
columns = ['id', 'value']
236+
columns = ['id', TEST_COLUMN_VALUE]
218237
one = [0.1, 0.2, 0.3]
219-
two = [0.000001, 0.2, 0.3]
220-
three = [0.1, 0.2, 0.3]
238+
two = [0.01, 0.02, 0.3]
239+
three = [0.9, 0.9, 0.9]
221240
values = [
222241
[Literal(stringData='test_1'), float_vector(*one)],
223242
[Literal(stringData='test_2'), float_vector(*two)],
224243
[Literal(stringData='test_3'), float_vector(*three)]
225244
]
226-
self.client.insert_batch(TEST_SCHEMA, TEST_VECTOR_ENTITY, columns, values)
245+
self.client.insert_batch(TEST_SCHEMA_STR, TEST_VECTOR_ENTITY_STR, columns, values)
227246

228247
def _update_value_with_key(self, key, value):
229248
where = Where(atomic=AtomicBooleanPredicate(
@@ -233,12 +252,10 @@ def _update_value_with_key(self, key, value):
233252
op=ComparisonOperator.EQUAL
234253
))
235254
updates = {TEST_COLUMN_VALUE: Expression(literal=Literal(intData=value))}
236-
self.client.update(TEST_SCHEMA, TEST_ENTITY, where, updates)
255+
self.client.update(TEST_SCHEMA_STR, TEST_ENTITY_STR, where, updates)
237256

238257
def _query_value_with_key(self, key):
239-
schema_name = SchemaName(name=TEST_SCHEMA)
240-
entity_name = EntityName(schema=schema_name, name=TEST_ENTITY)
241-
expression = Expression(column=ColumnName(entity=entity_name, name=TEST_COLUMN_VALUE))
258+
expression = Expression(column=ColumnName(entity=TEST_ENTITY_NAME, name=TEST_COLUMN_VALUE))
242259
projection_element = Projection.ProjectionElement(expression=expression)
243260
projection = Projection(op=Projection.ProjectionOperation.SELECT, elements=[projection_element])
244261
where = Where(atomic=AtomicBooleanPredicate(
@@ -247,4 +264,4 @@ def _query_value_with_key(self, key):
247264
expressions=Expressions(expression=[Expression(literal=Literal(stringData=key))])),
248265
op=ComparisonOperator.EQUAL
249266
))
250-
return self.client.query(TEST_SCHEMA, TEST_ENTITY, projection, where)
267+
return self.client.query(TEST_SCHEMA_STR, TEST_ENTITY_STR, projection, where)

0 commit comments

Comments
 (0)