diff --git a/src/matchbox/common/hash.py b/src/matchbox/common/hash.py index aa2a279..8b6f2f8 100644 --- a/src/matchbox/common/hash.py +++ b/src/matchbox/common/hash.py @@ -1,6 +1,5 @@ import base64 import hashlib -from functools import lru_cache from typing import TYPE_CHECKING, Any, TypeVar from uuid import UUID @@ -112,38 +111,48 @@ def columns_to_value_ordered_hash(data: DataFrame, columns: list[str]) -> Series return Series(hashed_records) -@lru_cache(maxsize=None) -def combine_integers(*n: int) -> int: +class IntMap: """ - Combine n integers into a single negative integer. - - Used to create a symmetric deterministic hash of two integers that populates the - range of integers efficiently and reduces the likelihood of collisions. - - Aims to vectorise amazingly when used in Arrow. - - Does this by: - - * Using a Mersenne prime as a modulus - * Making negative integers positive with modulo, sped up with bitwise operations - * Combining using symmetric operations with coprime multipliers - - Args: - *args: Variable number of integers to combine - - Returns: - A negative integer + A data structure taking unordered sets of integers, and mapping them a to an ID that + 1) is a negative integer; 2) does not collide with other IDs generated by other + instances of this class, as long as they are initialised with a different salt. + + The fact that IDs are always negative means that it's possible to build a hierarchy + where IDs are themselves parts of other sets, and it's easy to distinguish integers + mapped to raw data points (which will be non-negative), to integers that are IDs + (which will be negative). The salt allows to work with a parallel execution + model, where each worker maintains their separate ID space, as long as each worker + operates on disjoint subsets of positive integers. """ - P = 2147483647 - - total = 0 - product = 1 - - for x in sorted(n): - x_pos = (x ^ (x >> 31)) - (x >> 31) - total = (total + x_pos) % P - product = (product * x_pos) % P - - result = (31 * total + 37 * product) % P - return -result + def __init__(self, salt: int): + self.mapping: dict[frozenset[int], int] = {} + if salt < 0: + raise ValueError("The salt must be a positive integer") + self.salt: int = salt + + def _salt_id(self, i: int) -> int: + """ + Use Cantor pairing function on the salt and a negative int ID, and minus it. + """ + if i >= 0: + raise ValueError("ID must be a negative integer") + return -int(0.5 * (self.salt - i) * (self.salt - i + 1) - i) + + def index(self, *values: int) -> int: + """ + Args: + values: the integers in the set you want to index + + Returns: + The old or new ID corresponding to the set + """ + value_set = frozenset(values) + if value_set in self.mapping: + return self.mapping[value_set] + + new_id: int = -len(self.mapping) - 1 + salted_id = self._salt_id(new_id) + self.mapping[value_set] = salted_id + + return salted_id diff --git a/src/matchbox/common/results.py b/src/matchbox/common/results.py index 9416a50..ff97e5c 100644 --- a/src/matchbox/common/results.py +++ b/src/matchbox/common/results.py @@ -26,8 +26,8 @@ from matchbox.common.db import Cluster, Probability from matchbox.common.hash import ( + IntMap, columns_to_value_ordered_hash, - combine_integers, list_to_value_ordered_hash, ) from matchbox.server.base import MatchboxDBAdapter, inject_backend @@ -303,7 +303,7 @@ def to_clusters(results: ProbabilityResults) -> ClusterResults: ) # Get unique probability thresholds, sorted - thresholds = edges_df["probability"].unique() + thresholds = sorted(edges_df["probability"].unique()) # Process edges grouped by probability threshold for prob in thresholds: @@ -367,7 +367,7 @@ def attach_components_to_probabilities(probabilities: pa.Table) -> pa.Table: graph = rx.PyGraph(node_count_hint=n_nodes, edge_count_hint=n_edges) graph.add_nodes_from(range(n_nodes)) - edges = tuple(zip(left_indices.to_numpy(), right_indices.to_numpy(), strict=False)) + edges = tuple(zip(left_indices.to_numpy(), right_indices.to_numpy(), strict=True)) graph.add_edges_from_no_data(edges) components = rx.connected_components(graph) @@ -413,6 +413,7 @@ def find(self, x: T, parent_dict: dict[T, T] | None = None) -> T: self._shadow_parent[x] = x self._shadow_rank[x] = 0 + # TODO: Instead of being a `while`, could this be an `if`? while parent_dict[x] != x: parent_dict[x] = parent_dict[parent_dict[x]] x = parent_dict[x] @@ -512,7 +513,9 @@ def diff(self) -> Iterator[tuple[set[T], set[T]]]: self._shadow_rank = self.rank.copy() -def component_to_hierarchy(table: pa.Table, dtype: pa.DataType = pa.int32) -> pa.Table: +def component_to_hierarchy( + table: pa.Table, dtype: pa.DataType = pa.int32, salt: int = 1 +) -> pa.Table: """ Convert pairwise probabilities into a hierarchical representation. @@ -526,6 +529,7 @@ def component_to_hierarchy(table: pa.Table, dtype: pa.DataType = pa.int32) -> pa """ hierarchy: list[tuple[int, int, float]] = [] uf = UnionFindWithDiff[int]() + im = IntMap(salt=salt) probs = pc.unique(table["probability"]) for threshold in probs: @@ -537,24 +541,24 @@ def component_to_hierarchy(table: pa.Table, dtype: pa.DataType = pa.int32) -> pa for row in zip( current_probs["left"].to_numpy(), current_probs["right"].to_numpy(), - strict=False, + strict=True, ): left, right = row uf.union(left, right) - parent = combine_integers(left, right) + parent = im.index(left, right) hierarchy.extend([(parent, left, threshold), (parent, right, threshold)]) # Process union-find diffs for old_comp, new_comp in uf.diff(): if len(old_comp) > 1: - parent = combine_integers(*new_comp) - child = combine_integers(*old_comp) + parent = im.index(*new_comp) + child = im.index(*old_comp) hierarchy.extend([(parent, child, threshold)]) else: - parent = combine_integers(*new_comp) + parent = im.index(*new_comp) hierarchy.extend([(parent, old_comp.pop(), threshold)]) - parents, children, probs = zip(*hierarchy, strict=False) + parents, children, probs = zip(*hierarchy, strict=True) return pa.table( { "parent": pa.array(parents, type=dtype()), @@ -635,8 +639,8 @@ def to_hierarchical_clusters( with ProcessPoolExecutor(max_workers=n_cores) as executor: futures = [ - executor.submit(proc_func, component_table, dtype) - for component_table in component_tables + executor.submit(proc_func, component_table, dtype, salt) + for salt, component_table in enumerate(component_tables) ] for future in futures: diff --git a/test/client/test_hierarchy.py b/test/client/test_hierarchy.py index 6396a7b..4f73820 100644 --- a/test/client/test_hierarchy.py +++ b/test/client/test_hierarchy.py @@ -9,7 +9,6 @@ import pyarrow.compute as pc import pytest -from matchbox.common.hash import combine_integers from matchbox.common.results import ( attach_components_to_probabilities, component_to_hierarchy, @@ -119,22 +118,6 @@ def test_attach_components_to_probabilities(parameters: dict[str, Any]): assert len(pc.unique(with_components["component"])) == parameters["num_components"] -@pytest.mark.parametrize( - ("integer_list"), - [ - [1, 2, 3], - [9, 0], - [-4, -5, -6], - [7, -8, 9], - ], - ids=["positive", "pair_only", "negative", "mixed"], -) -def test_combine_integers(integer_list: list[int]): - res = combine_integers(*integer_list) - assert isinstance(res, int) - assert res < 0 - - @pytest.mark.parametrize( ("probabilities", "hierarchy"), [ @@ -195,9 +178,9 @@ def test_combine_integers(integer_list: list[int]): def test_component_to_hierarchy( probabilities: dict[str, list[str | float]], hierarchy: set[tuple[str, str, int]] ): - with patch( - "matchbox.common.results.combine_integers", side_effect=_combine_strings - ): + with patch("matchbox.common.results.IntMap") as MockIntMap: + instance = MockIntMap.return_value + instance.index.side_effect = _combine_strings probabilities_table = ( pa.Table.from_pydict(probabilities) .cast( @@ -340,8 +323,10 @@ def test_hierarchical_clusters(input_data, expected_hierarchy): "matchbox.common.results.ProcessPoolExecutor", lambda *args, **kwargs: parallel_pool_for_tests(timeout=30), ), - patch("matchbox.common.results.combine_integers", side_effect=_combine_strings), + patch("matchbox.common.results.IntMap") as MockIntMap, ): + instance = MockIntMap.return_value + instance.index.side_effect = _combine_strings result = to_hierarchical_clusters( probabilities, dtype=pa.string, proc_func=component_to_hierarchy ) diff --git a/test/common/test_hash.py b/test/common/test_hash.py new file mode 100644 index 0000000..7556ef7 --- /dev/null +++ b/test/common/test_hash.py @@ -0,0 +1,48 @@ +from matchbox.common.hash import IntMap + + +def test_intmap_basic(): + im1 = IntMap(salt=10) + a = im1.index(1, 2) + b = im1.index(3, 4) + c = im1.index(a, b) + + assert len({a, b, c}) == 3 + assert max(a, b, c) < 0 + + +def test_intmap_same(): + im1 = IntMap(salt=10) + a = im1.index(1, 2) + b = im1.index(3, 4) + c = im1.index(a, b) + + im2 = IntMap(salt=10) + x = im2.index(1, 2) + y = im2.index(3, 4) + z = im2.index(a, b) + + assert (a, b, c) == (x, y, z) + + +def test_intmap_different(): + im1 = IntMap(salt=10) + a = im1.index(1, 2) + b = im1.index(3, 4) + c = im1.index(a, b) + + im2 = IntMap(salt=11) + x = im2.index(1, 2) + y = im2.index(3, 4) + z = im2.index(a, b) + + for v1, v2 in zip([a, b, c], [x, y, z], strict=True): + assert v1 != v2 + + +def test_intmap_unordered(): + im1 = IntMap(salt=10) + a = im1.index(1, 2, 3) + b = im1.index(3, 1, 2) + + assert a == b