Skip to content

Commit

Permalink
model: add support for index, and HNSW index objectbox#24
Browse files Browse the repository at this point in the history
TODO: fix/extend index unit tests!
  • Loading branch information
loryruta committed Apr 10, 2024
1 parent c8ad84a commit 874520b
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 58 deletions.
19 changes: 11 additions & 8 deletions objectbox/model/entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,13 @@ def fill_properties(self):
elif self.id_property._ob_type != OBXPropertyType_Long:
raise Exception("ID property must be an int")

def get_property(self, name: str):
""" Gets the property having the given name. """
for prop in self.properties:
if prop._name == name:
return prop
raise Exception(f"Property \"{name}\" not found in Entity: \"{self.name}\"")

def get_value(self, object, prop: Property):
# in case value is not overwritten on the object, it's the Property object itself (= as defined in the Class)
val = getattr(object, prop._name)
Expand Down Expand Up @@ -228,12 +235,8 @@ def unmarshal(self, data: bytes):
return obj


# entity decorator - wrap _Entity to allow @Entity(id=, uid=), i.e. no class argument
def Entity(cls=None, id: int = 0, uid: int = 0):
if cls:
def Entity(id: int = 0, uid: int = 0):
""" Entity decorator that wraps _Entity to allow @Entity(id=, uid=); i.e. no class arguments. """
def wrapper(cls):
return _Entity(cls, id, uid)
else:
def wrapper(cls):
return _Entity(cls, id, uid)

return wrapper
return wrapper
61 changes: 43 additions & 18 deletions objectbox/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@


from objectbox.model.entity import _Entity
from objectbox.model.properties import *
from objectbox.c import *
import logging


class IdUid:
Expand All @@ -36,35 +38,58 @@ def __init__(self):
self.last_index_id = IdUid(0, 0)
self.last_relation_id = IdUid(0, 0)

def entity(self, entity: _Entity, last_property_id: IdUid):
def _set_hnsw_params(self, index: HnswIndex):
if index.dimensions is not None:
obx_model_property_index_hnsw_dimensions(self._c_model, index.dimensions)
if index.neighbors_per_node is not None:
obx_model_property_index_hnsw_neighbors_per_node(self._c_model, index.neighbors_per_node)
if index.indexing_search_count is not None:
obx_model_property_index_hnsw_indexing_search_count(self._c_model, index.indexing_search_count)
if index.flags is not None:
obx_model_property_index_hnsw_flags(self._c_model, index.flags)
if index.distance_type is not None:
obx_model_property_index_hnsw_distance_type(self._c_model, index.distance_type)
if index.reparation_backlink_probability is not None:
obx_model_property_index_hnsw_reparation_backlink_probability(self._c_model, index.reparation_backlink_probability)
if index.vector_cache_hint_size_kb is not None:
obx_model_property_index_hnsw_vector_cache_hint_size_kb(self._c_model, index.vector_cache_hint_size_kb)

def entity(self, entity: _Entity, last_property_id: IdUid, last_index_id: Optional[IdUid] = None):
if not isinstance(entity, _Entity):
raise Exception("Given type is not an Entity. Are you passing an instance instead of a type or did you "
"forget the '@Entity' annotation?")

entity.last_property_id = last_property_id

obx_model_entity(self._c_model, c_str(
entity.name), entity.id, entity.uid)
obx_model_entity(self._c_model, c_str(entity.name), entity.id, entity.uid)

for v in entity.properties:
obx_model_property(self._c_model, c_str(
v._name), v._ob_type, v._id, v._uid)
if v._flags != 0:
obx_model_property_flags(self._c_model, v._flags)
logging.debug(f"Creating entity \"{entity.name}\" (ID={entity.id}, {entity.uid})")

obx_model_entity_last_property_id(
self._c_model, last_property_id.id, last_property_id.uid)
for property_ in entity.properties:
obx_model_property(self._c_model, c_str(property_._name), property_._ob_type, property_._id, property_._uid)

