Skip to content

Commit

Permalink
Finished adapter methods, added bugbears to linting rules
Browse files Browse the repository at this point in the history
  • Loading branch information
Will Langdale committed Oct 11, 2024
1 parent 802c3b3 commit a25cb36
Show file tree
Hide file tree
Showing 15 changed files with 263 additions and 360 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ select = [
"E",
"F",
"I",
"B",
# "D"
]
ignore = []
Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,17 @@
import logging
from abc import ABC, abstractmethod
from typing import List, Optional, Union
from typing import List, Optional

import rustworkx as rx
from dotenv import find_dotenv, load_dotenv
from matchbox.server.base import Cluster, MatchboxDBAdapter, Probability
from matchbox.server.postgresql.utils.sha1 import (
columns_to_value_ordered_sha1,
list_to_value_ordered_sha1,
)
from pandas import DataFrame, concat
from pg_bulk_ingest import Delete, Upsert, ingest
from pydantic import BaseModel, ConfigDict, model_validator
from sqlalchemy import (
Engine,
Table,
delete,
)
from sqlalchemy.orm import Session

from matchbox.server.base import Cluster, MatchboxDBAdapter, Probability
from matchbox.server.exceptions import MatchboxDBDataError
from matchbox.server.postgresql import utils as du
from matchbox.server.postgresql.clusters import Clusters, clusters_association
from matchbox.server.postgresql.db import ENGINE
from matchbox.server.postgresql.dedupe import DDupeContains
from matchbox.server.postgresql.link import LinkContains
from matchbox.server.postgresql.models import Models
from sqlalchemy import Table

logic_logger = logging.getLogger("mb_logic")

Expand Down Expand Up @@ -66,39 +56,10 @@ def to_records(self) -> list[Probability | Cluster]:
"""Returns the results as a list of records suitable for insertion."""
return

def to_cmf(self, backend: MatchboxDBAdapter) -> None:
"""Writes the results to the CMF database."""
if self.left == self.right:
# Deduper
backend.insert_model(
model=self.run_name,
left=self.left,
description=self.description,
)

model = backend.get_model(model=self.run_name)

model.insert_probabilities(
probabilites=self.to_records(),
probability_type="deduplications",
batch_size=backend.settings.batch_size,
)
else:
# Linker
backend.insert_model(
model=self.run_name,
left=self.left,
right=self.right,
description=self.description,
)

model = backend.get_model(model=self.run_name)

model.insert_probabilities(
probabilites=self.to_records(),
probability_type="links",
batch_size=backend.settings.batch_size,
)
@abstractmethod
def to_matchbox(self, backend: MatchboxDBAdapter) -> None:
"""Writes the results to the Matchbox database."""
return


class ProbabilityResults(ResultsBaseDataclass):
Expand Down Expand Up @@ -184,32 +145,40 @@ def to_records(self, backend: MatchboxDBAdapter | None) -> list[Probability]:
hash_type=hash_type,
)

