From 08e07ecad395c7966305d0eb826074013fae976e Mon Sep 17 00:00:00 2001 From: Sophie Glinton Date: Mon, 2 Dec 2024 10:18:20 +0000 Subject: [PATCH 1/3] Add name to model hash --- src/matchbox/server/postgresql/utils/insert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/matchbox/server/postgresql/utils/insert.py b/src/matchbox/server/postgresql/utils/insert.py index 3d9ba75..3708bfc 100644 --- a/src/matchbox/server/postgresql/utils/insert.py +++ b/src/matchbox/server/postgresql/utils/insert.py @@ -125,7 +125,7 @@ def insert_model( """ logic_logger.info(f"[{model}] Registering model") with Session(engine) as session: - model_hash = list_to_value_ordered_hash([left.hash, right.hash]) + model_hash = list_to_value_ordered_hash([left.hash, right.hash, model]) # Create new model new_model = Models( From db5cf39faabc0c0cf40e8103217023e3d1a5db9e Mon Sep 17 00:00:00 2001 From: Sophie Glinton Date: Mon, 2 Dec 2024 11:04:50 +0000 Subject: [PATCH 2/3] Logic passes without unit tests --- .../server/postgresql/utils/insert.py | 75 +++++++++++-------- 1 file changed, 45 insertions(+), 30 deletions(-) diff --git a/src/matchbox/server/postgresql/utils/insert.py b/src/matchbox/server/postgresql/utils/insert.py index 3708bfc..c355c0b 100644 --- a/src/matchbox/server/postgresql/utils/insert.py +++ b/src/matchbox/server/postgresql/utils/insert.py @@ -4,6 +4,7 @@ from sqlalchemy import ( Engine, delete, + select, ) from sqlalchemy.dialects.postgresql import insert from sqlalchemy.exc import SQLAlchemyError @@ -125,55 +126,69 @@ def insert_model( """ logic_logger.info(f"[{model}] Registering model") with Session(engine) as session: - model_hash = list_to_value_ordered_hash([left.hash, right.hash, model]) + model_hash = list_to_value_ordered_hash([left.hash, right.hash, bytes(model, encoding="utf-8")]) - # Create new model - new_model = Models( + # Check if model exists + exists_stmt = select(Models).where(Models.hash == model_hash) + exists = session.scalar(exists_stmt) is not None + + # Upsert new model + stmt = insert(Models).values( hash=model_hash, type=ModelType.MODEL.value, name=model, description=description, truth=1.0, + ).on_conflict_do_update( + index_elements=['hash'], + set_={ + 'name': model, + 'description': description + } ) - session.add(new_model) - session.flush() - - def _create_closure_entries(parent_model: Models) -> None: - """Create closure entries for the new model, i.e. mappings between - nodes and any of their direct or indirect parents""" - session.add( - ModelsFrom( - parent=parent_model.hash, - child=model_hash, - level=1, - truth_cache=parent_model.truth, - ) - ) - ancestor_entries = ( - session.query(ModelsFrom) - .filter(ModelsFrom.child == parent_model.hash) - .all() - ) + session.execute(stmt) - for entry in ancestor_entries: + if not exists: + + def _create_closure_entries(parent_model: Models) -> None: + """Create closure entries for the new model, i.e. mappings between + nodes and any of their direct or indirect parents""" session.add( ModelsFrom( - parent=entry.parent, + parent=parent_model.hash, child=model_hash, - level=entry.level + 1, - truth_cache=entry.truth_cache, + level=1, + truth_cache=parent_model.truth, ) ) - # Create model lineage entries - _create_closure_entries(parent_model=left) + ancestor_entries = ( + session.query(ModelsFrom) + .filter(ModelsFrom.child == parent_model.hash) + .all() + ) + + for entry in ancestor_entries: + session.add( + ModelsFrom( + parent=entry.parent, + child=model_hash, + level=entry.level + 1, + truth_cache=entry.truth_cache, + ) + ) + + # Create model lineage entries + _create_closure_entries(parent_model=left) - if right != left: - _create_closure_entries(parent_model=right) + if right != left: + _create_closure_entries(parent_model=right) session.commit() + status = "Inserted new" if not exists else "Updated existing" + logic_logger.info(f"[{model}] {status} model with hash {model_hash}") logic_logger.info(f"[{model}] Done!") From c24e2438a4967942e512c8cfca109d511d02e08e Mon Sep 17 00:00:00 2001 From: Sophie Glinton Date: Mon, 2 Dec 2024 11:09:54 +0000 Subject: [PATCH 3/3] Unit test for model upsert --- .../server/postgresql/utils/insert.py | 29 ++++++++++--------- test/server/test_adapter.py | 7 +++++ 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/src/matchbox/server/postgresql/utils/insert.py b/src/matchbox/server/postgresql/utils/insert.py index c355c0b..65d8b8f 100644 --- a/src/matchbox/server/postgresql/utils/insert.py +++ b/src/matchbox/server/postgresql/utils/insert.py @@ -126,25 +126,28 @@ def insert_model( """ logic_logger.info(f"[{model}] Registering model") with Session(engine) as session: - model_hash = list_to_value_ordered_hash([left.hash, right.hash, bytes(model, encoding="utf-8")]) + model_hash = list_to_value_ordered_hash( + [left.hash, right.hash, bytes(model, encoding="utf-8")] + ) # Check if model exists exists_stmt = select(Models).where(Models.hash == model_hash) exists = session.scalar(exists_stmt) is not None # Upsert new model - stmt = insert(Models).values( - hash=model_hash, - type=ModelType.MODEL.value, - name=model, - description=description, - truth=1.0, - ).on_conflict_do_update( - index_elements=['hash'], - set_={ - 'name': model, - 'description': description - } + stmt = ( + insert(Models) + .values( + hash=model_hash, + type=ModelType.MODEL.value, + name=model, + description=description, + truth=1.0, + ) + .on_conflict_do_update( + index_elements=["hash"], + set_={"name": model, "description": description}, + ) ) session.execute(stmt) diff --git a/test/server/test_adapter.py b/test/server/test_adapter.py index 25913c3..439796e 100644 --- a/test/server/test_adapter.py +++ b/test/server/test_adapter.py @@ -220,6 +220,13 @@ def test_insert_model(self): assert self.backend.models.count() == model_count + 3 + # Test model upsert + self.backend.insert_model( + "link_1", left="dedupe_1", right="dedupe_2", description="Test upsert" + ) + + assert self.backend.models.count() == model_count + 3 + def test_model_get_probabilities(self): """Test that a model's ProbabilityResults can be retrieved.""" self.setup_database("dedupe")