Skip to content

Commit

Permalink
c.py: bubble up HNSW -related functions objectbox#24
Browse files Browse the repository at this point in the history
  • Loading branch information
loryruta committed Apr 10, 2024
1 parent 1395fa9 commit cbdd07d
Showing 1 changed file with 109 additions and 3 deletions.
112 changes: 109 additions & 3 deletions objectbox/c.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import platform
from objectbox.version import Version
from typing import *

# This file contains C-API bindings based on lib/objectbox.h, linking to the 'objectbox' shared library.
# The bindings are implementing using ctypes, see https://docs.python.org/dev/library/ctypes.html for introduction.
Expand Down Expand Up @@ -72,6 +73,8 @@ def shlib_name(library: str) -> str:
OBXDebugFlags = ctypes.c_int
OBXPutMode = ctypes.c_int
OBXOrderFlags = ctypes.c_int
OBXHnswFlags = ctypes.c_int
OBXHnswDistanceType = ctypes.c_int


class OBX_model(ctypes.Structure):
Expand Down Expand Up @@ -115,6 +118,27 @@ class OBX_bytes_array(ctypes.Structure):
OBX_bytes_array_p = ctypes.POINTER(OBX_bytes_array)


class OBX_bytes_score(ctypes.Structure):
_fields_ = [
('data', ctypes.c_void_p),
('size', ctypes.c_size_t),
('score', ctypes.c_double),
]


OBX_bytes_score_p = ctypes.POINTER(OBX_bytes_score)


class OBX_bytes_score_array(ctypes.Structure):
_fields_ = [
('bytes_scores', OBX_bytes_score_p),
('count', ctypes.c_size_t),
]


OBX_bytes_score_array_p = ctypes.POINTER(OBX_bytes_score_array)


class OBX_id_array(ctypes.Structure):
_fields_ = [
('ids', ctypes.POINTER(obx_id)),
Expand All @@ -125,6 +149,26 @@ class OBX_id_array(ctypes.Structure):
OBX_id_array_p = ctypes.POINTER(OBX_id_array)


class OBX_id_score(ctypes.Structure):
_fields_ = [
('id', obx_id),
('score', ctypes.c_double)
]


OBX_id_score_p = ctypes.POINTER(OBX_id_score)


class OBX_id_score_array(ctypes.Structure):
_fields_ = [
('ids_scores', ctypes.POINTER(OBX_id_score)),
('count', ctypes.c_size_t)
]


OBX_id_score_array_p = ctypes.POINTER(OBX_id_score_array)


class OBX_txn(ctypes.Structure):
pass

Expand Down Expand Up @@ -223,7 +267,7 @@ def check_result(result, func, args):

# creates a global function "name" with the given restype & argtypes, calling C function with the same name.
# restype is used for error checking: if not None, check_result will throw an exception if the result is empty.
def c_fn(name: str, restype: type, argtypes):
def c_fn(name: str, restype: Optional[type], argtypes):
func = C.__getattr__(name)
func.argtypes = argtypes
func.restype = restype
Expand Down Expand Up @@ -272,8 +316,38 @@ def c_voidp_as_bytes(voidp, size):
[OBX_model_p, ctypes.c_char_p, OBXPropertyType, obx_schema_id, obx_uid])

# obx_err (OBX_model* model, OBXPropertyFlags flags);
obx_model_property_flags = c_fn_rc('obx_model_property_flags', [
OBX_model_p, OBXPropertyFlags])
obx_model_property_flags = c_fn_rc('obx_model_property_flags', [OBX_model_p, OBXPropertyFlags])

# obx_err obx_model_property_index_id(OBX_model* model, obx_schema_id index_id, obx_uid index_uid)
obx_model_property_index_id = c_fn_rc('obx_model_property_index_id', [OBX_model_p, obx_schema_id, obx_uid])

# obx_err obx_model_property_index_hnsw_dimensions(OBX_model* model, size_t value)
obx_model_property_index_hnsw_dimensions = \
c_fn_rc('obx_model_property_index_hnsw_dimensions', [OBX_model_p, ctypes.c_size_t])

# obx_err obx_model_property_index_hnsw_neighbors_per_node(OBX_model* model, uint32_t value)
obx_model_property_index_hnsw_neighbors_per_node = \
c_fn_rc('obx_model_property_index_hnsw_neighbors_per_node', [OBX_model_p, ctypes.c_uint32])

# obx_err obx_model_property_index_hnsw_indexing_search_count(OBX_model* model, uint32_t value)
obx_model_property_index_hnsw_indexing_search_count = \
c_fn_rc('obx_model_property_index_hnsw_indexing_search_count', [OBX_model_p, ctypes.c_uint32])

# obx_err obx_model_property_index_hnsw_flags(OBX_model* model, OBXHnswFlags value)
obx_model_property_index_hnsw_flags = \
c_fn_rc('obx_model_property_index_hnsw_flags', [OBX_model_p, OBXHnswFlags])

# obx_err obx_model_property_index_hnsw_distance_type(OBX_model* model, OBXHnswDistanceType value)
obx_model_property_index_hnsw_distance_type = \
c_fn_rc('obx_model_property_index_hnsw_distance_type', [OBX_model_p, OBXHnswDistanceType])

# obx_err obx_model_property_index_hnsw_reparation_backlink_probability(OBX_model* model, float value)
obx_model_property_index_hnsw_reparation_backlink_probability = \
c_fn_rc('obx_model_property_index_hnsw_reparation_backlink_probability', [OBX_model_p, ctypes.c_float])

