Skip to content

Commit

Permalink
Fix deletion bug (set size change while looping)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhuici committed Nov 28, 2024
1 parent a4aa76e commit c3b1022
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
28 changes: 16 additions & 12 deletions datalayer/hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,18 +356,22 @@ def _insert_node_to_layers(self, new_node, enter_point):
enter_point = currently_found_nn

def _delete_neighbors_connections(self, node):
"""Given a node, deletes the connections to their neighbors.
Arguments:
node -- the node to delete
"""

logger.debug(f"Deleting neighbors of \"{node.get_id()}\"")
for layer in range(node.get_max_layer() + 1):
for neighbor in node.get_neighbors_at_layer(layer):
logger.debug(f"Deleting at L{layer} link \"{neighbor.get_id()}\"")
neighbor.remove_neighbor(layer, node)
node.remove_neighbor(layer, neighbor)
"""Given a node, deletes the connections to their neighbors.
Arguments:
node -- the node to delete
"""

logger.debug(f"Deleting neighbors of \"{node.get_id()}\"")
for layer in range(node.get_max_layer() + 1):
neighbors_to_remove = set()
for neighbor in node.get_neighbors_at_layer(layer):
logger.debug(f"Deleting at L{layer} link \"{neighbor.get_id()}\"")
neighbors_to_remove.add(neighbor)

for neighbor in neighbors_to_remove: # bidirectionally remove links
node.remove_neighbor(layer, neighbor)
neighbor.remove_neighbor(layer, node)

def _delete_node_dict(self, node):
"""Deletes a node from the dict of the HNSW structure.
Expand Down
3 changes: 1 addition & 2 deletions tests/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,14 @@ def test_search_approximate(self):
self.assertEqual(actual_founds, expected_founds)
self.assertEqual(actual_distances, expected_distances)

@unittest.skip("?")
def test_deletion(self):
for hash in HASHES[:5]:
self.apo_model.delete(HashNode(hash, TLSHHashAlgorithm))

expected_founds = [False, False, False, False, False, True, True, True, True, True]
actual_founds = []
for hash in HASHES:
found, exact, result_dict = self.apo_model.knn_search(HashNode(hash, TLSHHashAlgorithm), 1)
found, _, _ = self.apo_model.knn_search(HashNode(hash, TLSHHashAlgorithm), 1)
actual_founds.append(found)

self.assertEqual(actual_founds, expected_founds)
Expand Down

0 comments on commit c3b1022

Please sign in to comment.