Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/reliable int mapping #30

Open
wants to merge 5 commits into
base: feature/new-ingest-process
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 42 additions & 33 deletions src/matchbox/common/hash.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import base64
import hashlib
from functools import lru_cache
from typing import TYPE_CHECKING, Any, TypeVar
from uuid import UUID

Expand Down Expand Up @@ -112,38 +111,48 @@ def columns_to_value_ordered_hash(data: DataFrame, columns: list[str]) -> Series
return Series(hashed_records)


@lru_cache(maxsize=None)
def combine_integers(*n: int) -> int:
class IntMap:
"""
Combine n integers into a single negative integer.

Used to create a symmetric deterministic hash of two integers that populates the
range of integers efficiently and reduces the likelihood of collisions.

Aims to vectorise amazingly when used in Arrow.

Does this by:

* Using a Mersenne prime as a modulus
* Making negative integers positive with modulo, sped up with bitwise operations
* Combining using symmetric operations with coprime multipliers

Args:
*args: Variable number of integers to combine

Returns:
A negative integer
A data structure taking unordered sets of integers, and mapping them a to a key that
1) is a negative integer; 2) does not collide with other keys generated by other
instances of this class, as long as they are initialised with a different salt.

The fact that keys are always negative means that it's possible to build a hierarchy
where keys are themselves parts of keyed sets, and it's easy to distinguish integers
mapped to raw data points (which will be non-negative), to integers that are keys to
sets (which will be negative). The salt allows to work with a parallel execution
model, where each worker maintains their separate key space, as long as each worker
operates on disjoint subsets of positive integers.
"""
P = 2147483647

total = 0
product = 1

for x in sorted(n):
x_pos = (x ^ (x >> 31)) - (x >> 31)
total = (total + x_pos) % P
product = (product * x_pos) % P

result = (31 * total + 37 * product) % P

return -result
def __init__(self, salt: int):
self.mapping: dict[frozenset[int], int] = {}
if salt < 0:
raise ValueError("The salt must be a positive integer")
self.salt: int = salt

def _salt_key(self, k: int) -> int:
"""
Use Cantor pairing function on the salt and a negative int key, and minus it.
"""
if k >= 0:
raise ValueError("Key must be a negative integer")
return -int(0.5 * (self.salt - k) * (self.salt - k + 1) - k)

def index(self, *values: int) -> int:
"""
Args:
values: the integers in the set you want to index

Returns:
The old or new key corresponding to the set
"""
value_set = frozenset(values)
if value_set in self.mapping:
return self.mapping[value_set]

new_key: int = -len(self.mapping) - 1
salted_key = self._salt_key(new_key)
self.mapping[value_set] = salted_key

return salted_key
31 changes: 18 additions & 13 deletions src/matchbox/common/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@

from matchbox.common.db import Cluster, Probability
from matchbox.common.hash import (
IntMap,
columns_to_value_ordered_hash,
combine_integers,
list_to_value_ordered_hash,
)
from matchbox.server.base import MatchboxDBAdapter, inject_backend
Expand Down Expand Up @@ -303,7 +303,7 @@ def to_clusters(results: ProbabilityResults) -> ClusterResults:
)

# Get unique probability thresholds, sorted
thresholds = edges_df["probability"].unique()
thresholds = sorted(edges_df["probability"].unique())

# Process edges grouped by probability threshold
for prob in thresholds:
Expand Down Expand Up @@ -367,7 +367,7 @@ def attach_components_to_probabilities(probabilities: pa.Table) -> pa.Table:
graph = rx.PyGraph(node_count_hint=n_nodes, edge_count_hint=n_edges)
graph.add_nodes_from(range(n_nodes))

edges = tuple(zip(left_indices.to_numpy(), right_indices.to_numpy(), strict=False))
edges = tuple(zip(left_indices.to_numpy(), right_indices.to_numpy(), strict=True))
graph.add_edges_from_no_data(edges)

components = rx.connected_components(graph)
Expand Down Expand Up @@ -413,6 +413,7 @@ def find(self, x: T, parent_dict: dict[T, T] | None = None) -> T:
self._shadow_parent[x] = x
self._shadow_rank[x] = 0

# TODO: Instead of being a `while`, could this be an `if`?
while parent_dict[x] != x:
parent_dict[x] = parent_dict[parent_dict[x]]
x = parent_dict[x]
Expand Down Expand Up @@ -512,7 +513,9 @@ def diff(self) -> Iterator[tuple[set[T], set[T]]]:
self._shadow_rank = self.rank.copy()


def component_to_hierarchy(table: pa.Table, dtype: pa.DataType = pa.int32) -> pa.Table:
def component_to_hierarchy(
table: pa.Table, dtype: pa.DataType = pa.int32, salt: int | None = None
) -> pa.Table:
"""
Convert pairwise probabilities into a hierarchical representation.

Expand All @@ -526,6 +529,7 @@ def component_to_hierarchy(table: pa.Table, dtype: pa.DataType = pa.int32) -> pa
"""
hierarchy: list[tuple[int, int, float]] = []
uf = UnionFindWithDiff[int]()
im = IntMap(salt=salt)
probs = pc.unique(table["probability"])

