Skip to content

Commit

Permalink
Working debug in unit tests, removed a bunch of hard-coded references…
Browse files Browse the repository at this point in the history
… to SHA-1, first working unit test
  • Loading branch information
Will Langdale committed Oct 17, 2024
1 parent e4f60b6 commit 62040c4
Show file tree
Hide file tree
Showing 23 changed files with 305 additions and 257 deletions.
24 changes: 24 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Matchbox: Debug",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"purpose": ["debug-test"],
"console": "integratedTerminal",
"justMyCode": false,
"env": {
"PYTEST_ADDOPTS": "--no-cov",
"PYTHONPATH": "${workspaceFolder}"
},
"python": "${workspaceFolder}/.venv/bin/python",
"cwd": "${workspaceFolder}",
"args": [
"-v",
"-s"
]
}
]
}
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@
"editor.codeActionsOnSave": {
"source.fixAll": "explicit"
}
}
},
"python.testing.pytestPath": "${workspaceFolder}/.venv/bin/pytest"
}
102 changes: 102 additions & 0 deletions src/matchbox/common/hash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import hashlib
from typing import Any, TypeVar
from uuid import UUID

from matchbox.server.base import IndexableDataset
from pandas import DataFrame, Series
from sqlalchemy import String, func, select
from sqlalchemy.orm import Session

T = TypeVar("T")
HashableItem = TypeVar("HashableItem", bytes, bool, str, int, float, bytearray)

HASH_FUNC = hashlib.sha1


def dataset_to_hashlist(dataset: IndexableDataset, uuid: UUID) -> list[dict[str, Any]]:
"""Retrieve and hash a dataset from its warehouse, ready to be inserted."""
with Session(dataset.database.engine) as warehouse_session:
source_table = dataset.to_table()

# Exclude the primary key from the columns to be hashed
cols = tuple(
[col for col in list(source_table.c.keys()) if col != dataset.db_pk]
)

slct_stmt = select(
func.concat(*source_table.c[cols]).label("raw"),
func.array_agg(source_table.c[dataset.db_pk].cast(String)).label("id"),
).group_by(*source_table.c[cols])

raw_result = warehouse_session.execute(slct_stmt)

to_insert = [
{
"sha1": hash_data(data.raw),
"id": data.id,
"dataset": uuid,
}
for data in raw_result.all()
]

return to_insert


def prep_for_hash(item: HashableItem) -> bytes:
"""Encodes strings so they can be hashed, otherwises, passes through."""
if isinstance(item, bytes):
return item
elif isinstance(item, str):
return bytes(item.encode())
elif isinstance(item, UUID):
return item.bytes
else:
return bytes(item)


def hash_data(data: str) -> bytes:
"""
Hash the given data using the globally defined hash function.
This function ties into the existing hashing utilities.
"""
return HASH_FUNC(prep_for_hash(data)).digest()


def list_to_value_ordered_hash(list_: list[T]) -> bytes:
"""Returns a single hash of a list ordered by its values.
List must be sorted as the different orders of value must produce the same hash.
"""
try:
sorted_vals = sorted(list_)
except TypeError as e:
raise TypeError("Can only order lists or columns of the same datatype.") from e

hashed_vals_list = [HASH_FUNC(prep_for_hash(i)) for i in sorted_vals]

Check failure

Code scanning / CodeQL

Use of a broken or weak cryptographic hashing algorithm on sensitive data High

Sensitive data (id)
is used in a hashing algorithm (SHA1) that is insecure.
Sensitive data (id)
is used in a hashing algorithm (SHA1) that is insecure.
Sensitive data (id)
is used in a hashing algorithm (SHA1) that is insecure.

hashed_vals = hashed_vals_list[0]
for val in hashed_vals_list[1:]:
hashed_vals.update(val.digest())

return hashed_vals.digest()