# called by Builder
def _finish(self):
logging.debug(f"Creating property \"\" (ID={property_._id}, UID={property_._uid})")

if property_._flags != 0:
obx_model_property_flags(self._c_model, property_._flags)

if property_._index is not None:
index = property_._index
if isinstance(index, HnswIndex):
self._set_hnsw_params(index)
logging.debug(f" HNSW index (ID={index.id}, UID{index.uid}); Dimensions: {index.dimensions}")
else:
logging.debug(f" Index (ID={index.id}, UID{index.uid}); Type: {index.type}")
obx_model_property_index_id(self._c_model, index.id, index.uid)

obx_model_entity_last_property_id(self._c_model, last_property_id.id, last_property_id.uid)

def _finish(self): # Called by Builder
if self.last_relation_id:
obx_model_last_relation_id(
self._c_model, self.last_relation_id.id, self.last_relation_id.uid)
obx_model_last_relation_id(self._c_model, self.last_relation_id.id, self.last_relation_id.uid)

if self.last_index_id:
obx_model_last_index_id(
self._c_model, self.last_index_id.id, self.last_index_id.uid)
obx_model_last_index_id(self._c_model, self.last_index_id.id, self.last_index_id.uid)

if self.last_entity_id:
obx_model_last_entity_id(
self._c_model, self.last_entity_id.id, self.last_entity_id.uid)
obx_model_last_entity_id(self._c_model, self.last_entity_id.id, self.last_entity_id.uid)
85 changes: 59 additions & 26 deletions objectbox/model/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from objectbox.c import *
import flatbuffers.number_types
import numpy as np
from dataclasses import dataclass


class PropertyType(IntEnum):
Expand Down Expand Up @@ -72,42 +73,68 @@ class PropertyType(IntEnum):


class IndexType(IntEnum):
value = OBXPropertyFlags_INDEXED
hash = OBXPropertyFlags_INDEX_HASH
hash64 = OBXPropertyFlags_INDEX_HASH64
VALUE = OBXPropertyFlags_INDEXED
HASH = OBXPropertyFlags_INDEX_HASH
HASH64 = OBXPropertyFlags_INDEX_HASH64


@dataclass
class Index:
id: int
uid: int
# TODO HNSW isn't a type but HASH and HASH64 are, remove type member and make HashIndex and Hash64Index classes?
type: IndexType = IndexType.VALUE


class HnswFlags(IntEnum):
NONE = 0
DEBUG_LOGS = 1
DEBUG_LOGS_DETAILED = 2
VECTOR_CACHE_SIMD_PADDING_OFF = 4
REPARATION_LIMIT_CANDIDATES = 8


class HnswDistanceType(IntEnum):
UNKNOWN = OBXHnswDistanceType_UNKNOWN,
EUCLIDEAN = OBXHnswDistanceType_EUCLIDEAN


@dataclass
class HnswIndex:
id: int
uid: int
dimensions: int
neighbors_per_node: Optional[int] = None
indexing_search_count: Optional[int] = None
flags: HnswFlags = HnswFlags.NONE
distance_type: HnswDistanceType = HnswDistanceType.EUCLIDEAN
reparation_backlink_probability: Optional[float] = None
vector_cache_hint_size_kb: Optional[float] = None


class Property:
def __init__(self, py_type: type, id: int, uid: int, type: PropertyType = None, index: bool = None, index_type: IndexType = None):
self._id = id
self._uid = uid
def __init__(self, pytype: Type, **kwargs):
self._id = kwargs['id']
self._uid = kwargs['uid']
self._name = "" # set in Entity.fill_properties()

self._py_type = py_type
self._ob_type = type if type != None else self.__determine_ob_type()
self._py_type = pytype
self._ob_type = kwargs['type'] if 'type' in kwargs else self._determine_ob_type()
self._fb_type = fb_type_map[self._ob_type]

self._is_id = isinstance(self, Id)
self._flags = OBXPropertyFlags(0)
self.__set_flags()
self._flags = 0

