Skip to content

Commit

Permalink
query: add nearest_neighbors_f32, find with scores functions objectbo…
Browse files Browse the repository at this point in the history
  • Loading branch information
loryruta committed Apr 10, 2024
1 parent 874520b commit 9c53c92
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 16 deletions.
43 changes: 40 additions & 3 deletions objectbox/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ def __init__(self, c_query, box: 'Box'):
self._ob = box._ob

def find(self) -> list:
""" Finds a list of objects matching query. """
with self._ob.read_tx():
# OBX_bytes_array*
c_bytes_array_p = obx_query_find(self._c_query)

try:
# OBX_bytes_array
c_bytes_array = c_bytes_array_p.contents
Expand All @@ -36,11 +36,48 @@ def find(self) -> list:
c_bytes = c_bytes_array.data[i]
data = c_voidp_as_bytes(c_bytes.data, c_bytes.size)
result.append(self._box._entity.unmarshal(data))

return result
finally:
obx_bytes_array_free(c_bytes_array_p)

def find_ids(self) -> List[int]:
""" Finds a list of object IDs matching query. The result is sorted by ID (ascending order). """
c_id_array_p = obx_query_find_ids(self._c_query)
try:
return list(c_id_array_p.contents)
finally:
obx_id_array_free(c_id_array_p)

def find_with_scores(self):
""" Finds objects matching the query associated to their query score (e.g. distance in NN search).
The result is sorted by score in ascending order. """
c_bytes_score_array_p = obx_query_find_with_scores(self._c_query)
try:
# OBX_bytes_score_array
c_bytes_score_array: OBX_bytes_score_array = c_bytes_score_array_p.contents
result = []
for i in range(c_bytes_score_array.count):
# TODO implement
pass
return result
finally:
obx_bytes_score_array_free(c_bytes_score_array_p)

def find_ids_with_scores(self) -> List[Tuple[int, float]]:
""" Finds object IDs matching the query associated to their query score (e.g. distance in NN search).
The resulting list is sorted by score in ascending order. """
c_id_score_array_p = obx_query_find_ids_with_scores(self._c_query)
try:
# OBX_id_score_array
c_id_score_array: OBX_bytes_score_array = c_id_score_array_p.contents
result = []
for i in range(c_id_score_array.count):
c_id_score: OBX_id_score = c_id_score_array.ids_scores[i]
result.append((c_id_score.id, c_id_score.score))
return result
finally:
obx_id_score_array_free(c_id_score_array_p)

def count(self) -> int:
count = ctypes.c_uint64()
obx_query_count(self._c_query, ctypes.byref(count))
Expand All @@ -55,4 +92,4 @@ def offset(self, offset: int):
return obx_query_offset(self._c_query, offset)

def limit(self, limit: int):
return obx_query_limit(self._c_query, limit)
return obx_query_limit(self._c_query, limit)
32 changes: 19 additions & 13 deletions objectbox/query_builder.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from objectbox.model.entity import _Entity
import ctypes
import numpy as np
from typing import *

from objectbox.objectbox import ObjectBox
from objectbox.query import Query
from objectbox.c import *


class QueryBuilder:
def __init__(self, ob: ObjectBox, box: 'Box', entity: '_Entity', condition: 'QueryCondition'):
if not isinstance(entity, _Entity):
raise Exception("Given type is not an Entity")
def __init__(self, ob: ObjectBox, box: 'Box'):
self._box = box
self._entity = entity
self._condition = condition
self._c_builder = obx_query_builder(ob._c_store, entity.id)
self._entity = box._entity
self._c_builder = obx_query_builder(ob._c_store, box._entity.id)

def close(self) -> int:
return obx_qb_close(self)
Expand Down Expand Up @@ -85,11 +85,17 @@ def less_or_equal_int(self, property_id: int, value: int):
def between_2ints(self, property_id: int, value_a: int, value_b: int):
obx_qb_between_2ints(self._c_builder, property_id, value_a, value_b)
return self

def apply_condition(self):
self._condition.apply(self)


def nearest_neighbors_f32(self, vector_property_id: int, query_vector: Union[np.ndarray, List[float]], element_count: int):
if isinstance(query_vector, np.ndarray):
if query_vector.dtype != np.float32:
raise Exception(f"query_vector dtype must be float32")
query_vector_data = query_vector.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
else: # List[float]
query_vector_data = (ctypes.c_float * len(query_vector))(*query_vector)
obx_qb_nearest_neighbors_f32(self._c_builder, vector_property_id, query_vector_data, element_count)
return self

def build(self) -> Query:
self.apply_condition()
c_query = obx_query(self._c_builder)
return Query(c_query, self._box)
return Query(c_query, self._box)

0 comments on commit 9c53c92

Please sign in to comment.