Skip to content

Commit

Permalink
Initial run at the query and ORM changes
Browse files Browse the repository at this point in the history
  • Loading branch information
wpfl-dbt committed Dec 3, 2024
1 parent fe8b797 commit 7e3529b
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 5 deletions.
28 changes: 26 additions & 2 deletions src/matchbox/server/postgresql/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
CheckConstraint,
Column,
ForeignKey,
Index,
select,
)
from sqlalchemy.dialects.postgresql import ARRAY, BYTEA
Expand Down Expand Up @@ -105,6 +106,22 @@ def descendants(self) -> set["Models"]:
)
return set(session.execute(descendant_query).scalars().all())

def get_lineage(self) -> dict[bytes, float]:
"""Returns all ancestors and their cached truth values from this model."""
with Session(MBDB.get_engine()) as session:
lineage_query = (
select(ModelsFrom.parent, ModelsFrom.truth_cache)
.where(ModelsFrom.child == self.hash)
.order_by(ModelsFrom.level.desc())
)

results = session.execute(lineage_query).all()

lineage = {parent: truth for parent, truth in results}
lineage[self.hash] = self.truth

return lineage

def get_lineage_to_dataset(
self, model: "Models"
) -> tuple[bytes, dict[bytes, float]]:
Expand Down Expand Up @@ -179,8 +196,12 @@ class Contains(CountMixin, MBDB.MatchboxBase):
BYTEA, ForeignKey("clusters.hash", ondelete="CASCADE"), primary_key=True
)

# Constraints
__table_args__ = (CheckConstraint("parent != child", name="no_self_containment"),)
# Constraints and indices
__table_args__ = (
CheckConstraint("parent != child", name="no_self_containment"),
Index("ix_contains_parent_child", "parent", "child"),
Index("ix_contains_child_parent", "child", "parent"),
)


class Clusters(CountMixin, MBDB.MatchboxBase):
Expand Down Expand Up @@ -209,6 +230,9 @@ class Clusters(CountMixin, MBDB.MatchboxBase):
backref="parents",
)

# Constraints and indices
__table_args__ = (Index("ix_clusters_id_gin", id, postgresql_using="gin"),)


class Probabilities(CountMixin, MBDB.MatchboxBase):
"""Table of probabilities that a cluster merge is correct, according to a model."""
Expand Down
215 changes: 212 additions & 3 deletions src/matchbox/server/postgresql/utils/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,21 @@

import pyarrow as pa
from pandas import ArrowDtype, DataFrame
from sqlalchemy import Engine, and_, cast, func, literal, null, select, union
from sqlalchemy import (
Engine,
and_,
cast,
func,
literal,
null,
select,
union,
)
from sqlalchemy.dialects.postgresql import BYTEA
from sqlalchemy.orm import Session
from sqlalchemy.sql.selectable import Select

