Skip to content

Commit

Permalink
Test Splink settings and Splink training functions
Browse files Browse the repository at this point in the history
  • Loading branch information
lmazz1-dbt committed Nov 18, 2024
1 parent ac6d92c commit 042e6bf
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/matchbox/models/linkers/splinklinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,15 @@ def check_ids_match(self) -> "SplinkSettings":
"left_id and right_id must match in a Splink linker."
)
return self

@model_validator(mode="after")
def add_enforced_settings(self) -> "SplinkSettings":
def check_link_only(self) -> "SplinkSettings":
if self.linker_settings.link_type != "link_only":
raise ValueError('link_type must be set to "link_only"')
self.linker_settings.link_type = "link_only"
return self

@model_validator(mode="after")
def add_enforced_settings(self) -> "SplinkSettings":
self.linker_settings.unique_id_column_name = self.left_id
return self

Expand Down
47 changes: 47 additions & 0 deletions test/client/test_linkers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import pytest
from matchbox import make_model, query
from matchbox.helpers import selectors
from matchbox.models.linkers.splinklinker import SplinkLinkerFunction, SplinkSettings
from matchbox.server.models import Source, SourceWarehouse
from matchbox.server.postgresql import MatchboxPostgres
from pandas import DataFrame
from splink import SettingsCreator

from ..fixtures.db import AddDedupeModelsAndDataCallable, AddIndexedDataCallable
from ..fixtures.models import (
Expand Down Expand Up @@ -187,3 +189,48 @@ def unique_non_null(s):

assert isinstance(clusters, DataFrame)
assert clusters.hash.nunique() == fx_data.unique_n


def test_splink_training_functions():
# You can create a valid SplinkLinkerFunction
SplinkLinkerFunction(
function="estimate_u_using_random_sampling",
arguments={"max_pairs": 1e4},
)
# You can't reference a function that doesn't exist
with pytest.raises(ValueError):
SplinkLinkerFunction(function="made_up_funcname", arguments=dict())
# You can't pass arguments that don't exist
with pytest.raises(ValueError):
SplinkLinkerFunction(
function="estimate_u_using_random_sampling", arguments={"foo": "bar"}
)

def test_splink_settings():
valid_settings = SplinkSettings(
left_id="hash",
right_id="hash",
linker_training_functions=[],
linker_settings=SettingsCreator(link_type="link_only"),
threshold=None,
)
assert valid_settings.linker_settings.unique_id_column_name == "hash"
# Can only use "link_only"
with pytest.raises(ValueError):
valid_settings = SplinkSettings(
left_id="hash",
right_id="hash",
linker_training_functions=[],
linker_settings=SettingsCreator(link_type="dedupe_only"),
threshold=None,
)
# Left and right ID must coincide
with pytest.raises(ValueError):
valid_settings = SplinkSettings(
left_id="hash",
right_id="hash2",
linker_training_functions=[],
linker_settings=SettingsCreator(link_type="link_only"),
threshold=None,
)

0 comments on commit 042e6bf

Please sign in to comment.