diff --git a/src/matchbox/common/exceptions.py b/src/matchbox/common/exceptions.py index 2df3f93..aeeb027 100644 --- a/src/matchbox/common/exceptions.py +++ b/src/matchbox/common/exceptions.py @@ -22,7 +22,26 @@ def __init__(self, message: str = None, model_name: str = None): self.model_name = model_name -class MatchboxDBDataError(Exception): +class MatchboxDatasetError(Exception): + """Model not found.""" + + def __init__( + self, + message: str = None, + db_schema: str | None = None, + db_table: str | None = None, + ): + if message is None: + message = "Dataset not found." + if db_table is not None: + message = f"Dataset {db_schema or ''}.{db_table} not found." + + super().__init__(message) + self.db_schema = db_schema + self.db_table = db_table + + +class MatchboxDataError(Exception): """Data doesn't exist in the Matchbox source table.""" def __init__( diff --git a/src/matchbox/server/postgresql/adapter.py b/src/matchbox/server/postgresql/adapter.py index 3ef5bf3..b233d21 100644 --- a/src/matchbox/server/postgresql/adapter.py +++ b/src/matchbox/server/postgresql/adapter.py @@ -6,7 +6,11 @@ from sqlalchemy.engine.result import ChunkedIteratorResult from sqlalchemy.orm import Session -from matchbox.common.exceptions import MatchboxDBDataError, MatchboxModelError +from matchbox.common.exceptions import ( + MatchboxDataError, + MatchboxDatasetError, + MatchboxModelError, +) from matchbox.server.base import MatchboxDBAdapter, MatchboxModelAdapter from matchbox.server.models import Cluster, Probability, Source, SourceWarehouse from matchbox.server.postgresql.clusters import Clusters, clusters_association @@ -174,7 +178,7 @@ def validate_hashes( hash_type: The type of hash to validate. Raises: - MatchboxDBDataError: If some items don't exist in the target table. + MatchboxDataError: If some items don't exist in the target table. """ if hash_type == "data": Source = SourceData @@ -199,7 +203,7 @@ def validate_hashes( ) if len(data_inner_join) != len(hashes): - raise MatchboxDBDataError( + raise MatchboxDataError( message=( f"Some items don't exist the target table. " f"Did you use {tgt_col} as your ID when deduplicating?" @@ -221,12 +225,15 @@ def get_dataset(self, db_schema: str, db_table: str, engine: Engine) -> Source: .filter_by(db_schema=db_schema, db_table=db_table) .first() ) - return Source( - db_schema=dataset.db_schema, - db_table=dataset.db_table, - db_pk=dataset.db_id, - database=SourceWarehouse.from_engine(engine), - ) + if dataset: + return Source( + db_schema=dataset.db_schema, + db_table=dataset.db_table, + db_pk=dataset.db_id, + database=SourceWarehouse.from_engine(engine), + ) + else: + raise MatchboxDatasetError(db_schema=db_schema, db_table=db_table) def get_model_subgraph(self) -> PyDiGraph: """Get the full subgraph of a model.""" @@ -265,7 +272,7 @@ def insert_model( description: A description of the model Raises - MatchboxDBDataError if, for a linker, the source models weren't found in + MatchboxDataError if, for a linker, the source models weren't found in the database """ if right: diff --git a/src/matchbox/server/postgresql/utils/hash.py b/src/matchbox/server/postgresql/utils/hash.py index 86f975b..4626db5 100644 --- a/src/matchbox/server/postgresql/utils/hash.py +++ b/src/matchbox/server/postgresql/utils/hash.py @@ -1,7 +1,7 @@ from sqlalchemy import Engine, select from sqlalchemy.orm import Session -from matchbox.common.exceptions import MatchboxDBDataError +from matchbox.common.exceptions import MatchboxDataError from matchbox.helpers.selector import get_schema_table_names from matchbox.server.postgresql.data import SourceDataset from matchbox.server.postgresql.models import Models @@ -29,7 +29,7 @@ def table_name_to_uuid(schema_table: str, engine: Engine) -> bytes: dataset_uuid = session.execute(stmt).scalar() if dataset_uuid is None: - raise MatchboxDBDataError(table=SourceDataset.__tablename__, data=schema_table) + raise MatchboxDataError(table=SourceDataset.__tablename__, data=schema_table) return dataset_uuid @@ -52,6 +52,6 @@ def model_name_to_hash(run_name: str, engine: Engine) -> bytes: model_hash = session.execute(stmt).scalar() if model_hash is None: - raise MatchboxDBDataError(table=Models.__tablename__, data=run_name) + raise MatchboxDataError(table=Models.__tablename__, data=run_name) return model_hash diff --git a/src/matchbox/server/postgresql/utils/insert.py b/src/matchbox/server/postgresql/utils/insert.py index c33f34c..9896016 100644 --- a/src/matchbox/server/postgresql/utils/insert.py +++ b/src/matchbox/server/postgresql/utils/insert.py @@ -7,7 +7,7 @@ from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session -from matchbox.common.exceptions import MatchboxDBDataError +from matchbox.common.exceptions import MatchboxDataError from matchbox.common.hash import list_to_value_ordered_hash from matchbox.server.models import Cluster, Probability from matchbox.server.postgresql.clusters import Clusters, clusters_association @@ -124,7 +124,7 @@ def insert_probabilities( model_hash = db_model.sha1 if db_model is None: - raise MatchboxDBDataError(source=Models, data=model) + raise MatchboxDataError(source=Models, data=model) # Clear old model probabilities old_probs_subquery = ( @@ -203,7 +203,7 @@ def insert_clusters( model_hash = db_model.sha1 if db_model is None: - raise MatchboxDBDataError(source=Models, data=model) + raise MatchboxDataError(source=Models, data=model) # Clear old model endorsements old_cluster_creates_subquery = db_model.creates.select().with_only_columns( diff --git a/test/server/test_adapter.py b/test/server/test_adapter.py index 84e7658..24d7cbd 100644 --- a/test/server/test_adapter.py +++ b/test/server/test_adapter.py @@ -1,8 +1,9 @@ import pytest from dotenv import find_dotenv, load_dotenv -from matchbox.common.exceptions import MatchboxDBDataError +from matchbox.common.exceptions import MatchboxDataError, MatchboxDatasetError from matchbox.common.hash import HASH_FUNC from matchbox.helpers.selector import query, selector, selectors +from matchbox.server import MatchboxDBAdapter from matchbox.server.models import Source from matchbox.server.postgresql import MatchboxPostgres from pandas import DataFrame @@ -23,18 +24,33 @@ load_dotenv(dotenv_path) +backends = [ + pytest.param("matchbox_postgres", id="postgres"), +] + + +@pytest.fixture(params=backends) +def matchbox_backend(request: pytest.FixtureRequest) -> MatchboxDBAdapter: + """Fixture to provide different backend implementations.""" + return request.param() + + +@pytest.mark.parametrize("backend", backends) def test_index( - matchbox_postgres: MatchboxPostgres, + backend: MatchboxDBAdapter, db_add_indexed_data: AddIndexedDataCallable, warehouse_data: list[Source], crn_companies: DataFrame, duns_companies: DataFrame, cdms_companies: DataFrame, + request: pytest.FixtureRequest, ): """Test that indexing data works.""" - assert matchbox_postgres.data.count() == 0 + backend = request.getfixturevalue(backend) + + assert backend.data.count() == 0 - db_add_indexed_data(backend=matchbox_postgres, warehouse_data=warehouse_data) + db_add_indexed_data(backend=backend, warehouse_data=warehouse_data) def count_deduplicates(df: DataFrame) -> int: return df.drop(columns=["id"]).drop_duplicates().shape[0] @@ -43,17 +59,21 @@ def count_deduplicates(df: DataFrame) -> int: count_deduplicates(df) for df in [crn_companies, duns_companies, cdms_companies] ) - assert matchbox_postgres.data.count() == unique + assert backend.data.count() == unique +@pytest.mark.parametrize("backend", backends) def test_query_single_table( - matchbox_postgres: MatchboxPostgres, + backend: MatchboxDBAdapter, db_add_indexed_data: AddIndexedDataCallable, warehouse_data: list[Source], + request: pytest.FixtureRequest, ): """Test querying data from the database.""" + backend = request.getfixturevalue(backend) + # Setup - db_add_indexed_data(backend=matchbox_postgres, warehouse_data=warehouse_data) + db_add_indexed_data(backend=backend, warehouse_data=warehouse_data) # Test crn = warehouse_data[0] @@ -66,7 +86,7 @@ def test_query_single_table( df_crn_sample = query( selector=select_crn, - backend=matchbox_postgres, + backend=backend, model=None, return_type="pandas", limit=10, @@ -77,7 +97,7 @@ def test_query_single_table( df_crn_full = query( selector=select_crn, - backend=matchbox_postgres, + backend=backend, model=None, return_type="pandas", ) @@ -90,14 +110,18 @@ def test_query_single_table( } +@pytest.mark.parametrize("backend", backends) def test_query_multi_table( - matchbox_postgres: MatchboxPostgres, + backend: MatchboxDBAdapter, db_add_indexed_data: AddIndexedDataCallable, warehouse_data: list[Source], + request: pytest.FixtureRequest, ): """Test querying data from multiple tables from the database.""" + backend = request.getfixturevalue(backend) + # Setup - db_add_indexed_data(backend=matchbox_postgres, warehouse_data=warehouse_data) + db_add_indexed_data(backend=backend, warehouse_data=warehouse_data) # Test crn = warehouse_data[0] @@ -117,7 +141,7 @@ def test_query_multi_table( df_crn_duns_full = query( selector=select_crn_duns, - backend=matchbox_postgres, + backend=backend, model=None, return_type="pandas", ) @@ -135,18 +159,21 @@ def test_query_multi_table( } +@pytest.mark.parametrize("backend", backends) def test_query_with_dedupe_model( - matchbox_postgres: MatchboxPostgres, + backend: MatchboxDBAdapter, db_add_dedupe_models_and_data: AddDedupeModelsAndDataCallable, db_add_indexed_data: AddIndexedDataCallable, warehouse_data: list[Source], request: pytest.FixtureRequest, ): """Test querying data from a deduplication point of truth.""" + backend = request.getfixturevalue(backend) + # Setup db_add_dedupe_models_and_data( db_add_indexed_data=db_add_indexed_data, - backend=matchbox_postgres, + backend=backend, warehouse_data=warehouse_data, dedupe_data=dedupe_data_test_params, dedupe_models=[dedupe_model_test_params[0]], # Naive deduper, @@ -164,7 +191,7 @@ def test_query_with_dedupe_model( df_crn = query( selector=select_crn, - backend=matchbox_postgres, + backend=backend, model="naive_test.crn", return_type="pandas", ) @@ -181,8 +208,9 @@ def test_query_with_dedupe_model( assert df_crn.cluster_hash.nunique() == 1000 +@pytest.mark.parametrize("backend", backends) def test_query_with_link_model( - matchbox_postgres: MatchboxPostgres, + backend: MatchboxDBAdapter, db_add_dedupe_models_and_data: AddDedupeModelsAndDataCallable, db_add_indexed_data: AddIndexedDataCallable, db_add_link_models_and_data: AddLinkModelsAndDataCallable, @@ -190,11 +218,13 @@ def test_query_with_link_model( request: pytest.FixtureRequest, ): """Test querying data from a link point of truth.""" + backend = request.getfixturevalue(backend) + # Setup db_add_link_models_and_data( db_add_indexed_data=db_add_indexed_data, db_add_dedupe_models_and_data=db_add_dedupe_models_and_data, - backend=matchbox_postgres, + backend=backend, warehouse_data=warehouse_data, dedupe_data=dedupe_data_test_params, dedupe_models=[dedupe_model_test_params[0]], # Naive deduper, @@ -224,7 +254,7 @@ def test_query_with_link_model( crn_duns = query( selector=select_crn_duns, - backend=matchbox_postgres, + backend=backend, model=linker_name, return_type="pandas", ) @@ -241,18 +271,21 @@ def test_query_with_link_model( assert crn_duns.cluster_hash.nunique() == 1000 +@pytest.mark.parametrize("backend", backends) def test_validate_hashes( - matchbox_postgres: MatchboxPostgres, + backend: MatchboxDBAdapter, db_add_dedupe_models_and_data: AddDedupeModelsAndDataCallable, db_add_indexed_data: AddIndexedDataCallable, warehouse_data: list[Source], request: pytest.FixtureRequest, ): """Test validating data hashes.""" + backend = request.getfixturevalue(backend) + # Setup db_add_dedupe_models_and_data( db_add_indexed_data=db_add_indexed_data, - backend=matchbox_postgres, + backend=backend, warehouse_data=warehouse_data, dedupe_data=dedupe_data_test_params, dedupe_models=[dedupe_model_test_params[0]], # Naive deduper, @@ -267,31 +300,50 @@ def test_validate_hashes( ) df_crn = query( selector=select_crn, - backend=matchbox_postgres, + backend=backend, model="naive_test.crn", return_type="pandas", ) # Test validating data hashes - matchbox_postgres.validate_hashes( - hashes=df_crn.data_hash.to_list(), hash_type="data" - ) + backend.validate_hashes(hashes=df_crn.data_hash.to_list(), hash_type="data") # Test validating cluster hashes - matchbox_postgres.validate_hashes( + backend.validate_hashes( hashes=df_crn.cluster_hash.drop_duplicates().to_list(), hash_type="cluster" ) # Test validating nonexistant hashes errors - with pytest.raises(MatchboxDBDataError): - matchbox_postgres.validate_hashes( + with pytest.raises(MatchboxDataError): + backend.validate_hashes( hashes=[HASH_FUNC(b"nonexistant").digest()], hash_type="data" ) -def test_get_dataset(matchbox_postgres: MatchboxPostgres): - # Test getting an existing model - pass +@pytest.mark.parametrize("backend", backends) +def test_get_dataset( + backend: MatchboxDBAdapter, + db_add_indexed_data: AddIndexedDataCallable, + warehouse_data: list[Source], + request: pytest.FixtureRequest, +): + """Test querying data from the database.""" + backend = request.getfixturevalue(backend) + + # Setup + db_add_indexed_data(backend=backend, warehouse_data=warehouse_data) + crn = warehouse_data[0] + + # Test get a real dataset + backend.get_dataset( + db_schema=crn.db_schema, db_table=crn.db_table, engine=crn.database.engine + ) + + # Test getting a dataset that doesn't exist in Matchbox + with pytest.raises(MatchboxDatasetError): + backend.get_dataset( + db_schema="nonexistant", db_table="nonexistant", engine=crn.database.engine + ) def test_get_model_subgraph(matchbox_postgres: MatchboxPostgres):