Skip to content

Commit

Permalink
Merge pull request #30 from uktrade/fix/reliable-int-mapping
Browse files Browse the repository at this point in the history
Fix/reliable int mapping
  • Loading branch information
lmazz1-dbt authored Dec 18, 2024
2 parents 6a4e93e + e28937c commit 68445d9
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 66 deletions.
75 changes: 42 additions & 33 deletions src/matchbox/common/hash.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import base64
import hashlib
from functools import lru_cache
from typing import TYPE_CHECKING, Any, TypeVar
from uuid import UUID

Expand Down Expand Up @@ -112,38 +111,48 @@ def columns_to_value_ordered_hash(data: DataFrame, columns: list[str]) -> Series
return Series(hashed_records)


@lru_cache(maxsize=None)
def combine_integers(*n: int) -> int:
class IntMap:
"""
Combine n integers into a single negative integer.
Used to create a symmetric deterministic hash of two integers that populates the
range of integers efficiently and reduces the likelihood of collisions.
Aims to vectorise amazingly when used in Arrow.
Does this by:
* Using a Mersenne prime as a modulus
* Making negative integers positive with modulo, sped up with bitwise operations
* Combining using symmetric operations with coprime multipliers
Args:
*args: Variable number of integers to combine
Returns:
A negative integer
A data structure taking unordered sets of integers, and mapping them a to an ID that
1) is a negative integer; 2) does not collide with other IDs generated by other
instances of this class, as long as they are initialised with a different salt.
The fact that IDs are always negative means that it's possible to build a hierarchy
where IDs are themselves parts of other sets, and it's easy to distinguish integers
mapped to raw data points (which will be non-negative), to integers that are IDs
(which will be negative). The salt allows to work with a parallel execution
model, where each worker maintains their separate ID space, as long as each worker
operates on disjoint subsets of positive integers.
"""
P = 2147483647

total = 0
product = 1

for x in sorted(n):
x_pos = (x ^ (x >> 31)) - (x >> 31)
total = (total + x_pos) % P
product = (product * x_pos) % P

result = (31 * total + 37 * product) % P

return -result
def __init__(self, salt: int):
self.mapping: dict[frozenset[int], int] = {}
if salt < 0:
raise ValueError("The salt must be a positive integer")
self.salt: int = salt

def _salt_id(self, i: int) -> int:
"""
Use Cantor pairing function on the salt and a negative int ID, and minus it.
"""
if i >= 0:
raise ValueError("ID must be a negative integer")
return -int(0.5 * (self.salt - i) * (self.salt - i + 1) - i)

def index(self, *values: int) -> int:
"""
Args:
values: the integers in the set you want to index
Returns:
The old or new ID corresponding to the set
"""
value_set = frozenset(values)
if value_set in self.mapping:
return self.mapping[value_set]

new_id: int = -len(self.mapping) - 1
salted_id = self._salt_id(new_id)
self.mapping[value_set] = salted_id

return salted_id
28 changes: 16 additions & 12 deletions src/matchbox/common/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@

from matchbox.common.db import Cluster, Probability
from matchbox.common.hash import (
IntMap,
columns_to_value_ordered_hash,
combine_integers,
list_to_value_ordered_hash,
)
from matchbox.server.base import MatchboxDBAdapter, inject_backend
Expand Down Expand Up @@ -303,7 +303,7 @@ def to_clusters(results: ProbabilityResults) -> ClusterResults:
)

# Get unique probability thresholds, sorted
thresholds = edges_df["probability"].unique()
thresholds = sorted(edges_df["probability"].unique())

# Process edges grouped by probability threshold
for prob in thresholds:
Expand Down Expand Up @@ -367,7 +367,7 @@ def attach_components_to_probabilities(probabilities: pa.Table) -> pa.Table:
graph = rx.PyGraph(node_count_hint=n_nodes, edge_count_hint=n_edges)
graph.add_nodes_from(range(n_nodes))

edges = tuple(zip(left_indices.to_numpy(), right_indices.to_numpy(), strict=False))
edges = tuple(zip(left_indices.to_numpy(), right_indices.to_numpy(), strict=True))
graph.add_edges_from_no_data(edges)

components = rx.connected_components(graph)
Expand Down Expand Up @@ -413,6 +413,7 @@ def find(self, x: T, parent_dict: dict[T, T] | None = None) -> T:
self._shadow_parent[x] = x
self._shadow_rank[x] = 0

# TODO: Instead of being a `while`, could this be an `if`?
while parent_dict[x] != x:
parent_dict[x] = parent_dict[parent_dict[x]]
x = parent_dict[x]
Expand Down Expand Up @@ -512,7 +513,9 @@ def diff(self) -> Iterator[tuple[set[T], set[T]]]:
self._shadow_rank = self.rank.copy()


def component_to_hierarchy(table: pa.Table, dtype: pa.DataType = pa.int32) -> pa.Table:
def component_to_hierarchy(
table: pa.Table, dtype: pa.DataType = pa.int32, salt: int = 1
) -> pa.Table:
"""
Convert pairwise probabilities into a hierarchical representation.
Expand All @@ -526,6 +529,7 @@ def component_to_hierarchy(table: pa.Table, dtype: pa.DataType = pa.int32) -> pa
"""
hierarchy: list[tuple[int, int, float]] = []
uf = UnionFindWithDiff[int]()
im = IntMap(salt=salt)
probs = pc.unique(table["probability"])

