Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added match function to the backend #17

Merged
merged 24 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3d5b897
Created initial structure and basic unit test for match function
wpfl-dbt Dec 3, 2024
c34df25
Updated pyproject.toml to work with uv 0.5+
wpfl-dbt Dec 3, 2024
fe8b797
Added API endpoint (unimplemented)
wpfl-dbt Dec 3, 2024
7e3529b
Initial run at the query and ORM changes
wpfl-dbt Dec 3, 2024
35b62f3
Attempting to copy indices
wpfl-dbt Dec 3, 2024
8ea880f
Working indices on ORM
wpfl-dbt Dec 4, 2024
e9d009c
Changed _resolve_cluster_hierarchy and lineage to handle datasets bet…
wpfl-dbt Dec 4, 2024
3a51657
Tidied up _resolve_cluster_hierarchy() with new logic
wpfl-dbt Dec 4, 2024
0097f61
Working match(), not yet passing unit tests
wpfl-dbt Dec 4, 2024
719f152
Moved match unit tests into separate functions
wpfl-dbt Dec 4, 2024
bcba176
Merged main
wpfl-dbt Dec 5, 2024
802643a
Factored out subqueries of match()
wpfl-dbt Dec 5, 2024
72a9b84
Wrote functions to visualise the subgraph to help debug
wpfl-dbt Dec 5, 2024
f6a3225
Test cases with data working
wpfl-dbt Dec 8, 2024
6a936d7
All unit tests passing
wpfl-dbt Dec 8, 2024
79ee5b5
Minor refactor of query so it's easier to read
wpfl-dbt Dec 8, 2024
b3af74f
Merged from main
wpfl-dbt Dec 12, 2024
1db389f
Updated Match validation
wpfl-dbt Dec 12, 2024
f300e96
Dealt with all comments
wpfl-dbt Dec 12, 2024
7b4a199
Fixed validation
wpfl-dbt Dec 12, 2024
1c29775
Removed data subgraph visualisation functions
wpfl-dbt Dec 13, 2024
fce3781
One to zero test added
wpfl-dbt Dec 13, 2024
592ca99
Removed a bunch of redundant code from processing of match results
wpfl-dbt Dec 13, 2024
ce58805
Fixed unit test
wpfl-dbt Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/matchbox/common/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from matchbox.common.hash import HASH_FUNC
from pandas import DataFrame
from pyarrow import Table as ArrowTable
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator
from sqlalchemy import (
LABEL_STYLE_TABLENAME_PLUS_COL,
ColumnElement,
Expand All @@ -33,6 +33,28 @@
T = TypeVar("T")


class Match(BaseModel):
"""A match between primary keys in the Matchbox database."""

cluster: bytes | None
source: str
source_id: set[str] = Field(default_factory=set)
target: str
target_id: set[str] = Field(default_factory=set)

@model_validator(mode="after")
def found_or_none(self) -> "Match":
if self.cluster is None and (self.source_id or self.target_id):
raise ValueError(
"A match must have a cluster if source_id or target_id is set."
)
elif self.cluster is not None and not (self.source_id or self.target_id):
raise ValueError(
"A match must have source_id or target_id if cluster is set."
)
lmazz1-dbt marked this conversation as resolved.
Show resolved Hide resolved
return self


class Probability(BaseModel):
"""A probability of a match in the Matchbox database.

Expand Down
36 changes: 35 additions & 1 deletion src/matchbox/helpers/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pyarrow import Table as ArrowTable
from sqlalchemy import Engine, inspect

from matchbox.common.db import Source, get_schema_table_names
from matchbox.common.db import Match, Source, get_schema_table_names
from matchbox.server import MatchboxDBAdapter, inject_backend


Expand Down Expand Up @@ -91,3 +91,37 @@ def query(
return_type="pandas" if not return_type else return_type,
limit=limit,
)


@inject_backend
def match(
backend: MatchboxDBAdapter,
source_id: str,
source: str,
target: str | list[str],
model: str,
threshold: float | dict[str, float] | None = None,
) -> Match | list[Match]:
"""Matches IDs against the selected backend.

Args:
backend: the backend to query
source_id: The ID of the source to match.
source: The name of the source dataset.
target: The name of the target dataset(s).
model: the model to use for filtering results
threshold (optional): the threshold to use for creating clusters
If None, uses the models' default threshold
If a float, uses that threshold for the specified model, and the
model's cached thresholds for its ancestors
If a dictionary, expects a shape similar to model.ancestors, keyed
by model name and valued by the threshold to use for that model. Will
use these threshold values instead of the cached thresholds
"""
return backend.match(
source_id=source_id,
source=source,
target=target,
model=model,
threshold=threshold,
)
92 changes: 92 additions & 0 deletions src/matchbox/helpers/visualisation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from collections import defaultdict
from itertools import count

import rustworkx as rx
from matplotlib.figure import Figure
from rustworkx.visualization import mpl_draw
Expand Down Expand Up @@ -40,3 +43,92 @@ def draw_model_tree(backend: MatchboxDBAdapter) -> Figure:
edge_labels=lambda edge: edge["type"],
font_size=8,
)


def draw_data_tree(graph: rx.PyDiGraph) -> str:
"""
Convert a rustworkx PyDiGraph to Mermaid graph visualization code.

Args:
graph (rx.PyDiGraph): A rustworkx directed graph with nodes containing 'id' and
'type' attributes

Returns:
str: Mermaid graph definition code
"""
mermaid_lines = ["graph LR"]

counters = defaultdict(count, {"hash": count(1)})
node_to_var = {}
node_types = {}
data_nodes = set()

def format_id(id_value):
"""Format ID value, converting bytes to hex if needed."""
if isinstance(id_value, bytes):
return f"\\x{id_value.hex()}"
return f"['{str(id_value)}']"

for node_idx in graph.node_indices():
node_data = graph.get_node_data(node_idx)
if isinstance(node_data, dict):
node_type = node_data.get("type", "")
node_types[node_idx] = node_type
if node_type == "data":
data_nodes.add(node_idx)

for node_idx, node_type in node_types.items():
if node_type == "source":
node_data = graph.get_node_data(node_idx)
table_name = node_data["id"].split(".")[-1]
node_to_var[node_idx] = table_name

counter = count(1)
for predecessor in graph.predecessor_indices(node_idx):
if predecessor in data_nodes:
node_to_var[predecessor] = f"{table_name}{str(next(counter))}"
data_nodes.remove(predecessor)

remaining_counter = count(len(node_to_var) + 1)
for node_idx in data_nodes:
node_to_var[node_idx] = str(next(remaining_counter))

for node_idx, node_type in node_types.items():
if node_type == "cluster":
node_to_var[node_idx] = f"hash{next(counters['hash'])}"

sources = []
data_defs = []
clusters = []

for node_idx, node_type in node_types.items():
node_data = graph.get_node_data(node_idx)
var_name = node_to_var[node_idx]

if node_type == "source":
node_def = f' {var_name}["{node_data["id"]}"]'
sources.append(node_def)
elif node_type == "data":
node_label = format_id(node_data["id"])
node_label = node_label.strip("[]'")
node_def = f' {var_name}["{node_label}"]'
data_defs.append(node_def)
elif node_type == "cluster":
node_label = format_id(node_data["id"])
node_def = f' {var_name}["{node_label}"]'
clusters.append(node_def)

mermaid_lines.extend(sources)
mermaid_lines.extend(data_defs)
mermaid_lines.extend(clusters)

mermaid_lines.append("")

for edge in graph.edge_list():
source = edge[0]
target = edge[1]
source_var = node_to_var[source]
target_var = node_to_var[target]
mermaid_lines.append(f" {source_var} --> {target_var}")

return "\n".join(mermaid_lines)
5 changes: 5 additions & 0 deletions src/matchbox/server/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ async def query():
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/match")
async def match():
raise HTTPException(status_code=501, detail="Not implemented")


@app.get("/validate/hash")
async def validate_hashes():
raise HTTPException(status_code=501, detail="Not implemented")
Expand Down
12 changes: 11 additions & 1 deletion src/matchbox/server/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from rustworkx import PyDiGraph
from sqlalchemy import Engine

from matchbox.common.db import Source
from matchbox.common.db import Match, Source

if TYPE_CHECKING:
from pandas import DataFrame as PandasDataFrame
Expand Down Expand Up @@ -251,6 +251,16 @@ def query(
limit: int = None,
) -> PandasDataFrame | ArrowTable | PolarsDataFrame: ...

@abstractmethod
def match(
self,
source_id: str,
source: str,
target: str | list[str],
model: str,
lmazz1-dbt marked this conversation as resolved.
Show resolved Hide resolved
threshold: float | dict[str, float] | None = None,
) -> Match | list[Match]: ...

@abstractmethod
def index(self, dataset: Source) -> None: ...

Expand Down
36 changes: 34 additions & 2 deletions src/matchbox/server/postgresql/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy import Engine, and_, bindparam, delete, func, or_, select
from sqlalchemy.orm import Session

from matchbox.common.db import Source, SourceWarehouse
from matchbox.common.db import Match, Source, SourceWarehouse
from matchbox.common.exceptions import (
MatchboxDataError,
MatchboxDatasetError,
Expand All @@ -28,7 +28,7 @@
insert_model,
insert_results,
)
from matchbox.server.postgresql.utils.query import query
from matchbox.server.postgresql.utils.query import match, query
from matchbox.server.postgresql.utils.results import (
get_model_clusters,
get_model_probabilities,
Expand Down Expand Up @@ -292,6 +292,38 @@ def query(
limit=limit,
)

def match(
self,
source_id: str,
source: str,
target: str | list[str],
model: str,
threshold: float | dict[str, float] | None = None,
) -> Match | list[Match]:
"""Matches an ID in a source dataset and returns the keys in the targets.

Args:
source_id: The ID of the source to match.
source: The name of the source dataset.
target: The name of the target dataset(s).
model: The name of the model to use for matching.
threshold (optional): the threshold to use for creating clusters
If None, uses the models' default threshold
If a float, uses that threshold for the specified model, and the
model's cached thresholds for its ancestors
If a dictionary, expects a shape similar to model.ancestors, keyed
by model name and valued by the threshold to use for that model. Will
use these threshold values instead of the cached thresholds
"""
return match(
source_id=source_id,
source=source,
target=target,
model=model,
engine=MBDB.get_engine(),
threshold=threshold,
)

def index(self, dataset: Source) -> None:
"""Indexes a data from your data warehouse within Matchbox.

Expand Down
41 changes: 30 additions & 11 deletions src/matchbox/server/postgresql/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
CheckConstraint,
Column,
ForeignKey,
Index,
select,
)
from sqlalchemy.dialects.postgresql import ARRAY, BYTEA
Expand Down Expand Up @@ -105,6 +106,22 @@ def descendants(self) -> set["Models"]:
)
return set(session.execute(descendant_query).scalars().all())