for threshold in probs:
Expand All @@ -537,31 +541,32 @@ def component_to_hierarchy(table: pa.Table, dtype: pa.DataType = pa.int32) -> pa
for row in zip(
current_probs["left"].to_numpy(),
current_probs["right"].to_numpy(),
strict=False,
strict=True,
):
left, right = row
uf.union(left, right)
parent = combine_integers(left, right)
parent = im.index(left, right)
hierarchy.extend([(parent, left, threshold), (parent, right, threshold)])

# Process union-find diffs
for old_comp, new_comp in uf.diff():
if len(old_comp) > 1:
parent = combine_integers(*new_comp)
child = combine_integers(*old_comp)
parent = im.index(*new_comp)
child = im.index(*old_comp)
hierarchy.extend([(parent, child, threshold)])
else:
parent = combine_integers(*new_comp)
parent = im.index(*new_comp)
hierarchy.extend([(parent, old_comp.pop(), threshold)])

parents, children, probs = zip(*hierarchy, strict=False)
return pa.table(
parents, children, probs = zip(*hierarchy, strict=True)
hierarchy_results = pa.table(
{
"parent": pa.array(parents, type=dtype()),
"child": pa.array(children, type=dtype()),
"probability": pa.array(probs, type=pa.uint8()),
}
)
return hierarchy_results


def to_hierarchical_clusters(
Expand Down Expand Up @@ -635,8 +640,8 @@ def to_hierarchical_clusters(

with ProcessPoolExecutor(max_workers=n_cores) as executor:
futures = [
executor.submit(proc_func, component_table, dtype)
for component_table in component_tables
executor.submit(proc_func, component_table, dtype, salt)
for salt, component_table in enumerate(component_tables)
]

for future in futures:
Expand Down
27 changes: 6 additions & 21 deletions test/client/test_hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pyarrow.compute as pc
import pytest

from matchbox.common.hash import combine_integers
from matchbox.common.results import (
attach_components_to_probabilities,
component_to_hierarchy,
Expand Down Expand Up @@ -119,22 +118,6 @@ def test_attach_components_to_probabilities(parameters: dict[str, Any]):
assert len(pc.unique(with_components["component"])) == parameters["num_components"]


@pytest.mark.parametrize(
("integer_list"),
[
[1, 2, 3],
[9, 0],
[-4, -5, -6],
[7, -8, 9],
],
ids=["positive", "pair_only", "negative", "mixed"],
)
def test_combine_integers(integer_list: list[int]):
res = combine_integers(*integer_list)
assert isinstance(res, int)
assert res < 0


@pytest.mark.parametrize(
("probabilities", "hierarchy"),
[
Expand Down Expand Up @@ -195,9 +178,9 @@ def test_combine_integers(integer_list: list[int]):
def test_component_to_hierarchy(
probabilities: dict[str, list[str | float]], hierarchy: set[tuple[str, str, int]]
):
with patch(
"matchbox.common.results.combine_integers", side_effect=_combine_strings
):
with patch("matchbox.common.results.IntMap") as MockIntMap:
instance = MockIntMap.return_value
instance.index.side_effect = _combine_strings
probabilities_table = (
pa.Table.from_pydict(probabilities)
.cast(
Expand Down Expand Up @@ -340,8 +323,10 @@ def test_hierarchical_clusters(input_data, expected_hierarchy):
"matchbox.common.results.ProcessPoolExecutor",
lambda *args, **kwargs: parallel_pool_for_tests(timeout=30),
),
patch("matchbox.common.results.combine_integers", side_effect=_combine_strings),
patch("matchbox.common.results.IntMap") as MockIntMap,
):
instance = MockIntMap.return_value
instance.index.side_effect = _combine_strings
result = to_hierarchical_clusters(
probabilities, dtype=pa.string, proc_func=component_to_hierarchy
)
Expand Down
48 changes: 48 additions & 0 deletions test/common/test_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from matchbox.common.hash import IntMap


def test_intmap_basic():
im1 = IntMap(salt=10)
a = im1.index(1, 2)
b = im1.index(3, 4)
c = im1.index(a, b)

assert len({a, b, c}) == 3
assert max(a, b, c) < 0


def test_intmap_same():
im1 = IntMap(salt=10)
a = im1.index(1, 2)
b = im1.index(3, 4)
c = im1.index(a, b)

im2 = IntMap(salt=10)
x = im2.index(1, 2)
y = im2.index(3, 4)
z = im2.index(a, b)

assert (a, b, c) == (x, y, z)


def test_intmap_different():
im1 = IntMap(salt=10)
a = im1.index(1, 2)
b = im1.index(3, 4)
c = im1.index(a, b)

im2 = IntMap(salt=11)
x = im2.index(1, 2)
y = im2.index(3, 4)
z = im2.index(a, b)

for v1, v2 in zip([a, b, c], [x, y, z], strict=True):
assert v1 != v2


def test_intmap_unordered():
im1 = IntMap(salt=10)
a = im1.index(1, 2, 3)
b = im1.index(3, 1, 2)

assert a == b
Loading