for threshold in probs:
Expand All @@ -537,24 +541,24 @@ def component_to_hierarchy(table: pa.Table, dtype: pa.DataType = pa.int32) -> pa
for row in zip(
current_probs["left"].to_numpy(),
current_probs["right"].to_numpy(),
strict=False,
strict=True,
):
left, right = row
uf.union(left, right)
parent = combine_integers(left, right)
parent = im.index(left, right)
hierarchy.extend([(parent, left, threshold), (parent, right, threshold)])

# Process union-find diffs
for old_comp, new_comp in uf.diff():
if len(old_comp) > 1:
parent = combine_integers(*new_comp)
child = combine_integers(*old_comp)
parent = im.index(*new_comp)
child = im.index(*old_comp)
hierarchy.extend([(parent, child, threshold)])
else:
parent = combine_integers(*new_comp)
parent = im.index(*new_comp)
hierarchy.extend([(parent, old_comp.pop(), threshold)])

parents, children, probs = zip(*hierarchy, strict=False)
parents, children, probs = zip(*hierarchy, strict=True)
return pa.table(
{
"parent": pa.array(parents, type=dtype()),
Expand Down Expand Up @@ -635,8 +639,8 @@ def to_hierarchical_clusters(

with ProcessPoolExecutor(max_workers=n_cores) as executor:
futures = [
executor.submit(proc_func, component_table, dtype)
for component_table in component_tables
executor.submit(proc_func, component_table, dtype, salt)
for salt, component_table in enumerate(component_tables)
]

for future in futures:
Expand Down
27 changes: 6 additions & 21 deletions test/client/test_hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pyarrow.compute as pc
import pytest

from matchbox.common.hash import combine_integers
from matchbox.common.results import (
attach_components_to_probabilities,
component_to_hierarchy,
Expand Down Expand Up @@ -119,22 +118,6 @@ def test_attach_components_to_probabilities(parameters: dict[str, Any]):
assert len(pc.unique(with_components["component"])) == parameters["num_components"]


@pytest.mark.parametrize(
("integer_list"),
[
[1, 2, 3],
[9, 0],
[-4, -5, -6],
[7, -8, 9],
],
ids=["positive", "pair_only", "negative", "mixed"],
)
def test_combine_integers(integer_list: list[int]):
res = combine_integers(*integer_list)
assert isinstance(res, int)
assert res < 0


@pytest.mark.parametrize(
("probabilities", "hierarchy"),
[
Expand Down Expand Up @@ -195,9 +178,9 @@ def test_combine_integers(integer_list: list[int]):
def test_component_to_hierarchy(
probabilities: dict[str, list[str | float]], hierarchy: set[tuple[str, str, int]]
):
with patch(
"matchbox.common.results.combine_integers", side_effect=_combine_strings
):
with patch("matchbox.common.results.IntMap") as MockIntMap:
instance = MockIntMap.return_value
instance.index.side_effect = _combine_strings
probabilities_table = (
pa.Table.from_pydict(probabilities)
.cast(
Expand Down Expand Up @@ -340,8 +323,10 @@ def test_hierarchical_clusters(input_data, expected_hierarchy):
"matchbox.common.results.ProcessPoolExecutor",
lambda *args, **kwargs: parallel_pool_for_tests(timeout=30),
),
patch("matchbox.common.results.combine_integers", side_effect=_combine_strings),
patch("matchbox.common.results.IntMap") as MockIntMap,
):
instance = MockIntMap.return_value
instance.index.side_effect = _combine_strings
result = to_hierarchical_clusters(
probabilities, dtype=pa.string, proc_func=component_to_hierarchy
)
Expand Down
48 changes: 48 additions & 0 deletions test/common/test_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from matchbox.common.hash import IntMap


def test_intmap_basic():
im1 = IntMap(salt=10)
a = im1.index(1, 2)
b = im1.index(3, 4)
c = im1.index(a, b)

assert len({a, b, c}) == 3
assert max(a, b, c) < 0


def test_intmap_same():
im1 = IntMap(salt=10)
a = im1.index(1, 2)
b = im1.index(3, 4)
c = im1.index(a, b)

im2 = IntMap(salt=10)
x = im2.index(1, 2)
y = im2.index(3, 4)
z = im2.index(a, b)

assert (a, b, c) == (x, y, z)


def test_intmap_different():
im1 = IntMap(salt=10)
a = im1.index(1, 2)
b = im1.index(3, 4)
c = im1.index(a, b)

im2 = IntMap(salt=11)
x = im2.index(1, 2)
y = im2.index(3, 4)
z = im2.index(a, b)

for v1, v2 in zip([a, b, c], [x, y, z], strict=True):
assert v1 != v2


def test_intmap_unordered():
im1 = IntMap(salt=10)
a = im1.index(1, 2, 3)
b = im1.index(3, 1, 2)

assert a == b

0 comments on commit 68445d9

Please sign in to comment.