Skip to content

Commit

Permalink
Merge pull request #17 from uktrade/feature/match
Browse files Browse the repository at this point in the history
Added match function to the backend
  • Loading branch information
wpfl-dbt authored Dec 13, 2024
2 parents c426882 + ce58805 commit 0ed6046
Show file tree
Hide file tree
Showing 10 changed files with 676 additions and 60 deletions.
20 changes: 20 additions & 0 deletions src/matchbox/common/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,26 @@
T = TypeVar("T")


class Match(BaseModel):
"""A match between primary keys in the Matchbox database."""

cluster: bytes | None
source: str
source_id: set[str] = Field(default_factory=set)
target: str
target_id: set[str] = Field(default_factory=set)

@model_validator(mode="after")
def found_or_none(self) -> "Match":
if self.target_id and not (self.source_id and self.cluster):
raise ValueError(
"A match must have sources and a cluster if target was found."
)
if self.cluster and not self.source_id:
raise ValueError("A match must have source if cluster is set.")
return self


class Probability(BaseModel):
"""A probability of a match in the Matchbox database.
Expand Down
36 changes: 35 additions & 1 deletion src/matchbox/helpers/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pyarrow import Table as ArrowTable
from sqlalchemy import Engine, inspect

from matchbox.common.db import Source, get_schema_table_names
from matchbox.common.db import Match, Source, get_schema_table_names
from matchbox.server import MatchboxDBAdapter, inject_backend


Expand Down Expand Up @@ -91,3 +91,37 @@ def query(
return_type="pandas" if not return_type else return_type,
limit=limit,
)


@inject_backend
def match(
backend: MatchboxDBAdapter,
source_id: str,
source: str,
target: str | list[str],
resolution: str,
threshold: float | dict[str, float] | None = None,
) -> Match | list[Match]:
"""Matches IDs against the selected backend.
Args:
backend: the backend to query
source_id: The ID of the source to match.
source: The name of the source dataset.
target: The name of the target dataset(s).
resolution: the resolution to use for filtering results
threshold (optional): the threshold to use for creating clusters
If None, uses the resolutions' default threshold
If a float, uses that threshold for the specified resolution, and the
resolution's cached thresholds for its ancestors
If a dictionary, expects a shape similar to resolution.ancestors, keyed
by resolution name and valued by the threshold to use for that resolution.
Will use these threshold values instead of the cached thresholds
"""
return backend.match(
source_id=source_id,
source=source,
target=target,
resolution=resolution,
threshold=threshold,
)
5 changes: 5 additions & 0 deletions src/matchbox/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ async def query():
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/match")
async def match():
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/validate/hash")
async def validate_hashes():
raise HTTPException(status_code=501, detail="Not implemented")
Expand Down
12 changes: 11 additions & 1 deletion src/matchbox/server/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
from sqlalchemy import Engine

from matchbox.common.db import Source
from matchbox.common.db import Match, Source
from matchbox.common.graph import ResolutionGraph

if TYPE_CHECKING:
Expand Down Expand Up @@ -251,6 +251,16 @@ def query(
limit: int = None,
) -> PandasDataFrame | ArrowTable | PolarsDataFrame: ...

@abstractmethod
def match(
self,
source_id: str,
source: str,
target: str | list[str],
resolution: str,
threshold: float | dict[str, float] | None = None,
) -> Match | list[Match]: ...

@abstractmethod
def index(self, dataset: Source) -> None: ...

Expand Down
37 changes: 35 additions & 2 deletions src/matchbox/server/postgresql/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy import Engine, and_, bindparam, delete, func, or_, select
from sqlalchemy.orm import Session

from matchbox.common.db import Source, SourceWarehouse
from matchbox.common.db import Match, Source, SourceWarehouse
from matchbox.common.exceptions import (
MatchboxDataError,
MatchboxDatasetError,
Expand All @@ -29,7 +29,7 @@
insert_model,
insert_results,
)
from matchbox.server.postgresql.utils.query import query
from matchbox.server.postgresql.utils.query import match, query
from matchbox.server.postgresql.utils.results import (
get_model_clusters,
get_model_probabilities,
Expand Down Expand Up @@ -300,6 +300,39 @@ def query(
limit=limit,
)

def match(
self,
source_id: str,
source: str,
target: str | list[str],
resolution: str,
threshold: float | dict[str, float] | None = None,
) -> Match | list[Match]:
"""Matches an ID in a source dataset and returns the keys in the targets.
Args:
source_id: The ID of the source to match.
source: The name of the source dataset.
target: The name of the target dataset(s).
resolution: The name of the resolution to use for matching.
threshold (optional): the threshold to use for creating clusters
If None, uses the resolutions' default threshold
If a float, uses that threshold for the specified resolution, and the
resolution's cached thresholds for its ancestors
If a dictionary, expects a shape similar to resolution.ancestors, keyed
by resolution name and valued by the threshold to use for that
resolution.
Will use these threshold values instead of the cached thresholds
"""
return match(
source_id=source_id,
source=source,
target=target,
resolution=resolution,
engine=MBDB.get_engine(),
threshold=threshold,
)

def index(self, dataset: Source) -> None:
"""Indexes a data from your data warehouse within Matchbox.
Expand Down
43 changes: 30 additions & 13 deletions src/matchbox/server/postgresql/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
CheckConstraint,
Column,
ForeignKey,
Index,
UniqueConstraint,
select,
)
Expand Down Expand Up @@ -98,6 +99,22 @@ def descendants(self) -> set["Resolutions"]:
)
return set(session.execute(descendant_query).scalars().all())

