From 3d5b897c58f22fdd4e0c4bb5959dd750f6d95abd Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Tue, 3 Dec 2024 07:49:33 +0000 Subject: [PATCH 01/22] Created initial structure and basic unit test for match function --- src/matchbox/helpers/selector.py | 36 ++++++++++- src/matchbox/server/base.py | 12 +++- src/matchbox/server/models.py | 10 ++++ src/matchbox/server/postgresql/adapter.py | 36 ++++++++++- src/matchbox/server/postgresql/utils/query.py | 24 +++++++- test/fixtures/data.py | 42 +++++++++++-- test/server/test_adapter.py | 60 ++++++++++++++++++- uv.lock | 4 +- 8 files changed, 210 insertions(+), 14 deletions(-) diff --git a/src/matchbox/helpers/selector.py b/src/matchbox/helpers/selector.py index 41b3632..bb6b9ad 100644 --- a/src/matchbox/helpers/selector.py +++ b/src/matchbox/helpers/selector.py @@ -6,7 +6,7 @@ from matchbox.common.db import get_schema_table_names from matchbox.server import MatchboxDBAdapter, inject_backend -from matchbox.server.models import Source +from matchbox.server.models import Match, Source @inject_backend @@ -92,3 +92,37 @@ def query( return_type="pandas" if not return_type else return_type, limit=limit, ) + + +@inject_backend +def match( + backend: MatchboxDBAdapter, + source_id: str, + source: str, + target: str | list[str], + model: str, + threshold: float | dict[str, float] | None = None, +) -> Match | list[Match]: + """Matches IDs against the selected backend. + + Args: + backend: the backend to query + source_id: The ID of the source to match. + source: The name of the source dataset. + target: The name of the target dataset(s). + model: the model to use for filtering results + threshold (optional): the threshold to use for creating clusters + If None, uses the models' default threshold + If a float, uses that threshold for the specified model, and the + model's cached thresholds for its ancestors + If a dictionary, expects a shape similar to model.ancestors, keyed + by model name and valued by the threshold to use for that model. Will + use these threshold values instead of the cached thresholds + """ + return backend.match( + source_id=source_id, + source=source, + target=target, + model=model, + threshold=threshold, + ) diff --git a/src/matchbox/server/base.py b/src/matchbox/server/base.py index 9b472f4..33303e5 100644 --- a/src/matchbox/server/base.py +++ b/src/matchbox/server/base.py @@ -20,7 +20,7 @@ from rustworkx import PyDiGraph from sqlalchemy import Engine -from matchbox.server.models import Source +from matchbox.server.models import Match, Source if TYPE_CHECKING: from pandas import DataFrame as PandasDataFrame @@ -251,6 +251,16 @@ def query( limit: int = None, ) -> PandasDataFrame | ArrowTable | PolarsDataFrame: ... + @abstractmethod + def match( + self, + source_id: str, + source: str, + target: str | list[str], + model: str, + threshold: float | dict[str, float] | None = None, + ) -> Match | list[Match]: ... + @abstractmethod def index(self, dataset: Source) -> None: ... diff --git a/src/matchbox/server/models.py b/src/matchbox/server/models.py index 62603a0..3caf85e 100644 --- a/src/matchbox/server/models.py +++ b/src/matchbox/server/models.py @@ -21,6 +21,16 @@ T = TypeVar("T") +class Match(BaseModel): + """A match between primary keys in the Matchbox database.""" + + cluster: bytes + source: str + source_id: set[str] = Field(default_factory=set) + target: str + target_id: set[str] = Field(default_factory=set) + + class Probability(BaseModel): """A probability of a match in the Matchbox database. diff --git a/src/matchbox/server/postgresql/adapter.py b/src/matchbox/server/postgresql/adapter.py index e7f5612..b84bf74 100644 --- a/src/matchbox/server/postgresql/adapter.py +++ b/src/matchbox/server/postgresql/adapter.py @@ -12,7 +12,7 @@ ) from matchbox.common.results import ClusterResults, ProbabilityResults, Results from matchbox.server.base import MatchboxDBAdapter, MatchboxModelAdapter -from matchbox.server.models import Source, SourceWarehouse +from matchbox.server.models import Match, Source, SourceWarehouse from matchbox.server.postgresql.db import MBDB, MatchboxPostgresSettings from matchbox.server.postgresql.orm import ( Clusters, @@ -28,7 +28,7 @@ insert_model, insert_results, ) -from matchbox.server.postgresql.utils.query import query +from matchbox.server.postgresql.utils.query import match, query from matchbox.server.postgresql.utils.results import ( get_model_clusters, get_model_probabilities, @@ -292,6 +292,38 @@ def query( limit=limit, ) + def match( + self, + source_id: str, + source: str, + target: str | list[str], + model: str, + threshold: float | dict[str, float] | None = None, + ) -> Match | list[Match]: + """Matches an ID in a source dataset and returns the keys in the targets. + + Args: + source_id: The ID of the source to match. + source: The name of the source dataset. + target: The name of the target dataset(s). + model: The name of the model to use for matching. + threshold (optional): the threshold to use for creating clusters + If None, uses the models' default threshold + If a float, uses that threshold for the specified model, and the + model's cached thresholds for its ancestors + If a dictionary, expects a shape similar to model.ancestors, keyed + by model name and valued by the threshold to use for that model. Will + use these threshold values instead of the cached thresholds + """ + return match( + source_id=source_id, + source=source, + target=target, + model=model, + engine=MBDB.get_engine(), + threshold=threshold, + ) + def index(self, dataset: Source) -> None: """Indexes a data from your data warehouse within Matchbox. diff --git a/src/matchbox/server/postgresql/utils/query.py b/src/matchbox/server/postgresql/utils/query.py index 18add15..3c36823 100644 --- a/src/matchbox/server/postgresql/utils/query.py +++ b/src/matchbox/server/postgresql/utils/query.py @@ -13,7 +13,7 @@ MatchboxDatasetError, MatchboxModelError, ) -from matchbox.server.models import Source +from matchbox.server.models import Match, Source from matchbox.server.postgresql.orm import ( Clusters, Contains, @@ -341,3 +341,25 @@ def query( ) else: raise ValueError(f"return_type of {return_type} not valid") + + +def match( + source_id: str, + source: str, + target: str | list[str], + model: str, + engine: Engine, + threshold: float | dict[str, float] | None = None, +) -> Match | list[Match]: + """Matches an ID in a source dataset and returns the keys in the targets. + + To accomplish this, the function: + + * Reconstructs the model lineage from the specified model + * Iterates through each target, and + * Retrieves its cluster hash according to the model + * Retrieves all other IDs in the cluster in the source dataset + * Retrieves all other IDs in the cluster in the target dataset + * Returns the results as Match objects, one per target + """ + pass diff --git a/test/fixtures/data.py b/test/fixtures/data.py index 0900be1..4a98562 100644 --- a/test/fixtures/data.py +++ b/test/fixtures/data.py @@ -1,6 +1,6 @@ import logging -import uuid from pathlib import Path +from uuid import UUID import numpy as np import pandas as pd @@ -34,7 +34,7 @@ def all_companies(test_root_dir: Path) -> DataFrame: df = pd.read_csv( Path(test_root_dir, "data", "all_companies.csv"), encoding="utf-8" ).reset_index(names="id") - df["id"] = df["id"].apply(lambda x: uuid.UUID(int=x)) + df["id"] = df["id"].apply(lambda x: UUID(int=x)) return df @@ -60,7 +60,7 @@ def crn_companies(all_companies: DataFrame) -> DataFrame: df_crn["id"] = range(df_crn.shape[0]) df_crn = df_crn.filter(["id", "company_name", "crn"]) - df_crn["id"] = df_crn["id"].apply(lambda x: uuid.UUID(int=x)) + df_crn["id"] = df_crn["id"].apply(lambda x: UUID(int=x)) df_crn = df_crn.convert_dtypes(dtype_backend="pyarrow") return df_crn @@ -79,12 +79,12 @@ def duns_companies(all_companies: DataFrame) -> DataFrame: """ df_duns = ( all_companies.filter(["company_name", "duns"]) - .sample(n=500) + .sample(n=500, random_state=1618) .reset_index(drop=True) .reset_index(names="id") .convert_dtypes(dtype_backend="pyarrow") ) - df_duns["id"] = df_duns["id"].apply(lambda x: uuid.UUID(int=x)) + df_duns["id"] = df_duns["id"].apply(lambda x: UUID(int=x)) return df_duns @@ -107,12 +107,42 @@ def cdms_companies(all_companies: DataFrame) -> DataFrame: df_cdms.columns = ["crn", "cdms"] df_cdms.reset_index(names="id", inplace=True) - df_cdms["id"] = df_cdms["id"].apply(lambda x: uuid.UUID(int=x)) + df_cdms["id"] = df_cdms["id"].apply(lambda x: UUID(int=x)) df_cdms = df_cdms.convert_dtypes(dtype_backend="pyarrow") return df_cdms +@pytest.fixture(scope="session") +def revolution_inc( + crn_companies: DataFrame, duns_companies: DataFrame, cdms_companies: DataFrame +) -> dict[str, str]: + """ + Revolution Inc. as it exists across all three datasets. + + UUIDs are converted to strings to mirror how Matchbox stores them. + """ + crn_ids = crn_companies[ + crn_companies["company_name"].str.contains("Revolution", case=False) + ]["id"].tolist() + + duns_ids = duns_companies[ + duns_companies["company_name"].str.contains("Revolution", case=False) + ]["id"].tolist() + + revolution_crn = crn_companies[ + crn_companies["company_name"].str.contains("Revolution", case=False) + ]["crn"].iloc[0] + + cdms_ids = cdms_companies[cdms_companies["crn"] == revolution_crn]["id"].tolist() + + return { + "crn": [str(id) for id in crn_ids], + "duns": [str(id) for id in duns_ids], + "cdms": [str(id) for id in cdms_ids], + } + + @pytest.fixture(scope="function") def query_clean_crn( matchbox_postgres: MatchboxPostgres, warehouse_data: list[Source] diff --git a/test/server/test_adapter.py b/test/server/test_adapter.py index 439796e..1212be8 100644 --- a/test/server/test_adapter.py +++ b/test/server/test_adapter.py @@ -15,9 +15,9 @@ Results, to_clusters, ) -from matchbox.helpers.selector import query, selector, selectors +from matchbox.helpers.selector import match, query, selector, selectors from matchbox.server.base import MatchboxDBAdapter, MatchboxModelAdapter -from matchbox.server.models import Source +from matchbox.server.models import Match, Source from pandas import DataFrame from ..fixtures.db import SetupDatabaseCallable @@ -541,6 +541,62 @@ def test_query_with_link_model(self): } assert crn_duns.hash.nunique() == 1000 + def test_match(self, revolution_inc: dict[str, list[str]]): + """Test that matching data works.""" + self.setup_database("link") + + crn_x_duns = "deterministic_naive_test.crn_naive_test.duns" + crn_wh = self.warehouse_data[0] + duns_wh = self.warehouse_data[1] + + # Test 1:* match + + res = match( + backend=self.backend, + source_id=revolution_inc["duns"][0], + source=str(duns_wh), + target=str(crn_wh), + model=crn_x_duns, + ) + + assert isinstance(res, Match) + assert Match.source == str(duns_wh) + assert Match.target == str(crn_wh) + assert Match.source_id == set(revolution_inc["duns"]) + assert Match.target_id == set(revolution_inc["crn"]) + + # Test *:1 match + + res = match( + backend=self.backend, + source_id=revolution_inc["crn"][0], + source=str(crn_wh), + target=str(duns_wh), + model=crn_x_duns, + ) + + assert isinstance(res, Match) + assert Match.source == str(crn_wh) + assert Match.target == str(duns_wh) + assert Match.source_id == set(revolution_inc["crn"]) + assert Match.target_id == set(revolution_inc["duns"]) + + # Test 0:0 match + + res = match( + backend=self.backend, + source_id="foo", + source=str(crn_wh), + target=str(duns_wh), + model=crn_x_duns, + ) + + assert isinstance(res, Match) + assert Match.source == str(crn_wh) + assert Match.target == str(duns_wh) + assert Match.source_id == set() + assert Match.target_id == set() + def test_clear(self): """Test clearing the database.""" self.setup_database("dedupe") diff --git a/uv.lock b/uv.lock index c088139..ad842c5 100644 --- a/uv.lock +++ b/uv.lock @@ -909,7 +909,7 @@ wheels = [ [[package]] name = "matchbox" -version = "0.1.0" +version = "0.2.0" source = { editable = "." } dependencies = [ { name = "altair" }, @@ -1319,6 +1319,8 @@ version = "6.0.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/18/c7/8c6872f7372eb6a6b2e4708b88419fb46b857f7a2e1892966b851cc79fc9/psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2", size = 508067 } wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/66/78c9c3020f573c58101dc43a44f6855d01bbbd747e24da2f0c4491200ea3/psutil-6.0.0-cp27-none-win32.whl", hash = "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35", size = 249766 }, + { url = "https://files.pythonhosted.org/packages/e1/3f/2403aa9558bea4d3854b0e5e567bc3dd8e9fbc1fc4453c0aa9aafeb75467/psutil-6.0.0-cp27-none-win_amd64.whl", hash = "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1", size = 253024 }, { url = "https://files.pythonhosted.org/packages/0b/37/f8da2fbd29690b3557cca414c1949f92162981920699cd62095a984983bf/psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0", size = 250961 }, { url = "https://files.pythonhosted.org/packages/35/56/72f86175e81c656a01c4401cd3b1c923f891b31fbcebe98985894176d7c9/psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0", size = 287478 }, { url = "https://files.pythonhosted.org/packages/19/74/f59e7e0d392bc1070e9a70e2f9190d652487ac115bb16e2eff6b22ad1d24/psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd", size = 290455 }, From c34df258f206dfb6424144b90c756f4b5ad41ce4 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Tue, 3 Dec 2024 08:06:28 +0000 Subject: [PATCH 02/22] Updated pyproject.toml to work with uv 0.5+ --- pyproject.toml | 15 ++++++++------- uv.lock | 12 ++++-------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 79cd3c7..f3111ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,13 +26,8 @@ dependencies = [ "tomli>=2.0.1", ] -[project.optional-dependencies] -typing = [ - "polars>=1.11.0", -] - -[tool.uv] -dev-dependencies = [ +[dependency-groups] +dev = [ "ipykernel>=6.29.5", "pre-commit>=3.8.0", "pytest>=8.3.3", @@ -41,6 +36,12 @@ dev-dependencies = [ "ruff>=0.6.8", "docker>=7.1.0", ] +typing = [ + "polars>=1.11.0", +] + +[tool.uv] +default-groups = ["dev", "typing"] package = true [tool.ruff] diff --git a/uv.lock b/uv.lock index ad842c5..7eabfb0 100644 --- a/uv.lock +++ b/uv.lock @@ -931,11 +931,6 @@ dependencies = [ { name = "tomli" }, ] -[package.optional-dependencies] -typing = [ - { name = "polars" }, -] - [package.dev-dependencies] dev = [ { name = "docker" }, @@ -946,6 +941,9 @@ dev = [ { name = "pytest-env" }, { name = "ruff" }, ] +typing = [ + { name = "polars" }, +] [package.metadata] requires-dist = [ @@ -957,7 +955,6 @@ requires-dist = [ { name = "matplotlib", specifier = ">=3.9.2" }, { name = "pandas", specifier = ">=2.2.3" }, { name = "pg-bulk-ingest", specifier = ">=0.0.54" }, - { name = "polars", marker = "extra == 'typing'", specifier = ">=1.11.0" }, { name = "psycopg2", specifier = ">=2.9.10" }, { name = "pyarrow", specifier = ">=17.0.0" }, { name = "pydantic", specifier = ">=2.9.2" }, @@ -979,6 +976,7 @@ dev = [ { name = "pytest-env", specifier = ">=1.1.5" }, { name = "ruff", specifier = ">=0.6.8" }, ] +typing = [{ name = "polars", specifier = ">=1.11.0" }] [[package]] name = "matplotlib" @@ -1319,8 +1317,6 @@ version = "6.0.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/18/c7/8c6872f7372eb6a6b2e4708b88419fb46b857f7a2e1892966b851cc79fc9/psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2", size = 508067 } wheels = [ - { url = "https://files.pythonhosted.org/packages/c5/66/78c9c3020f573c58101dc43a44f6855d01bbbd747e24da2f0c4491200ea3/psutil-6.0.0-cp27-none-win32.whl", hash = "sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35", size = 249766 }, - { url = "https://files.pythonhosted.org/packages/e1/3f/2403aa9558bea4d3854b0e5e567bc3dd8e9fbc1fc4453c0aa9aafeb75467/psutil-6.0.0-cp27-none-win_amd64.whl", hash = "sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1", size = 253024 }, { url = "https://files.pythonhosted.org/packages/0b/37/f8da2fbd29690b3557cca414c1949f92162981920699cd62095a984983bf/psutil-6.0.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0", size = 250961 }, { url = "https://files.pythonhosted.org/packages/35/56/72f86175e81c656a01c4401cd3b1c923f891b31fbcebe98985894176d7c9/psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0", size = 287478 }, { url = "https://files.pythonhosted.org/packages/19/74/f59e7e0d392bc1070e9a70e2f9190d652487ac115bb16e2eff6b22ad1d24/psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd", size = 290455 }, From fe8b79712083e74ce2b21c4db3610d1df94dad9b Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Tue, 3 Dec 2024 08:24:22 +0000 Subject: [PATCH 03/22] Added API endpoint (unimplemented) --- src/matchbox/server/api.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/matchbox/server/api.py b/src/matchbox/server/api.py index 80b84bb..c5cfd45 100644 --- a/src/matchbox/server/api.py +++ b/src/matchbox/server/api.py @@ -149,6 +149,11 @@ async def query(): raise HTTPException(status_code=501, detail="Not implemented") +@app.get("/match") +async def match(): + raise HTTPException(status_code=501, detail="Not implemented") + + @app.get("/validate/hash") async def validate_hashes(): raise HTTPException(status_code=501, detail="Not implemented") From 7e3529b92503aeae097123467b5e97a4069259a2 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Tue, 3 Dec 2024 09:42:09 +0000 Subject: [PATCH 04/22] Initial run at the query and ORM changes --- src/matchbox/server/postgresql/orm.py | 28 ++- src/matchbox/server/postgresql/utils/query.py | 215 +++++++++++++++++- 2 files changed, 238 insertions(+), 5 deletions(-) diff --git a/src/matchbox/server/postgresql/orm.py b/src/matchbox/server/postgresql/orm.py index 8f1c667..c40279f 100644 --- a/src/matchbox/server/postgresql/orm.py +++ b/src/matchbox/server/postgresql/orm.py @@ -7,6 +7,7 @@ CheckConstraint, Column, ForeignKey, + Index, select, ) from sqlalchemy.dialects.postgresql import ARRAY, BYTEA @@ -105,6 +106,22 @@ def descendants(self) -> set["Models"]: ) return set(session.execute(descendant_query).scalars().all()) + def get_lineage(self) -> dict[bytes, float]: + """Returns all ancestors and their cached truth values from this model.""" + with Session(MBDB.get_engine()) as session: + lineage_query = ( + select(ModelsFrom.parent, ModelsFrom.truth_cache) + .where(ModelsFrom.child == self.hash) + .order_by(ModelsFrom.level.desc()) + ) + + results = session.execute(lineage_query).all() + + lineage = {parent: truth for parent, truth in results} + lineage[self.hash] = self.truth + + return lineage + def get_lineage_to_dataset( self, model: "Models" ) -> tuple[bytes, dict[bytes, float]]: @@ -179,8 +196,12 @@ class Contains(CountMixin, MBDB.MatchboxBase): BYTEA, ForeignKey("clusters.hash", ondelete="CASCADE"), primary_key=True ) - # Constraints - __table_args__ = (CheckConstraint("parent != child", name="no_self_containment"),) + # Constraints and indices + __table_args__ = ( + CheckConstraint("parent != child", name="no_self_containment"), + Index("ix_contains_parent_child", "parent", "child"), + Index("ix_contains_child_parent", "child", "parent"), + ) class Clusters(CountMixin, MBDB.MatchboxBase): @@ -209,6 +230,9 @@ class Clusters(CountMixin, MBDB.MatchboxBase): backref="parents", ) + # Constraints and indices + __table_args__ = (Index("ix_clusters_id_gin", id, postgresql_using="gin"),) + class Probabilities(CountMixin, MBDB.MatchboxBase): """Table of probabilities that a cluster merge is correct, according to a model.""" diff --git a/src/matchbox/server/postgresql/utils/query.py b/src/matchbox/server/postgresql/utils/query.py index 3c36823..d0e224a 100644 --- a/src/matchbox/server/postgresql/utils/query.py +++ b/src/matchbox/server/postgresql/utils/query.py @@ -3,12 +3,21 @@ import pyarrow as pa from pandas import ArrowDtype, DataFrame -from sqlalchemy import Engine, and_, cast, func, literal, null, select, union +from sqlalchemy import ( + Engine, + and_, + cast, + func, + literal, + null, + select, + union, +) from sqlalchemy.dialects.postgresql import BYTEA from sqlalchemy.orm import Session from sqlalchemy.sql.selectable import Select -from matchbox.common.db import sql_to_df +from matchbox.common.db import get_schema_table_names, sql_to_df from matchbox.common.exceptions import ( MatchboxDatasetError, MatchboxModelError, @@ -362,4 +371,204 @@ def match( * Retrieves all other IDs in the cluster in the target dataset * Returns the results as Match objects, one per target """ - pass + # Split source and target into schema/table + source_schema, source_table = get_schema_table_names(source, validate=True) + targets = [target] if isinstance(target, str) else target + target_pairs = [get_schema_table_names(t, validate=True) for t in targets] + + with Session(engine) as session: + # Get truth model + truth_model = session.query(Models).filter(Models.name == model).first() + if truth_model is None: + raise MatchboxModelError(f"Model {model} not found") + + # Get source dataset + source_dataset = ( + session.query(Models) + .join(Sources, Sources.model == Models.hash) + .filter( + Sources.schema == source_schema, + Sources.table == source_table, + ) + .first() + ) + if source_dataset is None: + raise MatchboxDatasetError( + db_schema=source_schema, + db_table=source_table, + ) + + # Get target datasets + target_datasets = [] + for schema, table in target_pairs: + dataset = ( + session.query(Models) + .join(Sources, Sources.model == Models.hash) + .filter( + Sources.schema == schema, + Sources.table == table, + ) + .first() + ) + if dataset is None: + raise MatchboxDatasetError(db_schema=schema, db_table=table) + target_datasets.append(dataset) + + # Get model lineage and resolve thresholds + lineage_truths = truth_model.get_lineage() + thresholds = _resolve_thresholds( + lineage_truths=lineage_truths, + model=truth_model, + threshold=threshold, + session=session, + ) + + # Get valid clusters across all models + valid_clusters = _union_valid_clusters(thresholds) + + # Unnest cluster IDs + unnested_clusters = ( + select( + Clusters.hash, Clusters.dataset, func.unnest(Clusters.id).label("id") + ) + .select_from(Clusters) + .cte("unnested_clusters") + ) + + # Find source ID's initial cluster + source_cluster = ( + select(unnested_clusters.c.hash) + .select_from(unnested_clusters) + .where( + and_( + unnested_clusters.c.dataset + == hash_to_hex_decode(source_dataset.hash), + unnested_clusters.c.id == source_id, + ) + ) + .scalar_subquery() + ) + + # Build recursive hierarchy CTE going up + hierarchy_up = ( + # Base case: direct parents + select( + source_cluster.label("original_cluster"), + source_cluster.label("child"), + Contains.parent.label("parent"), + literal(1).label("level"), + ) + .join(Contains, Contains.child == source_cluster) + .where(Contains.parent.in_(select(valid_clusters.c.cluster))) + .cte("hierarchy_up", recursive=True) + ) + + # Recursive case going up + recursive_up = ( + select( + hierarchy_up.c.original_cluster, + hierarchy_up.c.parent.label("child"), + Contains.parent.label("parent"), + (hierarchy_up.c.level + 1).label("level"), + ) + .select_from(hierarchy_up) + .join(Contains, Contains.child == hierarchy_up.c.parent) + .where(Contains.parent.in_(select(valid_clusters.c.cluster))) + ) + + hierarchy_up = hierarchy_up.union_all(recursive_up) + + # Get highest parent + highest_parent = ( + select(hierarchy_up.c.parent) + .order_by(hierarchy_up.c.level.desc()) + .limit(1) + .scalar_subquery() + ) + + # Build recursive hierarchy CTE going down + hierarchy_down = ( + # Base case: direct children from highest parent + select( + highest_parent.label("parent"), + Contains.child.label("child"), + literal(1).label("level"), + unnested_clusters.c.dataset.label("dataset"), + unnested_clusters.c.id.label("id"), + ) + .select_from(Contains) + .join(unnested_clusters, unnested_clusters.c.hash == Contains.child) + .where(Contains.parent == highest_parent) + .cte("hierarchy_down", recursive=True) + ) + + # Recursive case going down + recursive_down = ( + select( + hierarchy_down.c.parent, + Contains.child.label("child"), + (hierarchy_down.c.level + 1).label("level"), + unnested_clusters.c.dataset.label("dataset"), + unnested_clusters.c.id.label("id"), + ) + .select_from(hierarchy_down) + .join(Contains, Contains.parent == hierarchy_down.c.child) + .join(unnested_clusters, unnested_clusters.c.hash == Contains.child) + ) + + hierarchy_down = hierarchy_down.union_all(recursive_down) + + # Get all matched IDs + matches = session.execute( + select( + hierarchy_down.c.dataset, + hierarchy_down.c.id, + ) + .distinct() + .select_from(hierarchy_down) + ).all() + + # Group matches by dataset + matches_by_dataset = {} + for dataset_hash, id in matches: + if dataset_hash not in matches_by_dataset: + matches_by_dataset[dataset_hash] = set() + matches_by_dataset[dataset_hash].add(id) + + # Create Match objects for each target + result = [] + for target_dataset in target_datasets: + # Get source/target table names + source_name = f"{source_schema}.{source_table}" + target_schema, target_table = next( + (schema, table) + for schema, table in target_pairs + if session.get(Sources, target_dataset.hash).schema == schema + and session.get(Sources, target_dataset.hash).table == table + ) + target_name = f"{target_schema}.{target_table}" + + highest_cluster = highest_parent.scalar() if matches else None + + # Get source and target IDs + source_ids = { + id + for dataset_hash, id in matches + if dataset_hash == source_dataset.hash + } + target_ids = { + id + for dataset_hash, id in matches + if dataset_hash == target_dataset.hash + } + + match_obj = Match( + cluster=highest_cluster, + source=source_name, + source_id=source_ids, + target=target_name, + target_id=target_ids, + ) + result.append(match_obj) + + return result[0] if isinstance(target, str) else result From 35b62f3c1a5e2f60f62b020b793d83f4c02e1ddb Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Tue, 3 Dec 2024 15:54:46 +0000 Subject: [PATCH 05/22] Attempting to copy indices --- src/matchbox/server/postgresql/utils/db.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/matchbox/server/postgresql/utils/db.py b/src/matchbox/server/postgresql/utils/db.py index 1cd90f5..59bd90d 100644 --- a/src/matchbox/server/postgresql/utils/db.py +++ b/src/matchbox/server/postgresql/utils/db.py @@ -116,6 +116,7 @@ def batch_ingest( table.__table__.name, isolated_metadata, *[c._copy() for c in table.__table__.columns], + *[i for i in table.__table__.indexes], schema=table.__table__.schema, ) From 8ea880f62c6c036bec82521dd4679e42ce39a6f3 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Wed, 4 Dec 2024 07:45:15 +0000 Subject: [PATCH 06/22] Working indices on ORM --- src/matchbox/server/postgresql/utils/db.py | 55 +++++++++++++++++----- 1 file changed, 42 insertions(+), 13 deletions(-) diff --git a/src/matchbox/server/postgresql/utils/db.py b/src/matchbox/server/postgresql/utils/db.py index 59bd90d..613f556 100644 --- a/src/matchbox/server/postgresql/utils/db.py +++ b/src/matchbox/server/postgresql/utils/db.py @@ -3,11 +3,11 @@ import io import pstats from itertools import islice -from typing import Any, Callable, Iterable, Tuple +from typing import Any, Callable, Iterable import rustworkx as rx from pg_bulk_ingest import Delete, Upsert, ingest -from sqlalchemy import Engine, MetaData, Table +from sqlalchemy import Engine, Index, MetaData, Table from sqlalchemy.engine.base import Connection from sqlalchemy.orm import DeclarativeMeta, Session @@ -87,18 +87,55 @@ def batched(iterable: Iterable, n: int) -> Iterable: def data_to_batch( records: list[tuple], table: Table, batch_size: int -) -> Callable[[str], Tuple[Any]]: +) -> Callable[[str], tuple[Any]]: """Constructs a batches function for any dataframe and table.""" def _batches( high_watermark, # noqa ARG001 required for pg_bulk_ingest - ) -> Iterable[Tuple[None, None, Iterable[Tuple[Table, tuple]]]]: + ) -> Iterable[tuple[None, None, Iterable[tuple[Table, tuple]]]]: for batch in batched(records, batch_size): yield None, None, ((table, t) for t in batch) return _batches +def isolate_table(table: DeclarativeMeta) -> tuple[MetaData, Table]: + """Creates an isolated copy of a SQLAlchemy table. + + This is used to prevent pg_bulk_ingest from attempting to drop unrelated tables + in the same schema. The function creates a new Table instance with: + + * A fresh MetaData instance + * Copied columns + * Recreated indices properly bound to the new table + + Args: + table: The DeclarativeMeta class whose table should be isolated + + Returns: + A tuple of: + * The isolated SQLAlchemy MetaData + * A new SQLAlchemy Table instance with all columns and indices + """ + isolated_metadata = MetaData(schema=table.__table__.schema) + + isolated_table = Table( + table.__table__.name, + isolated_metadata, + *[c._copy() for c in table.__table__.columns], + schema=table.__table__.schema, + ) + + for idx in table.__table__.indexes: + Index( + idx.name, + *[isolated_table.c[col.name] for col in idx.columns], + **{k: v for k, v in idx.kwargs.items()}, + ) + + return isolated_metadata, isolated_table + + def batch_ingest( records: list[tuple[Any]], table: DeclarativeMeta, @@ -110,15 +147,7 @@ def batch_ingest( We isolate the table and metadata as pg_bulk_ingest will try and drop unrelated tables if they're in the same schema. """ - - isolated_metadata = MetaData(schema=table.__table__.schema) - isolated_table = Table( - table.__table__.name, - isolated_metadata, - *[c._copy() for c in table.__table__.columns], - *[i for i in table.__table__.indexes], - schema=table.__table__.schema, - ) + isolated_metadata, isolated_table = isolate_table(table) fn_batch = data_to_batch( records=records, From e9d009cd92575420b8173e2910d65129f6fd5131 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Wed, 4 Dec 2024 08:34:55 +0000 Subject: [PATCH 07/22] Changed _resolve_cluster_hierarchy and lineage to handle datasets better, ahead of using this in match() --- src/matchbox/server/postgresql/orm.py | 13 ++++--------- src/matchbox/server/postgresql/utils/query.py | 19 ++++++++++++++++--- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/matchbox/server/postgresql/orm.py b/src/matchbox/server/postgresql/orm.py index c40279f..0d1a94f 100644 --- a/src/matchbox/server/postgresql/orm.py +++ b/src/matchbox/server/postgresql/orm.py @@ -132,11 +132,11 @@ def get_lineage_to_dataset( ) if self.hash == model.hash: - return {} + return {model.hash: None} with Session(MBDB.get_engine()) as session: path_query = ( - select(ModelsFrom.parent, ModelsFrom.truth_cache, Models.type) + select(ModelsFrom.parent, ModelsFrom.truth_cache) .join(Models, Models.hash == ModelsFrom.parent) .where(ModelsFrom.child == self.hash) .order_by(ModelsFrom.level.desc()) @@ -144,17 +144,12 @@ def get_lineage_to_dataset( results = session.execute(path_query).all() - if not any(parent == model.hash for parent, _, _ in results): + if not any(parent == model.hash for parent, _ in results): raise ValueError( f"No path exists between model {self.name} and dataset {model.name}" ) - lineage = { - parent: truth - for parent, truth, type in results - if type != ModelType.DATASET.value - } - + lineage = {parent: truth for parent, truth in results} lineage[self.hash] = self.truth return lineage diff --git a/src/matchbox/server/postgresql/utils/query.py b/src/matchbox/server/postgresql/utils/query.py index d0e224a..944b5b2 100644 --- a/src/matchbox/server/postgresql/utils/query.py +++ b/src/matchbox/server/postgresql/utils/query.py @@ -111,7 +111,19 @@ def _union_valid_clusters(lineage_thresholds: dict[bytes, float]) -> Select: valid_clusters = None for model_hash, threshold in lineage_thresholds.items(): - model_valid = _get_valid_clusters_for_model(model_hash, threshold) + if threshold is None: + # This is a dataset - get all its clusters directly + model_valid = select(Clusters.hash.label("cluster")).where( + Clusters.dataset == hash_to_hex_decode(model_hash) + ) + else: + # This is a model - get clusters meeting threshold + model_valid = select(Probabilities.cluster.label("cluster")).where( + and_( + Probabilities.model == hash_to_hex_decode(model_hash), + Probabilities.probability >= threshold, + ) + ) if valid_clusters is None: valid_clusters = model_valid @@ -519,14 +531,15 @@ def match( hierarchy_down = hierarchy_down.union_all(recursive_down) # Get all matched IDs - matches = session.execute( + final_stmt = ( select( hierarchy_down.c.dataset, hierarchy_down.c.id, ) .distinct() .select_from(hierarchy_down) - ).all() + ) + matches = session.execute(final_stmt).all() # Group matches by dataset matches_by_dataset = {} From 3a5165744bcab51aec3c3c539a4409a512b51b9e Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Wed, 4 Dec 2024 08:52:12 +0000 Subject: [PATCH 08/22] Tidied up _resolve_cluster_hierarchy() with new logic --- src/matchbox/server/postgresql/utils/query.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/matchbox/server/postgresql/utils/query.py b/src/matchbox/server/postgresql/utils/query.py index 944b5b2..84a2358 100644 --- a/src/matchbox/server/postgresql/utils/query.py +++ b/src/matchbox/server/postgresql/utils/query.py @@ -72,6 +72,12 @@ def _resolve_thresholds( resolved_thresholds = {} for model_hash, default_truth in lineage_truths.items(): + # Dataset + if default_truth is None: + resolved_thresholds[model_hash] = None + continue + + # Model if threshold is None: resolved_thresholds[model_hash] = default_truth elif isinstance(threshold, float): @@ -93,16 +99,6 @@ def _resolve_thresholds( return resolved_thresholds -def _get_valid_clusters_for_model(model_hash: bytes, threshold: float) -> Select: - """Get clusters that meet the threshold for a specific model.""" - return select(Probabilities.cluster.label("cluster")).where( - and_( - Probabilities.model == hash_to_hex_decode(model_hash), - Probabilities.probability >= threshold, - ) - ) - - def _union_valid_clusters(lineage_thresholds: dict[bytes, float]) -> Select: """Creates a CTE of clusters that are valid for any model in the lineage. @@ -158,6 +154,9 @@ def _resolve_cluster_hierarchy( """ with Session(engine) as session: dataset_model = session.get(Models, dataset_hash) + if dataset_model is None: + raise MatchboxDatasetError("Dataset not found") + try: lineage_truths = model.get_lineage_to_dataset(model=dataset_model) except ValueError as e: @@ -181,6 +180,7 @@ def _resolve_cluster_hierarchy( ) .where( and_( + Clusters.hash.in_(select(valid_clusters.c.cluster)), Clusters.dataset == hash_to_hex_decode(dataset_hash), Clusters.id.isnot(None), ) From 0097f616958d8e644a8f9aa2840a3f474e43f395 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Wed, 4 Dec 2024 09:04:17 +0000 Subject: [PATCH 09/22] Working match(), not yet passing unit tests --- src/matchbox/server/postgresql/utils/query.py | 23 +++++++++++------- test/server/test_adapter.py | 24 +++++++++---------- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/src/matchbox/server/postgresql/utils/query.py b/src/matchbox/server/postgresql/utils/query.py index 84a2358..6c1ae0c 100644 --- a/src/matchbox/server/postgresql/utils/query.py +++ b/src/matchbox/server/postgresql/utils/query.py @@ -470,8 +470,13 @@ def match( Contains.parent.label("parent"), literal(1).label("level"), ) - .join(Contains, Contains.child == source_cluster) - .where(Contains.parent.in_(select(valid_clusters.c.cluster))) + .select_from(Contains) + .where( + and_( + Contains.child == source_cluster, + Contains.parent.in_(select(valid_clusters.c.cluster)), + ) + ) .cte("hierarchy_up", recursive=True) ) @@ -533,6 +538,7 @@ def match( # Get all matched IDs final_stmt = ( select( + hierarchy_down.c.parent.label("cluster"), hierarchy_down.c.dataset, hierarchy_down.c.id, ) @@ -542,8 +548,11 @@ def match( matches = session.execute(final_stmt).all() # Group matches by dataset + cluster = None matches_by_dataset = {} - for dataset_hash, id in matches: + for cluster_hash, dataset_hash, id in matches: + if cluster is None: + cluster = cluster_hash if dataset_hash not in matches_by_dataset: matches_by_dataset[dataset_hash] = set() matches_by_dataset[dataset_hash].add(id) @@ -561,22 +570,20 @@ def match( ) target_name = f"{target_schema}.{target_table}" - highest_cluster = highest_parent.scalar() if matches else None - # Get source and target IDs source_ids = { id - for dataset_hash, id in matches + for _, dataset_hash, id in matches if dataset_hash == source_dataset.hash } target_ids = { id - for dataset_hash, id in matches + for _, dataset_hash, id in matches if dataset_hash == target_dataset.hash } match_obj = Match( - cluster=highest_cluster, + cluster=cluster, source=source_name, source_id=source_ids, target=target_name, diff --git a/test/server/test_adapter.py b/test/server/test_adapter.py index 1212be8..fe9b1cc 100644 --- a/test/server/test_adapter.py +++ b/test/server/test_adapter.py @@ -560,10 +560,10 @@ def test_match(self, revolution_inc: dict[str, list[str]]): ) assert isinstance(res, Match) - assert Match.source == str(duns_wh) - assert Match.target == str(crn_wh) - assert Match.source_id == set(revolution_inc["duns"]) - assert Match.target_id == set(revolution_inc["crn"]) + assert res.source == str(duns_wh) + assert res.target == str(crn_wh) + assert res.source_id == set(revolution_inc["duns"]) + assert res.target_id == set(revolution_inc["crn"]) # Test *:1 match @@ -576,10 +576,10 @@ def test_match(self, revolution_inc: dict[str, list[str]]): ) assert isinstance(res, Match) - assert Match.source == str(crn_wh) - assert Match.target == str(duns_wh) - assert Match.source_id == set(revolution_inc["crn"]) - assert Match.target_id == set(revolution_inc["duns"]) + assert res.source == str(crn_wh) + assert res.target == str(duns_wh) + assert res.source_id == set(revolution_inc["crn"]) + assert res.target_id == set(revolution_inc["duns"]) # Test 0:0 match @@ -592,10 +592,10 @@ def test_match(self, revolution_inc: dict[str, list[str]]): ) assert isinstance(res, Match) - assert Match.source == str(crn_wh) - assert Match.target == str(duns_wh) - assert Match.source_id == set() - assert Match.target_id == set() + assert res.source == str(crn_wh) + assert res.target == str(duns_wh) + assert res.source_id == set() + assert res.target_id == set() def test_clear(self): """Test clearing the database.""" From 719f152ae28ddd9fef2ef926c3bcc033dabc9ce8 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Wed, 4 Dec 2024 11:19:09 +0000 Subject: [PATCH 10/22] Moved match unit tests into separate functions --- test/server/test_adapter.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/test/server/test_adapter.py b/test/server/test_adapter.py index fe9b1cc..c350f3d 100644 --- a/test/server/test_adapter.py +++ b/test/server/test_adapter.py @@ -541,16 +541,14 @@ def test_query_with_link_model(self): } assert crn_duns.hash.nunique() == 1000 - def test_match(self, revolution_inc: dict[str, list[str]]): - """Test that matching data works.""" + def test_match_one_to_many(self, revolution_inc: dict[str, list[str]]): + """Test that matching data works when the target has many IDs.""" self.setup_database("link") crn_x_duns = "deterministic_naive_test.crn_naive_test.duns" crn_wh = self.warehouse_data[0] duns_wh = self.warehouse_data[1] - # Test 1:* match - res = match( backend=self.backend, source_id=revolution_inc["duns"][0], @@ -565,7 +563,13 @@ def test_match(self, revolution_inc: dict[str, list[str]]): assert res.source_id == set(revolution_inc["duns"]) assert res.target_id == set(revolution_inc["crn"]) - # Test *:1 match + def test_match_many_to_one(self, revolution_inc: dict[str, list[str]]): + """Test that matching data works when the source has more possible IDs.""" + self.setup_database("link") + + crn_x_duns = "deterministic_naive_test.crn_naive_test.duns" + crn_wh = self.warehouse_data[0] + duns_wh = self.warehouse_data[1] res = match( backend=self.backend, @@ -581,7 +585,13 @@ def test_match(self, revolution_inc: dict[str, list[str]]): assert res.source_id == set(revolution_inc["crn"]) assert res.target_id == set(revolution_inc["duns"]) - # Test 0:0 match + def test_match_none_to_none(self): + """Test that matching data work when the supplied key doesn't exist.""" + self.setup_database("link") + + crn_x_duns = "deterministic_naive_test.crn_naive_test.duns" + crn_wh = self.warehouse_data[0] + duns_wh = self.warehouse_data[1] res = match( backend=self.backend, From 802643aaa2e3baa1c33739e0e00f8c5e0db274fe Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Thu, 5 Dec 2024 06:53:28 +0000 Subject: [PATCH 11/22] Factored out subqueries of match() --- src/matchbox/server/postgresql/utils/query.py | 243 +++++++++++------- 1 file changed, 146 insertions(+), 97 deletions(-) diff --git a/src/matchbox/server/postgresql/utils/query.py b/src/matchbox/server/postgresql/utils/query.py index d758f3b..981082c 100644 --- a/src/matchbox/server/postgresql/utils/query.py +++ b/src/matchbox/server/postgresql/utils/query.py @@ -15,7 +15,7 @@ ) from sqlalchemy.dialects.postgresql import BYTEA from sqlalchemy.orm import Session -from sqlalchemy.sql.selectable import Select +from sqlalchemy.sql.selectable import CTE, Select from matchbox.common.db import Match, Source, get_schema_table_names, sql_to_df from matchbox.common.exceptions import ( @@ -363,6 +363,145 @@ def query( raise ValueError(f"return_type of {return_type} not valid") +def _build_unnested_clusters() -> CTE: + """Create CTE that unnests cluster IDs for easier joining.""" + return ( + select(Clusters.hash, Clusters.dataset, func.unnest(Clusters.id).label("id")) + .select_from(Clusters) + .cte("unnested_clusters") + ) + + +def _find_source_cluster( + unnested_clusters: CTE, source_dataset_hash: bytes, source_id: str +) -> Select: + """Find the initial cluster containing the source ID.""" + return ( + select(unnested_clusters.c.hash) + .select_from(unnested_clusters) + .where( + and_( + unnested_clusters.c.dataset == hash_to_hex_decode(source_dataset_hash), + unnested_clusters.c.id == source_id, + ) + ) + .scalar_subquery() + ) + + +def _build_hierarchy_up( + source_cluster: Select, valid_clusters: CTE | None = None +) -> CTE: + """ + Build recursive CTE that finds all parent clusters. + + Args: + source_cluster: Subquery that finds starting cluster + valid_clusters: Optional CTE of valid clusters to filter by + """ + # Base case: direct parents + base = ( + select( + source_cluster.label("original_cluster"), + source_cluster.label("child"), + Contains.parent.label("parent"), + literal(1).label("level"), + ) + .select_from(Contains) + .where(Contains.child == source_cluster) + ) + + # Add valid clusters filter if provided + if valid_clusters is not None: + base = base.where(Contains.parent.in_(select(valid_clusters.c.cluster))) + + hierarchy_up = base.cte("hierarchy_up", recursive=True) + + # Recursive case + recursive = ( + select( + hierarchy_up.c.original_cluster, + hierarchy_up.c.parent.label("child"), + Contains.parent.label("parent"), + (hierarchy_up.c.level + 1).label("level"), + ) + .select_from(hierarchy_up) + .join(Contains, Contains.child == hierarchy_up.c.parent) + ) + + # Add valid clusters filter to recursive part if provided + if valid_clusters is not None: + recursive = recursive.where( + Contains.parent.in_(select(valid_clusters.c.cluster)) + ) + + return hierarchy_up.union_all(recursive) + + +def _find_highest_parent(hierarchy_up: CTE) -> Select: + """Find the topmost parent cluster from the hierarchy.""" + return ( + select(hierarchy_up.c.parent) + .order_by(hierarchy_up.c.level.desc()) + .limit(1) + .scalar_subquery() + ) + + +def _build_hierarchy_down( + highest_parent: Select, unnested_clusters: CTE, valid_clusters: CTE | None = None +) -> CTE: + """ + Build recursive CTE that finds all child clusters and their IDs. + + Args: + highest_parent: Subquery that finds top cluster + unnested_clusters: CTE with unnested cluster IDs + valid_clusters: Optional CTE of valid clusters to filter by + """ + # Base case: direct children + base = ( + select( + highest_parent.label("parent"), + Contains.child.label("child"), + literal(1).label("level"), + unnested_clusters.c.dataset.label("dataset"), + unnested_clusters.c.id.label("id"), + ) + .select_from(Contains) + .join(unnested_clusters, unnested_clusters.c.hash == Contains.child) + .where(Contains.parent == highest_parent) + ) + + # Add valid clusters filter if provided + if valid_clusters is not None: + base = base.where(Contains.child.in_(select(valid_clusters.c.cluster))) + + hierarchy_down = base.cte("hierarchy_down", recursive=True) + + # Recursive case + recursive = ( + select( + hierarchy_down.c.parent, + Contains.child.label("child"), + (hierarchy_down.c.level + 1).label("level"), + unnested_clusters.c.dataset.label("dataset"), + unnested_clusters.c.id.label("id"), + ) + .select_from(hierarchy_down) + .join(Contains, Contains.parent == hierarchy_down.c.child) + .join(unnested_clusters, unnested_clusters.c.hash == Contains.child) + ) + + # Add valid clusters filter to recursive part if provided + if valid_clusters is not None: + recursive = recursive.where( + Contains.child.in_(select(valid_clusters.c.cluster)) + ) + + return hierarchy_down.union_all(recursive) + + def match( source_id: str, source: str, @@ -437,102 +576,12 @@ def match( # Get valid clusters across all models valid_clusters = _union_valid_clusters(thresholds) - # Unnest cluster IDs - unnested_clusters = ( - select( - Clusters.hash, Clusters.dataset, func.unnest(Clusters.id).label("id") - ) - .select_from(Clusters) - .cte("unnested_clusters") - ) - - # Find source ID's initial cluster - source_cluster = ( - select(unnested_clusters.c.hash) - .select_from(unnested_clusters) - .where( - and_( - unnested_clusters.c.dataset - == hash_to_hex_decode(source_dataset.hash), - unnested_clusters.c.id == source_id, - ) - ) - .scalar_subquery() - ) - - # Build recursive hierarchy CTE going up - hierarchy_up = ( - # Base case: direct parents - select( - source_cluster.label("original_cluster"), - source_cluster.label("child"), - Contains.parent.label("parent"), - literal(1).label("level"), - ) - .select_from(Contains) - .where( - and_( - Contains.child == source_cluster, - Contains.parent.in_(select(valid_clusters.c.cluster)), - ) - ) - .cte("hierarchy_up", recursive=True) - ) - - # Recursive case going up - recursive_up = ( - select( - hierarchy_up.c.original_cluster, - hierarchy_up.c.parent.label("child"), - Contains.parent.label("parent"), - (hierarchy_up.c.level + 1).label("level"), - ) - .select_from(hierarchy_up) - .join(Contains, Contains.child == hierarchy_up.c.parent) - .where(Contains.parent.in_(select(valid_clusters.c.cluster))) - ) - - hierarchy_up = hierarchy_up.union_all(recursive_up) - - # Get highest parent - highest_parent = ( - select(hierarchy_up.c.parent) - .order_by(hierarchy_up.c.level.desc()) - .limit(1) - .scalar_subquery() - ) - - # Build recursive hierarchy CTE going down - hierarchy_down = ( - # Base case: direct children from highest parent - select( - highest_parent.label("parent"), - Contains.child.label("child"), - literal(1).label("level"), - unnested_clusters.c.dataset.label("dataset"), - unnested_clusters.c.id.label("id"), - ) - .select_from(Contains) - .join(unnested_clusters, unnested_clusters.c.hash == Contains.child) - .where(Contains.parent == highest_parent) - .cte("hierarchy_down", recursive=True) - ) - - # Recursive case going down - recursive_down = ( - select( - hierarchy_down.c.parent, - Contains.child.label("child"), - (hierarchy_down.c.level + 1).label("level"), - unnested_clusters.c.dataset.label("dataset"), - unnested_clusters.c.id.label("id"), - ) - .select_from(hierarchy_down) - .join(Contains, Contains.parent == hierarchy_down.c.child) - .join(unnested_clusters, unnested_clusters.c.hash == Contains.child) - ) - - hierarchy_down = hierarchy_down.union_all(recursive_down) + # Build the query components + unnested = _build_unnested_clusters() + source_cluster = _find_source_cluster(unnested, source_dataset.hash, source_id) + hierarchy_up = _build_hierarchy_up(source_cluster, valid_clusters) + highest = _find_highest_parent(hierarchy_up) + hierarchy_down = _build_hierarchy_down(highest, unnested, valid_clusters) # Get all matched IDs final_stmt = ( From 72a9b845a0d871b2cd97cef64259f6276609e9bc Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Thu, 5 Dec 2024 13:12:58 +0000 Subject: [PATCH 12/22] Wrote functions to visualise the subgraph to help debug --- src/matchbox/helpers/visualisation.py | 92 ++++++++++++++++++++++ src/matchbox/server/postgresql/utils/db.py | 47 ++++++++++- 2 files changed, 138 insertions(+), 1 deletion(-) diff --git a/src/matchbox/helpers/visualisation.py b/src/matchbox/helpers/visualisation.py index ff41d59..71b130b 100644 --- a/src/matchbox/helpers/visualisation.py +++ b/src/matchbox/helpers/visualisation.py @@ -1,3 +1,6 @@ +from collections import defaultdict +from itertools import count + import rustworkx as rx from matplotlib.figure import Figure from rustworkx.visualization import mpl_draw @@ -40,3 +43,92 @@ def draw_model_tree(backend: MatchboxDBAdapter) -> Figure: edge_labels=lambda edge: edge["type"], font_size=8, ) + + +def draw_data_tree(graph: rx.PyDiGraph) -> str: + """ + Convert a rustworkx PyDiGraph to Mermaid graph visualization code. + + Args: + graph (rx.PyDiGraph): A rustworkx directed graph with nodes containing 'id' and + 'type' attributes + + Returns: + str: Mermaid graph definition code + """ + mermaid_lines = ["graph LR"] + + counters = defaultdict(count, {"hash": count(1)}) + node_to_var = {} + node_types = {} + data_nodes = set() + + def format_id(id_value): + """Format ID value, converting bytes to hex if needed.""" + if isinstance(id_value, bytes): + return f"\\x{id_value.hex()}" + return f"['{str(id_value)}']" + + for node_idx in graph.node_indices(): + node_data = graph.get_node_data(node_idx) + if isinstance(node_data, dict): + node_type = node_data.get("type", "") + node_types[node_idx] = node_type + if node_type == "data": + data_nodes.add(node_idx) + + for node_idx, node_type in node_types.items(): + if node_type == "source": + node_data = graph.get_node_data(node_idx) + table_name = node_data["id"].split(".")[-1] + node_to_var[node_idx] = table_name + + counter = count(1) + for predecessor in graph.predecessor_indices(node_idx): + if predecessor in data_nodes: + node_to_var[predecessor] = f"{table_name}{str(next(counter))}" + data_nodes.remove(predecessor) + + remaining_counter = count(len(node_to_var) + 1) + for node_idx in data_nodes: + node_to_var[node_idx] = str(next(remaining_counter)) + + for node_idx, node_type in node_types.items(): + if node_type == "cluster": + node_to_var[node_idx] = f"hash{next(counters['hash'])}" + + sources = [] + data_defs = [] + clusters = [] + + for node_idx, node_type in node_types.items(): + node_data = graph.get_node_data(node_idx) + var_name = node_to_var[node_idx] + + if node_type == "source": + node_def = f' {var_name}["{node_data["id"]}"]' + sources.append(node_def) + elif node_type == "data": + node_label = format_id(node_data["id"]) + node_label = node_label.strip("[]'") + node_def = f' {var_name}["{node_label}"]' + data_defs.append(node_def) + elif node_type == "cluster": + node_label = format_id(node_data["id"]) + node_def = f' {var_name}["{node_label}"]' + clusters.append(node_def) + + mermaid_lines.extend(sources) + mermaid_lines.extend(data_defs) + mermaid_lines.extend(clusters) + + mermaid_lines.append("") + + for edge in graph.edge_list(): + source = edge[0] + target = edge[1] + source_var = node_to_var[source] + target_var = node_to_var[target] + mermaid_lines.append(f" {source_var} --> {target_var}") + + return "\n".join(mermaid_lines) diff --git a/src/matchbox/server/postgresql/utils/db.py b/src/matchbox/server/postgresql/utils/db.py index 613f556..b2eb5bf 100644 --- a/src/matchbox/server/postgresql/utils/db.py +++ b/src/matchbox/server/postgresql/utils/db.py @@ -11,7 +11,14 @@ from sqlalchemy.engine.base import Connection from sqlalchemy.orm import DeclarativeMeta, Session -from matchbox.server.postgresql.orm import Models, ModelsFrom, ModelType, Sources +from matchbox.server.postgresql.orm import ( + Clusters, + Contains, + Models, + ModelsFrom, + ModelType, + Sources, +) # Retrieval @@ -50,6 +57,44 @@ def get_model_subgraph(engine: Engine) -> rx.PyDiGraph: return G +def get_data_subgraph(engine: Engine) -> rx.PyDiGraph: + """Retrieves the complete data subgraph as a PyDiGraph.""" + G = rx.PyDiGraph() + nodes = {} + + with Session(engine) as session: + sources = {source.model: source for source in session.query(Sources).all()} + + for source in sources.values(): + source_id = f"{source.schema}.{source.table}" + if source_id not in nodes: + source_idx = G.add_node({"id": source_id, "type": "source"}) + nodes[source_id] = source_idx + + for cluster in session.query(Clusters).all(): + cluster_id = cluster.hash + if cluster_id not in nodes: + cluster_idx = G.add_node({"id": cluster_id, "type": "cluster"}) + nodes[cluster_id] = cluster_idx + + if cluster.id is not None and cluster.dataset is not None: + source = sources.get(cluster.dataset) + if source: + data_id = str(cluster.id) + data_idx = G.add_node({"id": data_id, "type": "data"}) + + source_id = f"{source.schema}.{source.table}" + G.add_edge(data_idx, nodes[source_id], {"type": "source"}) + G.add_edge(nodes[cluster_id], data_idx, {"type": "data"}) + + for contains in session.query(Contains).all(): + G.add_edge( + nodes[contains.parent], nodes[contains.child], {"type": "contains"} + ) + + return G + + # SQLAlchemy profiling From f6a3225c7c9d60de97c585a51f4a815cca941e91 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Sun, 8 Dec 2024 12:43:29 +0000 Subject: [PATCH 13/22] Test cases with data working --- src/matchbox/server/postgresql/utils/query.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/matchbox/server/postgresql/utils/query.py b/src/matchbox/server/postgresql/utils/query.py index 981082c..89897a6 100644 --- a/src/matchbox/server/postgresql/utils/query.py +++ b/src/matchbox/server/postgresql/utils/query.py @@ -459,7 +459,7 @@ def _build_hierarchy_down( unnested_clusters: CTE with unnested cluster IDs valid_clusters: Optional CTE of valid clusters to filter by """ - # Base case: direct children + # Base case: Get both direct children and their IDs base = ( select( highest_parent.label("parent"), @@ -469,7 +469,12 @@ def _build_hierarchy_down( unnested_clusters.c.id.label("id"), ) .select_from(Contains) - .join(unnested_clusters, unnested_clusters.c.hash == Contains.child) + .join_from( + Contains, + unnested_clusters, + unnested_clusters.c.hash == Contains.child, + isouter=True, + ) .where(Contains.parent == highest_parent) ) @@ -479,7 +484,7 @@ def _build_hierarchy_down( hierarchy_down = base.cte("hierarchy_down", recursive=True) - # Recursive case + # Recursive case: Get both intermediate nodes AND their leaf records recursive = ( select( hierarchy_down.c.parent, @@ -489,8 +494,18 @@ def _build_hierarchy_down( unnested_clusters.c.id.label("id"), ) .select_from(hierarchy_down) - .join(Contains, Contains.parent == hierarchy_down.c.child) - .join(unnested_clusters, unnested_clusters.c.hash == Contains.child) + .join_from( + hierarchy_down, + Contains, + Contains.parent == hierarchy_down.c.child, + ) + .join_from( + Contains, + unnested_clusters, + unnested_clusters.c.hash == Contains.child, + isouter=True, + ) + .where(hierarchy_down.c.id.is_(None)) # Only recurse on non-leaf nodes ) # Add valid clusters filter to recursive part if provided From 6a936d7b784fbf549aa4eb9afe8b8350c2520a34 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Sun, 8 Dec 2024 12:47:38 +0000 Subject: [PATCH 14/22] All unit tests passing --- src/matchbox/common/db.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/matchbox/common/db.py b/src/matchbox/common/db.py index 5be3128..7f7455c 100644 --- a/src/matchbox/common/db.py +++ b/src/matchbox/common/db.py @@ -6,7 +6,7 @@ from matchbox.common.hash import HASH_FUNC from pandas import DataFrame from pyarrow import Table as ArrowTable -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator from sqlalchemy import ( LABEL_STYLE_TABLENAME_PLUS_COL, ColumnElement, @@ -36,12 +36,24 @@ class Match(BaseModel): """A match between primary keys in the Matchbox database.""" - cluster: bytes + cluster: bytes | None source: str source_id: set[str] = Field(default_factory=set) target: str target_id: set[str] = Field(default_factory=set) + @model_validator(mode="after") + def found_or_none(self) -> "Match": + if self.cluster is None and (self.source_id or self.target_id): + raise ValueError( + "A match must have a cluster if source_id or target_id is set." + ) + elif self.cluster is not None and not (self.source_id or self.target_id): + raise ValueError( + "A match must have source_id or target_id if cluster is set." + ) + return self + class Probability(BaseModel): """A probability of a match in the Matchbox database. From 79ee5b5f89af06b8721210a8d706c9c13941b60c Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Sun, 8 Dec 2024 13:38:38 +0000 Subject: [PATCH 15/22] Minor refactor of query so it's easier to read --- src/matchbox/server/postgresql/utils/query.py | 101 +++++++----------- 1 file changed, 40 insertions(+), 61 deletions(-) diff --git a/src/matchbox/server/postgresql/utils/query.py b/src/matchbox/server/postgresql/utils/query.py index 89897a6..e45f194 100644 --- a/src/matchbox/server/postgresql/utils/query.py +++ b/src/matchbox/server/postgresql/utils/query.py @@ -50,6 +50,31 @@ def key_to_sqlalchemy_label(key: str, source: Source) -> str: return f"{source.db_schema}_{source.db_table}_{key}" +def source_to_dataset_model(source: Source | str, session: Session) -> Models: + """Converts a Source object to a Sources ORM object.""" + if isinstance(source, str): + source_schema, source_table = get_schema_table_names(source, validate=True) + else: + source_schema, source_table = source.db_schema, source.db_table + + source_dataset = ( + session.query(Models) + .join(Sources, Sources.model == Models.hash) + .filter( + Sources.schema == source_schema, + Sources.table == source_table, + ) + .first() + ) + if source_dataset is None: + raise MatchboxDatasetError( + db_schema=source_schema, + db_table=source_table, + ) + + return source_dataset + + def _resolve_thresholds( lineage_truths: dict[str, float], model: Models, @@ -294,26 +319,10 @@ def query( # Process each source dataset for source, fields in selector.items(): - # Get the dataset model - dataset = ( - session.query(Models) - .join(Sources, Sources.model == Models.hash) - .filter( - Sources.schema == source.db_schema, - Sources.table == source.db_table, - Sources.id == source.db_pk, - ) - .first() - ) - - if dataset is None: - raise MatchboxDatasetError( - db_schema=source.db_schema, db_table=source.db_table - ) - + dataset_model = source_to_dataset_model(source, session) hash_query = _resolve_cluster_hierarchy( - dataset_hash=dataset.hash, - model=truth_model if truth_model else dataset, + dataset_hash=dataset_model.hash, + model=truth_model if truth_model else dataset_model, threshold=threshold, engine=engine, ) @@ -537,48 +546,19 @@ def match( * Returns the results as Match objects, one per target """ # Split source and target into schema/table - source_schema, source_table = get_schema_table_names(source, validate=True) targets = [target] if isinstance(target, str) else target target_pairs = [get_schema_table_names(t, validate=True) for t in targets] with Session(engine) as session: - # Get truth model + # Get source, target and truth models + source_model = source_to_dataset_model(source, session) + target_models = [] + for target in targets: + target_models.append(source_to_dataset_model(target, session)) truth_model = session.query(Models).filter(Models.name == model).first() if truth_model is None: raise MatchboxModelError(f"Model {model} not found") - # Get source dataset - source_dataset = ( - session.query(Models) - .join(Sources, Sources.model == Models.hash) - .filter( - Sources.schema == source_schema, - Sources.table == source_table, - ) - .first() - ) - if source_dataset is None: - raise MatchboxDatasetError( - db_schema=source_schema, - db_table=source_table, - ) - - # Get target datasets - target_datasets = [] - for schema, table in target_pairs: - dataset = ( - session.query(Models) - .join(Sources, Sources.model == Models.hash) - .filter( - Sources.schema == schema, - Sources.table == table, - ) - .first() - ) - if dataset is None: - raise MatchboxDatasetError(db_schema=schema, db_table=table) - target_datasets.append(dataset) - # Get model lineage and resolve thresholds lineage_truths = truth_model.get_lineage() thresholds = _resolve_thresholds( @@ -593,7 +573,7 @@ def match( # Build the query components unnested = _build_unnested_clusters() - source_cluster = _find_source_cluster(unnested, source_dataset.hash, source_id) + source_cluster = _find_source_cluster(unnested, source_model.hash, source_id) hierarchy_up = _build_hierarchy_up(source_cluster, valid_clusters) highest = _find_highest_parent(hierarchy_up) hierarchy_down = _build_hierarchy_down(highest, unnested, valid_clusters) @@ -622,14 +602,13 @@ def match( # Create Match objects for each target result = [] - for target_dataset in target_datasets: + for target_model in target_models: # Get source/target table names - source_name = f"{source_schema}.{source_table}" target_schema, target_table = next( (schema, table) for schema, table in target_pairs - if session.get(Sources, target_dataset.hash).schema == schema - and session.get(Sources, target_dataset.hash).table == table + if session.get(Sources, target_model.hash).schema == schema + and session.get(Sources, target_model.hash).table == table ) target_name = f"{target_schema}.{target_table}" @@ -637,17 +616,17 @@ def match( source_ids = { id for _, dataset_hash, id in matches - if dataset_hash == source_dataset.hash + if dataset_hash == source_model.hash } target_ids = { id for _, dataset_hash, id in matches - if dataset_hash == target_dataset.hash + if dataset_hash == target_model.hash } match_obj = Match( cluster=cluster, - source=source_name, + source=source, source_id=source_ids, target=target_name, target_id=target_ids, From 1db389fed1f19e65285fd275e68fe2d472298d82 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Thu, 12 Dec 2024 16:56:13 +0000 Subject: [PATCH 16/22] Updated Match validation --- src/matchbox/common/db.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/matchbox/common/db.py b/src/matchbox/common/db.py index f736668..f8d7946 100644 --- a/src/matchbox/common/db.py +++ b/src/matchbox/common/db.py @@ -51,14 +51,12 @@ class Match(BaseModel): @model_validator(mode="after") def found_or_none(self) -> "Match": - if self.cluster is None and (self.source_id or self.target_id): + if self.target and not (self.source and self.cluster): raise ValueError( - "A match must have a cluster if source_id or target_id is set." - ) - elif self.cluster is not None and not (self.source_id or self.target_id): - raise ValueError( - "A match must have source_id or target_id if cluster is set." + "A match must have sources and a cluster if target was found." ) + if self.cluster and not self.source: + raise ValueError("A match must have source if cluster is set.") return self From f300e961619c1fcb8a256f1b2c98ba19ea59b1ba Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Thu, 12 Dec 2024 17:10:37 +0000 Subject: [PATCH 17/22] Dealt with all comments --- src/matchbox/server/postgresql/utils/db.py | 15 +++++++-------- test/fixtures/data.py | 12 +++++++----- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/matchbox/server/postgresql/utils/db.py b/src/matchbox/server/postgresql/utils/db.py index 8f1b66b..abc82c3 100644 --- a/src/matchbox/server/postgresql/utils/db.py +++ b/src/matchbox/server/postgresql/utils/db.py @@ -70,14 +70,13 @@ def get_data_subgraph(engine: Engine) -> rx.PyDiGraph: nodes[cluster_id] = cluster_idx if cluster.id is not None and cluster.dataset is not None: - source = sources.get(cluster.dataset) - if source: - data_id = str(cluster.id) - data_idx = G.add_node({"id": data_id, "type": "data"}) - - source_id = f"{source.schema}.{source.table}" - G.add_edge(data_idx, nodes[source_id], {"type": "source"}) - G.add_edge(nodes[cluster_id], data_idx, {"type": "data"}) + source = sources[cluster.dataset] + data_id = str(cluster.id) + data_idx = G.add_node({"id": data_id, "type": "data"}) + + source_id = f"{source.schema}.{source.table}" + G.add_edge(data_idx, nodes[source_id], {"type": "source"}) + G.add_edge(nodes[cluster_id], data_idx, {"type": "data"}) for contains in session.query(Contains).all(): G.add_edge( diff --git a/test/fixtures/data.py b/test/fixtures/data.py index 1fc7ffa..4534469 100644 --- a/test/fixtures/data.py +++ b/test/fixtures/data.py @@ -120,6 +120,12 @@ def revolution_inc( Revolution Inc. as it exists across all three datasets. UUIDs are converted to strings to mirror how Matchbox stores them. + + Based on the above fixtures, should return: + + * Three CRNs + * One DUNS + * Two CDMS """ crn_ids = crn_companies[ crn_companies["company_name"].str.contains("Revolution", case=False) @@ -129,11 +135,7 @@ def revolution_inc( duns_companies["company_name"].str.contains("Revolution", case=False) ]["id"].tolist() - revolution_crn = crn_companies[ - crn_companies["company_name"].str.contains("Revolution", case=False) - ]["crn"].iloc[0] - - cdms_ids = cdms_companies[cdms_companies["crn"] == revolution_crn]["id"].tolist() + cdms_ids = cdms_companies[cdms_companies["crn"] == crn_ids[0]]["id"].tolist() return { "crn": [str(id) for id in crn_ids], From 7b4a1999d70d0faccee653fb8c1e570c2fb70f8e Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Thu, 12 Dec 2024 17:26:42 +0000 Subject: [PATCH 18/22] Fixed validation --- src/matchbox/common/db.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/matchbox/common/db.py b/src/matchbox/common/db.py index f8d7946..19a4359 100644 --- a/src/matchbox/common/db.py +++ b/src/matchbox/common/db.py @@ -51,11 +51,11 @@ class Match(BaseModel): @model_validator(mode="after") def found_or_none(self) -> "Match": - if self.target and not (self.source and self.cluster): + if self.target_id and not (self.source_id and self.cluster): raise ValueError( "A match must have sources and a cluster if target was found." ) - if self.cluster and not self.source: + if self.cluster and not self.source_id: raise ValueError("A match must have source if cluster is set.") return self From 1c29775a1dadf338100191b7a0e63734013d3e87 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Fri, 13 Dec 2024 10:59:49 +0000 Subject: [PATCH 19/22] Removed data subgraph visualisation functions --- src/matchbox/client/visualisation.py | 92 ---------------------- src/matchbox/server/postgresql/utils/db.py | 41 ---------- 2 files changed, 133 deletions(-) diff --git a/src/matchbox/client/visualisation.py b/src/matchbox/client/visualisation.py index b92b19f..e7e324e 100644 --- a/src/matchbox/client/visualisation.py +++ b/src/matchbox/client/visualisation.py @@ -1,6 +1,3 @@ -from collections import defaultdict -from itertools import count - import rustworkx as rx from matplotlib.figure import Figure from rustworkx.visualization import mpl_draw @@ -42,92 +39,3 @@ def draw_resolution_graph() -> Figure: labels=lambda node: node["name"], font_size=8, ) - - -def draw_data_tree(graph: rx.PyDiGraph) -> str: - """ - Convert a rustworkx PyDiGraph to Mermaid graph visualization code. - - Args: - graph (rx.PyDiGraph): A rustworkx directed graph with nodes containing 'id' and - 'type' attributes - - Returns: - str: Mermaid graph definition code - """ - mermaid_lines = ["graph LR"] - - counters = defaultdict(count, {"hash": count(1)}) - node_to_var = {} - node_types = {} - data_nodes = set() - - def format_id(id_value): - """Format ID value, converting bytes to hex if needed.""" - if isinstance(id_value, bytes): - return f"\\x{id_value.hex()}" - return f"['{str(id_value)}']" - - for node_idx in graph.node_indices(): - node_data = graph.get_node_data(node_idx) - if isinstance(node_data, dict): - node_type = node_data.get("type", "") - node_types[node_idx] = node_type - if node_type == "data": - data_nodes.add(node_idx) - - for node_idx, node_type in node_types.items(): - if node_type == "source": - node_data = graph.get_node_data(node_idx) - table_name = node_data["id"].split(".")[-1] - node_to_var[node_idx] = table_name - - counter = count(1) - for predecessor in graph.predecessor_indices(node_idx): - if predecessor in data_nodes: - node_to_var[predecessor] = f"{table_name}{str(next(counter))}" - data_nodes.remove(predecessor) - - remaining_counter = count(len(node_to_var) + 1) - for node_idx in data_nodes: - node_to_var[node_idx] = str(next(remaining_counter)) - - for node_idx, node_type in node_types.items(): - if node_type == "cluster": - node_to_var[node_idx] = f"hash{next(counters['hash'])}" - - sources = [] - data_defs = [] - clusters = [] - - for node_idx, node_type in node_types.items(): - node_data = graph.get_node_data(node_idx) - var_name = node_to_var[node_idx] - - if node_type == "source": - node_def = f' {var_name}["{node_data["id"]}"]' - sources.append(node_def) - elif node_type == "data": - node_label = format_id(node_data["id"]) - node_label = node_label.strip("[]'") - node_def = f' {var_name}["{node_label}"]' - data_defs.append(node_def) - elif node_type == "cluster": - node_label = format_id(node_data["id"]) - node_def = f' {var_name}["{node_label}"]' - clusters.append(node_def) - - mermaid_lines.extend(sources) - mermaid_lines.extend(data_defs) - mermaid_lines.extend(clusters) - - mermaid_lines.append("") - - for edge in graph.edge_list(): - source = edge[0] - target = edge[1] - source_var = node_to_var[source] - target_var = node_to_var[target] - mermaid_lines.append(f" {source_var} --> {target_var}") - - return "\n".join(mermaid_lines) diff --git a/src/matchbox/server/postgresql/utils/db.py b/src/matchbox/server/postgresql/utils/db.py index abc82c3..2402cd2 100644 --- a/src/matchbox/server/postgresql/utils/db.py +++ b/src/matchbox/server/postgresql/utils/db.py @@ -5,7 +5,6 @@ from itertools import islice from typing import Any, Callable, Iterable -import rustworkx as rx from pg_bulk_ingest import Delete, Upsert, ingest from sqlalchemy import Engine, Index, MetaData, Table from sqlalchemy.engine.base import Connection @@ -18,11 +17,8 @@ ResolutionNodeType, ) from matchbox.server.postgresql.orm import ( - Clusters, - Contains, ResolutionFrom, Resolutions, - Sources, ) # Retrieval @@ -49,43 +45,6 @@ def get_resolution_graph(engine: Engine) -> ResolutionGraph: return G -def get_data_subgraph(engine: Engine) -> rx.PyDiGraph: - """Retrieves the complete data subgraph as a PyDiGraph.""" - G = rx.PyDiGraph() - nodes = {} - - with Session(engine) as session: - sources = {source.model: source for source in session.query(Sources).all()} - - for source in sources.values(): - source_id = f"{source.schema}.{source.table}" - if source_id not in nodes: - source_idx = G.add_node({"id": source_id, "type": "source"}) - nodes[source_id] = source_idx - - for cluster in session.query(Clusters).all(): - cluster_id = cluster.hash - if cluster_id not in nodes: - cluster_idx = G.add_node({"id": cluster_id, "type": "cluster"}) - nodes[cluster_id] = cluster_idx - - if cluster.id is not None and cluster.dataset is not None: - source = sources[cluster.dataset] - data_id = str(cluster.id) - data_idx = G.add_node({"id": data_id, "type": "data"}) - - source_id = f"{source.schema}.{source.table}" - G.add_edge(data_idx, nodes[source_id], {"type": "source"}) - G.add_edge(nodes[cluster_id], data_idx, {"type": "data"}) - - for contains in session.query(Contains).all(): - G.add_edge( - nodes[contains.parent], nodes[contains.child], {"type": "contains"} - ) - - return G - - # SQLAlchemy profiling From fce37818f4abb20315567d2e0dc45f4b6a26729d Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Fri, 13 Dec 2024 12:57:55 +0000 Subject: [PATCH 20/22] One to zero test added --- test/fixtures/data.py | 32 ++++++++++++++++++++++++++++++++ test/server/test_adapter.py | 26 ++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/test/fixtures/data.py b/test/fixtures/data.py index 4534469..4acb59c 100644 --- a/test/fixtures/data.py +++ b/test/fixtures/data.py @@ -144,6 +144,38 @@ def revolution_inc( } +@pytest.fixture(scope="session") +def winner_inc( + crn_companies: DataFrame, duns_companies: DataFrame, cdms_companies: DataFrame +) -> dict[str, str]: + """ + Winner Inc. as it exists across all three datasets. + + UUIDs are converted to strings to mirror how Matchbox stores them. + + Based on the above fixtures, should return: + + * Three CRNs + * Zero DUNS + * Two CDMS + """ + crn_ids = crn_companies[ + crn_companies["company_name"].str.contains("Winner", case=False) + ]["id"].tolist() + + duns_ids = duns_companies[ + duns_companies["company_name"].str.contains("Winner", case=False) + ]["id"].tolist() + + cdms_ids = cdms_companies[cdms_companies["crn"] == crn_ids[0]]["id"].tolist() + + return { + "crn": [str(id) for id in crn_ids], + "duns": [str(id) for id in duns_ids], + "cdms": [str(id) for id in cdms_ids], + } + + @pytest.fixture(scope="function") def query_clean_crn( matchbox_postgres: MatchboxPostgres, warehouse_data: list[Source] diff --git a/test/server/test_adapter.py b/test/server/test_adapter.py index f0ea4f0..35a165e 100644 --- a/test/server/test_adapter.py +++ b/test/server/test_adapter.py @@ -597,6 +597,7 @@ def test_match_one_to_many(self, revolution_inc: dict[str, list[str]]): assert isinstance(res, Match) assert res.source == str(duns_wh) assert res.target == str(crn_wh) + assert res.cluster is not None assert res.source_id == set(revolution_inc["duns"]) assert res.target_id == set(revolution_inc["crn"]) @@ -619,9 +620,33 @@ def test_match_many_to_one(self, revolution_inc: dict[str, list[str]]): assert isinstance(res, Match) assert res.source == str(crn_wh) assert res.target == str(duns_wh) + assert res.cluster is not None assert res.source_id == set(revolution_inc["crn"]) assert res.target_id == set(revolution_inc["duns"]) + def test_match_one_to_none(self, winner_inc: dict[str, list[str]]): + """Test that matching data work when the target has no IDs.""" + self.setup_database("link") + + crn_x_duns = "deterministic_naive_test.crn_naive_test.duns" + crn_wh = self.warehouse_data[0] + duns_wh = self.warehouse_data[1] + + res = match( + backend=self.backend, + source_id=winner_inc["crn"][0], + source=str(crn_wh), + target=str(duns_wh), + resolution=crn_x_duns, + ) + + assert isinstance(res, Match) + assert res.source == str(crn_wh) + assert res.target == str(duns_wh) + assert res.cluster is not None + assert res.source_id == set(winner_inc["crn"]) + assert res.target_id == set() == set(winner_inc["duns"]) + def test_match_none_to_none(self): """Test that matching data work when the supplied key doesn't exist.""" self.setup_database("link") @@ -641,6 +666,7 @@ def test_match_none_to_none(self): assert isinstance(res, Match) assert res.source == str(crn_wh) assert res.target == str(duns_wh) + assert res.cluster is None assert res.source_id == set() assert res.target_id == set() From 592ca99fa94eae58ee4bd5bc80e7995ab769cfe6 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Fri, 13 Dec 2024 13:31:34 +0000 Subject: [PATCH 21/22] Removed a bunch of redundant code from processing of match results --- src/matchbox/server/postgresql/utils/query.py | 45 ++++++------------- 1 file changed, 14 insertions(+), 31 deletions(-) diff --git a/src/matchbox/server/postgresql/utils/query.py b/src/matchbox/server/postgresql/utils/query.py index bea65d1..32476b3 100644 --- a/src/matchbox/server/postgresql/utils/query.py +++ b/src/matchbox/server/postgresql/utils/query.py @@ -52,7 +52,7 @@ def key_to_sqlalchemy_label(key: str, source: Source) -> str: def source_to_dataset_resolution(source: Source | str, session: Session) -> Resolutions: - """Converts a Source object to a Sources ORM object.""" + """Converts a common Source object to a Resolutions ORM object.""" if isinstance(source, str): source_schema, source_table = get_schema_table_names(source, validate=True) else: @@ -578,19 +578,24 @@ def match( """ # Split source and target into schema/table targets = [target] if isinstance(target, str) else target - target_pairs = [get_schema_table_names(t, validate=True) for t in targets] with Session(engine) as session: # Get source, target and truth resolutions source_resolution = source_to_dataset_resolution(source, session) + + # Get target resolutions with schema/table info target_resolutions = [] - for target in targets: - target_resolutions.append(source_to_dataset_resolution(target, session)) + for t in targets: + schema, table = get_schema_table_names(t, validate=True) + target_resolution = source_to_dataset_resolution(t, session) + target_resolutions.append((target_resolution, f"{schema}.{table}")) + + # Get truth resolution truth_resolution = ( session.query(Resolutions).filter(Resolutions.name == resolution).first() ) if truth_resolution is None: - raise MatchboxResolutionError(f"Resolution {resolution} not found") + raise MatchboxResolutionError(resolution_name=resolution) # Get resolution lineage and resolve thresholds lineage_truths = truth_resolution.get_lineage() @@ -627,7 +632,7 @@ def match( # Group matches by dataset cluster = None - matches_by_dataset = {} + matches_by_dataset: dict[bytes, set] = {} for cluster_hash, dataset_hash, id in matches: if cluster is None: cluster = cluster_hash @@ -635,36 +640,14 @@ def match( matches_by_dataset[dataset_hash] = set() matches_by_dataset[dataset_hash].add(id) - # Create Match objects for each target result = [] - for target_resolution in target_resolutions: - # Get source/target table names - target_schema, target_table = next( - (schema, table) - for schema, table in target_pairs - if session.get(Sources, target_resolution.hash).schema == schema - and session.get(Sources, target_resolution.hash).table == table - ) - target_name = f"{target_schema}.{target_table}" - - # Get source and target IDs - source_ids = { - id - for _, dataset_hash, id in matches - if dataset_hash == source_resolution.hash - } - target_ids = { - id - for _, dataset_hash, id in matches - if dataset_hash == target_resolution.hash - } - + for target_resolution, target_name in target_resolutions: match_obj = Match( cluster=cluster, source=source, - source_id=source_ids, + source_id=matches_by_dataset.get(source_resolution.hash, set()), target=target_name, - target_id=target_ids, + target_id=matches_by_dataset.get(target_resolution.hash, set()), ) result.append(match_obj) From ce588051efa3a3d05417932daee5c882461f258f Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Fri, 13 Dec 2024 15:36:28 +0000 Subject: [PATCH 22/22] Fixed unit test --- test/fixtures/data.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/test/fixtures/data.py b/test/fixtures/data.py index 4acb59c..e685fe2 100644 --- a/test/fixtures/data.py +++ b/test/fixtures/data.py @@ -135,14 +135,24 @@ def revolution_inc( duns_companies["company_name"].str.contains("Revolution", case=False) ]["id"].tolist() - cdms_ids = cdms_companies[cdms_companies["crn"] == crn_ids[0]]["id"].tolist() + revolution_crn = crn_companies[ + crn_companies["company_name"].str.contains("Revolution", case=False) + ]["crn"].iloc[0] + + cdms_ids = cdms_companies[cdms_companies["crn"] == revolution_crn]["id"].tolist() - return { + revolution = { "crn": [str(id) for id in crn_ids], "duns": [str(id) for id in duns_ids], "cdms": [str(id) for id in cdms_ids], } + assert len(revolution.get("crn", [])) == 3 + assert len(revolution.get("duns", [])) == 1 + assert len(revolution.get("cdms", [])) == 2 + + return revolution + @pytest.fixture(scope="session") def winner_inc( @@ -167,14 +177,24 @@ def winner_inc( duns_companies["company_name"].str.contains("Winner", case=False) ]["id"].tolist() - cdms_ids = cdms_companies[cdms_companies["crn"] == crn_ids[0]]["id"].tolist() + winner_crn = crn_companies[ + crn_companies["company_name"].str.contains("Revolution", case=False) + ]["crn"].iloc[0] + + cdms_ids = cdms_companies[cdms_companies["crn"] == winner_crn]["id"].tolist() - return { + winner = { "crn": [str(id) for id in crn_ids], "duns": [str(id) for id in duns_ids], "cdms": [str(id) for id in cdms_ids], } + assert len(winner.get("crn", [])) == 3 + assert len(winner.get("duns", [])) == 0 + assert len(winner.get("cdms", [])) == 2 + + return winner + @pytest.fixture(scope="function") def query_clean_crn(