from matchbox.common.db import sql_to_df
from matchbox.common.db import get_schema_table_names, sql_to_df
from matchbox.common.exceptions import (
MatchboxDatasetError,
MatchboxModelError,
Expand Down Expand Up @@ -362,4 +371,204 @@ def match(
* Retrieves all other IDs in the cluster in the target dataset
* Returns the results as Match objects, one per target
"""
pass
# Split source and target into schema/table
source_schema, source_table = get_schema_table_names(source, validate=True)
targets = [target] if isinstance(target, str) else target
target_pairs = [get_schema_table_names(t, validate=True) for t in targets]

with Session(engine) as session:
# Get truth model
truth_model = session.query(Models).filter(Models.name == model).first()
if truth_model is None:
raise MatchboxModelError(f"Model {model} not found")

# Get source dataset
source_dataset = (
session.query(Models)
.join(Sources, Sources.model == Models.hash)
.filter(
Sources.schema == source_schema,
Sources.table == source_table,
)
.first()
)
if source_dataset is None:
raise MatchboxDatasetError(
db_schema=source_schema,
db_table=source_table,
)

# Get target datasets
target_datasets = []
for schema, table in target_pairs:
dataset = (
session.query(Models)
.join(Sources, Sources.model == Models.hash)
.filter(
Sources.schema == schema,
Sources.table == table,
)
.first()
)
if dataset is None:
raise MatchboxDatasetError(db_schema=schema, db_table=table)
target_datasets.append(dataset)

# Get model lineage and resolve thresholds
lineage_truths = truth_model.get_lineage()
thresholds = _resolve_thresholds(
lineage_truths=lineage_truths,
model=truth_model,
threshold=threshold,
session=session,
)

# Get valid clusters across all models
valid_clusters = _union_valid_clusters(thresholds)

# Unnest cluster IDs
unnested_clusters = (
select(
Clusters.hash, Clusters.dataset, func.unnest(Clusters.id).label("id")
)
.select_from(Clusters)
.cte("unnested_clusters")
)

# Find source ID's initial cluster
source_cluster = (
select(unnested_clusters.c.hash)
.select_from(unnested_clusters)
.where(
and_(
unnested_clusters.c.dataset
== hash_to_hex_decode(source_dataset.hash),
unnested_clusters.c.id == source_id,
)
)
.scalar_subquery()
)

# Build recursive hierarchy CTE going up
hierarchy_up = (
# Base case: direct parents
select(
source_cluster.label("original_cluster"),
source_cluster.label("child"),
Contains.parent.label("parent"),
literal(1).label("level"),
)
.join(Contains, Contains.child == source_cluster)
.where(Contains.parent.in_(select(valid_clusters.c.cluster)))
.cte("hierarchy_up", recursive=True)
)

# Recursive case going up
recursive_up = (
select(
hierarchy_up.c.original_cluster,
hierarchy_up.c.parent.label("child"),
Contains.parent.label("parent"),
(hierarchy_up.c.level + 1).label("level"),
)
.select_from(hierarchy_up)
.join(Contains, Contains.child == hierarchy_up.c.parent)
.where(Contains.parent.in_(select(valid_clusters.c.cluster)))
)

hierarchy_up = hierarchy_up.union_all(recursive_up)

# Get highest parent
highest_parent = (
select(hierarchy_up.c.parent)
.order_by(hierarchy_up.c.level.desc())
.limit(1)
.scalar_subquery()
)

# Build recursive hierarchy CTE going down
hierarchy_down = (
# Base case: direct children from highest parent
select(
highest_parent.label("parent"),
Contains.child.label("child"),
literal(1).label("level"),
unnested_clusters.c.dataset.label("dataset"),
unnested_clusters.c.id.label("id"),
)
.select_from(Contains)
.join(unnested_clusters, unnested_clusters.c.hash == Contains.child)
.where(Contains.parent == highest_parent)
.cte("hierarchy_down", recursive=True)
)

# Recursive case going down
recursive_down = (
select(
hierarchy_down.c.parent,
Contains.child.label("child"),
(hierarchy_down.c.level + 1).label("level"),
unnested_clusters.c.dataset.label("dataset"),
unnested_clusters.c.id.label("id"),
)
.select_from(hierarchy_down)
.join(Contains, Contains.parent == hierarchy_down.c.child)
.join(unnested_clusters, unnested_clusters.c.hash == Contains.child)
)

hierarchy_down = hierarchy_down.union_all(recursive_down)

# Get all matched IDs
matches = session.execute(
select(
hierarchy_down.c.dataset,
hierarchy_down.c.id,
)
.distinct()
.select_from(hierarchy_down)
).all()

# Group matches by dataset
matches_by_dataset = {}
for dataset_hash, id in matches:
if dataset_hash not in matches_by_dataset:
matches_by_dataset[dataset_hash] = set()
matches_by_dataset[dataset_hash].add(id)

# Create Match objects for each target
result = []
for target_dataset in target_datasets:
# Get source/target table names
source_name = f"{source_schema}.{source_table}"
target_schema, target_table = next(
(schema, table)
for schema, table in target_pairs
if session.get(Sources, target_dataset.hash).schema == schema
and session.get(Sources, target_dataset.hash).table == table
)
target_name = f"{target_schema}.{target_table}"

highest_cluster = highest_parent.scalar() if matches else None

# Get source and target IDs
source_ids = {
id
for dataset_hash, id in matches
if dataset_hash == source_dataset.hash
}
target_ids = {
id
for dataset_hash, id in matches
if dataset_hash == target_dataset.hash
}

match_obj = Match(
cluster=highest_cluster,
source=source_name,
source_id=source_ids,
target=target_name,
target_id=target_ids,
)
result.append(match_obj)

return result[0] if isinstance(target, str) else result

0 comments on commit 7e3529b

Please sign in to comment.