Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/reliable int mapping #30

Merged
merged 7 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 42 additions & 33 deletions src/matchbox/common/hash.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import base64
import hashlib
from functools import lru_cache
from typing import TYPE_CHECKING, Any, TypeVar
from uuid import UUID

Expand Down Expand Up @@ -112,38 +111,48 @@ def columns_to_value_ordered_hash(data: DataFrame, columns: list[str]) -> Series
return Series(hashed_records)


@lru_cache(maxsize=None)
def combine_integers(*n: int) -> int:
class IntMap:
"""
Combine n integers into a single negative integer.

Used to create a symmetric deterministic hash of two integers that populates the
range of integers efficiently and reduces the likelihood of collisions.

Aims to vectorise amazingly when used in Arrow.

Does this by:

* Using a Mersenne prime as a modulus
* Making negative integers positive with modulo, sped up with bitwise operations
* Combining using symmetric operations with coprime multipliers

Args:
*args: Variable number of integers to combine

Returns:
A negative integer
A data structure taking unordered sets of integers, and mapping them a to an ID that
1) is a negative integer; 2) does not collide with other IDs generated by other
instances of this class, as long as they are initialised with a different salt.

The fact that IDs are always negative means that it's possible to build a hierarchy
where IDs are themselves parts of other sets, and it's easy to distinguish integers
mapped to raw data points (which will be non-negative), to integers that are IDs
(which will be negative). The salt allows to work with a parallel execution
model, where each worker maintains their separate ID space, as long as each worker
operates on disjoint subsets of positive integers.
"""
P = 2147483647

total = 0
product = 1

for x in sorted(n):
x_pos = (x ^ (x >> 31)) - (x >> 31)
total = (total + x_pos) % P
product = (product * x_pos) % P

result = (31 * total + 37 * product) % P

return -result
def __init__(self, salt: int):
wpfl-dbt marked this conversation as resolved.
Show resolved Hide resolved
self.mapping: dict[frozenset[int], int] = {}
if salt < 0:
raise ValueError("The salt must be a positive integer")
self.salt: int = salt

def _salt_id(self, i: int) -> int:
"""
Use Cantor pairing function on the salt and a negative int ID, and minus it.
"""
if i >= 0:
raise ValueError("ID must be a negative integer")
return -int(0.5 * (self.salt - i) * (self.salt - i + 1) - i)

def index(self, *values: int) -> int:
"""
Args:
values: the integers in the set you want to index

Returns:
The old or new ID corresponding to the set
"""
value_set = frozenset(values)
if value_set in self.mapping:
return self.mapping[value_set]

new_id: int = -len(self.mapping) - 1
salted_id = self._salt_id(new_id)
self.mapping[value_set] = salted_id

return salted_id
28 changes: 16 additions & 12 deletions src/matchbox/common/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@

from matchbox.common.db import Cluster, Probability
from matchbox.common.hash import (
IntMap,
columns_to_value_ordered_hash,
combine_integers,
list_to_value_ordered_hash,
)
from matchbox.server.base import MatchboxDBAdapter, inject_backend
Expand Down Expand Up @@ -303,7 +303,7 @@ def to_clusters(results: ProbabilityResults) -> ClusterResults:
)

# Get unique probability thresholds, sorted
thresholds = edges_df["probability"].unique()
thresholds = sorted(edges_df["probability"].unique())

# Process edges grouped by probability threshold
for prob in thresholds:
Expand Down Expand Up @@ -367,7 +367,7 @@ def attach_components_to_probabilities(probabilities: pa.Table) -> pa.Table:
graph = rx.PyGraph(node_count_hint=n_nodes, edge_count_hint=n_edges)
graph.add_nodes_from(range(n_nodes))

edges = tuple(zip(left_indices.to_numpy(), right_indices.to_numpy(), strict=False))
edges = tuple(zip(left_indices.to_numpy(), right_indices.to_numpy(), strict=True))
graph.add_edges_from_no_data(edges)

components = rx.connected_components(graph)
Expand Down Expand Up @@ -413,6 +413,7 @@ def find(self, x: T, parent_dict: dict[T, T] | None = None) -> T:
self._shadow_parent[x] = x
self._shadow_rank[x] = 0

# TODO: Instead of being a `while`, could this be an `if`?
while parent_dict[x] != x:
parent_dict[x] = parent_dict[parent_dict[x]]
x = parent_dict[x]
Expand Down Expand Up @@ -512,7 +513,9 @@ def diff(self) -> Iterator[tuple[set[T], set[T]]]:
self._shadow_rank = self.rank.copy()


def component_to_hierarchy(table: pa.Table, dtype: pa.DataType = pa.int32) -> pa.Table:
def component_to_hierarchy(
table: pa.Table, dtype: pa.DataType = pa.int32, salt: int = 1
) -> pa.Table:
"""
Convert pairwise probabilities into a hierarchical representation.

Expand All @@ -526,6 +529,7 @@ def component_to_hierarchy(table: pa.Table, dtype: pa.DataType = pa.int32) -> pa
"""
hierarchy: list[tuple[int, int, float]] = []
uf = UnionFindWithDiff[int]()
im = IntMap(salt=salt)
probs = pc.unique(table["probability"])

