From a25cb3685dc5f94c53ea728037ef75dc6dab36cd Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Fri, 11 Oct 2024 23:54:20 +0100 Subject: [PATCH] Finished adapter methods, added bugbears to linting rules --- pyproject.toml | 1 + .../{server/postgresql => common}/results.py | 251 ++++-------------- src/matchbox/dedupers/make_deduper.py | 3 +- src/matchbox/helpers/selector.py | 8 +- src/matchbox/helpers/visualisation.py | 8 +- src/matchbox/linkers/make_linker.py | 5 +- src/matchbox/server/base.py | 3 +- src/matchbox/server/exceptions.py | 4 + src/matchbox/server/postgresql/adapter.py | 3 +- src/matchbox/server/postgresql/db.py | 4 +- src/matchbox/server/postgresql/models.py | 52 ++-- src/matchbox/server/postgresql/utils/db.py | 25 ++ .../server/postgresql/utils/insert.py | 244 +++++++++-------- test/fixtures/models.py | 6 +- test/test_linkers.py | 6 +- 15 files changed, 263 insertions(+), 360 deletions(-) rename src/matchbox/{server/postgresql => common}/results.py (59%) diff --git a/pyproject.toml b/pyproject.toml index f2cbb1a..0cdb7b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ select = [ "E", "F", "I", + "B", # "D" ] ignore = [] diff --git a/src/matchbox/server/postgresql/results.py b/src/matchbox/common/results.py similarity index 59% rename from src/matchbox/server/postgresql/results.py rename to src/matchbox/common/results.py index 4e8d7db..fc7816b 100644 --- a/src/matchbox/server/postgresql/results.py +++ b/src/matchbox/common/results.py @@ -1,27 +1,17 @@ import logging from abc import ABC, abstractmethod -from typing import List, Optional, Union +from typing import List, Optional import rustworkx as rx from dotenv import find_dotenv, load_dotenv +from matchbox.server.base import Cluster, MatchboxDBAdapter, Probability +from matchbox.server.postgresql.utils.sha1 import ( + columns_to_value_ordered_sha1, + list_to_value_ordered_sha1, +) from pandas import DataFrame, concat -from pg_bulk_ingest import Delete, Upsert, ingest from pydantic import BaseModel, ConfigDict, model_validator -from sqlalchemy import ( - Engine, - Table, - delete, -) -from sqlalchemy.orm import Session - -from matchbox.server.base import Cluster, MatchboxDBAdapter, Probability -from matchbox.server.exceptions import MatchboxDBDataError -from matchbox.server.postgresql import utils as du -from matchbox.server.postgresql.clusters import Clusters, clusters_association -from matchbox.server.postgresql.db import ENGINE -from matchbox.server.postgresql.dedupe import DDupeContains -from matchbox.server.postgresql.link import LinkContains -from matchbox.server.postgresql.models import Models +from sqlalchemy import Table logic_logger = logging.getLogger("mb_logic") @@ -66,39 +56,10 @@ def to_records(self) -> list[Probability | Cluster]: """Returns the results as a list of records suitable for insertion.""" return - def to_cmf(self, backend: MatchboxDBAdapter) -> None: - """Writes the results to the CMF database.""" - if self.left == self.right: - # Deduper - backend.insert_model( - model=self.run_name, - left=self.left, - description=self.description, - ) - - model = backend.get_model(model=self.run_name) - - model.insert_probabilities( - probabilites=self.to_records(), - probability_type="deduplications", - batch_size=backend.settings.batch_size, - ) - else: - # Linker - backend.insert_model( - model=self.run_name, - left=self.left, - right=self.right, - description=self.description, - ) - - model = backend.get_model(model=self.run_name) - - model.insert_probabilities( - probabilites=self.to_records(), - probability_type="links", - batch_size=backend.settings.batch_size, - ) + @abstractmethod + def to_matchbox(self, backend: MatchboxDBAdapter) -> None: + """Writes the results to the Matchbox database.""" + return class ProbabilityResults(ResultsBaseDataclass): @@ -184,32 +145,40 @@ def to_records(self, backend: MatchboxDBAdapter | None) -> list[Probability]: hash_type=hash_type, ) - # Prep and return - pre_prep_df = self.dataframe.copy() - cols = ["left_id", "right_id"] - pre_prep_df[cols] = pre_prep_df[cols].astype("binary[pyarrow]") - pre_prep_df["sha1"] = du.columns_to_value_ordered_sha1( - data=self.dataframe, columns=cols - ) - pre_prep_df.sha1 = pre_prep_df.sha1.astype("binary[pyarrow]") - pre_prep_df = pre_prep_df.rename( - columns={"left_id": "left", "right_id": "right"} + # Preprocess the dataframe + pre_prep_df = self.dataframe[["left_id", "right_id", "probability"]].copy() + pre_prep_df[["left_id", "right_id"]] = pre_prep_df[ + ["left_id", "right_id"] + ].astype("binary[pyarrow]") + pre_prep_df["sha1"] = columns_to_value_ordered_sha1( + data=pre_prep_df, columns=["left_id", "right_id"] ) - - pre_prep_df = pre_prep_df[["sha1", "left", "right", "probability"]] + pre_prep_df["sha1"] = pre_prep_df["sha1"].astype("binary[pyarrow]") return [ - Probability( - sha1=sha1, - left=left, - right=right, - probability=probability, - ) - for sha1, left, right, probability in self.dataframe.itertuples( - index=False, name=None - ) + Probability(sha1=row[0], left=row[1], right=row[2], probability=row[3]) + for row in pre_prep_df[ + ["sha1", "left_id", "right_id", "probability"] + ].to_numpy() ] + def to_matchbox(self, backend: MatchboxDBAdapter) -> None: + """Writes the results to the Matchbox database.""" + backend.insert_model( + model=self.run_name, + left=self.left, + right=self.right if self.left != self.right else None, + description=self.description, + ) + + model = backend.get_model(model=self.run_name) + + model.insert_probabilities( + probabilites=self.to_records(), + probability_type="links" if self.left != self.right else "deduplications", + batch_size=backend.settings.batch_size, + ) + class ClusterResults(ResultsBaseDataclass): """Cluster data produced by using to_clusters on ProbabilityResults. @@ -262,127 +231,18 @@ def to_df(self) -> DataFrame: """Returns the results as a DataFrame.""" return self.dataframe.copy().convert_dtypes(dtype_backend="pyarrow") - def _to_mb_logic( - self, - contains_class: Union[DDupeContains, LinkContains], - engine: Engine = ENGINE, - ) -> None: - """Handles common logic for writing dedupe or link clusters to the database. - - In ClusterResults, the only difference is the tables being written to. - - * Adds the new cluster nodes - * Adds model endorsement of these nodes with "creates" edge - * Adds the contains edges to show which clusters contain which - - Args: - contains_class: the target table, one of DDupeContains or LinkContains - engine: a SQLAlchemy Engine object for the database - - Raises: - MatchboxDBDataError if model wasn't inserted correctly - """ - Contains = contains_class - with Session(engine) as session: - # Add clusters - # Get model - model = session.query(Models).filter_by(name=self.run_name).first() - model_sha1 = model.sha1 - - if model is None: - raise MatchboxDBDataError(source=Models, data=self.run_name) - - # Clear old model endorsements - old_cluster_creates_subquery = model.creates.select().with_only_columns( - Clusters.sha1 - ) - - session.execute( - delete(clusters_association).where( - clusters_association.c.child.in_(old_cluster_creates_subquery) - ) - ) - - session.commit() - - logic_logger.info(f"[{self.metadata}] Removed old clusters") - - with engine.connect() as conn: - logic_logger.info( - f"[{self.metadata}] Inserting %s cluster objects", - self.dataframe.shape[0], - ) - - clusters_prepped = self.dataframe.astype("binary[pyarrow]") - - # Upsert cluster nodes - # Create data batching function and pass it to ingest - fn_cluster_batch = du.data_to_batch( - dataframe=( - clusters_prepped.drop_duplicates(subset="parent").rename( - columns={"parent": "sha1"} - )[list(Clusters.__table__.columns.keys())] - ), - table=Clusters.__table__, - batch_size=self._batch_size, - ) - - ingest( - conn=conn, - metadata=Clusters.metadata, - batches=fn_cluster_batch, - upsert=Upsert.IF_PRIMARY_KEY, - delete=Delete.OFF, - ) - - # Insert cluster contains - fn_cluster_contains_batch = du.data_to_batch( - dataframe=clusters_prepped[list(Contains.__table__.columns.keys())], - table=Contains.__table__, - batch_size=self._batch_size, - ) - - ingest( - conn=conn, - metadata=Contains.metadata, - batches=fn_cluster_contains_batch, - upsert=Upsert.IF_PRIMARY_KEY, - delete=Delete.OFF, - ) - - # Insert cluster proposed by - fn_cluster_proposed_batch = du.data_to_batch( - dataframe=( - clusters_prepped.drop("child", axis=1) - .rename(columns={"parent": "child"}) - .assign(parent=model_sha1)[ - list(clusters_association.columns.keys()) - ] - ), - table=clusters_association, - batch_size=self._batch_size, - ) - - ingest( - conn=conn, - metadata=clusters_association.metadata, - batches=fn_cluster_proposed_batch, - upsert=Upsert.IF_PRIMARY_KEY, - delete=Delete.OFF, - ) - - logic_logger.info( - f"[{self.metadata}] Inserted all %s cluster objects", - self.dataframe.shape[0], - ) - - def _deduper_to_cmf(self, engine: Engine = ENGINE) -> None: - """Writes the results of a deduper to the CMF database.""" - self._to_mb_logic(contains_class=DDupeContains, engine=engine) - - def _linker_to_cmf(self, engine: Engine = ENGINE) -> None: - """Writes the results of a linker to the CMF database.""" - self._to_mb_logic(contains_class=LinkContains, engine=engine) + def to_records(self) -> list[Cluster]: + """Returns the results as a list of records suitable for insertion.""" + parent_child_pairs = self.dataframe[["parent", "child"]].values + return [Cluster(parent=row[0], child=row[1]) for row in parent_child_pairs] + + def to_matchbox(self, backend: MatchboxDBAdapter) -> None: + """Writes the results to the Matchbox database.""" + model = backend.get_model(model=self.run_name) + model.insert_clusters( + clusters=self.to_records(), + batch_size=backend.settings.batch_size, + ) def get_unclustered( @@ -427,7 +287,7 @@ def to_clusters( *data: Optional[DataFrame], results: ProbabilityResults, key: str, - threshold: float = 0.0, + threshold: float = None, ) -> ClusterResults: """ Takes a models probabilistic outputs and turns them into clusters. @@ -445,6 +305,9 @@ def to_clusters( Returns A ClusterResults object """ + if not threshold: + threshold = 0.0 + all_edges = ( results.dataframe.query("probability >= @threshold") .filter(["left_id", "right_id"]) @@ -476,7 +339,7 @@ def to_clusters( res["child"].append(child_hash) # Must be sorted to be symmetric - parent_hash = du.list_to_value_ordered_sha1(child_hashes) + parent_hash = list_to_value_ordered_sha1(child_hashes) res["parent"] += [parent_hash] * len(component) diff --git a/src/matchbox/dedupers/make_deduper.py b/src/matchbox/dedupers/make_deduper.py index bdc2c40..ef7526a 100644 --- a/src/matchbox/dedupers/make_deduper.py +++ b/src/matchbox/dedupers/make_deduper.py @@ -24,7 +24,8 @@ def _id_for_cmf(cls, v: str, info: ValidationInfo) -> str: f"For offline deduplication, {info.field_name} can be any field. \n\n" "When deduplicating to write back to the Company Matching " f"Framework database, the ID must be {enforce}, generated by " - "retrieving data with cmf.query()." + "retrieving data with cmf.query().", + stacklevel=3, ) return v diff --git a/src/matchbox/helpers/selector.py b/src/matchbox/helpers/selector.py index 0c60188..b1f7d88 100644 --- a/src/matchbox/helpers/selector.py +++ b/src/matchbox/helpers/selector.py @@ -11,7 +11,7 @@ from sqlalchemy.orm import Session from matchbox.server import MatchboxDBAdapter -from matchbox.server.exceptions import MatchboxSourceTableError +from matchbox.server.exceptions import MatchboxSourceTableError, MatchboxValidatonError def get_schema_table_names(full_name: str, validate: bool = False) -> tuple[str, str]: @@ -28,7 +28,7 @@ def get_schema_table_names(full_name: str, validate: bool = False) -> tuple[str, Raises: ValueError: When the function can't detect either a schema.table or table format in the input - ValidationError: If both schema and table can't be detected + MatchboxValidatonError: If both schema and table can't be detected when the validate argument is True Returns: @@ -52,7 +52,9 @@ def get_schema_table_names(full_name: str, validate: bool = False) -> tuple[str, ) if validate and schema is None: - raise ("Schema could not be detected and validation required.") + raise MatchboxValidatonError( + "Schema could not be detected and validation required." + ) return (schema, table) diff --git a/src/matchbox/helpers/visualisation.py b/src/matchbox/helpers/visualisation.py index 6d93aad..cead056 100644 --- a/src/matchbox/helpers/visualisation.py +++ b/src/matchbox/helpers/visualisation.py @@ -1,17 +1,15 @@ import rustworkx as rx from matplotlib.figure import Figure from rustworkx.visualization import mpl_draw -from sqlalchemy import Engine -from matchbox.data import ENGINE -from matchbox.data.utils import get_model_subgraph +from matchbox.server.base import MatchboxDBAdapter -def draw_model_tree(engine: Engine = ENGINE) -> Figure: +def draw_model_tree(backend: MatchboxDBAdapter) -> Figure: """ Draws the model subgraph. """ - G = get_model_subgraph(engine=engine) + G = backend.get_model_subgraph() node_indices = G.node_indices() datasets = { diff --git a/src/matchbox/linkers/make_linker.py b/src/matchbox/linkers/make_linker.py index 599aea2..88ae58b 100644 --- a/src/matchbox/linkers/make_linker.py +++ b/src/matchbox/linkers/make_linker.py @@ -5,7 +5,7 @@ from pandas import DataFrame from pydantic import BaseModel, Field, ValidationInfo, field_validator -from matchbox.data.results import ProbabilityResults +from matchbox.helpers.results import ProbabilityResults class LinkerSettings(BaseModel): @@ -25,7 +25,8 @@ def _id_for_cmf(cls, v: str, info: ValidationInfo) -> str: f"For offline deduplication, {info.field_name} can be any field. \n\n" "When deduplicating to write back to the Company Matching " f"Framework database, the ID must be {enforce}, generated by " - "retrieving data with cmf.query()." + "retrieving data with cmf.query().", + stacklevel=3, ) return v diff --git a/src/matchbox/server/base.py b/src/matchbox/server/base.py index b1a5cf5..40bb561 100644 --- a/src/matchbox/server/base.py +++ b/src/matchbox/server/base.py @@ -4,6 +4,7 @@ import pandas as pd from pydantic import AnyUrl, BaseModel, Field +from rustworkx import PyDiGraph from sqlalchemy import create_engine from sqlalchemy import text as sqltext from sqlalchemy.engine import Engine @@ -154,7 +155,7 @@ def validate_hashes( ) -> bool: ... @abstractmethod - def get_model_subgraph(self) -> dict: ... + def get_model_subgraph(self) -> PyDiGraph: ... @abstractmethod def get_model(self, model: str) -> MatchboxModelAdapter: ... diff --git a/src/matchbox/server/exceptions.py b/src/matchbox/server/exceptions.py index 987e76f..df0de7c 100644 --- a/src/matchbox/server/exceptions.py +++ b/src/matchbox/server/exceptions.py @@ -7,6 +7,10 @@ class MatchboxConnectionError(Exception): """Connection to Matchbox's backend database failed.""" +class MatchboxValidatonError(Exception): + """Validation of data failed.""" + + class MatchboxDBDataError(Exception): """Data doesn't exist in the Matchbox source table.""" diff --git a/src/matchbox/server/postgresql/adapter.py b/src/matchbox/server/postgresql/adapter.py index 79a2ee4..00351ba 100644 --- a/src/matchbox/server/postgresql/adapter.py +++ b/src/matchbox/server/postgresql/adapter.py @@ -3,6 +3,7 @@ import pandas as pd from dotenv import find_dotenv, load_dotenv from pydantic import BaseSettings, Field +from rustworkx import PyDiGraph from sqlalchemy import ( Engine, bindparam, @@ -170,7 +171,7 @@ def validate_hashes( source=Source, ) - def get_model_subgraph(self) -> dict: + def get_model_subgraph(self) -> PyDiGraph: """Get the full subgraph of a model.""" return get_model_subgraph(engine=self.engine) diff --git a/src/matchbox/server/postgresql/db.py b/src/matchbox/server/postgresql/db.py index 970fbda..4edd28d 100644 --- a/src/matchbox/server/postgresql/db.py +++ b/src/matchbox/server/postgresql/db.py @@ -46,7 +46,7 @@ class MatchboxBase(Base): try: with engine.connect() as connection: connection.execute("SELECT 1") - except Exception: - raise MatchboxConnectionError + except Exception as e: + raise MatchboxConnectionError from e return MatchboxBase, engine diff --git a/src/matchbox/server/postgresql/models.py b/src/matchbox/server/postgresql/models.py index 0abde89..db668d2 100644 --- a/src/matchbox/server/postgresql/models.py +++ b/src/matchbox/server/postgresql/models.py @@ -14,8 +14,8 @@ from matchbox.server.postgresql.link import LinkProbabilities from matchbox.server.postgresql.mixin import SHA1Mixin from matchbox.server.postgresql.utils.insert import ( - insert_deduplication_probabilities, - insert_link_probabilities, + insert_clusters, + insert_probabilities, ) if TYPE_CHECKING: @@ -32,6 +32,8 @@ def count(self): class Models(SHA1Mixin, MatchboxModelAdapter, MatchboxBase): + """The Matchbox PostgreSQL model class, and ModelAdaper for PostgreSQL.""" + __tablename__ = "mb__models" __table_args__ = (UniqueConstraint("name"),) @@ -78,16 +80,20 @@ class Models(SHA1Mixin, MatchboxModelAdapter, MatchboxBase): @property def clusters(self) -> Clusters: + """Returns all clusters the model endorses.""" return self.creates @property def probabilities(self) -> CombinedProbabilities: + """Returns all probabilities the model proposes.""" return CombinedProbabilities(self.proposes_dedupes, self.proposes_links) def parent_neighbours(self) -> list["Models"]: + """Returns the parent neighbours of the model.""" return [x.parent_model for x in self.child_edges] def child_neighbours(self) -> list["Models"]: + """Returns the child neighbours of the model.""" return [x.child_model for x in self.parent_edges] def insert_probabilities( @@ -96,25 +102,29 @@ def insert_probabilities( probability_type: Literal["deduplications", "links"], batch_size: int, ) -> None: - if probability_type == "deduplications": - insert_deduplication_probabilities( - model=self.name, - engine=self.get_session().get_bind(), - probabilities=probabilities, - batch_size=batch_size, - ) - elif probability_type == "links": - insert_link_probabilities( - model=self.name, - engine=self.get_session().get_bind(), - probabilities=probabilities, - batch_size=batch_size, - ) - - def insert_clusters(self, clusters: list[Cluster]) -> None: - for cluster in clusters: - self.creates.add(cluster) - self.session.flush() + """Inserts probabilities into the database.""" + insert_probabilities( + model=self.name, + engine=self.get_session().get_bind(), + probabilities=probabilities, + batch_size=batch_size, + is_deduper=True if probability_type == "deduplications" else False, + ) + + def insert_clusters( + self, + clusters: list[Cluster], + cluster_type: Literal["deduplications", "links"], + batch_size: int, + ) -> None: + """Inserts clusters into the database.""" + insert_clusters( + model=self.name, + engine=self.get_session().get_bind(), + clusters=clusters, + batch_size=batch_size, + is_deduper=True if cluster_type == "deduplications" else False, + ) # From diff --git a/src/matchbox/server/postgresql/utils/db.py b/src/matchbox/server/postgresql/utils/db.py index 84d2030..f5389cf 100644 --- a/src/matchbox/server/postgresql/utils/db.py +++ b/src/matchbox/server/postgresql/utils/db.py @@ -6,7 +6,9 @@ from typing import Any, Callable, Iterable, Tuple import rustworkx as rx +from pg_bulk_ingest import Delete, Upsert, ingest from sqlalchemy import Engine, MetaData, Table +from sqlalchemy.engine.base import Connection from sqlalchemy.exc import NoSuchTableError from sqlalchemy.orm import Session @@ -128,3 +130,26 @@ def _batches() -> Iterable[Tuple[None, None, Iterable[Tuple[Table, dict]]]]: yield None, None, ((table, t) for t in batch) return _batches + + +def batch_ingest( + records: list[dict], + table: Table, + conn: Connection, + batch_size: int, +) -> None: + """Batch ingest records into a database table.""" + + fn_batch = data_to_batch( + records=records, + table=table, + batch_size=batch_size, + ) + + ingest( + conn=conn, + metadata=table.metadata, + batches=fn_batch, + upsert=Upsert.IF_PRIMARY_KEY, + delete=Delete.OFF, + ) diff --git a/src/matchbox/server/postgresql/utils/insert.py b/src/matchbox/server/postgresql/utils/insert.py index 12b5265..8aba80b 100644 --- a/src/matchbox/server/postgresql/utils/insert.py +++ b/src/matchbox/server/postgresql/utils/insert.py @@ -1,6 +1,5 @@ import logging -from pg_bulk_ingest import Delete, Upsert, ingest from sqlalchemy import ( Engine, delete, @@ -8,12 +7,13 @@ from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session -from matchbox.server.base import Probability +from matchbox.server.base import Cluster, Probability from matchbox.server.exceptions import MatchboxDBDataError -from matchbox.server.postgresql.dedupe import DDupeProbabilities, Dedupes -from matchbox.server.postgresql.link import LinkProbabilities, Links +from matchbox.server.postgresql.clusters import Clusters, clusters_association +from matchbox.server.postgresql.dedupe import DDupeContains, DDupeProbabilities, Dedupes +from matchbox.server.postgresql.link import LinkContains, LinkProbabilities, Links from matchbox.server.postgresql.models import Models, ModelsFrom -from matchbox.server.postgresql.utils import data_to_batch +from matchbox.server.postgresql.utils.db import batch_ingest from matchbox.server.postgresql.utils.sha1 import ( list_to_value_ordered_sha1, model_name_to_sha1, @@ -91,213 +91,207 @@ def insert_linker( session.commit() -def insert_deduplication_probabilities( +def insert_probabilities( model: str, engine: Engine, probabilities: list[Probability], batch_size: int, + is_deduper: bool, ) -> None: - """Writes deduplication probabilities to Matchbox.""" - metadata = f"{model} [Deduplication]" + """Writes probabilities to Matchbox.""" + probability_type = "Deduplication" if is_deduper else "Linking" + metadata = f"{model} [{probability_type}]" - if not len(probabilities): - logic_logger.info(f"{metadata} No deduplication data to insert") + if not probabilities: + logic_logger.info(f"{metadata} No {probability_type.lower()} data to insert") return else: logic_logger.info( - f"{metadata} Writing deduplication data with batch size {batch_size}" + f"{metadata} Writing {probability_type.lower()} data " + f"with batch size {batch_size}" ) - def probability_to_ddupe(probability: Probability) -> dict: - """Prepares a Probability for the Dedupes table.""" - return { - "sha1": probability.sha1, - "left": probability.left, - "right": probability.right, - } - - def probability_to_ddupeprobability( - probability: Probability, model_sha1: bytes - ) -> dict: - """Prepares a Probability for the DDupeProbabilities table.""" - return { - "ddupe": probability.sha1, - "model": model_sha1, - "probability": probability.probability, - } + if is_deduper: + ProbabilitiesTable = DDupeProbabilities + NodesTable = Dedupes + else: + ProbabilitiesTable = LinkProbabilities + NodesTable = Links with Session(engine) as session: - # Add probabilities # Get model db_model = session.query(Models).filter_by(name=model).first() - model_sha1 = model.sha1 + model_sha1 = db_model.sha1 if db_model is None: raise MatchboxDBDataError(source=Models, data=model) # Clear old model probabilities - old_ddupe_probs_subquery = db_model.proposes_dedupes.select().with_only_columns( - DDupeProbabilities.model + old_probs_subquery = ( + (db_model.proposes_dedupes if is_deduper else db_model.proposes_links) + .select() + .with_only_columns(ProbabilitiesTable.model) ) session.execute( - delete(DDupeProbabilities).where( - DDupeProbabilities.model.in_(old_ddupe_probs_subquery) + delete(ProbabilitiesTable).where( + ProbabilitiesTable.model.in_(old_probs_subquery) ) ) session.commit() - logic_logger.info(f"{metadata} Removed old deduplication probabilities") + logic_logger.info( + f"{metadata} Removed old {probability_type.lower()} probabilities" + ) with engine.connect() as conn: logic_logger.info( - f"{metadata} Inserting %s deduplication objects", + f"{metadata} Inserting %s {probability_type.lower()} objects", len(probabilities), ) - # Upsert dedupe nodes - # Create data batching function and pass it to ingest - fn_dedupe_batch = data_to_batch( - records=[probability_to_ddupe(prob) for prob in probabilities], - table=Dedupes.__table__, - batch_size=batch_size, - ) - - ingest( + # Upsert nodes + def probability_to_node(probability: Probability) -> dict: + return { + "sha1": probability.sha1, + "left": probability.left, + "right": probability.right, + } + + batch_ingest( + records=[probability_to_node(prob) for prob in probabilities], + table=NodesTable, conn=conn, - metadata=Dedupes.metadata, - batches=fn_dedupe_batch, - upsert=Upsert.IF_PRIMARY_KEY, - delete=Delete.OFF, + batch_size=batch_size, ) - # Insert dedupe probabilities - fn_dedupe_probs_batch = data_to_batch( + # Insert probabilities + def probability_to_probability( + probability: Probability, model_sha1: bytes + ) -> dict: + return { + "ddupe" if is_deduper else "link": probability.sha1, + "model": model_sha1, + "probability": probability.probability, + } + + batch_ingest( records=[ - probability_to_ddupeprobability(prob, model_sha1) - for prob in probabilities + probability_to_probability(prob, model_sha1) for prob in probabilities ], - table=DDupeProbabilities.__table__, - batch_size=batch_size, - ) - - ingest( + table=ProbabilitiesTable, conn=conn, - metadata=DDupeProbabilities.metadata, - batches=fn_dedupe_probs_batch, - upsert=Upsert.IF_PRIMARY_KEY, - delete=Delete.OFF, + batch_size=batch_size, ) logic_logger.info( - f"{metadata} Inserted all %s deduplication objects", + f"{metadata} Inserted all %s {probability_type.lower()} objects", len(probabilities), ) logic_logger.info(f"{metadata} Complete!") -def insert_link_probabilities( +def insert_clusters( model: str, engine: Engine, - probabilities: list[Probability], + clusters: list[Cluster], batch_size: int, + is_deduper: bool, ) -> None: - """Writes link probabilities to Matchbox.""" - metadata = f"{model} [Linking]" + """Writes clusters to Matchbox.""" + metadata = f"{model} [{'Deduplication' if is_deduper else 'Linking'}]" - if not len(probabilities): - logic_logger.info(f"{metadata} No link data to insert") + if not clusters: + logic_logger.info(f"{metadata} No cluster data to insert") return else: - logic_logger.info(f"{metadata} Writing link data with batch size {batch_size}") - - def probability_to_link(probability: Probability) -> dict: - """Prepares a Probability for the Links table.""" - return { - "sha1": probability.sha1, - "left": probability.left, - "right": probability.right, - } - - def probability_to_linkprobability( - probability: Probability, model_sha1: bytes - ) -> dict: - """Prepares a Probability for the LinkProbabilities table.""" - return { - "link": probability.sha1, - "model": model_sha1, - "probability": probability.probability, - } + logic_logger.info( + f"{metadata} Writing cluster data with batch size {batch_size}" + ) + + Contains = DDupeContains if is_deduper else LinkContains with Session(engine) as session: - # Add probabilities # Get model - model = session.query(Models).filter_by(name=model).first() - model_sha1 = model.sha1 + db_model = session.query(Models).filter_by(name=model).first() + model_sha1 = db_model.sha1 - if model is None: + if db_model is None: raise MatchboxDBDataError(source=Models, data=model) - # Clear old model probabilities - old_link_probs_subquery = model.proposes_links.select().with_only_columns( - LinkProbabilities.model + # Clear old model endorsements + old_cluster_creates_subquery = db_model.creates.select().with_only_columns( + Clusters.sha1 ) session.execute( - delete(LinkProbabilities).where( - LinkProbabilities.model.in_(old_link_probs_subquery) + delete(clusters_association).where( + clusters_association.c.child.in_(old_cluster_creates_subquery) ) ) session.commit() - logic_logger.info(f"{metadata} Removed old link probabilities") + logic_logger.info(f"{metadata} Removed old clusters") with engine.connect() as conn: logic_logger.info( - f"{metadata} Inserting %s link objects", - len(probabilities), + f"{metadata} Inserting %s cluster objects", + len(clusters), ) - # Upsert link nodes - # Create data batching function and pass it to ingest - fn_link_batch = data_to_batch( - records=[probability_to_link(prob) for prob in probabilities], - table=Links.__table__, + # Upsert cluster nodes + def cluster_to_cluster(cluster: Cluster) -> dict: + """Prepares a Cluster for the Clusters table.""" + return { + "sha1": cluster.parent, + } + + batch_ingest( + records=list({cluster_to_cluster(cluster) for cluster in clusters}), + table=Clusters, + conn=conn, batch_size=batch_size, ) - ingest( + # Insert cluster contains + def cluster_to_cluster_contains(cluster: Cluster) -> dict: + """Prepares a Cluster for the Contains tables.""" + return { + "parent": cluster.parent, + "child": cluster.child, + } + + batch_ingest( + records=[cluster_to_cluster_contains(cluster) for cluster in clusters], + table=Contains, conn=conn, - metadata=Links.metadata, - batches=fn_link_batch, - upsert=Upsert.IF_PRIMARY_KEY, - delete=Delete.OFF, + batch_size=batch_size, ) - # Insert link probabilities - fn_link_probs_batch = data_to_batch( + # Insert cluster proposed by + def cluster_to_cluster_association(cluster: Cluster, model_sha1: bytes) -> dict: + """Prepares a Cluster for the cluster association table.""" + return { + "parent": model_sha1, + "child": cluster.parent, + } + + batch_ingest( records=[ - probability_to_linkprobability(prob, model_sha1) - for prob in probabilities + cluster_to_cluster_association(cluster, model_sha1) + for cluster in clusters ], - table=LinkProbabilities.__table__, - batch_size=batch_size, - ) - - ingest( + table=clusters_association, conn=conn, - metadata=LinkProbabilities.metadata, - batches=fn_link_probs_batch, - upsert=Upsert.IF_PRIMARY_KEY, - delete=Delete.OFF, + batch_size=batch_size, ) logic_logger.info( - f"{metadata} Inserted all %s link objects", - len(probabilities), + f"{metadata} Inserted all %s cluster objects", + len(clusters), ) logic_logger.info(f"{metadata} Complete!") diff --git a/test/fixtures/models.py b/test/fixtures/models.py index 386dcb5..f6c8204 100644 --- a/test/fixtures/models.py +++ b/test/fixtures/models.py @@ -197,7 +197,9 @@ def make_naive_dd_settings(data: DedupeTestParams) -> Dict[str, Any]: def make_deterministic_li_settings(data: LinkTestParams) -> Dict[str, Any]: comparisons = [] - for field_l, field_r in zip(data.fields_l.keys(), data.fields_r.keys()): + for field_l, field_r in zip( + data.fields_l.keys(), data.fields_r.keys(), strict=False + ): comparisons.append(f"l.{field_l} = r.{field_r}") return { @@ -261,7 +263,7 @@ def make_splink_li_settings(data: LinkTestParams) -> Dict[str, Any]: def make_weighted_deterministic_li_settings(data: LinkTestParams) -> Dict[str, Any]: weighted_comparisons = [] - for field_l, field_r in zip(data.fields_l, data.fields_r): + for field_l, field_r in zip(data.fields_l, data.fields_r, strict=False): weighted_comparisons.append((f"l.{field_l} = r.{field_r}", 1)) return { diff --git a/test/test_linkers.py b/test/test_linkers.py index 1afff01..80c48ee 100644 --- a/test/test_linkers.py +++ b/test/test_linkers.py @@ -110,7 +110,7 @@ def test_linkers( assert linked_df.shape[0] == fx_data.tgt_prob_n assert isinstance(linked_df_with_source, DataFrame) - for field_l, field_r in zip(fields_l, fields_r): + for field_l, field_r in zip(fields_l, fields_r, strict=False): assert linked_df_with_source[field_l].equals(linked_df_with_source[field_r]) # 3. Linked probabilities are inserted correctly @@ -137,7 +137,7 @@ def test_linkers( assert clusters_links_df.parent.nunique() == fx_data.tgt_clus_n assert isinstance(clusters_links_df_with_source, DataFrame) - for field_l, field_r in zip(fields_l, fields_r): + for field_l, field_r in zip(fields_l, fields_r, strict=False): # When we enrich the ClusterResults in a deduplication job, every child # hash will match something in the source data, because we're only using # one dataset. NaNs are therefore impossible. @@ -183,7 +183,7 @@ def unique_non_null(s): assert clusters_all_df.parent.nunique() == fx_data.unique_n assert isinstance(clusters_all_df_with_source, DataFrame) - for field_l, field_r in zip(fields_l, fields_r): + for field_l, field_r in zip(fields_l, fields_r, strict=False): # See above for method # Only change is that we've now introduced expected NaNs for data # that contains different number of entities