# obx_err obx_model_property_index_hnsw_vector_cache_hint_size_kb(OBX_model* model, size_t value)
obx_model_property_index_hnsw_vector_cache_hint_size_kb = \
c_fn_rc('obx_model_property_index_hnsw_vector_cache_hint_size_kb', [OBX_model_p, ctypes.c_size_t])

# obx_err (OBX_model*, obx_schema_id entity_id, obx_uid entity_uid);
obx_model_last_entity_id = c_fn('obx_model_last_entity_id', None, [
Expand Down Expand Up @@ -536,9 +610,20 @@ def c_voidp_as_bytes(voidp, size):
# OBX_C_API obx_err obx_qb_param_alias(OBX_query_builder* builder, const char* alias);
obx_qb_param_alias = c_fn_rc('obx_qb_param_alias', [OBX_query_builder_p, ctypes.c_char_p])

# OBX_C_API obx_err obx_query_param_vector_float32(OBX_query* query, obx_schema_id entity_id, obx_schema_id property_id, const float* value, size_t element_count);
# TODO

# OBX_C_API obx_err obx_query_param_alias_vector_float32(OBX_query* query, const char* alias, const float* value, size_t element_count);
# TODO

# OBX_C_API obx_err obx_qb_order(OBX_query_builder* builder, obx_schema_id property_id, OBXOrderFlags flags);
obx_qb_order = c_fn_rc('obx_qb_order', [OBX_query_builder_p, obx_schema_id, OBXOrderFlags])

# OBX_C_API obx_qb_cond obx_qb_nearest_neighbors_f32(OBX_query_builder* builder, obx_schema_id vector_property_id, const float* query_vector, size_t max_result_count)
obx_qb_nearest_neighbors_f32 = \
c_fn('obx_qb_nearest_neighbors_f32', obx_qb_cond, [OBX_query_builder_p, obx_schema_id,
ctypes.pointer(ctypes.c_float), ctypes.c_size_t])

# OBX_C_API OBX_query* obx_query(OBX_query_builder* builder);
obx_query = c_fn('obx_query', OBX_query_p, [OBX_query_builder_p])

Expand Down Expand Up @@ -566,6 +651,9 @@ def c_voidp_as_bytes(voidp, size):
# OBX_C_API obx_err obx_query_find_unique(OBX_query* query, const void** data, size_t* size);
obx_query_find_unique = c_fn_rc('obx_query_find_unique', [OBX_query_p, ctypes.POINTER(ctypes.c_void_p), ctypes.POINTER(ctypes.c_size_t)])

# OBX_C_API OBX_bytes_score_array* obx_query_find_with_scores(OBX_query* query);
obx_query_find_with_scores = c_fn('obx_query_find_with_scores', OBX_bytes_score_array_p, [OBX_query_p]) # TODO

# typedef bool obx_data_visitor(void* user_data, const void* data, size_t size);

# OBX_C_API obx_err obx_query_visit(OBX_query* query, obx_data_visitor* visitor, void* user_data);
Expand All @@ -574,6 +662,9 @@ def c_voidp_as_bytes(voidp, size):
# OBX_C_API OBX_id_array* obx_query_find_ids(OBX_query* query);
obx_query_find_ids = c_fn('obx_query_find_ids', OBX_id_array_p, [OBX_query_p])

# OBX_C_API OBX_id_score_array* obx_query_find_ids_with_scores(OBX_query* query);
obx_query_find_ids_with_scores = c_fn('obx_query_find_ids_with_scores', OBX_id_score_array_p, [OBX_query_p]) # TODO

# OBX_C_API obx_err obx_query_count(OBX_query* query, uint64_t* out_count);
obx_query_count = c_fn_rc('obx_query_count', [OBX_query_p, ctypes.POINTER(ctypes.c_uint64)])

Expand All @@ -596,6 +687,12 @@ def c_voidp_as_bytes(voidp, size):
# void (OBX_bytes_array * array);
obx_bytes_array_free = c_fn('obx_bytes_array_free', None, [OBX_bytes_array_p])

# OBX_C_API void obx_bytes_score_array_free(OBX_bytes_score_array* array)
obx_bytes_score_array_free = c_fn('obx_bytes_score_array_free', None, [OBX_bytes_score_array_p])

# OBX_C_API void obx_id_score_array_free(OBX_id_score_array* array)
obx_id_score_array_free = c_fn('obx_id_score_array_free', None, [OBX_id_score_array_p])

OBXPropertyType_Bool = 1
OBXPropertyType_Byte = 2
OBXPropertyType_Short = 3
Expand Down Expand Up @@ -669,3 +766,12 @@ def c_voidp_as_bytes(voidp, size):

# null values should be treated equal to zero (scalars only).
OBXOrderFlags_NULLS_ZERO = 16

OBXHnswFlags_NONE = 0
OBXHnswFlags_DEBUG_LOGS = 1
OBXHnswFlags_DEBUG_LOGS_DETAILED = 2
OBXHnswFlags_VECTOR_CACHE_SIMD_PADDING_OFF = 4
OBXHnswFlags_REPARATION_LIMIT_CANDIDATES = 8

OBXHnswDistanceType_UNKNOWN = 0
OBXHnswDistanceType_EUCLIDEAN = 1

0 comments on commit cbdd07d

Please sign in to comment.