def get_lineage(self) -> dict[bytes, float]:
"""Returns all ancestors and their cached truth values from this model."""
with Session(MBDB.get_engine()) as session:
lineage_query = (
select(ResolutionFrom.parent, ResolutionFrom.truth_cache)
.where(ResolutionFrom.child == self.hash)
.order_by(ResolutionFrom.level.desc())
)

results = session.execute(lineage_query).all()

lineage = {parent: truth for parent, truth in results}
lineage[self.hash] = self.truth

return lineage

def get_lineage_to_dataset(
self, dataset: "Resolutions"
) -> tuple[bytes, dict[bytes, float]]:
Expand All @@ -108,31 +125,24 @@ def get_lineage_to_dataset(
)

if self.hash == dataset.hash:
return {}
return {dataset.hash: None}

with Session(MBDB.get_engine()) as session:
path_query = (
select(
ResolutionFrom.parent, ResolutionFrom.truth_cache, Resolutions.type
)
select(ResolutionFrom.parent, ResolutionFrom.truth_cache)
.join(Resolutions, Resolutions.hash == ResolutionFrom.parent)
.where(ResolutionFrom.child == self.hash)
.order_by(ResolutionFrom.level.desc())
)

results = session.execute(path_query).all()

if not any(parent == dataset.hash for parent, _, _ in results):
if not any(parent == dataset.hash for parent, _ in results):
raise ValueError(
f"No path between resolution {self.name}, dataset {dataset.name}"
)

lineage = {
parent: truth
for parent, truth, type in results
if type != ResolutionNodeType.DATASET.value
}

lineage = {parent: truth for parent, truth in results}
lineage[self.hash] = self.truth

return lineage
Expand Down Expand Up @@ -181,8 +191,12 @@ class Contains(CountMixin, MBDB.MatchboxBase):
BYTEA, ForeignKey("clusters.hash", ondelete="CASCADE"), primary_key=True
)

# Constraints
__table_args__ = (CheckConstraint("parent != child", name="no_self_containment"),)
# Constraints and indices
__table_args__ = (
CheckConstraint("parent != child", name="no_self_containment"),
Index("ix_contains_parent_child", "parent", "child"),
Index("ix_contains_child_parent", "child", "parent"),
)


class Clusters(CountMixin, MBDB.MatchboxBase):
Expand Down Expand Up @@ -211,6 +225,9 @@ class Clusters(CountMixin, MBDB.MatchboxBase):
backref="parents",
)

# Constraints and indices
__table_args__ = (Index("ix_clusters_id_gin", id, postgresql_using="gin"),)


class Probabilities(CountMixin, MBDB.MatchboxBase):
"""Table of probabilities that a cluster is correct, according to a resolution."""
Expand Down
59 changes: 46 additions & 13 deletions src/matchbox/server/postgresql/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import io
import pstats
from itertools import islice
from typing import Any, Callable, Iterable, Tuple
from typing import Any, Callable, Iterable

from pg_bulk_ingest import Delete, Upsert, ingest
from sqlalchemy import Engine, MetaData, Table
from sqlalchemy import Engine, Index, MetaData, Table
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import DeclarativeMeta, Session

Expand All @@ -16,7 +16,10 @@
ResolutionNode,
ResolutionNodeType,
)
from matchbox.server.postgresql.orm import ResolutionFrom, Resolutions
from matchbox.server.postgresql.orm import (
ResolutionFrom,
Resolutions,
)

# Retrieval

Expand Down Expand Up @@ -79,18 +82,55 @@ def batched(iterable: Iterable, n: int) -> Iterable:

def data_to_batch(
records: list[tuple], table: Table, batch_size: int
) -> Callable[[str], Tuple[Any]]:
) -> Callable[[str], tuple[Any]]:
"""Constructs a batches function for any dataframe and table."""

def _batches(
high_watermark, # noqa ARG001 required for pg_bulk_ingest
) -> Iterable[Tuple[None, None, Iterable[Tuple[Table, tuple]]]]:
) -> Iterable[tuple[None, None, Iterable[tuple[Table, tuple]]]]:
for batch in batched(records, batch_size):
yield None, None, ((table, t) for t in batch)

return _batches


def isolate_table(table: DeclarativeMeta) -> tuple[MetaData, Table]:
"""Creates an isolated copy of a SQLAlchemy table.
This is used to prevent pg_bulk_ingest from attempting to drop unrelated tables
in the same schema. The function creates a new Table instance with:
* A fresh MetaData instance
* Copied columns
* Recreated indices properly bound to the new table
Args:
table: The DeclarativeMeta class whose table should be isolated
Returns:
A tuple of:
* The isolated SQLAlchemy MetaData
* A new SQLAlchemy Table instance with all columns and indices
"""
isolated_metadata = MetaData(schema=table.__table__.schema)

isolated_table = Table(
table.__table__.name,
isolated_metadata,
*[c._copy() for c in table.__table__.columns],
schema=table.__table__.schema,
)

for idx in table.__table__.indexes:
Index(
idx.name,
*[isolated_table.c[col.name] for col in idx.columns],
**{k: v for k, v in idx.kwargs.items()},
)

return isolated_metadata, isolated_table


def batch_ingest(
records: list[tuple[Any]],
table: DeclarativeMeta,
Expand All @@ -102,14 +142,7 @@ def batch_ingest(
We isolate the table and metadata as pg_bulk_ingest will try and drop unrelated
tables if they're in the same schema.
"""

isolated_metadata = MetaData(schema=table.__table__.schema)
isolated_table = Table(
table.__table__.name,
isolated_metadata,
*[c._copy() for c in table.__table__.columns],
schema=table.__table__.schema,
)
isolated_metadata, isolated_table = isolate_table(table)

fn_batch = data_to_batch(
records=records,
Expand Down
Loading

0 comments on commit 0ed6046

Please sign in to comment.