# FlatBuffers marshalling information
self._fb_slot = self._id - 1
self._fb_v_offset = 4 + 2*self._fb_slot

if index_type:
if index == True or index == None:
self._index = True
self._index_type = index_type
elif index == False:
raise Exception(f"trying to set index type on property with id {self._id} while index is set to False")
else:
self._index = index if index != None else False
if index:
self._index_type = IndexType.value if self._py_type != str else IndexType.hash
self._fb_v_offset = 4 + 2 * self._fb_slot

self._index = kwargs.get('index', None)

def __determine_ob_type(self) -> OBXPropertyType:
self._set_flags()

def _determine_ob_type(self) -> OBXPropertyType:
""" Tries to infer the OBX property type from the Python type. """
ts = self._py_type
if ts == str:
return OBXPropertyType_String
Expand All @@ -124,9 +151,15 @@ def __determine_ob_type(self) -> OBXPropertyType:
else:
raise Exception("unknown property type %s" % ts)

def __set_flags(self):
def _set_flags(self):
if self._is_id:
self._flags = OBXPropertyFlags_ID
self._flags |= OBXPropertyFlags_ID

if self._index is not None:
self._flags |= OBXPropertyFlags_INDEXED
if isinstance(self._index, Index): # Generic index
self._flags |= self._index.type
print("Flags set to", self._flags, bin(self._flags))

def op(self, op: _ConditionOp, value, case_sensitive: bool = True) -> QueryCondition:
return QueryCondition(self._id, op, value, case_sensitive)
Expand Down Expand Up @@ -165,4 +198,4 @@ def between(self, value_a, value_b) -> QueryCondition:
# ID property (primary key)
class Id(Property):
def __init__(self, py_type: type = int, id: int = 0, uid: int = 0):
super(Id, self).__init__(py_type, id, uid)
super(Id, self).__init__(py_type, id=id, uid=uid)
13 changes: 7 additions & 6 deletions tests/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from objectbox.model import *
from objectbox.model.properties import IndexType
from objectbox.model.properties import *
import numpy as np
from datetime import datetime
from typing import Generic, Dict, Any
Expand All @@ -10,16 +10,17 @@ class TestEntity:
id = Id(id=1, uid=1001)
# TODO Enable indexing dynamically, e.g. have a constructor to enable index(es).
# E.g. indexString=False (defaults to false). Same for bytes.
str = Property(str, id=2, uid=1002, index=True)
# TODO Test HASH and HASH64 indices (only supported for strings)
str = Property(str, id=2, uid=1002, index=Index(id=1, uid=10001))
bool = Property(bool, id=3, uid=1003)
int64 = Property(int, type=PropertyType.long, id=4, uid=1004, index=True)
int32 = Property(int, type=PropertyType.int, id=5, uid=1005, index=True, index_type=IndexType.hash)
int16 = Property(int, type=PropertyType.short, id=6, uid=1006, index_type=IndexType.hash)
int64 = Property(int, type=PropertyType.long, id=4, uid=1004, index=Index(id=2, uid=10002))
int32 = Property(int, type=PropertyType.int, id=5, uid=1005)
int16 = Property(int, type=PropertyType.short, id=6, uid=1006)
int8 = Property(int, type=PropertyType.byte, id=7, uid=1007)
float64 = Property(float, type=PropertyType.double, id=8, uid=1008)
float32 = Property(float, type=PropertyType.float, id=9, uid=1009)
bools = Property(np.ndarray, type=PropertyType.boolVector, id=10, uid=1010)
bytes = Property(bytes, id=11, uid=1011, index_type=IndexType.hash64)
bytes = Property(bytes, id=11, uid=1011)
shorts = Property(np.ndarray, type=PropertyType.shortVector, id=12, uid=1012)
chars = Property(np.ndarray, type=PropertyType.charVector, id=13, uid=1013)
ints = Property(np.ndarray, type=PropertyType.intVector, id=14, uid=1014)
Expand Down

0 comments on commit 874520b

Please sign in to comment.