From 9a1f9704bfa8af162c6349d44fd50cc301154a7f Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Mon, 21 Oct 2024 11:36:37 +0100 Subject: [PATCH] Working fixture to add link models --- src/matchbox/common/exceptions.py | 13 +++++++++ src/matchbox/server/postgresql/adapter.py | 8 ++--- .../server/postgresql/utils/selector.py | 9 ++++-- test/fixtures/data.py | 6 ++-- test/fixtures/db.py | 1 + test/server/test_adapter.py | 29 +++++++++++++++++-- 6 files changed, 53 insertions(+), 13 deletions(-) diff --git a/src/matchbox/common/exceptions.py b/src/matchbox/common/exceptions.py index edaa3f8a..2df3f932 100644 --- a/src/matchbox/common/exceptions.py +++ b/src/matchbox/common/exceptions.py @@ -9,6 +9,19 @@ class MatchboxValidatonError(Exception): """Validation of data failed.""" +class MatchboxModelError(Exception): + """Model not found.""" + + def __init__(self, message: str = None, model_name: str = None): + if message is None: + message = "Model not found." + if model_name is not None: + message = f"Model {model_name} not found." + + super().__init__(message) + self.model_name = model_name + + class MatchboxDBDataError(Exception): """Data doesn't exist in the Matchbox source table.""" diff --git a/src/matchbox/server/postgresql/adapter.py b/src/matchbox/server/postgresql/adapter.py index c947185c..3ef5bf35 100644 --- a/src/matchbox/server/postgresql/adapter.py +++ b/src/matchbox/server/postgresql/adapter.py @@ -6,7 +6,7 @@ from sqlalchemy.engine.result import ChunkedIteratorResult from sqlalchemy.orm import Session -from matchbox.common.exceptions import MatchboxDBDataError +from matchbox.common.exceptions import MatchboxDBDataError, MatchboxModelError from matchbox.server.base import MatchboxDBAdapter, MatchboxModelAdapter from matchbox.server.models import Cluster, Probability, Source, SourceWarehouse from matchbox.server.postgresql.clusters import Clusters, clusters_association @@ -106,10 +106,10 @@ def insert_clusters( @classmethod def get_model(cls, model_name: str) -> "MatchboxPostgresModel": with Session(MBDB.get_engine()) as session: - model = session.query(Models).filter_by(name=model_name).first() - if model: + if model := session.query(Models).filter_by(name=model_name).first(): return cls(model) - return None + else: + raise MatchboxModelError(model_name=model_name) class MatchboxPostgres(MatchboxDBAdapter): diff --git a/src/matchbox/server/postgresql/utils/selector.py b/src/matchbox/server/postgresql/utils/selector.py index 99d213cb..50ab008f 100644 --- a/src/matchbox/server/postgresql/utils/selector.py +++ b/src/matchbox/server/postgresql/utils/selector.py @@ -8,6 +8,7 @@ from sqlalchemy.sql.selectable import Select from matchbox.common.db import sql_to_df +from matchbox.common.exceptions import MatchboxModelError from matchbox.server.models import Source from matchbox.server.postgresql.clusters import Clusters, clusters_association from matchbox.server.postgresql.data import SourceData, SourceDataset @@ -105,9 +106,11 @@ def _parent_to_tree(model_name: str, engine: Engine) -> tuple[bytes, list[bytes] """ with Session(engine) as session: - model = session.query(Models).filter_by(name=model_name).first() - model_children = get_all_children(model) - model_children.pop(0) # includes original model + if model := session.query(Models).filter_by(name=model_name).first(): + model_children = get_all_children(model) + model_children.pop(0) # includes original model + else: + raise MatchboxModelError(model_name=model_name) return model.sha1, [m.sha1 for m in model_children] diff --git a/test/fixtures/data.py b/test/fixtures/data.py index de516d3f..b9558d0e 100644 --- a/test/fixtures/data.py +++ b/test/fixtures/data.py @@ -209,7 +209,7 @@ def query_clean_crn_deduped( crn = query( selector=select_crn, backend=matchbox_postgres, - model="naive_mb.crn", + model="naive_test.crn", return_type="pandas", ) @@ -239,7 +239,7 @@ def query_clean_duns_deduped( duns = query( selector=select_duns, backend=matchbox_postgres, - model="naive_mb.duns", + model="naive_test.duns", return_type="pandas", ) @@ -269,7 +269,7 @@ def query_clean_cdms_deduped( cdms = query( selector=select_cdms, backend=matchbox_postgres, - model="naive_mb.cdms", + model="naive_test.cdms", return_type="pandas", ) diff --git a/test/fixtures/db.py b/test/fixtures/db.py index b7122b96..415bc0dc 100644 --- a/test/fixtures/db.py +++ b/test/fixtures/db.py @@ -124,6 +124,7 @@ def _db_add_link_models_and_data( ) -> None: """Links data from the warehouse and logs in Matchbox.""" db_add_dedupe_models_and_data( + db_add_indexed_data=db_add_indexed_data, backend=backend, warehouse_data=warehouse_data, dedupe_data=dedupe_data, diff --git a/test/server/test_adapter.py b/test/server/test_adapter.py index 3e896658..1496cefa 100644 --- a/test/server/test_adapter.py +++ b/test/server/test_adapter.py @@ -5,10 +5,16 @@ from pandas import DataFrame from pytest import FixtureRequest -from ..fixtures.db import AddDedupeModelsAndDataCallable, AddIndexedDataCallable +from ..fixtures.db import ( + AddDedupeModelsAndDataCallable, + AddIndexedDataCallable, + AddLinkModelsAndDataCallable, +) from ..fixtures.models import ( dedupe_data_test_params, dedupe_model_test_params, + link_data_test_params, + link_model_test_params, ) dotenv_path = find_dotenv() @@ -173,9 +179,26 @@ def test_query_with_dedupe_model( assert df_crn.cluster_hash.nunique() == 1000 -def test_query_with_link_model(): +def test_query_with_link_model( + matchbox_postgres: MatchboxPostgres, + db_add_dedupe_models_and_data: AddDedupeModelsAndDataCallable, + db_add_indexed_data: AddIndexedDataCallable, + db_add_link_models_and_data: AddLinkModelsAndDataCallable, + warehouse_data: list[Source], + request: FixtureRequest, +): """Test querying data from a link point of truth.""" - pass + db_add_link_models_and_data( + db_add_indexed_data=db_add_indexed_data, + db_add_dedupe_models_and_data=db_add_dedupe_models_and_data, + backend=matchbox_postgres, + warehouse_data=warehouse_data, + dedupe_data=dedupe_data_test_params, + dedupe_models=[dedupe_model_test_params[0]], # Naive deduper, + link_data=link_data_test_params, + link_models=[link_model_test_params[0]], # Deterministic linker, + request=request, + ) def test_validate_hashes(matchbox_postgres):