diff --git a/deeplake/constants.py b/deeplake/constants.py index edb1056f8e..81eca454f9 100644 --- a/deeplake/constants.py +++ b/deeplake/constants.py @@ -329,6 +329,7 @@ DEFAULT_VECTORSTORE_INDEX_PARAMS = { "threshold": -1, + "bm25": False, "distance_metric": DEFAULT_VECTORSTORE_DISTANCE_METRIC, "additional_params": { "efConstruction": 600, diff --git a/deeplake/core/dataset/dataset.py b/deeplake/core/dataset/dataset.py index 6fca825095..f15f6be034 100644 --- a/deeplake/core/dataset/dataset.py +++ b/deeplake/core/dataset/dataset.py @@ -549,6 +549,7 @@ def __getitem__( enabled_tensors=self.enabled_tensors, view_base=self._view_base or self, libdeeplake_dataset=self.libdeeplake_dataset, + index_params=self.index_params, ) elif "/" in item: splt = posixpath.split(item) @@ -595,6 +596,7 @@ def __getitem__( enabled_tensors=enabled_tensors, view_base=self._view_base or self, libdeeplake_dataset=self.libdeeplake_dataset, + index_params=self.index_params, ) elif isinstance(item, tuple) and len(item) and isinstance(item[0], str): ret = self @@ -624,6 +626,7 @@ def __getitem__( enabled_tensors=self.enabled_tensors, view_base=self._view_base or self, libdeeplake_dataset=self.libdeeplake_dataset, + index_params=self.index_params, ) else: raise InvalidKeyTypeError(item) @@ -2904,6 +2907,7 @@ def parent(self): path=self.path, link_creds=self.link_creds, libdeeplake_dataset=self.libdeeplake_dataset, + index_params=self.index_params, ) self.storage.autoflush = autoflush return ds @@ -2927,6 +2931,7 @@ def root(self): link_creds=self.link_creds, view_base=self._view_base, libdeeplake_dataset=self.libdeeplake_dataset, + index_params=self.index_params, ) self.storage.autoflush = autoflush return ds @@ -2950,6 +2955,7 @@ def no_view_dataset(self): pad_tensors=self._pad_tensors, enabled_tensors=self.enabled_tensors, libdeeplake_dataset=self.libdeeplake_dataset, + index_params=self.index_params, ) def _create_group(self, name: str) -> "Dataset": @@ -4824,6 +4830,7 @@ def max_view(self): pad_tensors=True, enabled_tensors=self.enabled_tensors, libdeeplake_dataset=self.libdeeplake_dataset, + index_params=self.index_params, ) def random_split(self, lengths: Sequence[Union[int, float]]): diff --git a/deeplake/core/index_maintenance.py b/deeplake/core/index_maintenance.py index 113b4f00aa..d844bd522d 100644 --- a/deeplake/core/index_maintenance.py +++ b/deeplake/core/index_maintenance.py @@ -28,6 +28,16 @@ def is_embedding_tensor(tensor): or tensor.key in valid_names ) +def is_text_tensor(tensor): + """Check if a tensor is a text tensor.""" + + valid_names = ["text"] + + return ( + tensor.htype == "text" + or tensor.meta.name in valid_names + or tensor.key in valid_names + ) def validate_embedding_tensor(tensor): """Check if a tensor is an embedding tensor.""" @@ -40,6 +50,17 @@ def validate_embedding_tensor(tensor): or tensor.key in valid_names ) +def validate_text_tensor(tensor): + """Check if a tensor is an embedding tensor.""" + + valid_names = ["text"] + + return ( + tensor.meta.name in valid_names and + tensor.htype == "text" and + tensor.key in valid_names + ) + def fetch_embedding_tensor(dataset): tensors = dataset.tensors @@ -48,8 +69,15 @@ def fetch_embedding_tensor(dataset): return tensor return None +def fetch_text_tensor(dataset): + tensors = dataset.tensors + for _, tensor in tensors.items(): + if validate_text_tensor(tensor): + return tensor + return None + -def index_exists(dataset): +def index_exists_for_embedding_tensor(dataset): """Check if the Index already exists.""" emb_tensor = fetch_embedding_tensor(dataset) if emb_tensor is not None: @@ -61,6 +89,18 @@ def index_exists(dataset): else: return False +def index_exists_for_text_tensor(dataset): + """Check if the Index already exists.""" + text_tensor = fetch_text_tensor(dataset) + if text_tensor is not None: + vdb_indexes = text_tensor.fetch_vdb_indexes() + if len(vdb_indexes) == 0: + return False + else: + return True + else: + return False + def index_used(exec_option): """Check if the index is used for the exec_option""" @@ -110,7 +150,7 @@ def check_index_params(self): def index_operation_type_dataset(self, num_rows, changed_data_len): - if not index_exists(self): + if not index_exists_for_embedding_tensor(self): if self.index_params is None: return INDEX_OP_TYPE.NOOP threshold = self.index_params.get("threshold", -1) @@ -183,13 +223,14 @@ def check_vdb_indexes(dataset): def _incr_maintenance_vdb_indexes(tensor, indexes, index_operation): try: is_embedding = tensor.htype == "embedding" + is_text = tensor.htype == "text" has_vdb_indexes = hasattr(tensor.meta, "vdb_indexes") try: vdb_index_ids_present = len(tensor.meta.vdb_indexes) > 0 except AttributeError: vdb_index_ids_present = False - if is_embedding and has_vdb_indexes and vdb_index_ids_present: + if is_embedding or is_text and has_vdb_indexes and vdb_index_ids_present: for vdb_index in tensor.meta.vdb_indexes: tensor.update_vdb_index( operation_kind=index_operation, @@ -204,44 +245,71 @@ def index_operation_vectorstore(self): if not index_used(self.exec_option): return None - emb_tensor = fetch_embedding_tensor(self.dataset) - - if index_exists(self.dataset) and check_index_params(self): - return emb_tensor.get_vdb_indexes()[0]["distance"] - threshold = self.index_params.get("threshold", -1) below_threshold = threshold < 0 or len(self.dataset) < threshold if below_threshold: return None - if not check_index_params(self): - try: - vdb_indexes = emb_tensor.get_vdb_indexes() - for vdb_index in vdb_indexes: - emb_tensor.delete_vdb_index(vdb_index["id"]) - except Exception as e: - raise Exception(f"An error occurred while removing VDB indexes: {e}") + bm25 = self.index_params.get("bm25", False) + print("BM25: ", bm25) + if bm25: + txt_tensor = fetch_text_tensor(self.dataset) + + emb_tensor = fetch_embedding_tensor(self.dataset) + + # TODO have to revisit it later. + if index_exists_for_embedding_tensor(self.dataset) and check_index_params(self): + return emb_tensor.get_vdb_indexes()[0]["distance"] + + if bm25 and index_exists_for_text_tensor(self.dataset): + return txt_tensor.get_vdb_indexes()[0] + + # if not check_index_params(self): + # try: + # vdb_indexes = tensor.get_vdb_indexes() + # for vdb_index in vdb_indexes: + # tensor.delete_vdb_index(vdb_index["id"]) + # except Exception as e: + # raise Exception(f"An error occurred while removing VDB indexes: {e}") + + + if bm25: + print("Creating BM25 index") + txt_tensor.create_vdb_index("bm25") + distance_str = self.index_params.get("distance_metric", "COS") additional_params_dict = self.index_params.get("additional_params", None) distance = get_index_metric(distance_str.upper()) if additional_params_dict and len(additional_params_dict) > 0: param_dict = normalize_additional_params(additional_params_dict) + print("Creating HNSW index") emb_tensor.create_vdb_index( "hnsw_1", distance=distance, additional_params=param_dict ) else: + print("Creating HNSW index") emb_tensor.create_vdb_index("hnsw_1", distance=distance) return distance def index_operation_dataset(self, dml_type, rowids): + if self.index_params is None: + return + + bm25 = self.index_params.get("bm25", False) + txt_tensor = None + if bm25: + txt_tensor = fetch_text_tensor(self) + emb_tensor = fetch_embedding_tensor(self) - if emb_tensor is None: + if emb_tensor and txt_tensor is None: return + num_rows = txt_tensor.chunk_engine.num_samples if txt_tensor is not None else emb_tensor.chunk_engine.num_samples + index_operation_type = index_operation_type_dataset( self, - emb_tensor.chunk_engine.num_samples, + num_rows, len(rowids), ) @@ -254,13 +322,23 @@ def index_operation_dataset(self, dml_type, rowids): ): if index_operation_type == INDEX_OP_TYPE.REGENERATE_INDEX: try: - vdb_indexes = emb_tensor.get_vdb_indexes() - for vdb_index in vdb_indexes: - emb_tensor.delete_vdb_index(vdb_index["id"]) + if txt_tensor is not None: + print("Regenerating BM25 index for text tensor") + vdb_indexes = txt_tensor.get_vdb_indexes() + for vdb_index in vdb_indexes: + txt_tensor.delete_vdb_index(vdb_index["id"]) + else: + vdb_indexes = emb_tensor.get_vdb_indexes() + for vdb_index in vdb_indexes: + emb_tensor.delete_vdb_index(vdb_index["id"]) except Exception as e: raise Exception( f"An error occurred while regenerating VDB indexes: {e}" ) + if txt_tensor is not None: + print("Creating BM25 index") + txt_tensor.create_vdb_index("bm25_1") + distance_str = self.index_params.get("distance_metric", "COS") additional_params_dict = self.index_params.get("additional_params", None) distance = get_index_metric(distance_str.upper()) @@ -272,6 +350,10 @@ def index_operation_dataset(self, dml_type, rowids): else: emb_tensor.create_vdb_index("hnsw_1", distance=distance) elif index_operation_type == INDEX_OP_TYPE.INCREMENTAL_INDEX: + if txt_tensor is not None: + print("Incremental maintenance of BM25 index") + _incr_maintenance_vdb_indexes(txt_tensor, rowids, dml_type) + _incr_maintenance_vdb_indexes(emb_tensor, rowids, dml_type) else: raise Exception("Unknown index operation") diff --git a/deeplake/core/meta/tensor_meta.py b/deeplake/core/meta/tensor_meta.py index 8d90d39208..2828d78832 100644 --- a/deeplake/core/meta/tensor_meta.py +++ b/deeplake/core/meta/tensor_meta.py @@ -229,6 +229,9 @@ def __setstate__(self, state: Dict[str, Any]): if self.htype == "embedding" and not hasattr(self, "vdb_indexes"): self.vdb_indexes = [] self._required_meta_keys += ("vdb_indexes",) + if self.htype == "text" and not hasattr(self, "vdb_indexes"): + self.vdb_indexes = [] + self._required_meta_keys += ("vdb_indexes",) @property def nbytes(self): diff --git a/deeplake/core/tensor.py b/deeplake/core/tensor.py index 288a71f318..c910e1a970 100644 --- a/deeplake/core/tensor.py +++ b/deeplake/core/tensor.py @@ -1537,8 +1537,8 @@ def update_vdb_index( row_ids: List[int] = [], ): self.storage.check_readonly() - if self.meta.htype != "embedding": - raise Exception(f"Only supported for embedding tensors.") + #if self.meta.htype != "embedding": + # raise Exception(f"Only supported for embedding tensors.") self.invalidate_libdeeplake_dataset() self.dataset.flush() from deeplake.enterprise.convert_to_libdeeplake import ( @@ -1604,8 +1604,8 @@ def create_vdb_index( additional_params: Optional[Dict[str, int]] = None, ): self.storage.check_readonly() - if self.meta.htype != "embedding": - raise Exception(f"Only supported for embedding tensors.") + # if self.meta.htype != "embedding": + # raise Exception(f"Only supported for embedding tensors.") if not self.dataset.libdeeplake_dataset is None: ds = self.dataset.libdeeplake_dataset else: @@ -1617,11 +1617,36 @@ def create_vdb_index( ts = getattr(ds, self.meta.name) from indra import api # type: ignore + if self.meta.htype == "text": + self.meta.add_vdb_index( + id=id, type="bm25", distance=None + ) + try: + if additional_params is None: + index = api.vdb.generate_index( + ts, index_type="bm25" + ) + else: + index = api.vdb.generate_index( + ts, + index_type="bm25", + param=additional_params, + ) + b = index.serialize() + commit_id = self.version_state["commit_id"] + self.storage[get_tensor_vdb_index_key(self.key, commit_id, id)] = b + self.invalidate_libdeeplake_dataset() + #self.storage.flush() + except: + self.meta.remove_vdb_index(id=id) + raise + return index if type(distance) == DistanceType: distance = distance.value self.meta.add_vdb_index( id=id, type="hnsw", distance=distance, additional_params=additional_params ) + try: if additional_params is None: index = api.vdb.generate_index( @@ -1638,6 +1663,7 @@ def create_vdb_index( commit_id = self.version_state["commit_id"] self.storage[get_tensor_vdb_index_key(self.key, commit_id, id)] = b self.invalidate_libdeeplake_dataset() + except: self.meta.remove_vdb_index(id=id) raise @@ -1645,8 +1671,8 @@ def create_vdb_index( def delete_vdb_index(self, id: str): self.storage.check_readonly() - if self.meta.htype != "embedding": - raise Exception(f"Only supported for embedding tensors.") + #if self.meta.htype != "embedding": + # raise Exception(f"Only supported for embedding tensors.") commit_id = self.version_state["commit_id"] self.unload_vdb_index_cache() self.storage.pop(get_tensor_vdb_index_key(self.key, commit_id, id)) @@ -1656,7 +1682,8 @@ def delete_vdb_index(self, id: str): def _verify_and_delete_vdb_indexes(self): try: - is_embedding = self.htype == "embedding" + #is_embedding = self.htype == "embedding" + is_embedding = True has_vdb_indexes = hasattr(self.meta, "vdb_indexes") try: vdb_index_ids_present = len(self.meta.vdb_indexes) > 0 @@ -1671,8 +1698,8 @@ def _verify_and_delete_vdb_indexes(self): raise Exception(f"An error occurred while deleting VDB indexes: {e}") def load_vdb_index(self, id: str): - if self.meta.htype != "embedding": - raise Exception(f"Only supported for embedding tensors.") + #if self.meta.htype != "embedding": + # raise Exception(f"Only supported for embedding tensors.") if not self.meta.contains_vdb_index(id): raise ValueError(f"Tensor meta has no vdb index with name '{id}'.") if not self.dataset.libdeeplake_dataset is None: @@ -1693,8 +1720,8 @@ def load_vdb_index(self, id: str): raise ValueError(f"An error occurred while loading the VDB index {id}: {e}") def unload_vdb_index_cache(self): - if self.meta.htype != "embedding": - raise Exception(f"Only supported for embedding tensors.") + #if self.meta.htype != "embedding": + # raise Exception(f"Only supported for embedding tensors.") if not self.dataset.libdeeplake_dataset is None: ds = self.dataset.libdeeplake_dataset else: @@ -1713,15 +1740,15 @@ def unload_vdb_index_cache(self): raise Exception(f"An error occurred while cleaning VDB Cache: {e}") def get_vdb_indexes(self) -> List[Dict[str, str]]: - if self.meta.htype != "embedding": - raise Exception(f"Only supported for embedding tensors.") + #if self.meta.htype != "embedding": + # raise Exception(f"Only supported for embedding tensors.") return self.meta.vdb_indexes def fetch_vdb_indexes(self) -> List[Dict[str, str]]: vdb_indexes = [] - if self.meta.htype == "embedding": - if (not self.meta.vdb_indexes is None) and len(self.meta.vdb_indexes) > 0: - vdb_indexes.extend(self.meta.vdb_indexes) + #if self.meta.htype == "embedding": + if (not self.meta.vdb_indexes is None) and len(self.meta.vdb_indexes) > 0: + vdb_indexes.extend(self.meta.vdb_indexes) return vdb_indexes def _check_compatibility_with_htype(self, htype): diff --git a/deeplake/core/vectorstore/deep_memory/deep_memory.py b/deeplake/core/vectorstore/deep_memory/deep_memory.py index c6c46b100f..56e3d80a28 100644 --- a/deeplake/core/vectorstore/deep_memory/deep_memory.py +++ b/deeplake/core/vectorstore/deep_memory/deep_memory.py @@ -678,6 +678,7 @@ def recall_at_k( return avg_recalls, queries_data + def get_view( metric: str, query_emb: Union[List[float], np.ndarray], diff --git a/deeplake/htype.py b/deeplake/htype.py index 2933f20932..515ae45f06 100644 --- a/deeplake/htype.py +++ b/deeplake/htype.py @@ -95,7 +95,7 @@ class htype: "dtype": "Any", }, htype.LIST: {"dtype": "List"}, - htype.TEXT: {"dtype": "str"}, + htype.TEXT: {"dtype": "str", "vdb_indexes": []}, htype.TAG: {"dtype": "List"}, htype.DICOM: {"sample_compression": "dcm"}, htype.NIFTI: {},