From 9c53c9231230f1e800526491bdc179d018220060 Mon Sep 17 00:00:00 2001 From: loryruta Date: Wed, 10 Apr 2024 10:24:15 +0200 Subject: [PATCH] query: add nearest_neighbors_f32, find with scores functions #24 --- objectbox/query.py | 43 +++++++++++++++++++++++++++++++++++--- objectbox/query_builder.py | 32 ++++++++++++++++------------ 2 files changed, 59 insertions(+), 16 deletions(-) diff --git a/objectbox/query.py b/objectbox/query.py index 97ebe10..1d2a341 100644 --- a/objectbox/query.py +++ b/objectbox/query.py @@ -22,10 +22,10 @@ def __init__(self, c_query, box: 'Box'): self._ob = box._ob def find(self) -> list: + """ Finds a list of objects matching query. """ with self._ob.read_tx(): # OBX_bytes_array* c_bytes_array_p = obx_query_find(self._c_query) - try: # OBX_bytes_array c_bytes_array = c_bytes_array_p.contents @@ -36,11 +36,48 @@ def find(self) -> list: c_bytes = c_bytes_array.data[i] data = c_voidp_as_bytes(c_bytes.data, c_bytes.size) result.append(self._box._entity.unmarshal(data)) - return result finally: obx_bytes_array_free(c_bytes_array_p) + def find_ids(self) -> List[int]: + """ Finds a list of object IDs matching query. The result is sorted by ID (ascending order). """ + c_id_array_p = obx_query_find_ids(self._c_query) + try: + return list(c_id_array_p.contents) + finally: + obx_id_array_free(c_id_array_p) + + def find_with_scores(self): + """ Finds objects matching the query associated to their query score (e.g. distance in NN search). + The result is sorted by score in ascending order. """ + c_bytes_score_array_p = obx_query_find_with_scores(self._c_query) + try: + # OBX_bytes_score_array + c_bytes_score_array: OBX_bytes_score_array = c_bytes_score_array_p.contents + result = [] + for i in range(c_bytes_score_array.count): + # TODO implement + pass + return result + finally: + obx_bytes_score_array_free(c_bytes_score_array_p) + + def find_ids_with_scores(self) -> List[Tuple[int, float]]: + """ Finds object IDs matching the query associated to their query score (e.g. distance in NN search). + The resulting list is sorted by score in ascending order. """ + c_id_score_array_p = obx_query_find_ids_with_scores(self._c_query) + try: + # OBX_id_score_array + c_id_score_array: OBX_bytes_score_array = c_id_score_array_p.contents + result = [] + for i in range(c_id_score_array.count): + c_id_score: OBX_id_score = c_id_score_array.ids_scores[i] + result.append((c_id_score.id, c_id_score.score)) + return result + finally: + obx_id_score_array_free(c_id_score_array_p) + def count(self) -> int: count = ctypes.c_uint64() obx_query_count(self._c_query, ctypes.byref(count)) @@ -55,4 +92,4 @@ def offset(self, offset: int): return obx_query_offset(self._c_query, offset) def limit(self, limit: int): - return obx_query_limit(self._c_query, limit) \ No newline at end of file + return obx_query_limit(self._c_query, limit) diff --git a/objectbox/query_builder.py b/objectbox/query_builder.py index 3e7a2f7..9c35030 100644 --- a/objectbox/query_builder.py +++ b/objectbox/query_builder.py @@ -1,17 +1,17 @@ -from objectbox.model.entity import _Entity +import ctypes +import numpy as np +from typing import * + from objectbox.objectbox import ObjectBox from objectbox.query import Query from objectbox.c import * class QueryBuilder: - def __init__(self, ob: ObjectBox, box: 'Box', entity: '_Entity', condition: 'QueryCondition'): - if not isinstance(entity, _Entity): - raise Exception("Given type is not an Entity") + def __init__(self, ob: ObjectBox, box: 'Box'): self._box = box - self._entity = entity - self._condition = condition - self._c_builder = obx_query_builder(ob._c_store, entity.id) + self._entity = box._entity + self._c_builder = obx_query_builder(ob._c_store, box._entity.id) def close(self) -> int: return obx_qb_close(self) @@ -85,11 +85,17 @@ def less_or_equal_int(self, property_id: int, value: int): def between_2ints(self, property_id: int, value_a: int, value_b: int): obx_qb_between_2ints(self._c_builder, property_id, value_a, value_b) return self - - def apply_condition(self): - self._condition.apply(self) - + + def nearest_neighbors_f32(self, vector_property_id: int, query_vector: Union[np.ndarray, List[float]], element_count: int): + if isinstance(query_vector, np.ndarray): + if query_vector.dtype != np.float32: + raise Exception(f"query_vector dtype must be float32") + query_vector_data = query_vector.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) + else: # List[float] + query_vector_data = (ctypes.c_float * len(query_vector))(*query_vector) + obx_qb_nearest_neighbors_f32(self._c_builder, vector_property_id, query_vector_data, element_count) + return self + def build(self) -> Query: - self.apply_condition() c_query = obx_query(self._c_builder) - return Query(c_query, self._box) \ No newline at end of file + return Query(c_query, self._box)