diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..fee3db3 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,24 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Matchbox: Debug", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "purpose": ["debug-test"], + "console": "integratedTerminal", + "justMyCode": false, + "env": { + "PYTEST_ADDOPTS": "--no-cov", + "PYTHONPATH": "${workspaceFolder}" + }, + "python": "${workspaceFolder}/.venv/bin/python", + "cwd": "${workspaceFolder}", + "args": [ + "-v", + "-s" + ] + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 237f7f4..5562e2b 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -10,5 +10,6 @@ "editor.codeActionsOnSave": { "source.fixAll": "explicit" } - } + }, + "python.testing.pytestPath": "${workspaceFolder}/.venv/bin/pytest" } \ No newline at end of file diff --git a/src/matchbox/common/hash.py b/src/matchbox/common/hash.py new file mode 100644 index 0000000..e4278f0 --- /dev/null +++ b/src/matchbox/common/hash.py @@ -0,0 +1,102 @@ +import hashlib +from typing import Any, TypeVar +from uuid import UUID + +from matchbox.server.base import IndexableDataset +from pandas import DataFrame, Series +from sqlalchemy import String, func, select +from sqlalchemy.orm import Session + +T = TypeVar("T") +HashableItem = TypeVar("HashableItem", bytes, bool, str, int, float, bytearray) + +HASH_FUNC = hashlib.sha1 + + +def dataset_to_hashlist(dataset: IndexableDataset, uuid: UUID) -> list[dict[str, Any]]: + """Retrieve and hash a dataset from its warehouse, ready to be inserted.""" + with Session(dataset.database.engine) as warehouse_session: + source_table = dataset.to_table() + + # Exclude the primary key from the columns to be hashed + cols = tuple( + [col for col in list(source_table.c.keys()) if col != dataset.db_pk] + ) + + slct_stmt = select( + func.concat(*source_table.c[cols]).label("raw"), + func.array_agg(source_table.c[dataset.db_pk].cast(String)).label("id"), + ).group_by(*source_table.c[cols]) + + raw_result = warehouse_session.execute(slct_stmt) + + to_insert = [ + { + "sha1": hash_data(data.raw), + "id": data.id, + "dataset": uuid, + } + for data in raw_result.all() + ] + + return to_insert + + +def prep_for_hash(item: HashableItem) -> bytes: + """Encodes strings so they can be hashed, otherwises, passes through.""" + if isinstance(item, bytes): + return item + elif isinstance(item, str): + return bytes(item.encode()) + elif isinstance(item, UUID): + return item.bytes + else: + return bytes(item) + + +def hash_data(data: str) -> bytes: + """ + Hash the given data using the globally defined hash function. + This function ties into the existing hashing utilities. + """ + return HASH_FUNC(prep_for_hash(data)).digest() + + +def list_to_value_ordered_hash(list_: list[T]) -> bytes: + """Returns a single hash of a list ordered by its values. + + List must be sorted as the different orders of value must produce the same hash. + """ + try: + sorted_vals = sorted(list_) + except TypeError as e: + raise TypeError("Can only order lists or columns of the same datatype.") from e + + hashed_vals_list = [HASH_FUNC(prep_for_hash(i)) for i in sorted_vals] + + hashed_vals = hashed_vals_list[0] + for val in hashed_vals_list[1:]: + hashed_vals.update(val.digest()) + + return hashed_vals.digest() + + +def columns_to_value_ordered_hash(data: DataFrame, columns: list[str]) -> Series: + """Returns the rowwise hash ordered by the row's values, ignoring column order. + + This function is used to add a column to a dataframe that represents the + hash of each its rows, but where the order of the row values doesn't change the + hash value. Column order is ignored in favour of value order. + + This is primarily used to give a consistent hash to a new cluster no matter whether + its parent hashes were used in the left or right table. + """ + bytes_records = data.filter(columns).astype(bytes).to_dict("records") + + hashed_records = [] + + for record in bytes_records: + hashed_vals = list_to_value_ordered_hash(record.values()) + hashed_records.append(hashed_vals) + + return Series(hashed_records) diff --git a/src/matchbox/common/results.py b/src/matchbox/common/results.py index 4935952..c1ba168 100644 --- a/src/matchbox/common/results.py +++ b/src/matchbox/common/results.py @@ -4,9 +4,9 @@ import rustworkx as rx from dotenv import find_dotenv, load_dotenv -from matchbox.common.sha1 import ( - columns_to_value_ordered_sha1, - list_to_value_ordered_sha1, +from matchbox.common.hash import ( + columns_to_value_ordered_hash, + list_to_value_ordered_hash, ) from matchbox.server.base import Cluster, MatchboxDBAdapter, Probability from pandas import DataFrame, concat @@ -150,7 +150,7 @@ def to_records(self, backend: MatchboxDBAdapter | None) -> list[Probability]: pre_prep_df[["left_id", "right_id"]] = pre_prep_df[ ["left_id", "right_id"] ].astype("binary[pyarrow]") - pre_prep_df["sha1"] = columns_to_value_ordered_sha1( + pre_prep_df["sha1"] = columns_to_value_ordered_hash( data=pre_prep_df, columns=["left_id", "right_id"] ) pre_prep_df["sha1"] = pre_prep_df["sha1"].astype("binary[pyarrow]") @@ -257,7 +257,7 @@ def get_unclustered( Args: clusters (ClusterResults): a ClusterResults generated by a linker or deduper data (DataFrame): cleaned data that went into the model - key (str): the column that was matched, usually data_sha1 or cluster_sha1 + key (str): the column that was matched, usually data_hash or cluster_hash Returns: A ClusterResults object @@ -297,7 +297,7 @@ def to_clusters( Args: results (ProbabilityResults): an object of class ProbabilityResults - key (str): the column that was matched, usually data_sha1 or cluster_sha1 + key (str): the column that was matched, usually data_hash or cluster_hash threshold (float): the value above which to consider probabilities true data (DataFrame): (optional) Any number of cleaned data that went into the model. Typically this is one dataset for a deduper or two for a @@ -339,7 +339,7 @@ def to_clusters( res["child"].append(child_hash) # Must be sorted to be symmetric - parent_hash = list_to_value_ordered_sha1(child_hashes) + parent_hash = list_to_value_ordered_hash(child_hashes) res["parent"] += [parent_hash] * len(component) diff --git a/src/matchbox/common/sha1.py b/src/matchbox/common/sha1.py deleted file mode 100644 index f45b4d9..0000000 --- a/src/matchbox/common/sha1.py +++ /dev/null @@ -1,60 +0,0 @@ -import hashlib -import uuid -from typing import TypeVar - -from pandas import DataFrame, Series - -T = TypeVar("T") -HashableItem = TypeVar("HashableItem", bytes, bool, str, int, float, bytearray) - - -def prep_for_hash(item: HashableItem) -> bytes: - """Encodes strings so they can be hashed, otherwises, passes through.""" - if isinstance(item, bytes): - return item - elif isinstance(item, str): - return bytes(item.encode()) - elif isinstance(item, uuid.UUID): - return item.bytes - else: - return bytes(item) - - -def list_to_value_ordered_sha1(list_: list[T]) -> bytes: - """Returns a single SHA1 hash of a list ordered by its values. - - List must be sorted as the different orders of value must produce the same hash. - """ - try: - sorted_vals = sorted(list_) - except TypeError as e: - raise TypeError("Can only order lists or columns of the same datatype.") from e - - hashed_vals_list = [hashlib.sha1(prep_for_hash(i)) for i in sorted_vals] - - hashed_vals = hashed_vals_list[0] - for val in hashed_vals_list[1:]: - hashed_vals.update(val.digest()) - - return hashed_vals.digest() - - -def columns_to_value_ordered_sha1(data: DataFrame, columns: list[str]) -> Series: - """Returns the rowwise SHA1 hash ordered by the row's values, ignoring column order. - - This function is used to add a column to a dataframe that represents the SHA1 - hash of each its rows, but where the order of the row values doesn't change the - hash value. Column order is ignored in favour of value order. - - This is primarily used to give a consistent hash to a new cluster no matter whether - its parent hashes were used in the left or right table. - """ - bytes_records = data.filter(columns).astype(bytes).to_dict("records") - - hashed_records = [] - - for record in bytes_records: - hashed_vals = list_to_value_ordered_sha1(record.values()) - hashed_records.append(hashed_vals) - - return Series(hashed_records) diff --git a/src/matchbox/dedupers/make_deduper.py b/src/matchbox/dedupers/make_deduper.py index 8dfa293..bca7207 100644 --- a/src/matchbox/dedupers/make_deduper.py +++ b/src/matchbox/dedupers/make_deduper.py @@ -18,7 +18,7 @@ class DeduperSettings(BaseModel): @field_validator("id") @classmethod def _id_for_cmf(cls, v: str, info: ValidationInfo) -> str: - enforce = "data_sha1" + enforce = "data_hash" if v != enforce: warnings.warn( f"For offline deduplication, {info.field_name} can be any field. \n\n" diff --git a/src/matchbox/linkers/make_linker.py b/src/matchbox/linkers/make_linker.py index 3006aaa..297e593 100644 --- a/src/matchbox/linkers/make_linker.py +++ b/src/matchbox/linkers/make_linker.py @@ -19,7 +19,7 @@ class LinkerSettings(BaseModel): @field_validator("left_id", "right_id") @classmethod def _id_for_cmf(cls, v: str, info: ValidationInfo) -> str: - enforce = "cluster_sha1" + enforce = "cluster_hash" if v != enforce: warnings.warn( f"For offline deduplication, {info.field_name} can be any field. \n\n" diff --git a/src/matchbox/linkers/weighteddeterministic.py b/src/matchbox/linkers/weighteddeterministic.py index 9257dcd..e9f9e42 100644 --- a/src/matchbox/linkers/weighteddeterministic.py +++ b/src/matchbox/linkers/weighteddeterministic.py @@ -45,8 +45,8 @@ class WeightedDeterministicSettings(LinkerSettings): Example: >>> { - ... left_id: "cluster_sha1", - ... right_id: "cluster_sha1", + ... left_id: "cluster_hash", + ... right_id: "cluster_hash", ... weighted_comparisons: [ ... ("l.company_name = r.company_name", .7), ... ("l.postcode = r.postcode", .7), diff --git a/src/matchbox/server/base.py b/src/matchbox/server/base.py index a522d86..b105264 100644 --- a/src/matchbox/server/base.py +++ b/src/matchbox/server/base.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, Field from pydantic_settings import BaseSettings, SettingsConfigDict from rustworkx import PyDiGraph -from sqlalchemy import create_engine +from sqlalchemy import MetaData, Table, create_engine from sqlalchemy import text as sqltext from sqlalchemy.engine import Engine from sqlalchemy.engine.result import ChunkedIteratorResult @@ -102,6 +102,12 @@ class Config: def __str__(self) -> str: return f"{self.db_schema}.{self.db_table}" + def to_table(self) -> Table: + """Returns the dataset as a SQLAlchemy Table object.""" + metadata = MetaData(schema=self.db_schema) + table = Table(self.db_table, metadata, autoload_with=self.database.engine) + return table + class MatchboxModelAdapter(ABC): """An abstract base class for Matchbox model adapters.""" diff --git a/src/matchbox/server/postgresql/adapter.py b/src/matchbox/server/postgresql/adapter.py index 5bc1200..5386974 100644 --- a/src/matchbox/server/postgresql/adapter.py +++ b/src/matchbox/server/postgresql/adapter.py @@ -26,6 +26,7 @@ from matchbox.server.postgresql.models import Models, ModelsFrom from matchbox.server.postgresql.utils.db import get_model_subgraph from matchbox.server.postgresql.utils.delete import delete_model +from matchbox.server.postgresql.utils.hash import table_name_to_uuid from matchbox.server.postgresql.utils.index import index_dataset from matchbox.server.postgresql.utils.insert import ( insert_clusters, @@ -34,7 +35,6 @@ insert_probabilities, ) from matchbox.server.postgresql.utils.selector import query -from matchbox.server.postgresql.utils.sha1 import table_name_to_uuid class MergesUnion: @@ -171,7 +171,6 @@ def index(self, dataset: IndexableDataset) -> None: index_dataset( dataset=dataset, engine=MBDB.get_engine(), - warehouse_engine=dataset.database.engine(), ) def validate_hashes( @@ -188,10 +187,10 @@ def validate_hashes( """ if hash_type == "data": Source = SourceData - tgt_col = "data_sha1" + tgt_col = "data_hash" elif hash_type == "cluster": Source = Clusters - tgt_col = "cluster_sha1" + tgt_col = "cluster_hash" with Session(MBDB.get_engine()) as session: data_inner_join = ( @@ -199,7 +198,7 @@ def validate_hashes( .filter( Source.sha1.in_( bindparam( - "ins_sha1s", + "ins_hashs", hashes, expanding=True, ) diff --git a/src/matchbox/server/postgresql/db.py b/src/matchbox/server/postgresql/db.py index 2a3725e..247c379 100644 --- a/src/matchbox/server/postgresql/db.py +++ b/src/matchbox/server/postgresql/db.py @@ -1,14 +1,7 @@ from dotenv import find_dotenv, load_dotenv from pydantic import BaseModel, Field -from sqlalchemy import ( - Engine, - create_engine, - text, -) -from sqlalchemy.orm import ( - declarative_base, - sessionmaker, -) +from sqlalchemy import Engine, MetaData, create_engine, text +from sqlalchemy.orm import declarative_base, sessionmaker from matchbox.server.base import MatchboxBackends, MatchboxSettings @@ -47,7 +40,9 @@ def __init__(self, settings: MatchboxPostgresSettings): self.settings = settings self.engine: Engine | None = None self.SessionLocal: sessionmaker | None = None - self.MatchboxBase = declarative_base() + self.MatchboxBase = declarative_base( + metadata=MetaData(schema=settings.postgres.db_schema) + ) def connect(self): """Connect to the database.""" @@ -62,7 +57,6 @@ def connect(self): self.SessionLocal = sessionmaker( autocommit=False, autoflush=False, bind=self.engine ) - self.MatchboxBase.metadata.schema = self.settings.postgres.db_schema def get_engine(self) -> Engine: """Get the database engine.""" diff --git a/src/matchbox/server/postgresql/utils/db.py b/src/matchbox/server/postgresql/utils/db.py index c833068..32196ec 100644 --- a/src/matchbox/server/postgresql/utils/db.py +++ b/src/matchbox/server/postgresql/utils/db.py @@ -7,48 +7,13 @@ import rustworkx as rx from pg_bulk_ingest import Delete, Upsert, ingest -from sqlalchemy import Engine, MetaData, Table +from sqlalchemy import Engine, Table from sqlalchemy.engine.base import Connection -from sqlalchemy.exc import NoSuchTableError from sqlalchemy.orm import Session -from matchbox.common.exceptions import MatchboxSourceTableError from matchbox.server.postgresql.data import SourceDataset from matchbox.server.postgresql.models import Models, ModelsFrom -# Data conversion - - -def dataset_to_table(dataset: SourceDataset, engine: Engine) -> Table: - """Takes a CMF SourceDataset object and returns a SQLAlchemy Table.""" - with Session(engine) as session: - source_schema = MetaData(schema=dataset.db_schema) - try: - source_table = Table( - dataset.db_table, - source_schema, - schema=dataset.db_schema, - autoload_with=session.get_bind(), - ) - except NoSuchTableError as e: - raise MatchboxSourceTableError( - table_name=f"{dataset.db_schema}.{dataset.db_table}" - ) from e - - return source_table - - -def string_to_dataset(db_schema: str, db_table: str, engine: Engine) -> SourceDataset: - """Takes strings and returns a CMF SourceDataset""" - with Session(engine) as session: - dataset = ( - session.query(SourceDataset) - .filter_by(db_schema=db_schema, db_table=db_table) - .first() - ) - return dataset - - # Retrieval diff --git a/src/matchbox/server/postgresql/utils/sha1.py b/src/matchbox/server/postgresql/utils/hash.py similarity index 85% rename from src/matchbox/server/postgresql/utils/sha1.py rename to src/matchbox/server/postgresql/utils/hash.py index 1ccdfbf..86f975b 100644 --- a/src/matchbox/server/postgresql/utils/sha1.py +++ b/src/matchbox/server/postgresql/utils/hash.py @@ -34,8 +34,8 @@ def table_name_to_uuid(schema_table: str, engine: Engine) -> bytes: return dataset_uuid -def model_name_to_sha1(run_name: str, engine: Engine) -> bytes: - """Takes a model's name and returns its SHA-1 hash. +def model_name_to_hash(run_name: str, engine: Engine) -> bytes: + """Takes a model's name and returns its hash. Args: run_name (str): The string name of the model in the database @@ -45,13 +45,13 @@ def model_name_to_sha1(run_name: str, engine: Engine) -> bytes: CMFSourceError if model not found in database Returns: - The SHA-1 hash of the model + The hash of the model """ with Session(engine) as session: stmt = select(Models.sha1).where(Models.name == run_name) - model_sha1 = session.execute(stmt).scalar() + model_hash = session.execute(stmt).scalar() - if model_sha1 is None: + if model_hash is None: raise MatchboxDBDataError(table=Models.__tablename__, data=run_name) - return model_sha1 + return model_hash diff --git a/src/matchbox/server/postgresql/utils/index.py b/src/matchbox/server/postgresql/utils/index.py index f08edcc..2a89c8c 100644 --- a/src/matchbox/server/postgresql/utils/index.py +++ b/src/matchbox/server/postgresql/utils/index.py @@ -1,28 +1,26 @@ import logging -from sqlalchemy import Engine, String, func, select +from sqlalchemy import Engine from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session +from matchbox.common.hash import dataset_to_hashlist from matchbox.server.base import IndexableDataset -from matchbox.server.postgresql import utils as du from matchbox.server.postgresql.data import SourceData, SourceDataset -def index_dataset( - dataset: IndexableDataset, engine: Engine, warehouse_engine: Engine -) -> None: +def index_dataset(dataset: IndexableDataset, engine: Engine) -> None: """Indexes a dataset from your data warehouse within Matchbox.""" logic_logger = logging.getLogger("mb_logic") db_logger = logging.getLogger("sqlalchemy.engine") db_logger.setLevel(logging.WARNING) - with Session(engine) as session: - ########################## - # Insert dataset section # - ########################## + ################## + # Insert dataset # + ################## + with Session(engine) as session: logic_logger.info(f"Adding {dataset}") to_insert = [ @@ -49,38 +47,17 @@ def index_dataset( .first() ) + new_dataset_uuid = new_dataset.uuid + session.commit() logic_logger.info(f"{dataset} added to SourceDataset") - ####################### - # Insert data section # - ####################### - - with Session(warehouse_engine) as warehouse_session: - source_table = du.dataset_to_table(new_dataset, warehouse_engine) - - # Retrieve the SHA1 of data and an array of row IDs (as strings) - # Array because we can't guarantee non-duplicated data - cols = tuple( - [col for col in list(source_table.c.keys()) if col != dataset.db_pk] - ) - slct_stmt = select( - func.digest(func.concat(*source_table.c[cols]), "sha1").label("sha1"), - func.array_agg(source_table.c[dataset.db_pk].cast(String)).label("id"), - ).group_by(*source_table.c[cols]) - - raw_result = warehouse_session.execute(slct_stmt) - - logic_logger.info(f"Retrieved raw data from {dataset}") + ############### + # Insert data # + ############### - # Create list of (sha1, id, dataset)-keyed dicts using RowMapping: - # https://docs.sqlalchemy.org/en/20/core/ - # connections.html#sqlalchemy.engine.Row._mapping - to_insert = [ - dict(data._mapping, **{"dataset": new_dataset.uuid}) - for data in raw_result.all() - ] + to_insert = dataset_to_hashlist(dataset=dataset, uuid=new_dataset_uuid) with Session(engine) as session: # Insert it using PostgreSQL upsert diff --git a/src/matchbox/server/postgresql/utils/insert.py b/src/matchbox/server/postgresql/utils/insert.py index e12038f..3c0d54b 100644 --- a/src/matchbox/server/postgresql/utils/insert.py +++ b/src/matchbox/server/postgresql/utils/insert.py @@ -8,15 +8,15 @@ from sqlalchemy.orm import Session from matchbox.common.exceptions import MatchboxDBDataError -from matchbox.common.sha1 import list_to_value_ordered_sha1 +from matchbox.common.hash import list_to_value_ordered_hash from matchbox.server.base import Cluster, Probability from matchbox.server.postgresql.clusters import Clusters, clusters_association from matchbox.server.postgresql.dedupe import DDupeContains, DDupeProbabilities, Dedupes from matchbox.server.postgresql.link import LinkContains, LinkProbabilities, Links from matchbox.server.postgresql.models import Models, ModelsFrom from matchbox.server.postgresql.utils.db import batch_ingest -from matchbox.server.postgresql.utils.sha1 import ( - model_name_to_sha1, +from matchbox.server.postgresql.utils.hash import ( + model_name_to_hash, ) logic_logger = logging.getLogger("mb_logic") @@ -30,12 +30,12 @@ def insert_deduper( logic_logger.info(f"{metadata} Registering model") with Session(engine) as session: - # Construct model SHA1 from name and what it deduplicates - model_sha1 = list_to_value_ordered_sha1([model, deduplicates]) + # Construct model hash from name and what it deduplicates + model_hash = list_to_value_ordered_hash([model, deduplicates]) # Insert model model = Models( - sha1=model_sha1, + sha1=model_hash, name=model, description=description, deduplicates=deduplicates, @@ -55,17 +55,17 @@ def insert_linker( logic_logger.info(f"{metadata} Registering model") with Session(engine) as session: - # Construct model SHA1 from parent model SHA1s - left_sha1 = model_name_to_sha1(left, engine=engine) - right_sha1 = model_name_to_sha1(right, engine=engine) + # Construct model hash from parent model hashes + left_hash = model_name_to_hash(left, engine=engine) + right_hash = model_name_to_hash(right, engine=engine) - model_sha1 = list_to_value_ordered_sha1( - [bytes(model, encoding="utf-8"), left_sha1, right_sha1] + model_hash = list_to_value_ordered_hash( + [bytes(model, encoding="utf-8"), left_hash, right_hash] ) # Insert model model = Models( - sha1=model_sha1, + sha1=model_hash, name=model, description=description, deduplicates=None, @@ -76,8 +76,8 @@ def insert_linker( # Insert reference to parent models models_from_to_insert = [ - {"parent": model_sha1, "child": left_sha1}, - {"parent": model_sha1, "child": right_sha1}, + {"parent": model_hash, "child": left_hash}, + {"parent": model_hash, "child": right_hash}, ] ins_stmt = insert(ModelsFrom) @@ -121,7 +121,7 @@ def insert_probabilities( with Session(engine) as session: # Get model db_model = session.query(Models).filter_by(name=model).first() - model_sha1 = db_model.sha1 + model_hash = db_model.sha1 if db_model is None: raise MatchboxDBDataError(source=Models, data=model) @@ -168,17 +168,17 @@ def probability_to_node(probability: Probability) -> dict: # Insert probabilities def probability_to_probability( - probability: Probability, model_sha1: bytes + probability: Probability, model_hash: bytes ) -> dict: return { "ddupe" if is_deduper else "link": probability.sha1, - "model": model_sha1, + "model": model_hash, "probability": probability.probability, } batch_ingest( records=[ - probability_to_probability(prob, model_sha1) for prob in probabilities + probability_to_probability(prob, model_hash) for prob in probabilities ], table=ProbabilitiesTable, conn=conn, @@ -216,7 +216,7 @@ def insert_clusters( with Session(engine) as session: # Get model db_model = session.query(Models).filter_by(name=model).first() - model_sha1 = db_model.sha1 + model_hash = db_model.sha1 if db_model is None: raise MatchboxDBDataError(source=Models, data=model) @@ -272,16 +272,16 @@ def cluster_to_cluster_contains(cluster: Cluster) -> dict: ) # Insert cluster proposed by - def cluster_to_cluster_association(cluster: Cluster, model_sha1: bytes) -> dict: + def cluster_to_cluster_association(cluster: Cluster, model_hash: bytes) -> dict: """Prepares a Cluster for the cluster association table.""" return { - "parent": model_sha1, + "parent": model_hash, "child": cluster.parent, } batch_ingest( records=[ - cluster_to_cluster_association(cluster, model_sha1) + cluster_to_cluster_association(cluster, model_hash) for cluster in clusters ], table=clusters_association, diff --git a/src/matchbox/server/postgresql/utils/selector.py b/src/matchbox/server/postgresql/utils/selector.py index efe2302..6c6bc0a 100644 --- a/src/matchbox/server/postgresql/utils/selector.py +++ b/src/matchbox/server/postgresql/utils/selector.py @@ -18,11 +18,10 @@ from matchbox.helpers.selector import get_schema_table_names, string_to_table from matchbox.server.postgresql.clusters import Clusters, clusters_association -from matchbox.server.postgresql.data import SourceData +from matchbox.server.postgresql.data import SourceData, SourceDataset from matchbox.server.postgresql.dedupe import DDupeContains from matchbox.server.postgresql.link import LinkContains from matchbox.server.postgresql.models import Models -from matchbox.server.postgresql.utils.db import string_to_dataset def get_all_parents(model: Models | list[Models]) -> list[Models]: @@ -147,7 +146,7 @@ def _tree_to_reachable_stmt(model_tree: list[bytes]) -> Select: def _reachable_to_parent_data_stmt( - reachable_stmt: Select, parent_sha1: bytes + reachable_stmt: Select, parent_hash: bytes ) -> Select: """ Takes a select statement representing the reachable edges of a parent @@ -159,7 +158,7 @@ def _reachable_to_parent_data_stmt( Args: reachable_stmt: a SQLAlchemy Select object that defines the reachable edges of the combined LinkContains and DDupeContains tables - parent_sha1: the SHA-1 to use as the ultimate parent model, the point + parent_hash: the SHA-1 to use as the ultimate parent model, the point of truth Returns: @@ -172,7 +171,7 @@ def _reachable_to_parent_data_stmt( .join(Clusters, Clusters.sha1 == allowed.c.parent) .join(clusters_association, clusters_association.c.child == Clusters.sha1) .join(Models, clusters_association.c.parent == Models.sha1) - .where(Models.sha1 == parent_sha1) + .where(Models.sha1 == parent_hash) .cte("root") ) @@ -211,7 +210,13 @@ def _selector_to_data( for schema_table, fields in selector.items(): db_schema, db_table = get_schema_table_names(schema_table) - mb_dataset = string_to_dataset(db_schema, db_table, engine=engine) + with Session(engine) as session: + mb_dataset = ( + session.query(SourceDataset) + .filter_by(db_schema=db_schema, db_table=db_table) + .first() + ) + db_table = string_to_table(db_schema, db_table, engine=engine) # To handle array column @@ -233,7 +238,7 @@ def _selector_to_data( where_stmts.append(db_table.c[mb_dataset.db_id] != None) # NoQA E711 stmt = select( - source_data_unested.c.sha1.label("data_sha1"), *select_stmt + source_data_unested.c.sha1.label("data_hash"), *select_stmt ).select_from(source_data_unested) for join_stmt in join_stmts: @@ -336,8 +341,8 @@ def query( lookup_stmt = _reachable_to_parent_data_stmt(reachable_stmt, parent) data_stmt = _selector_to_data(selector, engine=engine).cte() - final_stmt = select(lookup_stmt.c.parent.label("cluster_sha1"), data_stmt).join( - lookup_stmt, lookup_stmt.c.child == data_stmt.c.data_sha1 + final_stmt = select(lookup_stmt.c.parent.label("cluster_hash"), data_stmt).join( + lookup_stmt, lookup_stmt.c.child == data_stmt.c.data_hash ) if limit is not None: @@ -347,8 +352,8 @@ def query( # Detect datatypes selector_dtypes = _selector_to_pandas_dtypes(selector, engine=engine) default_dtypes = { - "cluster_sha1": "string[pyarrow]", - "data_sha1": "string[pyarrow]", + "cluster_hash": "string[pyarrow]", + "data_hash": "string[pyarrow]", } with engine.connect() as conn: @@ -373,12 +378,12 @@ def query( ).convert_dtypes(dtype_backend="pyarrow") # Manually convert SHA-1s to bytes correctly - if "data_sha1" in res.columns: - res.data_sha1 = res.data_sha1.str[2:].apply(bytes.fromhex) - res.data_sha1 = res.data_sha1.astype("binary[pyarrow]") - if "cluster_sha1" in res.columns: - res.cluster_sha1 = res.cluster_sha1.str[2:].apply(bytes.fromhex) - res.cluster_sha1 = res.cluster_sha1.astype("binary[pyarrow]") + if "data_hash" in res.columns: + res.data_hash = res.data_hash.str[2:].apply(bytes.fromhex) + res.data_hash = res.data_hash.astype("binary[pyarrow]") + if "cluster_hash" in res.columns: + res.cluster_hash = res.cluster_hash.str[2:].apply(bytes.fromhex) + res.cluster_hash = res.cluster_hash.astype("binary[pyarrow]") elif return_type == "sqlalchemy": with Session(engine) as session: diff --git a/test/fixtures/db.py b/test/fixtures/db.py index d41e149..61363e0 100644 --- a/test/fixtures/db.py +++ b/test/fixtures/db.py @@ -7,7 +7,7 @@ from matchbox.server.base import IndexableDataset, SourceWarehouse from matchbox.server.postgresql import MatchboxPostgres, MatchboxPostgresSettings from pandas import DataFrame -from sqlalchemy import text +from sqlalchemy import text as sqltext from .models import DedupeTestParams, LinkTestParams, ModelTestParams @@ -81,7 +81,7 @@ def _db_add_dedupe_models_and_data( deduped = deduper() clustered = to_clusters( - df, results=deduped, key="data_sha1", threshold=0 + df, results=deduped, key="data_hash", threshold=0 ) deduped.to_matchbox(backend=matchbox_postgres) @@ -155,7 +155,7 @@ def _db_add_link_models_and_data( linked = linker() clustered = to_clusters( - df_l, df_r, results=linked, key="cluster_sha1", threshold=0 + df_l, df_r, results=linked, key="cluster_hash", threshold=0 ) linked.to_matchbox(backend=matchbox_postgres) @@ -192,29 +192,44 @@ def warehouse_data( ) -> Generator[list[IndexableDataset], None, None]: """Inserts data into the warehouse database for testing.""" with warehouse.engine.connect() as conn: - conn.execute(text("drop schema if exists test cascade;")) - conn.execute(text("create schema test;")) + conn.execute(sqltext("drop schema if exists test cascade;")) + conn.execute(sqltext("create schema test;")) crn_companies.to_sql( - "crn", + name="crn", con=conn, schema="test", if_exists="replace", index=False, ) duns_companies.to_sql( - "duns", + name="duns", con=conn, schema="test", if_exists="replace", index=False, ) cdms_companies.to_sql( - "cdms", + name="cdms", con=conn, schema="test", if_exists="replace", index=False, ) + conn.commit() + + with warehouse.engine.connect() as conn: + assert ( + conn.execute(sqltext("select count(*) from test.crn;")).scalar() + == crn_companies.shape[0] + ) + assert ( + conn.execute(sqltext("select count(*) from test.duns;")).scalar() + == duns_companies.shape[0] + ) + assert ( + conn.execute(sqltext("select count(*) from test.cdms;")).scalar() + == cdms_companies.shape[0] + ) yield [ IndexableDataset( @@ -230,9 +245,9 @@ def warehouse_data( # Clean up the warehouse data with warehouse.engine.connect() as conn: - conn.execute(text("drop table if exists test.crn;")) - conn.execute(text("drop table if exists test.duns;")) - conn.execute(text("drop table if exists test.cdms;")) + conn.execute(sqltext("drop table if exists test.crn;")) + conn.execute(sqltext("drop table if exists test.duns;")) + conn.execute(sqltext("drop table if exists test.cdms;")) conn.commit() @@ -250,7 +265,7 @@ def matchbox_settings() -> MatchboxPostgresSettings: "user": "matchbox_user", "password": "matchbox_password", "database": "matchbox", - "db_schema": "matchbox", + "db_schema": "mb", }, ) @@ -263,6 +278,9 @@ def matchbox_postgres( adapter = MatchboxPostgres(settings=matchbox_settings) + # Clean up the Matchbox database before each test, just in case + adapter.clear(certain=True) + yield adapter # Clean up the Matchbox database after each test diff --git a/test/fixtures/models.py b/test/fixtures/models.py index 29a0dfd..92e2787 100644 --- a/test/fixtures/models.py +++ b/test/fixtures/models.py @@ -176,7 +176,7 @@ class ModelTestParams(BaseModel): def make_naive_dd_settings(data: DedupeTestParams) -> dict[str, Any]: - return {"id": "data_sha1", "unique_fields": list(data.fields.keys())} + return {"id": "data_hash", "unique_fields": list(data.fields.keys())} dedupe_model_test_params = [ @@ -198,8 +198,8 @@ def make_deterministic_li_settings(data: LinkTestParams) -> dict[str, Any]: comparisons.append(f"l.{field_l} = r.{field_r}") return { - "left_id": "cluster_sha1", - "right_id": "cluster_sha1", + "left_id": "cluster_hash", + "right_id": "cluster_hash", "comparisons": " and ".join(comparisons), } @@ -246,8 +246,8 @@ def make_splink_li_settings(data: LinkTestParams) -> dict[str, Any]: } return { - "left_id": "cluster_sha1", - "right_id": "cluster_sha1", + "left_id": "cluster_hash", + "right_id": "cluster_hash", "linker_class": DuckDBLinker, "linker_training_functions": linker_training_functions, "linker_settings": linker_settings, @@ -262,8 +262,8 @@ def make_weighted_deterministic_li_settings(data: LinkTestParams) -> dict[str, A weighted_comparisons.append((f"l.{field_l} = r.{field_r}", 1)) return { - "left_id": "cluster_sha1", - "right_id": "cluster_sha1", + "left_id": "cluster_hash", + "right_id": "cluster_hash", "weighted_comparisons": weighted_comparisons, "threshold": 1, } diff --git a/test/server/test_adapter.py b/test/server/test_adapter.py index c0a528a..2385ca1 100644 --- a/test/server/test_adapter.py +++ b/test/server/test_adapter.py @@ -1,6 +1,7 @@ from dotenv import find_dotenv, load_dotenv from matchbox.server.base import IndexableDataset from matchbox.server.postgresql import MatchboxPostgres +from pandas import DataFrame from ..fixtures.db import AddIndexedDataCallable @@ -12,9 +13,25 @@ def test_index( matchbox_postgres: MatchboxPostgres, db_add_indexed_data: AddIndexedDataCallable, warehouse_data: list[IndexableDataset], + crn_companies: DataFrame, + duns_companies: DataFrame, + cdms_companies: DataFrame, ): - # Test indexing a dataset - pass + """Test that indexing data works.""" + assert matchbox_postgres.data.count() == 0 + + db_add_indexed_data( + matchbox_postgres=matchbox_postgres, warehouse_data=warehouse_data + ) + + def count_deduplicates(df: DataFrame) -> int: + return df.drop(columns=["id"]).drop_duplicates().shape[0] + + unique = sum( + count_deduplicates(df) for df in [crn_companies, duns_companies, cdms_companies] + ) + + assert matchbox_postgres.data.count() == unique def test_query(matchbox_postgres, db_add_link_models_and_data): diff --git a/test/test_dedupers.py b/test/test_dedupers.py index 3932a1d..e920c2f 100644 --- a/test/test_dedupers.py +++ b/test/test_dedupers.py @@ -38,7 +38,7 @@ def test_dedupers( if fx_deduper.rename_fields: df_renamed = df.copy().rename(columns=fx_data.fields) fields_renamed = list(fx_data.fields.values()) - df_renamed = df_renamed.filter(["data_sha1"] + fields_renamed) + df_renamed = df_renamed.filter(["data_hash"] + fields_renamed) # 1. Input data is as expected @@ -67,7 +67,7 @@ def test_dedupers( deduped_df = deduped.to_df() deduped_df_with_source = deduped.inspect_with_source( - left_data=df, left_key="data_sha1", right_data=df, right_key="data_sha1" + left_data=df, left_key="data_hash", right_data=df, right_key="data_hash" ) assert isinstance(deduped_df, DataFrame) @@ -89,11 +89,11 @@ def test_dedupers( # 4. Correct number of clusters are resolved - clusters_dupes = to_clusters(results=deduped, key="data_sha1", threshold=0) + clusters_dupes = to_clusters(results=deduped, key="data_hash", threshold=0) clusters_dupes_df = clusters_dupes.to_df() clusters_dupes_df_with_source = clusters_dupes.inspect_with_source( - left_data=df, left_key="data_sha1", right_data=df, right_key="data_sha1" + left_data=df, left_key="data_hash", right_data=df, right_key="data_hash" ) assert isinstance(clusters_dupes_df, DataFrame) @@ -105,11 +105,11 @@ def test_dedupers( clusters_dupes_df_with_source[field + "_y"] ) - clusters_all = to_clusters(df, results=deduped, key="data_sha1", threshold=0) + clusters_all = to_clusters(df, results=deduped, key="data_hash", threshold=0) clusters_all_df = clusters_all.to_df() clusters_all_df_with_source = clusters_all.inspect_with_source( - left_data=df, left_key="data_sha1", right_data=df, right_key="data_sha1" + left_data=df, left_key="data_hash", right_data=df, right_key="data_hash" ) assert isinstance(clusters_all_df, DataFrame) diff --git a/test/test_helpers.py b/test/test_helpers.py index c6320c3..28fc72d 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -69,7 +69,7 @@ def test_single_table_no_model_query(db_engine): assert df_crn_full.shape[0] == 3000 assert set(df_crn_full.columns) == { - "data_sha1", + "data_hash", f"{os.getenv('MB__POSTGRES__SCHEMA')}_crn_id", f"{os.getenv('MB__POSTGRES__SCHEMA')}_crn_crn", } @@ -108,7 +108,7 @@ def test_multi_table_no_model_query(db_engine): ) assert set(df_crn_duns_full.columns) == { - "data_sha1", + "data_hash", f"{os.getenv('MB__POSTGRES__SCHEMA')}_crn_id", f"{os.getenv('MB__POSTGRES__SCHEMA')}_crn_crn", f"{os.getenv('MB__POSTGRES__SCHEMA')}_duns_id", @@ -148,13 +148,13 @@ def test_single_table_with_model_query( assert isinstance(crn, DataFrame) assert crn.shape[0] == 3000 assert set(crn.columns) == { - "cluster_sha1", - "data_sha1", + "cluster_hash", + "data_hash", f"{os.getenv('MB__POSTGRES__SCHEMA')}_crn_crn", f"{os.getenv('MB__POSTGRES__SCHEMA')}_crn_company_name", } - assert crn.data_sha1.nunique() == 3000 - assert crn.cluster_sha1.nunique() == 1000 + assert crn.data_hash.nunique() == 3000 + assert crn.cluster_hash.nunique() == 1000 def test_multi_table_with_model_query( @@ -208,13 +208,13 @@ def test_multi_table_with_model_query( assert isinstance(crn_duns, DataFrame) assert crn_duns.shape[0] == 3500 assert set(crn_duns.columns) == { - "cluster_sha1", - "data_sha1", + "cluster_hash", + "data_hash", f"{os.getenv('MB__POSTGRES__SCHEMA')}_crn_crn", f"{os.getenv('MB__POSTGRES__SCHEMA')}_duns_duns", } - assert crn_duns.data_sha1.nunique() == 3500 - assert crn_duns.cluster_sha1.nunique() == 1000 + assert crn_duns.data_hash.nunique() == 3500 + assert crn_duns.cluster_hash.nunique() == 1000 def test_cleaners(): diff --git a/test/test_linkers.py b/test/test_linkers.py index 89fbdd9..02ab682 100644 --- a/test/test_linkers.py +++ b/test/test_linkers.py @@ -53,8 +53,8 @@ def test_linkers( df_r_renamed = df_r.copy().rename(columns=fx_data.fields_r) fields_l_renamed = list(fx_data.fields_l.values()) fields_r_renamed = list(fx_data.fields_r.values()) - df_l_renamed = df_l_renamed.filter(["cluster_sha1"] + fields_l_renamed) - df_r_renamed = df_r_renamed.filter(["cluster_sha1"] + fields_r_renamed) + df_l_renamed = df_l_renamed.filter(["cluster_hash"] + fields_l_renamed) + df_r_renamed = df_r_renamed.filter(["cluster_hash"] + fields_r_renamed) assert set(df_l_renamed.columns) == set(df_r_renamed.columns) assert df_l_renamed.dtypes.equals(df_r_renamed.dtypes) @@ -99,9 +99,9 @@ def test_linkers( linked_df_with_source = linked.inspect_with_source( left_data=df_l, - left_key="cluster_sha1", + left_key="cluster_hash", right_data=df_r, - right_key="cluster_sha1", + right_key="cluster_hash", ) assert isinstance(linked_df, DataFrame) @@ -121,14 +121,14 @@ def test_linkers( # 4. Correct number of clusters are resolved - clusters_links = to_clusters(results=linked, key="cluster_sha1", threshold=0) + clusters_links = to_clusters(results=linked, key="cluster_hash", threshold=0) clusters_links_df = clusters_links.to_df() clusters_links_df_with_source = clusters_links.inspect_with_source( left_data=df_l, - left_key="cluster_sha1", + left_key="cluster_hash", right_data=df_r, - right_key="cluster_sha1", + right_key="cluster_hash", ) assert isinstance(clusters_links_df, DataFrame) @@ -166,15 +166,15 @@ def unique_non_null(s): assert cluster_vals.shape[0] == fx_data.tgt_clus_n clusters_all = to_clusters( - df_l, df_r, results=linked, key="cluster_sha1", threshold=0 + df_l, df_r, results=linked, key="cluster_hash", threshold=0 ) clusters_all_df = clusters_all.to_df() clusters_all_df_with_source = clusters_all.inspect_with_source( left_data=df_l, - left_key="cluster_sha1", + left_key="cluster_hash", right_data=df_r, - right_key="cluster_sha1", + right_key="cluster_hash", ) assert isinstance(clusters_all_df, DataFrame) diff --git a/test/test_utils.py b/test/test_utils.py index 8667c53..d991da2 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,10 +1,10 @@ -from matchbox.common.sha1 import columns_to_value_ordered_sha1 +from matchbox.common.hash import columns_to_value_ordered_hash from pandas import Series, concat -def test_sha1_conversion(all_companies): +def test_hash_conversion(all_companies): """Tests SHA1 conversion works as expected.""" - sha1_series_1 = columns_to_value_ordered_sha1( + sha1_series_1 = columns_to_value_ordered_hash( data=all_companies, columns=["id", "company_name", "address", "crn", "duns", "cdms"], ) @@ -29,7 +29,7 @@ def test_sha1_conversion(all_companies): [all_companies_reordered_top, all_companies.tail(500)] ) - sha1_series_2 = columns_to_value_ordered_sha1( + sha1_series_2 = columns_to_value_ordered_hash( data=all_companies_reodered, columns=["id", "company_name", "address", "crn", "duns", "cdms"], )