Skip to content

Commit

Permalink
Merge pull request #13 from uktrade/fix/update-model-insert
Browse files Browse the repository at this point in the history
Fix/update model insert
  • Loading branch information
sophie-daintta authored Dec 2, 2024
2 parents 7cf4562 + c24e243 commit 8a49e11
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 35 deletions.
88 changes: 53 additions & 35 deletions src/matchbox/server/postgresql/utils/insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sqlalchemy import (
Engine,
delete,
select,
)
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import SQLAlchemyError
Expand Down Expand Up @@ -125,55 +126,72 @@ def insert_model(
"""
logic_logger.info(f"[{model}] Registering model")
with Session(engine) as session:
model_hash = list_to_value_ordered_hash([left.hash, right.hash])

# Create new model
new_model = Models(
hash=model_hash,
type=ModelType.MODEL.value,
name=model,
description=description,
truth=1.0,
model_hash = list_to_value_ordered_hash(
[left.hash, right.hash, bytes(model, encoding="utf-8")]
)
session.add(new_model)
session.flush()

def _create_closure_entries(parent_model: Models) -> None:
"""Create closure entries for the new model, i.e. mappings between
nodes and any of their direct or indirect parents"""
session.add(
ModelsFrom(
parent=parent_model.hash,
child=model_hash,
level=1,
truth_cache=parent_model.truth,
)
)

ancestor_entries = (
session.query(ModelsFrom)
.filter(ModelsFrom.child == parent_model.hash)
.all()
# Check if model exists
exists_stmt = select(Models).where(Models.hash == model_hash)
exists = session.scalar(exists_stmt) is not None

# Upsert new model
stmt = (
insert(Models)
.values(
hash=model_hash,
type=ModelType.MODEL.value,
name=model,
description=description,
truth=1.0,
)
.on_conflict_do_update(
index_elements=["hash"],
set_={"name": model, "description": description},
)
)

session.execute(stmt)

for entry in ancestor_entries:
if not exists:

def _create_closure_entries(parent_model: Models) -> None:
"""Create closure entries for the new model, i.e. mappings between
nodes and any of their direct or indirect parents"""
session.add(
ModelsFrom(
parent=entry.parent,
parent=parent_model.hash,
child=model_hash,
level=entry.level + 1,
truth_cache=entry.truth_cache,
level=1,
truth_cache=parent_model.truth,
)
)

# Create model lineage entries
_create_closure_entries(parent_model=left)
ancestor_entries = (
session.query(ModelsFrom)
.filter(ModelsFrom.child == parent_model.hash)
.all()
)

for entry in ancestor_entries:
session.add(
ModelsFrom(
parent=entry.parent,
child=model_hash,
level=entry.level + 1,
truth_cache=entry.truth_cache,
)
)

# Create model lineage entries
_create_closure_entries(parent_model=left)

if right != left:
_create_closure_entries(parent_model=right)
if right != left:
_create_closure_entries(parent_model=right)

session.commit()

status = "Inserted new" if not exists else "Updated existing"
logic_logger.info(f"[{model}] {status} model with hash {model_hash}")
logic_logger.info(f"[{model}] Done!")


Expand Down
7 changes: 7 additions & 0 deletions test/server/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,13 @@ def test_insert_model(self):

assert self.backend.models.count() == model_count + 3

# Test model upsert
self.backend.insert_model(
"link_1", left="dedupe_1", right="dedupe_2", description="Test upsert"
)

assert self.backend.models.count() == model_count + 3

def test_model_get_probabilities(self):
"""Test that a model's ProbabilityResults can be retrieved."""
self.setup_database("dedupe")
Expand Down

0 comments on commit 8a49e11

Please sign in to comment.