Skip to content

Commit

Permalink
Working get dataset unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Will Langdale committed Oct 21, 2024
1 parent fbde0bd commit 212e142
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 47 deletions.
21 changes: 20 additions & 1 deletion src/matchbox/common/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,26 @@ def __init__(self, message: str = None, model_name: str = None):
self.model_name = model_name


class MatchboxDBDataError(Exception):
class MatchboxDatasetError(Exception):
"""Model not found."""

def __init__(
self,
message: str = None,
db_schema: str | None = None,
db_table: str | None = None,
):
if message is None:
message = "Dataset not found."
if db_table is not None:
message = f"Dataset {db_schema or ''}.{db_table} not found."

super().__init__(message)
self.db_schema = db_schema
self.db_table = db_table


class MatchboxDataError(Exception):
"""Data doesn't exist in the Matchbox source table."""

def __init__(
Expand Down
27 changes: 17 additions & 10 deletions src/matchbox/server/postgresql/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
from sqlalchemy.engine.result import ChunkedIteratorResult
from sqlalchemy.orm import Session

from matchbox.common.exceptions import MatchboxDBDataError, MatchboxModelError
from matchbox.common.exceptions import (
MatchboxDataError,
MatchboxDatasetError,
MatchboxModelError,
)
from matchbox.server.base import MatchboxDBAdapter, MatchboxModelAdapter
from matchbox.server.models import Cluster, Probability, Source, SourceWarehouse
from matchbox.server.postgresql.clusters import Clusters, clusters_association
Expand Down Expand Up @@ -174,7 +178,7 @@ def validate_hashes(
hash_type: The type of hash to validate.
Raises:
MatchboxDBDataError: If some items don't exist in the target table.
MatchboxDataError: If some items don't exist in the target table.
"""
if hash_type == "data":
Source = SourceData
Expand All @@ -199,7 +203,7 @@ def validate_hashes(
)

if len(data_inner_join) != len(hashes):
raise MatchboxDBDataError(
raise MatchboxDataError(
message=(
f"Some items don't exist the target table. "
f"Did you use {tgt_col} as your ID when deduplicating?"
Expand All @@ -221,12 +225,15 @@ def get_dataset(self, db_schema: str, db_table: str, engine: Engine) -> Source:
.filter_by(db_schema=db_schema, db_table=db_table)
.first()
)
return Source(
db_schema=dataset.db_schema,
db_table=dataset.db_table,
db_pk=dataset.db_id,
database=SourceWarehouse.from_engine(engine),
)
if dataset:
return Source(
db_schema=dataset.db_schema,
db_table=dataset.db_table,
db_pk=dataset.db_id,
database=SourceWarehouse.from_engine(engine),
)
else:
raise MatchboxDatasetError(db_schema=db_schema, db_table=db_table)

def get_model_subgraph(self) -> PyDiGraph:
"""Get the full subgraph of a model."""
Expand Down Expand Up @@ -265,7 +272,7 @@ def insert_model(
description: A description of the model
Raises
MatchboxDBDataError if, for a linker, the source models weren't found in
MatchboxDataError if, for a linker, the source models weren't found in
the database
"""
if right:
Expand Down
6 changes: 3 additions & 3 deletions src/matchbox/server/postgresql/utils/hash.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session

from matchbox.common.exceptions import MatchboxDBDataError
from matchbox.common.exceptions import MatchboxDataError
from matchbox.helpers.selector import get_schema_table_names
from matchbox.server.postgresql.data import SourceDataset
from matchbox.server.postgresql.models import Models
Expand Down Expand Up @@ -29,7 +29,7 @@ def table_name_to_uuid(schema_table: str, engine: Engine) -> bytes:
dataset_uuid = session.execute(stmt).scalar()

if dataset_uuid is None:
raise MatchboxDBDataError(table=SourceDataset.__tablename__, data=schema_table)
raise MatchboxDataError(table=SourceDataset.__tablename__, data=schema_table)

return dataset_uuid

Expand All @@ -52,6 +52,6 @@ def model_name_to_hash(run_name: str, engine: Engine) -> bytes:
model_hash = session.execute(stmt).scalar()

if model_hash is None:
raise MatchboxDBDataError(table=Models.__tablename__, data=run_name)
raise MatchboxDataError(table=Models.__tablename__, data=run_name)

return model_hash
6 changes: 3 additions & 3 deletions src/matchbox/server/postgresql/utils/insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session

from matchbox.common.exceptions import MatchboxDBDataError
from matchbox.common.exceptions import MatchboxDataError
from matchbox.common.hash import list_to_value_ordered_hash
from matchbox.server.models import Cluster, Probability
from matchbox.server.postgresql.clusters import Clusters, clusters_association
Expand Down Expand Up @@ -124,7 +124,7 @@ def insert_probabilities(
model_hash = db_model.sha1

if db_model is None:
raise MatchboxDBDataError(source=Models, data=model)
raise MatchboxDataError(source=Models, data=model)

# Clear old model probabilities
old_probs_subquery = (
Expand Down Expand Up @@ -203,7 +203,7 @@ def insert_clusters(
model_hash = db_model.sha1

if db_model is None:
raise MatchboxDBDataError(source=Models, data=model)
raise MatchboxDataError(source=Models, data=model)

# Clear old model endorsements
old_cluster_creates_subquery = db_model.creates.select().with_only_columns(
Expand Down
Loading

0 comments on commit 212e142

Please sign in to comment.