def get_lineage(self) -> dict[bytes, float]:
"""Returns all ancestors and their cached truth values from this model."""
with Session(MBDB.get_engine()) as session:
lineage_query = (
select(ModelsFrom.parent, ModelsFrom.truth_cache)
.where(ModelsFrom.child == self.hash)
.order_by(ModelsFrom.level.desc())
lmazz1-dbt marked this conversation as resolved.
Show resolved Hide resolved
)

results = session.execute(lineage_query).all()

lineage = {parent: truth for parent, truth in results}
lineage[self.hash] = self.truth

return lineage

def get_lineage_to_dataset(
self, model: "Models"
) -> tuple[bytes, dict[bytes, float]]:
Expand All @@ -115,29 +132,24 @@ def get_lineage_to_dataset(
)

if self.hash == model.hash:
return {}
return {model.hash: None}

with Session(MBDB.get_engine()) as session:
path_query = (
select(ModelsFrom.parent, ModelsFrom.truth_cache, Models.type)
select(ModelsFrom.parent, ModelsFrom.truth_cache)
.join(Models, Models.hash == ModelsFrom.parent)
.where(ModelsFrom.child == self.hash)
.order_by(ModelsFrom.level.desc())
)

results = session.execute(path_query).all()

if not any(parent == model.hash for parent, _, _ in results):
if not any(parent == model.hash for parent, _ in results):
raise ValueError(
f"No path exists between model {self.name} and dataset {model.name}"
)

lineage = {
parent: truth
for parent, truth, type in results
if type != ModelType.DATASET.value
}

lineage = {parent: truth for parent, truth in results}
lineage[self.hash] = self.truth

return lineage
Expand Down Expand Up @@ -179,8 +191,12 @@ class Contains(CountMixin, MBDB.MatchboxBase):
BYTEA, ForeignKey("clusters.hash", ondelete="CASCADE"), primary_key=True
)

# Constraints
__table_args__ = (CheckConstraint("parent != child", name="no_self_containment"),)
# Constraints and indices
__table_args__ = (
CheckConstraint("parent != child", name="no_self_containment"),
Index("ix_contains_parent_child", "parent", "child"),
Index("ix_contains_child_parent", "child", "parent"),
)


class Clusters(CountMixin, MBDB.MatchboxBase):
Expand Down Expand Up @@ -209,6 +225,9 @@ class Clusters(CountMixin, MBDB.MatchboxBase):
backref="parents",
)

# Constraints and indices
__table_args__ = (Index("ix_clusters_id_gin", id, postgresql_using="gin"),)


class Probabilities(CountMixin, MBDB.MatchboxBase):
"""Table of probabilities that a cluster merge is correct, according to a model."""
Expand Down
Loading
Loading