diff --git a/src/matchbox/common/db.py b/src/matchbox/common/db.py index 0c913b8..19a4359 100644 --- a/src/matchbox/common/db.py +++ b/src/matchbox/common/db.py @@ -40,6 +40,26 @@ T = TypeVar("T") +class Match(BaseModel): + """A match between primary keys in the Matchbox database.""" + + cluster: bytes | None + source: str + source_id: set[str] = Field(default_factory=set) + target: str + target_id: set[str] = Field(default_factory=set) + + @model_validator(mode="after") + def found_or_none(self) -> "Match": + if self.target_id and not (self.source_id and self.cluster): + raise ValueError( + "A match must have sources and a cluster if target was found." + ) + if self.cluster and not self.source_id: + raise ValueError("A match must have source if cluster is set.") + return self + + class Probability(BaseModel): """A probability of a match in the Matchbox database. diff --git a/src/matchbox/helpers/selector.py b/src/matchbox/helpers/selector.py index 56105e3..36d4d70 100644 --- a/src/matchbox/helpers/selector.py +++ b/src/matchbox/helpers/selector.py @@ -4,7 +4,7 @@ from pyarrow import Table as ArrowTable from sqlalchemy import Engine, inspect -from matchbox.common.db import Source, get_schema_table_names +from matchbox.common.db import Match, Source, get_schema_table_names from matchbox.server import MatchboxDBAdapter, inject_backend @@ -91,3 +91,37 @@ def query( return_type="pandas" if not return_type else return_type, limit=limit, ) + + +@inject_backend +def match( + backend: MatchboxDBAdapter, + source_id: str, + source: str, + target: str | list[str], + resolution: str, + threshold: float | dict[str, float] | None = None, +) -> Match | list[Match]: + """Matches IDs against the selected backend. + + Args: + backend: the backend to query + source_id: The ID of the source to match. + source: The name of the source dataset. + target: The name of the target dataset(s). + resolution: the resolution to use for filtering results + threshold (optional): the threshold to use for creating clusters + If None, uses the resolutions' default threshold + If a float, uses that threshold for the specified resolution, and the + resolution's cached thresholds for its ancestors + If a dictionary, expects a shape similar to resolution.ancestors, keyed + by resolution name and valued by the threshold to use for that resolution. + Will use these threshold values instead of the cached thresholds + """ + return backend.match( + source_id=source_id, + source=source, + target=target, + resolution=resolution, + threshold=threshold, + ) diff --git a/src/matchbox/server/api.py b/src/matchbox/server/api.py index 71da31c..36babd3 100644 --- a/src/matchbox/server/api.py +++ b/src/matchbox/server/api.py @@ -150,6 +150,11 @@ async def query(): raise HTTPException(status_code=501, detail="Not implemented") +@app.get("/match") +async def match(): + raise HTTPException(status_code=501, detail="Not implemented") + + @app.get("/validate/hash") async def validate_hashes(): raise HTTPException(status_code=501, detail="Not implemented") diff --git a/src/matchbox/server/base.py b/src/matchbox/server/base.py index d31ff3b..95bdcd0 100644 --- a/src/matchbox/server/base.py +++ b/src/matchbox/server/base.py @@ -19,7 +19,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict from sqlalchemy import Engine -from matchbox.common.db import Source +from matchbox.common.db import Match, Source from matchbox.common.graph import ResolutionGraph if TYPE_CHECKING: @@ -251,6 +251,16 @@ def query( limit: int = None, ) -> PandasDataFrame | ArrowTable | PolarsDataFrame: ... + @abstractmethod + def match( + self, + source_id: str, + source: str, + target: str | list[str], + resolution: str, + threshold: float | dict[str, float] | None = None, + ) -> Match | list[Match]: ... + @abstractmethod def index(self, dataset: Source) -> None: ... diff --git a/src/matchbox/server/postgresql/adapter.py b/src/matchbox/server/postgresql/adapter.py index 3290e76..c8af7dc 100644 --- a/src/matchbox/server/postgresql/adapter.py +++ b/src/matchbox/server/postgresql/adapter.py @@ -5,7 +5,7 @@ from sqlalchemy import Engine, and_, bindparam, delete, func, or_, select from sqlalchemy.orm import Session -from matchbox.common.db import Source, SourceWarehouse +from matchbox.common.db import Match, Source, SourceWarehouse from matchbox.common.exceptions import ( MatchboxDataError, MatchboxDatasetError, @@ -29,7 +29,7 @@ insert_model, insert_results, ) -from matchbox.server.postgresql.utils.query import query +from matchbox.server.postgresql.utils.query import match, query from matchbox.server.postgresql.utils.results import ( get_model_clusters, get_model_probabilities, @@ -300,6 +300,39 @@ def query( limit=limit, ) + def match( + self, + source_id: str, + source: str, + target: str | list[str], + resolution: str, + threshold: float | dict[str, float] | None = None, + ) -> Match | list[Match]: + """Matches an ID in a source dataset and returns the keys in the targets. + + Args: + source_id: The ID of the source to match. + source: The name of the source dataset. + target: The name of the target dataset(s). + resolution: The name of the resolution to use for matching. + threshold (optional): the threshold to use for creating clusters + If None, uses the resolutions' default threshold + If a float, uses that threshold for the specified resolution, and the + resolution's cached thresholds for its ancestors + If a dictionary, expects a shape similar to resolution.ancestors, keyed + by resolution name and valued by the threshold to use for that + resolution. + Will use these threshold values instead of the cached thresholds + """ + return match( + source_id=source_id, + source=source, + target=target, + resolution=resolution, + engine=MBDB.get_engine(), + threshold=threshold, + ) + def index(self, dataset: Source) -> None: """Indexes a data from your data warehouse within Matchbox. diff --git a/src/matchbox/server/postgresql/orm.py b/src/matchbox/server/postgresql/orm.py index 70c0e71..d9beea3 100644 --- a/src/matchbox/server/postgresql/orm.py +++ b/src/matchbox/server/postgresql/orm.py @@ -5,6 +5,7 @@ CheckConstraint, Column, ForeignKey, + Index, UniqueConstraint, select, ) @@ -98,6 +99,22 @@ def descendants(self) -> set["Resolutions"]: ) return set(session.execute(descendant_query).scalars().all()) + def get_lineage(self) -> dict[bytes, float]: + """Returns all ancestors and their cached truth values from this model.""" + with Session(MBDB.get_engine()) as session: + lineage_query = ( + select(ResolutionFrom.parent, ResolutionFrom.truth_cache) + .where(ResolutionFrom.child == self.hash) + .order_by(ResolutionFrom.level.desc()) + ) + + results = session.execute(lineage_query).all() + + lineage = {parent: truth for parent, truth in results} + lineage[self.hash] = self.truth + + return lineage + def get_lineage_to_dataset( self, dataset: "Resolutions" ) -> tuple[bytes, dict[bytes, float]]: @@ -108,13 +125,11 @@ def get_lineage_to_dataset( ) if self.hash == dataset.hash: - return {} + return {dataset.hash: None} with Session(MBDB.get_engine()) as session: path_query = ( - select( - ResolutionFrom.parent, ResolutionFrom.truth_cache, Resolutions.type - ) + select(ResolutionFrom.parent, ResolutionFrom.truth_cache) .join(Resolutions, Resolutions.hash == ResolutionFrom.parent) .where(ResolutionFrom.child == self.hash) .order_by(ResolutionFrom.level.desc()) @@ -122,17 +137,12 @@ def get_lineage_to_dataset( results = session.execute(path_query).all() - if not any(parent == dataset.hash for parent, _, _ in results): + if not any(parent == dataset.hash for parent, _ in results): raise ValueError( f"No path between resolution {self.name}, dataset {dataset.name}" ) - lineage = { - parent: truth - for parent, truth, type in results - if type != ResolutionNodeType.DATASET.value - } - + lineage = {parent: truth for parent, truth in results} lineage[self.hash] = self.truth return lineage @@ -181,8 +191,12 @@ class Contains(CountMixin, MBDB.MatchboxBase): BYTEA, ForeignKey("clusters.hash", ondelete="CASCADE"), primary_key=True ) - # Constraints - __table_args__ = (CheckConstraint("parent != child", name="no_self_containment"),) + # Constraints and indices + __table_args__ = ( + CheckConstraint("parent != child", name="no_self_containment"), + Index("ix_contains_parent_child", "parent", "child"), + Index("ix_contains_child_parent", "child", "parent"), + ) class Clusters(CountMixin, MBDB.MatchboxBase): @@ -211,6 +225,9 @@ class Clusters(CountMixin, MBDB.MatchboxBase): backref="parents", ) + # Constraints and indices + __table_args__ = (Index("ix_clusters_id_gin", id, postgresql_using="gin"),) + class Probabilities(CountMixin, MBDB.MatchboxBase): """Table of probabilities that a cluster is correct, according to a resolution.""" diff --git a/src/matchbox/server/postgresql/utils/db.py b/src/matchbox/server/postgresql/utils/db.py index 21cdad7..2402cd2 100644 --- a/src/matchbox/server/postgresql/utils/db.py +++ b/src/matchbox/server/postgresql/utils/db.py @@ -3,10 +3,10 @@ import io import pstats from itertools import islice -from typing import Any, Callable, Iterable, Tuple +from typing import Any, Callable, Iterable from pg_bulk_ingest import Delete, Upsert, ingest -from sqlalchemy import Engine, MetaData, Table +from sqlalchemy import Engine, Index, MetaData, Table from sqlalchemy.engine.base import Connection from sqlalchemy.orm import DeclarativeMeta, Session @@ -16,7 +16,10 @@ ResolutionNode, ResolutionNodeType, ) -from matchbox.server.postgresql.orm import ResolutionFrom, Resolutions +from matchbox.server.postgresql.orm import ( + ResolutionFrom, + Resolutions, +) # Retrieval @@ -79,18 +82,55 @@ def batched(iterable: Iterable, n: int) -> Iterable: def data_to_batch( records: list[tuple], table: Table, batch_size: int -) -> Callable[[str], Tuple[Any]]: +) -> Callable[[str], tuple[Any]]: """Constructs a batches function for any dataframe and table.""" def _batches( high_watermark, # noqa ARG001 required for pg_bulk_ingest - ) -> Iterable[Tuple[None, None, Iterable[Tuple[Table, tuple]]]]: + ) -> Iterable[tuple[None, None, Iterable[tuple[Table, tuple]]]]: for batch in batched(records, batch_size): yield None, None, ((table, t) for t in batch) return _batches +def isolate_table(table: DeclarativeMeta) -> tuple[MetaData, Table]: + """Creates an isolated copy of a SQLAlchemy table. + + This is used to prevent pg_bulk_ingest from attempting to drop unrelated tables + in the same schema. The function creates a new Table instance with: + + * A fresh MetaData instance + * Copied columns + * Recreated indices properly bound to the new table + + Args: + table: The DeclarativeMeta class whose table should be isolated + + Returns: + A tuple of: + * The isolated SQLAlchemy MetaData + * A new SQLAlchemy Table instance with all columns and indices + """ + isolated_metadata = MetaData(schema=table.__table__.schema) + + isolated_table = Table( + table.__table__.name, + isolated_metadata, + *[c._copy() for c in table.__table__.columns], + schema=table.__table__.schema, + ) + + for idx in table.__table__.indexes: + Index( + idx.name, + *[isolated_table.c[col.name] for col in idx.columns], + **{k: v for k, v in idx.kwargs.items()}, + ) + + return isolated_metadata, isolated_table + + def batch_ingest( records: list[tuple[Any]], table: DeclarativeMeta, @@ -102,14 +142,7 @@ def batch_ingest( We isolate the table and metadata as pg_bulk_ingest will try and drop unrelated tables if they're in the same schema. """ - - isolated_metadata = MetaData(schema=table.__table__.schema) - isolated_table = Table( - table.__table__.name, - isolated_metadata, - *[c._copy() for c in table.__table__.columns], - schema=table.__table__.schema, - ) + isolated_metadata, isolated_table = isolate_table(table) fn_batch = data_to_batch( records=records, diff --git a/src/matchbox/server/postgresql/utils/query.py b/src/matchbox/server/postgresql/utils/query.py index 2b3534c..32476b3 100644 --- a/src/matchbox/server/postgresql/utils/query.py +++ b/src/matchbox/server/postgresql/utils/query.py @@ -4,12 +4,21 @@ import pyarrow as pa from pandas import ArrowDtype, DataFrame -from sqlalchemy import Engine, and_, cast, func, literal, null, select, union +from sqlalchemy import ( + Engine, + and_, + cast, + func, + literal, + null, + select, + union, +) from sqlalchemy.dialects.postgresql import BYTEA from sqlalchemy.orm import Session -from sqlalchemy.sql.selectable import Select +from sqlalchemy.sql.selectable import CTE, Select -from matchbox.common.db import Source, sql_to_df +from matchbox.common.db import Match, Source, get_schema_table_names, sql_to_df from matchbox.common.exceptions import ( MatchboxDatasetError, MatchboxResolutionError, @@ -42,6 +51,31 @@ def key_to_sqlalchemy_label(key: str, source: Source) -> str: return f"{source.db_schema}_{source.db_table}_{key}" +def source_to_dataset_resolution(source: Source | str, session: Session) -> Resolutions: + """Converts a common Source object to a Resolutions ORM object.""" + if isinstance(source, str): + source_schema, source_table = get_schema_table_names(source, validate=True) + else: + source_schema, source_table = source.db_schema, source.db_table + + source_dataset = ( + session.query(Resolutions) + .join(Sources, Sources.resolution == Resolutions.hash) + .filter( + Sources.schema == source_schema, + Sources.table == source_table, + ) + .first() + ) + if source_dataset is None: + raise MatchboxDatasetError( + db_schema=source_schema, + db_table=source_table, + ) + + return source_dataset + + def _resolve_thresholds( lineage_truths: dict[str, float], resolution: Resolutions, @@ -63,6 +97,12 @@ def _resolve_thresholds( resolved_thresholds = {} for resolution_hash, default_truth in lineage_truths.items(): + # Dataset + if default_truth is None: + resolved_thresholds[resolution_hash] = None + continue + + # Model if threshold is None: resolved_thresholds[resolution_hash] = default_truth elif isinstance(threshold, float): @@ -104,9 +144,16 @@ def _union_valid_clusters(lineage_thresholds: dict[bytes, float]) -> Select: valid_clusters = None for resolution_hash, threshold in lineage_thresholds.items(): - resolution_valid = _get_valid_clusters_for_resolution( - resolution_hash, threshold - ) + if threshold is None: + # This is a dataset - get all its clusters directly + resolution_valid = select(Clusters.hash.label("cluster")).where( + Clusters.dataset == hash_to_hex_decode(resolution_hash) + ) + else: + # This is a model - get clusters meeting threshold + resolution_valid = _get_valid_clusters_for_resolution( + resolution_hash, threshold + ) if valid_clusters is None: valid_clusters = resolution_valid @@ -141,6 +188,9 @@ def _resolve_cluster_hierarchy( """ with Session(engine) as session: dataset_resolution = session.get(Resolutions, dataset_hash) + if dataset_resolution is None: + raise MatchboxDatasetError("Dataset not found") + try: lineage_truths = resolution.get_lineage_to_dataset( dataset=dataset_resolution @@ -168,6 +218,7 @@ def _resolve_cluster_hierarchy( ) .where( and_( + Clusters.hash.in_(select(valid_clusters.c.cluster)), Clusters.dataset == hash_to_hex_decode(dataset_hash), Clusters.id.isnot(None), ) @@ -287,21 +338,7 @@ def query( # Process each source dataset for source, fields in selector.items(): # Get the dataset resolution - dataset = ( - session.query(Resolutions) - .join(Sources, Sources.resolution == Resolutions.hash) - .filter( - Sources.schema == source.db_schema, - Sources.table == source.db_table, - Sources.id == source.db_pk, - ) - .first() - ) - - if dataset is None: - raise MatchboxDatasetError( - db_schema=source.db_schema, db_table=source.db_table - ) + dataset_resolution = source_to_dataset_resolution(source, session) # Warn if non-indexed fields have been requested not_indexed = set(fields) - set( @@ -315,8 +352,8 @@ def query( ) hash_query = _resolve_cluster_hierarchy( - dataset_hash=dataset.hash, - resolution=point_of_truth if point_of_truth else dataset, + dataset_hash=dataset_resolution.hash, + resolution=point_of_truth if point_of_truth else dataset_resolution, threshold=threshold, engine=engine, ) @@ -364,3 +401,254 @@ def query( ) else: raise ValueError(f"return_type of {return_type} not valid") + + +def _build_unnested_clusters() -> CTE: + """Create CTE that unnests cluster IDs for easier joining.""" + return ( + select(Clusters.hash, Clusters.dataset, func.unnest(Clusters.id).label("id")) + .select_from(Clusters) + .cte("unnested_clusters") + ) + + +def _find_source_cluster( + unnested_clusters: CTE, source_dataset_hash: bytes, source_id: str +) -> Select: + """Find the initial cluster containing the source ID.""" + return ( + select(unnested_clusters.c.hash) + .select_from(unnested_clusters) + .where( + and_( + unnested_clusters.c.dataset == hash_to_hex_decode(source_dataset_hash), + unnested_clusters.c.id == source_id, + ) + ) + .scalar_subquery() + ) + + +def _build_hierarchy_up( + source_cluster: Select, valid_clusters: CTE | None = None +) -> CTE: + """ + Build recursive CTE that finds all parent clusters. + + Args: + source_cluster: Subquery that finds starting cluster + valid_clusters: Optional CTE of valid clusters to filter by + """ + # Base case: direct parents + base = ( + select( + source_cluster.label("original_cluster"), + source_cluster.label("child"), + Contains.parent.label("parent"), + literal(1).label("level"), + ) + .select_from(Contains) + .where(Contains.child == source_cluster) + ) + + # Add valid clusters filter if provided + if valid_clusters is not None: + base = base.where(Contains.parent.in_(select(valid_clusters.c.cluster))) + + hierarchy_up = base.cte("hierarchy_up", recursive=True) + + # Recursive case + recursive = ( + select( + hierarchy_up.c.original_cluster, + hierarchy_up.c.parent.label("child"), + Contains.parent.label("parent"), + (hierarchy_up.c.level + 1).label("level"), + ) + .select_from(hierarchy_up) + .join(Contains, Contains.child == hierarchy_up.c.parent) + ) + + # Add valid clusters filter to recursive part if provided + if valid_clusters is not None: + recursive = recursive.where( + Contains.parent.in_(select(valid_clusters.c.cluster)) + ) + + return hierarchy_up.union_all(recursive) + + +def _find_highest_parent(hierarchy_up: CTE) -> Select: + """Find the topmost parent cluster from the hierarchy.""" + return ( + select(hierarchy_up.c.parent) + .order_by(hierarchy_up.c.level.desc()) + .limit(1) + .scalar_subquery() + ) + + +def _build_hierarchy_down( + highest_parent: Select, unnested_clusters: CTE, valid_clusters: CTE | None = None +) -> CTE: + """ + Build recursive CTE that finds all child clusters and their IDs. + + Args: + highest_parent: Subquery that finds top cluster + unnested_clusters: CTE with unnested cluster IDs + valid_clusters: Optional CTE of valid clusters to filter by + """ + # Base case: Get both direct children and their IDs + base = ( + select( + highest_parent.label("parent"), + Contains.child.label("child"), + literal(1).label("level"), + unnested_clusters.c.dataset.label("dataset"), + unnested_clusters.c.id.label("id"), + ) + .select_from(Contains) + .join_from( + Contains, + unnested_clusters, + unnested_clusters.c.hash == Contains.child, + isouter=True, + ) + .where(Contains.parent == highest_parent) + ) + + # Add valid clusters filter if provided + if valid_clusters is not None: + base = base.where(Contains.child.in_(select(valid_clusters.c.cluster))) + + hierarchy_down = base.cte("hierarchy_down", recursive=True) + + # Recursive case: Get both intermediate nodes AND their leaf records + recursive = ( + select( + hierarchy_down.c.parent, + Contains.child.label("child"), + (hierarchy_down.c.level + 1).label("level"), + unnested_clusters.c.dataset.label("dataset"), + unnested_clusters.c.id.label("id"), + ) + .select_from(hierarchy_down) + .join_from( + hierarchy_down, + Contains, + Contains.parent == hierarchy_down.c.child, + ) + .join_from( + Contains, + unnested_clusters, + unnested_clusters.c.hash == Contains.child, + isouter=True, + ) + .where(hierarchy_down.c.id.is_(None)) # Only recurse on non-leaf nodes + ) + + # Add valid clusters filter to recursive part if provided + if valid_clusters is not None: + recursive = recursive.where( + Contains.child.in_(select(valid_clusters.c.cluster)) + ) + + return hierarchy_down.union_all(recursive) + + +def match( + source_id: str, + source: str, + target: str | list[str], + resolution: str, + engine: Engine, + threshold: float | dict[str, float] | None = None, +) -> Match | list[Match]: + """Matches an ID in a source dataset and returns the keys in the targets. + + To accomplish this, the function: + + * Reconstructs the resolution lineage from the specified resolution + * Iterates through each target, and + * Retrieves its cluster hash according to the resolution + * Retrieves all other IDs in the cluster in the source dataset + * Retrieves all other IDs in the cluster in the target dataset + * Returns the results as Match objects, one per target + """ + # Split source and target into schema/table + targets = [target] if isinstance(target, str) else target + + with Session(engine) as session: + # Get source, target and truth resolutions + source_resolution = source_to_dataset_resolution(source, session) + + # Get target resolutions with schema/table info + target_resolutions = [] + for t in targets: + schema, table = get_schema_table_names(t, validate=True) + target_resolution = source_to_dataset_resolution(t, session) + target_resolutions.append((target_resolution, f"{schema}.{table}")) + + # Get truth resolution + truth_resolution = ( + session.query(Resolutions).filter(Resolutions.name == resolution).first() + ) + if truth_resolution is None: + raise MatchboxResolutionError(resolution_name=resolution) + + # Get resolution lineage and resolve thresholds + lineage_truths = truth_resolution.get_lineage() + thresholds = _resolve_thresholds( + lineage_truths=lineage_truths, + resolution=truth_resolution, + threshold=threshold, + session=session, + ) + + # Get valid clusters across all resolutions + valid_clusters = _union_valid_clusters(thresholds) + + # Build the query components + unnested = _build_unnested_clusters() + source_cluster = _find_source_cluster( + unnested, source_resolution.hash, source_id + ) + hierarchy_up = _build_hierarchy_up(source_cluster, valid_clusters) + highest = _find_highest_parent(hierarchy_up) + hierarchy_down = _build_hierarchy_down(highest, unnested, valid_clusters) + + # Get all matched IDs + final_stmt = ( + select( + hierarchy_down.c.parent.label("cluster"), + hierarchy_down.c.dataset, + hierarchy_down.c.id, + ) + .distinct() + .select_from(hierarchy_down) + ) + matches = session.execute(final_stmt).all() + + # Group matches by dataset + cluster = None + matches_by_dataset: dict[bytes, set] = {} + for cluster_hash, dataset_hash, id in matches: + if cluster is None: + cluster = cluster_hash + if dataset_hash not in matches_by_dataset: + matches_by_dataset[dataset_hash] = set() + matches_by_dataset[dataset_hash].add(id) + + result = [] + for target_resolution, target_name in target_resolutions: + match_obj = Match( + cluster=cluster, + source=source, + source_id=matches_by_dataset.get(source_resolution.hash, set()), + target=target_name, + target_id=matches_by_dataset.get(target_resolution.hash, set()), + ) + result.append(match_obj) + + return result[0] if isinstance(target, str) else result diff --git a/test/fixtures/data.py b/test/fixtures/data.py index c85c163..e685fe2 100644 --- a/test/fixtures/data.py +++ b/test/fixtures/data.py @@ -1,6 +1,6 @@ import logging -import uuid from pathlib import Path +from uuid import UUID import numpy as np import pandas as pd @@ -34,7 +34,7 @@ def all_companies(test_root_dir: Path) -> DataFrame: df = pd.read_csv( Path(test_root_dir, "data", "all_companies.csv"), encoding="utf-8" ).reset_index(names="id") - df["id"] = df["id"].apply(lambda x: uuid.UUID(int=x)) + df["id"] = df["id"].apply(lambda x: UUID(int=x)) return df @@ -60,7 +60,7 @@ def crn_companies(all_companies: DataFrame) -> DataFrame: df_crn["id"] = range(df_crn.shape[0]) df_crn = df_crn.filter(["id", "company_name", "crn"]) - df_crn["id"] = df_crn["id"].apply(lambda x: uuid.UUID(int=x)) + df_crn["id"] = df_crn["id"].apply(lambda x: UUID(int=x)) df_crn = df_crn.convert_dtypes(dtype_backend="pyarrow") return df_crn @@ -79,7 +79,7 @@ def duns_companies(all_companies: DataFrame) -> DataFrame: """ df_duns = ( all_companies.filter(["company_name", "duns"]) - .sample(n=500) + .sample(n=500, random_state=1618) .reset_index(drop=True) .convert_dtypes(dtype_backend="pyarrow") ) @@ -106,12 +106,96 @@ def cdms_companies(all_companies: DataFrame) -> DataFrame: df_cdms.columns = ["crn", "cdms"] df_cdms.reset_index(names="id", inplace=True) - df_cdms["id"] = df_cdms["id"].apply(lambda x: uuid.UUID(int=x)) + df_cdms["id"] = df_cdms["id"].apply(lambda x: UUID(int=x)) df_cdms = df_cdms.convert_dtypes(dtype_backend="pyarrow") return df_cdms +@pytest.fixture(scope="session") +def revolution_inc( + crn_companies: DataFrame, duns_companies: DataFrame, cdms_companies: DataFrame +) -> dict[str, str]: + """ + Revolution Inc. as it exists across all three datasets. + + UUIDs are converted to strings to mirror how Matchbox stores them. + + Based on the above fixtures, should return: + + * Three CRNs + * One DUNS + * Two CDMS + """ + crn_ids = crn_companies[ + crn_companies["company_name"].str.contains("Revolution", case=False) + ]["id"].tolist() + + duns_ids = duns_companies[ + duns_companies["company_name"].str.contains("Revolution", case=False) + ]["id"].tolist() + + revolution_crn = crn_companies[ + crn_companies["company_name"].str.contains("Revolution", case=False) + ]["crn"].iloc[0] + + cdms_ids = cdms_companies[cdms_companies["crn"] == revolution_crn]["id"].tolist() + + revolution = { + "crn": [str(id) for id in crn_ids], + "duns": [str(id) for id in duns_ids], + "cdms": [str(id) for id in cdms_ids], + } + + assert len(revolution.get("crn", [])) == 3 + assert len(revolution.get("duns", [])) == 1 + assert len(revolution.get("cdms", [])) == 2 + + return revolution + + +@pytest.fixture(scope="session") +def winner_inc( + crn_companies: DataFrame, duns_companies: DataFrame, cdms_companies: DataFrame +) -> dict[str, str]: + """ + Winner Inc. as it exists across all three datasets. + + UUIDs are converted to strings to mirror how Matchbox stores them. + + Based on the above fixtures, should return: + + * Three CRNs + * Zero DUNS + * Two CDMS + """ + crn_ids = crn_companies[ + crn_companies["company_name"].str.contains("Winner", case=False) + ]["id"].tolist() + + duns_ids = duns_companies[ + duns_companies["company_name"].str.contains("Winner", case=False) + ]["id"].tolist() + + winner_crn = crn_companies[ + crn_companies["company_name"].str.contains("Revolution", case=False) + ]["crn"].iloc[0] + + cdms_ids = cdms_companies[cdms_companies["crn"] == winner_crn]["id"].tolist() + + winner = { + "crn": [str(id) for id in crn_ids], + "duns": [str(id) for id in duns_ids], + "cdms": [str(id) for id in cdms_ids], + } + + assert len(winner.get("crn", [])) == 3 + assert len(winner.get("duns", [])) == 0 + assert len(winner.get("cdms", [])) == 2 + + return winner + + @pytest.fixture(scope="function") def query_clean_crn( matchbox_postgres: MatchboxPostgres, warehouse_data: list[Source] diff --git a/test/server/test_adapter.py b/test/server/test_adapter.py index 26388c2..35a165e 100644 --- a/test/server/test_adapter.py +++ b/test/server/test_adapter.py @@ -3,7 +3,7 @@ import pytest from dotenv import find_dotenv, load_dotenv -from matchbox.common.db import Source, SourceColumn +from matchbox.common.db import Match, Source, SourceColumn from matchbox.common.exceptions import ( MatchboxDataError, MatchboxDatasetError, @@ -17,7 +17,7 @@ Results, to_clusters, ) -from matchbox.helpers.selector import query, selector, selectors +from matchbox.helpers.selector import match, query, selector, selectors from matchbox.server.base import MatchboxDBAdapter, MatchboxModelAdapter from pandas import DataFrame @@ -578,6 +578,98 @@ def test_query_with_link_model(self): } assert crn_duns.hash.nunique() == 1000 + def test_match_one_to_many(self, revolution_inc: dict[str, list[str]]): + """Test that matching data works when the target has many IDs.""" + self.setup_database("link") + + crn_x_duns = "deterministic_naive_test.crn_naive_test.duns" + crn_wh = self.warehouse_data[0] + duns_wh = self.warehouse_data[1] + + res = match( + backend=self.backend, + source_id=revolution_inc["duns"][0], + source=str(duns_wh), + target=str(crn_wh), + resolution=crn_x_duns, + ) + + assert isinstance(res, Match) + assert res.source == str(duns_wh) + assert res.target == str(crn_wh) + assert res.cluster is not None + assert res.source_id == set(revolution_inc["duns"]) + assert res.target_id == set(revolution_inc["crn"]) + + def test_match_many_to_one(self, revolution_inc: dict[str, list[str]]): + """Test that matching data works when the source has more possible IDs.""" + self.setup_database("link") + + crn_x_duns = "deterministic_naive_test.crn_naive_test.duns" + crn_wh = self.warehouse_data[0] + duns_wh = self.warehouse_data[1] + + res = match( + backend=self.backend, + source_id=revolution_inc["crn"][0], + source=str(crn_wh), + target=str(duns_wh), + resolution=crn_x_duns, + ) + + assert isinstance(res, Match) + assert res.source == str(crn_wh) + assert res.target == str(duns_wh) + assert res.cluster is not None + assert res.source_id == set(revolution_inc["crn"]) + assert res.target_id == set(revolution_inc["duns"]) + + def test_match_one_to_none(self, winner_inc: dict[str, list[str]]): + """Test that matching data work when the target has no IDs.""" + self.setup_database("link") + + crn_x_duns = "deterministic_naive_test.crn_naive_test.duns" + crn_wh = self.warehouse_data[0] + duns_wh = self.warehouse_data[1] + + res = match( + backend=self.backend, + source_id=winner_inc["crn"][0], + source=str(crn_wh), + target=str(duns_wh), + resolution=crn_x_duns, + ) + + assert isinstance(res, Match) + assert res.source == str(crn_wh) + assert res.target == str(duns_wh) + assert res.cluster is not None + assert res.source_id == set(winner_inc["crn"]) + assert res.target_id == set() == set(winner_inc["duns"]) + + def test_match_none_to_none(self): + """Test that matching data work when the supplied key doesn't exist.""" + self.setup_database("link") + + crn_x_duns = "deterministic_naive_test.crn_naive_test.duns" + crn_wh = self.warehouse_data[0] + duns_wh = self.warehouse_data[1] + + res = match( + backend=self.backend, + source_id="foo", + source=str(crn_wh), + target=str(duns_wh), + resolution=crn_x_duns, + ) + + assert isinstance(res, Match) + assert res.source == str(crn_wh) + assert res.target == str(duns_wh) + assert res.cluster is None + assert res.source_id == set() + assert res.target_id == set() + def test_clear(self): """Test clearing the database.""" self.setup_database("dedupe")