for threshold in probs:
Expand All @@ -537,24 +541,24 @@ def component_to_hierarchy(table: pa.Table, dtype: pa.DataType = pa.int32) -> pa
for row in zip(
current_probs["left"].to_numpy(),
current_probs["right"].to_numpy(),
strict=False,
strict=True,
):
left, right = row
uf.union(left, right)
parent = combine_integers(left, right)
parent = im.index(left, right)
hierarchy.extend([(parent, left, threshold), (parent, right, threshold)])

# Process union-find diffs
for old_comp, new_comp in uf.diff():
if len(old_comp) > 1:
parent = combine_integers(*new_comp)
child = combine_integers(*old_comp)
parent = im.index(*new_comp)
child = im.index(*old_comp)
hierarchy.extend([(parent, child, threshold)])
else:
parent = combine_integers(*new_comp)
parent = im.index(*new_comp)
hierarchy.extend([(parent, old_comp.pop(), threshold)])

parents, children, probs = zip(*hierarchy, strict=False)
parents, children, probs = zip(*hierarchy, strict=True)
return pa.table(
{
"parent": pa.array(parents, type=dtype()),
Expand Down Expand Up @@ -635,8 +639,8 @@ def to_hierarchical_clusters(

with ProcessPoolExecutor(max_workers=n_cores) as executor:
futures = [
executor.submit(proc_func, component_table, dtype)
for component_table in component_tables
executor.submit(proc_func, component_table, dtype, salt)
for salt, component_table in enumerate(component_tables)
]

for future in futures:
Expand Down
27 changes: 6 additions & 21 deletions test/client/test_hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pyarrow.compute as pc
import pytest

from matchbox.common.hash import combine_integers
from matchbox.common.results import (
attach_components_to_probabilities,
component_to_hierarchy,
Expand Down Expand Up @@ -119,22 +118,6 @@ def test_attach_components_to_probabilities(parameters: dict[str, Any]):
assert len(pc.unique(with_components["component"])) == parameters["num_components"]


@pytest.mark.parametrize(
("integer_list"),
[
[1, 2, 3],
[9, 0],
[-4, -5, -6],
[7, -8, 9],
],
ids=["positive", "pair_only", "negative", "mixed"],
)
def test_combine_integers(integer_list: list[int]):
res = combine_integers(*integer_list)
assert isinstance(res, int)
assert res < 0


@pytest.mark.parametrize(
("probabilities", "hierarchy"),
[
Expand Down Expand Up @@ -195,9 +178,9 @@ def test_combine_integers(integer_list: list[int]):
def test_component_to_hierarchy(
probabilities: dict[str, list[str | float]], hierarchy: set[tuple[str, str, int]]
):
with patch(
"matchbox.common.results.combine_integers", side_effect=_combine_strings
):
with patch("matchbox.common.results.IntMap") as MockIntMap:
instance = MockIntMap.return_value
instance.index.side_effect = _combine_strings
probabilities_table = (
pa.Table.from_pydict(probabilities)
.cast(
Expand Down Expand Up @@ -340,8 +323,10 @@ def test_hierarchical_clusters(input_data, expected_hierarchy):
"matchbox.common.results.ProcessPoolExecutor",
lambda *args, **kwargs: parallel_pool_for_tests(timeout=30),
),
patch("matchbox.common.results.combine_integers", side_effect=_combine_strings),
patch("matchbox.common.results.IntMap") as MockIntMap,
):
instance = MockIntMap.return_value
instance.index.side_effect = _combine_strings
result = to_hierarchical_clusters(
probabilities, dtype=pa.string, proc_func=component_to_hierarchy
)
Expand Down
48 changes: 48 additions & 0 deletions test/common/test_hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from matchbox.common.hash import IntMap


def test_intmap_basic():
im1 = IntMap(salt=10)
a = im1.index(1, 2)
b = im1.index(3, 4)
c = im1.index(a, b)

assert len({a, b, c}) == 3
assert max(a, b, c) < 0


def test_intmap_same():
im1 = IntMap(salt=10)
a = im1.index(1, 2)
b = im1.index(3, 4)
c = im1.index(a, b)

im2 = IntMap(salt=10)
x = im2.index(1, 2)
y = im2.index(3, 4)
z = im2.index(a, b)

assert (a, b, c) == (x, y, z)


def test_intmap_different():
im1 = IntMap(salt=10)
a = im1.index(1, 2)
b = im1.index(3, 4)
c = im1.index(a, b)

im2 = IntMap(salt=11)
x = im2.index(1, 2)
y = im2.index(3, 4)
z = im2.index(a, b)

for v1, v2 in zip([a, b, c], [x, y, z], strict=True):
assert v1 != v2


def test_intmap_unordered():
im1 = IntMap(salt=10)
a = im1.index(1, 2, 3)
b = im1.index(3, 1, 2)

assert a == b
Loading