def columns_to_value_ordered_hash(data: DataFrame, columns: list[str]) -> Series:
"""Returns the rowwise hash ordered by the row's values, ignoring column order.
This function is used to add a column to a dataframe that represents the
hash of each its rows, but where the order of the row values doesn't change the
hash value. Column order is ignored in favour of value order.
This is primarily used to give a consistent hash to a new cluster no matter whether
its parent hashes were used in the left or right table.
"""
bytes_records = data.filter(columns).astype(bytes).to_dict("records")

hashed_records = []

for record in bytes_records:
hashed_vals = list_to_value_ordered_hash(record.values())
hashed_records.append(hashed_vals)

return Series(hashed_records)
14 changes: 7 additions & 7 deletions src/matchbox/common/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import rustworkx as rx
from dotenv import find_dotenv, load_dotenv
from matchbox.common.sha1 import (
columns_to_value_ordered_sha1,
list_to_value_ordered_sha1,
from matchbox.common.hash import (
columns_to_value_ordered_hash,
list_to_value_ordered_hash,
)
from matchbox.server.base import Cluster, MatchboxDBAdapter, Probability
from pandas import DataFrame, concat
Expand Down Expand Up @@ -150,7 +150,7 @@ def to_records(self, backend: MatchboxDBAdapter | None) -> list[Probability]:
pre_prep_df[["left_id", "right_id"]] = pre_prep_df[
["left_id", "right_id"]
].astype("binary[pyarrow]")
pre_prep_df["sha1"] = columns_to_value_ordered_sha1(
pre_prep_df["sha1"] = columns_to_value_ordered_hash(
data=pre_prep_df, columns=["left_id", "right_id"]
)
pre_prep_df["sha1"] = pre_prep_df["sha1"].astype("binary[pyarrow]")
Expand Down Expand Up @@ -257,7 +257,7 @@ def get_unclustered(
Args:
clusters (ClusterResults): a ClusterResults generated by a linker or deduper
data (DataFrame): cleaned data that went into the model
key (str): the column that was matched, usually data_sha1 or cluster_sha1
key (str): the column that was matched, usually data_hash or cluster_hash
Returns:
A ClusterResults object
Expand Down Expand Up @@ -297,7 +297,7 @@ def to_clusters(
Args:
results (ProbabilityResults): an object of class ProbabilityResults
key (str): the column that was matched, usually data_sha1 or cluster_sha1
key (str): the column that was matched, usually data_hash or cluster_hash
threshold (float): the value above which to consider probabilities true
data (DataFrame): (optional) Any number of cleaned data that went into
the model. Typically this is one dataset for a deduper or two for a
Expand Down Expand Up @@ -339,7 +339,7 @@ def to_clusters(
res["child"].append(child_hash)

# Must be sorted to be symmetric
parent_hash = list_to_value_ordered_sha1(child_hashes)
parent_hash = list_to_value_ordered_hash(child_hashes)

res["parent"] += [parent_hash] * len(component)

Expand Down
60 changes: 0 additions & 60 deletions src/matchbox/common/sha1.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/matchbox/dedupers/make_deduper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class DeduperSettings(BaseModel):
@field_validator("id")
@classmethod
def _id_for_cmf(cls, v: str, info: ValidationInfo) -> str:
enforce = "data_sha1"
enforce = "data_hash"
if v != enforce:
warnings.warn(
f"For offline deduplication, {info.field_name} can be any field. \n\n"
Expand Down
2 changes: 1 addition & 1 deletion src/matchbox/linkers/make_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class LinkerSettings(BaseModel):
@field_validator("left_id", "right_id")
@classmethod
def _id_for_cmf(cls, v: str, info: ValidationInfo) -> str:
enforce = "cluster_sha1"
enforce = "cluster_hash"
if v != enforce:
warnings.warn(
f"For offline deduplication, {info.field_name} can be any field. \n\n"
Expand Down
4 changes: 2 additions & 2 deletions src/matchbox/linkers/weighteddeterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class WeightedDeterministicSettings(LinkerSettings):
Example:
>>> {
... left_id: "cluster_sha1",
... right_id: "cluster_sha1",
... left_id: "cluster_hash",
... right_id: "cluster_hash",
... weighted_comparisons: [
... ("l.company_name = r.company_name", .7),
... ("l.postcode = r.postcode", .7),
Expand Down
8 changes: 7 additions & 1 deletion src/matchbox/server/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from rustworkx import PyDiGraph
from sqlalchemy import create_engine
from sqlalchemy import MetaData, Table, create_engine
from sqlalchemy import text as sqltext
from sqlalchemy.engine import Engine
from sqlalchemy.engine.result import ChunkedIteratorResult
Expand Down Expand Up @@ -102,6 +102,12 @@ class Config:
def __str__(self) -> str:
return f"{self.db_schema}.{self.db_table}"

def to_table(self) -> Table:
"""Returns the dataset as a SQLAlchemy Table object."""
metadata = MetaData(schema=self.db_schema)
table = Table(self.db_table, metadata, autoload_with=self.database.engine)
return table


class MatchboxModelAdapter(ABC):
"""An abstract base class for Matchbox model adapters."""
Expand Down
9 changes: 4 additions & 5 deletions src/matchbox/server/postgresql/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from matchbox.server.postgresql.models import Models, ModelsFrom
from matchbox.server.postgresql.utils.db import get_model_subgraph
from matchbox.server.postgresql.utils.delete import delete_model
from matchbox.server.postgresql.utils.hash import table_name_to_uuid
from matchbox.server.postgresql.utils.index import index_dataset
from matchbox.server.postgresql.utils.insert import (
insert_clusters,
Expand All @@ -34,7 +35,6 @@
insert_probabilities,
)
from matchbox.server.postgresql.utils.selector import query
from matchbox.server.postgresql.utils.sha1 import table_name_to_uuid


class MergesUnion:
Expand Down Expand Up @@ -171,7 +171,6 @@ def index(self, dataset: IndexableDataset) -> None:
index_dataset(
dataset=dataset,
engine=MBDB.get_engine(),
warehouse_engine=dataset.database.engine(),
)

def validate_hashes(
Expand All @@ -188,18 +187,18 @@ def validate_hashes(
"""
if hash_type == "data":
Source = SourceData
tgt_col = "data_sha1"
tgt_col = "data_hash"
elif hash_type == "cluster":
Source = Clusters
tgt_col = "cluster_sha1"
tgt_col = "cluster_hash"

with Session(MBDB.get_engine()) as session:
data_inner_join = (
session.query(Source)
.filter(
Source.sha1.in_(
bindparam(
"ins_sha1s",
"ins_hashs",
hashes,
expanding=True,
)
Expand Down
16 changes: 5 additions & 11 deletions src/matchbox/server/postgresql/db.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
from dotenv import find_dotenv, load_dotenv
from pydantic import BaseModel, Field
from sqlalchemy import (
Engine,
create_engine,
text,
)
from sqlalchemy.orm import (
declarative_base,
sessionmaker,
)
from sqlalchemy import Engine, MetaData, create_engine, text
from sqlalchemy.orm import declarative_base, sessionmaker

from matchbox.server.base import MatchboxBackends, MatchboxSettings

Expand Down Expand Up @@ -47,7 +40,9 @@ def __init__(self, settings: MatchboxPostgresSettings):
self.settings = settings
self.engine: Engine | None = None
self.SessionLocal: sessionmaker | None = None
self.MatchboxBase = declarative_base()
self.MatchboxBase = declarative_base(
metadata=MetaData(schema=settings.postgres.db_schema)
)

def connect(self):
"""Connect to the database."""
Expand All @@ -62,7 +57,6 @@ def connect(self):
self.SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=self.engine
)
self.MatchboxBase.metadata.schema = self.settings.postgres.db_schema

def get_engine(self) -> Engine:
"""Get the database engine."""
Expand Down
Loading

0 comments on commit 62040c4

Please sign in to comment.