# Prep and return
pre_prep_df = self.dataframe.copy()
cols = ["left_id", "right_id"]
pre_prep_df[cols] = pre_prep_df[cols].astype("binary[pyarrow]")
pre_prep_df["sha1"] = du.columns_to_value_ordered_sha1(
data=self.dataframe, columns=cols
)
pre_prep_df.sha1 = pre_prep_df.sha1.astype("binary[pyarrow]")
pre_prep_df = pre_prep_df.rename(
columns={"left_id": "left", "right_id": "right"}
# Preprocess the dataframe
pre_prep_df = self.dataframe[["left_id", "right_id", "probability"]].copy()
pre_prep_df[["left_id", "right_id"]] = pre_prep_df[
["left_id", "right_id"]
].astype("binary[pyarrow]")
pre_prep_df["sha1"] = columns_to_value_ordered_sha1(
data=pre_prep_df, columns=["left_id", "right_id"]
)

pre_prep_df = pre_prep_df[["sha1", "left", "right", "probability"]]
pre_prep_df["sha1"] = pre_prep_df["sha1"].astype("binary[pyarrow]")

return [
Probability(
sha1=sha1,
left=left,
right=right,
probability=probability,
)
for sha1, left, right, probability in self.dataframe.itertuples(
index=False, name=None
)
Probability(sha1=row[0], left=row[1], right=row[2], probability=row[3])
for row in pre_prep_df[
["sha1", "left_id", "right_id", "probability"]
].to_numpy()
]

def to_matchbox(self, backend: MatchboxDBAdapter) -> None:
"""Writes the results to the Matchbox database."""
backend.insert_model(
model=self.run_name,
left=self.left,
right=self.right if self.left != self.right else None,
description=self.description,
)

model = backend.get_model(model=self.run_name)

model.insert_probabilities(
probabilites=self.to_records(),
probability_type="links" if self.left != self.right else "deduplications",
batch_size=backend.settings.batch_size,
)


class ClusterResults(ResultsBaseDataclass):
"""Cluster data produced by using to_clusters on ProbabilityResults.
Expand Down Expand Up @@ -262,127 +231,18 @@ def to_df(self) -> DataFrame:
"""Returns the results as a DataFrame."""
return self.dataframe.copy().convert_dtypes(dtype_backend="pyarrow")

def _to_mb_logic(
self,
contains_class: Union[DDupeContains, LinkContains],
engine: Engine = ENGINE,
) -> None:
"""Handles common logic for writing dedupe or link clusters to the database.
In ClusterResults, the only difference is the tables being written to.
* Adds the new cluster nodes
* Adds model endorsement of these nodes with "creates" edge
* Adds the contains edges to show which clusters contain which
Args:
contains_class: the target table, one of DDupeContains or LinkContains
engine: a SQLAlchemy Engine object for the database
Raises:
MatchboxDBDataError if model wasn't inserted correctly
"""
Contains = contains_class
with Session(engine) as session:
# Add clusters
# Get model
model = session.query(Models).filter_by(name=self.run_name).first()
model_sha1 = model.sha1

if model is None:
raise MatchboxDBDataError(source=Models, data=self.run_name)

# Clear old model endorsements
old_cluster_creates_subquery = model.creates.select().with_only_columns(
Clusters.sha1
)

session.execute(
delete(clusters_association).where(
clusters_association.c.child.in_(old_cluster_creates_subquery)
)
)

session.commit()

logic_logger.info(f"[{self.metadata}] Removed old clusters")

with engine.connect() as conn:
logic_logger.info(
f"[{self.metadata}] Inserting %s cluster objects",
self.dataframe.shape[0],
)

clusters_prepped = self.dataframe.astype("binary[pyarrow]")

# Upsert cluster nodes
# Create data batching function and pass it to ingest
fn_cluster_batch = du.data_to_batch(
dataframe=(
clusters_prepped.drop_duplicates(subset="parent").rename(
columns={"parent": "sha1"}
)[list(Clusters.__table__.columns.keys())]
),
table=Clusters.__table__,
batch_size=self._batch_size,
)

ingest(
conn=conn,
metadata=Clusters.metadata,
batches=fn_cluster_batch,
upsert=Upsert.IF_PRIMARY_KEY,
delete=Delete.OFF,
)

# Insert cluster contains
fn_cluster_contains_batch = du.data_to_batch(
dataframe=clusters_prepped[list(Contains.__table__.columns.keys())],
table=Contains.__table__,
batch_size=self._batch_size,
)

ingest(
conn=conn,
metadata=Contains.metadata,
batches=fn_cluster_contains_batch,
upsert=Upsert.IF_PRIMARY_KEY,
delete=Delete.OFF,
)

# Insert cluster proposed by
fn_cluster_proposed_batch = du.data_to_batch(
dataframe=(
clusters_prepped.drop("child", axis=1)
.rename(columns={"parent": "child"})
.assign(parent=model_sha1)[
list(clusters_association.columns.keys())
]
),
table=clusters_association,
batch_size=self._batch_size,
)

ingest(
conn=conn,
metadata=clusters_association.metadata,
batches=fn_cluster_proposed_batch,
upsert=Upsert.IF_PRIMARY_KEY,
delete=Delete.OFF,
)

logic_logger.info(
f"[{self.metadata}] Inserted all %s cluster objects",
self.dataframe.shape[0],
)

def _deduper_to_cmf(self, engine: Engine = ENGINE) -> None:
"""Writes the results of a deduper to the CMF database."""
self._to_mb_logic(contains_class=DDupeContains, engine=engine)

def _linker_to_cmf(self, engine: Engine = ENGINE) -> None:
"""Writes the results of a linker to the CMF database."""
self._to_mb_logic(contains_class=LinkContains, engine=engine)
def to_records(self) -> list[Cluster]:
"""Returns the results as a list of records suitable for insertion."""
parent_child_pairs = self.dataframe[["parent", "child"]].values
return [Cluster(parent=row[0], child=row[1]) for row in parent_child_pairs]

def to_matchbox(self, backend: MatchboxDBAdapter) -> None:
"""Writes the results to the Matchbox database."""
model = backend.get_model(model=self.run_name)
model.insert_clusters(
clusters=self.to_records(),
batch_size=backend.settings.batch_size,
)


def get_unclustered(
Expand Down Expand Up @@ -427,7 +287,7 @@ def to_clusters(
*data: Optional[DataFrame],
results: ProbabilityResults,
key: str,
threshold: float = 0.0,
threshold: float = None,
) -> ClusterResults:
"""
Takes a models probabilistic outputs and turns them into clusters.
Expand All @@ -445,6 +305,9 @@ def to_clusters(
Returns
A ClusterResults object
"""
if not threshold:
threshold = 0.0

all_edges = (
results.dataframe.query("probability >= @threshold")
.filter(["left_id", "right_id"])
Expand Down Expand Up @@ -476,7 +339,7 @@ def to_clusters(
res["child"].append(child_hash)

# Must be sorted to be symmetric
parent_hash = du.list_to_value_ordered_sha1(child_hashes)
parent_hash = list_to_value_ordered_sha1(child_hashes)

res["parent"] += [parent_hash] * len(component)

Expand Down
3 changes: 2 additions & 1 deletion src/matchbox/dedupers/make_deduper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def _id_for_cmf(cls, v: str, info: ValidationInfo) -> str:
f"For offline deduplication, {info.field_name} can be any field. \n\n"
"When deduplicating to write back to the Company Matching "
f"Framework database, the ID must be {enforce}, generated by "
"retrieving data with cmf.query()."
"retrieving data with cmf.query().",
stacklevel=3,
)
return v

Expand Down
8 changes: 5 additions & 3 deletions src/matchbox/helpers/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sqlalchemy.orm import Session

from matchbox.server import MatchboxDBAdapter
from matchbox.server.exceptions import MatchboxSourceTableError
from matchbox.server.exceptions import MatchboxSourceTableError, MatchboxValidatonError


def get_schema_table_names(full_name: str, validate: bool = False) -> tuple[str, str]:
Expand All @@ -28,7 +28,7 @@ def get_schema_table_names(full_name: str, validate: bool = False) -> tuple[str,
Raises:
ValueError: When the function can't detect either a
schema.table or table format in the input
ValidationError: If both schema and table can't be detected
MatchboxValidatonError: If both schema and table can't be detected
when the validate argument is True
Returns:
Expand All @@ -52,7 +52,9 @@ def get_schema_table_names(full_name: str, validate: bool = False) -> tuple[str,
)

if validate and schema is None:
raise ("Schema could not be detected and validation required.")
raise MatchboxValidatonError(
"Schema could not be detected and validation required."
)

return (schema, table)

Expand Down
8 changes: 3 additions & 5 deletions src/matchbox/helpers/visualisation.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import rustworkx as rx
from matplotlib.figure import Figure
from rustworkx.visualization import mpl_draw
from sqlalchemy import Engine

from matchbox.data import ENGINE
from matchbox.data.utils import get_model_subgraph
from matchbox.server.base import MatchboxDBAdapter


def draw_model_tree(engine: Engine = ENGINE) -> Figure:
def draw_model_tree(backend: MatchboxDBAdapter) -> Figure:
"""
Draws the model subgraph.
"""
G = get_model_subgraph(engine=engine)
G = backend.get_model_subgraph()

node_indices = G.node_indices()
datasets = {
Expand Down
5 changes: 3 additions & 2 deletions src/matchbox/linkers/make_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pandas import DataFrame
from pydantic import BaseModel, Field, ValidationInfo, field_validator

from matchbox.data.results import ProbabilityResults
from matchbox.helpers.results import ProbabilityResults


class LinkerSettings(BaseModel):
Expand All @@ -25,7 +25,8 @@ def _id_for_cmf(cls, v: str, info: ValidationInfo) -> str:
f"For offline deduplication, {info.field_name} can be any field. \n\n"
"When deduplicating to write back to the Company Matching "
f"Framework database, the ID must be {enforce}, generated by "
"retrieving data with cmf.query()."
"retrieving data with cmf.query().",
stacklevel=3,
)
return v

Expand Down
Loading

0 comments on commit a25cb36

Please sign in to comment.